diff --git a/research/cv/sknet/train.py b/research/cv/sknet/train.py index 3354c01c84b4f87000777b306d9784c62c45d394..2f1c518bd45b0f0805f87a6c331fc5f3cfedbcd0 100644 --- a/research/cv/sknet/train.py +++ b/research/cv/sknet/train.py @@ -30,7 +30,7 @@ from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.serialization import load_checkpoint, load_param_into_net from src.config import config1 as config -from src.dataset import create_dataset1 as create_dataset +from src.dataset import create_dataset_cifar10 as create_dataset from src.lr_generator import get_lr from src.sknet50 import sknet50 as sknet from src.var_init import KaimingNormal @@ -50,6 +50,7 @@ parser.add_argument('--device_target', type=str, default='Ascend', choices=["Asc help="Device target, support Ascend.") parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path') +parser.add_argument('--parameter_server', type=ast.literal_eval, default=False, help='Run parameter server train') args_opt = parser.parse_args() set_seed(1)