diff --git a/research/cv/u2net/README.md b/research/cv/u2net/README.md new file mode 100644 index 0000000000000000000000000000000000000000..701df7b930440c04ae681615f0064096a5c9afd9 --- /dev/null +++ b/research/cv/u2net/README.md @@ -0,0 +1,219 @@ +# Contents + +- [U-2-Net Description](#u-2-net-description) +- [Model Architecture](#model-architecture) +- [Dataset](#dataset) +- [Pretrained model](#pretrained-model) +- [Environment Requirements](#environment-requirements) +- [Script Description](#script-description) + - [Script and Sample Code](#script-and-sample-code) + - [Script Parameters](#script-parameters) + - [Run On Modelarts](#run-on-modelarts) + - [Training Process](#training-process) +- [Model Description](#model-description) + - [Performance](#performance) + - [Training Performance](#training-performance) + - [Evaluation Performance](#evaluation-performance) +- [Description of Random Situation](#description-of-random-situation) +- [ModelZoo Homepage](#modelzoo-homepage) + +# [U-2-Net Description](#contents) + +This is the Implementation of the Mindspore code of paper [**U<sup>2</sup>-Net: Going deeper with nested U-structure for +salient object detection** ](http://arxiv.org/abs/2005.09007) + +Authors: Qin, Xuebin and Zhang, Zichen and Huang, Chenyang and Dehghan, Masood and Zaiane, Osmar and Jagersand, Martin. + +In this paper, we design a simple yet powerful deep network architecture, U<sup>2</sup>-Net, for salient object +detection (SOD). The architecture of our U<sup>2</sup>-Net is a two-level nested U-structure. + +# [Model Architecture](#contents) + + + +The architecture of our U<sup>2</sup>-Net is a two-level nested U-structure. + +# [Dataset](#contents) + +To train U<sup>2</sup>-Net, We use the dataset [DUTS-TR](http://saliencydetection.net/duts/download/DUTS-TR.zip) + +# [Environment Requirements](#contents) + +- Hardware Ascend + - Prepare hardware environment with Ascend processor. +- Framework + - [MindSpore](https://www.mindspore.cn/install/en) +- For more information, please check the resources below£º + - [MindSpore Tutorials](https://www.mindspore.cn/tutorials/en/master/index.html) + - [MindSpore Python API](https://www.mindspore.cn/docs/api/en/master/index.html) + +# [Script Description](#contents) + +## [Script and Sample Code](#contents) + +```shell +U-2-Net + ├─ README.md # descriptions about U-2-Net + ├─ scripts + └─ run_distribute_train.sh # launch Ascend training (8 Ascend) + ├─ assets # save pics for README.MD + ├─ ckpts # save ckpt + ├─ src + ├─ data_loader.py # generate dataset for training + ├─ loss.py # loss function define + └─ blocks.py # U-2-Net model define + ├─ train_modelarts.py # train script for online train + ├─ test.py # generate detection images + ├─ eval.py # eval script + └─ train.py # train script +``` + +## [Script Parameters](#contents) + +### [Training Script Parameters](#contents) + +```shell +# distributed training +cd scripts +./run_distribute_train.sh [/path/to/content] [/path/to/label] [/path/to/RANK_TABLE_FILE] + +# standalone training +python train.py --content_path [/path/to/content] --label_path [/path/to/label] +``` + +### Training Result + +Training result will be stored in './ckpts'. You can find checkpoint file. + +### Evaluation Script Parameters + +- Run `test.py` to generate semantically segmented pictures +- Run `eval.py` for evaluation. + +```bash +# generate semantically segmented pictures +python test.py --content_path [/path/to/content] &>test.log & + +# evaling +python eval.py --pred_dir [/path/to/pred_dir] --label_dir [/path/to/label] &>evaluation.log & +``` + +### [Run On Modelarts](#contents) + +- Run `train.py` to train on modelarts + +```bash +#In order to run on modelars, you should place data files like this: +# └─ dataset +# ├─ pre_ckpt +# └─ DUTS +#params: +# run_distribute # Run distribute, default: false. +# data_url # path to data on obs default: None +# train_url # output path in obs default: None +# ckpt_name # prefix of ckpt files, default: u2net +# run_online # you should set run_online to 1 if you wnt to run online +# is_load_pre # whether use pretrained model, default: false. +``` + +- Run `test.py` to generate semantically segmented pictures on modelarts + +```bash +#In order to run on modelars, you should place data files like this: +# └─ dataset +# ├─ pre_ckpt +# └─ content_dir +#params: +# data_url # path to data on obs default: None +# run_online # you should set run_online to 1 if you wnt to run online +``` + +- Run `eval.py` to evaluate on modelarts + +```bash +#In order to run on modelars, you should place data files like this: +# └─ dataset +# ├─ pre_ckpt +# ├─ label_dir +# └─ content_dir +#params: +# data_url # path to data on obs default: None +# run_online # you should set run_online to 1 if you wnt to run online +``` + +## [Training Process](#contents) + +### [Training](#contents) + +- Run `run_standalone_train_ascend.sh` for non-distributed training of U-2-Net model. + +```bash +# standalone training +python train.py --content_path [/path/to/content] --label_path [/path/to/label] +``` + +### [Distributed Training](#contents) + +- Run `run_distributed_train_ascend.sh` for distributed training of U-2-Net model. + +```bash +# distributed training +cd scripts +bash run_distribute_train.sh [/path/to/content] [/path/to/label] [/path/to/RANK_TABLE_FILE] +``` + +- Notes + +1. hccl.json which is specified by RANK_TABLE_FILE is needed when you are running a distribute task. You can generate it + by using the [hccl_tools](https://gitee.com/mindspore/models/tree/master/utils/hccl_tools). + +# [Model Description](#contents) + +## [Performance](#contents) + +### Training Performance + +| Parameters | | +| -------------------------- | ----------------------------------------------------- | +| Model Version | v1 | +| Resource | Red Hat 8.3.1; Ascend 910; CPU 2.60GHz; 192cores | +| MindSpore Version | 1.3.0 | +| Dataset | DUTS-TR. | +| Training Parameters | epoch=1500, batch_size = 16 | +| Optimizer | Adam | +| Loss Function | BCELoss | +| outputs | semantically segmented pictures | +| Speed | 8 Ascend: 440 ms/step; 1 Ascend: 303 ms/step; | +| Total time | 8pcs: 15h | +| Checkpoint for Fine tuning | 512M (.ckpt file) | + +### Picture Generating Performance + +| Parameters | single Ascend | +| ----------------- | ------------------------------------------------ | +| Model Version | v1 | +| Resource | Red Hat 8.3.1; Ascend 910; CPU 2.60GHz; 192cores | +| MindSpore Version | 1.3.0 | +| Dataset | content images | +| batch_size | 1 | +| outputs | semantically segmented pictures | +| Speed | 95 ms/pic | + +### Evaluation Performance + +| Parameters | single Ascend | +| ----------------- | ------------------------------------------------ | +| Model Version | v1 | +| Resource | Red Hat 8.3.1; Ascend 910; CPU 2.60GHz; 192cores | +| MindSpore Version | 1.3.0 | +| Dataset | DUTS-TE | +| batch_size | 1 | +| Accuracy | 85.52% | + +# [Description of Random Situation](#contents) + +We use random seed in train.py. + +# [ModelZoo Homepage](#contents) + +Please check the official [homepage](https://gitee.com/mindspore/models). diff --git a/research/cv/u2net/assets/network.png b/research/cv/u2net/assets/network.png new file mode 100644 index 0000000000000000000000000000000000000000..299b36135e6c3ba94cdb04d25d3bfa8280558301 Binary files /dev/null and b/research/cv/u2net/assets/network.png differ diff --git a/research/cv/u2net/eval.py b/research/cv/u2net/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..7e5cd608a57bc51a84c2a63085a5c68662006389 --- /dev/null +++ b/research/cv/u2net/eval.py @@ -0,0 +1,97 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""eval process""" +import os +import argparse +import ast + +import numpy as np +from PIL import Image + +parser = argparse.ArgumentParser() +parser.add_argument("--pred_dir", type=str, help='pred_dir, default: None') +parser.add_argument("--label_dir", type=str, help='label_dir, default: None') + +# additional params for online evaluating +parser.add_argument("--run_online", type=ast.literal_eval, default=False, help='whether train online, default: false') +parser.add_argument("--data_url", type=str, help='path to data on obs, default: None') +parser.add_argument("--train_url", type=str, help='output path on obs, default: None') + +args = parser.parse_args() +if __name__ == '__main__': + if args.run_online: + import moxing as mox + mox.file.copy_parallel(args.data_url, "/cache/dataset") + pred_dir = "/cache/dataset/pred_dir" + label_dir = "/cache/dataset/label_dir" + + else: + pred_dir = args.pred_dir + label_dir = args.label_dir + + def generate(pred, label, num=255): + """generate prec and recall""" + prec = np.zeros(num) + recall = np.zeros(num) + min_num = 0 + max_num = 1 - 1e-10 + for i in range(num): + tmp_num = (max_num - min_num) / num * i + pred_ones = pred >= tmp_num + acc = (pred_ones * label).sum() + recall[i] = acc / (pred_ones.sum() + 1e-20) + prec[i] = acc / (label.sum() + 1e-20) + return prec, recall + + + def F_score(pred, label, num=255): + """calculate f-score""" + prec, recall = generate(pred, label, num) + beta2 = 0.3 + f_score = (1 + beta2) * prec * recall / (beta2 * prec + recall + 1e-20) + return f_score + + + def max_F(pred_directory, label_directory, num=255): + """calculate max f-score""" + sum_value = np.zeros(num) + content_list = os.listdir(pred_directory) + pic_num = 0 + for i in range(len(content_list)): + + pred_path = pred_directory + "/" + content_list[i] + pred = np.array(Image.open(pred_path), dtype='float32') + + pic_name = content_list[i].replace(".jpg", "").replace(".png", "").replace(".JPEG", "") + print("%d / %d , %s \n" % (i, len(content_list), pic_name)) + label_path = os.path.join(label_directory, pic_name) + ".png" + label = np.array(Image.open(label_path), dtype='float32') + if len(label.shape) > 2: + label = label[:, :, 0] + + if len(pred.shape) > 2: + print(pred.shape) + pred = pred[:, :, 0] + pred = pred.squeeze() + print(pred.shape) + label /= label.max() + pred /= pred.max() + tmp = F_score(pred, label, num) + sum_value += tmp + pic_num += 1 + score = sum_value / pic_num + return score.max() + + print("max_F measure, score = %f" % (max_F(pred_dir, label_dir))) diff --git a/research/cv/u2net/requirements.txt b/research/cv/u2net/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..4f6e0db57d4dacf3ff725fade2008e584960097c --- /dev/null +++ b/research/cv/u2net/requirements.txt @@ -0,0 +1,5 @@ +easydict~=1.9 +imageio~=2.9.0 +opencv_python~=4.5.1.48 +Pillow~=8.4.0 +scikit_image~=0.18.1 diff --git a/research/cv/u2net/scripts/run_distribute_train.sh b/research/cv/u2net/scripts/run_distribute_train.sh new file mode 100644 index 0000000000000000000000000000000000000000..4797ddad7deaea7a2bec010e887f5635102fddc0 --- /dev/null +++ b/research/cv/u2net/scripts/run_distribute_train.sh @@ -0,0 +1,72 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +echo "==============================================================================================================" +echo "Please run the script as: " +echo "bash run.sh DATA_PATH RANK_SIZE" +echo "For example: bash run.sh /path/dataset 8" +echo "It is better to use the absolute path." +echo "==============================================================================================================" + +if [ ! -d $1 ] +then + echo "error: CONTENT_PATH=$2 is not a directory" +exit 1 +fi + +if [ ! -d $2 ] +then + echo "error: LABEL_PATH=$2 is not a directory" +exit 1 +fi + +if [ ! -f $3 ] +then + echo "error: RANK_TABLE_FILE=$3 is not a file" +exit 1 +fi + +set -e +export RANK_TABLE_FILE=$3 +export RANK_SIZE=8 +EXEC_PATH=$(pwd) + +echo "$EXEC_PATH" + +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +for((i=0;i<${RANK_SIZE};i++)) +do + rm -rf device$i + mkdir device$i + cd ./device$i + cd ../ + cp ../*.py ./device$i + cp -r ../src ./device$i + cd ./device$i + export DEVICE_ID=$i + export RANK_ID=$i + echo "start training for device $i" + env > env$i.log + python train.py --run_distribute 1 --content_path $1 --label_path $2 > output.log 2>&1 & + echo "$i finish" + cd ../ +done + diff --git a/research/cv/u2net/src/__init__.py b/research/cv/u2net/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6228b7132697d24157a4052193061e9913f031c4 --- /dev/null +++ b/research/cv/u2net/src/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ diff --git a/research/cv/u2net/src/blocks.py b/research/cv/u2net/src/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..3f150073bd95f0e30cbe13b0fca0ca13ec01c73e --- /dev/null +++ b/research/cv/u2net/src/blocks.py @@ -0,0 +1,455 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""network blocks definition""" + +import mindspore +from mindspore import nn, ops + +init_mode = 'xavier_uniform' +has_bias = True +bias_init = 'zeros' + +class REBNCONV(nn.Cell): + """ + A basic unit consisting of convolution, batchnorm, and relu activation functions + """ + + def __init__(self, in_ch=3, out_ch=3, dirate=1): + """definition method""" + super(REBNCONV, self).__init__() + self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, dilation=1 * dirate, weight_init=init_mode, has_bias=has_bias, + bias_init=bias_init) + self.bn_s1 = nn.BatchNorm2d(out_ch, affine=True) + self.relu_s1 = nn.ReLU() + + def construct(self, x): + """compute method""" + hx = x + hx = self.conv_s1(hx) + hx = self.bn_s1(hx) + xout = self.relu_s1(hx) + return xout + + +def _upsample_like(src, tar): + """generate upsample unit""" + resize_bilinear = mindspore.ops.operations.ResizeBilinear(tar.shape[2:]) + src = resize_bilinear(src) + return src + + +class RSU7(nn.Cell): + """RSU7 block""" + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + """RSU7 definition""" + super(RSU7, self).__init__() + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2) + + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2) + + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool3 = nn.MaxPool2d(2, stride=2) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool4 = nn.MaxPool2d(2, stride=2) + + self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool5 = nn.MaxPool2d(2, stride=2) + + self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1) + + self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2) + + self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + self.cat = ops.Concat(axis=1) + + def construct(self, x): + """RSU7 compute""" + hx = x + hxin = self.rebnconvin(hx) + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + hx = self.pool4(hx4) + + hx5 = self.rebnconv5(hx) + hx = self.pool5(hx5) + + hx6 = self.rebnconv6(hx) + hx7 = self.rebnconv7(hx6) + hx6d = self.rebnconv6d(self.cat((hx7, hx6))) + hx6dup = _upsample_like(hx6d, hx5) + + hx5d = self.rebnconv5d(self.cat((hx6dup, hx5))) + hx5dup = _upsample_like(hx5d, hx4) + + hx4d = self.rebnconv4d(self.cat((hx5dup, hx4))) + hx4dup = _upsample_like(hx4d, hx3) + + hx3d = self.rebnconv3d(self.cat((hx4dup, hx3))) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.rebnconv2d(self.cat((hx3dup, hx2))) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.rebnconv1d(self.cat((hx2dup, hx1))) + + return hx1d + hxin + + +class RSU6(nn.Cell): + """RSU6 block""" + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + """RSU6 definition""" + super(RSU6, self).__init__() + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2) + + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2) + + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool3 = nn.MaxPool2d(2, stride=2) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool4 = nn.MaxPool2d(2, stride=2) + + self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1) + + self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2) + + self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + self.cat = ops.Concat(axis=1) + + def construct(self, x): + """RSU6 compute""" + hx = x + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + hx = self.pool4(hx4) + + hx5 = self.rebnconv5(hx) + + hx6 = self.rebnconv6(hx5) + + hx5d = self.rebnconv5d(self.cat((hx6, hx5))) + hx5dup = _upsample_like(hx5d, hx4) + + hx4d = self.rebnconv4d(self.cat((hx5dup, hx4))) + hx4dup = _upsample_like(hx4d, hx3) + + hx3d = self.rebnconv3d(self.cat((hx4dup, hx3))) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.rebnconv2d(self.cat((hx3dup, hx2))) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.rebnconv1d(self.cat((hx2dup, hx1))) + + return hx1d + hxin + + +class RSU5(nn.Cell): + """RSU5 block""" + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + """RSU5 definition""" + super(RSU5, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2) + + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2) + + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool3 = nn.MaxPool2d(2, stride=2) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) + + self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2) + + self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + self.cat = ops.Concat(axis=1) + + def construct(self, x): + """RSU5 compute""" + hx = x + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + + hx5 = self.rebnconv5(hx4) + + hx4d = self.rebnconv4d(self.cat((hx5, hx4))) + hx4dup = _upsample_like(hx4d, hx3) + + hx3d = self.rebnconv3d(self.cat((hx4dup, hx3))) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.rebnconv2d(self.cat((hx3dup, hx2))) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.rebnconv1d(self.cat((hx2dup, hx1))) + + return hx1d + hxin + + +class RSU4(nn.Cell): + """RSU4 block""" + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + """RSU4 definition""" + super(RSU4, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2) + + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2) + + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2) + + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + self.cat = ops.Concat(axis=1) + + def construct(self, x): + """RSU4 compute""" + hx = x + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + + hx4 = self.rebnconv4(hx3) + + hx3d = self.rebnconv3d(self.cat((hx4, hx3))) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.rebnconv2d(self.cat((hx3dup, hx2))) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.rebnconv1d(self.cat((hx2dup, hx1))) + return hx1d + hxin + + +class RSU4F(nn.Cell): + """RSU4F block""" + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + """RSU4F definition""" + super(RSU4F, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2) + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8) + + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + self.cat = ops.Concat(axis=1) + + def construct(self, x): + """RSU4F compute""" + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx2 = self.rebnconv2(hx1) + hx3 = self.rebnconv3(hx2) + + hx4 = self.rebnconv4(hx3) + + hx3d = self.rebnconv3d(self.cat((hx4, hx3))) + hx2d = self.rebnconv2d(self.cat((hx3d, hx2))) + hx1d = self.rebnconv1d(self.cat((hx2d, hx1))) + + return hx1d + hxin + + +class U2NET(nn.Cell): + """U-2-Net model""" + + def __init__(self, in_ch=3, out_ch=1): + """U-2-Net definition""" + super(U2NET, self).__init__() + + self.stage1 = RSU7(in_ch, 32, 64) + self.pool12 = nn.MaxPool2d(2, stride=2) + + self.stage2 = RSU6(64, 32, 128) + self.pool23 = nn.MaxPool2d(2, stride=2) + + self.stage3 = RSU5(128, 64, 256) + self.pool34 = nn.MaxPool2d(2, stride=2) + + self.stage4 = RSU4(256, 128, 512) + self.pool45 = nn.MaxPool2d(2, stride=2) + + self.stage5 = RSU4F(512, 256, 512) + self.pool56 = nn.MaxPool2d(2, stride=2) + + self.stage6 = RSU4F(512, 256, 512) + + # decoder + self.stage5d = RSU4F(1024, 256, 512) + self.stage4d = RSU4(1024, 128, 256) + self.stage3d = RSU5(512, 64, 128) + self.stage2d = RSU6(256, 32, 64) + self.stage1d = RSU7(128, 16, 64) + + self.side1 = nn.Conv2d(64, out_ch, 3, pad_mode='same', weight_init=init_mode, has_bias=has_bias, + bias_init=bias_init) + self.side2 = nn.Conv2d(64, out_ch, 3, pad_mode='same', weight_init=init_mode, has_bias=has_bias, + bias_init=bias_init) + self.side3 = nn.Conv2d(128, out_ch, 3, pad_mode='same', weight_init=init_mode, has_bias=has_bias, + bias_init=bias_init) + self.side4 = nn.Conv2d(256, out_ch, 3, pad_mode='same', weight_init=init_mode, has_bias=has_bias, + bias_init=bias_init) + self.side5 = nn.Conv2d(512, out_ch, 3, pad_mode='same', weight_init=init_mode, has_bias=has_bias, + bias_init=bias_init) + self.side6 = nn.Conv2d(512, out_ch, 3, pad_mode='same', weight_init=init_mode, has_bias=has_bias, + bias_init=bias_init) + self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1, pad_mode='same', weight_init=init_mode, has_bias=has_bias, + bias_init=bias_init) + + self.cat = ops.Concat(axis=1) + self.Sigmoid = nn.Sigmoid() + self.reshape = ops.Reshape() + + def construct(self, x): + """U-2-Net compute""" + hx = x + + # stage 1 + hx1 = self.stage1(hx) + hx = self.pool12(hx1) + + # stage 2 + hx2 = self.stage2(hx) + hx = self.pool23(hx2) + + # stage 3 + hx3 = self.stage3(hx) + hx = self.pool34(hx3) + + # stage 4 + hx4 = self.stage4(hx) + hx = self.pool45(hx4) + + # stage 5 + hx5 = self.stage5(hx) + hx = self.pool56(hx5) + + # stage 6 + hx6 = self.stage6(hx) + hx6up = _upsample_like(hx6, hx5) + # -------------------- decoder -------------------- + hx5d = self.stage5d(self.cat((hx6up, hx5))) + hx5dup = _upsample_like(hx5d, hx4) + + hx4d = self.stage4d(self.cat((hx5dup, hx4))) + hx4dup = _upsample_like(hx4d, hx3) + + hx3d = self.stage3d(self.cat((hx4dup, hx3))) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.stage2d(self.cat((hx3dup, hx2))) + hx2dup = _upsample_like(hx2d, hx1) + hx1d = self.stage1d(self.cat((hx2dup, hx1))) + # side output + d1 = self.side1(hx1d) + + d2 = self.side2(hx2d) + d2 = _upsample_like(d2, d1) + + d3 = self.side3(hx3d) + d3 = _upsample_like(d3, d1) + + d4 = self.side4(hx4d) + d4 = _upsample_like(d4, d1) + + d5 = self.side5(hx5d) + d5 = _upsample_like(d5, d1) + + d6 = self.side6(hx6) + d6 = _upsample_like(d6, d1) + + d0 = self.outconv(self.cat((d1, d2, d3, d4, d5, d6))) + return self.cat((self.Sigmoid(d0), self.Sigmoid(d1), self.Sigmoid(d2), self.Sigmoid(d3), self.Sigmoid(d4), + self.Sigmoid(d5), self.Sigmoid(d6))) diff --git a/research/cv/u2net/src/config.py b/research/cv/u2net/src/config.py new file mode 100644 index 0000000000000000000000000000000000000000..d28290fba1599665382a567c2e1b569976e0716d --- /dev/null +++ b/research/cv/u2net/src/config.py @@ -0,0 +1,33 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""config of distribute training and standanlone training""" +from easydict import EasyDict as edict + +single_cfg = edict({ + "lr": 1e-3, + "batch_size": 12, + "max_epoch": 700, + "keep_checkpoint_max": 10, + "weight_decay": 0, + "eps": 1e-8, +}) +run_distribute_cfg = edict({ + "lr": 4e-3, + "batch_size": 16, + "max_epoch": 1500, + "keep_checkpoint_max": 10, + "weight_decay": 0, + "eps": 1e-8, +}) diff --git a/research/cv/u2net/src/data_loader.py b/research/cv/u2net/src/data_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..66eb91711aea3067d5c40570c1d678226140bc1d --- /dev/null +++ b/research/cv/u2net/src/data_loader.py @@ -0,0 +1,380 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""definition of some data loading operations""" + +from __future__ import print_function, division +import random +import os +import glob + +import numpy as np +from skimage import io, transform, color + +from mindspore import context +from mindspore import dataset as ds +from mindspore.common import dtype as mstype +import mindspore.dataset.transforms.c_transforms as CC +from mindspore.context import ParallelMode +from mindspore.communication.management import get_rank, get_group_size + + +# ==========================dataset load========================== +class RescaleT: + """Rescale operation""" + + def __init__(self, output_size): + """RescaleT definition""" + assert isinstance(output_size, (int, tuple)) + self.output_size = output_size + + def __call__(self, sample): + """RescaleT compute""" + imidx, image, label = sample['imidx'], sample['image'], sample['label'] + + h, w = image.shape[:2] + + if isinstance(self.output_size, int): + if h > w: + new_h, new_w = self.output_size * h / w, self.output_size + else: + new_h, new_w = self.output_size, self.output_size * w / h + else: + new_h, new_w = self.output_size + + new_h, new_w = int(new_h), int(new_w) + + img = transform.resize(image, (self.output_size, self.output_size), mode='constant') + lbl = transform.resize(label, (self.output_size, self.output_size), mode='constant', order=0, + preserve_range=True) + + return {'imidx': imidx, 'image': img, 'label': lbl} + + +class Rescale: + """Rescale operation""" + + def __init__(self, output_size): + """Rescale definition""" + assert isinstance(output_size, (int, tuple)) + self.output_size = output_size + + def __call__(self, sample): + """Rescale compute""" + imidx, image, label = sample['imidx'], sample['image'], sample['label'] + + if random.random() >= 0.5: + image = image[::-1] + label = label[::-1] + + a = int(random.random() * 4) % 4 + image = np.rot90(image, a) + label = np.rot90(label, a) + h, w = image.shape[:2] + + if isinstance(self.output_size, int): + if h > w: + new_h, new_w = self.output_size * h / w, self.output_size + else: + new_h, new_w = self.output_size, self.output_size * w / h + else: + new_h, new_w = self.output_size + + new_h, new_w = int(new_h), int(new_w) + img = transform.resize(image, (new_h, new_w), mode='constant') + lbl = transform.resize(label, (new_h, new_w), mode='constant', order=0, preserve_range=True) + + return {'imidx': imidx, 'image': img, 'label': lbl} + + +class RandomCrop: + """RandomCrop operation""" + + def __init__(self, output_size): + """RandomCrop definition""" + assert isinstance(output_size, (int, tuple)) + if isinstance(output_size, int): + self.output_size = (output_size, output_size) + else: + assert len(output_size) == 2 + self.output_size = output_size + + def __call__(self, sample): + """RandomCrop compute""" + imidx, image, label = sample['imidx'], sample['image'], sample['label'] + + if random.random() >= 0.5: + image = image[::-1] + label = label[::-1] + + a = int(random.random() * 4) % 4 + image = np.rot90(image, a) + label = np.rot90(label, a) + h, w = image.shape[:2] + new_h, new_w = self.output_size + + top = np.random.randint(0, h - new_h) + left = np.random.randint(0, w - new_w) + + image = image[top: top + new_h, left: left + new_w] + label = label[top: top + new_h, left: left + new_w] + + return {'imidx': imidx, 'image': image, 'label': label} + + +class ToTensor: + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + """Convert operation""" + + imidx, image, label = sample['imidx'], sample['image'], sample['label'] + + tmpImg = np.zeros((image.shape[0], image.shape[1], 3)) + tmpLbl = np.zeros(label.shape) + + image = image / np.max(image) + if np.max(label) < 1e-6: + label = label + else: + label = label / np.max(label) + + if image.shape[2] == 1: + tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229 + tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229 + tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229 + else: + tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229 + tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224 + tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225 + + tmpLbl[:, :, 0] = label[:, :, 0] + + tmpImg = tmpImg.transpose((2, 0, 1)) + tmpLbl = label.transpose((2, 0, 1)) + + return imidx, tmpImg, tmpLbl + + +class ToTensorLab: + """Convert ndarrays in sample to Tensors.""" + + def __init__(self, flag=0): + """ToTensorLab definition""" + self.flag = flag + + def __call__(self, sample): + """ToTensorLab compute""" + imidx, image, label = sample['imidx'], sample['image'], sample['label'] + + tmpLbl = np.zeros(label.shape) + + if np.max(label) < 1e-6: + label = label + else: + label = label / np.max(label) + + # change the color space + if self.flag == 2: # with rgb and Lab colors + tmpImg = np.zeros((image.shape[0], image.shape[1], 6)) + tmpImgt = np.zeros((image.shape[0], image.shape[1], 3)) + if image.shape[2] == 1: + tmpImgt[:, :, 0] = image[:, :, 0] + tmpImgt[:, :, 1] = image[:, :, 0] + tmpImgt[:, :, 2] = image[:, :, 0] + else: + tmpImgt = image + tmpImgtl = color.rgb2lab(tmpImgt) + + # nomalize image to range [0,1] + tmpImg[:, :, 0] = (tmpImgt[:, :, 0] - np.min(tmpImgt[:, :, 0])) / ( + np.max(tmpImgt[:, :, 0]) - np.min(tmpImgt[:, :, 0])) + tmpImg[:, :, 1] = (tmpImgt[:, :, 1] - np.min(tmpImgt[:, :, 1])) / ( + np.max(tmpImgt[:, :, 1]) - np.min(tmpImgt[:, :, 1])) + tmpImg[:, :, 2] = (tmpImgt[:, :, 2] - np.min(tmpImgt[:, :, 2])) / ( + np.max(tmpImgt[:, :, 2]) - np.min(tmpImgt[:, :, 2])) + tmpImg[:, :, 3] = (tmpImgtl[:, :, 0] - np.min(tmpImgtl[:, :, 0])) / ( + np.max(tmpImgtl[:, :, 0]) - np.min(tmpImgtl[:, :, 0])) + tmpImg[:, :, 4] = (tmpImgtl[:, :, 1] - np.min(tmpImgtl[:, :, 1])) / ( + np.max(tmpImgtl[:, :, 1]) - np.min(tmpImgtl[:, :, 1])) + tmpImg[:, :, 5] = (tmpImgtl[:, :, 2] - np.min(tmpImgtl[:, :, 2])) / ( + np.max(tmpImgtl[:, :, 2]) - np.min(tmpImgtl[:, :, 2])) + + tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(tmpImg[:, :, 1]) + tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(tmpImg[:, :, 2]) + tmpImg[:, :, 3] = (tmpImg[:, :, 3] - np.mean(tmpImg[:, :, 3])) / np.std(tmpImg[:, :, 3]) + tmpImg[:, :, 4] = (tmpImg[:, :, 4] - np.mean(tmpImg[:, :, 4])) / np.std(tmpImg[:, :, 4]) + tmpImg[:, :, 5] = (tmpImg[:, :, 5] - np.mean(tmpImg[:, :, 5])) / np.std(tmpImg[:, :, 5]) + + elif self.flag == 1: # with Lab color + tmpImg = np.zeros((image.shape[0], image.shape[1], 3)) + + if image.shape[2] == 1: + tmpImg[:, :, 0] = image[:, :, 0] + tmpImg[:, :, 1] = image[:, :, 0] + tmpImg[:, :, 2] = image[:, :, 0] + else: + tmpImg = image + + tmpImg = color.rgb2lab(tmpImg) + + tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.min(tmpImg[:, :, 0])) / ( + np.max(tmpImg[:, :, 0]) - np.min(tmpImg[:, :, 0])) + tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.min(tmpImg[:, :, 1])) / ( + np.max(tmpImg[:, :, 1]) - np.min(tmpImg[:, :, 1])) + tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.min(tmpImg[:, :, 2])) / ( + np.max(tmpImg[:, :, 2]) - np.min(tmpImg[:, :, 2])) + + tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std(tmpImg[:, :, 0]) + tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(tmpImg[:, :, 1]) + tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(tmpImg[:, :, 2]) + + else: # with rgb color + tmpImg = np.zeros((image.shape[0], image.shape[1], 3)) + image = image / np.max(image) + if image.shape[2] == 1: + tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229 + tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229 + tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229 + else: + tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229 + tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224 + tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225 + + tmpLbl[:, :, 0] = label[:, :, 0] + + # change the rgb to brg + tmpImg = tmpImg.transpose((2, 0, 1)) + tmpLbl = label.transpose((2, 0, 1)) + + return {'imidx': imidx, 'image': tmpImg, 'label': tmpLbl} + + +class SalObjDataset: + """Preprocess the tensor""" + + def __init__(self, img_name_list, lbl_name_list, train_transform=None): + """ + SalObjDataset definition + """ + self.image_name_list = img_name_list + self.label_name_list = lbl_name_list + self.train_transform = train_transform + self.rescale = RescaleT(320) + self.randomcrop = RandomCrop(288) + self.totensor = ToTensorLab(flag=0) + + def __len__(self): + """ + get the length of the dataset + """ + return len(self.image_name_list) + + def __getitem__(self, idx): + """ + get data in one step + """ + image = io.imread(self.image_name_list[idx]) + imidx = np.array([idx]) + + if not self.label_name_list: + label_3 = np.zeros(image.shape) + else: + label_3 = io.imread(self.label_name_list[idx]) + + label = np.zeros(label_3.shape[0:2]) + if len(label_3.shape) == 3: + label = label_3[:, :, 0] + elif len(label_3.shape) == 2: + label = label_3 + + if len(image.shape) == 3 and len(label.shape) == 2: + label = label[:, :, np.newaxis] + elif len(image.shape) == 2 and len(label.shape) == 2: + image = image[:, :, np.newaxis] + label = label[:, :, np.newaxis] + + sample = {'imidx': imidx, 'image': image, 'label': label} + + if self.train_transform: + sample = self.rescale(sample) + sample = self.randomcrop(sample) + sample = self.totensor(sample) + else: + sample = self.rescale(sample) + sample = self.totensor(sample) + + return sample['image'], sample['label'] + + +def _get_rank_info(): + """ + get rank size and rank id + """ + rank_size = int(os.environ.get("RANK_SIZE", 1)) + + if rank_size > 1: + rank_size = get_group_size() + rank_id = get_rank() + else: + rank_size = 1 + rank_id = 0 + + return rank_size, rank_id + + +def create_dataset(image_dir, label_dir, args): + """ + create dataset + """ + parallel_mode = context.get_auto_parallel_context("parallel_mode") + tra_image_dir = image_dir + tra_label_dir = label_dir + image_ext = '.jpg' + label_ext = '.png' + tra_img_name_list = glob.glob(tra_image_dir + '*' + image_ext) + tra_lbl_name_list = [] + for img_path in tra_img_name_list: + img_name = img_path.split(os.sep)[-1] + aaa = img_name.split(".") + bbb = aaa[0:-1] + imidx = bbb[0] + for i in range(1, len(bbb)): + imidx = imidx + "." + bbb[i] + tra_lbl_name_list.append(tra_label_dir + imidx + label_ext) + print("---") + print(image_dir) + print("train images: ", len(tra_img_name_list)) + print("train labels: ", len(tra_lbl_name_list)) + print("---") + + salobj_dataset = SalObjDataset( + img_name_list=tra_img_name_list, + lbl_name_list=tra_lbl_name_list, + train_transform=True) + if parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + device_num, rank_id = _get_rank_info() + type_cast = CC.TypeCast(mstype.float32) + data_set = ds.GeneratorDataset(salobj_dataset, ["image", "label"], num_parallel_workers=8, shuffle=True, + num_shards=device_num, shard_id=rank_id) + data_set = data_set.map(operations=type_cast, input_columns=["image"], num_parallel_workers=8) + data_set = data_set.map(operations=type_cast, input_columns=["label"], num_parallel_workers=8) + data_set = data_set.batch(args.batch_size, drop_remainder=True) + else: + type_cast = CC.TypeCast(mstype.float32) + data_set = ds.GeneratorDataset(salobj_dataset, ["image", "label"], num_parallel_workers=8, shuffle=True) + data_set = data_set.map(operations=type_cast, input_columns=["image"], num_parallel_workers=8) + data_set = data_set.map(operations=type_cast, input_columns=["label"], num_parallel_workers=8) + data_set = data_set.batch(args.batch_size, drop_remainder=True) + return data_set diff --git a/research/cv/u2net/src/loss.py b/research/cv/u2net/src/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..6fdec29224247e0cae72b8bbc6b6df03340b994e --- /dev/null +++ b/research/cv/u2net/src/loss.py @@ -0,0 +1,34 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""loss function definition""" + +from mindspore import nn +from mindspore import ops +from mindspore.nn.loss.loss import LossBase + +class total_loss(LossBase): + """Loss function""" + + def __init__(self): + """loss function definition""" + super(total_loss, self).__init__() + self.bceloss = nn.BCELoss(reduction='mean') + self.reshape = ops.Reshape() + self.cat = ops.Concat(axis=1) + self.squeeze = ops.Squeeze() + + def construct(self, generated, target): + """loss function compute""" + return self.bceloss(generated, self.cat((target, target, target, target, target, target, target))) * 7 diff --git a/research/cv/u2net/test.py b/research/cv/u2net/test.py new file mode 100644 index 0000000000000000000000000000000000000000..a78d7aca49649920be9748860fa3157c83814759 --- /dev/null +++ b/research/cv/u2net/test.py @@ -0,0 +1,124 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""operation to generate semantically segmented pictures""" +import os +import time +import argparse + +import numpy as np +import cv2 +import imageio +from PIL import Image + +from mindspore import context +from mindspore import Tensor +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +import src.blocks as blocks + +parser = argparse.ArgumentParser() +parser.add_argument("--content_path", type=str, help='content_path, default: None') +parser.add_argument('--pre_trained', type=str, help='model_path, local pretrained model to load') +parser.add_argument("--output_dir", type=str, default='output_dir', help='output_path, path to store output') + +#additional params for online generating +parser.add_argument("--run_online", type=int, default=0, help='whether train online, default: false') +parser.add_argument("--data_url", type=str, help='path to data on obs, default: None') +parser.add_argument("--train_url", type=str, help='output path on obs, default: None') + +args = parser.parse_args() + +if __name__ == '__main__': + + if args.run_online: + import moxing as mox + mox.file.copy_parallel(args.data_url, "/cache/dataset") + local_dataset_dir = "/cache/dataset/content_dir" + pre_ckpt_dir = "/cache/dataset/pre_ckpt" + pre_ckpt_path = pre_ckpt_dir + "/" + os.listdir(pre_ckpt_dir)[0] + output_dir = "/cache/dataset/pred_dir" + else: + local_dataset_dir = args.content_path + pre_ckpt_path = args.pre_trained + output_dir = args.output_dir + + context.set_context(mode=context.GRAPH_MODE) + param_dict = load_checkpoint(pre_ckpt_path) + net = blocks.U2NET() + net.set_train(False) + load_param_into_net(net, param_dict) + + + def normPRED(d): + """rescale the value of tensor to between 0 and 1""" + ma = d.max() + mi = d.min() + dn = (d - mi) / (ma - mi) + return dn + + + def normalize(img, im_type): + """normalize tensor""" + if im_type == "label": + return img + if len(img.shape) == 3: + img[:, :, 0] = (img[:, :, 0] - 0.485) / 0.229 + img[:, :, 1] = (img[:, :, 1] - 0.456) / 0.224 + img[:, :, 2] = (img[:, :, 2] - 0.406) / 0.225 + else: + img = (img - 0.485) / 0.229 + return img + + + def crop_and_resize(img_path, im_type, size=320): + """crop and resize tensors""" + img = np.array(Image.open(img_path), dtype='float32') + img = img / 255 + img = normalize(img, im_type) + h, w = img.shape[:2] + img = cv2.resize(img, dsize=(0, 0), fx=size / w, fy=size / h) + if len(img.shape) == 2: + img = np.expand_dims(img, 2).repeat(1, axis=2) + im = img + im = np.swapaxes(im, 1, 2) + im = np.swapaxes(im, 0, 1) + im = np.reshape(im, (1, im.shape[0], im.shape[1], im.shape[2])) + return im + + + content_list = os.listdir(local_dataset_dir) + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + start_time = time.time() + for j in range(0, len(content_list)): + pic_path = os.path.join(local_dataset_dir, content_list[j]) + content_pic = crop_and_resize(pic_path, im_type="content", size=320) + image = net(Tensor(content_pic)) + content_name = content_list[j].replace(".jpg", "") + content_name = content_name.replace(".png", "") + file_path = os.path.join(local_dataset_dir, content_list[j]) + original = np.array(Image.open(file_path), dtype='float32') + shape = original.shape + image = normPRED(image[0][0].asnumpy()) + image = cv2.resize(image, dsize=(0, 0), fx=shape[1] / image.shape[1], fy=shape[0] / image.shape[0]) + file_path = os.path.join(output_dir, content_name) + ".png" + imageio.imsave(file_path, image) + print("%d / %d , %s \n" % (j, len(content_list), content_name)) + end_time = time.time() + dtime = end_time - start_time + print("finish generating in %.8s s" % (dtime)) + if args.run_online: + mox.file.copy_parallel(output_dir, args.train_url) diff --git a/research/cv/u2net/train.py b/research/cv/u2net/train.py new file mode 100644 index 0000000000000000000000000000000000000000..021bf63c69ceb5b665da2f828faf667cd0e70df8 --- /dev/null +++ b/research/cv/u2net/train.py @@ -0,0 +1,120 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train process""" + +import os +import random +import argparse +import ast + +import numpy as np + +from mindspore import nn +from mindspore import context +from mindspore import FixedLossScaleManager +from mindspore import Model +from mindspore import dataset as ds +from mindspore.context import ParallelMode +from mindspore.train.callback import LossMonitor, ModelCheckpoint, CheckpointConfig, TimeMonitor +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.communication.management import init + +import src.config as config +from src.data_loader import create_dataset +from src.loss import total_loss +from src.blocks import U2NET + +random.seed(1) +np.random.seed(1) +ds.config.set_seed(1) +context.set_context(mode=context.GRAPH_MODE, device_target='Ascend') + +parser = argparse.ArgumentParser() +parser.add_argument("--content_path", type=str, help='content_path, default: None') +parser.add_argument("--label_path", type=str, help='label_path, default: None') +parser.add_argument('--ckpt_path', type=str, default='ckpts', help='checkpoint save location, default: ckpts') +parser.add_argument("--ckpt_name", default='u2net', type=str, help='prefix of ckpt files, default: u2net') +parser.add_argument("--loss_scale", type=int, default=8192) +parser.add_argument("--run_distribute", type=ast.literal_eval, default=False, help="Run distribute, default: false.") +parser.add_argument('--pre_trained', default='', type=str, help='model_path, local pretrained model to load') + +# additional params for online training +parser.add_argument("--run_online", type=int, default=0, help='whether train online, default: false') +parser.add_argument("--data_url", type=str, help='path to data on obs, default: None') +parser.add_argument("--train_url", type=str, help='output path on obs, default: None') +parser.add_argument("--is_load_pre", type=int, default=0, help="whether use pretrained model, default: false.") + +args = parser.parse_args() + +if __name__ == '__main__': + device_id = int(os.getenv('DEVICE_ID')) + if args.run_online: + import moxing as mox + mox.file.copy_parallel(args.data_url, "/cache/dataset") + content_path = "/cache/dataset/DUTS/DUTS-TR/DUTS-TR-Image/" + label_path = "/cache/dataset/DUTS/DUTS-TR/DUTS-TR-Mask/" + if args.run_distribute: + args.ckpt_path = "/cache/ckpts/device" + str(device_id) + else: + args.ckpt_path = "/cache/ckpts" + if args.is_load_pre: + pre_ckpt_dir = "/cache/dataset/pre_ckpt" + args.pre_trained = pre_ckpt_dir + "/" + os.listdir(pre_ckpt_dir)[0] + else: + content_path = args.content_path + label_path = args.label_path + + if args.run_distribute: + cfg = config.run_distribute_cfg + + device_num = int(os.getenv('RANK_SIZE')) + context.set_context(device_id=device_id) + context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, + gradients_mean=True, + device_num=device_num) + init() + else: + cfg = config.single_cfg + + args.lr = cfg.lr + args.batch_size = cfg.batch_size + args.decay = cfg.weight_decay + args.epoch_size = cfg.max_epoch + args.eps = cfg.eps + print("---") + print("eps = %lf,batch_size = %d, epoch size = %d" % (args.eps, args.batch_size, args.epoch_size)) + print("lr = %lf,decay = %f, ckpt_name = %s, loss_scale = %d" + % (args.lr, args.decay, args.ckpt_name, args.loss_scale)) + net = U2NET() + net.set_train() + if args.pre_trained != '': + print("pretrained path = %s" % args.pre_trained) + param_dict = load_checkpoint(args.pre_trained) + load_param_into_net(net, param_dict) + print("---") + loss = total_loss() + ds_train = create_dataset(content_path, label_path, args) + print("dataset size: ", ds_train.get_dataset_size()) + opt = nn.Adam(net.get_parameters(), learning_rate=args.lr, beta1=0.9, beta2=0.999, eps=args.eps, + weight_decay=args.decay) + loss_scale_manager = FixedLossScaleManager(args.loss_scale) + model = Model(net, loss, opt, loss_scale_manager=loss_scale_manager, amp_level="O0") + data_size = ds_train.get_dataset_size() + config_ck = CheckpointConfig(save_checkpoint_steps=data_size, keep_checkpoint_max=cfg.keep_checkpoint_max) + ckpoint_cb = ModelCheckpoint(prefix=args.ckpt_name, directory=args.ckpt_path, config=config_ck) + net.set_train() + model.train(args.epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(per_print_times=1), TimeMonitor()]) + if args.run_online: + mox.file.copy_parallel(args.ckpt_path, args.train_url)