diff --git a/research/cv/sknet/train.py b/research/cv/sknet/train.py
index 3354c01c84b4f87000777b306d9784c62c45d394..2f1c518bd45b0f0805f87a6c331fc5f3cfedbcd0 100644
--- a/research/cv/sknet/train.py
+++ b/research/cv/sknet/train.py
@@ -30,7 +30,7 @@ from mindspore.train.loss_scale_manager import FixedLossScaleManager
 from mindspore.train.serialization import load_checkpoint, load_param_into_net
 
 from src.config import config1 as config
-from src.dataset import create_dataset1 as create_dataset
+from src.dataset import create_dataset_cifar10 as create_dataset
 from src.lr_generator import get_lr
 from src.sknet50 import sknet50 as sknet
 from src.var_init import KaimingNormal
@@ -50,6 +50,7 @@ parser.add_argument('--device_target', type=str, default='Ascend', choices=["Asc
                     help="Device target, support Ascend.")
 parser.add_argument('--pre_trained', type=str, default=None,
                     help='Pretrained checkpoint path')
+parser.add_argument('--parameter_server', type=ast.literal_eval, default=False, help='Run parameter server train')
 args_opt = parser.parse_args()
 
 set_seed(1)