diff --git a/research/cv/ssd_ghostnet/src/dataset.py b/research/cv/ssd_ghostnet/src/dataset.py index 8eb7fa4dcfb311ade7a185c87b0de78d96d66a77..312f14fcecc132d841526d0c2ba4c222160a62f4 100644 --- a/research/cv/ssd_ghostnet/src/dataset.py +++ b/research/cv/ssd_ghostnet/src/dataset.py @@ -30,6 +30,9 @@ from src.model_utils.config import config from .box_utils import jaccard_numpy, ssd_bboxes_encode +de.config.set_seed(1) + + def _rand(a=0., b=1.): """Generate random.""" return np.random.rand() * (b - a) + a diff --git a/research/cv/ssd_ghostnet/train.py b/research/cv/ssd_ghostnet/train.py index 2c72d63b20f5777bb0b9cd6e362948d768f919a5..82384d738faa713cc5e90bf7796d22280545b01c 100644 --- a/research/cv/ssd_ghostnet/train.py +++ b/research/cv/ssd_ghostnet/train.py @@ -16,6 +16,8 @@ """Train SSD and get checkpoint files.""" import os +import random +from numpy.random import seed as seed_np import mindspore.nn as nn from mindspore import context, Tensor from mindspore.communication.management import init, get_rank @@ -23,6 +25,7 @@ from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMoni from mindspore.train import Model from mindspore.context import ParallelMode from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.common import set_seed as seed_ms from src.ssd_ghostnet import SSD300, SSDWithLossCell, TrainingWrapper, ssd_ghostnet from src.dataset import create_ssd_dataset, data_to_mindrecord_byte_image, voc_data_to_mindrecord from src.lr_schedule import get_lr @@ -31,6 +34,11 @@ from src.model_utils.config import config from src.model_utils.moxing_adapter import moxing_wrapper +random.seed(0) +seed_ms(0) +seed_np(0) + + @moxing_wrapper() def train_net(): """train net"""