From f1a42b8104129225e2bcb107ca7d203ad6c02926 Mon Sep 17 00:00:00 2001
From: hemaohua <hemaohua@huawei.com>
Date: Sat, 27 Aug 2022 19:01:44 +0800
Subject: [PATCH] update octsqueeze train.py

---
 official/cv/east/scripts/run_standalone_train_ascend.sh | 4 ++--
 official/cv/east/train.py                               | 2 +-
 official/cv/octsqueeze/train.py                         | 2 +-
 research/cv/nas-fpn/train.py                            | 2 ++
 research/nlp/hypertext/create_dataset.py                | 2 ++
 research/nlp/hypertext/src/data_preprocessing.py        | 9 ++++-----
 6 files changed, 12 insertions(+), 9 deletions(-)

diff --git a/official/cv/east/scripts/run_standalone_train_ascend.sh b/official/cv/east/scripts/run_standalone_train_ascend.sh
index 22b1e7def..cfb958384 100644
--- a/official/cv/east/scripts/run_standalone_train_ascend.sh
+++ b/official/cv/east/scripts/run_standalone_train_ascend.sh
@@ -66,6 +66,6 @@ python train.py \
     --is_distributed=0 \
     --lr=0.001 \
     --max_epoch=600 \
-    --per_batch_size=24 \
+    --per_batch_size=8 \
     --lr_scheduler=my_lr  > log.txt 2>&1 &
-cd ..
\ No newline at end of file
+cd ..
diff --git a/official/cv/east/train.py b/official/cv/east/train.py
index d1a50aea6..9399c0eaa 100644
--- a/official/cv/east/train.py
+++ b/official/cv/east/train.py
@@ -61,7 +61,7 @@ parser.add_argument(
     '--per_batch_size',
     default=8,
     type=int,
-    help='Batch size for Training. Default: 24.')
+    help='Batch size for Training. Default: 8.')
 parser.add_argument(
     '--outputs_dir',
     default='outputs/',
diff --git a/official/cv/octsqueeze/train.py b/official/cv/octsqueeze/train.py
index 631e727a8..8edfd3fed 100644
--- a/official/cv/octsqueeze/train.py
+++ b/official/cv/octsqueeze/train.py
@@ -69,7 +69,7 @@ def parallel_init(argv):
     context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True,
                                       device_num=degree, parameter_broadcast=True)
 
-if name == '__main__':
+if __name__ == '__main__':
     if args.is_distributed == 1:
         network_init(args)
         parallel_init(args)
diff --git a/research/cv/nas-fpn/train.py b/research/cv/nas-fpn/train.py
index f3a5d58e7..6545d174e 100644
--- a/research/cv/nas-fpn/train.py
+++ b/research/cv/nas-fpn/train.py
@@ -188,5 +188,7 @@ def main():
     else:
         cb += [ckpt_cb]
         model.train(config.epoch_size, dataset, callbacks=cb, dataset_sink_mode=True)
+    print('train success!')
+
 if __name__ == '__main__':
     main()
diff --git a/research/nlp/hypertext/create_dataset.py b/research/nlp/hypertext/create_dataset.py
index 3443ae39f..a32cef166 100644
--- a/research/nlp/hypertext/create_dataset.py
+++ b/research/nlp/hypertext/create_dataset.py
@@ -35,6 +35,8 @@ def create_dir_not_exist(path):
 
 
 create_dir_not_exist(args.out_data_dir)
+args.data_dir = os.path.abspath(os.path.realpath(args.data_dir))
+args.out_data_dir = os.path.abspath(os.path.realpath(args.out_data_dir))
 
 if args.datasetType == 'iflytek':
     changeIflytek(args.data_dir, args.out_data_dir)
diff --git a/research/nlp/hypertext/src/data_preprocessing.py b/research/nlp/hypertext/src/data_preprocessing.py
index 783e6d26b..28a7b7933 100644
--- a/research/nlp/hypertext/src/data_preprocessing.py
+++ b/research/nlp/hypertext/src/data_preprocessing.py
@@ -18,7 +18,6 @@ import pkuseg
 from tqdm import tqdm
 
 seg = pkuseg.pkuseg()
-current_path = os.path.abspath(os.path.dirname(os.getcwd()))
 
 
 def changeListToText(content):
@@ -36,14 +35,14 @@ def changeIflytek(in_data_dir='', out_data_dir=''):
     for name in changeList:
         print(name)
         data = []
-        with open(current_path + in_data_dir + "/" + name, 'r', encoding='utf-8') as f:
+        with open(os.path.join(in_data_dir, name), 'r', encoding='utf-8') as f:
             line = f.readline()
             while line:
                 spData = line.split('_!_')
                 content = spData[1].strip('\n').replace('\t', '')
                 data.append({'content': content, 'label': spData[0]})
                 line = f.readline()
-        with open(current_path + out_data_dir + "/" + name, "w", encoding='utf-8') as f:
+        with open(os.path.join(out_data_dir, name), "w", encoding='utf-8') as f:
             for d in tqdm(data):
                 content = changeListToText(d['content'])
                 f.write(content + '\t' + d['label'] + '\n')
@@ -57,14 +56,14 @@ def changeTnews(in_data_dir='', out_data_dir=''):
         print(k)
         print(changeDict[k])
         data = []
-        with open(current_path + in_data_dir + "/" + k, 'r', encoding='utf-8') as f:
+        with open(os.path.join(in_data_dir, k), 'r', encoding='utf-8') as f:
             line = f.readline()
             while line:
                 spData = line.split('_!_')
                 content = spData[3].strip('\n').replace('\t', '')
                 data.append({'content': content, 'label': spData[1]})
                 line = f.readline()
-        with open(current_path + out_data_dir  + "/" + changeDict[k], "w", encoding='utf-8') as f:
+        with open(os.path.join(out_data_dir, changeDict[k]), "w", encoding='utf-8') as f:
             for d in tqdm(data):
                 content = changeListToText(d['content'])
                 f.write(content + '\t' + d['label'] + '\n')
-- 
GitLab