Skip to content
Snippets Groups Projects
Commit 2119aed8 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!787 [模型王者挑战赛]_[MindSpore赛题]_[UNet3+]

Merge pull request !787 from 谭华林/unet3plus
parents d564a7fd 9a47783b
No related branches found
No related tags found
No related merge requests found
Showing with 1660 additions and 0 deletions
# UNet3+
<!-- TOC -->
- [UNet3+](#UNet3+)
- [UNet3+介绍](#UNet3+介绍)
- [模型结构](#模型结构)
- [数据集](#数据集)
- [环境要求](#环境要求)
- [快速入门](#快速入门)
- [脚本说明](#脚本说明)
- [脚本及样例代码](#脚本及样例代码)
- [脚本参数](#脚本参数)
- [训练过程](#训练过程)
- [训练](#训练)
- [评估过程](#评估过程)
- [评估](#评估)
- [模型描述](#模型描述)
- [性能](#性能)
- [评估性能](#评估性能)
- [随机情况说明](#随机情况说明)
- [ModelZoo主页](#modelzoo主页)
<!-- /TOC -->
## UNet3+介绍
UNET 3+: A FULL-SCALE CONNECTED UNET FOR MEDICAL IMAGE SEGMENTATION 利用了全尺度的跳跃连接 (skip connection) 和深度监督(deep supervisions)来完成医学图像语义分割的任务。全尺度的跳跃连接把来自不同尺度特征图中的高级语义与低级语义结合;而深度监督则从多尺度聚合的特征图中学习层次表示,特别适用于不同规模的器官。除了提高精度外,本文所提出的 UNet 3 + 还可以减少网络参数,提高计算效率。
[论文](https://arxiv.org/abs/2004.08790):Huang H , Lin L , Tong R , et al. UNet 3+: A Full-Scale Connected UNet for Medical Image Segmentation[J]. arXiv, 2020.
## 模型结构
与UNet和UNet++相比,UNet 3+通过重新设计跳跃连接、利用多尺度的深度监督将多尺度特征结合起来,这使得它只需要比它们更少的参数,却可以产生更准确的位置感知和边界增强的分割图。
无论是U-Net中的直接连接还是U-Net ++中的密集嵌套连接,都缺乏从全尺度探索足够信息的能力,因此不能明确地得知器官的位置和边界。U-Net 3+ 中的每个解码器层都融合了来自编码器的较小和相同尺度的特征图以及来自解码器的较大尺度的特征图,它们捕获了全尺度下的细粒度语义和粗粒度语义。
## 数据集
数据集:[**LiTS2017**](<https://competitions.codalab.org/competitions/15595>)
Liver tumor Segmentation Challenge (LiTS,肝脏肿瘤病灶区 CT 图像分割挑战大赛) 数据集,包含来自全球各地的医院提供的对比增强过的CT图像。 共有训练集 131例,测试集 70例,其中测试集未公布标签。论文中从131例训练集中选出103例和28例分别用于训练和验证。数据集源格式为 'nii' 。
UNet3+ 处理的数据为RGB图像,故训练前应该将源数据预处理为图片。
## 环境要求
- 硬件(Ascend/ModelArts)
- 准备Ascend或ModelArts处理器搭建硬件环境。
- 框架
- [MindSpore](https://www.mindspore.cn/install)
- 如需查看详情,请参见如下资源:
- [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
## 快速入门
通过官方网站安装 MindSpore 后,您可以按照如下步骤进行训练和评估:
```bash
###参数配置请修改 default_config.yaml 文件
#通过 python 命令行运行单卡训练脚本。
python train.py > log.txt 2>&1 &
#通过 bash 命令启动单卡训练。
bash ./scripts/run_train.sh [root path of code]
#训练日志将输出到 log.txt 文件
#Ascend多卡训练。
bash ./scripts/run_distribute_train.sh [root path of code] [rank size] [rank start id] [rank table file]
# 通过 python 命令行运行推理脚本。
# pretrain_path 指 ckpt 所在目录,为了兼容 modelarts,将其拆分为了 “路径” 与 “文件名”
python eval.py > eval_log.txt 2>&1 &
#通过 bash 命令启动推理。
bash ./scripts/run_eval.sh [root path of code]
#推理日志将输出到 eval_log.txt 文件
```
Ascend训练:生成[RANK_TABLE_FILE](https://gitee.com/mindspore/models/tree/master/utils/hccl_tools)
## 脚本说明
### 脚本及样例代码
```text
├── model_zoo
├── README.md // 所有模型的说明文件
├── UNet3+
├── README_CN.md // UNet3+ 的说明文件
├── scripts
│ ├──run_distribute_train.sh // Ascend 8卡训练脚本
│ ├──run_eval.sh // 推理启动脚本
│ ├──run_train.sh // 训练启动脚本
├── src
│ ├──config.py // 配置加载文件
│ ├──dataset.py // 数据集处理
│ ├──models.py // 模型结构
│ ├──logger.py // 日志打印文件
│ ├──util.py // 工具类
├── default_config.yaml // 默认配置信息,包括训练、推理、模型冻结等
├── train.py // 训练脚本
├── eval.py // 推理脚本
├── export.py // 将权重文件导出为 MINDIR 等格式的脚本
├── dataset_preprocess.py // 数据集预处理脚本
```
### 脚本参数
```text
模型训练、推理、冻结等操作的参数均在 default_config.yaml 文件中进行配置。
关键参数默认如下:
--aug: 是否启用数据增强,(1 for True, 0 for False)
--epochs: 训练轮数
--lr: 学习率
--batch_size: 批次大小
```
### 训练过程
#### 训练
- 数据集预处理
进行网络训练和推理之前,您应该先进行数据集预处理。
```python
###参数配置请修改 default_config.yaml 文件,其中 source_path 指源 nii 格式数据集根目录,该目录下应该有 "CT" 和 “seg”
#两个文件夹,分别指源数据及对于语义分割标注结果;dest_path 指您期望存储处理后的图片数据的目录,不存在会自动创建;buffer_path
#指缓冲区目录,在处理完成后该文件夹会被递归删除。
python dataset_preprocess.py
```
- Ascend处理器环境运行
```bash
###参数配置请修改 default_config.yaml 文件
#通过 python 命令行运行单卡训练脚本。
python train.py > log.txt 2>&1 &
#通过 bash 命令启动单卡训练。
bash ./scripts/run_train.sh [root path of code]
#上述命令均会使脚本在后台运行,日志将输出到 log.txt,可通过查看该文件了解训练详情
#Ascend多卡训练。
bash ./scripts/run_distribute_train.sh [root path of code] [rank size] [rank start id] [rank table file]
```
训练完成后,您可以在 output_path 参数指定的目录下找到保存的权重文件,训练过程中的部分 loss 收敛情况如下(8卡并行):
```text
# grep "epoch time:" log.txt
epoch: 170 step: 960, loss is 0.51230466
epoch time: 58413.158 ms, per step time: 60.847 ms
epoch time: 58448.345 ms, per step time: 60.884 ms
epoch time: 58446.879 ms, per step time: 60.882 ms
epoch time: 58480.166 ms, per step time: 60.917 ms
epoch time: 58409.484 ms, per step time: 60.843 ms
epoch: 175 step: 960, loss is 0.50975895
epoch time: 58429.310 ms, per step time: 60.864 ms
epoch time: 58543.156 ms, per step time: 60.982 ms
epoch time: 58455.628 ms, per step time: 60.891 ms
epoch time: 58453.604 ms, per step time: 60.889 ms
epoch time: 58422.367 ms, per step time: 60.857 ms
epoch: 180 step: 960, loss is 0.51502335
epoch time: 58416.837 ms, per step time: 60.851 ms
[WARNING] SESSION(53798,fffed29421e0,python):2021-11-01-15:55:11.115.617 [mindspore/ccsrc/backend/session/ascend_session.cc:1380] SelectKernel] There are 42 node/nodes used reduce precision to selected the kernel!
2021-11-01 15:56:54,111 :INFO: epoch: 180, Dice: 97.20967
2021-11-01 15:56:56,486 :INFO: update best result: 97.20967
2021-11-01 15:56:56,709 :INFO: update best checkpoint at: ./output/unet_2021-11-01_time_12_56_54/0_best_map.ckpt
epoch time: 62822.634 ms, per step time: 65.440 ms
2021-11-01 15:59:10,762 :INFO: epoch: 181, Dice: 97.1946
epoch time: 66539.150 ms, per step time: 69.312 ms
2021-11-01 16:01:30,357 :INFO: epoch: 182, Dice: 97.19583
epoch time: 64837.935 ms, per step time: 67.540 ms
2021-11-01 16:03:46,606 :INFO: epoch: 183, Dice: 97.33418
2021-11-01 16:03:46,608 :INFO: update best result: 97.33418
2021-11-01 16:03:46,828 :INFO: update best checkpoint at: ./output/unet_2021-11-01_time_12_56_54/0_best_map.ckpt
epoch time: 65825.663 ms, per step time: 68.568 ms
2021-11-01 16:06:07,652 :INFO: epoch: 184, Dice: 97.15482
epoch: 185 step: 960, loss is 0.5108043
epoch time: 62547.918 ms, per step time: 65.154 ms
2021-11-01 16:08:26,350 :INFO: epoch: 185, Dice: 97.32324
epoch time: 62356.042 ms, per step time: 64.954 ms
2021-11-01 16:10:40,546 :INFO: epoch: 186, Dice: 97.008
epoch time: 66353.477 ms, per step time: 69.118 ms
2021-11-01 16:13:00,183 :INFO: epoch: 187, Dice: 97.37989
2021-11-01 16:13:00,186 :INFO: update best result: 97.37989
2021-11-01 16:13:00,408 :INFO: update best checkpoint at: ./output/unet_2021-11-01_time_12_56_54/0_best_map.ckpt
...
```
### 评估过程
#### 评估
在运行以下命令之前,请检查用于推理评估的权重文件路径是否正确。
- Ascend处理器环境运行
```bash
###参数配置请修改 default_config.yaml 文件
# 通过 python 命令行运行推理脚本。
# pretrain_path 指 ckpt 所在目录,为了兼容 modelarts,将其拆分为了 “路径” 与 “文件名”
python eval.py > eval_log.txt 2>&1 &
#通过 bash 命令启动推理。
bash ./scripts/run_eval.sh [root path of code]
#推理日志将输出到 eval_log.txt 文件
```
运行完成后,您可以在 output_path 指定的目录下找到推理运行日志。
## 模型描述
### 性能
#### 评估性能
UNet3+ on “LiTS2017 ”
| Parameters | UNet3+ |
| -------------------------- | ------------------------------------------------------------ |
| Resource | Ascend 910 ;CPU 2.60GHz,192cores; Memory, 755G |
| uploaded Date | 1/11/2021 (month/day/year) |
| MindSpore Version | 1.3.0 |
| Dataset | LiTS2017 |
| Training Parameters | epoch=200, batch_size=2, lr=3e-4, aug=1 |
| Optimizer | Adam |
| Loss Function | BCEDiceLoss |
| outputs | image with segmentation mask |
| Loss | 0.5271476 |
| Accuracy | 97.71% |
| Total time | 8p:2h44m (without validation) |
| Checkpoint for Fine tuning | 8p: 19.30MB(.ckpt file) |
| Scripts | [UNet3+脚本](https://gitee.com/mindspore/models/tree/master/research/cv/UNet3+) |
## 随机情况说明
train.py 和 eval.py 中设置了随机种子。
## ModelZoo主页
请浏览官网[主页](https://gitee.com/mindspore/models)
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''Preprocessing of the dataset: turn nii into png'''
import os
import math
import random
import shutil
import numpy as np
import SimpleITK as sitk
from tqdm import tqdm
import cv2
from src.config import config as cfg
if __name__ == "__main__":
if not os.path.exists(os.path.join(cfg.buffer_path, "ct")):
os.makedirs(os.path.join(cfg.buffer_path, "ct"))
if not os.path.exists(os.path.join(cfg.buffer_path, "seg")):
os.makedirs(os.path.join(cfg.buffer_path, "seg"))
ct_path = os.path.join(cfg.source_path, "CT")
for index, file in enumerate(tqdm(os.listdir(ct_path))):
ct_src = sitk.ReadImage(os.path.join(ct_path, file), sitk.sitkInt16)
mask = sitk.ReadImage(os.path.join(cfg.source_path, "seg", \
file.replace('volume', 'segmentation')), sitk.sitkUInt8)
ct_array = sitk.GetArrayFromImage(ct_src)
mask_array = sitk.GetArrayFromImage(mask)
ct_array[ct_array > 200] = 200
ct_array[ct_array < -200] = -200
z = np.any(mask_array, axis=(1, 2))
start_slice, end_slice = np.where(z)[0][[0, -1]]
ct_crop = ct_array[start_slice - 1:end_slice + 1, :, :]
mask_crop = mask_array[start_slice:end_slice + 1, :, :]
for n_slice in range(mask_crop.shape[0]):
maskImg = mask_crop[n_slice, :, :] * 255
cv2.imwrite(os.path.join(cfg.buffer_path, "seg", str(index) + "_" + str(n_slice) + ".png"), maskImg)
ctImageArray = np.zeros((ct_crop.shape[1], ct_crop.shape[2], 3), np.float)
ctImageArray[:, :, 0] = ct_crop[n_slice - 1, :, :]
ctImageArray[:, :, 1] = ct_crop[n_slice, :, :]
ctImageArray[:, :, 2] = ct_crop[n_slice + 1, :, :]
ctImg = ct_crop[n_slice, :, :]
ctImg = ctImg.astype(np.float)
cv2.imwrite(os.path.join(cfg.buffer_path, "ct", str(index) + "_" + str(n_slice) + ".png"), ctImageArray)
print("Data transform Done!")
if not os.path.exists(os.path.join(cfg.dest_path, "train", "ct")):
os.makedirs(os.path.join(cfg.dest_path, "train", "ct"))
if not os.path.exists(os.path.join(cfg.dest_path, "train", "seg")):
os.makedirs(os.path.join(cfg.dest_path, "train", "seg"))
if not os.path.exists(os.path.join(cfg.dest_path, "test", "ct")):
os.makedirs(os.path.join(cfg.dest_path, "test", "ct"))
if not os.path.exists(os.path.join(cfg.dest_path, "test", "seg")):
os.makedirs(os.path.join(cfg.dest_path, "test", "seg"))
seg = os.listdir(os.path.join(cfg.buffer_path, "seg"))
random.seed(1000)
random.shuffle(seg)
print("Start to split train data!")
for index, i in enumerate(seg[:math.floor(len(seg)*0.8)]):
if (index+1)%1000 == 0:
print(index+1, "/", math.floor(len(seg)*0.8))
shutil.move(os.path.join(cfg.buffer_path, "ct", i), os.path.join(cfg.dest_path, "train", "ct"))
shutil.move(os.path.join(cfg.buffer_path, "seg", i), os.path.join(cfg.dest_path, "train", "seg"))
print("Start to split val data!")
for index, i in enumerate(seg[math.floor(len(seg)*0.8):]):
if (index+1)%1000 == 0:
print(index, "/", len(seg)-math.floor(len(seg)*0.8))
shutil.move(os.path.join(cfg.buffer_path, "ct", i), os.path.join(cfg.dest_path, "test", "ct"))
shutil.move(os.path.join(cfg.buffer_path, "seg", i), os.path.join(cfg.dest_path, "test", "seg"))
shutil.rmtree(cfg.buffer_path)
print("Data processing and splitting finished!")
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unlesee you know exactly what you are doing)
use_modelarts: 0
# url for modelarts
data_url: ""
train_url: ""
outer_path: 's3://output/'
# mainly hyperparameters for training
aug: 1
epochs: 200
lr: 3e-4
batch_size: 2
# dataset options, we recommend the absolute path
train_data_path: "../dataset/LiTS2017/train"
val_data_path: "../dataset/LiTS2017/test"
# eval settings while training
eval_while_train: 1
eval_steps: 1
eval_start_epoch: 180
# checkpoint config while training
save_every: 1
is_save_on_master: 1
ckpt_save_max: 5
output_path: './output/'
resume_path: ""
resume_name: ""
# eval settings stand alone, other hyperparameters are shared with training
pretrain_path: "./"
ckpt_name: "best_map.ckpt"
# export settings stand alone, other hyperparameters are shared with training
export_batch_size: 1
image_height: 512
image_width: 512
ckpt_file: "./best_map.ckpt"
file_name: "unet3plus"
file_format: "AIR"
# dataset preprocess settings
source_path: "../origin_dataset/train/"
dest_path: "../dataset/LiTS2017"
buffer_path: "./buffer"
# ======================================================================================
# common options
device_target: 'Ascend'
is_distributed: 0
rank: 0
group_size: 1
---
# Help description for each configuration
use_modelarts: "Whether training on modelarts, 1 for True, 0 for False; default: 0"
data_url: "needed by modelarts, but we donot use it because the name is ambiguous"
train_url: "needed by modelarts, but we donot use it because the name is ambiguous"
outer_path: "obs path, to store e.g ckpt files"
aug: "Whether to apply data augmentation, 1 for True, 0 for False; default: 1"
epochs: "epochs for training"
lr: "lr"
batch_size: "batch_size"
train_data_path: "root path of train data"
val_data_path: "root path of val data"
eval_while_train: "Whether eval while training, 1 for True, 0 for False; default: 1"
eval_steps: "each N epochs we eval"
eval_start_epoch: "eval_start_epoch"
save_every: "save model at every x epoches"
is_save_on_master: "save ckpt on master or all rank"
ckpt_save_max: "Maximum number of checkpoint files can be saved"
output_path: "output_path,when use_modelarts is set 1, it would better be cache/output/"
resume_path: "put the path to resuming file if needed"
resume_name: "resuming file name"
pretrain_path: "path of the ckpt to eval"
ckpt_name: "name of the ckpt to eval"
export_batch_size: "batch size for export ckpt"
image_height: "image height for export ckpt"
image_width: "image width for export ckpt"
ckpt_file: "the ckpt to export"
file_name: "name of exported ckpt"
file_format: "file format, choose from ['MINDIR','AIR','ONNX']"
source_path: "the root path of MICCAI-LITS-2017 with nii format"
dest_path: "the root path of the folder that you want to store the data with png format"
buffer_path: "the buffer to process data, it will be REMOVED after process finished"
device_target: "device where the code will be implemented. (Default: Ascend)"
is_distributed: "if multi device"
rank: "local rank of distributed"
group_size: "world size of distributed"
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''eval of UNet3+'''
import datetime
import os
import mindspore.ops as ops
from mindspore import context
from mindspore.common import set_seed
from mindspore import load_checkpoint, load_param_into_net
from src.logger import get_logger
from src.dataset import create_Dataset
from src.models import UNet3Plus
from src.config import config as cfg
def copy_data_from_obs():
'''copy_data_from_obs'''
if cfg.use_modelarts:
import moxing as mox
cfg.logger.info("copying test weights from obs to cache....")
mox.file.copy_parallel(cfg.pretrain_path, 'cache/weight')
cfg.logger.info("copying test weights finished....")
cfg.pretrain_path = 'cache/weight/'
cfg.logger.info("copying val dataset from obs to cache....")
mox.file.copy_parallel(cfg.val_data_path, 'cache/val')
cfg.logger.info("copying val dataset finished....")
cfg.val_data_path = 'cache/val/'
def copy_data_to_obs():
if cfg.use_modelarts:
import moxing as mox
cfg.logger.info("copying files from cache to obs....")
mox.file.copy_parallel(cfg.save_dir, cfg.outer_path)
cfg.logger.info("copying finished....")
def dice_coef(output, target):
smooth = 1e-5
intersection = (output * target).sum()
return (2. * intersection + smooth) / \
(output.sum() + target.sum() + smooth)
class AverageMeter():
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def test(model_path):
'''test'''
model = UNet3Plus()
model.set_train(False)
cfg.logger.info('load test weights from %s', str(model_path))
load_param_into_net(model, load_checkpoint(model_path))
cfg.logger.info('loaded test weights from %s', str(model_path))
val_dataset, _ = create_Dataset(cfg.val_data_path, 0, cfg.batch_size,\
1, 0, shuffle=False)
data_loader = val_dataset.create_dict_iterator()
dices = AverageMeter()
sigmoid = ops.Sigmoid()
for _, data in enumerate(data_loader):
output = sigmoid(model(data["image"])).asnumpy()
dice = dice_coef(output, data["mask"].asnumpy())
dices.update(dice, cfg.batch_size)
cfg.logger.info("Final dices: %s", str(dices.avg))
if __name__ == '__main__':
set_seed(1)
cfg.save_dir = os.path.join(cfg.output_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
if not cfg.use_modelarts and not os.path.exists(cfg.save_dir):
os.makedirs(cfg.save_dir)
device_id = int(os.getenv('DEVICE_ID', '0'))
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target,
device_id=device_id, save_graphs=False)
cfg.logger = get_logger(cfg.save_dir, "UNet3Plus", 0)
cfg.logger.save_args(cfg)
copy_data_from_obs()
test(os.path.join(cfg.pretrain_path, cfg.ckpt_name))
copy_data_to_obs()
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
##############export checkpoint file into air, onnx, mindir models#################
python export.py
"""
import os
import numpy as np
import mindspore as ms
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
from src.models import UNet3Plus
from src.config import config as cfg
if __name__ == '__main__':
device_id = int(os.getenv('DEVICE_ID', '0'))
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target,
device_id=device_id)
net = UNet3Plus()
param_dict = load_checkpoint(cfg.ckpt_file)
load_param_into_net(net, param_dict)
input_arr = Tensor(np.zeros([cfg.export_batch_size, 3, \
cfg.image_height, cfg.image_width]), ms.float32)
export(net, input_arr, file_name=cfg.file_name, file_format=cfg.file_format)
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# help message
if [ $# != 4 ]; then
echo "Usage: sh run_train.sh [root path of code] [rank size]" \
"[rank start id] [rank table file]"
exit 1
fi
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)/"
fi
}
train_code_path=$(get_real_path $1)
echo "train_code_path: "$train_code_path
if [ ! -d $train_code_path ]
then
echo "error: train_code_path=$train_code_path is not a dictionary."
exit 1
fi
ulimit -c unlimited
ulimit -n 65530
export SLOG_PRINT_TO_STDOUT=0
export RANK_TABLE_FILE=$4
export RANK_SIZE=$2
export RANK_START_ID=$3
for((i=0;i<=$RANK_SIZE-1;i++));
do
export RANK_ID=${i}
export DEVICE_ID=$((i + RANK_START_ID))
echo 'start rank='${i}', device id='${DEVICE_ID}'...'
if [ -d ${train_code_path}/device${DEVICE_ID} ]; then
rm -rf ${train_code_path}/device${DEVICE_ID}
fi
mkdir ${train_code_path}/device${DEVICE_ID}
cd ${train_code_path}/device${DEVICE_ID} || exit
nohup python ${train_code_path}train.py > log.txt 2>&1 &
done
\ No newline at end of file
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# help message
if [ $# != 1 ]; then
echo "Usage: sh run_eval.sh [root path of code]"
exit 1
fi
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)/"
fi
}
train_code_path=$(get_real_path $1)
echo "train_code_path: "$train_code_path
if [ ! -d $train_code_path ]
then
echo "error: train_code_path=$train_code_path is not a dictionary."
exit 1
fi
ulimit -n 65530
nohup python ${train_code_path}eval.py > eval_log.txt 2>&1 &
echo 'Validation task has been started successfully!'
echo 'Please check the log at eval_log.txt'
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# help message
if [ $# != 1 ]; then
echo "Usage: sh run_train.sh [root path of code]"
exit 1
fi
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)/"
fi
}
train_code_path=$(get_real_path $1)
echo "train_code_path: "$train_code_path
if [ ! -d $train_code_path ]
then
echo "error: train_code_path=$train_code_path is not a dictionary."
exit 1
fi
ulimit -n 65530
nohup python ${train_code_path}train.py > log.txt 2>&1 &
echo 'Train task has been started successfully!'
echo 'Please check the log at log.txt'
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse arguments"""
import os
import ast
import argparse
from pprint import pprint, pformat
import yaml
class Config:
"""
Configuration namespace. Convert dictionary to members.
"""
def __init__(self, cfg_dict):
for k, v in cfg_dict.items():
if isinstance(v, (list, tuple)):
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
else:
setattr(self, k, Config(v) if isinstance(v, dict) else v)
def __str__(self):
return pformat(self.__dict__)
def __repr__(self):
return self.__str__()
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
"""
Parse command line arguments to the configuration according to the default yaml.
Args:
parser: Parent parser.
cfg: Base configuration.
helper: Helper description.
cfg_path: Path to the default yaml config.
"""
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
parents=[parser])
helper = {} if helper is None else helper
choices = {} if choices is None else choices
for item in cfg:
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
choice = choices[item] if item in choices else None
if isinstance(cfg[item], bool):
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
help=help_description)
else:
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
help=help_description)
args = parser.parse_args()
return args
def parse_yaml(yaml_path):
"""
Parse the yaml config file.
Args:
yaml_path: Path to the yaml config.
"""
with open(yaml_path, 'r') as fin:
try:
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
cfgs = [x for x in cfgs]
if len(cfgs) == 1:
cfg_helper = {}
cfg = cfgs[0]
cfg_choices = {}
elif len(cfgs) == 2:
cfg, cfg_helper = cfgs
cfg_choices = {}
elif len(cfgs) == 3:
cfg, cfg_helper, cfg_choices = cfgs
else:
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
print(cfg_helper)
except:
raise ValueError("Failed to parse yaml")
return cfg, cfg_helper, cfg_choices
def merge(args, cfg):
"""
Merge the base config from yaml file and command line arguments.
Args:
args: Command line arguments.
cfg: Base configuration.
"""
args_var = vars(args)
for item in args_var:
cfg[item] = args_var[item]
return cfg
def get_config():
"""
Get Config according to the yaml file and cli arguments.
"""
parser = argparse.ArgumentParser(description="default name", add_help=False)
current_dir = os.path.dirname(os.path.abspath(__file__))
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../default_config.yaml"),
help="Config file path")
path_args, _ = parser.parse_known_args()
default, helper, choices = parse_yaml(path_args.config_path)
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
final_config = merge(args, default)
pprint(final_config)
print("Please check the above information for the configurations", flush=True)
return Config(final_config)
config = get_config()
if __name__ == '__main__':
print(config)
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''data loader'''
import os
import glob
import random
import numpy as np
from skimage.io import imread
from skimage import color
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as CV
class Dataset:
'''Dataset'''
def __init__(self, data_path, aug=False):
super(Dataset, self).__init__()
self.img_paths = glob.glob(os.path.join(data_path, "ct", "*"))
self.mask_paths = glob.glob(os.path.join(data_path, "seg", "*"))
self.aug = aug
def __getitem__(self, idx):
img_path = self.img_paths[idx]
mask_path = self.mask_paths[idx]
image = imread(img_path)
mask = imread(mask_path)
image = image.astype('float32') / 255
mask = mask.astype('float32') / 255
if self.aug:
if random.uniform(0, 1) > 0.5:
image = image[:, ::-1, :].copy()
mask = mask[:, ::-1].copy()
if random.uniform(0, 1) > 0.5:
image = image[::-1, :, :].copy()
mask = mask[::-1, :].copy()
image = color.gray2rgb(image)
mask = mask[:, :, np.newaxis]
return image, mask
def __len__(self):
return len(self.img_paths)
def create_Dataset(data_path, aug, batch_size, device_num, rank, shuffle):
dataset = Dataset(data_path, aug)
hwc_to_chw = CV.HWC2CHW()
data_set = ds.GeneratorDataset(dataset, column_names=["image", "mask"], \
num_parallel_workers=8, shuffle=shuffle, num_shards=device_num, shard_id=rank)
data_set = data_set.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=8)
data_set = data_set.map(input_columns=["mask"], operations=hwc_to_chw, num_parallel_workers=8)
data_set = data_set.batch(batch_size, drop_remainder=True)
return data_set, data_set.get_dataset_size()
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Custom Logger."""
import os
import sys
import logging
from datetime import datetime
class LOGGER(logging.Logger):
"""
Logger.
Args:
logger_name: String. Logger name.
rank: Integer. Rank id.
"""
def __init__(self, logger_name, rank=0):
super(LOGGER, self).__init__(logger_name)
self.rank = rank
if rank % 8 == 0:
console = logging.StreamHandler(sys.stdout)
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
console.setFormatter(formatter)
self.addHandler(console)
def setup_logging_file(self, log_dir, rank=0):
"""Setup logging file."""
self.rank = rank
if not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank)
self.log_fn = os.path.join(log_dir, log_name)
fh = logging.FileHandler(self.log_fn)
fh.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
fh.setFormatter(formatter)
self.addHandler(fh)
def info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO):
self._log(logging.INFO, msg, args, **kwargs)
def save_args(self, args):
self.info('Args:')
args_dict = vars(args)
for key in args_dict.keys():
self.info('--> %s: %s', key, args_dict[key])
self.info('')
def important_info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO) and self.rank == 0:
line_width = 2
important_msg = '\n'
important_msg += ('*'*70 + '\n')*line_width
important_msg += ('*'*line_width + '\n')*2
important_msg += '*'*line_width + ' '*8 + msg + '\n'
important_msg += ('*'*line_width + '\n')*2
important_msg += ('*'*70 + '\n')*line_width
self.info(important_msg, *args, **kwargs)
def get_logger(path, logger_name, rank):
"""Get Logger."""
logger = LOGGER(logger_name, rank)
logger.setup_logging_file(path, rank)
return logger
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''model of UNet3+'''
import numpy as np
import mindspore
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.tensor import Tensor
class unetConv2(nn.Cell):
'''unetConv2'''
def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1, weight_init="HeNormal"):
super(unetConv2, self).__init__()
self.n = n
self.ks = ks
self.stride = stride
self.padding = padding
s = stride
p = padding
conv_layer = []
if is_batchnorm:
for _ in range(1, n + 1):
conv_layer.extend([
nn.Conv2d(in_size, out_size, ks, s, pad_mode="pad", padding=p, weight_init="HeNormal"),
nn.BatchNorm2d(out_size, gamma_init="ones"),
nn.ReLU()
])
in_size = out_size
else:
for _ in range(1, n + 1):
conv_layer.extend([
nn.Conv2d(in_size, out_size, ks, s, pad_mode="pad", padding=p, weight_init="HeNormal"),
nn.ReLU()
])
in_size = out_size
self.conv = nn.SequentialCell(conv_layer)
def construct(self, inputs):
'''construct'''
return self.conv(inputs)
class UNet3Plus(nn.Cell):
'''UNet3Plus'''
def __init__(self, in_channels=3, n_classes=1, feature_scale=4,
is_deconv=True, is_batchnorm=True):
super(UNet3Plus, self).__init__()
self.is_deconv = is_deconv
self.in_channels = in_channels
self.is_batchnorm = is_batchnorm
self.feature_scale = feature_scale
filters = [16, 32, 64, 128, 256]
## -------------Encoder--------------
self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm)
self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm)
self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm)
self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm)
self.maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv5 = unetConv2(filters[3], filters[4], self.is_batchnorm)
## -------------Decoder--------------
self.CatChannels = filters[0]
self.CatBlocks = 5
self.UpChannels = self.CatChannels * self.CatBlocks
'''stage 4d'''
# h1->320*320, hd4->40*40, Pooling 8 times
self.h1_PT_hd4 = nn.MaxPool2d(8, 8)
self.h1_PT_hd4_conv = nn.Conv2d(filters[0], self.CatChannels, 3, \
pad_mode="pad", padding=1, weight_init="HeNormal")
self.h1_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels, gamma_init="ones")
self.h1_PT_hd4_relu = nn.ReLU()
# h2->160*160, hd4->40*40, Pooling 4 times
self.h2_PT_hd4 = nn.MaxPool2d(4, 4)
self.h2_PT_hd4_conv = nn.Conv2d(filters[1], self.CatChannels, 3, \
pad_mode="pad", padding=1, weight_init="HeNormal")
self.h2_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels, gamma_init="ones")
self.h2_PT_hd4_relu = nn.ReLU()
# h3->80*80, hd4->40*40, Pooling 2 times
self.h3_PT_hd4 = nn.MaxPool2d(2, 2)
self.h3_PT_hd4_conv = nn.Conv2d(filters[2], self.CatChannels, 3, \
pad_mode="pad", padding=1, weight_init="HeNormal")
self.h3_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels, gamma_init="ones")
self.h3_PT_hd4_relu = nn.ReLU()
# h4->40*40, hd4->40*40, Concatenation
self.h4_Cat_hd4_conv = nn.Conv2d(filters[3], self.CatChannels, 3, \
pad_mode="pad", padding=1, weight_init="HeNormal")
self.h4_Cat_hd4_bn = nn.BatchNorm2d(self.CatChannels, gamma_init="ones")
self.h4_Cat_hd4_relu = nn.ReLU()
# hd5->20*20, hd4->40*40, Upsample 2 times
self.ResizeBilinear = nn.ResizeBilinear()
self.hd5_UT_hd4_conv = nn.Conv2d(filters[4], self.CatChannels, 3, \
pad_mode="pad", padding=1, weight_init="HeNormal")
self.hd5_UT_hd4_bn = nn.BatchNorm2d(self.CatChannels, gamma_init="ones")
self.hd5_UT_hd4_relu = nn.ReLU()
# fusion(h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4)
self.conv4d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, \
pad_mode="pad", padding=1, weight_init="HeNormal") # 16
self.bn4d_1 = nn.BatchNorm2d(self.UpChannels, gamma_init="ones")
self.relu4d_1 = nn.ReLU()
'''stage 3d'''
# h1->320*320, hd3->80*80, Pooling 4 times
self.h1_PT_hd3 = nn.MaxPool2d(4, 4)
self.h1_PT_hd3_conv = nn.Conv2d(filters[0], self.CatChannels, 3, \
pad_mode="pad", padding=1, weight_init="HeNormal")
self.h1_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels, gamma_init="ones")
self.h1_PT_hd3_relu = nn.ReLU()
# h2->160*160, hd3->80*80, Pooling 2 times
self.h2_PT_hd3 = nn.MaxPool2d(2, 2)
self.h2_PT_hd3_conv = nn.Conv2d(filters[1], self.CatChannels, 3, \
pad_mode="pad", padding=1, weight_init="HeNormal")
self.h2_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels, gamma_init="ones")
self.h2_PT_hd3_relu = nn.ReLU()
# h3->80*80, hd3->80*80, Concatenation
self.h3_Cat_hd3_conv = nn.Conv2d(filters[2], self.CatChannels, 3, \
pad_mode="pad", padding=1, weight_init="HeNormal")
self.h3_Cat_hd3_bn = nn.BatchNorm2d(self.CatChannels, gamma_init="ones")
self.h3_Cat_hd3_relu = nn.ReLU()
# hd4->40*40, hd4->80*80, Upsample 2 times
self.hd4_UT_hd3_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, \
pad_mode="pad", padding=1, weight_init="HeNormal")
self.hd4_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels, gamma_init="ones")
self.hd4_UT_hd3_relu = nn.ReLU()
# hd5->20*20, hd4->80*80, Upsample 4 times
self.hd5_UT_hd3_conv = nn.Conv2d(filters[4], self.CatChannels, 3, \
pad_mode="pad", padding=1, weight_init="HeNormal")
self.hd5_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels, gamma_init="ones")
self.hd5_UT_hd3_relu = nn.ReLU()
# fusion(h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3)
self.conv3d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, \
pad_mode="pad", padding=1, weight_init="HeNormal") # 16
self.bn3d_1 = nn.BatchNorm2d(self.UpChannels, gamma_init="ones")
self.relu3d_1 = nn.ReLU()
'''stage 2d '''
# h1->320*320, hd2->160*160, Pooling 2 times
self.h1_PT_hd2 = nn.MaxPool2d(2, 2)
self.h1_PT_hd2_conv = nn.Conv2d(filters[0], self.CatChannels, 3, \
pad_mode="pad", padding=1, weight_init="HeNormal")
self.h1_PT_hd2_bn = nn.BatchNorm2d(self.CatChannels, gamma_init="ones")
self.h1_PT_hd2_relu = nn.ReLU()
# h2->160*160, hd2->160*160, Concatenation
self.h2_Cat_hd2_conv = nn.Conv2d(filters[1], self.CatChannels, 3, \
pad_mode="pad", padding=1, weight_init="HeNormal")
self.h2_Cat_hd2_bn = nn.BatchNorm2d(self.CatChannels, gamma_init="ones")
self.h2_Cat_hd2_relu = nn.ReLU()
# hd3->80*80, hd2->160*160, Upsample 2 times
self.hd3_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, \
pad_mode="pad", padding=1, weight_init="HeNormal")
self.hd3_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels, gamma_init="ones")
self.hd3_UT_hd2_relu = nn.ReLU()
# hd4->40*40, hd2->160*160, Upsample 4 times
self.hd4_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, \
pad_mode="pad", padding=1, weight_init="HeNormal")
self.hd4_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels, gamma_init="ones")
self.hd4_UT_hd2_relu = nn.ReLU()
# hd5->20*20, hd2->160*160, Upsample 8 times
self.hd5_UT_hd2_conv = nn.Conv2d(filters[4], self.CatChannels, 3, \
pad_mode="pad", padding=1, weight_init="HeNormal")
self.hd5_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels, gamma_init="ones")
self.hd5_UT_hd2_relu = nn.ReLU()
# fusion(h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2)
self.conv2d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, \
pad_mode="pad", padding=1, weight_init="HeNormal") # 16
self.bn2d_1 = nn.BatchNorm2d(self.UpChannels, gamma_init="ones")
self.relu2d_1 = nn.ReLU()
'''stage 1d'''
# h1->320*320, hd1->320*320, Concatenation
self.h1_Cat_hd1_conv = nn.Conv2d(filters[0], self.CatChannels, 3, \
pad_mode="pad", padding=1, weight_init="HeNormal")
self.h1_Cat_hd1_bn = nn.BatchNorm2d(self.CatChannels, gamma_init="ones")
self.h1_Cat_hd1_relu = nn.ReLU()
# hd2->160*160, hd1->320*320, Upsample 2 times
self.hd2_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, \
pad_mode="pad", padding=1, weight_init="HeNormal")
self.hd2_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels, gamma_init="ones")
self.hd2_UT_hd1_relu = nn.ReLU()
# hd3->80*80, hd1->320*320, Upsample 4 times
self.hd3_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, \
pad_mode="pad", padding=1, weight_init="HeNormal")
self.hd3_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels, gamma_init="ones")
self.hd3_UT_hd1_relu = nn.ReLU()
# hd4->40*40, hd1->320*320, Upsample 8 times
self.hd4_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, \
pad_mode="pad", padding=1, weight_init="HeNormal")
self.hd4_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels, gamma_init="ones")
self.hd4_UT_hd1_relu = nn.ReLU()
# hd5->20*20, hd1->320*320, Upsample 16 times
self.hd5_UT_hd1_conv = nn.Conv2d(filters[4], self.CatChannels, 3, \
pad_mode="pad", padding=1, weight_init="HeNormal")
self.hd5_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels, gamma_init="ones")
self.hd5_UT_hd1_relu = nn.ReLU()
# fusion(h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1)
self.conv1d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, \
pad_mode="pad", padding=1, weight_init="HeNormal") # 16
self.bn1d_1 = nn.BatchNorm2d(self.UpChannels, gamma_init="ones")
self.relu1d_1 = nn.ReLU()
# output
self.outconv1 = nn.Conv2d(self.UpChannels, n_classes, 3, \
pad_mode="pad", padding=1, weight_init="HeNormal")
self.concat1 = ops.Concat(1)
def construct(self, inputs):
'''construct'''
## -------------Encoder-------------
h1 = self.conv1(inputs) # h1->320*320*64
h2 = self.maxpool1(h1)
h2 = self.conv2(h2) # h2->160*160*128
h3 = self.maxpool2(h2)
h3 = self.conv3(h3) # h3->80*80*256
h4 = self.maxpool3(h3)
h4 = self.conv4(h4) # h4->40*40*512
h5 = self.maxpool4(h4)
hd5 = self.conv5(h5) # h5->20*20*1024
## -------------Decoder-------------
h1_PT_hd4 = self.h1_PT_hd4_relu(self.h1_PT_hd4_bn(self.h1_PT_hd4_conv(self.h1_PT_hd4(h1))))
h2_PT_hd4 = self.h2_PT_hd4_relu(self.h2_PT_hd4_bn(self.h2_PT_hd4_conv(self.h2_PT_hd4(h2))))
h3_PT_hd4 = self.h3_PT_hd4_relu(self.h3_PT_hd4_bn(self.h3_PT_hd4_conv(self.h3_PT_hd4(h3))))
h4_Cat_hd4 = self.h4_Cat_hd4_relu(self.h4_Cat_hd4_bn(self.h4_Cat_hd4_conv(h4)))
hd5_UT_hd4 = self.hd5_UT_hd4_relu(self.hd5_UT_hd4_bn(self.hd5_UT_hd4_conv(\
self.ResizeBilinear(hd5, scale_factor=2, align_corners=True))))
hd4 = self.relu4d_1(self.bn4d_1(self.conv4d_1(
self.concat1((h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4))))) # hd4->40*40*UpChannels
h1_PT_hd3 = self.h1_PT_hd3_relu(self.h1_PT_hd3_bn(self.h1_PT_hd3_conv(self.h1_PT_hd3(h1))))
h2_PT_hd3 = self.h2_PT_hd3_relu(self.h2_PT_hd3_bn(self.h2_PT_hd3_conv(self.h2_PT_hd3(h2))))
h3_Cat_hd3 = self.h3_Cat_hd3_relu(self.h3_Cat_hd3_bn(self.h3_Cat_hd3_conv(h3)))
hd4_UT_hd3 = self.hd4_UT_hd3_relu(self.hd4_UT_hd3_bn(self.hd4_UT_hd3_conv(\
self.ResizeBilinear(hd4, scale_factor=2, align_corners=True))))
hd5_UT_hd3 = self.hd5_UT_hd3_relu(self.hd5_UT_hd3_bn(self.hd5_UT_hd3_conv(\
self.ResizeBilinear(hd5, scale_factor=4, align_corners=True))))
hd3 = self.relu3d_1(self.bn3d_1(self.conv3d_1(
self.concat1((h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3))))) # hd3->80*80*UpChannels
h1_PT_hd2 = self.h1_PT_hd2_relu(self.h1_PT_hd2_bn(self.h1_PT_hd2_conv(self.h1_PT_hd2(h1))))
h2_Cat_hd2 = self.h2_Cat_hd2_relu(self.h2_Cat_hd2_bn(self.h2_Cat_hd2_conv(h2)))
hd3_UT_hd2 = self.hd3_UT_hd2_relu(self.hd3_UT_hd2_bn(self.hd3_UT_hd2_conv(\
self.ResizeBilinear(hd3, scale_factor=2, align_corners=True))))
hd4_UT_hd2 = self.hd4_UT_hd2_relu(self.hd4_UT_hd2_bn(self.hd4_UT_hd2_conv(\
self.ResizeBilinear(hd4, scale_factor=4, align_corners=True))))
hd5_UT_hd2 = self.hd5_UT_hd2_relu(self.hd5_UT_hd2_bn(self.hd5_UT_hd2_conv(\
self.ResizeBilinear(hd5, scale_factor=8, align_corners=True))))
hd2 = self.relu2d_1(self.bn2d_1(self.conv2d_1(
self.concat1((h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2))))) # hd2->160*160*UpChannels
h1_Cat_hd1 = self.h1_Cat_hd1_relu(self.h1_Cat_hd1_bn(self.h1_Cat_hd1_conv(h1)))
hd2_UT_hd1 = self.hd2_UT_hd1_relu(self.hd2_UT_hd1_bn(self.hd2_UT_hd1_conv(\
self.ResizeBilinear(hd2, scale_factor=2, align_corners=True))))
hd3_UT_hd1 = self.hd3_UT_hd1_relu(self.hd3_UT_hd1_bn(self.hd3_UT_hd1_conv(\
self.ResizeBilinear(hd3, scale_factor=4, align_corners=True))))
hd4_UT_hd1 = self.hd4_UT_hd1_relu(self.hd4_UT_hd1_bn(self.hd4_UT_hd1_conv(\
self.ResizeBilinear(hd4, scale_factor=8, align_corners=True))))
hd5_UT_hd1 = self.hd5_UT_hd1_relu(self.hd5_UT_hd1_bn(self.hd5_UT_hd1_conv(\
self.ResizeBilinear(hd5, scale_factor=16, align_corners=True))))
hd1 = self.relu1d_1(self.bn1d_1(self.conv1d_1(
self.concat1((h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1))))) # hd1->320*320*UpChannels
d1 = self.outconv1(hd1) # d1->320*320*n_classes
return d1
class BCEDiceLoss(nn.Cell):
'''BCEDiceLoss'''
def __init__(self):
super(BCEDiceLoss, self).__init__()
self.bceloss = ops.BinaryCrossEntropy()
self.sigmoid = ops.Sigmoid()
self.reduceSum = ops.ReduceSum(keep_dims=False)
self.one_tensor = Tensor(np.ones([2, 1, 512, 512]), mindspore.float32)
def construct(self, predict, target):
'''construct'''
bce = self.bceloss(self.sigmoid(predict), target, self.one_tensor)
smooth = 1e-5
predict = self.sigmoid(predict)
num = target.shape[0]
predict = predict.view(num, -1)
target = target.view(num, -1)
intersection = (predict * target)
dice = (2. * self.reduceSum(intersection, 1) + smooth) / \
(self.reduceSum(predict, 1) + self.reduceSum(target, 1) + smooth)
dice = 1 - dice / num
return 0.5 * bce + dice
class UNet3PlusWithLossCell(nn.Cell):
'''UNet3PlusWithLossCell'''
def __init__(self, network):
super(UNet3PlusWithLossCell, self).__init__()
self.network = network
self.loss = BCEDiceLoss()
def construct(self, image, mask):
'''construct'''
output = self.network(image)
return self.loss(output, mask)
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Util class or function."""
import os
import stat
from datetime import datetime
import mindspore.ops as ops
from mindspore import nn
from mindspore import save_checkpoint
from mindspore import log as logger
from mindspore.train.callback import Callback
class AverageMeter:
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f', tb_writer=None):
self.name = name
self.fmt = fmt
self.reset()
self.tb_writer = tb_writer
self.cur_step = 1
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
if self.tb_writer is not None:
self.tb_writer.add_scalar(self.name, self.val, self.cur_step)
self.cur_step += 1
def __str__(self):
fmtstr = '{name}:{avg' + self.fmt + '}'
return fmtstr.format(**self.__dict__)
class TempLoss(nn.Cell):
"""A temp loss cell."""
def construct(self, *inputs, **kwargs):
return 0.1
class DiceMetric(nn.Metric):
"""DiceMetric"""
def __init__(self,):
super(DiceMetric, self).__init__()
self.sum = 0.
self.count = 0.
self.clear()
self.sigmoid = ops.Sigmoid()
self.reducesum = ops.ReduceSum()
self.smooth = 1e-5
def clear(self):
"""Resets the internal evaluation result to initial state."""
self.sum = 0.
self.count = 0.
def update(self, output, target):
"""Updates the internal evaluation result.
Parameters
----------
labels : 'NumpyArray' or list of `NumpyArray`
The labels of the data.
preds : 'NumpyArray' or list of `NumpyArray`
Predicted values.
"""
output = self.sigmoid(output)
intersection = self.reducesum(output * target)
dice = (2. * intersection + self.smooth) / \
(self.reducesum(output) + self.reducesum(target) + self.smooth)
self.sum += dice*output.shape[0]
self.count += output.shape[0]
def eval(self):
return self.sum / self.count
class EvalCallBack(Callback):
"""
Evaluation callback when training.
Args:
eval_function (function): evaluation function.
eval_param_dict (dict): evaluation parameters' configure dict.
interval (int): run evaluation interval, default is 1.
eval_start_epoch (int): evaluation start epoch, default is 1.
save_best_ckpt (bool): Whether to save best checkpoint, default is True.
besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`.
metrics_name (str): evaluation metrics name, default is `acc`.
Returns:
None
Examples:
>>> EvalCallBack(eval_function, eval_param_dict)
"""
def __init__(self, network, dataloader, interval=1, eval_start_epoch=1, \
save_best_ckpt=True, ckpt_directory="./", besk_ckpt_name="best.ckpt"):
super(EvalCallBack, self).__init__()
self.network = network
self.dataloader = dataloader
self.eval_start_epoch = eval_start_epoch
if interval < 1:
raise ValueError("interval should >= 1.")
self.interval = interval
self.save_best_ckpt = save_best_ckpt
self.best_res = 0
self.best_epoch = 0
if not os.path.isdir(ckpt_directory):
os.makedirs(ckpt_directory)
self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name)
def remove_ckpoint_file(self, file_name):
"""Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
try:
os.chmod(file_name, stat.S_IWRITE)
os.remove(file_name)
except OSError:
logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
except ValueError:
logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
def epoch_end(self, run_context):
"""Callback when epoch end."""
cb_params = run_context.original_args()
cur_epoch = cb_params.cur_epoch_num
if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0:
acc = self.network.eval(self.dataloader, dataset_sink_mode=True)['DiceMetric']
print(datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3],\
":INFO: epoch: {}, {}: {}".format(cur_epoch, "Dice", acc*100), flush=True)
if acc >= self.best_res:
self.best_res = acc
self.best_epoch = cur_epoch
print(datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3],\
":INFO: update best result: {}".format(acc*100), flush=True)
if self.save_best_ckpt:
if os.path.exists(self.bast_ckpt_path):
self.remove_ckpoint_file(self.bast_ckpt_path)
save_checkpoint(cb_params.train_network, self.bast_ckpt_path)
print(datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3],\
":INFO: update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True)
def end(self, run_context):
print(datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3],\
":INFO: End training, the best {0} is: {1}, it's epoch is {2}".format("Dice",\
self.best_res*100, self.best_epoch), flush=True)
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''train'''
import os
import datetime
import mindspore.nn as nn
from mindspore import context
from mindspore.train import Model
from mindspore.common import set_seed
from mindspore.context import ParallelMode
from mindspore.dataset import config
from mindspore.train.callback import TimeMonitor, LossMonitor
from mindspore import load_checkpoint, load_param_into_net
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
from mindspore.communication.management import init, get_rank, get_group_size
from src.logger import get_logger
from src.dataset import create_Dataset
from src.models import UNet3Plus, UNet3PlusWithLossCell
from src.util import DiceMetric, EvalCallBack, TempLoss
from src.config import config as cfg
def copy_data_from_obs():
'''copy_data_from_obs'''
if cfg.use_modelarts:
import moxing as mox
cfg.logger.info("copying train data from obs to cache....")
mox.file.copy_parallel(cfg.train_data_path, 'cache/dataset')
cfg.logger.info("copying traindata finished....")
cfg.train_data_path = 'cache/dataset/'
if cfg.resume_path:
cfg.logger.info("copying resume checkpoint from obs to cache....")
mox.file.copy_parallel(cfg.resume_path, 'cache/resume_path')
cfg.logger.info("copying resume checkpoint finished....")
cfg.resume_path = 'cache/resume_path/'
if cfg.eval_while_train:
cfg.logger.info("copying val data from obs to cache....")
mox.file.copy_parallel(cfg.val_data_path, 'cache/vatdataset')
cfg.logger.info("copying val data finished....")
cfg.val_data_path = 'cache/vatdataset/'
def copy_data_to_obs():
if cfg.use_modelarts:
import moxing as mox
cfg.logger.info("copying files from cache to obs....")
mox.file.copy_parallel(cfg.save_dir, cfg.outer_path)
cfg.logger.info("copying finished....")
def train():
'''trian'''
if cfg.is_distributed:
assert cfg.device_target == "Ascend"
init()
context.set_context(device_id=device_id)
cfg.rank = get_rank()
cfg.group_size = get_group_size()
device_num = cfg.group_size
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL)
else:
if cfg.device_target in ["Ascend", "GPU"]:
context.set_context(device_id=device_id)
config.set_enable_shared_mem(False)
train_dataset, cfg.steps_per_epoch = create_Dataset(cfg.train_data_path, cfg.aug, cfg.batch_size,\
cfg.group_size, cfg.rank, shuffle=True)
f_model = UNet3Plus()
if cfg.resume_path:
cfg.resume_path = os.path.join(cfg.resume_path, cfg.resume_name)
cfg.logger.info('loading resume checkpoint %s into network', str(cfg.resume_path))
load_param_into_net(f_model, load_checkpoint(cfg.resume_path))
cfg.logger.info('loaded resume checkpoint %s into network', str(cfg.resume_path))
optimizer = nn.Adam(params=f_model.trainable_params(), learning_rate=float(cfg.lr))
time_cb = TimeMonitor(data_size=cfg.steps_per_epoch)
loss_cb = LossMonitor(50)
callbacks = [time_cb, loss_cb]
if cfg.rank_save_ckpt_flag:
ckpt_config = CheckpointConfig(save_checkpoint_steps=cfg.steps_per_epoch*cfg.save_every,
keep_checkpoint_max=cfg.ckpt_save_max)
save_ckpt_path = os.path.join(cfg.save_dir, 'ckpt_' + str(cfg.rank) + '/')
ckpt_cb = ModelCheckpoint(config=ckpt_config,
directory=save_ckpt_path,
prefix='rank_'+str(cfg.rank))
callbacks.append(ckpt_cb)
if cfg.eval_while_train == 1:
loss_f = TempLoss()
val_dataset, _ = create_Dataset(cfg.val_data_path, 0, cfg.batch_size,\
1, 0, shuffle=False)
network_eval = Model(f_model, loss_fn=loss_f, metrics={"DiceMetric": DiceMetric()})
eval_cb = EvalCallBack(network_eval, val_dataset, interval=cfg.eval_steps,
eval_start_epoch=cfg.eval_start_epoch, save_best_ckpt=True,
ckpt_directory=cfg.save_dir, besk_ckpt_name=str(cfg.rank)+"_best_map.ckpt")
callbacks.append(eval_cb)
model = UNet3PlusWithLossCell(f_model)
model.set_train()
model = nn.TrainOneStepCell(model, optimizer)
model = Model(model)
model.train(cfg.epochs, train_dataset, callbacks=callbacks, dataset_sink_mode=True)
cfg.logger.info("training finished....")
if __name__ == '__main__':
set_seed(1)
cfg.save_dir = os.path.join(cfg.output_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
if not cfg.use_modelarts and not os.path.exists(cfg.save_dir):
os.makedirs(cfg.save_dir)
device_id = int(os.getenv('DEVICE_ID', '0'))
context.set_context(mode=context.GRAPH_MODE,
device_target=cfg.device_target, save_graphs=False)
cfg.logger = get_logger(cfg.save_dir, "UNet3Plus", cfg.rank)
cfg.logger.save_args(cfg)
# select for master rank save ckpt or all rank save, compatible for model parallel
cfg.rank_save_ckpt_flag = not (cfg.is_save_on_master and cfg.rank)
copy_data_from_obs()
train()
copy_data_to_obs()
cfg.logger.info('All task finished!')
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment