Skip to content
Snippets Groups Projects
Unverified Commit 84501432 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!3296 tbnet update for modelarts

Merge pull request !3296 from DMH_coco/master
parents 095f2b3f cb9adeca
No related branches found
No related tags found
No related merge requests found
Showing
with 459 additions and 200 deletions
......@@ -165,6 +165,12 @@ python infer.py \
├─src_train.csv # training dataset
└─id_maps.json # explanation configuration
├─src
├─utils
├─__init__.py # init file
├─device_adapter.py # Get cloud ID
├─local_adapter.py # Get local ID
├─moxing_adapter.py # Parameter processing
└─param.py # parse args
├─aggregator.py # inference result aggregation
├─config.py # parsing parameter configuration
├─dataset.py # generate dataset
......@@ -172,10 +178,10 @@ python infer.py \
├─metrics.py # model metrics
├─steam.py # 'steam' dataset text explainer
└─tbnet.py # TB-Net model
├─export.py # export mindir script
├─export.py # export mindir script
├─preprocess_dataset.py # dataset preprocess script
├─preprocess.py # inference data preprocess script
├─postprocess.py # inference result calculation script
├─preprocess.py # inference data preprocess script
├─postprocess.py # inference result calculation script
├─eval.py # evaluation
├─infer.py # inference and explanation
└─train.py # training
......@@ -183,6 +189,23 @@ python infer.py \
## [Script Parameters](#contents)
The entire code structure is as following:
```python
data_path: "." # The location of input data
load_path: "./checkpoint" # file path of stored checkpoint file in training
checkpoint_id: 19 # checkpoint id
same_relation: False # only generate paths that relation1 is same as relation2
dataset: "steam" # dataset name
train_csv: "train.csv" # the train csv datafile inside the dataset folder
test_csv: "test.csv" # the test csv datafile inside the dataset folder
infer_csv: "infer.csv" # the infer csv datafile inside the dataset folder
device_id: 0 # Device id
device_target: "GPU" # device id of GPU or Ascend
run_mode: "graph" # run code by GRAPH mode or PYNATIVE mode
epochs: 20 # number of training epochs
```
- preprocess_dataset.py parameters
```text
......
......@@ -160,6 +160,12 @@ python infer.py \
├─src_train.csv # 训练数据集
└─id_maps.json # 输出解释相关配置
├─src
├─utils
├─__init__.py # 初始化文件
├─device_adapter.py # 获得云设备id
├─local_adapter.py # 获得本地id
├─moxing_adapter.py # 参数处理
└─param.py # 解析参数
├─aggregator.py # 推理结果聚合
├─config.py # 参数配置解析
├─dataset.py # 创建数据集
......@@ -171,6 +177,7 @@ python infer.py \
├─preprocess_dataset.py # 数据集预处理脚本
├─preprocess.py # 推理数据预处理脚本
├─postprocess.py # 推理结果计算脚本
├─default_config.yaml # yaml配置文件
├─eval.py # 评估网络
├─infer.py # 推理和解释
└─train.py # 训练网络
......@@ -178,6 +185,23 @@ python infer.py \
## [脚本参数](#目录)
train.py与param.py主要参数如下:
```python
data_path: "." # 数据集路径
load_path: "./checkpoint" # 检查点保存路径
checkpoint_id: 19 # 检查点id
same_relation: False # 预处理数据集时,只生成`relation1`与`relation2`相同的路径
dataset: "steam" # 数据集名陈
train_csv: "train.csv" # 数据集中训练集文件名
test_csv: "test.csv" # 数据集中测试集文件名
infer_csv: "infer.csv" # 数据集中推理数据文件名
device_id: 0 # 设备id
device_target: "GPU" # 运行平台
run_mode: "graph" # 运行模式
epochs: 20 # 训练轮数
```
- preprocess_dataset.py参数
```text
......
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unlesee you know exactly what you are doing)
enable_modelarts: False
# url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# url for openi
ckpt_url: ""
result_url: ""
# path for local
data_path: "."
output_path: "./output"
load_path: "./checkpoint"
# preprocess_data
same_relation: False
#train
dataset: "steam"
train_csv: "train.csv"
test_csv: "test.csv"
infer_csv: "infer.csv"
device_id: 0
epochs: 20
device_target: "GPU"
run_mode: "graph"
#eval
checkpoint_id: 19
......@@ -15,7 +15,6 @@
"""TB-Net evaluation."""
import os
import argparse
import math
from mindspore import context, Model, load_checkpoint, load_param_into_net
......@@ -23,90 +22,36 @@ import mindspore.common.dtype as mstype
from src import tbnet, config, metrics, dataset
def get_args():
"""Parse commandline arguments."""
parser = argparse.ArgumentParser(description='Train TBNet.')
parser.add_argument(
'--dataset',
type=str,
required=False,
default='steam',
help="'steam' dataset is supported currently"
)
parser.add_argument(
'--csv',
type=str,
required=False,
default='test.csv',
help="the csv datafile inside the dataset folder (e.g. test.csv)"
)
parser.add_argument(
'--checkpoint_id',
type=int,
required=True,
help="use which checkpoint(.ckpt) file to eval"
)
parser.add_argument(
'--device_id',
type=int,
required=False,
default=0,
help="device id"
)
parser.add_argument(
'--device_target',
type=str,
required=False,
default='GPU',
choices=['GPU', 'Ascend'],
help="run code on GPU or Ascend NPU"
)
parser.add_argument(
'--run_mode',
type=str,
required=False,
default='graph',
choices=['graph', 'pynative'],
help="run code by GRAPH mode or PYNATIVE mode"
)
return parser.parse_args()
from src.utils.param import param
from src.utils.moxing_adapter import moxing_wrapper
from preprocess_dataset import preprocess_data
@moxing_wrapper(preprocess_data)
def eval_tbnet():
"""Evaluation process."""
args = get_args()
home = os.path.dirname(os.path.realpath(__file__))
config_path = os.path.join(home, 'data', args.dataset, 'config.json')
test_csv_path = os.path.join(home, 'data', args.dataset, args.csv)
ckpt_path = os.path.join(home, 'checkpoints')
config_path = os.path.join(param.data_path, 'data', param.dataset, 'config.json')
test_csv_path = os.path.join(param.data_path, 'data', param.dataset, param.test_csv)
ckpt_path = param.load_path
context.set_context(device_id=args.device_id)
if args.run_mode == 'graph':
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
context.set_context(device_id=param.device_id)
if param.run_mode == 'graph':
context.set_context(mode=context.GRAPH_MODE, device_target=param.device_target)
else:
context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target)
context.set_context(mode=context.PYNATIVE_MODE, device_target=param.device_target)
print(f"creating dataset from {test_csv_path}...")
net_config = config.TBNetConfig(config_path)
if args.device_target == 'Ascend':
if param.device_target == 'Ascend':
net_config.per_item_paths = math.ceil(net_config.per_item_paths / 16) * 16
net_config.embedding_dim = math.ceil(net_config.embedding_dim / 16) * 16
eval_ds = dataset.create(test_csv_path, net_config.per_item_paths, train=True).batch(net_config.batch_size)
print(f"creating TBNet from checkpoint {args.checkpoint_id} for evaluation...")
print(f"creating TBNet from checkpoint {param.checkpoint_id} for evaluation...")
network = tbnet.TBNet(net_config)
if args.device_target == 'Ascend':
if param.device_target == 'Ascend':
network.to_float(mstype.float16)
param_dict = load_checkpoint(os.path.join(ckpt_path, f'tbnet_epoch{args.checkpoint_id}.ckpt'))
param_dict = load_checkpoint(os.path.join(ckpt_path, f'tbnet_epoch{param.checkpoint_id}.ckpt'))
load_param_into_net(network, param_dict)
loss_net = tbnet.NetWithLossClass(network, net_config)
......@@ -117,7 +62,10 @@ def eval_tbnet():
print("evaluating...")
e_out = model.eval(eval_ds, dataset_sink_mode=False)
print(f'Test AUC:{e_out ["auc"]} ACC:{e_out ["acc"]}')
print(f'Test AUC:{e_out["auc"]} ACC:{e_out["acc"]}')
if param.enable_modelarts:
with open(os.path.join(param.output_path, 'result.txt'), 'w') as f:
f.write(f'Test AUC:{e_out["auc"]} ACC:{e_out["acc"]}')
if __name__ == '__main__':
......
......@@ -18,44 +18,12 @@
import os
import io
import argparse
import json
import math
from src.path_gen import PathGen
from src.config import TBNetConfig
def get_args():
"""Parse commandline arguments."""
parser = argparse.ArgumentParser(description='Preprocess TB-Net data.')
parser.add_argument(
'--dataset',
type=str,
required=False,
default='steam',
help="'steam' dataset is supported currently"
)
parser.add_argument(
'--device_target',
type=str,
required=False,
default='GPU',
choices=['GPU', 'Ascend'],
help="run code on GPU or Ascend NPU"
)
parser.add_argument(
'--same_relation',
required=False,
action='store_true',
default=False,
help="only generate paths that relation1 is same as relation2"
)
return parser.parse_args()
from src.utils.param import param
def preprocess_csv(path_gen, data_home, src_name, out_name):
......@@ -69,17 +37,14 @@ def preprocess_csv(path_gen, data_home, src_name, out_name):
def preprocess_data():
"""Pre-process the dataset."""
args = get_args()
home = os.path.dirname(os.path.realpath(__file__))
data_home = os.path.join(home, 'data', args.dataset)
data_home = os.path.join(param.data_path, 'data', param.dataset)
config_path = os.path.join(data_home, 'config.json')
id_maps_path = os.path.join(data_home, 'id_maps.json')
cfg = TBNetConfig(config_path)
if args.device_target == 'Ascend':
if param.device_target == 'Ascend':
cfg.per_item_paths = math.ceil(cfg.per_item_paths / 16) * 16
path_gen = PathGen(per_item_paths=cfg.per_item_paths, same_relation=args.same_relation)
path_gen = PathGen(per_item_paths=cfg.per_item_paths, same_relation=param.same_relation)
preprocess_csv(path_gen, data_home, 'src_train.csv', 'train.csv')
......
......@@ -31,5 +31,6 @@ if [ $# == 4 ]; then
DEVICE_TARGET=$4
fi
python ../eval.py --checkpoint_id $CHECKPOINT_ID --dataset $DATA_NAME --device_target $DEVICE_TARGET \
--device_id $DEVICE_ID &> eval_standalone_gpu_log &
\ No newline at end of file
cd ..
python eval.py --checkpoint_id $CHECKPOINT_ID --dataset $DATA_NAME --device_target $DEVICE_TARGET \
--device_id $DEVICE_ID &> scripts/eval_standalone_gpu_log &
\ No newline at end of file
......@@ -29,5 +29,6 @@ if [ $# == 3 ]; then
DEVICE_TARGET=$3
fi
python ../preprocess_dataset.py --dataset $DATA_NAME --device_target $DEVICE_TARGET &> train_standalone_log &&
python ../train.py --dataset $DATA_NAME --device_target $DEVICE_TARGET --device_id $DEVICE_ID &>> train_standalone_log &
\ No newline at end of file
cd ..
python preprocess_dataset.py --dataset $DATA_NAME --device_target $DEVICE_TARGET &> scripts/train_standalone_log &&
python train.py --dataset $DATA_NAME --device_target $DEVICE_TARGET --device_id $DEVICE_ID &>> scripts/train_standalone_log &
\ No newline at end of file
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License Version 2.0(the "License");
# you may not use this file except in compliance with the License.
# you may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0#
#
# Unless required by applicable law or agreed to in writing software
# distributed under the License is distributed on an "AS IS" BASIS
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ====================================================================================
"""Device adapter for ModelArts"""
from .param import config
if config.enable_modelarts:
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
else:
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
__all__ = [
'get_device_id', 'get_device_num', 'get_job_id', 'get_rank_id'
]
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License Version 2.0(the "License");
# you may not use this file except in compliance with the License.
# you may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0#
#
# Unless required by applicable law or agreed to in writing software
# distributed under the License is distributed on an "AS IS" BASIS
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ====================================================================================
"""Local adapter"""
import os
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
return 'Local Job'
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License Version 2.0(the "License");
# you may not use this file except in compliance with the License.
# you may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0#
#
# Unless required by applicable law or agreed to in writing software
# distributed under the License is distributed on an "AS IS" BASIS
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ====================================================================================
"""Moxing adapter for ModelArts"""
import os
import functools
from mindspore import context
from .param import param
_global_syn_count = 0
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
job_id = os.getenv('JOB_ID')
job_id = job_id if job_id != "" else "default"
return job_id
def sync_data(from_path, to_path):
"""
Download data from remote obs to local directory if the first url is remote url and the second one is local
Uploca data from local directory to remote obs in contrast
"""
import moxing as mox
import time
global _global_syn_count
sync_lock = '/tmp/copy_sync.lock' + str(_global_syn_count)
_global_syn_count += 1
# Each server contains 8 devices as most
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print('from path: ', from_path)
print('to path: ', to_path)
mox.file.copy_parallel(from_path, to_path)
print('===finished data synchronization===')
try:
os.mknod(sync_lock)
except IOError:
pass
print('===save flag===')
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print('Finish sync data from {} to {}'.format(from_path, to_path))
def moxing_wrapper(pre_process=None, post_process=None):
"""
Moxing wrapper to download dataset and upload outputs
"""
def wrapper(run_func):
@functools.wraps(run_func)
def wrapped_func(*args, **kwargs):
# Download data from data_url
if param.enable_modelarts:
if param.data_url:
sync_data(param.data_url, param.data_path)
print('Dataset downloaded: ', os.listdir(param.data_path))
if param.checkpoint_url or param.ckpt_url:
if not os.path.exists(param.load_path):
os.makedirs(param.load_path)
print('=' * 20 + 'makedirs')
if os.path.isdir(param.load_path):
print('=' * 20 + 'makedirs success')
else:
print('=' * 20 + 'makedirs fail')
if param.checkpoint_url:
sync_data(param.checkpoint_url, param.load_path)
else:
sync_data(os.path.dirname(param.ckpt_url), param.load_path)
print('Preload downloaded: ', os.listdir(param.load_path))
if param.train_url:
sync_data(param.train_url, param.output_path)
print('Workspace downloaded: ', os.listdir(param.output_path))
context.set_context(save_graphs_path=os.path.join(param.output_path, str(get_rank_id())))
param.device_num = get_device_num()
param.device_id = get_device_id()
if not os.path.exists(param.output_path):
os.makedirs(param.output_path)
if pre_process:
pre_process()
run_func(*args, **kwargs)
# Upload data to train_url
if param.enable_modelarts:
if post_process:
post_process()
if param.train_url:
print('Start to copy output directory')
sync_data(param.output_path, param.train_url)
if param.result_url:
print('Start to copy output directory')
sync_data(param.output_path, param.result_url)
return wrapped_func
return wrapper
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License Version 2.0(the "License");
# you may not use this file except in compliance with the License.
# you may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0#
#
# Unless required by applicable law or agreed to in writing software
# distributed under the License is distributed on an "AS IS" BASIS
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ====================================================================================
"""Parse arguments"""
import os
import ast
import argparse
from pprint import pprint, pformat
import yaml
_config_path = '../../default_config.yaml'
class Param:
"""
Configuration namespace. Convert dictionary to members
"""
def __init__(self, cfg_dict):
for k, v in cfg_dict.items():
if isinstance(v, (list, tuple)):
setattr(self, k, [Param(x) if isinstance(x, dict) else x for x in v])
else:
setattr(self, k, Param(v) if isinstance(v, dict) else v)
def __str__(self):
return pformat(self.__dict__)
def __repr__(self):
return self.__str__()
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path='default_config.yaml'):
"""
Parse command line arguments to the configuration according to the default yaml
Args:
parser: Parent parser
cfg: Base configuration
helper: Helper description
cfg_path: Path to the default yaml config
"""
parser = argparse.ArgumentParser(description='[REPLACE THIS at param.py]',
parents=[parser])
helper = {} if helper is None else helper
choices = {} if choices is None else choices
for item in cfg:
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
help_description = helper[item] if item in helper else 'Please reference to {}'.format(cfg_path)
choice = choices[item] if item in choices else None
if isinstance(cfg[item], bool):
parser.add_argument('--' + item, type=ast.literal_eval, default=cfg[item], choices=choice,
help=help_description)
else:
parser.add_argument('--' + item, type=type(cfg[item]), default=cfg[item], choices=choice,
help=help_description)
args = parser.parse_args()
return args
def parse_yaml(yaml_path):
"""
Parse the yaml config file
Args:
yaml_path: Path to the yaml config
"""
with open(yaml_path, 'r') as fin:
try:
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
cfgs = [x for x in cfgs]
if len(cfgs) == 1:
cfg_helper = {}
cfg = cfgs[0]
cfg_choices = {}
elif len(cfgs) == 2:
cfg, cfg_helper = cfgs
cfg_choices = {}
elif len(cfgs) == 3:
cfg, cfg_helper, cfg_choices = cfgs
else:
raise ValueError('At most 3 docs (config description for help, choices) are supported in config yaml')
print(cfg_helper)
except:
raise ValueError('Failed to parse yaml')
return cfg, cfg_helper, cfg_choices
def merge(args, cfg):
"""
Merge the base config from yaml file and command line arguments
Args:
args: command line arguments
cfg: Base configuration
"""
args_var = vars(args)
for item in args_var:
cfg[item] = args_var[item]
return cfg
def get_param():
"""
Get Config according to the yaml file and cli arguments
"""
parser = argparse.ArgumentParser(description='default name', add_help=False)
current_dir = os.path.dirname(os.path.abspath(__file__))
parser.add_argument('--config_path', type=str, default=os.path.join(current_dir, _config_path),
help='Config file path')
path_args, _ = parser.parse_known_args()
default, helper, choices = parse_yaml(path_args.config_path)
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
final_config = merge(args, default)
pprint(final_config)
print("Please check the above information for the configurations", flush=True)
return Param(final_config)
param = get_param()
......@@ -15,7 +15,6 @@
"""TB-Net training."""
import os
import argparse
import math
import numpy as np
......@@ -26,6 +25,10 @@ import mindspore.common.dtype as mstype
from src import tbnet, config, metrics, dataset
from src.utils.param import param
from src.utils.moxing_adapter import moxing_wrapper
from preprocess_dataset import preprocess_data
class MyLossMonitor(Callback):
"""My loss monitor definition."""
......@@ -57,93 +60,26 @@ class MyLossMonitor(Callback):
print('loss:' + str(loss))
def get_args():
"""Parse commandline arguments."""
parser = argparse.ArgumentParser(description='Train TBNet.')
parser.add_argument(
'--dataset',
type=str,
required=False,
default='steam',
help="'steam' dataset is supported currently"
)
parser.add_argument(
'--train_csv',
type=str,
required=False,
default='train.csv',
help="the train csv datafile inside the dataset folder"
)
parser.add_argument(
'--test_csv',
type=str,
required=False,
default='test.csv',
help="the test csv datafile inside the dataset folder"
)
parser.add_argument(
'--device_id',
type=int,
required=False,
default=0,
help="device id"
)
parser.add_argument(
'--epochs',
type=int,
required=False,
default=20,
help="number of training epochs"
)
parser.add_argument(
'--device_target',
type=str,
required=False,
default='GPU',
choices=['GPU', 'Ascend'],
help="run code on GPU or Ascend NPU"
)
parser.add_argument(
'--run_mode',
type=str,
required=False,
default='graph',
choices=['graph', 'pynative'],
help="run code by GRAPH mode or PYNATIVE mode"
)
return parser.parse_args()
@moxing_wrapper(preprocess_data)
def train_tbnet():
"""Training process."""
args = get_args()
home = os.path.dirname(os.path.realpath(__file__))
config_path = os.path.join(home, 'data', args.dataset, 'config.json')
train_csv_path = os.path.join(home, 'data', args.dataset, args.train_csv)
test_csv_path = os.path.join(home, 'data', args.dataset, args.test_csv)
ckpt_path = os.path.join(home, 'checkpoints')
context.set_context(device_id=args.device_id)
if args.run_mode == 'graph':
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
config_path = os.path.join(param.data_path, 'data', param.dataset, 'config.json')
train_csv_path = os.path.join(param.data_path, 'data', param.dataset, param.train_csv)
test_csv_path = os.path.join(param.data_path, 'data', param.dataset, param.test_csv)
ckpt_path = param.load_path
context.set_context(device_id=param.device_id)
if param.run_mode == 'graph':
context.set_context(mode=context.GRAPH_MODE, device_target=param.device_target)
else:
context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target)
context.set_context(mode=context.PYNATIVE_MODE, device_target=param.device_target)
if not os.path.exists(ckpt_path):
os.makedirs(ckpt_path)
print(f"creating dataset from {train_csv_path}...")
net_config = config.TBNetConfig(config_path)
if args.device_target == 'Ascend':
if param.device_target == 'Ascend':
net_config.per_item_paths = math.ceil(net_config.per_item_paths / 16) * 16
net_config.embedding_dim = math.ceil(net_config.embedding_dim / 16) * 16
train_ds = dataset.create(train_csv_path, net_config.per_item_paths, train=True).batch(net_config.batch_size)
......@@ -152,7 +88,7 @@ def train_tbnet():
print("creating TBNet for training...")
network = tbnet.TBNet(net_config)
loss_net = tbnet.NetWithLossClass(network, net_config)
if args.device_target == 'Ascend':
if param.device_target == 'Ascend':
loss_net.to_float(mstype.float16)
train_net = tbnet.TrainStepWrap(loss_net, net_config.lr, loss_scale=True)
else:
......@@ -164,15 +100,18 @@ def train_tbnet():
loss_callback = MyLossMonitor()
model = Model(network=train_net, eval_network=eval_net, metrics={'auc': metrics.AUC(), 'acc': metrics.ACC()})
print("training...")
for i in range(args.epochs):
for i in range(param.epochs):
print(f'===================== Epoch {i} =====================')
model.train(epoch=1, train_dataset=train_ds, callbacks=[time_callback, loss_callback], dataset_sink_mode=False)
train_out = model.eval(train_ds, dataset_sink_mode=False)
test_out = model.eval(test_ds, dataset_sink_mode=False)
print(f'Train AUC:{train_out["auc"]} ACC:{train_out["acc"]} Test AUC:{test_out["auc"]} ACC:{test_out["acc"]}')
if i >= args.epochs-5:
save_checkpoint(network, os.path.join(ckpt_path, f'tbnet_epoch{i}.ckpt'))
if i >= param.epochs - 5:
if param.enable_modelarts:
save_checkpoint(network, os.path.join(param.output_path, f'tbnet_epoch{i}.ckpt'))
else:
save_checkpoint(network, os.path.join(ckpt_path, f'tbnet_epoch{i}.ckpt'))
if __name__ == '__main__':
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment