Skip to content
Snippets Groups Projects
Commit b66839a3 authored by gengdongjie's avatar gengdongjie
Browse files

upload u2net

parent 656d24c7
No related branches found
No related tags found
No related merge requests found
# Contents
- [U-2-Net Description](#u-2-net-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Pretrained model](#pretrained-model)
- [Environment Requirements](#environment-requirements)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Run On Modelarts](#run-on-modelarts)
- [Training Process](#training-process)
- [Model Description](#model-description)
- [Performance](#performance)
- [Training Performance](#training-performance)
- [Evaluation Performance](#evaluation-performance)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [U-2-Net Description](#contents)
This is the Implementation of the Mindspore code of paper [**U<sup>2</sup>-Net: Going deeper with nested U-structure for
salient object detection** ](http://arxiv.org/abs/2005.09007)
Authors: Qin, Xuebin and Zhang, Zichen and Huang, Chenyang and Dehghan, Masood and Zaiane, Osmar and Jagersand, Martin.
In this paper, we design a simple yet powerful deep network architecture, U<sup>2</sup>-Net, for salient object
detection (SOD). The architecture of our U<sup>2</sup>-Net is a two-level nested U-structure.
# [Model Architecture](#contents)
![](assets/network.png)
The architecture of our U<sup>2</sup>-Net is a two-level nested U-structure.
# [Dataset](#contents)
To train U<sup>2</sup>-Net, We use the dataset [DUTS-TR](http://saliencydetection.net/duts/download/DUTS-TR.zip)
# [Environment Requirements](#contents)
- Hardware Ascend
- Prepare hardware environment with Ascend processor.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below£º
- [MindSpore Tutorials](https://www.mindspore.cn/tutorials/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/api/en/master/index.html)
# [Script Description](#contents)
## [Script and Sample Code](#contents)
```shell
U-2-Net
├─ README.md # descriptions about U-2-Net
├─ scripts
└─ run_distribute_train.sh # launch Ascend training (8 Ascend)
├─ assets # save pics for README.MD
├─ ckpts # save ckpt
├─ src
├─ data_loader.py # generate dataset for training
├─ loss.py # loss function define
└─ blocks.py # U-2-Net model define
├─ train_modelarts.py # train script for online train
├─ test.py # generate detection images
├─ eval.py # eval script
└─ train.py # train script
```
## [Script Parameters](#contents)
### [Training Script Parameters](#contents)
```shell
# distributed training
cd scripts
./run_distribute_train.sh [/path/to/content] [/path/to/label] [/path/to/RANK_TABLE_FILE]
# standalone training
python train.py --content_path [/path/to/content] --label_path [/path/to/label]
```
### Training Result
Training result will be stored in './ckpts'. You can find checkpoint file.
### Evaluation Script Parameters
- Run `test.py` to generate semantically segmented pictures
- Run `eval.py` for evaluation.
```bash
# generate semantically segmented pictures
python test.py --content_path [/path/to/content] &>test.log &
# evaling
python eval.py --pred_dir [/path/to/pred_dir] --label_dir [/path/to/label] &>evaluation.log &
```
### [Run On Modelarts](#contents)
- Run `train.py` to train on modelarts
```bash
#In order to run on modelars, you should place data files like this:
# └─ dataset
# ├─ pre_ckpt
# └─ DUTS
#params:
# run_distribute # Run distribute, default: false.
# data_url # path to data on obs default: None
# train_url # output path in obs default: None
# ckpt_name # prefix of ckpt files, default: u2net
# run_online # you should set run_online to 1 if you wnt to run online
# is_load_pre # whether use pretrained model, default: false.
```
- Run `test.py` to generate semantically segmented pictures on modelarts
```bash
#In order to run on modelars, you should place data files like this:
# └─ dataset
# ├─ pre_ckpt
# └─ content_dir
#params:
# data_url # path to data on obs default: None
# run_online # you should set run_online to 1 if you wnt to run online
```
- Run `eval.py` to evaluate on modelarts
```bash
#In order to run on modelars, you should place data files like this:
# └─ dataset
# ├─ pre_ckpt
# ├─ label_dir
# └─ content_dir
#params:
# data_url # path to data on obs default: None
# run_online # you should set run_online to 1 if you wnt to run online
```
## [Training Process](#contents)
### [Training](#contents)
- Run `run_standalone_train_ascend.sh` for non-distributed training of U-2-Net model.
```bash
# standalone training
python train.py --content_path [/path/to/content] --label_path [/path/to/label]
```
### [Distributed Training](#contents)
- Run `run_distributed_train_ascend.sh` for distributed training of U-2-Net model.
```bash
# distributed training
cd scripts
bash run_distribute_train.sh [/path/to/content] [/path/to/label] [/path/to/RANK_TABLE_FILE]
```
- Notes
1. hccl.json which is specified by RANK_TABLE_FILE is needed when you are running a distribute task. You can generate it
by using the [hccl_tools](https://gitee.com/mindspore/models/tree/master/utils/hccl_tools).
# [Model Description](#contents)
## [Performance](#contents)
### Training Performance
| Parameters | |
| -------------------------- | ----------------------------------------------------- |
| Model Version | v1 |
| Resource | Red Hat 8.3.1; Ascend 910; CPU 2.60GHz; 192cores |
| MindSpore Version | 1.3.0 |
| Dataset | DUTS-TR. |
| Training Parameters | epoch=1500, batch_size = 16 |
| Optimizer | Adam |
| Loss Function | BCELoss |
| outputs | semantically segmented pictures |
| Speed | 8 Ascend: 440 ms/step; 1 Ascend: 303 ms/step; |
| Total time | 8pcs: 15h |
| Checkpoint for Fine tuning | 512M (.ckpt file) |
### Picture Generating Performance
| Parameters | single Ascend |
| ----------------- | ------------------------------------------------ |
| Model Version | v1 |
| Resource | Red Hat 8.3.1; Ascend 910; CPU 2.60GHz; 192cores |
| MindSpore Version | 1.3.0 |
| Dataset | content images |
| batch_size | 1 |
| outputs | semantically segmented pictures |
| Speed | 95 ms/pic |
### Evaluation Performance
| Parameters | single Ascend |
| ----------------- | ------------------------------------------------ |
| Model Version | v1 |
| Resource | Red Hat 8.3.1; Ascend 910; CPU 2.60GHz; 192cores |
| MindSpore Version | 1.3.0 |
| Dataset | DUTS-TE |
| batch_size | 1 |
| Accuracy | 85.52% |
# [Description of Random Situation](#contents)
We use random seed in train.py.
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/models).
research/cv/u2net/assets/network.png

826 KiB

# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""eval process"""
import os
import argparse
import ast
import numpy as np
from PIL import Image
parser = argparse.ArgumentParser()
parser.add_argument("--pred_dir", type=str, help='pred_dir, default: None')
parser.add_argument("--label_dir", type=str, help='label_dir, default: None')
# additional params for online evaluating
parser.add_argument("--run_online", type=ast.literal_eval, default=False, help='whether train online, default: false')
parser.add_argument("--data_url", type=str, help='path to data on obs, default: None')
parser.add_argument("--train_url", type=str, help='output path on obs, default: None')
args = parser.parse_args()
if __name__ == '__main__':
if args.run_online:
import moxing as mox
mox.file.copy_parallel(args.data_url, "/cache/dataset")
pred_dir = "/cache/dataset/pred_dir"
label_dir = "/cache/dataset/label_dir"
else:
pred_dir = args.pred_dir
label_dir = args.label_dir
def generate(pred, label, num=255):
"""generate prec and recall"""
prec = np.zeros(num)
recall = np.zeros(num)
min_num = 0
max_num = 1 - 1e-10
for i in range(num):
tmp_num = (max_num - min_num) / num * i
pred_ones = pred >= tmp_num
acc = (pred_ones * label).sum()
recall[i] = acc / (pred_ones.sum() + 1e-20)
prec[i] = acc / (label.sum() + 1e-20)
return prec, recall
def F_score(pred, label, num=255):
"""calculate f-score"""
prec, recall = generate(pred, label, num)
beta2 = 0.3
f_score = (1 + beta2) * prec * recall / (beta2 * prec + recall + 1e-20)
return f_score
def max_F(pred_directory, label_directory, num=255):
"""calculate max f-score"""
sum_value = np.zeros(num)
content_list = os.listdir(pred_directory)
pic_num = 0
for i in range(len(content_list)):
pred_path = pred_directory + "/" + content_list[i]
pred = np.array(Image.open(pred_path), dtype='float32')
pic_name = content_list[i].replace(".jpg", "").replace(".png", "").replace(".JPEG", "")
print("%d / %d , %s \n" % (i, len(content_list), pic_name))
label_path = os.path.join(label_directory, pic_name) + ".png"
label = np.array(Image.open(label_path), dtype='float32')
if len(label.shape) > 2:
label = label[:, :, 0]
if len(pred.shape) > 2:
print(pred.shape)
pred = pred[:, :, 0]
pred = pred.squeeze()
print(pred.shape)
label /= label.max()
pred /= pred.max()
tmp = F_score(pred, label, num)
sum_value += tmp
pic_num += 1
score = sum_value / pic_num
return score.max()
print("max_F measure, score = %f" % (max_F(pred_dir, label_dir)))
easydict~=1.9
imageio~=2.9.0
opencv_python~=4.5.1.48
Pillow~=8.4.0
scikit_image~=0.18.1
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run.sh DATA_PATH RANK_SIZE"
echo "For example: bash run.sh /path/dataset 8"
echo "It is better to use the absolute path."
echo "=============================================================================================================="
if [ ! -d $1 ]
then
echo "error: CONTENT_PATH=$2 is not a directory"
exit 1
fi
if [ ! -d $2 ]
then
echo "error: LABEL_PATH=$2 is not a directory"
exit 1
fi
if [ ! -f $3 ]
then
echo "error: RANK_TABLE_FILE=$3 is not a file"
exit 1
fi
set -e
export RANK_TABLE_FILE=$3
export RANK_SIZE=8
EXEC_PATH=$(pwd)
echo "$EXEC_PATH"
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
for((i=0;i<${RANK_SIZE};i++))
do
rm -rf device$i
mkdir device$i
cd ./device$i
cd ../
cp ../*.py ./device$i
cp -r ../src ./device$i
cd ./device$i
export DEVICE_ID=$i
export RANK_ID=$i
echo "start training for device $i"
env > env$i.log
python train.py --run_distribute 1 --content_path $1 --label_path $2 > output.log 2>&1 &
echo "$i finish"
cd ../
done
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""network blocks definition"""
import mindspore
from mindspore import nn, ops
init_mode = 'xavier_uniform'
has_bias = True
bias_init = 'zeros'
class REBNCONV(nn.Cell):
"""
A basic unit consisting of convolution, batchnorm, and relu activation functions
"""
def __init__(self, in_ch=3, out_ch=3, dirate=1):
"""definition method"""
super(REBNCONV, self).__init__()
self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, dilation=1 * dirate, weight_init=init_mode, has_bias=has_bias,
bias_init=bias_init)
self.bn_s1 = nn.BatchNorm2d(out_ch, affine=True)
self.relu_s1 = nn.ReLU()
def construct(self, x):
"""compute method"""
hx = x
hx = self.conv_s1(hx)
hx = self.bn_s1(hx)
xout = self.relu_s1(hx)
return xout
def _upsample_like(src, tar):
"""generate upsample unit"""
resize_bilinear = mindspore.ops.operations.ResizeBilinear(tar.shape[2:])
src = resize_bilinear(src)
return src
class RSU7(nn.Cell):
"""RSU7 block"""
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
"""RSU7 definition"""
super(RSU7, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool3 = nn.MaxPool2d(2, stride=2)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool4 = nn.MaxPool2d(2, stride=2)
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool5 = nn.MaxPool2d(2, stride=2)
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
self.cat = ops.Concat(axis=1)
def construct(self, x):
"""RSU7 compute"""
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)
hx5 = self.rebnconv5(hx)
hx = self.pool5(hx5)
hx6 = self.rebnconv6(hx)
hx7 = self.rebnconv7(hx6)
hx6d = self.rebnconv6d(self.cat((hx7, hx6)))
hx6dup = _upsample_like(hx6d, hx5)
hx5d = self.rebnconv5d(self.cat((hx6dup, hx5)))
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.rebnconv4d(self.cat((hx5dup, hx4)))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.rebnconv3d(self.cat((hx4dup, hx3)))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(self.cat((hx3dup, hx2)))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(self.cat((hx2dup, hx1)))
return hx1d + hxin
class RSU6(nn.Cell):
"""RSU6 block"""
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
"""RSU6 definition"""
super(RSU6, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool3 = nn.MaxPool2d(2, stride=2)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool4 = nn.MaxPool2d(2, stride=2)
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
self.cat = ops.Concat(axis=1)
def construct(self, x):
"""RSU6 compute"""
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)
hx5 = self.rebnconv5(hx)
hx6 = self.rebnconv6(hx5)
hx5d = self.rebnconv5d(self.cat((hx6, hx5)))
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.rebnconv4d(self.cat((hx5dup, hx4)))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.rebnconv3d(self.cat((hx4dup, hx3)))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(self.cat((hx3dup, hx2)))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(self.cat((hx2dup, hx1)))
return hx1d + hxin
class RSU5(nn.Cell):
"""RSU5 block"""
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
"""RSU5 definition"""
super(RSU5, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool3 = nn.MaxPool2d(2, stride=2)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
self.cat = ops.Concat(axis=1)
def construct(self, x):
"""RSU5 compute"""
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx5 = self.rebnconv5(hx4)
hx4d = self.rebnconv4d(self.cat((hx5, hx4)))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.rebnconv3d(self.cat((hx4dup, hx3)))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(self.cat((hx3dup, hx2)))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(self.cat((hx2dup, hx1)))
return hx1d + hxin
class RSU4(nn.Cell):
"""RSU4 block"""
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
"""RSU4 definition"""
super(RSU4, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
self.cat = ops.Concat(axis=1)
def construct(self, x):
"""RSU4 compute"""
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx4 = self.rebnconv4(hx3)
hx3d = self.rebnconv3d(self.cat((hx4, hx3)))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(self.cat((hx3dup, hx2)))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(self.cat((hx2dup, hx1)))
return hx1d + hxin
class RSU4F(nn.Cell):
"""RSU4F block"""
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
"""RSU4F definition"""
super(RSU4F, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
self.cat = ops.Concat(axis=1)
def construct(self, x):
"""RSU4F compute"""
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx2 = self.rebnconv2(hx1)
hx3 = self.rebnconv3(hx2)
hx4 = self.rebnconv4(hx3)
hx3d = self.rebnconv3d(self.cat((hx4, hx3)))
hx2d = self.rebnconv2d(self.cat((hx3d, hx2)))
hx1d = self.rebnconv1d(self.cat((hx2d, hx1)))
return hx1d + hxin
class U2NET(nn.Cell):
"""U-2-Net model"""
def __init__(self, in_ch=3, out_ch=1):
"""U-2-Net definition"""
super(U2NET, self).__init__()
self.stage1 = RSU7(in_ch, 32, 64)
self.pool12 = nn.MaxPool2d(2, stride=2)
self.stage2 = RSU6(64, 32, 128)
self.pool23 = nn.MaxPool2d(2, stride=2)
self.stage3 = RSU5(128, 64, 256)
self.pool34 = nn.MaxPool2d(2, stride=2)
self.stage4 = RSU4(256, 128, 512)
self.pool45 = nn.MaxPool2d(2, stride=2)
self.stage5 = RSU4F(512, 256, 512)
self.pool56 = nn.MaxPool2d(2, stride=2)
self.stage6 = RSU4F(512, 256, 512)
# decoder
self.stage5d = RSU4F(1024, 256, 512)
self.stage4d = RSU4(1024, 128, 256)
self.stage3d = RSU5(512, 64, 128)
self.stage2d = RSU6(256, 32, 64)
self.stage1d = RSU7(128, 16, 64)
self.side1 = nn.Conv2d(64, out_ch, 3, pad_mode='same', weight_init=init_mode, has_bias=has_bias,
bias_init=bias_init)
self.side2 = nn.Conv2d(64, out_ch, 3, pad_mode='same', weight_init=init_mode, has_bias=has_bias,
bias_init=bias_init)
self.side3 = nn.Conv2d(128, out_ch, 3, pad_mode='same', weight_init=init_mode, has_bias=has_bias,
bias_init=bias_init)
self.side4 = nn.Conv2d(256, out_ch, 3, pad_mode='same', weight_init=init_mode, has_bias=has_bias,
bias_init=bias_init)
self.side5 = nn.Conv2d(512, out_ch, 3, pad_mode='same', weight_init=init_mode, has_bias=has_bias,
bias_init=bias_init)
self.side6 = nn.Conv2d(512, out_ch, 3, pad_mode='same', weight_init=init_mode, has_bias=has_bias,
bias_init=bias_init)
self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1, pad_mode='same', weight_init=init_mode, has_bias=has_bias,
bias_init=bias_init)
self.cat = ops.Concat(axis=1)
self.Sigmoid = nn.Sigmoid()
self.reshape = ops.Reshape()
def construct(self, x):
"""U-2-Net compute"""
hx = x
# stage 1
hx1 = self.stage1(hx)
hx = self.pool12(hx1)
# stage 2
hx2 = self.stage2(hx)
hx = self.pool23(hx2)
# stage 3
hx3 = self.stage3(hx)
hx = self.pool34(hx3)
# stage 4
hx4 = self.stage4(hx)
hx = self.pool45(hx4)
# stage 5
hx5 = self.stage5(hx)
hx = self.pool56(hx5)
# stage 6
hx6 = self.stage6(hx)
hx6up = _upsample_like(hx6, hx5)
# -------------------- decoder --------------------
hx5d = self.stage5d(self.cat((hx6up, hx5)))
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.stage4d(self.cat((hx5dup, hx4)))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.stage3d(self.cat((hx4dup, hx3)))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.stage2d(self.cat((hx3dup, hx2)))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.stage1d(self.cat((hx2dup, hx1)))
# side output
d1 = self.side1(hx1d)
d2 = self.side2(hx2d)
d2 = _upsample_like(d2, d1)
d3 = self.side3(hx3d)
d3 = _upsample_like(d3, d1)
d4 = self.side4(hx4d)
d4 = _upsample_like(d4, d1)
d5 = self.side5(hx5d)
d5 = _upsample_like(d5, d1)
d6 = self.side6(hx6)
d6 = _upsample_like(d6, d1)
d0 = self.outconv(self.cat((d1, d2, d3, d4, d5, d6)))
return self.cat((self.Sigmoid(d0), self.Sigmoid(d1), self.Sigmoid(d2), self.Sigmoid(d3), self.Sigmoid(d4),
self.Sigmoid(d5), self.Sigmoid(d6)))
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""config of distribute training and standanlone training"""
from easydict import EasyDict as edict
single_cfg = edict({
"lr": 1e-3,
"batch_size": 12,
"max_epoch": 700,
"keep_checkpoint_max": 10,
"weight_decay": 0,
"eps": 1e-8,
})
run_distribute_cfg = edict({
"lr": 4e-3,
"batch_size": 16,
"max_epoch": 1500,
"keep_checkpoint_max": 10,
"weight_decay": 0,
"eps": 1e-8,
})
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""definition of some data loading operations"""
from __future__ import print_function, division
import random
import os
import glob
import numpy as np
from skimage import io, transform, color
from mindspore import context
from mindspore import dataset as ds
from mindspore.common import dtype as mstype
import mindspore.dataset.transforms.c_transforms as CC
from mindspore.context import ParallelMode
from mindspore.communication.management import get_rank, get_group_size
# ==========================dataset load==========================
class RescaleT:
"""Rescale operation"""
def __init__(self, output_size):
"""RescaleT definition"""
assert isinstance(output_size, (int, tuple))
self.output_size = output_size
def __call__(self, sample):
"""RescaleT compute"""
imidx, image, label = sample['imidx'], sample['image'], sample['label']
h, w = image.shape[:2]
if isinstance(self.output_size, int):
if h > w:
new_h, new_w = self.output_size * h / w, self.output_size
else:
new_h, new_w = self.output_size, self.output_size * w / h
else:
new_h, new_w = self.output_size
new_h, new_w = int(new_h), int(new_w)
img = transform.resize(image, (self.output_size, self.output_size), mode='constant')
lbl = transform.resize(label, (self.output_size, self.output_size), mode='constant', order=0,
preserve_range=True)
return {'imidx': imidx, 'image': img, 'label': lbl}
class Rescale:
"""Rescale operation"""
def __init__(self, output_size):
"""Rescale definition"""
assert isinstance(output_size, (int, tuple))
self.output_size = output_size
def __call__(self, sample):
"""Rescale compute"""
imidx, image, label = sample['imidx'], sample['image'], sample['label']
if random.random() >= 0.5:
image = image[::-1]
label = label[::-1]
a = int(random.random() * 4) % 4
image = np.rot90(image, a)
label = np.rot90(label, a)
h, w = image.shape[:2]
if isinstance(self.output_size, int):
if h > w:
new_h, new_w = self.output_size * h / w, self.output_size
else:
new_h, new_w = self.output_size, self.output_size * w / h
else:
new_h, new_w = self.output_size
new_h, new_w = int(new_h), int(new_w)
img = transform.resize(image, (new_h, new_w), mode='constant')
lbl = transform.resize(label, (new_h, new_w), mode='constant', order=0, preserve_range=True)
return {'imidx': imidx, 'image': img, 'label': lbl}
class RandomCrop:
"""RandomCrop operation"""
def __init__(self, output_size):
"""RandomCrop definition"""
assert isinstance(output_size, (int, tuple))
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
assert len(output_size) == 2
self.output_size = output_size
def __call__(self, sample):
"""RandomCrop compute"""
imidx, image, label = sample['imidx'], sample['image'], sample['label']
if random.random() >= 0.5:
image = image[::-1]
label = label[::-1]
a = int(random.random() * 4) % 4
image = np.rot90(image, a)
label = np.rot90(label, a)
h, w = image.shape[:2]
new_h, new_w = self.output_size
top = np.random.randint(0, h - new_h)
left = np.random.randint(0, w - new_w)
image = image[top: top + new_h, left: left + new_w]
label = label[top: top + new_h, left: left + new_w]
return {'imidx': imidx, 'image': image, 'label': label}
class ToTensor:
"""Convert ndarrays in sample to Tensors."""
def __call__(self, sample):
"""Convert operation"""
imidx, image, label = sample['imidx'], sample['image'], sample['label']
tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
tmpLbl = np.zeros(label.shape)
image = image / np.max(image)
if np.max(label) < 1e-6:
label = label
else:
label = label / np.max(label)
if image.shape[2] == 1:
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229
else:
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224
tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225
tmpLbl[:, :, 0] = label[:, :, 0]
tmpImg = tmpImg.transpose((2, 0, 1))
tmpLbl = label.transpose((2, 0, 1))
return imidx, tmpImg, tmpLbl
class ToTensorLab:
"""Convert ndarrays in sample to Tensors."""
def __init__(self, flag=0):
"""ToTensorLab definition"""
self.flag = flag
def __call__(self, sample):
"""ToTensorLab compute"""
imidx, image, label = sample['imidx'], sample['image'], sample['label']
tmpLbl = np.zeros(label.shape)
if np.max(label) < 1e-6:
label = label
else:
label = label / np.max(label)
# change the color space
if self.flag == 2: # with rgb and Lab colors
tmpImg = np.zeros((image.shape[0], image.shape[1], 6))
tmpImgt = np.zeros((image.shape[0], image.shape[1], 3))
if image.shape[2] == 1:
tmpImgt[:, :, 0] = image[:, :, 0]
tmpImgt[:, :, 1] = image[:, :, 0]
tmpImgt[:, :, 2] = image[:, :, 0]
else:
tmpImgt = image
tmpImgtl = color.rgb2lab(tmpImgt)
# nomalize image to range [0,1]
tmpImg[:, :, 0] = (tmpImgt[:, :, 0] - np.min(tmpImgt[:, :, 0])) / (
np.max(tmpImgt[:, :, 0]) - np.min(tmpImgt[:, :, 0]))
tmpImg[:, :, 1] = (tmpImgt[:, :, 1] - np.min(tmpImgt[:, :, 1])) / (
np.max(tmpImgt[:, :, 1]) - np.min(tmpImgt[:, :, 1]))
tmpImg[:, :, 2] = (tmpImgt[:, :, 2] - np.min(tmpImgt[:, :, 2])) / (
np.max(tmpImgt[:, :, 2]) - np.min(tmpImgt[:, :, 2]))
tmpImg[:, :, 3] = (tmpImgtl[:, :, 0] - np.min(tmpImgtl[:, :, 0])) / (
np.max(tmpImgtl[:, :, 0]) - np.min(tmpImgtl[:, :, 0]))
tmpImg[:, :, 4] = (tmpImgtl[:, :, 1] - np.min(tmpImgtl[:, :, 1])) / (
np.max(tmpImgtl[:, :, 1]) - np.min(tmpImgtl[:, :, 1]))
tmpImg[:, :, 5] = (tmpImgtl[:, :, 2] - np.min(tmpImgtl[:, :, 2])) / (
np.max(tmpImgtl[:, :, 2]) - np.min(tmpImgtl[:, :, 2]))
tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(tmpImg[:, :, 1])
tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(tmpImg[:, :, 2])
tmpImg[:, :, 3] = (tmpImg[:, :, 3] - np.mean(tmpImg[:, :, 3])) / np.std(tmpImg[:, :, 3])
tmpImg[:, :, 4] = (tmpImg[:, :, 4] - np.mean(tmpImg[:, :, 4])) / np.std(tmpImg[:, :, 4])
tmpImg[:, :, 5] = (tmpImg[:, :, 5] - np.mean(tmpImg[:, :, 5])) / np.std(tmpImg[:, :, 5])
elif self.flag == 1: # with Lab color
tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
if image.shape[2] == 1:
tmpImg[:, :, 0] = image[:, :, 0]
tmpImg[:, :, 1] = image[:, :, 0]
tmpImg[:, :, 2] = image[:, :, 0]
else:
tmpImg = image
tmpImg = color.rgb2lab(tmpImg)
tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.min(tmpImg[:, :, 0])) / (
np.max(tmpImg[:, :, 0]) - np.min(tmpImg[:, :, 0]))
tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.min(tmpImg[:, :, 1])) / (
np.max(tmpImg[:, :, 1]) - np.min(tmpImg[:, :, 1]))
tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.min(tmpImg[:, :, 2])) / (
np.max(tmpImg[:, :, 2]) - np.min(tmpImg[:, :, 2]))
tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std(tmpImg[:, :, 0])
tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(tmpImg[:, :, 1])
tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(tmpImg[:, :, 2])
else: # with rgb color
tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
image = image / np.max(image)
if image.shape[2] == 1:
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229
else:
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224
tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225
tmpLbl[:, :, 0] = label[:, :, 0]
# change the rgb to brg
tmpImg = tmpImg.transpose((2, 0, 1))
tmpLbl = label.transpose((2, 0, 1))
return {'imidx': imidx, 'image': tmpImg, 'label': tmpLbl}
class SalObjDataset:
"""Preprocess the tensor"""
def __init__(self, img_name_list, lbl_name_list, train_transform=None):
"""
SalObjDataset definition
"""
self.image_name_list = img_name_list
self.label_name_list = lbl_name_list
self.train_transform = train_transform
self.rescale = RescaleT(320)
self.randomcrop = RandomCrop(288)
self.totensor = ToTensorLab(flag=0)
def __len__(self):
"""
get the length of the dataset
"""
return len(self.image_name_list)
def __getitem__(self, idx):
"""
get data in one step
"""
image = io.imread(self.image_name_list[idx])
imidx = np.array([idx])
if not self.label_name_list:
label_3 = np.zeros(image.shape)
else:
label_3 = io.imread(self.label_name_list[idx])
label = np.zeros(label_3.shape[0:2])
if len(label_3.shape) == 3:
label = label_3[:, :, 0]
elif len(label_3.shape) == 2:
label = label_3
if len(image.shape) == 3 and len(label.shape) == 2:
label = label[:, :, np.newaxis]
elif len(image.shape) == 2 and len(label.shape) == 2:
image = image[:, :, np.newaxis]
label = label[:, :, np.newaxis]
sample = {'imidx': imidx, 'image': image, 'label': label}
if self.train_transform:
sample = self.rescale(sample)
sample = self.randomcrop(sample)
sample = self.totensor(sample)
else:
sample = self.rescale(sample)
sample = self.totensor(sample)
return sample['image'], sample['label']
def _get_rank_info():
"""
get rank size and rank id
"""
rank_size = int(os.environ.get("RANK_SIZE", 1))
if rank_size > 1:
rank_size = get_group_size()
rank_id = get_rank()
else:
rank_size = 1
rank_id = 0
return rank_size, rank_id
def create_dataset(image_dir, label_dir, args):
"""
create dataset
"""
parallel_mode = context.get_auto_parallel_context("parallel_mode")
tra_image_dir = image_dir
tra_label_dir = label_dir
image_ext = '.jpg'
label_ext = '.png'
tra_img_name_list = glob.glob(tra_image_dir + '*' + image_ext)
tra_lbl_name_list = []
for img_path in tra_img_name_list:
img_name = img_path.split(os.sep)[-1]
aaa = img_name.split(".")
bbb = aaa[0:-1]
imidx = bbb[0]
for i in range(1, len(bbb)):
imidx = imidx + "." + bbb[i]
tra_lbl_name_list.append(tra_label_dir + imidx + label_ext)
print("---")
print(image_dir)
print("train images: ", len(tra_img_name_list))
print("train labels: ", len(tra_lbl_name_list))
print("---")
salobj_dataset = SalObjDataset(
img_name_list=tra_img_name_list,
lbl_name_list=tra_lbl_name_list,
train_transform=True)
if parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
device_num, rank_id = _get_rank_info()
type_cast = CC.TypeCast(mstype.float32)
data_set = ds.GeneratorDataset(salobj_dataset, ["image", "label"], num_parallel_workers=8, shuffle=True,
num_shards=device_num, shard_id=rank_id)
data_set = data_set.map(operations=type_cast, input_columns=["image"], num_parallel_workers=8)
data_set = data_set.map(operations=type_cast, input_columns=["label"], num_parallel_workers=8)
data_set = data_set.batch(args.batch_size, drop_remainder=True)
else:
type_cast = CC.TypeCast(mstype.float32)
data_set = ds.GeneratorDataset(salobj_dataset, ["image", "label"], num_parallel_workers=8, shuffle=True)
data_set = data_set.map(operations=type_cast, input_columns=["image"], num_parallel_workers=8)
data_set = data_set.map(operations=type_cast, input_columns=["label"], num_parallel_workers=8)
data_set = data_set.batch(args.batch_size, drop_remainder=True)
return data_set
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""loss function definition"""
from mindspore import nn
from mindspore import ops
from mindspore.nn.loss.loss import LossBase
class total_loss(LossBase):
"""Loss function"""
def __init__(self):
"""loss function definition"""
super(total_loss, self).__init__()
self.bceloss = nn.BCELoss(reduction='mean')
self.reshape = ops.Reshape()
self.cat = ops.Concat(axis=1)
self.squeeze = ops.Squeeze()
def construct(self, generated, target):
"""loss function compute"""
return self.bceloss(generated, self.cat((target, target, target, target, target, target, target))) * 7
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""operation to generate semantically segmented pictures"""
import os
import time
import argparse
import numpy as np
import cv2
import imageio
from PIL import Image
from mindspore import context
from mindspore import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
import src.blocks as blocks
parser = argparse.ArgumentParser()
parser.add_argument("--content_path", type=str, help='content_path, default: None')
parser.add_argument('--pre_trained', type=str, help='model_path, local pretrained model to load')
parser.add_argument("--output_dir", type=str, default='output_dir', help='output_path, path to store output')
#additional params for online generating
parser.add_argument("--run_online", type=int, default=0, help='whether train online, default: false')
parser.add_argument("--data_url", type=str, help='path to data on obs, default: None')
parser.add_argument("--train_url", type=str, help='output path on obs, default: None')
args = parser.parse_args()
if __name__ == '__main__':
if args.run_online:
import moxing as mox
mox.file.copy_parallel(args.data_url, "/cache/dataset")
local_dataset_dir = "/cache/dataset/content_dir"
pre_ckpt_dir = "/cache/dataset/pre_ckpt"
pre_ckpt_path = pre_ckpt_dir + "/" + os.listdir(pre_ckpt_dir)[0]
output_dir = "/cache/dataset/pred_dir"
else:
local_dataset_dir = args.content_path
pre_ckpt_path = args.pre_trained
output_dir = args.output_dir
context.set_context(mode=context.GRAPH_MODE)
param_dict = load_checkpoint(pre_ckpt_path)
net = blocks.U2NET()
net.set_train(False)
load_param_into_net(net, param_dict)
def normPRED(d):
"""rescale the value of tensor to between 0 and 1"""
ma = d.max()
mi = d.min()
dn = (d - mi) / (ma - mi)
return dn
def normalize(img, im_type):
"""normalize tensor"""
if im_type == "label":
return img
if len(img.shape) == 3:
img[:, :, 0] = (img[:, :, 0] - 0.485) / 0.229
img[:, :, 1] = (img[:, :, 1] - 0.456) / 0.224
img[:, :, 2] = (img[:, :, 2] - 0.406) / 0.225
else:
img = (img - 0.485) / 0.229
return img
def crop_and_resize(img_path, im_type, size=320):
"""crop and resize tensors"""
img = np.array(Image.open(img_path), dtype='float32')
img = img / 255
img = normalize(img, im_type)
h, w = img.shape[:2]
img = cv2.resize(img, dsize=(0, 0), fx=size / w, fy=size / h)
if len(img.shape) == 2:
img = np.expand_dims(img, 2).repeat(1, axis=2)
im = img
im = np.swapaxes(im, 1, 2)
im = np.swapaxes(im, 0, 1)
im = np.reshape(im, (1, im.shape[0], im.shape[1], im.shape[2]))
return im
content_list = os.listdir(local_dataset_dir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
start_time = time.time()
for j in range(0, len(content_list)):
pic_path = os.path.join(local_dataset_dir, content_list[j])
content_pic = crop_and_resize(pic_path, im_type="content", size=320)
image = net(Tensor(content_pic))
content_name = content_list[j].replace(".jpg", "")
content_name = content_name.replace(".png", "")
file_path = os.path.join(local_dataset_dir, content_list[j])
original = np.array(Image.open(file_path), dtype='float32')
shape = original.shape
image = normPRED(image[0][0].asnumpy())
image = cv2.resize(image, dsize=(0, 0), fx=shape[1] / image.shape[1], fy=shape[0] / image.shape[0])
file_path = os.path.join(output_dir, content_name) + ".png"
imageio.imsave(file_path, image)
print("%d / %d , %s \n" % (j, len(content_list), content_name))
end_time = time.time()
dtime = end_time - start_time
print("finish generating in %.8s s" % (dtime))
if args.run_online:
mox.file.copy_parallel(output_dir, args.train_url)
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train process"""
import os
import random
import argparse
import ast
import numpy as np
from mindspore import nn
from mindspore import context
from mindspore import FixedLossScaleManager
from mindspore import Model
from mindspore import dataset as ds
from mindspore.context import ParallelMode
from mindspore.train.callback import LossMonitor, ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.communication.management import init
import src.config as config
from src.data_loader import create_dataset
from src.loss import total_loss
from src.blocks import U2NET
random.seed(1)
np.random.seed(1)
ds.config.set_seed(1)
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
parser = argparse.ArgumentParser()
parser.add_argument("--content_path", type=str, help='content_path, default: None')
parser.add_argument("--label_path", type=str, help='label_path, default: None')
parser.add_argument('--ckpt_path', type=str, default='ckpts', help='checkpoint save location, default: ckpts')
parser.add_argument("--ckpt_name", default='u2net', type=str, help='prefix of ckpt files, default: u2net')
parser.add_argument("--loss_scale", type=int, default=8192)
parser.add_argument("--run_distribute", type=ast.literal_eval, default=False, help="Run distribute, default: false.")
parser.add_argument('--pre_trained', default='', type=str, help='model_path, local pretrained model to load')
# additional params for online training
parser.add_argument("--run_online", type=int, default=0, help='whether train online, default: false')
parser.add_argument("--data_url", type=str, help='path to data on obs, default: None')
parser.add_argument("--train_url", type=str, help='output path on obs, default: None')
parser.add_argument("--is_load_pre", type=int, default=0, help="whether use pretrained model, default: false.")
args = parser.parse_args()
if __name__ == '__main__':
device_id = int(os.getenv('DEVICE_ID'))
if args.run_online:
import moxing as mox
mox.file.copy_parallel(args.data_url, "/cache/dataset")
content_path = "/cache/dataset/DUTS/DUTS-TR/DUTS-TR-Image/"
label_path = "/cache/dataset/DUTS/DUTS-TR/DUTS-TR-Mask/"
if args.run_distribute:
args.ckpt_path = "/cache/ckpts/device" + str(device_id)
else:
args.ckpt_path = "/cache/ckpts"
if args.is_load_pre:
pre_ckpt_dir = "/cache/dataset/pre_ckpt"
args.pre_trained = pre_ckpt_dir + "/" + os.listdir(pre_ckpt_dir)[0]
else:
content_path = args.content_path
label_path = args.label_path
if args.run_distribute:
cfg = config.run_distribute_cfg
device_num = int(os.getenv('RANK_SIZE'))
context.set_context(device_id=device_id)
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True,
device_num=device_num)
init()
else:
cfg = config.single_cfg
args.lr = cfg.lr
args.batch_size = cfg.batch_size
args.decay = cfg.weight_decay
args.epoch_size = cfg.max_epoch
args.eps = cfg.eps
print("---")
print("eps = %lf,batch_size = %d, epoch size = %d" % (args.eps, args.batch_size, args.epoch_size))
print("lr = %lf,decay = %f, ckpt_name = %s, loss_scale = %d"
% (args.lr, args.decay, args.ckpt_name, args.loss_scale))
net = U2NET()
net.set_train()
if args.pre_trained != '':
print("pretrained path = %s" % args.pre_trained)
param_dict = load_checkpoint(args.pre_trained)
load_param_into_net(net, param_dict)
print("---")
loss = total_loss()
ds_train = create_dataset(content_path, label_path, args)
print("dataset size: ", ds_train.get_dataset_size())
opt = nn.Adam(net.get_parameters(), learning_rate=args.lr, beta1=0.9, beta2=0.999, eps=args.eps,
weight_decay=args.decay)
loss_scale_manager = FixedLossScaleManager(args.loss_scale)
model = Model(net, loss, opt, loss_scale_manager=loss_scale_manager, amp_level="O0")
data_size = ds_train.get_dataset_size()
config_ck = CheckpointConfig(save_checkpoint_steps=data_size, keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix=args.ckpt_name, directory=args.ckpt_path, config=config_ck)
net.set_train()
model.train(args.epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(per_print_times=1), TimeMonitor()])
if args.run_online:
mox.file.copy_parallel(args.ckpt_path, args.train_url)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment