Skip to content
Snippets Groups Projects
Commit 4067d264 authored by 何向东's avatar 何向东
Browse files

添加其他基础设施文件

parent 33d26efa
No related branches found
No related tags found
No related merge requests found
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
from collections import deque
import cv2
import numpy as np
from PIL import Image, ImageSequence
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as c_vision
from mindspore.dataset.vision.utils import Inter
from mindspore.communication.management import get_rank, get_group_size
def _load_multipage_tiff(path):
"""Load tiff images containing many images in the channel dimension"""
return np.array([np.array(p) for p in ImageSequence.Iterator(Image.open(path))])
def _get_val_train_indices(length, fold, ratio=0.8):
assert 0 < ratio <= 1, "Train/total data ratio must be in range (0.0, 1.0]"
np.random.seed(0)
indices = np.arange(0, length, 1, dtype=np.int)
np.random.shuffle(indices)
if fold is not None:
indices = deque(indices)
indices.rotate(fold * round((1.0 - ratio) * length))
indices = np.array(indices)
train_indices = indices[:round(ratio * len(indices))]
val_indices = indices[round(ratio * len(indices)):]
else:
train_indices = indices
val_indices = []
return train_indices, val_indices
def data_post_process(img, mask):
img = np.expand_dims(img, axis=0)
mask = (mask > 0.5).astype(np.int)
mask = (np.arange(mask.max() + 1) == mask[..., None]).astype(int)
mask = mask.transpose(2, 0, 1).astype(np.float32)
return img, mask
def train_data_augmentation(img, mask):
h_flip = np.random.random()
if h_flip > 0.5:
img = np.flipud(img)
mask = np.flipud(mask)
v_flip = np.random.random()
if v_flip > 0.5:
img = np.fliplr(img)
mask = np.fliplr(mask)
left = int(np.random.uniform()*0.3*572)
right = int((1-np.random.uniform()*0.3)*572)
top = int(np.random.uniform()*0.3*572)
bottom = int((1-np.random.uniform()*0.3)*572)
img = img[top:bottom, left:right]
mask = mask[top:bottom, left:right]
#adjust brightness
brightness = np.random.uniform(-0.2, 0.2)
img = np.float32(img+brightness*np.ones(img.shape))
img = np.clip(img, -1.0, 1.0)
return img, mask
def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cross_val_ind=1, run_distribute=False,
do_crop=None, img_size=None):
images = _load_multipage_tiff(os.path.join(data_dir, 'train-volume.tif'))
masks = _load_multipage_tiff(os.path.join(data_dir, 'train-labels.tif'))
train_indices, val_indices = _get_val_train_indices(len(images), cross_val_ind)
train_images = images[train_indices]
train_masks = masks[train_indices]
train_images = np.repeat(train_images, repeat, axis=0)
train_masks = np.repeat(train_masks, repeat, axis=0)
val_images = images[val_indices]
val_masks = masks[val_indices]
train_image_data = {"image": train_images}
train_mask_data = {"mask": train_masks}
valid_image_data = {"image": val_images}
valid_mask_data = {"mask": val_masks}
ds_train_images = ds.NumpySlicesDataset(data=train_image_data, sampler=None, shuffle=False)
ds_train_masks = ds.NumpySlicesDataset(data=train_mask_data, sampler=None, shuffle=False)
if run_distribute:
rank_id = get_rank()
rank_size = get_group_size()
ds_train_images = ds.NumpySlicesDataset(data=train_image_data,
sampler=None,
shuffle=False,
num_shards=rank_size,
shard_id=rank_id)
ds_train_masks = ds.NumpySlicesDataset(data=train_mask_data,
sampler=None,
shuffle=False,
num_shards=rank_size,
shard_id=rank_id)
ds_valid_images = ds.NumpySlicesDataset(data=valid_image_data, sampler=None, shuffle=False)
ds_valid_masks = ds.NumpySlicesDataset(data=valid_mask_data, sampler=None, shuffle=False)
if do_crop != "None":
resize_size = [int(img_size[x] * do_crop[x] / 572) for x in range(len(img_size))]
else:
resize_size = img_size
c_resize_op = c_vision.Resize(size=(resize_size[0], resize_size[1]), interpolation=Inter.BILINEAR)
c_pad = c_vision.Pad(padding=(img_size[0] - resize_size[0]) // 2)
c_rescale_image = c_vision.Rescale(1.0/127.5, -1)
c_rescale_mask = c_vision.Rescale(1.0/255.0, 0)
c_trans_normalize_img = [c_rescale_image, c_resize_op, c_pad]
c_trans_normalize_mask = [c_rescale_mask, c_resize_op, c_pad]
c_center_crop = c_vision.CenterCrop(size=388)
train_image_ds = ds_train_images.map(input_columns="image", operations=c_trans_normalize_img)
train_mask_ds = ds_train_masks.map(input_columns="mask", operations=c_trans_normalize_mask)
train_ds = ds.zip((train_image_ds, train_mask_ds))
train_ds = train_ds.project(columns=["image", "mask"])
if augment:
augment_process = train_data_augmentation
c_resize_op = c_vision.Resize(size=(img_size[0], img_size[1]), interpolation=Inter.BILINEAR)
train_ds = train_ds.map(input_columns=["image", "mask"], operations=augment_process)
train_ds = train_ds.map(input_columns="image", operations=c_resize_op)
train_ds = train_ds.map(input_columns="mask", operations=c_resize_op)
if do_crop != "None":
train_ds = train_ds.map(input_columns="mask", operations=c_center_crop)
post_process = data_post_process
train_ds = train_ds.map(input_columns=["image", "mask"], operations=post_process)
train_ds = train_ds.shuffle(repeat*24)
train_ds = train_ds.batch(batch_size=train_batch_size, drop_remainder=True)
valid_image_ds = ds_valid_images.map(input_columns="image", operations=c_trans_normalize_img)
valid_mask_ds = ds_valid_masks.map(input_columns="mask", operations=c_trans_normalize_mask)
valid_ds = ds.zip((valid_image_ds, valid_mask_ds))
valid_ds = valid_ds.project(columns=["image", "mask"])
if do_crop != "None":
valid_ds = valid_ds.map(input_columns="mask", operations=c_center_crop)
post_process = data_post_process
valid_ds = valid_ds.map(input_columns=["image", "mask"], operations=post_process)
valid_ds = valid_ds.batch(batch_size=1, drop_remainder=True)
return train_ds, valid_ds
class MultiClassDataset:
"""
Read image and mask from original images, and split all data into train_dataset and val_dataset by `split`.
Get image path and mask path from a tree of directories,
images within one folder is an image, the image file named `"image.png"`, the mask file named `"mask.png"`.
"""
def __init__(self, data_dir, repeat, is_train=False, split=0.8, shuffle=False):
self.data_dir = data_dir
self.is_train = is_train
self.split = (split != 1.0)
if self.split:
self.img_ids = sorted(next(os.walk(self.data_dir))[1])
self.train_ids = self.img_ids[:int(len(self.img_ids) * split)] * repeat
self.val_ids = self.img_ids[int(len(self.img_ids) * split):]
else:
self.train_ids = sorted(next(os.walk(os.path.join(self.data_dir, "train")))[1])
self.val_ids = sorted(next(os.walk(os.path.join(self.data_dir, "val")))[1])
if shuffle:
np.random.shuffle(self.train_ids)
def _read_img_mask(self, img_id):
if self.split:
path = os.path.join(self.data_dir, img_id)
elif self.is_train:
path = os.path.join(self.data_dir, "train", img_id)
else:
path = os.path.join(self.data_dir, "val", img_id)
img = cv2.imread(os.path.join(path, "image.png"))
mask = cv2.imread(os.path.join(path, "mask.png"), cv2.IMREAD_GRAYSCALE)
return img, mask
def __getitem__(self, index):
if self.is_train:
return self._read_img_mask(self.train_ids[index])
return self._read_img_mask(self.val_ids[index])
@property
def column_names(self):
column_names = ['image', 'mask']
return column_names
def __len__(self):
if self.is_train:
return len(self.train_ids)
return len(self.val_ids)
def preprocess_img_mask(img, mask, num_classes, img_size, augment=False, eval_resize=False):
"""
Preprocess for multi-class dataset.
Random crop and flip images and masks when augment is True.
"""
if augment:
img_size_w = int(np.random.randint(img_size[0], img_size[0] * 1.5, 1))
img_size_h = int(np.random.randint(img_size[1], img_size[1] * 1.5, 1))
img = cv2.resize(img, (img_size_w, img_size_h))
mask = cv2.resize(mask, (img_size_w, img_size_h))
dw = int(np.random.randint(0, img_size_w - img_size[0] + 1, 1))
dh = int(np.random.randint(0, img_size_h - img_size[1] + 1, 1))
img = img[dh:dh+img_size[1], dw:dw+img_size[0], :]
mask = mask[dh:dh+img_size[1], dw:dw+img_size[0]]
if np.random.random() > 0.5:
flip_code = int(np.random.randint(-1, 2, 1))
img = cv2.flip(img, flip_code)
mask = cv2.flip(mask, flip_code)
else:
img = cv2.resize(img, img_size)
if not eval_resize:
mask = cv2.resize(mask, img_size)
img = (img.astype(np.float32) - 127.5) / 127.5
img = img.transpose(2, 0, 1)
if num_classes == 2:
mask = mask.astype(np.float32) / mask.max()
mask = (mask > 0.5).astype(np.int)
else:
mask = mask.astype(np.int)
mask = (np.arange(num_classes) == mask[..., None]).astype(int)
mask = mask.transpose(2, 0, 1).astype(np.float32)
return img, mask
def create_multi_class_dataset(data_dir, img_size, repeat, batch_size, num_classes=2, is_train=False, augment=False,
eval_resize=False, split=0.8, rank=0, group_size=1, python_multiprocessing=True,
num_parallel_workers=8, shuffle=True):
"""
Get generator dataset for multi-class dataset.
"""
mc_dataset = MultiClassDataset(data_dir, repeat, is_train, split, shuffle)
sampler = ds.DistributedSampler(group_size, rank, shuffle=shuffle)
dataset = ds.GeneratorDataset(mc_dataset, mc_dataset.column_names, sampler=sampler)
compose_map_func = (lambda image, mask: preprocess_img_mask(image, mask, num_classes, tuple(img_size),
augment and is_train, eval_resize))
dataset = dataset.map(operations=compose_map_func, input_columns=mc_dataset.column_names,
output_columns=mc_dataset.column_names, column_order=mc_dataset.column_names,
python_multiprocessing=python_multiprocessing,
num_parallel_workers=num_parallel_workers)
dataset = dataset.batch(batch_size, drop_remainder=is_train)
dataset = dataset.repeat(1)
return dataset
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Evaluation callback when training"""
import os
import stat
from mindspore import save_checkpoint
from mindspore import log as logger
from mindspore.train.callback import Callback
class EvalCallBack(Callback):
"""
Evaluation callback when training.
Args:
eval_function (function): evaluation function.
eval_param_dict (dict): evaluation parameters' configure dict.
interval (int): run evaluation interval, default is 1.
eval_start_epoch (int): evaluation start epoch, default is 1.
save_best_ckpt (bool): Whether to save best checkpoint, default is True.
besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`.
metrics_name (str): evaluation metrics name, default is `acc`.
Returns:
None
Examples:
>>> EvalCallBack(eval_function, eval_param_dict)
"""
def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True,
ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"):
super(EvalCallBack, self).__init__()
self.eval_param_dict = eval_param_dict
self.eval_function = eval_function
self.eval_start_epoch = eval_start_epoch
if interval < 1:
raise ValueError("interval should >= 1.")
self.interval = interval
self.save_best_ckpt = save_best_ckpt
self.best_res = 0
self.best_epoch = 0
if not os.path.isdir(ckpt_directory):
os.makedirs(ckpt_directory)
self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name)
self.metrics_name = metrics_name
def remove_ckpoint_file(self, file_name):
"""Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
try:
os.chmod(file_name, stat.S_IWRITE)
os.remove(file_name)
except OSError:
logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
except ValueError:
logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
def epoch_end(self, run_context):
"""Callback when epoch end."""
cb_params = run_context.original_args()
cur_epoch = cb_params.cur_epoch_num
if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0:
res = self.eval_function(self.eval_param_dict)
print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True)
if res >= self.best_res:
self.best_res = res
self.best_epoch = cur_epoch
print("update best result: {}".format(res), flush=True)
if self.save_best_ckpt:
if os.path.exists(self.bast_ckpt_path):
self.remove_ckpoint_file(self.bast_ckpt_path)
save_checkpoint(cb_params.train_network, self.bast_ckpt_path)
print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True)
def end(self, run_context):
print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name,
self.best_res,
self.best_epoch), flush=True)
utils.py 0 → 100644
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import time
import cv2
import numpy as np
from PIL import Image
from mindspore import nn
from mindspore.ops import operations as ops
from mindspore.train.callback import Callback
from mindspore.common.tensor import Tensor
from src.model_utils.config import config
class UnetEval(nn.Cell):
"""
Add Unet evaluation activation.
"""
def __init__(self, net, need_slice=False):
super(UnetEval, self).__init__()
self.net = net
self.need_slice = need_slice
self.transpose = ops.Transpose()
self.softmax = ops.Softmax(axis=-1)
self.argmax = ops.Argmax(axis=-1)
self.squeeze = ops.Squeeze(axis=0)
def construct(self, x):
out = self.net(x)
if self.need_slice:
out = self.squeeze(out[-1:])
out = self.transpose(out, (0, 2, 3, 1))
softmax_out = self.softmax(out)
argmax_out = self.argmax(out)
return (softmax_out, argmax_out)
class TempLoss(nn.Cell):
"""A temp loss cell."""
def __init__(self):
super(TempLoss, self).__init__()
self.identity = ops.identity()
def construct(self, logits, label):
return self.identity(logits)
def apply_eval(eval_param_dict):
"""run Evaluation"""
model = eval_param_dict["model"]
dataset = eval_param_dict["dataset"]
metrics_name = eval_param_dict["metrics_name"]
index = 0 if metrics_name == "dice_coeff" else 1
eval_score = model.eval(dataset, dataset_sink_mode=False)["dice_coeff"][index]
return eval_score
class dice_coeff(nn.Metric):
"""Unet Metric, return dice coefficient and IOU."""
def __init__(self, print_res=True):
super(dice_coeff, self).__init__()
self.clear()
self.print_res = print_res
def clear(self):
self._dice_coeff_sum = 0
self._iou_sum = 0
self._samples_num = 0
def update(self, *inputs):
if len(inputs) != 2:
raise ValueError('Need 2 inputs ((y_softmax, y_argmax), y), but got {}'.format(len(inputs)))
y = self._convert_data(inputs[1])
self._samples_num += y.shape[0]
y = y.transpose(0, 2, 3, 1)
b, h, w, c = y.shape
if b != 1:
raise ValueError('Batch size should be 1 when in evaluation.')
y = y.reshape((h, w, c))
if config.eval_activate.lower() == "softmax":
y_softmax = np.squeeze(self._convert_data(inputs[0][0]), axis=0)
if config.eval_resize:
y_pred = []
for i in range(config.num_classes):
y_pred.append(cv2.resize(np.uint8(y_softmax[:, :, i] * 255), (w, h)) / 255)
y_pred = np.stack(y_pred, axis=-1)
else:
y_pred = y_softmax
elif config.eval_activate.lower() == "argmax":
y_argmax = np.squeeze(self._convert_data(inputs[0][1]), axis=0)
y_pred = []
for i in range(config.num_classes):
if config.eval_resize:
y_pred.append(cv2.resize(np.uint8(y_argmax == i), (w, h), interpolation=cv2.INTER_NEAREST))
else:
y_pred.append(np.float32(y_argmax == i))
y_pred = np.stack(y_pred, axis=-1)
else:
raise ValueError('config eval_activate should be softmax or argmax.')
y_pred = y_pred.astype(np.float32)
inter = np.dot(y_pred.flatten(), y.flatten())
union = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten())
single_dice_coeff = 2 * float(inter) / float(union+1e-6)
single_iou = single_dice_coeff / (2 - single_dice_coeff)
if self.print_res:
print("single dice coeff is: {}, IOU is: {}".format(single_dice_coeff, single_iou))
self._dice_coeff_sum += single_dice_coeff
self._iou_sum += single_iou
def eval(self):
if self._samples_num == 0:
raise RuntimeError('Total samples num must not be 0.')
return (self._dice_coeff_sum / float(self._samples_num), self._iou_sum / float(self._samples_num))
class StepLossTimeMonitor(Callback):
def __init__(self, batch_size, per_print_times=1):
super(StepLossTimeMonitor, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("print_step must be int and >= 0.")
self._per_print_times = per_print_times
self.batch_size = batch_size
def step_begin(self, run_context):
self.step_time = time.time()
def step_end(self, run_context):
step_seconds = time.time() - self.step_time
step_fps = self.batch_size*1.0/step_seconds
cb_params = run_context.original_args()
loss = cb_params.net_outputs
if isinstance(loss, (tuple, list)):
if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
loss = loss[0]
if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
loss = np.mean(loss.asnumpy())
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format(
cb_params.cur_epoch_num, cur_step_in_epoch))
self.losses.append(loss)
if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
# TEST
print("step: %s, loss is %s, fps is %s" % (cur_step_in_epoch, loss, step_fps), flush=True)
def epoch_begin(self, run_context):
self.epoch_start = time.time()
self.losses = []
def epoch_end(self, run_context):
cb_params = run_context.original_args()
epoch_cost = time.time() - self.epoch_start
step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
step_fps = self.batch_size * 1.0 * step_in_epoch / epoch_cost
print("epoch: {:3d}, avg loss:{:.4f}, total cost: {:.3f} s, per step fps:{:5.3f}".format(
cb_params.cur_epoch_num, np.mean(self.losses), epoch_cost, step_fps), flush=True)
def mask_to_image(mask):
return Image.fromarray((mask * 255).astype(np.uint8))
def filter_checkpoint_parameter_by_list(param_dict, filter_list):
"""remove useless parameters according to filter_list"""
for key in list(param_dict.keys()):
for name in filter_list:
if name in key:
print("Delete parameter from checkpoint: ", key)
del param_dict[key]
break
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment