From f1a42b8104129225e2bcb107ca7d203ad6c02926 Mon Sep 17 00:00:00 2001 From: hemaohua <hemaohua@huawei.com> Date: Sat, 27 Aug 2022 19:01:44 +0800 Subject: [PATCH] update octsqueeze train.py --- official/cv/east/scripts/run_standalone_train_ascend.sh | 4 ++-- official/cv/east/train.py | 2 +- official/cv/octsqueeze/train.py | 2 +- research/cv/nas-fpn/train.py | 2 ++ research/nlp/hypertext/create_dataset.py | 2 ++ research/nlp/hypertext/src/data_preprocessing.py | 9 ++++----- 6 files changed, 12 insertions(+), 9 deletions(-) diff --git a/official/cv/east/scripts/run_standalone_train_ascend.sh b/official/cv/east/scripts/run_standalone_train_ascend.sh index 22b1e7def..cfb958384 100644 --- a/official/cv/east/scripts/run_standalone_train_ascend.sh +++ b/official/cv/east/scripts/run_standalone_train_ascend.sh @@ -66,6 +66,6 @@ python train.py \ --is_distributed=0 \ --lr=0.001 \ --max_epoch=600 \ - --per_batch_size=24 \ + --per_batch_size=8 \ --lr_scheduler=my_lr > log.txt 2>&1 & -cd .. \ No newline at end of file +cd .. diff --git a/official/cv/east/train.py b/official/cv/east/train.py index d1a50aea6..9399c0eaa 100644 --- a/official/cv/east/train.py +++ b/official/cv/east/train.py @@ -61,7 +61,7 @@ parser.add_argument( '--per_batch_size', default=8, type=int, - help='Batch size for Training. Default: 24.') + help='Batch size for Training. Default: 8.') parser.add_argument( '--outputs_dir', default='outputs/', diff --git a/official/cv/octsqueeze/train.py b/official/cv/octsqueeze/train.py index 631e727a8..8edfd3fed 100644 --- a/official/cv/octsqueeze/train.py +++ b/official/cv/octsqueeze/train.py @@ -69,7 +69,7 @@ def parallel_init(argv): context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=degree, parameter_broadcast=True) -if name == '__main__': +if __name__ == '__main__': if args.is_distributed == 1: network_init(args) parallel_init(args) diff --git a/research/cv/nas-fpn/train.py b/research/cv/nas-fpn/train.py index f3a5d58e7..6545d174e 100644 --- a/research/cv/nas-fpn/train.py +++ b/research/cv/nas-fpn/train.py @@ -188,5 +188,7 @@ def main(): else: cb += [ckpt_cb] model.train(config.epoch_size, dataset, callbacks=cb, dataset_sink_mode=True) + print('train success!') + if __name__ == '__main__': main() diff --git a/research/nlp/hypertext/create_dataset.py b/research/nlp/hypertext/create_dataset.py index 3443ae39f..a32cef166 100644 --- a/research/nlp/hypertext/create_dataset.py +++ b/research/nlp/hypertext/create_dataset.py @@ -35,6 +35,8 @@ def create_dir_not_exist(path): create_dir_not_exist(args.out_data_dir) +args.data_dir = os.path.abspath(os.path.realpath(args.data_dir)) +args.out_data_dir = os.path.abspath(os.path.realpath(args.out_data_dir)) if args.datasetType == 'iflytek': changeIflytek(args.data_dir, args.out_data_dir) diff --git a/research/nlp/hypertext/src/data_preprocessing.py b/research/nlp/hypertext/src/data_preprocessing.py index 783e6d26b..28a7b7933 100644 --- a/research/nlp/hypertext/src/data_preprocessing.py +++ b/research/nlp/hypertext/src/data_preprocessing.py @@ -18,7 +18,6 @@ import pkuseg from tqdm import tqdm seg = pkuseg.pkuseg() -current_path = os.path.abspath(os.path.dirname(os.getcwd())) def changeListToText(content): @@ -36,14 +35,14 @@ def changeIflytek(in_data_dir='', out_data_dir=''): for name in changeList: print(name) data = [] - with open(current_path + in_data_dir + "/" + name, 'r', encoding='utf-8') as f: + with open(os.path.join(in_data_dir, name), 'r', encoding='utf-8') as f: line = f.readline() while line: spData = line.split('_!_') content = spData[1].strip('\n').replace('\t', '') data.append({'content': content, 'label': spData[0]}) line = f.readline() - with open(current_path + out_data_dir + "/" + name, "w", encoding='utf-8') as f: + with open(os.path.join(out_data_dir, name), "w", encoding='utf-8') as f: for d in tqdm(data): content = changeListToText(d['content']) f.write(content + '\t' + d['label'] + '\n') @@ -57,14 +56,14 @@ def changeTnews(in_data_dir='', out_data_dir=''): print(k) print(changeDict[k]) data = [] - with open(current_path + in_data_dir + "/" + k, 'r', encoding='utf-8') as f: + with open(os.path.join(in_data_dir, k), 'r', encoding='utf-8') as f: line = f.readline() while line: spData = line.split('_!_') content = spData[3].strip('\n').replace('\t', '') data.append({'content': content, 'label': spData[1]}) line = f.readline() - with open(current_path + out_data_dir + "/" + changeDict[k], "w", encoding='utf-8') as f: + with open(os.path.join(out_data_dir, changeDict[k]), "w", encoding='utf-8') as f: for d in tqdm(data): content = changeListToText(d['content']) f.write(content + '\t' + d['label'] + '\n') -- GitLab