diff --git a/research/cv/FaceAttribute/default_config.yaml b/research/cv/FaceAttribute/default_config.yaml
index e8916fe6af6f1982a31a009cd82c4b893617435d..f5f8534fc72650f98705ce0ab611370efec06e69 100644
--- a/research/cv/FaceAttribute/default_config.yaml
+++ b/research/cv/FaceAttribute/default_config.yaml
@@ -75,4 +75,4 @@ model_path: "pretrained model to load"
 # export option
 ckpt_file: "pretrained model to load"
 file_name: "file name"
-file_format: "file format, choices in ['MINDIR', 'AIR']"
\ No newline at end of file
+file_format: "file format, choices in ['MINDIR', 'AIR']"
diff --git a/research/cv/FaceAttribute/src/dataset_train.py b/research/cv/FaceAttribute/src/dataset_train.py
index 80d2cc2979d8c6a8722f99c9c59d325ffecc344e..d3f0c14ed8ed9bee64f6a4ef472658a1fc927453 100644
--- a/research/cv/FaceAttribute/src/dataset_train.py
+++ b/research/cv/FaceAttribute/src/dataset_train.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 # ============================================================================
 """Face attribute dataset for train"""
+import os
 import mindspore.dataset as de
 import mindspore.dataset.vision as F
 import mindspore.dataset.transforms as F2
@@ -27,7 +28,8 @@ def data_generator(args):
     dst_h = args.dst_h
     batch_size = args.per_batch_size
     attri_num = args.attri_num
-    max_epoch = args.max_epoch
+    if os.cpu_count() >= 192:
+        args.workers = 12
     transform_img = F2.Compose([F.Decode(True),
                                 F.Resize((dst_w, dst_h)),
                                 F.RandomHorizontalFlip(prob=0.5),
@@ -40,8 +42,7 @@ def data_generator(args):
                                 python_multiprocessing=True)
     de_dataset = de_dataset.batch(batch_size, drop_remainder=True)
     steps_per_epoch = de_dataset.get_dataset_size()
-    de_dataset = de_dataset.repeat(max_epoch)
-    de_dataloader = de_dataset.create_tuple_iterator(output_numpy=True)
+    de_dataloader = de_dataset.create_tuple_iterator()
 
     num_classes = attri_num
 
diff --git a/research/cv/FaceAttribute/train.py b/research/cv/FaceAttribute/train.py
index 2fc9d7159a9b53cb50fdab3c38336630cc18d2a4..540b64ffc0c69e4b1a659a42c26cda05f43d209e 100644
--- a/research/cv/FaceAttribute/train.py
+++ b/research/cv/FaceAttribute/train.py
@@ -19,14 +19,11 @@ import datetime
 import mindspore
 import mindspore.nn as nn
 from mindspore import context
-from mindspore import Tensor
 from mindspore.nn.optim import Momentum
 from mindspore.communication.management import get_group_size, init, get_rank
 from mindspore.nn import TrainOneStepCell
 from mindspore.context import ParallelMode
-from mindspore.train.callback import ModelCheckpoint, RunContext, CheckpointConfig
-from mindspore.train.serialization import load_checkpoint, load_param_into_net
-from mindspore.common import dtype as mstype
+from mindspore.train.serialization import load_checkpoint, load_param_into_net, save_checkpoint
 from src.FaceAttribute.resnet18 import get_resnet18
 from src.FaceAttribute.loss_factory import get_loss
 from src.dataset_train import data_generator
@@ -176,69 +173,30 @@ def run_train():
     # mixed precision training
     criterion.add_flags_recursive(fp32=True)
     train_net = TrainOneStepCell(train_net, opt, sens=config.loss_scale)
-
-    if config.local_rank == 0:
-        ckpt_max_num = config.max_epoch
-        train_config = CheckpointConfig(save_checkpoint_steps=config.steps_per_epoch, keep_checkpoint_max=ckpt_max_num)
-        ckpt_cb = ModelCheckpoint(config=train_config, directory=config.outputs_dir,
-                                  prefix='{}'.format(config.local_rank))
-        cb_params = InternalCallbackParam()
-        cb_params.train_network = train_net
-        cb_params.epoch_num = ckpt_max_num
-        cb_params.cur_epoch_num = 0
-        run_context = RunContext(cb_params)
-        ckpt_cb.begin(run_context)
-
     train_net.set_train()
-    t_end = time.time()
-    t_epoch = time.time()
-    old_progress = -1
-
-    i = 0
-    for _, (data, gt_classes) in enumerate(de_dataloader):
-
-        data_tensor = Tensor(data, dtype=mstype.float32)
-        gt_tensor = Tensor(gt_classes, dtype=mstype.int32)
-
-        loss = train_net(data_tensor, gt_tensor)
-        loss_meter.update(loss.asnumpy()[0])
-
-        if config.local_rank == 0:
-            cb_params.cur_step_num = i + 1
-            cb_params.batch_num = i + 2
-            ckpt_cb.step_end(run_context)
-
-        if (i + 1) % config.steps_per_epoch == 0 and config.local_rank == 0:
-            cb_params.cur_epoch_num += 1
-
-        if i == 0:
-            time_for_graph_compile = time.time() - create_network_start
-            config.logger.important_info(
-                '{}, graph compile time={:.2f}s'.format(config.backbone, time_for_graph_compile))
-
-        if (i + 1) % config.log_interval == 0 and config.local_rank == 0:
-            time_used = time.time() - t_end
-            epoch = int((i + 1) / config.steps_per_epoch)
-            fps = config.per_batch_size * (i - old_progress) * config.world_size / time_used
-            config.logger.info('epoch[{}], iter[{}], {}, {:.2f} imgs/sec'.format(epoch, i + 1, loss_meter, fps))
-
-            t_end = time.time()
-            loss_meter.reset()
-            old_progress = i
-
-        if (i + 1) % config.steps_per_epoch == 0 and config.local_rank == 0:
-            epoch_time_used = time.time() - t_epoch
-            epoch = int((i + 1) / config.steps_per_epoch)
-            fps = config.per_batch_size * config.world_size * config.steps_per_epoch / epoch_time_used
-            per_step_time = epoch_time_used / config.steps_per_epoch
-            config.logger.info('=================================================')
-            config.logger.info('epoch[{}], iter[{}], {:.2f} imgs/sec'.format(epoch, i + 1, fps))
-            config.logger.info('epoch[{}], epoch time: {:5.3f} ms, per step time: {:5.3f} ms'.format(
-                epoch, epoch_time_used * 1000, per_step_time * 1000))
-            config.logger.info('=================================================')
-            t_epoch = time.time()
-
-        i += 1
+
+    first_step = True
+    for epoch_idx in range(config.max_epoch):
+        epoch_begin_time = time.time()
+        for data_tensor, gt_tensor in de_dataloader:
+            loss = train_net(data_tensor, gt_tensor)
+            loss_meter.update(loss.asnumpy()[0])
+            if first_step:
+                time_for_graph_compile = time.time() - create_network_start
+                config.logger.important_info('{}, graph compile time={:.2f}s'.format(
+                    config.backbone, time_for_graph_compile))
+                first_step = False
+        epoch_end_time = time.time()
+        epoch_time = epoch_end_time - epoch_begin_time
+        fps = config.per_batch_size * config.world_size * config.steps_per_epoch / epoch_time
+        config.logger.info('=================================================')
+        config.logger.info('epoch[{}], iter[{}], {}, {:.2f} imgs/sec'.format(
+            epoch_idx, (epoch_idx + 1) * config.steps_per_epoch, loss_meter, fps))
+        config.logger.info('epoch[{}], epoch time: {:5.3f} ms, per step time: {:5.3f} ms'.format(
+            epoch_idx, epoch_time * 1000, epoch_time / config.steps_per_epoch * 1000))
+        if config.local_rank % 8 == 0:
+            save_checkpoint(train_net, os.path.join(config.outputs_dir, "{}-{}_{}.ckpt".format(
+                config.local_rank, epoch_idx, config.steps_per_epoch)))
 
     config.logger.info('--------- trains out ---------')