diff --git a/research/cv/glore_res/train.py b/research/cv/glore_res/train.py index 7628ed81ba0908c888946e02c643063de9061a1a..7ba94409c524710e3942763e4bf32f055f1a851c 100644 --- a/research/cv/glore_res/train.py +++ b/research/cv/glore_res/train.py @@ -46,14 +46,14 @@ from src.save_callback import SaveCallback if config.isModelArts: import moxing as mox -if config.net == 'resnet200' or config.net == 'resnet101' or config.net == 'resnet50': +if config.net == 'resnet200' or config.net == 'resnet101': if config.device_target == "GPU": config.cast_fp16 = False random.seed(1) np.random.seed(1) de.config.set_seed(1) -if config.net == 'resnet200' or config.net == 'resnet101': +if config.net == 'resnet200' or config.net == 'resnet101' or config.net == 'resnet50': set_seed(1) if __name__ == '__main__':