From 4067d264d35501e824a26f6e84b624d39fde7284 Mon Sep 17 00:00:00 2001 From: deepr <hexiangdong2020@outlook.com> Date: Thu, 15 Jul 2021 22:28:09 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=85=B6=E4=BB=96=E5=9F=BA?= =?UTF-8?q?=E7=A1=80=E8=AE=BE=E6=96=BD=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data_loader.py | 264 +++++++++++++++++++++++++++++++++++++++++++++++ eval_callback.py | 90 ++++++++++++++++ utils.py | 183 ++++++++++++++++++++++++++++++++ 3 files changed, 537 insertions(+) create mode 100644 data_loader.py create mode 100644 eval_callback.py create mode 100644 utils.py diff --git a/data_loader.py b/data_loader.py new file mode 100644 index 0000000..4bc3455 --- /dev/null +++ b/data_loader.py @@ -0,0 +1,264 @@ +# Copyright 2020-2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import os +from collections import deque +import cv2 +import numpy as np +from PIL import Image, ImageSequence +import mindspore.dataset as ds +import mindspore.dataset.vision.c_transforms as c_vision +from mindspore.dataset.vision.utils import Inter +from mindspore.communication.management import get_rank, get_group_size + + +def _load_multipage_tiff(path): + """Load tiff images containing many images in the channel dimension""" + return np.array([np.array(p) for p in ImageSequence.Iterator(Image.open(path))]) + +def _get_val_train_indices(length, fold, ratio=0.8): + assert 0 < ratio <= 1, "Train/total data ratio must be in range (0.0, 1.0]" + np.random.seed(0) + indices = np.arange(0, length, 1, dtype=np.int) + np.random.shuffle(indices) + + if fold is not None: + indices = deque(indices) + indices.rotate(fold * round((1.0 - ratio) * length)) + indices = np.array(indices) + train_indices = indices[:round(ratio * len(indices))] + val_indices = indices[round(ratio * len(indices)):] + else: + train_indices = indices + val_indices = [] + return train_indices, val_indices + +def data_post_process(img, mask): + + img = np.expand_dims(img, axis=0) + mask = (mask > 0.5).astype(np.int) + mask = (np.arange(mask.max() + 1) == mask[..., None]).astype(int) + mask = mask.transpose(2, 0, 1).astype(np.float32) + return img, mask + + +def train_data_augmentation(img, mask): + + h_flip = np.random.random() + if h_flip > 0.5: + img = np.flipud(img) + mask = np.flipud(mask) + v_flip = np.random.random() + if v_flip > 0.5: + img = np.fliplr(img) + mask = np.fliplr(mask) + + left = int(np.random.uniform()*0.3*572) + right = int((1-np.random.uniform()*0.3)*572) + top = int(np.random.uniform()*0.3*572) + bottom = int((1-np.random.uniform()*0.3)*572) + + + img = img[top:bottom, left:right] + mask = mask[top:bottom, left:right] + + #adjust brightness + brightness = np.random.uniform(-0.2, 0.2) + img = np.float32(img+brightness*np.ones(img.shape)) + img = np.clip(img, -1.0, 1.0) + + return img, mask + + +def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cross_val_ind=1, run_distribute=False, + do_crop=None, img_size=None): + + images = _load_multipage_tiff(os.path.join(data_dir, 'train-volume.tif')) + masks = _load_multipage_tiff(os.path.join(data_dir, 'train-labels.tif')) + + train_indices, val_indices = _get_val_train_indices(len(images), cross_val_ind) + train_images = images[train_indices] + train_masks = masks[train_indices] + train_images = np.repeat(train_images, repeat, axis=0) + train_masks = np.repeat(train_masks, repeat, axis=0) + val_images = images[val_indices] + val_masks = masks[val_indices] + + train_image_data = {"image": train_images} + train_mask_data = {"mask": train_masks} + valid_image_data = {"image": val_images} + valid_mask_data = {"mask": val_masks} + + + ds_train_images = ds.NumpySlicesDataset(data=train_image_data, sampler=None, shuffle=False) + ds_train_masks = ds.NumpySlicesDataset(data=train_mask_data, sampler=None, shuffle=False) + + if run_distribute: + rank_id = get_rank() + rank_size = get_group_size() + ds_train_images = ds.NumpySlicesDataset(data=train_image_data, + sampler=None, + shuffle=False, + num_shards=rank_size, + shard_id=rank_id) + ds_train_masks = ds.NumpySlicesDataset(data=train_mask_data, + sampler=None, + shuffle=False, + num_shards=rank_size, + shard_id=rank_id) + + ds_valid_images = ds.NumpySlicesDataset(data=valid_image_data, sampler=None, shuffle=False) + ds_valid_masks = ds.NumpySlicesDataset(data=valid_mask_data, sampler=None, shuffle=False) + + if do_crop != "None": + resize_size = [int(img_size[x] * do_crop[x] / 572) for x in range(len(img_size))] + else: + resize_size = img_size + c_resize_op = c_vision.Resize(size=(resize_size[0], resize_size[1]), interpolation=Inter.BILINEAR) + c_pad = c_vision.Pad(padding=(img_size[0] - resize_size[0]) // 2) + c_rescale_image = c_vision.Rescale(1.0/127.5, -1) + c_rescale_mask = c_vision.Rescale(1.0/255.0, 0) + + c_trans_normalize_img = [c_rescale_image, c_resize_op, c_pad] + c_trans_normalize_mask = [c_rescale_mask, c_resize_op, c_pad] + c_center_crop = c_vision.CenterCrop(size=388) + + train_image_ds = ds_train_images.map(input_columns="image", operations=c_trans_normalize_img) + train_mask_ds = ds_train_masks.map(input_columns="mask", operations=c_trans_normalize_mask) + train_ds = ds.zip((train_image_ds, train_mask_ds)) + train_ds = train_ds.project(columns=["image", "mask"]) + if augment: + augment_process = train_data_augmentation + c_resize_op = c_vision.Resize(size=(img_size[0], img_size[1]), interpolation=Inter.BILINEAR) + train_ds = train_ds.map(input_columns=["image", "mask"], operations=augment_process) + train_ds = train_ds.map(input_columns="image", operations=c_resize_op) + train_ds = train_ds.map(input_columns="mask", operations=c_resize_op) + + if do_crop != "None": + train_ds = train_ds.map(input_columns="mask", operations=c_center_crop) + post_process = data_post_process + train_ds = train_ds.map(input_columns=["image", "mask"], operations=post_process) + train_ds = train_ds.shuffle(repeat*24) + train_ds = train_ds.batch(batch_size=train_batch_size, drop_remainder=True) + + valid_image_ds = ds_valid_images.map(input_columns="image", operations=c_trans_normalize_img) + valid_mask_ds = ds_valid_masks.map(input_columns="mask", operations=c_trans_normalize_mask) + valid_ds = ds.zip((valid_image_ds, valid_mask_ds)) + valid_ds = valid_ds.project(columns=["image", "mask"]) + if do_crop != "None": + valid_ds = valid_ds.map(input_columns="mask", operations=c_center_crop) + post_process = data_post_process + valid_ds = valid_ds.map(input_columns=["image", "mask"], operations=post_process) + valid_ds = valid_ds.batch(batch_size=1, drop_remainder=True) + + return train_ds, valid_ds + +class MultiClassDataset: + """ + Read image and mask from original images, and split all data into train_dataset and val_dataset by `split`. + Get image path and mask path from a tree of directories, + images within one folder is an image, the image file named `"image.png"`, the mask file named `"mask.png"`. + """ + def __init__(self, data_dir, repeat, is_train=False, split=0.8, shuffle=False): + self.data_dir = data_dir + self.is_train = is_train + self.split = (split != 1.0) + if self.split: + self.img_ids = sorted(next(os.walk(self.data_dir))[1]) + self.train_ids = self.img_ids[:int(len(self.img_ids) * split)] * repeat + self.val_ids = self.img_ids[int(len(self.img_ids) * split):] + else: + self.train_ids = sorted(next(os.walk(os.path.join(self.data_dir, "train")))[1]) + self.val_ids = sorted(next(os.walk(os.path.join(self.data_dir, "val")))[1]) + if shuffle: + np.random.shuffle(self.train_ids) + + def _read_img_mask(self, img_id): + if self.split: + path = os.path.join(self.data_dir, img_id) + elif self.is_train: + path = os.path.join(self.data_dir, "train", img_id) + else: + path = os.path.join(self.data_dir, "val", img_id) + img = cv2.imread(os.path.join(path, "image.png")) + mask = cv2.imread(os.path.join(path, "mask.png"), cv2.IMREAD_GRAYSCALE) + return img, mask + + def __getitem__(self, index): + if self.is_train: + return self._read_img_mask(self.train_ids[index]) + return self._read_img_mask(self.val_ids[index]) + + @property + def column_names(self): + column_names = ['image', 'mask'] + return column_names + + def __len__(self): + if self.is_train: + return len(self.train_ids) + return len(self.val_ids) + +def preprocess_img_mask(img, mask, num_classes, img_size, augment=False, eval_resize=False): + """ + Preprocess for multi-class dataset. + Random crop and flip images and masks when augment is True. + """ + if augment: + img_size_w = int(np.random.randint(img_size[0], img_size[0] * 1.5, 1)) + img_size_h = int(np.random.randint(img_size[1], img_size[1] * 1.5, 1)) + img = cv2.resize(img, (img_size_w, img_size_h)) + mask = cv2.resize(mask, (img_size_w, img_size_h)) + dw = int(np.random.randint(0, img_size_w - img_size[0] + 1, 1)) + dh = int(np.random.randint(0, img_size_h - img_size[1] + 1, 1)) + img = img[dh:dh+img_size[1], dw:dw+img_size[0], :] + mask = mask[dh:dh+img_size[1], dw:dw+img_size[0]] + if np.random.random() > 0.5: + flip_code = int(np.random.randint(-1, 2, 1)) + img = cv2.flip(img, flip_code) + mask = cv2.flip(mask, flip_code) + else: + img = cv2.resize(img, img_size) + if not eval_resize: + mask = cv2.resize(mask, img_size) + img = (img.astype(np.float32) - 127.5) / 127.5 + img = img.transpose(2, 0, 1) + if num_classes == 2: + mask = mask.astype(np.float32) / mask.max() + mask = (mask > 0.5).astype(np.int) + else: + mask = mask.astype(np.int) + mask = (np.arange(num_classes) == mask[..., None]).astype(int) + mask = mask.transpose(2, 0, 1).astype(np.float32) + return img, mask + +def create_multi_class_dataset(data_dir, img_size, repeat, batch_size, num_classes=2, is_train=False, augment=False, + eval_resize=False, split=0.8, rank=0, group_size=1, python_multiprocessing=True, + num_parallel_workers=8, shuffle=True): + """ + Get generator dataset for multi-class dataset. + """ + mc_dataset = MultiClassDataset(data_dir, repeat, is_train, split, shuffle) + sampler = ds.DistributedSampler(group_size, rank, shuffle=shuffle) + dataset = ds.GeneratorDataset(mc_dataset, mc_dataset.column_names, sampler=sampler) + compose_map_func = (lambda image, mask: preprocess_img_mask(image, mask, num_classes, tuple(img_size), + augment and is_train, eval_resize)) + dataset = dataset.map(operations=compose_map_func, input_columns=mc_dataset.column_names, + output_columns=mc_dataset.column_names, column_order=mc_dataset.column_names, + python_multiprocessing=python_multiprocessing, + num_parallel_workers=num_parallel_workers) + dataset = dataset.batch(batch_size, drop_remainder=is_train) + dataset = dataset.repeat(1) + return dataset diff --git a/eval_callback.py b/eval_callback.py new file mode 100644 index 0000000..205fce0 --- /dev/null +++ b/eval_callback.py @@ -0,0 +1,90 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Evaluation callback when training""" + +import os +import stat +from mindspore import save_checkpoint +from mindspore import log as logger +from mindspore.train.callback import Callback + +class EvalCallBack(Callback): + """ + Evaluation callback when training. + + Args: + eval_function (function): evaluation function. + eval_param_dict (dict): evaluation parameters' configure dict. + interval (int): run evaluation interval, default is 1. + eval_start_epoch (int): evaluation start epoch, default is 1. + save_best_ckpt (bool): Whether to save best checkpoint, default is True. + besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`. + metrics_name (str): evaluation metrics name, default is `acc`. + + Returns: + None + + Examples: + >>> EvalCallBack(eval_function, eval_param_dict) + """ + + def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True, + ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"): + super(EvalCallBack, self).__init__() + self.eval_param_dict = eval_param_dict + self.eval_function = eval_function + self.eval_start_epoch = eval_start_epoch + if interval < 1: + raise ValueError("interval should >= 1.") + self.interval = interval + self.save_best_ckpt = save_best_ckpt + self.best_res = 0 + self.best_epoch = 0 + if not os.path.isdir(ckpt_directory): + os.makedirs(ckpt_directory) + self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name) + self.metrics_name = metrics_name + + def remove_ckpoint_file(self, file_name): + """Remove the specified checkpoint file from this checkpoint manager and also from the directory.""" + try: + os.chmod(file_name, stat.S_IWRITE) + os.remove(file_name) + except OSError: + logger.warning("OSError, failed to remove the older ckpt file %s.", file_name) + except ValueError: + logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name) + + def epoch_end(self, run_context): + """Callback when epoch end.""" + cb_params = run_context.original_args() + cur_epoch = cb_params.cur_epoch_num + if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0: + res = self.eval_function(self.eval_param_dict) + print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True) + if res >= self.best_res: + self.best_res = res + self.best_epoch = cur_epoch + print("update best result: {}".format(res), flush=True) + if self.save_best_ckpt: + if os.path.exists(self.bast_ckpt_path): + self.remove_ckpoint_file(self.bast_ckpt_path) + save_checkpoint(cb_params.train_network, self.bast_ckpt_path) + print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True) + + def end(self, run_context): + print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name, + self.best_res, + self.best_epoch), flush=True) diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..cd291ae --- /dev/null +++ b/utils.py @@ -0,0 +1,183 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import time +import cv2 +import numpy as np +from PIL import Image +from mindspore import nn +from mindspore.ops import operations as ops +from mindspore.train.callback import Callback +from mindspore.common.tensor import Tensor +from src.model_utils.config import config + +class UnetEval(nn.Cell): + """ + Add Unet evaluation activation. + """ + def __init__(self, net, need_slice=False): + super(UnetEval, self).__init__() + self.net = net + self.need_slice = need_slice + self.transpose = ops.Transpose() + self.softmax = ops.Softmax(axis=-1) + self.argmax = ops.Argmax(axis=-1) + self.squeeze = ops.Squeeze(axis=0) + + def construct(self, x): + out = self.net(x) + if self.need_slice: + out = self.squeeze(out[-1:]) + out = self.transpose(out, (0, 2, 3, 1)) + softmax_out = self.softmax(out) + argmax_out = self.argmax(out) + return (softmax_out, argmax_out) + +class TempLoss(nn.Cell): + """A temp loss cell.""" + def __init__(self): + super(TempLoss, self).__init__() + self.identity = ops.identity() + def construct(self, logits, label): + return self.identity(logits) + +def apply_eval(eval_param_dict): + """run Evaluation""" + model = eval_param_dict["model"] + dataset = eval_param_dict["dataset"] + metrics_name = eval_param_dict["metrics_name"] + index = 0 if metrics_name == "dice_coeff" else 1 + eval_score = model.eval(dataset, dataset_sink_mode=False)["dice_coeff"][index] + return eval_score + +class dice_coeff(nn.Metric): + """Unet Metric, return dice coefficient and IOU.""" + def __init__(self, print_res=True): + super(dice_coeff, self).__init__() + self.clear() + self.print_res = print_res + + def clear(self): + self._dice_coeff_sum = 0 + self._iou_sum = 0 + self._samples_num = 0 + + def update(self, *inputs): + if len(inputs) != 2: + raise ValueError('Need 2 inputs ((y_softmax, y_argmax), y), but got {}'.format(len(inputs))) + y = self._convert_data(inputs[1]) + self._samples_num += y.shape[0] + y = y.transpose(0, 2, 3, 1) + b, h, w, c = y.shape + if b != 1: + raise ValueError('Batch size should be 1 when in evaluation.') + y = y.reshape((h, w, c)) + if config.eval_activate.lower() == "softmax": + y_softmax = np.squeeze(self._convert_data(inputs[0][0]), axis=0) + if config.eval_resize: + y_pred = [] + for i in range(config.num_classes): + y_pred.append(cv2.resize(np.uint8(y_softmax[:, :, i] * 255), (w, h)) / 255) + y_pred = np.stack(y_pred, axis=-1) + else: + y_pred = y_softmax + elif config.eval_activate.lower() == "argmax": + y_argmax = np.squeeze(self._convert_data(inputs[0][1]), axis=0) + y_pred = [] + for i in range(config.num_classes): + if config.eval_resize: + y_pred.append(cv2.resize(np.uint8(y_argmax == i), (w, h), interpolation=cv2.INTER_NEAREST)) + else: + y_pred.append(np.float32(y_argmax == i)) + y_pred = np.stack(y_pred, axis=-1) + else: + raise ValueError('config eval_activate should be softmax or argmax.') + y_pred = y_pred.astype(np.float32) + inter = np.dot(y_pred.flatten(), y.flatten()) + union = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten()) + + single_dice_coeff = 2 * float(inter) / float(union+1e-6) + single_iou = single_dice_coeff / (2 - single_dice_coeff) + if self.print_res: + print("single dice coeff is: {}, IOU is: {}".format(single_dice_coeff, single_iou)) + self._dice_coeff_sum += single_dice_coeff + self._iou_sum += single_iou + + def eval(self): + if self._samples_num == 0: + raise RuntimeError('Total samples num must not be 0.') + return (self._dice_coeff_sum / float(self._samples_num), self._iou_sum / float(self._samples_num)) + +class StepLossTimeMonitor(Callback): + + def __init__(self, batch_size, per_print_times=1): + super(StepLossTimeMonitor, self).__init__() + if not isinstance(per_print_times, int) or per_print_times < 0: + raise ValueError("print_step must be int and >= 0.") + self._per_print_times = per_print_times + self.batch_size = batch_size + + def step_begin(self, run_context): + self.step_time = time.time() + + def step_end(self, run_context): + + step_seconds = time.time() - self.step_time + step_fps = self.batch_size*1.0/step_seconds + + cb_params = run_context.original_args() + loss = cb_params.net_outputs + + if isinstance(loss, (tuple, list)): + if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray): + loss = loss[0] + + if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray): + loss = np.mean(loss.asnumpy()) + + cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 + + if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)): + raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format( + cb_params.cur_epoch_num, cur_step_in_epoch)) + self.losses.append(loss) + if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: + # TEST + print("step: %s, loss is %s, fps is %s" % (cur_step_in_epoch, loss, step_fps), flush=True) + + def epoch_begin(self, run_context): + self.epoch_start = time.time() + self.losses = [] + + def epoch_end(self, run_context): + cb_params = run_context.original_args() + epoch_cost = time.time() - self.epoch_start + step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 + step_fps = self.batch_size * 1.0 * step_in_epoch / epoch_cost + print("epoch: {:3d}, avg loss:{:.4f}, total cost: {:.3f} s, per step fps:{:5.3f}".format( + cb_params.cur_epoch_num, np.mean(self.losses), epoch_cost, step_fps), flush=True) + +def mask_to_image(mask): + return Image.fromarray((mask * 255).astype(np.uint8)) + + +def filter_checkpoint_parameter_by_list(param_dict, filter_list): + """remove useless parameters according to filter_list""" + for key in list(param_dict.keys()): + for name in filter_list: + if name in key: + print("Delete parameter from checkpoint: ", key) + del param_dict[key] + break -- GitLab