diff --git a/research/cv/PDarts/src/call_backs.py b/research/cv/PDarts/src/call_backs.py index 4cac04de37decff05c4eeefa3f1b183c64922e32..93d4df358b3ea788a457e461ac5f3850149142b7 100644 --- a/research/cv/PDarts/src/call_backs.py +++ b/research/cv/PDarts/src/call_backs.py @@ -13,12 +13,6 @@ # limitations under the License. # ============================================================================ """train callbacks""" -try: - from moxing.framework import file - print("import moxing success") -except ModuleNotFoundError as e: - print(f'not modelarts env, error={e}') - import os import time @@ -132,6 +126,7 @@ class Val_Callback(Callback): ckpt_file = os.path.join(ckpt_path, 'model_checkpoint.ckpt') save_checkpoint(cb_params.train_network, ckpt_file) if self.checkpoint_path.startswith('s3://') or self.checkpoint_path.startswith('obs://'): + from moxing.framework import file file.copy_parallel(save_path, os.path.join( self.checkpoint_path, model_info)) print('==============save checkpoint finished===================') diff --git a/research/cv/PDarts/train.py b/research/cv/PDarts/train.py index 5b101c5522cf5702250f658fdbdb628f1fe6cd12..d59beceb979ae24d714e763cfdb35f33c4fe3fd2 100644 --- a/research/cv/PDarts/train.py +++ b/research/cv/PDarts/train.py @@ -13,12 +13,6 @@ # limitations under the License. # ============================================================================ """train the PDarts model""" -try: - from moxing.framework import file - print("import moxing success") -except ModuleNotFoundError as e: - print(f'not modelarts env, error={e}') - import os import time import logging @@ -136,6 +130,7 @@ def main(): load_param_into_net(network, param_dict) if args.data_url.startswith('s3://') or args.data_url.startswith('obs://'): + from moxing.framework import file data_url_cache = os.path.join(args.local_data_root, 'data') file.copy_parallel(args.data_url, data_url_cache) args.data_url = data_url_cache