diff --git a/official/cv/east/scripts/run_standalone_train_ascend.sh b/official/cv/east/scripts/run_standalone_train_ascend.sh index 22b1e7def7cab44635c81038e14c6500d8e6015f..cfb958384e19ebfe8bac11e8231ade14aeb80693 100644 --- a/official/cv/east/scripts/run_standalone_train_ascend.sh +++ b/official/cv/east/scripts/run_standalone_train_ascend.sh @@ -66,6 +66,6 @@ python train.py \ --is_distributed=0 \ --lr=0.001 \ --max_epoch=600 \ - --per_batch_size=24 \ + --per_batch_size=8 \ --lr_scheduler=my_lr > log.txt 2>&1 & -cd .. \ No newline at end of file +cd .. diff --git a/official/cv/east/train.py b/official/cv/east/train.py index d1a50aea659b73d3bc2b5aba2defb8ff0490ba0b..9399c0eaa52f19782050c1d8818493be4900a34e 100644 --- a/official/cv/east/train.py +++ b/official/cv/east/train.py @@ -61,7 +61,7 @@ parser.add_argument( '--per_batch_size', default=8, type=int, - help='Batch size for Training. Default: 24.') + help='Batch size for Training. Default: 8.') parser.add_argument( '--outputs_dir', default='outputs/', diff --git a/official/cv/octsqueeze/train.py b/official/cv/octsqueeze/train.py index 631e727a8340b989c1224513f8ee1878d5d505c8..8edfd3fedda1c1415745378856ded19d6fc54c57 100644 --- a/official/cv/octsqueeze/train.py +++ b/official/cv/octsqueeze/train.py @@ -69,7 +69,7 @@ def parallel_init(argv): context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=degree, parameter_broadcast=True) -if name == '__main__': +if __name__ == '__main__': if args.is_distributed == 1: network_init(args) parallel_init(args) diff --git a/research/cv/nas-fpn/train.py b/research/cv/nas-fpn/train.py index f3a5d58e73e0e8449bd9127ac703d466f63113b1..6545d174e8451311c3df3aa815327c35084662d3 100644 --- a/research/cv/nas-fpn/train.py +++ b/research/cv/nas-fpn/train.py @@ -188,5 +188,7 @@ def main(): else: cb += [ckpt_cb] model.train(config.epoch_size, dataset, callbacks=cb, dataset_sink_mode=True) + print('train success!') + if __name__ == '__main__': main() diff --git a/research/nlp/hypertext/create_dataset.py b/research/nlp/hypertext/create_dataset.py index 3443ae39fab9336094edb5fffbaf591f2eb887be..a32cef1667ca62792170aaa5fe35f593cea704a5 100644 --- a/research/nlp/hypertext/create_dataset.py +++ b/research/nlp/hypertext/create_dataset.py @@ -35,6 +35,8 @@ def create_dir_not_exist(path): create_dir_not_exist(args.out_data_dir) +args.data_dir = os.path.abspath(os.path.realpath(args.data_dir)) +args.out_data_dir = os.path.abspath(os.path.realpath(args.out_data_dir)) if args.datasetType == 'iflytek': changeIflytek(args.data_dir, args.out_data_dir) diff --git a/research/nlp/hypertext/src/data_preprocessing.py b/research/nlp/hypertext/src/data_preprocessing.py index 783e6d26b6b475e167aac40e0f659fbb0608d727..28a7b79331fc9400927315a41768a2bdfe2c8dd1 100644 --- a/research/nlp/hypertext/src/data_preprocessing.py +++ b/research/nlp/hypertext/src/data_preprocessing.py @@ -18,7 +18,6 @@ import pkuseg from tqdm import tqdm seg = pkuseg.pkuseg() -current_path = os.path.abspath(os.path.dirname(os.getcwd())) def changeListToText(content): @@ -36,14 +35,14 @@ def changeIflytek(in_data_dir='', out_data_dir=''): for name in changeList: print(name) data = [] - with open(current_path + in_data_dir + "/" + name, 'r', encoding='utf-8') as f: + with open(os.path.join(in_data_dir, name), 'r', encoding='utf-8') as f: line = f.readline() while line: spData = line.split('_!_') content = spData[1].strip('\n').replace('\t', '') data.append({'content': content, 'label': spData[0]}) line = f.readline() - with open(current_path + out_data_dir + "/" + name, "w", encoding='utf-8') as f: + with open(os.path.join(out_data_dir, name), "w", encoding='utf-8') as f: for d in tqdm(data): content = changeListToText(d['content']) f.write(content + '\t' + d['label'] + '\n') @@ -57,14 +56,14 @@ def changeTnews(in_data_dir='', out_data_dir=''): print(k) print(changeDict[k]) data = [] - with open(current_path + in_data_dir + "/" + k, 'r', encoding='utf-8') as f: + with open(os.path.join(in_data_dir, k), 'r', encoding='utf-8') as f: line = f.readline() while line: spData = line.split('_!_') content = spData[3].strip('\n').replace('\t', '') data.append({'content': content, 'label': spData[1]}) line = f.readline() - with open(current_path + out_data_dir + "/" + changeDict[k], "w", encoding='utf-8') as f: + with open(os.path.join(out_data_dir, changeDict[k]), "w", encoding='utf-8') as f: for d in tqdm(data): content = changeListToText(d['content']) f.write(content + '\t' + d['label'] + '\n')