From 4067d264d35501e824a26f6e84b624d39fde7284 Mon Sep 17 00:00:00 2001
From: deepr <hexiangdong2020@outlook.com>
Date: Thu, 15 Jul 2021 22:28:09 +0800
Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=85=B6=E4=BB=96=E5=9F=BA?=
 =?UTF-8?q?=E7=A1=80=E8=AE=BE=E6=96=BD=E6=96=87=E4=BB=B6?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 data_loader.py   | 264 +++++++++++++++++++++++++++++++++++++++++++++++
 eval_callback.py |  90 ++++++++++++++++
 utils.py         | 183 ++++++++++++++++++++++++++++++++
 3 files changed, 537 insertions(+)
 create mode 100644 data_loader.py
 create mode 100644 eval_callback.py
 create mode 100644 utils.py

diff --git a/data_loader.py b/data_loader.py
new file mode 100644
index 0000000..4bc3455
--- /dev/null
+++ b/data_loader.py
@@ -0,0 +1,264 @@
+# Copyright 2020-2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+import os
+from collections import deque
+import cv2
+import numpy as np
+from PIL import Image, ImageSequence
+import mindspore.dataset as ds
+import mindspore.dataset.vision.c_transforms as c_vision
+from mindspore.dataset.vision.utils import Inter
+from mindspore.communication.management import get_rank, get_group_size
+
+
+def _load_multipage_tiff(path):
+    """Load tiff images containing many images in the channel dimension"""
+    return np.array([np.array(p) for p in ImageSequence.Iterator(Image.open(path))])
+
+def _get_val_train_indices(length, fold, ratio=0.8):
+    assert 0 < ratio <= 1, "Train/total data ratio must be in range (0.0, 1.0]"
+    np.random.seed(0)
+    indices = np.arange(0, length, 1, dtype=np.int)
+    np.random.shuffle(indices)
+
+    if fold is not None:
+        indices = deque(indices)
+        indices.rotate(fold * round((1.0 - ratio) * length))
+        indices = np.array(indices)
+        train_indices = indices[:round(ratio * len(indices))]
+        val_indices = indices[round(ratio * len(indices)):]
+    else:
+        train_indices = indices
+        val_indices = []
+    return train_indices, val_indices
+
+def data_post_process(img, mask):
+
+    img = np.expand_dims(img, axis=0)
+    mask = (mask > 0.5).astype(np.int)
+    mask = (np.arange(mask.max() + 1) == mask[..., None]).astype(int)
+    mask = mask.transpose(2, 0, 1).astype(np.float32)
+    return img, mask
+
+
+def train_data_augmentation(img, mask):
+
+    h_flip = np.random.random()
+    if h_flip > 0.5:
+        img = np.flipud(img)
+        mask = np.flipud(mask)
+    v_flip = np.random.random()
+    if v_flip > 0.5:
+        img = np.fliplr(img)
+        mask = np.fliplr(mask)
+
+    left = int(np.random.uniform()*0.3*572)
+    right = int((1-np.random.uniform()*0.3)*572)
+    top = int(np.random.uniform()*0.3*572)
+    bottom = int((1-np.random.uniform()*0.3)*572)
+
+
+    img = img[top:bottom, left:right]
+    mask = mask[top:bottom, left:right]
+
+    #adjust brightness
+    brightness = np.random.uniform(-0.2, 0.2)
+    img = np.float32(img+brightness*np.ones(img.shape))
+    img = np.clip(img, -1.0, 1.0)
+
+    return img, mask
+
+
+def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cross_val_ind=1, run_distribute=False,
+                   do_crop=None, img_size=None):
+
+    images = _load_multipage_tiff(os.path.join(data_dir, 'train-volume.tif'))
+    masks = _load_multipage_tiff(os.path.join(data_dir, 'train-labels.tif'))
+
+    train_indices, val_indices = _get_val_train_indices(len(images), cross_val_ind)
+    train_images = images[train_indices]
+    train_masks = masks[train_indices]
+    train_images = np.repeat(train_images, repeat, axis=0)
+    train_masks = np.repeat(train_masks, repeat, axis=0)
+    val_images = images[val_indices]
+    val_masks = masks[val_indices]
+
+    train_image_data = {"image": train_images}
+    train_mask_data = {"mask": train_masks}
+    valid_image_data = {"image": val_images}
+    valid_mask_data = {"mask": val_masks}
+
+
+    ds_train_images = ds.NumpySlicesDataset(data=train_image_data, sampler=None, shuffle=False)
+    ds_train_masks = ds.NumpySlicesDataset(data=train_mask_data, sampler=None, shuffle=False)
+
+    if run_distribute:
+        rank_id = get_rank()
+        rank_size = get_group_size()
+        ds_train_images = ds.NumpySlicesDataset(data=train_image_data,
+                                                sampler=None,
+                                                shuffle=False,
+                                                num_shards=rank_size,
+                                                shard_id=rank_id)
+        ds_train_masks = ds.NumpySlicesDataset(data=train_mask_data,
+                                               sampler=None,
+                                               shuffle=False,
+                                               num_shards=rank_size,
+                                               shard_id=rank_id)
+
+    ds_valid_images = ds.NumpySlicesDataset(data=valid_image_data, sampler=None, shuffle=False)
+    ds_valid_masks = ds.NumpySlicesDataset(data=valid_mask_data, sampler=None, shuffle=False)
+
+    if do_crop != "None":
+        resize_size = [int(img_size[x] * do_crop[x] / 572) for x in range(len(img_size))]
+    else:
+        resize_size = img_size
+    c_resize_op = c_vision.Resize(size=(resize_size[0], resize_size[1]), interpolation=Inter.BILINEAR)
+    c_pad = c_vision.Pad(padding=(img_size[0] - resize_size[0]) // 2)
+    c_rescale_image = c_vision.Rescale(1.0/127.5, -1)
+    c_rescale_mask = c_vision.Rescale(1.0/255.0, 0)
+
+    c_trans_normalize_img = [c_rescale_image, c_resize_op, c_pad]
+    c_trans_normalize_mask = [c_rescale_mask, c_resize_op, c_pad]
+    c_center_crop = c_vision.CenterCrop(size=388)
+
+    train_image_ds = ds_train_images.map(input_columns="image", operations=c_trans_normalize_img)
+    train_mask_ds = ds_train_masks.map(input_columns="mask", operations=c_trans_normalize_mask)
+    train_ds = ds.zip((train_image_ds, train_mask_ds))
+    train_ds = train_ds.project(columns=["image", "mask"])
+    if augment:
+        augment_process = train_data_augmentation
+        c_resize_op = c_vision.Resize(size=(img_size[0], img_size[1]), interpolation=Inter.BILINEAR)
+        train_ds = train_ds.map(input_columns=["image", "mask"], operations=augment_process)
+        train_ds = train_ds.map(input_columns="image", operations=c_resize_op)
+        train_ds = train_ds.map(input_columns="mask", operations=c_resize_op)
+
+    if do_crop != "None":
+        train_ds = train_ds.map(input_columns="mask", operations=c_center_crop)
+    post_process = data_post_process
+    train_ds = train_ds.map(input_columns=["image", "mask"], operations=post_process)
+    train_ds = train_ds.shuffle(repeat*24)
+    train_ds = train_ds.batch(batch_size=train_batch_size, drop_remainder=True)
+
+    valid_image_ds = ds_valid_images.map(input_columns="image", operations=c_trans_normalize_img)
+    valid_mask_ds = ds_valid_masks.map(input_columns="mask", operations=c_trans_normalize_mask)
+    valid_ds = ds.zip((valid_image_ds, valid_mask_ds))
+    valid_ds = valid_ds.project(columns=["image", "mask"])
+    if do_crop != "None":
+        valid_ds = valid_ds.map(input_columns="mask", operations=c_center_crop)
+    post_process = data_post_process
+    valid_ds = valid_ds.map(input_columns=["image", "mask"], operations=post_process)
+    valid_ds = valid_ds.batch(batch_size=1, drop_remainder=True)
+
+    return train_ds, valid_ds
+
+class MultiClassDataset:
+    """
+    Read image and mask from original images, and split all data into train_dataset and val_dataset by `split`.
+    Get image path and mask path from a tree of directories,
+    images within one folder is an image, the image file named `"image.png"`, the mask file named `"mask.png"`.
+    """
+    def __init__(self, data_dir, repeat, is_train=False, split=0.8, shuffle=False):
+        self.data_dir = data_dir
+        self.is_train = is_train
+        self.split = (split != 1.0)
+        if self.split:
+            self.img_ids = sorted(next(os.walk(self.data_dir))[1])
+            self.train_ids = self.img_ids[:int(len(self.img_ids) * split)] * repeat
+            self.val_ids = self.img_ids[int(len(self.img_ids) * split):]
+        else:
+            self.train_ids = sorted(next(os.walk(os.path.join(self.data_dir, "train")))[1])
+            self.val_ids = sorted(next(os.walk(os.path.join(self.data_dir, "val")))[1])
+        if shuffle:
+            np.random.shuffle(self.train_ids)
+
+    def _read_img_mask(self, img_id):
+        if self.split:
+            path = os.path.join(self.data_dir, img_id)
+        elif self.is_train:
+            path = os.path.join(self.data_dir, "train", img_id)
+        else:
+            path = os.path.join(self.data_dir, "val", img_id)
+        img = cv2.imread(os.path.join(path, "image.png"))
+        mask = cv2.imread(os.path.join(path, "mask.png"), cv2.IMREAD_GRAYSCALE)
+        return img, mask
+
+    def __getitem__(self, index):
+        if self.is_train:
+            return self._read_img_mask(self.train_ids[index])
+        return self._read_img_mask(self.val_ids[index])
+
+    @property
+    def column_names(self):
+        column_names = ['image', 'mask']
+        return column_names
+
+    def __len__(self):
+        if self.is_train:
+            return len(self.train_ids)
+        return len(self.val_ids)
+
+def preprocess_img_mask(img, mask, num_classes, img_size, augment=False, eval_resize=False):
+    """
+    Preprocess for multi-class dataset.
+    Random crop and flip images and masks when augment is True.
+    """
+    if augment:
+        img_size_w = int(np.random.randint(img_size[0], img_size[0] * 1.5, 1))
+        img_size_h = int(np.random.randint(img_size[1], img_size[1] * 1.5, 1))
+        img = cv2.resize(img, (img_size_w, img_size_h))
+        mask = cv2.resize(mask, (img_size_w, img_size_h))
+        dw = int(np.random.randint(0, img_size_w - img_size[0] + 1, 1))
+        dh = int(np.random.randint(0, img_size_h - img_size[1] + 1, 1))
+        img = img[dh:dh+img_size[1], dw:dw+img_size[0], :]
+        mask = mask[dh:dh+img_size[1], dw:dw+img_size[0]]
+        if np.random.random() > 0.5:
+            flip_code = int(np.random.randint(-1, 2, 1))
+            img = cv2.flip(img, flip_code)
+            mask = cv2.flip(mask, flip_code)
+    else:
+        img = cv2.resize(img, img_size)
+        if not eval_resize:
+            mask = cv2.resize(mask, img_size)
+    img = (img.astype(np.float32) - 127.5) / 127.5
+    img = img.transpose(2, 0, 1)
+    if num_classes == 2:
+        mask = mask.astype(np.float32) / mask.max()
+        mask = (mask > 0.5).astype(np.int)
+    else:
+        mask = mask.astype(np.int)
+    mask = (np.arange(num_classes) == mask[..., None]).astype(int)
+    mask = mask.transpose(2, 0, 1).astype(np.float32)
+    return img, mask
+
+def create_multi_class_dataset(data_dir, img_size, repeat, batch_size, num_classes=2, is_train=False, augment=False,
+                               eval_resize=False, split=0.8, rank=0, group_size=1, python_multiprocessing=True,
+                               num_parallel_workers=8, shuffle=True):
+    """
+    Get generator dataset for multi-class dataset.
+    """
+    mc_dataset = MultiClassDataset(data_dir, repeat, is_train, split, shuffle)
+    sampler = ds.DistributedSampler(group_size, rank, shuffle=shuffle)
+    dataset = ds.GeneratorDataset(mc_dataset, mc_dataset.column_names, sampler=sampler)
+    compose_map_func = (lambda image, mask: preprocess_img_mask(image, mask, num_classes, tuple(img_size),
+                                                                augment and is_train, eval_resize))
+    dataset = dataset.map(operations=compose_map_func, input_columns=mc_dataset.column_names,
+                          output_columns=mc_dataset.column_names, column_order=mc_dataset.column_names,
+                          python_multiprocessing=python_multiprocessing,
+                          num_parallel_workers=num_parallel_workers)
+    dataset = dataset.batch(batch_size, drop_remainder=is_train)
+    dataset = dataset.repeat(1)
+    return dataset
diff --git a/eval_callback.py b/eval_callback.py
new file mode 100644
index 0000000..205fce0
--- /dev/null
+++ b/eval_callback.py
@@ -0,0 +1,90 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Evaluation callback when training"""
+
+import os
+import stat
+from mindspore import save_checkpoint
+from mindspore import log as logger
+from mindspore.train.callback import Callback
+
+class EvalCallBack(Callback):
+    """
+    Evaluation callback when training.
+
+    Args:
+        eval_function (function): evaluation function.
+        eval_param_dict (dict): evaluation parameters' configure dict.
+        interval (int): run evaluation interval, default is 1.
+        eval_start_epoch (int): evaluation start epoch, default is 1.
+        save_best_ckpt (bool): Whether to save best checkpoint, default is True.
+        besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`.
+        metrics_name (str): evaluation metrics name, default is `acc`.
+
+    Returns:
+        None
+
+    Examples:
+        >>> EvalCallBack(eval_function, eval_param_dict)
+    """
+
+    def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True,
+                 ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"):
+        super(EvalCallBack, self).__init__()
+        self.eval_param_dict = eval_param_dict
+        self.eval_function = eval_function
+        self.eval_start_epoch = eval_start_epoch
+        if interval < 1:
+            raise ValueError("interval should >= 1.")
+        self.interval = interval
+        self.save_best_ckpt = save_best_ckpt
+        self.best_res = 0
+        self.best_epoch = 0
+        if not os.path.isdir(ckpt_directory):
+            os.makedirs(ckpt_directory)
+        self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name)
+        self.metrics_name = metrics_name
+
+    def remove_ckpoint_file(self, file_name):
+        """Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
+        try:
+            os.chmod(file_name, stat.S_IWRITE)
+            os.remove(file_name)
+        except OSError:
+            logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
+        except ValueError:
+            logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
+
+    def epoch_end(self, run_context):
+        """Callback when epoch end."""
+        cb_params = run_context.original_args()
+        cur_epoch = cb_params.cur_epoch_num
+        if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0:
+            res = self.eval_function(self.eval_param_dict)
+            print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True)
+            if res >= self.best_res:
+                self.best_res = res
+                self.best_epoch = cur_epoch
+                print("update best result: {}".format(res), flush=True)
+                if self.save_best_ckpt:
+                    if os.path.exists(self.bast_ckpt_path):
+                        self.remove_ckpoint_file(self.bast_ckpt_path)
+                    save_checkpoint(cb_params.train_network, self.bast_ckpt_path)
+                    print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True)
+
+    def end(self, run_context):
+        print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name,
+                                                                                     self.best_res,
+                                                                                     self.best_epoch), flush=True)
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000..cd291ae
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,183 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+import time
+import cv2
+import numpy as np
+from PIL import Image
+from mindspore import nn
+from mindspore.ops import operations as ops
+from mindspore.train.callback import Callback
+from mindspore.common.tensor import Tensor
+from src.model_utils.config import config
+
+class UnetEval(nn.Cell):
+    """
+    Add Unet evaluation activation.
+    """
+    def __init__(self, net, need_slice=False):
+        super(UnetEval, self).__init__()
+        self.net = net
+        self.need_slice = need_slice
+        self.transpose = ops.Transpose()
+        self.softmax = ops.Softmax(axis=-1)
+        self.argmax = ops.Argmax(axis=-1)
+        self.squeeze = ops.Squeeze(axis=0)
+
+    def construct(self, x):
+        out = self.net(x)
+        if self.need_slice:
+            out = self.squeeze(out[-1:])
+        out = self.transpose(out, (0, 2, 3, 1))
+        softmax_out = self.softmax(out)
+        argmax_out = self.argmax(out)
+        return (softmax_out, argmax_out)
+
+class TempLoss(nn.Cell):
+    """A temp loss cell."""
+    def __init__(self):
+        super(TempLoss, self).__init__()
+        self.identity = ops.identity()
+    def construct(self, logits, label):
+        return self.identity(logits)
+
+def apply_eval(eval_param_dict):
+    """run Evaluation"""
+    model = eval_param_dict["model"]
+    dataset = eval_param_dict["dataset"]
+    metrics_name = eval_param_dict["metrics_name"]
+    index = 0 if metrics_name == "dice_coeff" else 1
+    eval_score = model.eval(dataset, dataset_sink_mode=False)["dice_coeff"][index]
+    return eval_score
+
+class dice_coeff(nn.Metric):
+    """Unet Metric, return dice coefficient and IOU."""
+    def __init__(self, print_res=True):
+        super(dice_coeff, self).__init__()
+        self.clear()
+        self.print_res = print_res
+
+    def clear(self):
+        self._dice_coeff_sum = 0
+        self._iou_sum = 0
+        self._samples_num = 0
+
+    def update(self, *inputs):
+        if len(inputs) != 2:
+            raise ValueError('Need 2 inputs ((y_softmax, y_argmax), y), but got {}'.format(len(inputs)))
+        y = self._convert_data(inputs[1])
+        self._samples_num += y.shape[0]
+        y = y.transpose(0, 2, 3, 1)
+        b, h, w, c = y.shape
+        if b != 1:
+            raise ValueError('Batch size should be 1 when in evaluation.')
+        y = y.reshape((h, w, c))
+        if config.eval_activate.lower() == "softmax":
+            y_softmax = np.squeeze(self._convert_data(inputs[0][0]), axis=0)
+            if config.eval_resize:
+                y_pred = []
+                for i in range(config.num_classes):
+                    y_pred.append(cv2.resize(np.uint8(y_softmax[:, :, i] * 255), (w, h)) / 255)
+                y_pred = np.stack(y_pred, axis=-1)
+            else:
+                y_pred = y_softmax
+        elif config.eval_activate.lower() == "argmax":
+            y_argmax = np.squeeze(self._convert_data(inputs[0][1]), axis=0)
+            y_pred = []
+            for i in range(config.num_classes):
+                if config.eval_resize:
+                    y_pred.append(cv2.resize(np.uint8(y_argmax == i), (w, h), interpolation=cv2.INTER_NEAREST))
+                else:
+                    y_pred.append(np.float32(y_argmax == i))
+            y_pred = np.stack(y_pred, axis=-1)
+        else:
+            raise ValueError('config eval_activate should be softmax or argmax.')
+        y_pred = y_pred.astype(np.float32)
+        inter = np.dot(y_pred.flatten(), y.flatten())
+        union = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten())
+
+        single_dice_coeff = 2 * float(inter) / float(union+1e-6)
+        single_iou = single_dice_coeff / (2 - single_dice_coeff)
+        if self.print_res:
+            print("single dice coeff is: {}, IOU is: {}".format(single_dice_coeff, single_iou))
+        self._dice_coeff_sum += single_dice_coeff
+        self._iou_sum += single_iou
+
+    def eval(self):
+        if self._samples_num == 0:
+            raise RuntimeError('Total samples num must not be 0.')
+        return (self._dice_coeff_sum / float(self._samples_num), self._iou_sum / float(self._samples_num))
+
+class StepLossTimeMonitor(Callback):
+
+    def __init__(self, batch_size, per_print_times=1):
+        super(StepLossTimeMonitor, self).__init__()
+        if not isinstance(per_print_times, int) or per_print_times < 0:
+            raise ValueError("print_step must be int and >= 0.")
+        self._per_print_times = per_print_times
+        self.batch_size = batch_size
+
+    def step_begin(self, run_context):
+        self.step_time = time.time()
+
+    def step_end(self, run_context):
+
+        step_seconds = time.time() - self.step_time
+        step_fps = self.batch_size*1.0/step_seconds
+
+        cb_params = run_context.original_args()
+        loss = cb_params.net_outputs
+
+        if isinstance(loss, (tuple, list)):
+            if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
+                loss = loss[0]
+
+        if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
+            loss = np.mean(loss.asnumpy())
+
+        cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
+
+        if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
+            raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format(
+                cb_params.cur_epoch_num, cur_step_in_epoch))
+        self.losses.append(loss)
+        if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
+            # TEST
+            print("step: %s, loss is %s, fps is %s" % (cur_step_in_epoch, loss, step_fps), flush=True)
+
+    def epoch_begin(self, run_context):
+        self.epoch_start = time.time()
+        self.losses = []
+
+    def epoch_end(self, run_context):
+        cb_params = run_context.original_args()
+        epoch_cost = time.time() - self.epoch_start
+        step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
+        step_fps = self.batch_size * 1.0 * step_in_epoch / epoch_cost
+        print("epoch: {:3d}, avg loss:{:.4f}, total cost: {:.3f} s, per step fps:{:5.3f}".format(
+            cb_params.cur_epoch_num, np.mean(self.losses), epoch_cost, step_fps), flush=True)
+
+def mask_to_image(mask):
+    return Image.fromarray((mask * 255).astype(np.uint8))
+
+
+def filter_checkpoint_parameter_by_list(param_dict, filter_list):
+    """remove useless parameters according to filter_list"""
+    for key in list(param_dict.keys()):
+        for name in filter_list:
+            if name in key:
+                print("Delete parameter from checkpoint: ", key)
+                del param_dict[key]
+                break
-- 
GitLab