Skip to content
Snippets Groups Projects
Commit 1c1e1d4e authored by Andrey Denisov's avatar Andrey Denisov
Browse files

Model: IndexNet - Added implementation

parent 86024548
No related branches found
No related tags found
No related merge requests found
Showing
with 2815 additions and 0 deletions
# Contents
- [Contents](#contents)
- [IndexNet Description](#indexnet-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Process](#training-process)
- [Standalone Training](#standalone-training)
- [Distribute Training](#distribute-training)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Model Export](#model-export)
- [Model Description](#model-description)
- [Performance](#performance)
- [Training Performance](#training-performance)
- [Evaluation Performance](#evaluation-performance)
- [ModelZoo Homepage](#modelzoo-homepage)
## [IndexNet Description](#contents)
Upsampling is an essential stage for most dense prediction tasks using deep convolutional neural networks (CNNs).
The frequently used upsampling operators include transposed convolution, unpooling, periodic shuffling
(also known as depth-to-space), and naive interpolation followed by convolution.
These operators, however, are not general-purpose designs and often exhibit different behaviors in different tasks.
Instead of using maxpooling and unpooling, IndexNet is based on two novel operations: indexed pooling and indexed upsampling
where downsampling and upsampling are guided by learned indices. The indices are generated dynamically conditioned
on the feature map and are learned using a fully convolutional network, termed IndexNet, without supervision.
[Paper](https://openaccess.thecvf.com/content_ICCV_2019/papers/Lu_Indices_Matter_Learning_to_Index_for_Deep_Image_Matting_ICCV_2019_paper.pdf): Indices Matter: Learning to Index for Deep Image Matting. Hau Lu, Yutong Dai, Chunhua Shen.
## [Model Architecture](#contents)
IndexNet bases on the UNet architecture and uses mobilenetv2 as backbone.
Mobilenetv2 was chosen because it is lightweight and allows the use of higher-resolution images on the same GPU as high capacity backbones.
All 2-stride convolutions were changed by 1-stride convolutions and 2-stride 2x2 max poolings after each encoding stage for downsampling, which allows the extraction of indices.
If applying the IndexNet idea, max pooling and unpooling layers can be replaced with IndexedPooling and IndexedUnpooling, respectively.
## [Dataset](#contents)
Paper uses the Adobe Image Matting dataset, but it is in close access.
Thus, we use AIM-500 (Automatic Image Matting - 500) dataset, which is in open access, and anyone can download it.
Every image from AIM-500 dataset cuts out by mask and N (96 train part, 20 test part) times placed as foreground
over the unique image from the COCO-2014 dataset (train part), which is used as background.
Datasets used: AIM-500, COCO-2014 (train).
| | AIM-500 | COCO-2014 | Merged (after processing) |
| --------------|------------------------------------------------- |---------------------- |-------------------------- |
| Dataset size | ~0.35 Gb | ~13.0 Gb | ~86.0 Gb |
| Train | 0.35 Gb, 3 * 500 images (mask, original, trimap) | 13.0 Gb, 82783 images | 84 Gb, 43200 images |
| Test | - | - | 2 Gb, 1000 images |
| Data format | .png, .jpg images | .jpg images | .png images |
Note: We manually split AIM-500 for the train/test parts (450/50).
Download [AIM-500](https://drive.google.com/drive/folders/1IyPiYJUp-KtOoa-Hsm922VU3aCcidjjz) dataset
(3 folders: original, mask, trimap), unzip them, move folders from unzipped archives into one folder named AIM-500.
Download [COCO-2014 train](http://images.cocodataset.org/zips/train2014.zip) and unzip.
The structure of the datasets will be as follows:
```text
.
└─AIM-500 <- data_dir
├─mask
│ └─***.png
├─original
│ └─***.jpg
└─trimap
└─***.png
.
└─train2014 <- bg_dir
└─***.jpg
Where *** is the image file name
```
To process dataset use the command below.
```bash
python -m data.process_dataset --data_dir /path/to/AIM-500 --bg_dir /path/to/coco/train2014
```
- DATA_DIR - path to image matting dataset (AIM-500 folder, in this case).
- BG_DIR - path to backgrounds dataset (COCO/train2014 folder, in this case).
Note: Before data processing requirements will be installed. Make sure that you have ~100 Gb free space at disk,
which corresponds to --data_dir path. It can take about 20 hours to prepare dataset, depends on hardware.
During processing the data_dir structure will be automatically changed and
the merged images saved into data_dir/train/merged, data_dir/validation/merged. The bg_dir will remain unchanged.
Processed dataset will have the following structure:
```text
.
└─AIM-500 <- data_dir
├─train
│ ├─data.txt
│ ├─mask
│ ├─merged
│ └─original
└─validation
├─data.txt
├─mask
├─merged
├─original
└─trimap
.
└─train2014 <- bg_dir
```
## [Environment Requirements](#contents)
- Hardware(GPU)
- Prepare hardware environment with GPU processor.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below:
- [MindSpore Tutorials](https://www.mindspore.cn/tutorials/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/api/en/master/index.html)
Note: We use MindSpore 1.6.1 GPU, thus make sure that you install >= 1.6.1 version.
## [Quick Start](#contents)
After installing MindSpore through the official website, you can follow the steps below for training and evaluation,
in particular, before training, you need to install `requirements.txt` by following command `pip install -r requirements.txt`
and [download](https://mindspore.cn/resources/hub/details/en?MindSpore/ascend/1.2/mobilenetv2_v1.2_imagenet2012)
the pre-trained on ImageNet mobilenetv2 backbone.
```bash
# Run standalone training example
bash scripts/run_standalone_train_gpu.sh [DEVICE_ID] [LOGS_CKPT_DIR] [MOBILENET_CKPT] [DATA_DIR] [BG_DIR]
# Run distribute training example
bash scripts/run_distribute_train_gpu.sh [DEVICE_NUM] [LOGS_CKPT_DIR] [MOBILENET_CKPT] [DATA_DIR] [BG_DIR]
```
- DEVICE_ID - process device ID.
- DEVICE_NUM - number of distribute training devices.
- LOGS_CKPT_DIR - path to the directory, where the training results (ckpts, logs) will be stored.
- MOBILENET_CKPT - path to the pre-trained mobilenetv2 backbone ([link](https://mindspore.cn/resources/hub/details/en?MindSpore/ascend/1.2/mobilenetv2_v1.2_imagenet2012)).
- DATA_DIR - path to image matting dataset (AIM-500 folder, in this case).
- BG_DIR - path to backgrounds dataset (COCO/train2014 folder, in this case).
## [Script Description](#contents)
### [Script and Sample Code](#contents)
```text
.
└─IndexNet
├─README.md
├─requirements.txt
├─data
│ └─process_dataset.py # data preparation script
├─scripts
│ ├─run_distribute_train_gpu.sh # launch distribute train on GPU
│ ├─run_eval_gpu.sh # launch evaluation on GPU
│ └─run_standalone_train_gpu.sh # launch standalone train on GPU
├─src
│ ├─cfg
│ │ ├─__init__.py
│ │ └─config.py # parameter parser
│ ├─dataset.py # dataset script and utils
│ ├─layers.py # model layers
│ ├─model.py # model script
│ ├─modules.py # model modules
│ └─utils.py # utilities used in other scripts
├─default_config.yaml # default configs
├─eval.py # evaluation script
├─export.py # export to MINDIR script
└─train.py # training script
```
### [Script Parameters](#contents)
```yaml
# Main arguments:
# training params
batch_size: 16 # Batch size for training
epochs: 30 # Number of training epochs
learning_rate: 0.01 # Learning rate init
backbone_lr_mult: 100 # Learning rate scaling (division) for backbone params
lr_decay: 0.1 # Learning rate scaling at milestone
milestones: [20, 26] # Milestones for learning rate scheduler
input_size: 320 # Input crop size for training
```
### [Training Process](#contents)
#### Standalone Training
Note: For all trainings necessary to use pretrained modilenetv2 as backbone.
```bash
bash scripts/run_standalone_train_gpu.sh [DEVICE_ID] [LOGS_CKPT_DIR] [MOBILENET_CKPT] [DATA_DIR] [BG_DIR]
```
- DEVICE_ID - process device ID.
- LOGS_CKPT_DIR - path to the directory, where the training results (ckpts, logs) will be stored.
- MOBILENET_CKPT - path to the pre-trained mobilenetv2 backbone ([link](https://mindspore.cn/resources/hub/details/en?MindSpore/ascend/1.2/mobilenetv2_v1.2_imagenet2012)).
- DATA_DIR - path to image matting dataset (AIM-500 folder, in this case).
- BG_DIR - path to backgrounds dataset (COCO/train2014 folder, in this case).
The above command will run in the background, you can view the result through the generated standalone_train.log file.
After training, you can get the training loss and time logs in chosen logs dir.
The model checkpoints will be saved in `[LOGS_CKPT_DIR]` directory.
#### Distribute Training
```bash
bash scripts/run_distribute_train_gpu.sh [DEVICE_NUM] [LOGS_CKPT_DIR] [MOBILENET_CKPT] [DATA_DIR] [BG_DIR]
```
- DEVICE_NUM - number of distribute training devices.
- LOGS_CKPT_DIR - path to the directory, where the training results (ckpts, logs) will be stored.
- MOBILENET_CKPT - path to the pre-trained mobilenetv2 backbone ([link](https://mindspore.cn/resources/hub/details/en?MindSpore/ascend/1.2/mobilenetv2_v1.2_imagenet2012)).
- DATA_DIR - path to image matting dataset (AIM-500 folder, in this case).
- BG_DIR - path to backgrounds dataset (COCO/train2014 folder, in this case).
The above command will run in the background, you can view the result through the generated distribute_train.log file.
After training, you can get the training loss and time logs in chosen logs dir.
The model checkpoints will be saved in `[LOGS_CKPT_DIR]` directory.
### [Evaluation Process](#contents)
#### Evaluation
To start evaluation run the command below.
```bash
bash scripts/run_eval_gpu.sh [DEVICE_ID] [CKPT_URL] [DATA_DIR] [LOGS_DIR]
```
- DEVICE_ID - process device ID.
- CKPT_URL - path to the trained IndexNet model.
- DATA_DIR - path to image matting dataset (AIM-500 folder, in this case).
- LOGS_DIR - path to the directory, where the eval results (outputs, logs) will be stored.
The above python command will run in the background. Predicted masks (.png) will be stored into chosen `[LOGS_DIR]`.
And there you can view the results through the file "eval.log".
### [Model Export](#contents)
To export the model to mindir format, run the following command:
```bash
python export.py --ckpt_url [CKPT_URL]
```
- CKPT_URL - path to the trained IndexNet model.
## [Model Description](#contents)
### [Performance](#contents)
#### Training Performance
| Parameters | GPU (1p) | GPU (8p) |
| -------------------------- |------------------------------------------------------------ |--------------------------------------------------------------- |
| Model | IndexNet | IndexNet |
| Hardware | 1 Nvidia Tesla V100-PCIE, CPU @ 3.40GHz | 8 Nvidia RTX 3090, Intel Xeon Gold 6226R CPU @ 2.90GHz |
| Upload Date | 07/04/2022 (day/month/year) | 07/04/2022 (day/month/year) |
| MindSpore Version | 1.6.1 | 1.6.1 |
| Dataset | AIM-500, COCO-2014 (composition of datasets) | AIM-500, COCO-2014 (composition of datasets) |
| Training Parameters | epochs=30, lr=0.01, batch_size=16, num_workers=12 | epochs=30, lr=0.01, batch_size=16 (each device), num_workers=4 |
| Optimizer | Adam, beta1=0.9, beta2=0.999, eps=1e-8 | Adam, beta1=0.9, beta2=0.999, eps=1e-8 |
| Loss Function | Weighted loss (alpha predictions loss and composition loss) | Weighted loss (alpha predictions loss and composition loss) |
| Speed | ~ 516 ms/step | ~ 2670 ms/step |
| Total time | ~ 11.6 hours | ~ 7.5 hours |
#### Evaluation Performance
| Parameters | GPU (1p) | GPU (8p) |
| -----------------------|--------------------------------------------------------|--------------------------------------------------------|
| Model | IndexNet | IndexNet |
| Resource | 1 Nvidia RTX 3090, Intel Xeon Gold 6226R CPU @ 2.90GHz | 1 Nvidia RTX 3090, Intel Xeon Gold 6226R CPU @ 2.90GHz |
| Upload Date | 07/04/2022 (day/month/year) | 07/04/2022 (day/month/year) |
| MindSpore Version | 1.6.1 | 1.6.1 |
| Dataset | AIM-500, COCO-2014 (composition of datasets) | AIM-500, COCO-2014 (composition of datasets) |
| Batch_size | 1 | 1 |
| Outputs | .png images of alpha masks | .png images of alpha masks |
| Metrics | 21.51 SAD, 0.0096 MSE, 13.43 Grad, 20.43 Conn | 22.06 SAD, 0.0134 MSE, 12.84 Grad, 21.32 Conn |
| Metrics expected range | < 24.00 SAD, < 0.0120 MSE, < 13.70 Grad, < 23.20 Conn | < 24.20 SAD, < 0.0145 MSE, < 13.40 Grad, < 22.70 Conn |
## [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/models).
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dataset preparation script."""
import math
import shutil
from pathlib import Path
import cv2 as cv
import numpy as np
from src.cfg.config import config
# Original dataset without splitting into train/test parts.
# We manually selected 50 test images and fixed its names to be able to reproduce results.
_AIM500_VALIDATION_SUBSET = (
'247d445d', '0530e2e7', '8431ecb9', '76c7dc77', 'bab88684',
'fabaf6e3', 'bde965af', 'e4a72c06', 'a974d3a3', 'e92f575c',
'dc288b1a', 'e30ce6cb', '7c642e86', '57e4d780', '52a2115b',
'33a4da38', '77d7a529', '2be0f6c9', '0a5e5a64', 'b71d875b',
'b4157744', '0723d28f', 'a3798f05', '404618b5', '780f404a',
'8e2eb72f', '7c88f64e', '6e470e4a', '819ee421', '12c828bc',
'b2dfc20d', '80193d44', '51873814', 'b40d228a', '26268b4b',
'4729fa87', 'fe6f4047', 'e92b90fc', '23bbcc9d', '197d7af1',
'868c53f0', '7b1db264', '4852f7d4', 'd2a1bff3', '751bcc44',
'44dd34a4', '1ea2b894', '9992c618', 'e62df10b', 'dbef692f',
)
def listdir(folder, file_format=".jpg"):
"""
Search files into chosen dir.
Args:
folder (pathlib.Path): Path to folder.
file_format (str): Search files of current format.
Returns:
out (list): Names of files found into folder.
"""
return [file.name for file in folder.iterdir() if file.name.endswith(f'{file_format}')]
def set_data_structure(dataset_path):
"""
Train test split and organize dataset structure.
Args:
dataset_path (pathlib.Path): Path to main dataset folder.
"""
test_names = ['o_' + file + '.jpg' for file in _AIM500_VALIDATION_SUBSET]
data_dirs = ['validation/original', 'validation/mask', 'validation/trimap', 'train']
for data_dir in data_dirs:
Path(dataset_path, data_dir).mkdir(parents=True)
for name in test_names:
shutil.move(f"{dataset_path}/original/{name}", f"{dataset_path}/validation/original/{name}")
name = name.replace('jpg', 'png')
shutil.move(f"{dataset_path}/mask/{name}", f"{dataset_path}/validation/mask/{name}")
shutil.move(f"{dataset_path}/trimap/{name}", f"{dataset_path}/validation/trimap/{name}")
shutil.move(f"{dataset_path}/original", f"{dataset_path}/train/original")
shutil.move(f"{dataset_path}/mask", f"{dataset_path}/train/mask")
shutil.rmtree(f"{dataset_path}/trimap")
def composite4(fg, bg, a, w, h):
"""
Place foreground over background by mask.
Args:
fg (np.array): Foreground image.
bg (np.array): Background image.
a (np.array): Mask of the foreground image.
w (int): Width of the foreground image.
h (int): Height of the foreground image.
Returns:
comp (np.array): Foreground placed by mask over background.
"""
fg = np.array(fg, np.float32)
bg = np.array(bg[0:h, 0:w], np.float32)
alpha = np.zeros((h, w, 1), np.float32)
alpha[:, :, 0] = a / 255.
comp = alpha * fg + (1 - alpha) * bg
comp = comp.astype(np.uint8)
return comp
def process(start, num_bgs, data_root, bg_path, part="train"):
"""
Compose foregrounds and backgrounds by mask to make dataset.
Args:
start (int): Start index of the background images.
num_bgs (int): Number of unique backgrounds per one image from matting dataset.
data_root (pathlib.Path): Path to matting dataset directory.
bg_path (pathlib.Path): Path to backgrounds dataset directory.
part (str): Which part of dataset to prepare (subfolder).
Returns:
data_part_size (int): Number of processed images at the current part of dataset.
"""
print(f'Start processing {part} part ...')
a_path = data_root / part / 'mask'
out_path = data_root / part / 'merged'
fg_dir = data_root / part / 'original'
bg_dir = bg_path
out_path.mkdir(parents=True, exist_ok=True)
fg_files = listdir(fg_dir, file_format='.jpg')
data_part_size = int(len(fg_files) * num_bgs)
bg_files = listdir(bg_dir, file_format='.jpg')[start:start + data_part_size]
fg_files.sort()
bg_files.sort()
data_schema = []
bg_iter = iter(bg_files)
for k, im_name in enumerate(fg_files):
im = cv.imread(str(fg_dir / im_name))
a = cv.imread(str(a_path / im_name.replace('jpg', 'png')), 0)
h, w = im.shape[:2]
bcount = 0
for _ in range(num_bgs):
bg_name = next(bg_iter)
bg = cv.imread(str(bg_dir / bg_name))
bh, bw = bg.shape[:2]
wratio = float(w) / float(bw)
hratio = float(h) / float(bh)
ratio = wratio if wratio > hratio else hratio
if ratio > 1:
bg = cv.resize(bg, (math.ceil(bw * ratio), math.ceil(bh * ratio)), interpolation=cv.INTER_CUBIC)
out = composite4(im, bg, a, w, h)
filename = Path(out_path, im_name[:len(im_name) - 4] + '_' + str(bcount) + '.png')
cv.imwrite(str(filename), out, [cv.IMWRITE_PNG_COMPRESSION, 9])
if part == 'train':
# [processed, mask, foreground, background]
data_schema_cell = [
str(Path(part, 'merged', filename.name)),
str(Path(part, 'mask', im_name.replace('.jpg', '.png'))),
str(Path(part, 'original', im_name)),
str(Path(bg_name)) + '\n',
]
else:
# [processed, mask, trimap]
data_schema_cell = [
str(Path(part, 'merged', filename.name)),
str(Path(part, 'mask', im_name.replace('.jpg', '.png'))),
str(Path(part, 'trimap', im_name.replace('.jpg', '.png'))) + '\n',
]
data_schema.append('|'.join(data_schema_cell))
bcount += 1
print(f'{k * num_bgs + bcount}/{data_part_size}')
with Path(Path(out_path).resolve().parent, 'data.txt').open('w') as file:
file.writelines(data_schema)
print(f'Successfully end {part} part processing.')
return data_part_size
if __name__ == "__main__":
main_dataset = Path(config.data_dir)
backgrounds = Path(config.bg_dir)
if not main_dataset.is_dir():
raise NotADirectoryError(f'Not valid path to main dataset: {main_dataset}')
if not backgrounds.is_dir():
raise NotADirectoryError(f'Not valid path to backgrounds: {backgrounds}')
set_data_structure(dataset_path=main_dataset)
end = process(start=0, num_bgs=config.num_bgs_train, data_root=main_dataset, bg_path=backgrounds, part='train')
_ = process(start=end, num_bgs=config.num_bgs_val, data_root=main_dataset, bg_path=backgrounds, part='validation')
# training params
batch_size: 16
epochs: 30
learning_rate: 0.01 # 1P training
backbone_lr_mult: 100
lr_decay: 0.1
milestones: [20, 26]
input_size: 320
# validation params
img_size: [1080, 1620]
# model params
width_mult: 1.0
output_stride: 32
decoder_kernel_size: 5
conv_operator: 'std_conv'
apply_aspp: True
use_context: True
use_nonlinear: True
# backbone params (mobilenetv2)
rate: 1
current_stride: 1
# expand_ratio, input_chn, output_chn, num_blocks, stride, dilation
inverted_residual_setting: [
[1, 32, 16, 1, 1, 1],
[6, 16, 24, 2, 2, 1],
[6, 24, 32, 3, 2, 1],
[6, 32, 64, 4, 2, 1],
[6, 64, 96, 3, 1, 1],
[6, 96, 160, 3, 2, 1],
[6, 160, 320, 1, 1, 1],
]
# data processing params
num_bgs_train: 96
num_bgs_val: 20
# data normalization params
img_scale: 255
scales: [1, 1.5, 2]
img_std: [0.229, 0.224, 0.225, 1]
img_mean: [0.485, 0.456, 0.406, 0]
# other
data_dir: '/path/to/matting/dataset/'
bg_dir: '/path/to/coco/backgrounds/'
ckpt_url: '/path/to/checkpoint/'
logs_dir: 'logs'
device_target: 'GPU'
device_id: 0
device_start: 0
keep_checkpoint_max: 10
is_distributed: False
num_workers: 12
---
# Config description for each option
batch_size: "Batch size for training."
epochs: "Number of training epochs."
learning_rate: "Learning rate init."
backbone_lr_mult: "Learning rate scaling (division) for backbone params."
lr_decay: "Learning rate scaling at milestone."
milestones: "Milestones for learning rate scheduler."
input_size: "Input crop size for training."
img_size: "Validation input image size."
width_mult: "Hidden layers ratio."
output_stride: "Output image stride."
decoder_kernel_size: "Decoder conv kernel size."
conv_operator: "Conv operator for decoder."
apply_aspp: "Use ASPP."
use_context: "Use context in index blocks."
use_nonlinear: "Use nonlinear in index blocks."
rate: "Encoder (mobilenetv2) rate."
current_stride: "Encoder (mobilenetv2) stride."
inverted_residual_setting: "Encoder (mobilenetv2) settings."
num_bgs_train: "Number of backgrounds to merge with foreground (for processing train part)."
num_bgs_val: "Number of backgrounds to merge with foreground (for processing validation part)."
img_scale: "Image scaling value."
scales: "Scales for cropping."
img_std: "Std to every channel for image normalization."
img_mean: "Mean to every channel for image normalization."
data_dir: "Image matting dataset dir."
bg_dir: "COCO dataset train part dir."
ckpt_url: "Checkpoint url (and url for pretrained)."
logs_dir: "Output logs dir."
device_target: "Target device platform."
device_id: "Device id of the target platform."
device_start: "Main device for distribute training."
keep_checkpoint_max: "Save last N checkpoints during train."
is_distributed: "Run distribute train or not."
num_workers: "Number of the parallel CPU workers for dataloader."
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Evaluation script."""
from pathlib import Path
from time import time
import cv2
import numpy as np
from PIL import Image
from mindspore import Tensor
from mindspore import context
from mindspore import dtype as mstype
from mindspore import load_checkpoint
from mindspore.common import set_seed
from src.cfg.config import config as default_config
from src.dataset import ImageMattingDatasetVal
from src.model import MobileNetV2UNetDecoderIndexLearning
from src.utils import compute_connectivity_loss
from src.utils import compute_gradient_loss
from src.utils import compute_mse_loss
from src.utils import compute_sad_loss
from src.utils import image_alignment
def evaluation(config):
"""
Init model, dataset, run evaluation.
Args:
config: Config parameters.
"""
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
val_loader = ImageMattingDatasetVal(
data_dir=config.data_dir,
config=config,
sub_folder='validation',
data_file='data.txt',
)
net = MobileNetV2UNetDecoderIndexLearning(
encoder_rate=config.rate,
encoder_current_stride=config.current_stride,
encoder_settings=config.inverted_residual_setting,
output_stride=config.output_stride,
width_mult=config.width_mult,
conv_operator=config.conv_operator,
decoder_kernel_size=config.decoder_kernel_size,
apply_aspp=config.apply_aspp,
use_nonlinear=config.use_nonlinear,
use_context=config.use_context,
)
load_checkpoint(config.ckpt_url, net)
net.set_train(False)
with Path(config.data_dir, 'validation/data.txt').open() as file:
image_list = [name.split('|') for name in file.read().splitlines()]
eval_logs_dir = Path(config.logs_dir)
eval_logs_dir.mkdir(parents=True, exist_ok=True)
sad = []
mse = []
grad = []
conn = []
avg_frame_rate = 0.0
stride = config.output_stride
start = time()
for i, (image, gt_alpha, trimap, transposed, pad_mask, size) in enumerate(val_loader):
h, w = image.shape[1:]
image = image.transpose(1, 2, 0)
image = image_alignment(image, stride, odd=False)
inputs = Tensor(np.expand_dims(image.transpose(2, 0, 1), axis=0), mstype.float32)
# Inference
outputs = net(inputs).asnumpy().squeeze()
alpha = cv2.resize(outputs, dsize=(w, h), interpolation=cv2.INTER_CUBIC)
alpha = alpha[pad_mask].reshape(size)
alpha = np.clip(alpha, 0, 1) * 255.
# Trimap edge region
mask = np.equal(trimap, 128).astype(np.float32)
alpha = (1 - mask) * trimap + mask * alpha
gt_alpha = gt_alpha * 255.
save_path = eval_logs_dir / Path(image_list[i][0]).name
if transposed:
Image.fromarray(alpha.transpose(1, 0).astype(np.uint8)).save(save_path)
else:
Image.fromarray(alpha.astype(np.uint8)).save(save_path)
# compute loss
sad.append(compute_sad_loss(alpha, gt_alpha, mask))
mse.append(compute_mse_loss(alpha, gt_alpha, mask))
grad.append(compute_gradient_loss(alpha, gt_alpha, mask))
conn.append(compute_connectivity_loss(alpha, gt_alpha, mask))
end = time()
running_frame_rate = 1 * float(1 / (end - start))
avg_frame_rate = (avg_frame_rate * i + running_frame_rate) / (i + 1)
print(
f'test: {i + 1}/{len(val_loader)}, sad: {sad[-1]:.2f},'
f' mse: {mse[-1]:.4f}, grad: {grad[-1]:.2f}, conn: {conn[-1]:.2f},'
f' frame: {running_frame_rate:.2f}FPS',
)
start = time()
print(60 * '=')
print(
f'SAD: {np.mean(sad):.2f}, MSE: {np.mean(mse):.4f},'
f' Grad: {np.mean(grad):.2f}, Conn: {np.mean(conn):.2f},'
f' frame: {avg_frame_rate:.2f}FPS',
)
print('Evaluation success')
if __name__ == '__main__':
set_seed(1)
evaluation(config=default_config)
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Export to MINDIR."""
from pathlib import Path
import numpy as np
from mindspore import Tensor
from mindspore import context
from mindspore import dtype as mstype
from mindspore import load_checkpoint
from mindspore.train.serialization import export
from src.cfg.config import config as default_config
from src.model import MobileNetV2UNetDecoderIndexLearning
def _calculate_size(size, stride, odd):
new_size = np.ceil(size / stride) * stride
if odd:
new_size += 1
return int(new_size)
def run_export(config):
"""
Export model to MINDIR.
Args:
config: Config parameters.
"""
model = MobileNetV2UNetDecoderIndexLearning(
encoder_rate=config.rate,
encoder_current_stride=config.current_stride,
encoder_settings=config.inverted_residual_setting,
output_stride=config.output_stride,
width_mult=config.width_mult,
conv_operator=config.conv_operator,
decoder_kernel_size=config.decoder_kernel_size,
apply_aspp=config.apply_aspp,
use_nonlinear=config.use_nonlinear,
use_context=config.use_context,
)
load_checkpoint(config.ckpt_url, model)
model.set_train(False)
# Correctly process input
odd_input = config.input_size % 2 == 1
h = _calculate_size(config.img_size[0], config.output_stride, odd_input)
w = _calculate_size(config.img_size[1], config.output_stride, odd_input)
model_input = Tensor(np.ones([1, 4, h, w]), mstype.float32)
save_path = Path(config.ckpt_url).resolve().with_suffix('').as_posix()
export(model, model_input, file_name=save_path, file_format='MINDIR')
print('Model exported successfully!')
print(f'Path to exported model {save_path}.mindir')
if __name__ == "__main__":
context.set_context(
mode=context.GRAPH_MODE,
device_target=default_config.device_target,
device_id=default_config.device_id,
)
run_export(default_config)
PyYAML
opencv-python
Pillow
SciPy
scikit-image
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [[ $# -ne 5 ]]; then
echo "Usage: bash ./scripts/run_distribute_train_gpu.sh [DEVICE_NUM] [LOGS_CKPT_DIR] [MOBILENET_CKPT] [DATA_DIR] [BG_DIR]"
exit 1;
fi
export RANK_SIZE=$1
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
realpath -m "$PWD/$1"
fi
}
LOGS_DIR=$(get_real_path "$2")
CKPT_URL=$(get_real_path "$3")
DATA_DIR=$(get_real_path "$4")
BG_DIR=$(get_real_path "$5")
if [ ! -d "$LOGS_DIR" ]; then
mkdir "$LOGS_DIR"
mkdir "$LOGS_DIR/training_configs"
fi
cp ./*.py "$LOGS_DIR"/training_configs
cp ./*.yaml "$LOGS_DIR"/training_configs
cp -r ./src "$LOGS_DIR"/training_configs
mpirun -n $1 --allow-run-as-root\
python train.py \
--device_target="GPU" \
--logs_dir="$LOGS_DIR" \
--ckpt_url="$CKPT_URL" \
--data_dir="$DATA_DIR" \
--bg_dir="$BG_DIR" \
--is_distributed=True \
--num_workers=4 \
> "$LOGS_DIR"/distribute_train.log 2>&1 &
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [[ $# -ne 4 ]]; then
echo "Usage: bash ./scripts/run_eval_gpu.sh [DEVICE_ID] [CKPT_URL] [DATA_DIR] [LOGS_DIR]"
exit 1;
fi
export CUDA_VISIBLE_DEVICES=$1
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
realpath -m "$PWD/$1"
fi
}
CKPT_URL=$(get_real_path "$2")
DATA_DIR=$(get_real_path "$3")
LOGS_DIR=$(get_real_path "$4")
if [ ! -d "$LOGS_DIR" ]; then
mkdir "$LOGS_DIR"
fi
python eval.py \
--device_target="GPU" \
--device_id=0 \
--ckpt_url="$CKPT_URL" \
--data_dir="$DATA_DIR" \
--logs_dir="$LOGS_DIR" \
> "$LOGS_DIR"/eval.log 2>&1 &
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [[ $# -ne 5 ]]; then
echo "Usage: bash ./scripts/run_standalone_train_gpu.sh [DEVICE_ID] [LOGS_CKPT_DIR] [MOBILENET_CKPT] [DATA_DIR] [BG_DIR]"
exit 1
fi
export CUDA_VISIBLE_DEVICES=$1
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
realpath -m "$PWD/$1"
fi
}
LOGS_DIR=$(get_real_path "$2")
CKPT_URL=$(get_real_path "$3")
DATA_DIR=$(get_real_path "$4")
BG_DIR=$(get_real_path "$5")
if [ ! -d "$LOGS_DIR" ]; then
mkdir "$LOGS_DIR"
mkdir "$LOGS_DIR/training_configs"
fi
cp ./*.py "$LOGS_DIR"/training_configs
cp ./*.yaml "$LOGS_DIR"/training_configs
cp -r ./src "$LOGS_DIR"/training_configs
python train.py \
--device_target="GPU" \
--device_id=0 \
--logs_dir="$LOGS_DIR" \
--ckpt_url="$CKPT_URL" \
--data_dir="$DATA_DIR" \
--bg_dir="$BG_DIR" \
> "$LOGS_DIR"/standalone_train.log 2>&1 &
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse arguments"""
import argparse
import ast
from pathlib import Path
from pprint import pformat
import yaml
class Config:
"""
Configuration namespace, convert dictionary to members.
Args:
cfg_dict (dict): Config parameters.
"""
def __init__(self, cfg_dict):
for k, v in cfg_dict.items():
if isinstance(v, (list, tuple)):
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
else:
setattr(self, k, Config(v) if isinstance(v, dict) else v)
def __str__(self):
return pformat(self.__dict__)
def __repr__(self):
return self.__str__()
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
"""
Parse command line arguments to the configuration according to the default yaml.
Args:
parser (argparse.ArgumentParser): Parent parser.
cfg (dict): Base configuration.
helper (dict): Helper description.
choices (dict): Choices.
cfg_path (str): Path to default_config.yaml.
Returns:
args: Parsed args from default_config.yaml.
"""
helper = {} if helper is None else helper
choices = {} if choices is None else choices
for item in cfg:
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
help_description = helper[item] if item in helper else f"Please reference to {cfg_path}"
choice = choices[item] if item in choices else None
if isinstance(cfg[item], bool):
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
help=help_description)
else:
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
help=help_description)
args = parser.parse_args()
return args
def parse_yaml(yaml_path):
"""
Parse the yaml config file.
Args:
yaml_path (str): Path to yaml config.
Returns:
cfg: Config parameters values.
cfg_helper: Config parameters descriptions.
cfg_choices: Config parameters choices.
"""
with open(yaml_path, 'r') as fin:
try:
cfgs_raw = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
cfgs = []
for cf in cfgs_raw:
cfgs.append(cf)
if len(cfgs) == 1:
cfg_helper = {}
cfg = cfgs[0]
cfg_choices = {}
elif len(cfgs) == 2:
cfg, cfg_helper = cfgs
cfg_choices = {}
elif len(cfgs) == 3:
cfg, cfg_helper, cfg_choices = cfgs
else:
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
except ValueError("Failed to parse yaml") as err:
raise err
return cfg, cfg_helper, cfg_choices
def merge(args, cfg):
"""
Merge the base config from yaml file and command line arguments.
Args:
args: Command line arguments.
cfg: Base configuration.
Returns:
cfg: Merged arguments.
"""
args_var = vars(args)
for item in args_var:
cfg[item] = args_var[item]
return cfg
def get_config():
"""
Get Config according to the yaml file and cli arguments.
Returns:
config: Parsed and merged config arguments from argparse and yaml config.
"""
parser = argparse.ArgumentParser(description="IndexNet config.", add_help=False)
curr_dir = Path(__file__).resolve().parent
parser.add_argument("--config_path", type=str, default=str(curr_dir / '../../default_config.yaml'),
help="Path to config.")
path_args, _ = parser.parse_known_args()
default, helper, choices = parse_yaml(path_args.config_path)
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices)
final_config = merge(args, default)
return Config(final_config)
config = get_config()
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dataset scripts."""
import random
from pathlib import Path
import cv2
import numpy as np
from PIL import Image
from scipy.ndimage import morphology
def resize_image_alpha(image, alpha, nh, nw):
"""
Resize image and alpha concatenated input to the same size.
Args:
image (np.array): Input 4 channels image.
alpha (np.array): Stacked alpha, mask, fg, bg and image.
nh (int): Height resize size.
nw (int): Width resize size.
Returns:
image (np.array): Resized input 4 channels image.
alpha (np.array): Resized stacked alpha, mask, fg, bg and image.
"""
alpha_chn = alpha.shape[2]
trimap = image[:, :, 3]
image = image[:, :, 0:3]
mask = alpha[:, :, 1]
if alpha_chn > 2:
fg = alpha[:, :, 2:5]
bg = alpha[:, :, 5:8]
ori_image = alpha[:, :, 8:11]
alpha = alpha[:, :, 0]
image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_CUBIC)
trimap = cv2.resize(trimap, (nw, nh), interpolation=cv2.INTER_NEAREST)
alpha = cv2.resize(alpha, (nw, nh), interpolation=cv2.INTER_CUBIC)
mask = cv2.resize(mask, (nw, nh), interpolation=cv2.INTER_NEAREST)
if alpha_chn > 2:
fg = cv2.resize(fg, (nw, nh), interpolation=cv2.INTER_CUBIC)
bg = cv2.resize(bg, (nw, nh), interpolation=cv2.INTER_CUBIC)
ori_image = cv2.resize(ori_image, (nw, nh), interpolation=cv2.INTER_CUBIC)
trimap = trimap.reshape(trimap.shape[0], trimap.shape[1], 1)
alpha = alpha.reshape(alpha.shape[0], alpha.shape[1], 1)
mask = mask.reshape(mask.shape[0], mask.shape[1], 1)
image = np.concatenate((image, trimap), axis=2)
if alpha_chn > 2:
alpha = np.concatenate((alpha, mask, fg, bg, ori_image), axis=2)
else:
alpha = np.concatenate((alpha, mask), axis=2)
return image, alpha
class RandomCrop:
"""
Randomly crop the image.
Args:
output_size (int): Desired output size. If int, square crop is made.
scales (list): Desired scales
"""
def __init__(self, output_size, scales):
assert isinstance(output_size, int)
self.output_size = output_size
self.scales = scales
def __call__(self, sample):
image, alpha = sample['image'], sample['alpha']
h, w = image.shape[:2]
if min(h, w) < self.output_size:
s = (self.output_size + 180) / min(h, w)
nh, nw = int(np.floor(h * s)), int(np.floor(w * s))
image, alpha = resize_image_alpha(image, alpha, nh, nw)
h, w = image.shape[:2]
crop_size = np.floor(self.output_size * np.array(self.scales)).astype('int')
crop_size = crop_size[crop_size < min(h, w)]
crop_size = int(random.choice(crop_size))
c = int(np.ceil(crop_size / 2))
mask = np.equal(image[:, :, 3], 128).astype(np.uint8)
if mask[c:h - c + 1, c:w - c + 1].sum() != 0:
mask_center = np.zeros((h, w), dtype=np.uint8)
mask_center[c:h - c + 1, c:w - c + 1] = 1
mask = (mask & mask_center)
idh, idw = np.where(mask == 1)
ids = random.choice(range(len(idh)))
hc, wc = idh[ids], idw[ids]
h1, w1 = hc - c, wc - c
else:
idh, idw = np.where(mask == 1)
ids = random.choice(range(len(idh)))
hc, wc = idh[ids], idw[ids]
h1, w1 = np.clip(hc - c, 0, h), np.clip(wc - c, 0, w)
h2, w2 = h1 + crop_size, w1 + crop_size
h1 = h - crop_size if h2 > h else h1
w1 = w - crop_size if w2 > w else w1
image = image[h1:h1 + crop_size, w1:w1 + crop_size, :]
alpha = alpha[h1:h1 + crop_size, w1:w1 + crop_size, :]
if crop_size != self.output_size:
nh = nw = self.output_size
image, alpha = resize_image_alpha(image, alpha, nh, nw)
return {'image': image, 'alpha': alpha}
class RandomFlip:
"""
Randomly flip the image and alpha.
"""
def __init__(self):
pass
def __call__(self, sample):
image, alpha = sample['image'], sample['alpha']
do_mirror = np.random.randint(2)
if do_mirror:
image = cv2.flip(image, 1)
alpha = cv2.flip(alpha, 1)
return {'image': image, 'alpha': alpha}
class Transpose:
"""
Transpose arrays to the input view.
"""
def __init__(self):
pass
def __call__(self, sample):
image, alpha = sample['image'], sample['alpha']
# swap color axis
# from numpy image: H x W x C
# to numpy image: C X H X W
image = image.transpose((2, 0, 1))
alpha = alpha.transpose((2, 0, 1))
return {'image': image, 'alpha': alpha}
class ResizePad:
"""
Resize pad and transpose input image (if necessary).
If image is vertical, transpose, resize, pad.
Args:
size (tuple): Input image size.
"""
def __init__(self, size):
self.h = size[0]
self.w = size[1]
def __call__(self, sample):
image, alpha = sample['image'], sample['alpha']
h, w = image.shape[:2]
transp = False
# Flip image to horizontal state
if w < h:
transp = True
image = image.transpose(1, 0, 2)
alpha = alpha.transpose(1, 0, 2)
h, w = w, h
if self.w < w or self.h < h:
aspect_ratio = min(self.h / h, self.w / w)
else:
aspect_ratio = max(w / self.w, h / self.h)
nh, nw = int(h * aspect_ratio), int(w * aspect_ratio)
image, alpha = resize_image_alpha(image, alpha, nh, nw)
image_mask = np.ones_like(image[:, :, 0])
image = np.pad(image, ((0, self.h - nh), (0, self.w - nw), (0, 0)))
image_mask = np.pad(image_mask, ((0, self.h - nh), (0, self.w - nw))).astype(np.bool)
return {'image': image, 'alpha': alpha}, transp, image_mask, (nh, nw)
class Normalize:
"""
Normalize image.
Args:
scale (float): Scale for image.
mean (np.array): 4 dims mean to normalize every axis.
std (np.array): 4 dims std to normalize every axis.
"""
def __init__(self, scale, mean, std):
self.scale = scale
self.mean = mean
self.std = std
def __call__(self, sample):
image, alpha = sample['image'], sample['alpha']
image, alpha = image.astype('float32'), alpha.astype('float32')
image = (self.scale * image - self.mean) / self.std
alpha[:, :, 0] = self.scale * alpha[:, :, 0]
if alpha.shape[2] > 2:
alpha[:, :, 2:11] = self.scale * alpha[:, :, 2:11]
return {'image': image.astype('float32'), 'alpha': alpha.astype('float32')}
class BaseDataset:
"""
Base dataset class for Image Matting.
Args:
data_dir (str): Path to the dataset folder.
config: Config parameters.
sub_folder (str): Name of sub folder to train/test part of dataset.
data_file (str): Name of data schema.
"""
def __init__(self, data_dir, config, sub_folder='train', data_file='data.txt'):
self.data_dir = Path(data_dir)
if not self.data_dir.is_dir():
raise NotADirectoryError(f'Not valid path to merged dataset sub folder: {Path(data_dir, sub_folder)}')
with Path(data_dir, sub_folder, data_file).open() as file:
self.datalist = [name.split('|') for name in file.read().splitlines()]
self.scales = config.scales
self.img_scale = 1. / float(config.img_scale)
self.img_mean = np.array(config.img_mean).reshape((1, 1, 4))
self.img_std = np.array(config.img_std).reshape((1, 1, 4))
def __len__(self):
return len(self.datalist)
@staticmethod
def _read_image(image_path):
"""
Read image.
Args:
image_path: Image path.
Returns:
img_arr (np.array): Loaded image.
"""
img_arr = np.array(Image.open(image_path))
if len(img_arr.shape) == 2: # grayscale
img_arr = np.tile(img_arr, [3, 1, 1]).transpose((1, 2, 0))
return img_arr
class ImageMattingDatasetTrain(BaseDataset):
"""
Image Matting train dataset.
Args:
data_dir (str): Path to the dataset folder.
bg_dir (str): Path to the backgrounds dataset folder.
config: Config parameters.
sub_folder (str): Name of sub folder to train part of dataset.
data_file (str): Name of data schema.
"""
def __init__(self, data_dir, bg_dir, config, sub_folder, data_file):
super().__init__(data_dir, config, sub_folder, data_file)
self.bg_dir = Path(bg_dir)
self.crop_size = config.input_size
if not self.bg_dir.is_dir():
raise NotADirectoryError(f'Not valid path to backgrounds: {self.bg_dir}')
@staticmethod
def _generate_trimap(alpha):
"""
Generate trimap with random line width.
Args:
alpha (np.array): Image alpha channel (mask).
Returns:
trimap (np.array): Generated trimap.
"""
# alpha \in [0, 1] should be taken into account
# be careful when dealing with regions of alpha=0 and alpha=1
fg = np.array(np.equal(alpha, 255).astype(np.float32))
unknown = np.array(np.not_equal(alpha, 0).astype(np.float32)) # unknown = alpha > 0
unknown = unknown - fg
# image dilation implemented by Euclidean distance transform
unknown = morphology.distance_transform_edt(unknown == 0) <= np.random.randint(1, 20)
trimap = fg * 255
trimap[unknown] = 128
return trimap.astype(np.uint8)
def __getitem__(self, idx):
image_name = self.data_dir / self.datalist[idx][0]
alpha_name = self.data_dir / self.datalist[idx][1]
fg_name = self.data_dir / self.datalist[idx][2]
bg_name = self.bg_dir / self.datalist[idx][3]
image = self._read_image(image_name)
fg = self._read_image(fg_name)
bg = self._read_image(bg_name)
alpha = np.array(Image.open(alpha_name))
if alpha.ndim != 2:
alpha = alpha[:, :, 0]
fgh, fgw = fg.shape[0:2]
bgh, bgw = bg.shape[0:2]
rh, rw = fgh / float(bgh), fgw / float(bgw)
r = rh if rh > rw else rw
nh, nw = int(np.ceil(bgh * r)), int(np.ceil(bgw * r))
bg = cv2.resize(bg, (nw, nh), interpolation=cv2.INTER_CUBIC)
bg = bg[0:fgh, 0:fgw, :]
trimap = self._generate_trimap(alpha)
mask = np.equal(trimap, 128).astype(np.uint8)
alpha = alpha.reshape((alpha.shape[0], alpha.shape[1], 1))
trimap = trimap.reshape((trimap.shape[0], trimap.shape[1], 1))
mask = mask.reshape((mask.shape[0], mask.shape[1], 1))
alpha = np.concatenate((alpha, mask, fg, bg, image), axis=2)
image = np.concatenate((image, trimap), axis=2)
sample = {'image': image, 'alpha': alpha}
# Apply training transforms
sample = RandomCrop(self.crop_size, self.scales)(sample)
sample = RandomFlip()(sample)
sample = Normalize(self.img_scale, self.img_mean, self.img_std)(sample)
sample = Transpose()(sample)
image = sample['image']
alpha = sample['alpha'][0, :, :]
mask = sample['alpha'][1, :, :]
fg = sample['alpha'][2:5, :, :]
bg = sample['alpha'][5:8, :, :]
c_g = sample['alpha'][8:11, :, :]
return image, mask, alpha, fg, bg, c_g
class ImageMattingDatasetVal(BaseDataset):
"""
Image Matting validation dataset.
Args:
data_dir (str): Path to the dataset folder.
config: Config parameters.
sub_folder (str): Name of sub folder to test part of dataset.
data_file (str): Name of data schema.
"""
def __init__(self, data_dir, config, sub_folder, data_file):
super().__init__(data_dir, config, sub_folder, data_file)
self.img_size = config.img_size
def __getitem__(self, idx):
image_name = self.data_dir / self.datalist[idx][0]
alpha_name = self.data_dir / self.datalist[idx][1]
trimap_name = self.data_dir / self.datalist[idx][2]
image = self._read_image(image_name)
alpha = np.array(Image.open(alpha_name))
if alpha.ndim != 2:
alpha = alpha[:, :, 0]
trimap = np.array(Image.open(trimap_name))
alpha = alpha[:, :, 0] if alpha.ndim == 3 else alpha
alpha = alpha.reshape((alpha.shape[0], alpha.shape[1], 1))
trimap = trimap.reshape((trimap.shape[0], trimap.shape[1], 1))
image = np.concatenate((image, trimap), axis=2)
alpha = np.concatenate((alpha, trimap), axis=2)
sample = {'image': image, 'alpha': alpha}
# Apply validation transforms
sample, transposed, pad_mask, clear_size = ResizePad(self.img_size)(sample)
sample = Normalize(self.img_scale, self.img_mean, self.img_std)(sample)
sample = Transpose()(sample)
image = sample['image']
alpha = sample['alpha'][0, :, :]
mask = sample['alpha'][1, :, :]
output = image, alpha, mask, transposed, pad_mask, clear_size
return output
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Model layers."""
from mindspore import nn
def depth_sep_dilated_conv_3x3_bn(inp, oup, padding, dilation):
"""
Dilated depthwise separable convolution block with BN, ReLU.
Args:
inp (int): Input channels of block.
oup (int): Output channels of block.
padding (int): Padding of depthwise conv.
dilation (int): Dilation of depthwise conv.
Returns:
block: Dilated depthwise separable conv block.
"""
return nn.SequentialCell(
[
nn.Conv2d(
in_channels=inp,
out_channels=inp,
kernel_size=3,
stride=1,
pad_mode='pad',
padding=padding,
dilation=dilation,
group=inp,
has_bias=False,
),
nn.BatchNorm2d(num_features=inp),
nn.ReLU6(),
nn.Conv2d(
in_channels=inp,
out_channels=oup,
kernel_size=1,
stride=1,
pad_mode='pad',
padding=0,
has_bias=False,
),
nn.BatchNorm2d(num_features=oup),
nn.ReLU6()
]
)
def dep_sep_conv_bn(inp, oup, k, s):
"""
Depthwise separable convolution block with BN, ReLU.
Args:
inp (int): Input channels of block.
oup (int): Output channels of block.
k (int): Kernel size of depthwise conv.
s (int): Stride of depthwise conv.
Returns:
block: Depthwise separable conv block.
"""
return nn.SequentialCell(
[
nn.Conv2d(
in_channels=inp,
out_channels=inp,
kernel_size=k,
stride=s,
pad_mode='pad',
padding=k // 2,
group=inp,
has_bias=False,
),
nn.BatchNorm2d(num_features=inp),
nn.ReLU6(),
nn.Conv2d(
in_channels=inp,
out_channels=oup,
kernel_size=1,
stride=1,
pad_mode='pad',
padding=0,
has_bias=False,
),
nn.BatchNorm2d(num_features=oup),
nn.ReLU6()
]
)
def conv_bn(inp, oup, k, s):
"""
Conv, BN, ReLU block.
Args:
inp (int): Input channels of block.
oup (int): Output channels of block.
k (int): Kernel size of conv.
s (int): Stride of conv.
Returns:
block: Conv, BN, activation block.
"""
return nn.SequentialCell(
[
nn.Conv2d(
in_channels=inp,
out_channels=oup,
kernel_size=k,
stride=s,
pad_mode='pad',
padding=k // 2,
has_bias=False,
),
nn.BatchNorm2d(num_features=oup),
nn.ReLU6()
]
)
def pred(inp, oup, conv_operator, k):
"""
Output conv block.
Args:
inp (int): Input channels of block.
oup (int): Output channels of block.
conv_operator (str): Type of conv operator to use as input conv block.
k (int): Kernel size of convs.
Returns:
block: Last convs mask prediction block.
"""
# the last 1x1 convolutional layer is very important
hlconv2d = hlconv[conv_operator]
return nn.SequentialCell(
[
hlconv2d(inp, oup, k, 1),
nn.Conv2d(
in_channels=oup,
out_channels=oup,
kernel_size=k,
stride=1,
pad_mode='pad',
padding=k // 2,
has_bias=False,
)
]
)
hlconv = {
'std_conv': conv_bn,
'dep_sep_conv': dep_sep_conv_bn,
}
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Model script."""
import math
import numpy as np
from mindspore import Tensor
from mindspore import dtype as mstype
from mindspore import nn
from mindspore import ops
from src.layers import conv_bn
from src.layers import pred
from src.modules import ASPP
from src.modules import DepthwiseM2OIndexBlock
from src.modules import IndexedUpsamlping
from src.modules import InvertedResidual
class MobileNetV2UNetDecoderIndexLearning(nn.Cell):
"""
IndexNet with MobileNetV2 backbone, UNet architecture.
Args:
encoder_rate: Encoder rate.
encoder_current_stride: Encoder stride.
encoder_settings: Encoder blocks settings.
output_stride (int): Output image stride.
width_mult (float): Width multiplication for mobilenetv2 blocks.
conv_operator (str): Conv operator for decoder.
decoder_kernel_size (int): Decoder conv kernel size.
apply_aspp (bool): Use ASPP.
use_nonlinear (bool): Use nonlinear in index blocks.
use_context (bool): Use context in index blocks.
"""
def __init__(
self,
encoder_rate,
encoder_current_stride,
encoder_settings,
output_stride=32,
width_mult=1.,
conv_operator='std_conv',
decoder_kernel_size=5,
apply_aspp=True,
use_nonlinear=True,
use_context=True,
):
super().__init__()
self.width_mult = width_mult
self.output_stride = output_stride
self.encoder_settings = encoder_settings
# ENCODER
# building the first layer
initial_channel = int(self.encoder_settings[0][1] * width_mult)
self.layer0 = conv_bn(4, initial_channel, 3, 1)
current_stride = encoder_current_stride * 2
# building bottleneck layers
for i, setting in enumerate(self.encoder_settings):
s = setting[4]
self.encoder_settings[i][4] = 1 # change stride
if current_stride == output_stride:
rate = encoder_rate * s
self.encoder_settings[i][5] = rate
else:
current_stride *= s
self.layer1 = self._build_layer(InvertedResidual, self.encoder_settings[0])
self.layer2 = self._build_layer(InvertedResidual, self.encoder_settings[1], downsample=True)
self.layer3 = self._build_layer(InvertedResidual, self.encoder_settings[2], downsample=True)
self.layer4 = self._build_layer(InvertedResidual, self.encoder_settings[3], downsample=True)
self.layer5 = self._build_layer(InvertedResidual, self.encoder_settings[4])
self.layer6 = self._build_layer(InvertedResidual, self.encoder_settings[5], downsample=True)
self.layer7 = self._build_layer(InvertedResidual, self.encoder_settings[6])
# freeze backbone batch norm layers
self.freeze_bn()
# define index blocks
self.index0 = DepthwiseM2OIndexBlock(32, use_nonlinear, use_context)
self.index2 = DepthwiseM2OIndexBlock(24, use_nonlinear, use_context)
self.index3 = DepthwiseM2OIndexBlock(32, use_nonlinear, use_context)
self.index4 = DepthwiseM2OIndexBlock(64, use_nonlinear, use_context)
self.index6 = DepthwiseM2OIndexBlock(160, use_nonlinear, use_context)
# context aggregation
if apply_aspp:
self.dconv_pp = ASPP(320, 160, output_stride=output_stride, width_mult=width_mult)
else:
self.dconv_pp = conv_bn(320, 160, k=1, s=1)
# DECODER
self.decoder_layer6 = IndexedUpsamlping(160 * 2, 96, conv_operator, decoder_kernel_size)
self.decoder_layer5 = IndexedUpsamlping(96 * 2, 64, conv_operator, decoder_kernel_size)
self.decoder_layer4 = IndexedUpsamlping(64 * 2, 32, conv_operator, decoder_kernel_size)
self.decoder_layer3 = IndexedUpsamlping(32 * 2, 24, conv_operator, decoder_kernel_size)
self.decoder_layer2 = IndexedUpsamlping(24 * 2, 16, conv_operator, decoder_kernel_size)
self.decoder_layer1 = IndexedUpsamlping(16 * 2, 32, conv_operator, decoder_kernel_size)
self.decoder_layer0 = IndexedUpsamlping(32 * 2, 32, conv_operator, decoder_kernel_size)
self.pred = pred(32, 1, conv_operator, decoder_kernel_size)
self.avg_pool = ops.AvgPool(pad_mode='same', kernel_size=(2, 2), strides=(2, 2))
self._initialize_weights()
def _build_layer(self, block, layer_setting, downsample=False):
"""
Build MobileNetV2 block.
Args:
block: Encoder block.
layer_setting (list): Encoder block settings.
downsample (bool): Downsample at this block.
Returns:
block: Inited encoder block.
"""
t, p, c, n, s, d = layer_setting
input_channel = int(p * self.width_mult)
output_channel = int(c * self.width_mult)
layers = []
for i in range(n):
if i == 0:
d0 = d
if downsample:
d0 = d // 2 if d > 1 else 1
layers.append(block(input_channel, output_channel, s, d0, expand_ratio=t))
else:
layers.append(block(input_channel, output_channel, 1, d, expand_ratio=t))
input_channel = output_channel
return nn.SequentialCell([*layers])
def _initialize_weights(self):
"""Init model weights."""
for _, cell in self.cells_and_names():
if isinstance(cell, nn.Conv2d):
n = cell.kernel_size[0] * cell.kernel_size[1] * cell.out_channels
weight = np.random.normal(loc=0, scale=math.sqrt(2. / n), size=cell.weight.shape)
cell.weight.set_data(Tensor(weight, mstype.float32))
def freeze_bn(self):
"""Freeze batch norms."""
for _, cell in self.cells_and_names():
if isinstance(cell, nn.BatchNorm2d):
cell.beta.requires_grad = False
cell.gamma.requires_grad = False
cell.use_batch_statistics = False
def construct(self, x):
"""
Feed forward.
Args:
x: Input 4 channel image.
Returns:
alpha_channel: Predicted alpha mask.
"""
# encode
l0 = self.layer0(x) # 4x320x320 (for default crop size=320)
idx0_en, idx0_de = self.index0(l0)
l0 = idx0_en * l0
l0p = 4 * self.avg_pool(l0) # 32x160x160
l1 = self.layer1(l0p) # 16x160x160
l2 = self.layer2(l1) # 24x160x160
idx2_en, idx2_de = self.index2(l2)
l2 = idx2_en * l2
l2p = 4 * self.avg_pool(l2) # 24x80x80
l3 = self.layer3(l2p) # 32x80x80
idx3_en, idx3_de = self.index3(l3)
l3 = idx3_en * l3
l3p = 4 * self.avg_pool(l3) # 32x40x40
l4 = self.layer4(l3p) # 64x40x40
idx4_en, idx4_de = self.index4(l4)
l4 = idx4_en * l4
l4p = 4 * self.avg_pool(l4) # 64x20x20
l5 = self.layer5(l4p) # 96x20x20
l6 = self.layer6(l5) # 160x20x20
idx6_en, idx6_de = self.index6(l6)
l6 = idx6_en * l6
l6p = 4 * self.avg_pool(l6) # 160x10x10
l7 = self.layer7(l6p) # 320x10x10
# pyramid pooling
l_up = self.dconv_pp(l7) # 160x10x10
# decode
l_up = self.decoder_layer6(l_up, l6, idx6_de)
l_up = self.decoder_layer5(l_up, l5)
l_up = self.decoder_layer4(l_up, l4, idx4_de)
l_up = self.decoder_layer3(l_up, l3, idx3_de)
l_up = self.decoder_layer2(l_up, l2, idx2_de)
l_up = self.decoder_layer1(l_up, l1)
l_up = self.decoder_layer0(l_up, l0, idx0_de)
alpha_channel = self.pred(l_up)
return alpha_channel
class LossWrapper(nn.Cell):
"""
Train wrapper to the model.
Args:
model (nn.Cell): Prediction model.
loss_function (func): Loss computation between ground-truth and predictions.
"""
def __init__(self, model, loss_function):
super().__init__()
self.model = model
self.weighted_loss = loss_function
def construct(
self,
inp,
mask_gt,
alpha_gt,
foreground_gt,
background_gt,
merged_gt,
):
"""
Get predictions and compute loss.
Args:
inp: Input 4 channel image.
mask_gt: Image mask.
alpha_gt: Image trimap.
foreground_gt: Original foreground image.
background_gt: Original background image.
merged_gt: Merged by mask foreground over background.
Returns:
loss: Computed weighted loss.
"""
pd = self.model(inp)
loss = self.weighted_loss(pd, mask_gt, alpha_gt, foreground_gt, background_gt, merged_gt)
return loss
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Model modules."""
from mindspore import nn
from mindspore import ops
from src.layers import depth_sep_dilated_conv_3x3_bn
from src.layers import hlconv
class DepthwiseM2OIndexBlock(nn.Cell):
"""
Depthwise many-to-one IndexNet block.
Args:
inp (int): Input channels.
use_nonlinear (bool): Use nonlinear in index blocks.
use_context (bool): Use context in index blocks.
"""
def __init__(self, inp, use_nonlinear, use_context):
super().__init__()
self.indexnet1 = self._build_index_block(inp, use_nonlinear, use_context)
self.indexnet2 = self._build_index_block(inp, use_nonlinear, use_context)
self.indexnet3 = self._build_index_block(inp, use_nonlinear, use_context)
self.indexnet4 = self._build_index_block(inp, use_nonlinear, use_context)
self.sigmoid = ops.Sigmoid()
self.reshape = ops.Reshape()
self.concat = ops.Concat(axis=2)
self.transpose = ops.Transpose()
self.unsqueeze = ops.ExpandDims()
self.softmax = ops.Softmax(axis=2)
@staticmethod
def _build_index_block(inp, use_nonlinear, use_context):
"""
Build IndexNet block.
Args:
inp (int): Input channels.
use_nonlinear (bool): Use nonlinear in index blocks.
use_context (bool): Use context in index blocks.
Returns:
block: Inited index block.
"""
if use_context:
k_s, pad = 4, 1
else:
k_s, pad = 2, 0
if use_nonlinear:
return nn.SequentialCell(
[
nn.Conv2d(inp, inp, kernel_size=k_s, stride=2, pad_mode='pad', padding=pad, has_bias=False),
nn.BatchNorm2d(inp),
nn.ReLU6(),
nn.Conv2d(inp, inp, kernel_size=1, stride=1, pad_mode='pad', padding=0, has_bias=False)
]
)
return nn.SequentialCell(
[
nn.Conv2d(inp, inp, kernel_size=k_s, stride=2, pad_mode='pad', padding=pad, has_bias=False)
]
)
def depth_to_space(self, input_x, kh=2, kw=2):
"""
Change depth of the tensor.
BS x C x H x W -> BS x C/kh/kw x H*kh x W*kw.
Args:
input_x: Input tensor. Shape BS x C x H x W.
kh: Scaling for height dim.
kw: Scaling for width dim.
Returns:
output_x: Output tensor. Shape BS x C/kh/kw x H*kh x W*kw.
"""
_, c, h, w = input_x.shape
nc = c // kh // kw
output_x = self.reshape(input_x, (-1, nc, kh, kw, h, w))
output_x = self.transpose(output_x, (0, 1, 4, 2, 5, 3))
output_x = self.reshape(output_x, (-1, nc, h * kh, w * kw))
return output_x
def construct(self, x):
"""
Block feed forward.
Args:
x: Input feature map.
Returns:
idx_en: Predicted indices to encoder stage.
idx_de: Predicted indices to decoder stage.
"""
bs, c, h, w = x.shape
x1 = self.unsqueeze(self.indexnet1(x), 2)
x2 = self.unsqueeze(self.indexnet2(x), 2)
x3 = self.unsqueeze(self.indexnet3(x), 2)
x4 = self.unsqueeze(self.indexnet4(x), 2)
x = self.concat((x1, x2, x3, x4))
# normalization
y = self.sigmoid(x)
z = self.softmax(y)
# pixel shuffling
y = self.reshape(y, (bs, c * 4, h // 2, w // 2))
z = self.reshape(z, (bs, c * 4, h // 2, w // 2))
idx_en = self.depth_to_space(z)
idx_de = self.depth_to_space(y)
return idx_en, idx_de
class InvertedResidual(nn.Cell):
"""
Inverted residual block.
Args:
inp (int): Block input channels.
oup (int): Block output channels.
stride (int): Depthwise conv stride.
dilation (int): Depthwise conv dilation.
expand_ratio (int): Hidden channels ratio.
"""
def __init__(self, inp, oup, stride, dilation, expand_ratio):
super().__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = round(inp * expand_ratio)
self.use_res_connect = self.stride == 1 and inp == oup
self.kernel_size = 3
self.dilation = dilation
if expand_ratio == 1:
self.conv = nn.SequentialCell(
[
# dw
nn.Conv2d(
hidden_dim,
hidden_dim,
kernel_size=3,
stride=stride,
pad_mode='pad',
padding=0,
dilation=dilation,
group=hidden_dim,
has_bias=False
),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(),
# pw-linear
nn.Conv2d(
hidden_dim,
oup,
kernel_size=1,
stride=1,
pad_mode='pad',
padding=0,
has_bias=False
),
nn.BatchNorm2d(oup),
]
)
else:
self.conv = nn.SequentialCell(
[
# pw
nn.Conv2d(inp, hidden_dim, 1, 1, pad_mode='pad', padding=0, has_bias=False),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(),
# dw
nn.Conv2d(
hidden_dim,
hidden_dim,
3,
stride,
pad_mode='pad',
padding=0,
dilation=dilation,
group=hidden_dim,
has_bias=False,
),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(),
# pw-linear
nn.Conv2d(
hidden_dim,
oup,
1,
1,
pad_mode='pad',
padding=0,
has_bias=False,
),
nn.BatchNorm2d(oup),
]
)
@staticmethod
def fixed_padding(inputs, kernel_size, dilation):
kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
pad_total = kernel_size_effective - 1
pad_beg = pad_total // 2
pad_end = pad_total - pad_beg
padded_inputs = nn.Pad(paddings=((0, 0), (0, 0), (pad_beg, pad_end), (pad_beg, pad_end)))(inputs)
return padded_inputs
def construct(self, x):
x_pad = self.fixed_padding(x, self.kernel_size, self.dilation)
if self.use_res_connect:
return x + self.conv(x_pad)
return self.conv(x_pad)
class _ASPPModule(nn.Cell):
"""ASPP module."""
def __init__(self, inp, planes, kernel_size, padding, dilation):
super().__init__()
if kernel_size == 1:
self.atrous_conv = nn.SequentialCell(
[
nn.Conv2d(inp, planes, 1, 1, pad_mode='pad', padding=padding, dilation=dilation, has_bias=False),
nn.BatchNorm2d(planes),
nn.ReLU6()
]
)
elif kernel_size == 3:
# depth-wise separable convolution to save the number of parameters
self.atrous_conv = depth_sep_dilated_conv_3x3_bn(inp, planes, padding, dilation)
def construct(self, x):
x = self.atrous_conv(x)
return x
class ASPP(nn.Cell):
"""
ASPP block.
Args:
inp (int): Block input channels.
oup (int): Block output channels.
output_stride (int): Output image stride.
width_mult (float): Hidden layers ratio.
"""
def __init__(self, inp, oup, output_stride, width_mult):
super().__init__()
if output_stride == 32:
dilations = [1, 2, 4, 8]
elif output_stride == 16:
dilations = [1, 6, 12, 18]
elif output_stride == 8:
dilations = [1, 12, 24, 36]
else:
raise NotImplementedError
self.aspp1 = _ASPPModule(inp, int(256 * width_mult), 1, padding=0, dilation=dilations[0])
self.aspp2 = _ASPPModule(inp, int(256 * width_mult), 3, padding=dilations[1], dilation=dilations[1])
self.aspp3 = _ASPPModule(inp, int(256 * width_mult), 3, padding=dilations[2], dilation=dilations[2])
self.aspp4 = _ASPPModule(inp, int(256 * width_mult), 3, padding=dilations[3], dilation=dilations[3])
self.adaptive_pooling = ops.AdaptiveAvgPool2D((1, 1))
self.global_avg_pool = nn.SequentialCell(
[
nn.Conv2d(inp, int(256 * width_mult), 1, stride=1, pad_mode='pad', padding=0, has_bias=False),
nn.BatchNorm2d(int(256 * width_mult)),
nn.ReLU6()
]
)
self.bottleneck_conv = nn.SequentialCell(
[
nn.Conv2d(int(256 * width_mult) * 5, oup, 1, stride=1, pad_mode='pad', padding=0, has_bias=False),
nn.BatchNorm2d(oup),
nn.ReLU6()
]
)
self.dropout = nn.Dropout(0.5)
self.concat = ops.Concat(axis=1)
def construct(self, x):
"""
ASPP block feed forward.
Args:
x: Input feature map.
Returns:
x: Output feature map of ASPP block.
"""
x1 = self.aspp1(x)
x2 = self.aspp2(x)
x3 = self.aspp3(x)
x4 = self.aspp4(x)
x_5 = self.adaptive_pooling(x)
x5 = self.global_avg_pool(x_5)
x5 = ops.ResizeNearestNeighbor(size=x4.shape[2:])(x5)
x = self.concat((x1, x2, x3, x4, x5))
x = self.bottleneck_conv(x)
return self.dropout(x)
class IndexedUpsamlping(nn.Cell):
"""
Upsampling by index block.
Args:
inp (int): Block input channels.
oup (int): Block output channels.
conv_operator (str): Name of block conv operator.
kernel_size (int): Kernel size of block convs.
"""
def __init__(self, inp, oup, conv_operator, kernel_size):
super().__init__()
self.oup = oup
hlconv2d = hlconv[conv_operator]
# inp, oup, kernel_size, stride, batch_norm
self.dconv = hlconv2d(inp, oup, kernel_size, 1)
self.concat = ops.Concat(axis=1)
def construct(self, l_encode, l_low, indices=None):
if indices is not None:
l_encode = indices * ops.ResizeNearestNeighbor(size=l_low.shape[2:])(l_encode)
l_cat = self.concat((l_encode, l_low))
return self.dconv(l_cat)
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Utils scripts."""
import cv2
import numpy as np
from mindspore import numpy as msnp
from scipy.ndimage import gaussian_filter
from skimage.measure import label
from skimage.measure import regionprops
def compute_sad_loss(pd, gt, mask):
"""
Compute the SAD error given a prediction, a ground truth and a mask.
Args:
pd (np.array): Predicted alpha mask.
gt (np.array): Groundtruth alpha mask.
mask (np.array): Unknown region of trimap mask.
Returns:
loss (float): Computed SAD loss.
"""
cv2.normalize(pd, pd, 0.0, 255.0, cv2.NORM_MINMAX, dtype=cv2.CV_32F)
cv2.normalize(gt, gt, 0.0, 255.0, cv2.NORM_MINMAX, dtype=cv2.CV_32F)
error_map = np.abs(pd - gt) / 255.
loss = np.sum(error_map * mask)
# the loss is scaled by 1000 due to the large images
loss = loss / 1000
return loss
def compute_mse_loss(pd, gt, mask):
"""
Compute the MSE error.
Args:
pd (np.array): Predicted alpha mask.
gt (np.array): Groundtruth alpha mask.
mask (np.array): Unknown region of trimap mask.
Returns:
loss (float): Computed MSE loss.
"""
cv2.normalize(pd, pd, 0.0, 255.0, cv2.NORM_MINMAX, dtype=cv2.CV_32F)
cv2.normalize(gt, gt, 0.0, 255.0, cv2.NORM_MINMAX, dtype=cv2.CV_32F)
error_map = (pd - gt) / 255.
loss = np.sum(np.square(error_map) * mask) / np.sum(mask)
return loss
def compute_gradient_loss(pd, gt, mask):
"""
Compute the gradient error.
Args:
pd (np.array): Predicted alpha mask.
gt (np.array): Groundtruth alpha mask.
mask (np.array): Unknown region of trimap mask.
Returns:
loss (float): Computed Grad loss.
"""
cv2.normalize(pd, pd, 0.0, 255.0, cv2.NORM_MINMAX, dtype=cv2.CV_32F)
cv2.normalize(gt, gt, 0.0, 255.0, cv2.NORM_MINMAX, dtype=cv2.CV_32F)
pd = pd / 255.
gt = gt / 255.
pd_x = gaussian_filter(pd, sigma=1.4, order=[1, 0], output=np.float32)
pd_y = gaussian_filter(pd, sigma=1.4, order=[0, 1], output=np.float32)
gt_x = gaussian_filter(gt, sigma=1.4, order=[1, 0], output=np.float32)
gt_y = gaussian_filter(gt, sigma=1.4, order=[0, 1], output=np.float32)
pd_mag = np.sqrt(pd_x ** 2 + pd_y ** 2)
gt_mag = np.sqrt(gt_x ** 2 + gt_y ** 2)
error_map = np.square(pd_mag - gt_mag)
loss = np.sum(error_map * mask) / 10
return loss
def compute_connectivity_loss(pd, gt, mask, step=0.1):
"""
Compute the connectivity error.
Args:
pd (np.array): Predicted alpha mask.
gt (np.array): Groundtruth alpha mask.
mask (np.array): Unknown region of trimap mask.
step (float): Threshold steps.
Returns:
loss (float): Computed Conn loss.
"""
cv2.normalize(pd, pd, 0, 255, cv2.NORM_MINMAX)
cv2.normalize(gt, gt, 0, 255, cv2.NORM_MINMAX)
pd = pd / 255.
gt = gt / 255.
h, w = pd.shape
thresh_steps = np.arange(0, 1.1, step)
l_map = -1 * np.ones((h, w), dtype=np.float32)
for i in range(1, thresh_steps.size):
pd_th = pd >= thresh_steps[i]
gt_th = gt >= thresh_steps[i]
label_image = label(pd_th & gt_th, connectivity=1)
cc = regionprops(label_image)
size_vec = np.array([c.area for c in cc])
if size_vec.size == 0:
continue
max_id = np.argmax(size_vec)
coords = cc[max_id].coords
omega = np.zeros((h, w), dtype=np.float32)
omega[coords[:, 0], coords[:, 1]] = 1
flag = (l_map == -1) & (omega == 0)
l_map[flag == 1] = thresh_steps[i - 1]
l_map[l_map == -1] = 1
# the definition of lambda is ambiguous
d_pd = pd - l_map
d_gt = gt - l_map
phi_pd = 1 - d_pd * (d_pd >= 0.15).astype(np.float32)
phi_gt = 1 - d_gt * (d_gt >= 0.15).astype(np.float32)
loss = np.sum(np.abs(phi_pd - phi_gt) * mask) / 1000
return loss
def image_alignment(x, output_stride, odd=False):
"""
Resize inputs corresponds to stride.
Args:
x: Raw model inputs.
output_stride: Output image stride.
odd: Odd of the inputs size.
Returns:
new_x: Resized inputs.
"""
imsize = np.asarray(x.shape[:2], dtype=np.float)
if odd:
new_imsize = np.ceil(imsize / output_stride) * output_stride + 1
else:
new_imsize = np.ceil(imsize / output_stride) * output_stride
h, w = int(new_imsize[0]), int(new_imsize[1])
x1 = x[:, :, 0:3]
x2 = x[:, :, 3]
new_x1 = cv2.resize(x1, dsize=(w, h), interpolation=cv2.INTER_CUBIC)
new_x2 = cv2.resize(x2, dsize=(w, h), interpolation=cv2.INTER_NEAREST)
new_x2 = np.expand_dims(new_x2, axis=2)
new_x = np.concatenate((new_x1, new_x2), axis=2)
return new_x
def image_rescale(x, scale):
"""
Rescale inputs.
Args:
x (dict): Raw model input.
scale (float): Scale ratio.
Returns:
new_x (dict): Resized inputs.
"""
x1 = x[:, :, 0:3]
x2 = x[:, :, 3]
new_x1 = cv2.resize(x1, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
new_x2 = cv2.resize(x2, None, fx=scale, fy=scale, interpolation=cv2.INTER_NEAREST)
new_x2 = np.expand_dims(new_x2, axis=2)
new_x = np.concatenate((new_x1, new_x2), axis=2)
return new_x
def compute_loss(diff, diff_mask, bs, epsilon=1e-6):
"""
Compute MRSE loss with mask.
Args:
diff (np.array):
diff_mask (np.array):
bs (int): Batch size of array.
epsilon (float): Additive value to except division by zero error.
Returns:
loss: Computed MRSE loss.
"""
loss = msnp.sqrt(diff * diff + epsilon ** 2)
loss = loss.sum(axis=2).sum(axis=2) / diff_mask.sum(axis=2).sum(axis=2)
loss = loss.sum() / bs
return loss
def weighted_loss(pd, mask, alpha, fg, bg, c_g, wl=0.5):
"""
Compute weighted loss of alpha prediction and the composition during training.
Args:
pd: Input 4 channel image.
mask: Image mask.
alpha: Image trimap.
fg: Original foreground image.
bg: Original background image.
c_g: Merged by mask foreground over background.
wl (float): Loss weight.
Returns:
output: Weighted loss.
"""
bs, _, h, w = pd.shape
mask = mask.reshape((bs, 1, h, w))
alpha_gt = alpha.reshape((bs, 1, h, w))
diff_alpha = (pd - alpha_gt) * mask
c_p = pd * fg + (1 - pd) * bg
diff_color = (c_p - c_g) * mask
loss_alpha = compute_loss(diff_alpha, mask, bs)
loss_composition = compute_loss(diff_color, mask, bs)
return wl * loss_alpha + (1 - wl) * loss_composition
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Training script."""
import numpy as np
from mindspore import Model
from mindspore import Parameter
from mindspore import context
from mindspore import dtype as mstype
from mindspore import load_checkpoint
from mindspore import load_param_into_net
from mindspore import nn
from mindspore import ops
from mindspore.common import set_seed
from mindspore.communication.management import get_group_size
from mindspore.communication.management import get_rank
from mindspore.communication.management import init
from mindspore.context import ParallelMode
from mindspore.dataset import GeneratorDataset
from mindspore.train.callback import CheckpointConfig
from mindspore.train.callback import LossMonitor
from mindspore.train.callback import ModelCheckpoint
from mindspore.train.callback import TimeMonitor
from src.cfg.config import config as default_config
from src.dataset import ImageMattingDatasetTrain
from src.model import LossWrapper
from src.model import MobileNetV2UNetDecoderIndexLearning
from src.utils import weighted_loss
def set_lr(cfg, steps_per_epoch):
"""
Set lr for each step of training.
Args:
cfg: Config parameters.
steps_per_epoch (int): Number of batches into one epoch on one device.
Returns:
lr_each_step (np.array): Learning rate for every step during training.
"""
base_lr = cfg.learning_rate
total_steps = int(cfg.epochs * steps_per_epoch)
milestone_1 = cfg.milestones[0]
milestone_2 = cfg.milestones[1]
lr_decay = cfg.lr_decay
lr_each_step = []
for i in range(total_steps):
if i < steps_per_epoch * milestone_1:
lr5 = base_lr
elif steps_per_epoch * milestone_1 <= i < steps_per_epoch * (milestone_1 + 1):
lr5 = base_lr * lr_decay * 0.1
elif steps_per_epoch * (milestone_1 + 1) <= i < steps_per_epoch * milestone_2:
lr5 = base_lr * lr_decay
elif steps_per_epoch * milestone_2 <= i < steps_per_epoch * (milestone_2 + 1):
lr5 = base_lr * lr_decay ** 2 * 0.1
elif steps_per_epoch * (milestone_2 + 1) <= i:
lr5 = base_lr * lr_decay ** 2
lr_each_step.append(lr5)
return np.array(lr_each_step, np.float32)
def load_pretrained(network, backbone_names, other_names, ckpt_url):
"""
Load weights from pretrained backbone.
Args:
network: Inited model.
backbone_names (list): Backbone parameters names.
other_names (list): Network parameters names except backbone names.
ckpt_url (str): Path to pretrained backbone checkpoint.
"""
model_inited_params = dict(network.parameters_and_names())
mobilenet_params = load_checkpoint(ckpt_url, filter_prefix=['moments', 'head'])
mobilenet_names = list(mobilenet_params.keys())[:-3]
clear_names = []
for name in mobilenet_names:
if name.startswith('features.18'):
continue
clear_names.append(name)
strict_names = []
for name in clear_names:
if name.endswith('beta'):
strict_names.append(name.replace('beta', 'moving_mean'))
strict_names.append(name.replace('beta', 'moving_variance'))
strict_names.append(name.replace('beta', 'gamma'))
strict_names.append(name)
elif name.endswith('weight'):
strict_names.append(name)
state_dict = {}
for net_name, mobil_name in zip(backbone_names, strict_names):
weight = mobilenet_params[mobil_name][:]
if mobil_name == 'features.0.features.0.weight':
expand_weight = ops.Zeros()((32, 1, 3, 3), mstype.float32)
weight = ops.Concat(axis=1)((weight, expand_weight))
model_param = Parameter(weight, name=net_name)
state_dict[net_name] = model_param
for name in other_names:
state_dict[name] = model_inited_params[name]
load_param_into_net(network, state_dict, strict_load=False)
def set_context(cfg):
"""
Set process context.
Args:
cfg: Config parameters.
Returns:
dev_target (str): Device target platform.
dev_num (int): Amount of devices participating in process.
dev_id (int): Current process device id..
"""
dev_target = cfg.device_target
context.set_context(mode=context.GRAPH_MODE, device_target=dev_target)
if dev_target == 'GPU':
if cfg.is_distributed:
init(backend_name='nccl')
dev_num = get_group_size()
dev_id = get_rank()
context.reset_auto_parallel_context()
context.set_auto_parallel_context(
device_num=dev_num,
parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True,
)
else:
dev_num = 1
dev_id = cfg.device_id
context.set_context(device_id=dev_id)
else:
raise ValueError("Unsupported platform.")
return dev_num, dev_id
def init_callbacks(cfg, batch_number, dev_id, network):
"""
Initialize training callbacks.
Args:
cfg: Config parameters.
batch_number: Number of batches into one epoch on one device.
dev_id: Current process device id.
network: Network to be save into checkpoint.
Returns:
cbs: Inited callbacks.
"""
loss_cb = LossMonitor(per_print_times=100)
time_cb = TimeMonitor(data_size=batch_number)
if cfg.is_distributed and dev_id != cfg.device_start:
cbs = [loss_cb, time_cb]
else:
config_ck = CheckpointConfig(
save_checkpoint_steps=batch_number,
keep_checkpoint_max=cfg.keep_checkpoint_max,
saved_network=network,
)
ckpt_cb = ModelCheckpoint(
prefix="IndexNet",
directory=cfg.logs_dir,
config=config_ck,
)
cbs = [loss_cb, time_cb, ckpt_cb]
return cbs
def train(config):
"""
Init model, dataset, run training.
Args:
config: Config parameters.
"""
rank_size, rank_id = set_context(config)
data = ImageMattingDatasetTrain(
data_dir=config.data_dir,
bg_dir=config.bg_dir,
config=config,
sub_folder='train',
data_file='data.txt',
)
net = MobileNetV2UNetDecoderIndexLearning(
encoder_rate=config.rate,
encoder_current_stride=config.current_stride,
encoder_settings=config.inverted_residual_setting,
output_stride=config.output_stride,
width_mult=config.width_mult,
conv_operator=config.conv_operator,
decoder_kernel_size=config.decoder_kernel_size,
apply_aspp=config.apply_aspp,
use_nonlinear=config.use_nonlinear,
use_context=config.use_context,
)
net_with_loss = LossWrapper(model=net, loss_function=weighted_loss)
net_with_loss.set_train(True)
dataloader = GeneratorDataset(
source=data,
column_names=['image', 'mask', 'alpha', 'fg', 'bg', 'c_g'],
shuffle=True,
num_parallel_workers=config.num_workers,
python_multiprocessing=True,
num_shards=rank_size,
shard_id=rank_id,
)
dataloader = dataloader.batch(config.batch_size, True)
batch_num = dataloader.get_dataset_size()
pretrained_params = []
pretrained_names = []
learning_params = []
learning_names = []
for p in net_with_loss.parameters_and_names():
if 'dconv' in p[0] or 'pred' in p[0] or 'index' in p[0]:
if p[1].requires_grad:
learning_params.append(p[1])
learning_names.append(p[0])
else:
if p[1].requires_grad:
pretrained_params.append(p[1])
pretrained_names.append(p[0])
load_pretrained(net_with_loss, pretrained_names, learning_names, config.ckpt_url)
lr_steps = set_lr(config, batch_num)
opt = nn.Adam(
[
{'params': learning_params, 'lr': lr_steps},
{'params': pretrained_params, 'lr': lr_steps / config.backbone_lr_mult},
],
learning_rate=lr_steps
)
model = Model(net_with_loss, optimizer=opt)
callbacks = init_callbacks(config, batch_num, rank_id, net)
model.train(epoch=config.epochs, train_dataset=dataloader, callbacks=callbacks, dataset_sink_mode=False)
print("train success")
if __name__ == '__main__':
set_seed(1)
train(config=default_config)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment