From 2c8123fbe8f58b9078339a2d2cf9ee025bc8a783 Mon Sep 17 00:00:00 2001
From: zhaoting <zhaoting23@huawei.com>
Date: Sat, 25 Jun 2022 10:04:11 +0800
Subject: [PATCH] fix some bugs

---
 official/cv/unet/scripts/run_distribute_train.sh | 3 +++
 official/cv/unet/src/data_loader.py              | 2 +-
 official/cv/unet/src/utils.py                    | 9 +++++----
 official/cv/unet/train.py                        | 2 +-
 official/recommend/naml/script/run_train.sh      | 7 -------
 5 files changed, 10 insertions(+), 13 deletions(-)

diff --git a/official/cv/unet/scripts/run_distribute_train.sh b/official/cv/unet/scripts/run_distribute_train.sh
index 27bfa1073..3d96d8de4 100644
--- a/official/cv/unet/scripts/run_distribute_train.sh
+++ b/official/cv/unet/scripts/run_distribute_train.sh
@@ -39,6 +39,9 @@ DATASET=$(get_real_path $2)
 CONFIG_PATH=$(get_real_path $3)
 RANK_TABLE=$(get_real_path $1)
 export RANK_TABLE_FILE=$RANK_TABLE
+
+ulimit -u unlimited
+
 for((i=0;i<RANK_SIZE;i++))
 do
     rm -rf LOG$i
diff --git a/official/cv/unet/src/data_loader.py b/official/cv/unet/src/data_loader.py
index 16cd33e04..bdbb0f386 100644
--- a/official/cv/unet/src/data_loader.py
+++ b/official/cv/unet/src/data_loader.py
@@ -257,7 +257,7 @@ def create_multi_class_dataset(data_dir, img_size, repeat, batch_size, num_class
     mc_dataset = MultiClassDataset(data_dir, repeat, is_train, split, shuffle)
     dataset = ds.GeneratorDataset(mc_dataset, mc_dataset.column_names, shuffle=True,
                                   num_shards=group_size, shard_id=rank,
-                                  num_parallel_workers=num_parallel_workers, python_multiprocessing=True)
+                                  num_parallel_workers=num_parallel_workers, python_multiprocessing=is_train)
     compose_map_func = (lambda image, mask: preprocess_img_mask(image, mask, num_classes, tuple(img_size),
                                                                 augment and is_train, eval_resize))
     dataset = dataset.map(operations=compose_map_func, input_columns=mc_dataset.column_names,
diff --git a/official/cv/unet/src/utils.py b/official/cv/unet/src/utils.py
index 0c3079a8a..d8222fe3e 100644
--- a/official/cv/unet/src/utils.py
+++ b/official/cv/unet/src/utils.py
@@ -92,10 +92,11 @@ class dice_coeff(nn.Metric):
         self._iou_sum = 0
         self._samples_num = 0
         self.img_num = 0
-        self.eval_images_path = "./draw_eval"
-        if os.path.exists(self.eval_images_path):
-            shutil.rmtree(self.eval_images_path)
-        os.mkdir(self.eval_images_path)
+        if self.show_eval:
+            self.eval_images_path = "./draw_eval"
+            if os.path.exists(self.eval_images_path):
+                shutil.rmtree(self.eval_images_path)
+            os.mkdir(self.eval_images_path)
 
     def draw_img(self, gray, index):
         """
diff --git a/official/cv/unet/train.py b/official/cv/unet/train.py
index 472c52c8d..22543d39a 100644
--- a/official/cv/unet/train.py
+++ b/official/cv/unet/train.py
@@ -114,7 +114,7 @@ def train_net(cross_valid_ind=1,
                   amp_level=amp_level)
     print("============== Starting Training ==============")
     callbacks = [StepLossTimeMonitor(batch_size=batch_size, per_print_times=per_print_times), ckpoint_cb]
-    if config.run_eval:
+    if config.run_eval and rank == 0:
         eval_model = Model(UnetEval(net, need_slice=need_slice, eval_activate=config.eval_activate.lower()),
                            loss_fn=TempLoss(), metrics={"dice_coeff": dice_coeff(False, config.show_eval)})
         eval_param_dict = {"model": eval_model, "dataset": valid_dataset, "metrics_name": config.eval_metrics}
diff --git a/official/recommend/naml/script/run_train.sh b/official/recommend/naml/script/run_train.sh
index 6c2873d40..1855b5ad1 100644
--- a/official/recommend/naml/script/run_train.sh
+++ b/official/recommend/naml/script/run_train.sh
@@ -40,10 +40,3 @@ python ${PROJECT_DIR}/../train.py \
     --save_checkpoint_path=${CHECKPOINT_PATH} \
     --weight_decay=False \
     --sink_mode=True
-
-python ${PROJECT_DIR}/../eval.py \
-    --config_path=${config_path} \
-    --platform=${PLATFORM} \
-    --dataset=${DATASET} \
-    --dataset_path=${DATASET_PATH} \
-    --checkpoint_path=${CHECKPOINT_PATH}/naml_last.ckpt
-- 
GitLab