diff --git a/research/cv/yolox/default_config.yaml b/research/cv/yolox/default_config.yaml
index 8b09560864a399afbf31e56c02f1a80cda35d451..de4a7cdd857b37b33532bc3ed87e2a7cb56c30f4 100644
--- a/research/cv/yolox/default_config.yaml
+++ b/research/cv/yolox/default_config.yaml
@@ -23,27 +23,30 @@ per_batch_size: 8
 
 # network configuration
 depth_wise: False
-max_gt: 48
+max_gt: 120
 num_classes: 80
 input_size: [640, 640]
 fpn_strides: [8, 16, 32]
 use_l1: False
 use_syc_bn: True
+updates: 0.0
 
 # dynamic_k
 n_candidate_k: 10
 
 # optimizer and lr related
-lr: 0.01
+lr: 0.011
 min_lr_ratio: 0.01
 warmup_epochs: 5
 weight_decay: 0.0005
 momentum: 0.9
 no_aug_epochs: 15
 # logging related
-log_interval: 36
+log_interval: 30
 ckpt_interval: -1
 is_save_on_master: 1
+ckpt_max_num: 60
+opt: "Momentum"
 
 # distributed related
 is_distributed: 1
@@ -72,6 +75,7 @@ log_path: "val/outputs/"
 val_ckpt: "0-2755_64.ckpt"
 conf_thre: 0.001
 nms_thre: 0.65
+eval_interval: 10
 # modelarts
 is_modelart: False
 result_path: ''
diff --git a/research/cv/yolox/eval.py b/research/cv/yolox/eval.py
index 640f7b057c7c66e239c33b629de6e624471111af..a5581c714a3005f201212b5945ff914a7b87e983 100644
--- a/research/cv/yolox/eval.py
+++ b/research/cv/yolox/eval.py
@@ -29,6 +29,7 @@ from src.yolox import DetectionBlock
 from src.yolox_dataset import create_yolox_dataset
 from src.initializer import default_recurisive_init
 
+
 def run_test():
     """The function of eval"""
     config.data_root = os.path.join(config.data_dir, 'val2017')
@@ -53,13 +54,20 @@ def run_test():
         backbone = "yolofpn"
     else:
         backbone = "yolopafpn"
-    network = DetectionBlock(config, backbone=backbone, is_training=False)  # default yolo-darknet53
+    network = DetectionBlock(config, backbone=backbone)  # default yolo-darknet53
     default_recurisive_init(network)
     config.logger.info(config.val_ckpt)
     if os.path.isfile(config.val_ckpt):
         param_dict = load_checkpoint(config.val_ckpt)
+        ema_param_dict = {}
+        for param in param_dict:
+            if param.startswith("ema."):
+                new_name = param.split("ema.")[1]
+                data = param_dict[param]
+                data.name = new_name
+                ema_param_dict[new_name] = data
 
-        load_param_into_net(network, param_dict)
+        load_param_into_net(network, ema_param_dict)
         config.logger.info('load model %s success', config.val_ckpt)
     else:
         config.logger.info('%s doesn''t exist or is not a pre-trained file', config.val_ckpt)
diff --git a/research/cv/yolox/scripts/run_distribute_train.sh b/research/cv/yolox/scripts/run_distribute_train.sh
index 5763acfe57020d2d1e6bb3c9a88024dd99065a94..f522edb234b6425e00f3b3144b6386a9b92903a6 100644
--- a/research/cv/yolox/scripts/run_distribute_train.sh
+++ b/research/cv/yolox/scripts/run_distribute_train.sh
@@ -14,8 +14,8 @@
 # limitations under the License.
 # ===========================================================================
 if [[ $# -lt 3 || $# -gt 4 ]];then
-    echo "Usage1: sh run_distribute_train.sh [DATASET_PATH] [RANK_TABLE_FILE] [BACKBONE]  for first data aug epochs"
-    echo "Usage2: sh run_distribute_train.sh [DATASET_PATH] [RANK_TABLE_FILE] [BACKBONE] [RESUME_CKPT] for last no data aug epochs"
+    echo "Usage1: bash run_distribute_train.sh [DATASET_PATH] [RANK_TABLE_FILE] [BACKBONE]  for first data aug epochs"
+    echo "Usage2: bash run_distribute_train.sh [DATASET_PATH] [RANK_TABLE_FILE] [BACKBONE] [RESUME_CKPT] for last no data aug epochs"
 exit 1
 fi
 
@@ -93,6 +93,7 @@ then
           --warmup_epochs=5 \
           --no_aug_epochs=15  \
           --min_lr_ratio=0.001 \
+          --eval_interval=10 \
           --lr_scheduler=yolox_warm_cos_lr  > log.txt 2>&1 &
       cd ..
   done
@@ -127,6 +128,7 @@ then
           --warmup_epochs=5 \
           --no_aug_epochs=15  \
           --min_lr_ratio=0.001 \
+          --eval_interval=1 \
           --lr_scheduler=yolox_warm_cos_lr  > log.txt 2>&1 &
       cd ..
   done
diff --git a/research/cv/yolox/scripts/run_eval.sh b/research/cv/yolox/scripts/run_eval.sh
index 2253b3f3f9e3020819678eeda1cb055e44a7b5a8..ca3c330c27fa517a70c083e3b746bc57bc3149d1 100644
--- a/research/cv/yolox/scripts/run_eval.sh
+++ b/research/cv/yolox/scripts/run_eval.sh
@@ -16,7 +16,7 @@
 
 if [ $# != 4 ]
 then
-    echo "Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] [BACKBONE] [BATCH_SIZE] "
+    echo "Usage: bash run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] [BACKBONE] [BATCH_SIZE] "
 exit 1
 fi
 
diff --git a/research/cv/yolox/scripts/run_standalone_train.sh b/research/cv/yolox/scripts/run_standalone_train.sh
index 9d9255b85d9e4fe99f7c46ed4bb18a8c17fc5c9a..c829aa638392b3f5f759d9837104509350d31058 100644
--- a/research/cv/yolox/scripts/run_standalone_train.sh
+++ b/research/cv/yolox/scripts/run_standalone_train.sh
@@ -62,8 +62,10 @@ if [ $# == 2 ]
 then
   echo "Start to launch first data augment epochs..."
   python train.py \
+        --data_dir=$DATASET_PATH \
         --data_aug=True \
         --is_distributed=0 \
+        --eval_interval=10 \
         --backbone=$BACKBONE > log.txt 2>&1 &
 fi
 
@@ -73,8 +75,10 @@ then
   CKPT_FILE=$(get_real_path $3)
   echo $CKPT_FILE
   python train.py \
+      --data_dir=$DATASET_PATH \
       --data_aug=False \
       --is_distributed=0 \
+      --eval_interval=1 \
       --backbone=$BACKBONE \
       --yolox_no_aug_ckpt=$CKPT_FILE > log.txt 2>&1 &
 fi
\ No newline at end of file
diff --git a/research/cv/yolox/src/transform.py b/research/cv/yolox/src/transform.py
index c62c4ec9baffc191ffc7c8035f401650efa6967e..bc68397985cb37d3bf01bdac2ceff05eec4efb95 100644
--- a/research/cv/yolox/src/transform.py
+++ b/research/cv/yolox/src/transform.py
@@ -20,80 +20,102 @@ import cv2
 import numpy as np
 
 
-def random_perspective(
-        img,
-        targets=(),
+def get_aug_params(value, center=0):
+    if len(value) != 2 and not isinstance(value, float):
+        raise ValueError(
+            "Affine params should be either a sequence containing two values\
+             or single float values. Got {}".format(value)
+        )
+    if isinstance(value, float):
+        return random.uniform(center - value, center + value)
+
+    return random.uniform(value[0], value[1])
+
+
+def get_affine_matrix(
+        target_size,
         degrees=10,
         translate=0.1,
-        scale=0.1,
+        scales=0.1,
         shear=10,
-        perspective=0.0,
-        border=(0, 0),
 ):
-    """ random perspective for images"""
-    height = img.shape[0] + border[0] * 2
-    width = img.shape[1] + border[1] * 2
-
-    # Center
-    C = np.eye(3)
-    C[0, 2] = -img.shape[1] / 2  # x translation (pixels)
-    C[1, 2] = -img.shape[0] / 2  # y translation (pixels)
+    twidth, theight = target_size
 
     # Rotation and Scale
-    R = np.eye(3)
-    a = random.uniform(-degrees, degrees)
-    s = random.uniform(scale[0], scale[1])
-    R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
+    angle = get_aug_params(degrees)
+    scale = get_aug_params(scales, center=1.0)
 
+    if scale <= 0.0:
+        raise ValueError("Argument scale should be positive")
+
+    R = cv2.getRotationMatrix2D(angle=angle, center=(0, 0), scale=scale)
+
+    M = np.ones([2, 3])
     # Shear
-    S = np.eye(3)
-    S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180)
-    S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180)
+    shear_x = math.tan(get_aug_params(shear) * math.pi / 180)
+    shear_y = math.tan(get_aug_params(shear) * math.pi / 180)
+
+    M[0] = R[0] + shear_y * R[1]
+    M[1] = R[1] + shear_x * R[0]
 
     # Translation
-    T = np.eye(3)
-    T[0, 2] = (random.uniform(0.5 - translate, 0.5 + translate) * width)  # x translation (pixels)
-    T[1, 2] = (random.uniform(0.5 - translate, 0.5 + translate) * height)  # y translation (pixels)
-
-    # Combined rotation matrix
-    M = T @ S @ R @ C  # order of operations (right to left) is IMPORTANT
-
-    if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any():  # image changed
-        if perspective:
-            img = cv2.warpPerspective(
-                img, M, dsize=(width, height), borderValue=(114, 114, 114)
-            )
-        else:  # affine
-            img = cv2.warpAffine(
-                img, M[:2], dsize=(width, height), borderValue=(114, 114, 114)
-            )
+    translation_x = get_aug_params(translate) * twidth  # x translation (pixels)
+    translation_y = get_aug_params(translate) * theight  # y translation (pixels)
 
-    # Transform label coordinates
-    n = len(targets)
-    if n:
-        xy = np.ones((n * 4, 3))
-        xy[:, :2] = targets[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(
-            n * 4, 2
+    M[0, 2] = translation_x
+    M[1, 2] = translation_y
+
+    return M, scale
+
+
+def apply_affine_to_bboxes(targets, target_size, M):
+    num_gts = len(targets)
+
+    # warp corner points
+    twidth, theight = target_size
+    corner_points = np.ones((4 * num_gts, 3))
+    corner_points[:, :2] = targets[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(
+        4 * num_gts, 2
+    )  # x1y1, x2y2, x1y2, x2y1
+    corner_points = corner_points @ M.T  # apply affine transform
+    corner_points = corner_points.reshape(num_gts, 8)
+
+    # create new boxes
+    corner_xs = corner_points[:, 0::2]
+    corner_ys = corner_points[:, 1::2]
+    new_bboxes = (
+        np.concatenate(
+            (corner_xs.min(1), corner_ys.min(1), corner_xs.max(1), corner_ys.max(1))
         )
-        xy = xy @ M.T
-        if perspective:
-            xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8)
-        else:
-            xy = xy[:, :2].reshape(n, 8)
+        .reshape(4, num_gts)
+        .T
+    )
+
+    # clip boxes
+    new_bboxes[:, 0::2] = new_bboxes[:, 0::2].clip(0, twidth)
+    new_bboxes[:, 1::2] = new_bboxes[:, 1::2].clip(0, theight)
+
+    targets[:, :4] = new_bboxes
+
+    return targets
+
 
-        # create new boxes
-        x = xy[:, [0, 2, 4, 6]]
-        y = xy[:, [1, 3, 5, 7]]
-        xy = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
+def random_affine(
+        img,
+        targets=(),
+        target_size=(640, 640),
+        degrees=10,
+        translate=0.1,
+        scales=0.1,
+        shear=10,
+):
+    M, _ = get_affine_matrix(target_size, degrees, translate, scales, shear)
 
-        # clip boxes
-        xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width)
-        xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height)
+    img = cv2.warpAffine(img, M, dsize=target_size, borderValue=(114, 114, 114))
 
-        # filter candidates
-        i = box_candidates(box1=targets[:, :4].T * s, box2=xy.T)
-        targets = targets[i]
-        targets[:, :4] = xy[i]
+    # Transform label coordinates
+    if targets:
+        targets = apply_affine_to_bboxes(targets, target_size, M)
 
     return img, targets
 
@@ -214,7 +236,9 @@ class TrainTransform:
         padded_labels[range(len(targets_t))[: self.max_labels]] = targets_t[: self.max_labels]
         padded_labels = np.ascontiguousarray(padded_labels, dtype=np.float32)
         gt_bboxes_per_image = padded_labels[:, 1:5]
+        # is_in_boxes_all [gt_max, 8400]
         is_in_boxes_all, is_in_boxes_and_center = self.get_in_boxes_info(gt_bboxes_per_image, true_labels)
+        # is_in_boxes_all [gt_max, 8400]
         is_in_boxes_all = is_in_boxes_all.any(1).reshape((-1, 1)) * is_in_boxes_all.any(0).reshape((1, -1))
         return image_t, padded_labels, is_in_boxes_all, is_in_boxes_and_center
 
diff --git a/research/cv/yolox/src/util.py b/research/cv/yolox/src/util.py
index e07beeb3134ae14009cc31cbb2dcb4488ffd8ed8..66e1f05f08046d1d0bfad26acccf97c500b0e5f4 100644
--- a/research/cv/yolox/src/util.py
+++ b/research/cv/yolox/src/util.py
@@ -19,13 +19,12 @@ import time
 import math
 import json
 import stat
-
 from datetime import datetime
 from collections import Counter
-
 import numpy as np
 import mindspore.common.dtype as mstype
-from mindspore import load_checkpoint, load_param_into_net, save_checkpoint
+from mindspore import load_checkpoint, load_param_into_net, save_checkpoint, Tensor, Parameter
+from mindspore.common.parameter import ParameterTuple
 from mindspore.train.callback import Callback
 from pycocotools.coco import COCO
 from pycocotools.cocoeval import COCOeval
@@ -280,7 +279,6 @@ def load_backbone(net, ckpt_path, args):
         else:
             param_not_load.append(param.name)
     args.logger.info("not loading param is :", len(param_not_load))
-
     return net
 
 
@@ -327,6 +325,41 @@ def keep_loss_fp32(network):
             cell.to_float(mstype.float32)
 
 
+class EMACallBack(Callback):
+
+    def __init__(self, network, steps_per_epoch, cur_steps=0):
+        self.steps_per_epoch = steps_per_epoch
+        self.cur_steps = cur_steps
+        self.network = network
+
+    def epoch_begin(self, run_context):
+        if self.network.ema:
+            if not isinstance(self.network.ema_moving_weight, list):
+                tmp_moving = []
+                for weight in self.network.ema_moving_weight:
+                    tmp_moving.append(weight.asnumpy())
+                self.network.ema_moving_weight = tmp_moving
+
+    def step_end(self, run_context):
+        if self.network.ema:
+            self.network.moving_parameter_update()
+            self.cur_steps += 1
+
+            if self.cur_steps % self.steps_per_epoch == 0:
+                if isinstance(self.network.ema_moving_weight, list):
+                    tmp_moving = []
+                    moving_name = []
+                    idx = 0
+                    for key in self.network.moving_name:
+                        moving_name.append(key)
+
+                    for weight in self.network.ema_moving_weight:
+                        param = Parameter(Tensor(weight), name=moving_name[idx])
+                        tmp_moving.append(param)
+                        idx += 1
+                    self.network.ema_moving_weight = ParameterTuple(tmp_moving)
+
+
 class YOLOXCB(Callback):
     """
     YOLOX Callback.
@@ -345,6 +378,7 @@ class YOLOXCB(Callback):
         self.current_step = 0
         self.save_ckpt_path = save_ckpt_path
         self.iter_time = time.time()
+        self.epoch_start_time = time.time()
         self.average_loss = []
         self.logger = logger
 
@@ -355,6 +389,8 @@ class YOLOXCB(Callback):
         Args:
             run_context (RunContext): Include some information of the model.
         """
+        self.epoch_start_time = time.time()
+        self.iter_time = time.time()
 
     def epoch_end(self, run_context):
         """
@@ -363,6 +399,15 @@ class YOLOXCB(Callback):
         Args:
             run_context (RunContext): Include some information of the model.
         """
+        cb_params = run_context.original_args()
+        cur_epoch = cb_params.cur_epoch_num
+        loss = cb_params.net_outputs
+        loss = "loss: %.4f, overflow: %s, scale: %s" % (float(loss[0].asnumpy()),
+                                                        bool(loss[1].asnumpy()),
+                                                        int(loss[2].asnumpy()))
+        self.logger.info(
+            "epoch: %s epoch time %.2fs %s" % (cur_epoch, time.time() - self.epoch_start_time, loss))
+
         if self.current_step % (self.step_per_epoch * 1) == 0:
             if self.is_modelarts:
                 import moxing as mox
@@ -387,19 +432,20 @@ class YOLOXCB(Callback):
         Args:
             run_context (RunContext): Include some information of the model.
         """
-        cb_params = run_context.original_args()
-        loss = cb_params.net_outputs
 
-        cur_epoch = self.current_step // self.step_per_epoch
-
-        loss = "loss: %.4f, overflow: %s, scale: %s" % (float(loss[0].asnumpy()),
-                                                        bool(loss[1].asnumpy()),
-                                                        int(loss[2].asnumpy()))
-        self.logger.info("epoch: %s step: [%s/%s], %s, lr: %.6f, time: %.2f" % (
-            cur_epoch, self.current_step % self.step_per_epoch, self.step_per_epoch, loss, self.lr[self.current_step],
-            (time.time() - self.iter_time) * 1000 / self._per_print_times))
-        self.iter_time = time.time()
-        self.current_step += self._per_print_times
+        cur_epoch_step = (self.current_step + 1) % self.step_per_epoch
+        if cur_epoch_step % self._per_print_times == 0 and cur_epoch_step != 0:
+            cb_params = run_context.original_args()
+            cur_epoch = cb_params.cur_epoch_num
+            loss = cb_params.net_outputs
+            loss = "loss: %.4f, overflow: %s, scale: %s" % (float(loss[0].asnumpy()),
+                                                            bool(loss[1].asnumpy()),
+                                                            int(loss[2].asnumpy()))
+            self.logger.info("epoch: %s step: [%s/%s], %s, lr: %.6f, avg step time: %.2f ms" % (
+                cur_epoch, cur_epoch_step, self.step_per_epoch, loss, self.lr[self.current_step],
+                (time.time() - self.iter_time) * 1000 / self._per_print_times))
+            self.iter_time = time.time()
+        self.current_step += 1
 
     def end(self, run_context):
         """
@@ -411,25 +457,47 @@ class YOLOXCB(Callback):
 
 
 class EvalCallBack(Callback):
-    def __init__(self, dataset, net, detection, logger, start_epoch=0,
-                 end_epoch=100, save_path=None, interval=1):
+    def __init__(self, dataset, test_net, train_net, detection, config, start_epoch=0, interval=1):
         self.dataset = dataset
-        self.network = net
+        self.network = train_net
+        self.test_network = test_net
         self.detection = detection
-        self.logger = logger
+        self.logger = config.logger
         self.start_epoch = start_epoch
-        self.save_path = save_path
         self.interval = interval
+        self.max_epoch = config.max_epoch
         self.best_result = 0
         self.best_epoch = 0
-        self.end_epoch = end_epoch
+        self.rank = config.rank
+
+    def load_ema_parameter(self):
+        param_dict = {}
+        for name, param in self.network.parameters_and_names():
+            if name.startswith("ema."):
+                new_name = name.split('ema.')[-1]
+                param_new = param.clone()
+                param_new.name = new_name
+                param_dict[new_name] = param_new
+        load_param_into_net(self.test_network, param_dict)
+
+    def load_network_parameter(self):
+        param_dict = {}
+        for name, param in self.network.parameters_and_names():
+            if name.startswith("network."):
+                param_new = param.clone()
+                param_dict[name] = param_new
+        load_param_into_net(self.test_network, param_dict)
 
     def epoch_end(self, run_context):
         cb_param = run_context.original_args()
         cur_epoch = cb_param.cur_epoch_num
         if cur_epoch >= self.start_epoch:
-            if (cur_epoch - self.start_epoch) % self.interval == 0:
-                self.network.set_train(False)
+            if (cur_epoch - self.start_epoch) % self.interval == 0 or cur_epoch == self.max_epoch:
+                if self.rank == 0:
+                    self.load_ema_parameter()
+                else:
+                    self.load_network_parameter()
+                self.test_network.set_train(False)
                 eval_print_str, results = self.inference()
                 if results >= self.best_result:
                     self.best_result = results
@@ -441,6 +509,9 @@ class EvalCallBack(Callback):
                 self.logger.info(eval_print_str)
                 self.logger.info('Ending inference...')
 
+    def end(self, run_context):
+        self.logger.info("Best result %s at %s epoch" % (self.best_result, self.best_epoch))
+
     def inference(self):
         self.logger.info('Start inference...')
         self.logger.info("eval dataset size, %s" % self.dataset.get_dataset_size())
@@ -449,7 +520,7 @@ class EvalCallBack(Callback):
             image = data['image']
             img_info = data['image_shape']
             img_id = data['img_id']
-            prediction = self.network(image)
+            prediction = self.test_network(image)
             prediction = prediction.asnumpy()
             img_shape = img_info.asnumpy()
             img_id = img_id.asnumpy()
@@ -461,8 +532,10 @@ class EvalCallBack(Callback):
         result_file_path = self.detection.evaluate_prediction()
         self.logger.info('result file path: %s', result_file_path)
         eval_result, results = self.detection.get_eval_result()
-        eval_print_str = '\n=============coco eval result=========\n' + eval_result
-        return eval_print_str, results
+        if eval_result is not None and results is not None:
+            eval_print_str = '\n=============coco eval result=========\n' + eval_result
+            return eval_print_str, results
+        return None, 0
 
     def remove_ckpoint_file(self, file_name):
         """Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
@@ -640,10 +713,18 @@ class DetectionEngine:
             raise RuntimeError("Unable to open json file to dump. What():{}".format(str(e)))
         else:
             f.close()
+            if not self.data_list:
+                self.file_path = ''
+                return self.file_path
+
+            self.data_list.clear()
             return self.file_path
 
     def get_eval_result(self):
         """Get eval result"""
+        if not self.file_path:
+            return None, None
+
         cocoGt = self._coco
         cocoDt = cocoGt.loadRes(self.file_path)
         cocoEval = COCOeval(cocoGt, cocoDt, 'bbox')
diff --git a/research/cv/yolox/src/yolox.py b/research/cv/yolox/src/yolox.py
index d09945dda46513448a190839c9c280105762e94a..48455b148c99b81ca876e458f85cec416ef4d0d8 100644
--- a/research/cv/yolox/src/yolox.py
+++ b/research/cv/yolox/src/yolox.py
@@ -18,6 +18,7 @@ import mindspore
 import mindspore.nn as nn
 from mindspore import Tensor, Parameter
 from mindspore import ops
+from mindspore.common.parameter import ParameterTuple
 from mindspore.ops import composite as C
 from mindspore.ops import functional as F
 from mindspore.ops import operations as P
@@ -113,138 +114,14 @@ class DetectionPerFPN(nn.Cell):
         return cls_output, reg_output, obj_output
 
 
-class DetectionBaseBlock(nn.Cell):
-    def __init__(self, config, backbone="yolopafpn"):
-        super(DetectionBaseBlock, self).__init__()
-        self.num_classes = config.num_classes
-        self.attr_num = self.num_classes + 5
-        self.depthwise = config.depth_wise
-        self.strides = Tensor([8, 16, 32], mindspore.float32)
-        self.input_size = config.input_size
-        # network
-        if backbone == "yolopafpn":
-            self.backbone = YOLOPAFPN(depth=1.33, width=1.25, input_w=self.input_size[1], input_h=self.input_size[0])
-            self.head_inchannels = [1024, 512, 256]
-            self.activation = "silu"
-            self.width = 1.25
-        else:
-            self.backbone = YOLOFPN(input_w=self.input_size[1], input_h=self.input_size[0])
-            self.head_inchannels = [512, 256, 128]
-            self.activation = "lrelu"
-            self.width = 1.0
-        self.head_l = DetectionPerFPN(in_channels=self.head_inchannels, num_classes=self.num_classes, scale='l',
-                                      act=self.activation, width=self.width)
-        self.head_m = DetectionPerFPN(in_channels=self.head_inchannels, num_classes=self.num_classes, scale='m',
-                                      act=self.activation, width=self.width)
-        self.head_s = DetectionPerFPN(in_channels=self.head_inchannels, num_classes=self.num_classes, scale='s',
-                                      act=self.activation, width=self.width)
-
-    def construct(self, x):
-        x_l, x_m, x_s = self.backbone(x)
-        cls_output_l, reg_output_l, obj_output_l = self.head_l(x_l)  # (bs, 80, 80, 80)(bs, 4, 80, 80)(bs, 1, 80, 80)
-        cls_output_m, reg_output_m, obj_output_m = self.head_m(x_m)  # (bs, 80, 40, 40)(bs, 4, 40, 40)(bs, 1, 40, 40)
-        cls_output_s, reg_output_s, obj_output_s = self.head_s(x_s)  # (bs, 80, 20, 20)(bs, 4, 20, 20)(bs, 1, 20, 20)
-        return cls_output_l, reg_output_l, obj_output_l, \
-               cls_output_m, reg_output_m, obj_output_m, \
-               cls_output_s, reg_output_s, obj_output_s
-
-    def mapping_to_img(self, output, stride):
-        """ map to origin image scale for each fpn """
-
-        batch_size = P.Shape()(output)[0]
-        n_ch = self.attr_num
-        grid_size = P.Shape()(output)[2:4]
-        range_x = range(grid_size[1])
-        range_y = range(grid_size[0])
-        stride = P.Cast()(stride, output.dtype)
-        grid_x = P.Cast()(F.tuple_to_array(range_x), output.dtype)
-        grid_y = P.Cast()(F.tuple_to_array(range_y), output.dtype)
-        grid_y = P.ExpandDims()(grid_y, 1)
-        grid_x = P.ExpandDims()(grid_x, 0)
-        yv = P.Tile()(grid_y, (1, grid_size[1]))
-        xv = P.Tile()(grid_x, (grid_size[0], 1))
-        grid = P.Stack(axis=2)([xv, yv])  # (80, 80, 2)
-        grid = P.Reshape()(grid, (1, 1, grid_size[0], grid_size[1], 2))  # (1,1,80,80,2)
-        output = P.Reshape()(output,
-                             (batch_size, n_ch, grid_size[0], grid_size[1]))  # bs, 6400, 85-->(bs,85,80,80)
-        output = P.Transpose()(output, (0, 2, 1, 3))  # (bs,85,80,80)-->(bs,80,85,80)
-        output = P.Transpose()(output, (0, 1, 3, 2))  # (bs,80,85,80)--->(bs, 80, 80, 85)
-        output = P.Reshape()(output, (batch_size, 1 * grid_size[0] * grid_size[1], -1))  # bs, 6400, 85
-        grid = P.Reshape()(grid, (1, -1, 2))  # grid(1, 6400, 2)
-
-        output_xy = output[..., :2]
-        output_xy = (output_xy + grid) * stride
-        output_wh = output[..., 2:4]
-        output_wh = P.Exp()(output_wh) * stride
-        output_other = output[..., 4:]
-        output_t = P.Concat(axis=-1)([output_xy, output_wh, output_other])
-        return output_t
-
-
-class DetectionTrainBlock(nn.Cell):
-    def __init__(self, base_network):
-        super(DetectionTrainBlock, self).__init__()
-        self.base_network = base_network
-
-    def construct(self, x):
-        cls_output_l, reg_output_l, obj_output_l, \
-        cls_output_m, reg_output_m, obj_output_m, \
-        cls_output_s, reg_output_s, obj_output_s = self.base_network(x)
-        output_l = P.Concat(axis=1)((reg_output_l, obj_output_l, cls_output_l))  # (bs, 85, 80, 80)
-        output_m = P.Concat(axis=1)((reg_output_m, obj_output_m, cls_output_m))  # (bs, 85, 40, 40)
-        output_s = P.Concat(axis=1)((reg_output_s, obj_output_s, cls_output_s))  # (bs, 85, 20, 20)
-        outputs = []
-        output_l = self.base_network.mapping_to_img(output_l,
-                                                    stride=self.base_network.strides[0])  # (bs, 6400, 85)x_c, y_c, w, h
-        output_m = self.base_network.mapping_to_img(output_m,
-                                                    stride=self.base_network.strides[1])  # (bs, 1600, 85)x_c, y_c, w, h
-        output_s = self.base_network.mapping_to_img(output_s,
-                                                    stride=self.base_network.strides[2])  # (bs,  400, 85)x_c, y_c, w, h
-        outputs.append(output_l)
-        outputs.append(output_m)
-        outputs.append(output_s)
-        return P.Concat(axis=1)(outputs)  # batch_size, 8400, 85
-
-
-class DetectionTestBlock(nn.Cell):
-    def __init__(self, base_network):
-        super(DetectionTestBlock, self).__init__()
-        self.base_network = base_network
-
-    def construct(self, x):
-        cls_output_l, reg_output_l, obj_output_l, \
-        cls_output_m, reg_output_m, obj_output_m, \
-        cls_output_s, reg_output_s, obj_output_s = self.base_network(x)
-        output_l = P.Concat(axis=1)(
-            (reg_output_l, P.Sigmoid()(obj_output_l), P.Sigmoid()(cls_output_l)))  # bs, 85, 80, 80
-
-        output_m = P.Concat(axis=1)(
-            (reg_output_m, P.Sigmoid()(obj_output_m), P.Sigmoid()(cls_output_m)))  # bs, 85, 40, 40
-
-        output_s = P.Concat(axis=1)(
-            (reg_output_s, P.Sigmoid()(obj_output_s), P.Sigmoid()(cls_output_s)))  # bs, 85, 20, 20
-        output_l = self.base_network.mapping_to_img(output_l,
-                                                    stride=self.base_network.strides[0])  # (bs, 6400, 85)x_c, y_c, w, h
-        output_m = self.base_network.mapping_to_img(output_m,
-                                                    stride=self.base_network.strides[1])  # (bs, 1600, 85)x_c, y_c, w, h
-        output_s = self.base_network.mapping_to_img(output_s,
-                                                    stride=self.base_network.strides[2])  # (bs,  400, 85)x_c, y_c, w, h
-        outputs = []
-        outputs.append(output_l)
-        outputs.append(output_m)
-        outputs.append(output_s)
-        return P.Concat(axis=1)(outputs)  # batch_size, 8400, 85
-
-
 class DetectionBlock(nn.Cell):
     """ connect yolox backbone and head """
 
-    def __init__(self, config, backbone="yolopafpn", is_training=True):
+    def __init__(self, config, backbone="yolopafpn"):
         super(DetectionBlock, self).__init__()
         self.num_classes = config.num_classes
         self.attr_num = self.num_classes + 5
         self.depthwise = config.depth_wise
-        self.is_training = is_training
         self.strides = Tensor([8, 16, 32], mindspore.float32)
         self.input_size = config.input_size
 
@@ -274,7 +151,7 @@ class DetectionBlock(nn.Cell):
         cls_output_l, reg_output_l, obj_output_l = self.head_l(x_l)  # (bs, 80, 80, 80)(bs, 4, 80, 80)(bs, 1, 80, 80)
         cls_output_m, reg_output_m, obj_output_m = self.head_m(x_m)  # (bs, 80, 40, 40)(bs, 4, 40, 40)(bs, 1, 40, 40)
         cls_output_s, reg_output_s, obj_output_s = self.head_s(x_s)  # (bs, 80, 20, 20)(bs, 4, 20, 20)(bs, 1, 20, 20)
-        if self.is_training:
+        if self.training:
             output_l = P.Concat(axis=1)((reg_output_l, obj_output_l, cls_output_l))  # (bs, 85, 80, 80)
             output_m = P.Concat(axis=1)((reg_output_m, obj_output_m, cls_output_m))  # (bs, 85, 40, 40)
             output_s = P.Concat(axis=1)((reg_output_s, obj_output_s, cls_output_s))  # (bs, 85, 20, 20)
@@ -533,25 +410,56 @@ def _tensor_grad_overflow(grad):
 class TrainOneStepWithEMA(nn.TrainOneStepWithLossScaleCell):
     """ Train one step with ema model """
 
-    def __init__(self, network, optimizer, scale_sense, ema=True, decay=0.9998, updates=0):
+    def __init__(self, network, optimizer, scale_sense, ema=True, decay=0.9998, updates=0, moving_name=None,
+                 ema_moving_weight=None):
         super(TrainOneStepWithEMA, self).__init__(network, optimizer, scale_sense)
         self.ema = ema
+        self.moving_name = moving_name
+        self.ema_moving_weight = ema_moving_weight
         if self.ema:
             self.ema_weight = self.weights.clone("ema", init='same')
             self.decay = decay
             self.updates = Parameter(Tensor(updates, mindspore.float32))
             self.assign = ops.Assign()
+            self.ema_moving_parameters()
+
+    def ema_moving_parameters(self):
+        self.moving_name = {}
+        moving_list = []
+        idx = 0
+        for key, param in self.network.parameters_and_names():
+            if "moving_mean" in key or "moving_variance" in key:
+                new_param = param.clone()
+                new_param.name = "ema." + param.name
+                moving_list.append(new_param)
+                self.moving_name["ema." + key] = idx
+                idx += 1
+        self.ema_moving_weight = ParameterTuple(moving_list)
 
     def ema_update(self):
         """Update EMA parameters."""
         if self.ema:
             self.updates += 1
             d = self.decay * (1 - ops.Exp()(-self.updates / 2000))
+            # update trainable parameters
             for ema_v, weight in zip(self.ema_weight, self.weights):
                 tep_v = ema_v * d
                 self.assign(ema_v, (1.0 - d) * weight + tep_v)
         return self.updates
 
+    # moving_parameter_update is executed inside the callback(EMACallBack)
+    def moving_parameter_update(self):
+        if self.ema:
+            d = (self.decay * (1 - ops.Exp()(-self.updates / 2000))).asnumpy().item()
+            # update moving mean and moving var
+            for key, param in self.network.parameters_and_names():
+                if "moving_mean" in key or "moving_variance" in key:
+                    idx = self.moving_name["ema." + key]
+                    moving_weight = param.asnumpy()
+                    tep_v = self.ema_moving_weight[idx] * d
+                    ema_value = (1.0 - d) * moving_weight + tep_v
+                    self.ema_moving_weight[idx] = ema_value
+
     def construct(self, *inputs):
         """ Forward """
         weights = self.weights
diff --git a/research/cv/yolox/src/yolox_dataset.py b/research/cv/yolox/src/yolox_dataset.py
index 8d3d4e111c0bf092562137cb56b636a3591d297b..aac2255ca16b54bb69ba826bb1bf6ced8be2bd26 100644
--- a/research/cv/yolox/src/yolox_dataset.py
+++ b/research/cv/yolox/src/yolox_dataset.py
@@ -22,7 +22,7 @@ import cv2
 import mindspore.dataset as de
 from pycocotools.coco import COCO
 
-from src.transform import box_candidates, random_perspective, TrainTransform, ValTransform
+from src.transform import box_candidates, random_affine, TrainTransform, ValTransform
 
 min_keypoints_per_image = 10
 
@@ -173,16 +173,15 @@ class COCOYoloXDataset:
                 np.clip(mosaic_labels[:, 2], 0, 2 * input_w, out=mosaic_labels[:, 2])
                 np.clip(mosaic_labels[:, 3], 0, 2 * input_h, out=mosaic_labels[:, 3])
 
-            mosaic_img, mosaic_labels = random_perspective(
+            mosaic_img, mosaic_labels = random_affine(
                 mosaic_img,
                 mosaic_labels,
+                target_size=(input_w, input_h),
                 degrees=self.degrees,
                 translate=self.translate,
-                scale=self.scale,
+                scales=self.scale,
                 shear=self.shear,
-                perspective=self.perspective,
-                border=[-input_h // 2, -input_w // 2],
-            )  # border to remove
+            )
 
             if (
                     self.enable_mixup
diff --git a/research/cv/yolox/train.py b/research/cv/yolox/train.py
index 1a2a60bab0ba1d8775a966dd2da805863de5744d..c8dc0b1f60ff2bd24687c281d35b823e3c1c2d54 100644
--- a/research/cv/yolox/train.py
+++ b/research/cv/yolox/train.py
@@ -20,8 +20,8 @@ import argparse
 
 from mindspore.context import ParallelMode
 from mindspore.common import set_seed
-from mindspore.nn import Momentum
-from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, SummaryCollector
+from mindspore.common.parameter import ParameterTuple
+from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
 from mindspore.communication.management import init, get_rank, get_group_size
 from mindspore import context, Model, DynamicLossScaleManager, load_checkpoint, load_param_into_net
 from mindspore.profiler.profiling import Profiler
@@ -30,12 +30,11 @@ from mindspore.common.tensor import Tensor
 from model_utils.config import config
 from model_utils.device_adapter import get_device_id, get_device_num
 from model_utils.moxing_adapter import moxing_wrapper
-from src.initializer import default_recurisive_init, load_resume_params
+from src.initializer import default_recurisive_init
 from src.logger import get_logger
 from src.network_blocks import use_syc_bn
-from src.util import get_param_groups, YOLOXCB, get_lr, load_backbone, EvalCallBack, DetectionEngine
-from src.yolox import YOLOLossCell, TrainOneStepWithEMA, DetectionBaseBlock, DetectionTrainBlock, \
-    DetectionTestBlock
+from src.util import get_param_groups, YOLOXCB, get_lr, load_backbone, EvalCallBack, DetectionEngine, EMACallBack
+from src.yolox import YOLOLossCell, TrainOneStepWithEMA, DetectionBlock
 from src.yolox_dataset import create_yolox_dataset
 
 set_seed(42)
@@ -193,6 +192,50 @@ def get_val_dataset():
     return ds_test
 
 
+def get_optimizer(cfg, network, lr):
+    param_group = get_param_groups(network, cfg.weight_decay)
+    if cfg.opt == "SGD":
+        from mindspore.nn import SGD
+        opt = SGD(params=param_group, learning_rate=Tensor(lr), momentum=config.momentum, nesterov=True)
+        cfg.logger.info("Use SGD Optimizer")
+    else:
+        from mindspore.nn import Momentum
+        opt = Momentum(params=param_group,
+                       learning_rate=Tensor(lr),
+                       momentum=cfg.momentum,
+                       use_nesterov=True)
+        cfg.logger.info("Use Momentum Optimizer")
+    return opt
+
+
+def load_resume_checkpoint(cfg, network, ckpt_path):
+    param_dict = load_checkpoint(ckpt_path)
+
+    ema_train_weight = []
+    ema_moving_weight = []
+    param_load = {}
+    for key, param in param_dict.items():
+        if key.startswith("network.") or key.startswith("moments."):
+            param_load[key] = param
+        elif "updates" in key:
+            cfg.updates = param
+            network.updates = cfg.updates
+            config.logger.info("network_ema updates:%s" % network.updates.asnumpy().item())
+    load_param_into_net(network, param_load)
+
+    for key, param in network.parameters_and_names():
+        if key.startswith("ema.") and "moving_mean" not in key and "moving_variance" not in key:
+            ema_train_weight.append(param_dict[key])
+        elif key.startswith("ema.") and ("moving_mean" in key or "moving_variance" in key):
+            ema_moving_weight.append(param_dict[key])
+
+    if network.ema:
+        if ema_train_weight and ema_moving_weight:
+            network.ema_weight = ParameterTuple(ema_train_weight)
+            network.ema_moving_weight = ParameterTuple(ema_moving_weight)
+            config.logger.info("successful loading ema weights")
+
+
 @moxing_wrapper(pre_process=modelarts_pre_process)
 def run_train():
     """ Launch Train process """
@@ -217,18 +260,16 @@ def run_train():
         backbone = "yolofpn"
     else:
         backbone = "yolopafpn"
-    base_network = DetectionBaseBlock(config, backbone=backbone)
-    test_network = DetectionTestBlock(base_network)
-    network = DetectionTrainBlock(base_network)  # train network
+    base_network = DetectionBlock(config, backbone=backbone)
     if config.pretrained:
-        network = load_backbone(network, config.pretrained, config)
+        base_network = load_backbone(base_network, config.pretrained, config)
     config.logger.info('Training backbone is: %s' % config.backbone)
     if config.use_syc_bn:
         config.logger.info("Using Synchronized batch norm layer...")
-        use_syc_bn(network)
-    default_recurisive_init(network)
+        use_syc_bn(base_network)
+    default_recurisive_init(base_network)
     config.logger.info("Network weights have been initialized...")
-    network = YOLOLossCell(network, config)
+    network = YOLOLossCell(base_network, config)
     config.logger.info('Finish getting network...')
     config.data_root = os.path.join(config.data_dir, 'train2017')
     config.annFile = os.path.join(config.data_dir, 'annotations/instances_train2017.json')
@@ -247,43 +288,45 @@ def run_train():
     lr = get_lr(config)
     config.logger.info("Learning rate scheduler:%s, base_lr:%s, min lr ratio:%s" % (config.lr_scheduler, config.lr,
                                                                                     config.min_lr_ratio))
-    opt = Momentum(params=get_param_groups(network, config.weight_decay),
-                   learning_rate=Tensor(lr),
-                   momentum=config.momentum,
-                   use_nesterov=True)
+    opt = get_optimizer(config, network, lr)
     loss_scale_manager = DynamicLossScaleManager(init_loss_scale=2 ** 22)
     update_cell = loss_scale_manager.get_update_cell()
     network_ema = TrainOneStepWithEMA(network, opt, update_cell,
-                                      ema=True, decay=0.9999, updates=0).set_train()
+                                      ema=True, decay=0.9998, updates=config.updates).set_train()
     if config.resume_yolox:
-        load_resume_params(config, network_ema)
+        resume_steps = config.updates.asnumpy().items()
+        config.resume_epoch = resume_steps // config.steps_per_epoch
+        lr = lr[resume_steps:]
+        opt = get_optimizer(config, network, lr)
+        network_ema = TrainOneStepWithEMA(network, opt, update_cell,
+                                          ema=True, decay=0.9998, updates=resume_steps).set_train()
+        load_resume_checkpoint(config, network_ema, config.resume_yolox)
+
     if not config.data_aug:
         if os.path.isfile(config.yolox_no_aug_ckpt):  # Loading the resume checkpoint for the last no data aug epochs
-            param_dict = load_checkpoint(config.yolox_no_aug_ckpt)
-            if "learning_rate" in param_dict:
-                param_dict.pop("learning_rate")
-            if "global_step" in param_dict:
-                param_dict.pop("global_step")
-            load_param_into_net(network_ema, param_dict)
+            load_resume_checkpoint(config, network_ema, config.yolox_no_aug_ckpt)
             config.logger.info("Finish load the resume checkpoint, begin to train the last...")
         else:
             raise FileNotFoundError('{} not exist or not a pre-trained file'.format(config.yolox_no_aug_ckpt))
+
     config.logger.info("Add ema model")
     model = Model(network_ema, amp_level="O0")
-    cb = [SummaryCollector(summary_dir="./summary_dir", collect_freq=1)]
+    cb = []
     save_ckpt_path = None
     if config.rank_save_ckpt_flag:
-        ckpt_max_num = config.max_epoch * config.steps_per_epoch // config.ckpt_interval
+        cb.append(EMACallBack(network_ema, config.steps_per_epoch))
         ckpt_config = CheckpointConfig(save_checkpoint_steps=config.steps_per_epoch * config.ckpt_interval,
-                                       keep_checkpoint_max=ckpt_max_num)
+                                       keep_checkpoint_max=config.ckpt_max_num)
         save_ckpt_path = os.path.join(config.outputs_dir, 'ckpt_' + str(config.rank) + '/')
         cb.append(ModelCheckpoint(config=ckpt_config,
                                   directory=save_ckpt_path,
-                                  prefix='{}'.format(config.rank)))
+                                  prefix='{}'.format(config.backbone)))
     cb.append(YOLOXCB(config.logger, config.steps_per_epoch, lr=lr, save_ckpt_path=save_ckpt_path,
                       is_modelart=config.enable_modelarts,
                       per_print_times=config.log_interval, train_url=args_opt.train_url))
-    cb.append(EvalCallBack(ds_test, test_network, DetectionEngine(config), config.logger, interval=513))
+    cb.append(
+        EvalCallBack(ds_test, DetectionBlock(config, backbone=backbone), network_ema, DetectionEngine(config), config,
+                     interval=config.eval_interval))
     if config.need_profiler:
         model.train(3, ds, callbacks=cb, dataset_sink_mode=True, sink_size=config.log_interval)
         profiler.analyse()
@@ -291,8 +334,8 @@ def run_train():
         config.logger.info("Epoch number:%s" % config.max_epoch)
         config.logger.info("All steps number:%s" % (config.max_epoch * config.steps_per_epoch))
         config.logger.info("==================Start Training=========================")
-        model.train(config.max_epoch * config.steps_per_epoch // config.log_interval, ds, callbacks=cb,
-                    dataset_sink_mode=True, sink_size=config.log_interval)
+        model.train(config.max_epoch, ds, callbacks=cb,
+                    dataset_sink_mode=False, sink_size=-1)
     config.logger.info("==================Training END======================")