Skip to content
Snippets Groups Projects
Commit 0a250c2b authored by Andrey Denisov's avatar Andrey Denisov
Browse files

Models: MIMO-UNet - Added implementation

parent 030435a4
No related branches found
No related tags found
No related merge requests found
Showing
with 1761 additions and 0 deletions
# Contents
- [Contents](#contents)
- [MIMO-UNet Description](#mimo-unet-description)
- [Model-architecture](#model-architecture)
- [Dataset](#dataset)
- [Environmental requirements](#environmental-requirements)
- [Quickstart](#quickstart)
- [Script description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Training process](#training-process)
- [Standalone training](#training)
- [Distributed training](#distributed-training)
- [Evaluation process](#evaluation-process)
- [Evaluate](#evaluate)
- [Inference process](#inference-process)
- [Export MindIR](#export-mindir)
- [Model description](#model-description)
- [Performance](#performance)
- [Training performance](#training-performance)
- [Evaluation performance](#evaluation-performance)
- [Description of Random Situation](#contents)
- [ModelZoo homepage](#modelzoo-homepage)
# [MIMO-UNet Description](#contents)
Coarse-to-fine strategies have been extensively used for the architecture design of single image deblurring networks.
Conventional methods typically stack sub-networks with multi-scale input images and gradually improve sharpness of
images from the bottom sub-network to the top sub-network, yielding inevitably high computational costs. Toward a fast
and accurate deblurring network design, we revisit the coarse-to-fine strategy and present a multi-input multi-output
U-net (MIMO-UNet). First, the single encoder of the MIMO-UNet takes multi-scale input images to ease the difficulty
of training. Second, the single decoder of the MIMO-UNet outputs multiple deblurred images with different scales to
mimic multi-cascaded U-nets using a single U-shaped network. Last, asymmetric feature fusion is introduced to merge
multi-scale features in an efficient manner. Extensive experiments on the GoPro and RealBlur datasets demonstrate that
the proposed network outperforms the state-of-the-art methods in terms of both accuracy and computational complexity.
[Paper](https://arxiv.org/abs/2108.05054): Rethinking Coarse-to-Fine Approach in Single Image Deblurring.
[Reference github repository](https://github.com/chosj95/MIMO-UNet)
# [Model architecture](#contents)
The architecture of MIMO-UNet is based on a single U-Net with significant modifications for efficient multi-scale
deblurring. The encoder and decoder of MIMO-UNet are composed of three encoder blocks (EBs) and decoder blocks (DBs)
that use convolutional layers to extract features from different stages.
# [Dataset](#contents)
## Dataset used
Dataset link (Google Drive): [GOPRO_Large](https://drive.google.com/file/d/1y4wvPdOG3mojpFCHTqLgriexhbjoWVkK/view?usp=sharing)
GOPRO_Large dataset is proposed for dynamic scene deblurring. Training and Test set are publicly available.
- Dataset size: ~6.2G
- Train: 3.9G, 2103 image pairs
- Test: 2.3G, 1111 image pairs
- Data format: Images
- Note: Data will be processed in src/data_augment.py and src/data_load.py
## Dataset organize way
```text
.
└─ GOPRO_Large
├─ train
│ ├─ GOPR0xxx_xx_xx
│ │ ├─ blur
│ │ │ ├─ ***.png
│ │ │ └─ ...
│ │ ├─ blur_gamma
│ │ │ ├─ ***.png
│ │ │ └─ ...
│ │ ├─ sharp
│ │ │ ├─ ***.png
│ │ │ └─ ...
│ │ └─ frames X offset X.txt
│ └─ ...
└─ test
├─ GOPR0xxx_xx_xx
│ ├─ blur
│ │ ├─ ***.png
│ │ └─ ...
│ ├─ blur_gamma
│ │ ├─ ***.png
│ │ └─ ...
│ └─ sharp
│ ├─ ***.png
│ └─ ...
└─ ...
```
## Dataset preprocessing
After downloading the dataset, run the `preprocessing.py` script located in the folder `src`.
Below is the file structure of the downloaded dataset.
Parameter description:
- `--root_src` - Path to the original dataset root, containing `train/` and `test/` folders.
- `--root_dst` - Path to the directory, where the pre-processed dataset will be stored.
```bash
python src/preprocessing.py --root_src /path/to/original/dataset/root --root_dst /path/to/preprocessed/dataset/root
```
### Dataset organize way after preprocessing
In the example above, after the test script is executed, the pre-processed images will be stored under
the /path/to/preprocessed/dataset/root path. Below is the file structure of the preprocessed dataset.
```text
.
└─ GOPRO_preprocessed
├─ train
│ ├─ blur
│ │ ├─ 1.png
│ │ ├─ ...
│ │ └─ 2103.png
│ └─ sharp
│ ├─ 1.png
│ ├─ ...
│ └─ 2103.png
└─ test
├─ blur
│ ├─ 1.png
│ ├─ ...
│ └─ 1111.png
└─ sharp
├─ 1.png
├─ ...
└─ 1111.png
```
# [Environmental requirements](#contents)
- Hardware (GPU)
- Prepare hardware environment with GPU processor
- Framework
- [MindSpore](https://www.mindspore.cn/install)
- For details, see the following resources:
- [MindSpore Tutorial](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
- Additional python packages:
- Pillow
- scikit-image
- PyYAML
Install additional packages manually or using `pip install -r requirements.txt` command in the model directory.
## [Quick start](#contents)
After installing MindSpore via the official website and additional packages, you can start training and evaluation as
follows:
- Running on GPU
```bash
# run the training example
python ./train.py --dataset_root /path/to/dataset/root --ckpt_save_directory /save/checkpoint/directory
# or
bash scripts/run_standalone_train_gpu.sh /path/to/dataset/root /save/checkpoint/directory
# run the distributed training example
bash scripts/run_distribute_train_gpu.sh /path/to/dataset/root /save/checkpoint/directory
# run the evaluation example
python ./eval.py --dataset_root /path/to/dataset/root \
--ckpt_file /path/to/eval/checkpoint.ckpt \
--img_save_directory /path/to/result/images
# or
bash scripts/run_eval_gpu.sh /path/to/dataset/root /path/to/eval/checkpoint.ckpt /path/to/result/images
```
# [Script description](#contents)
## [Script and sample code](#contents)
```text
.
└─ cv
└─ MIMO-UNet
├── configs
├── gpu_config.yaml # Config for training on GPU
├── scripts
├── run_distribute_train_gpu.sh # Distributed training on GPU shell script
├── run_standalone_train_gpu.sh # Shell script for single GPU training
├── run_eval_gpu.sh # GPU evaluation script
├─ src
├─ data_augment.py # Augmentation
├─ data_load.py # Dataloader
├─ init_weights.py # Weights initializers
├─ layers.py # Model layers
├─ loss.py # Loss function
├─ metric.py # Metrics
├─ mimo_unet.py # MIMO-UNet architecture
├─ preprocessing.py
├─ eval.py # test script
├─ train.py # train script
├─ export.py # export script
├─ requirements.txt # requirements file
└─ README.md # MIMO-UNet file English description
```
## [Training process](#contents)
### [Standalone training](#contents)
- Running on GPU
Description of parameters:
- `--dataset_root` - Path to the dataset root, containing `train/` and `test/` folders
- `--ckpt_save_directory` - Output directory, where the data from the train process will be stored
```bash
python ./train.py --dataset_root /path/to/dataset/root --ckpt_save_directory /save/checkpoint/directory
# or
bash scripts/run_standalone_train_gpu.sh [DATASET_PATH] [OUTPUT_CKPT_DIR]
```
- DATASET_PATH - Path to the dataset root, containing `train/` and `test/` folders
- OUTPUT_CKPT_DIR - Output directory, where the data from the train process will be stored
### [Distributed training](#contents)
- Running on GPU
```bash
bash scripts/run_distribute_train_gpu.sh [DATASET_PATH] [OUTPUT_CKPT_DIR]
```
- DATASET_PATH - Path to the dataset root, containing `train/` and `test/` folders
- OUTPUT_CKPT_DIR - Output directory, where the data from the train process will be stored
## [Evaluation process](#contents)
### [Evaluate](#contents)
Calculate PSNR metric and save deblured images.
When evaluating, select the last generated checkpoint and pass it to the appropriate parameter of the validation script.
- Running on GPU
Description of parameters:
- `--dataset_root` - Path to the dataset root, containing `train/` and `test/` folders
- `--ckpt_file` - path to the checkpoint containing the weights of the trained model.
- `--img_save_directory` - Output directory, where the images from the validation process will be stored.
Optional parameter. If not specified, validation images will not be saved.
```bash
python ./eval.py --dataset_root /path/to/dataset/root \
--ckpt_file /path/to/eval/checkpoint.ckpt \
--img_save_directory /path/to/result/images # save validation images
# or
python ./eval.py --dataset_root /path/to/dataset/root \
--ckpt_file /path/to/eval/checkpoint.ckpt # don't save validation images
# or
bash scripts/run_eval_gpu.sh [DATASET_PATH] [CKPT_PATH] [SAVE_IMG_DIR] # save validation images
# or
bash scripts/run_eval_gpu.sh [DATASET_PATH] [CKPT_PATH] # don't save validation images
```
- DATASET_PATH - Path to the dataset root, containing `train/` and `test/` folders
- CKPT_PATH - path to the checkpoint containing the weights of the trained model.
- SAVE_IMG_DIR - Output directory, where the images from the validation process will be stored. Optional parameter. If not specified, validation images will not be saved.
After the test script is executed, the deblured images are stored in `/path/to/result/img/` if the path was specified.
## [Inference process](#contents)
### [Export MindIR](#contents)
```bash
python export.py --ckpt_file /path/to/mimounet/checkpoint.ckpt --export_device_target GPU --export_file_format MINDIR
```
The script will generate the corresponding MINDIR file in the current directory.
# [Model description](#contents)
## [Performance](#contents)
### [Training Performance](#contents)
| Parameters | MIMO-UNet (1xGPU) | MIMO-UNet (8xGPU) |
|----------------------------|-------------------------------------------------------|-------------------------------------------------------|
| Model Version | MIMO-UNet | MIMO-UNet |
| Resources | 1x NV RTX3090-24G | 8x NV RTX3090-24G |
| Uploaded Date | 04 / 12 / 2022 (month/day/year) | 04 / 12 / 2022 (month/day/year) |
| MindSpore Version | 1.6.1 | 1.6.1 |
| Dataset | GOPRO_Large | GOPRO_Large |
| Training Parameters | batch_size=4, lr=0.0001 and bisected every 500 epochs | batch_size=4, lr=0.0005 and bisected every 500 epochs |
| Optimizer | Adam | Adam |
| Outputs | images | images |
| Speed | 132 ms/step | 167 ms/step |
| Total time | 5d 6h 4m | 9h 15m |
| Checkpoint for Fine tuning | 26MB(.ckpt file) | 26MB(.ckpt file) |
### [Evaluation Performance](#contents)
| Parameters | MIMO-UNet (1xGPU) |
|-------------------|---------------------------------|
| Model Version | MIMO-UNet |
| Resources | 1x NV RTX3090-24G |
| Uploaded Date | 04 / 12 / 2022 (month/day/year) |
| MindSpore Version | 1.6.1 |
| Datasets | GOPRO_Large |
| Batch_size | 1 |
| Outputs | images |
| PSNR metric | 1p: 31.47, 8p: 31.27 |
# [Description of Random Situation](#contents)
In train.py, we set the seed inside the “train" function.
# [ModelZoo homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/models)
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
# Url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# Path for local
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path"
# ==============================================================================
# options
# Context options
device_target: "GPU"
device_id: 0
is_train_distributed: False
# Model options
model_name: "MIMO-UNet"
# Dataset options
dataset_root: "/path/to/dataset/"
train_batch_size: 4
# Logging options
ckpt_save_frequency: 100
ckpt_save_directory: "saving/ckpt/directory"
img_save_directory: ""
# Training options
learning_rate: 0.0001
num_worker: 8
epochs_num: 3000
train_use_data_sink: False
# Evaluation and export options
ckpt_file: "/path/to/trained/checkpoint.ckpt"
eval_use_data_sink: False
eval_batch_size: 1
export_batch_size: 1
export_file_name: "MIMO-UNet"
export_file_format: "MINDIR"
export_device_target: "GPU"
---
# Help description for each configuration
# Context options
device_target: "Device type which will be used for graph computations"
device_id: "Id of device which will be used for graph computations"
is_train_distributed: "Whether the training process is distributed among several devices"
# Model options
model_name: "Name of the model"
# Dataset options
dataset_root: "Path to the dataset root, containing train and test folders"
train_batch_size: "The batch size to be used for training"
# Logging options
ckpt_save_frequency: "Specifies the number epoch which must pass before saving a single checkpoint."
ckpt_save_directory: "Output directory, where the data from the train process will be stored."
img_save_directory: "Output directory, where the data from the validation process will be stored."
# Training options
learning_rate: "Learning rate"
epochs_num: "Number of the epochs"
train_use_data_sink: "Use data sink mode during the model training."
# Evaluation and export options
ckpt_file: "Path to the checkpoint containing the weights of the trained model."
eval_use_data_sink: "Use data sink mode during the model evaluation."
export_batch_size: "Batch size used for the exported model"
export_file_name: "Exported model file name"
export_file_format: "Format of the exported model"
export_device_target: "Device type which will be used for export"
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
eval MIMO_UNet
"""
import random
from pathlib import Path
import numpy as np
from PIL import Image
from mindspore import context
from mindspore import dataset as ds
from mindspore.common import set_seed
from mindspore.train import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.mimo_unet import MIMOUNet
from src.config import config
from src.data_load import create_dataset_generator
from src.loss import ContentLoss
from src.metric import PSNR
def run_eval(args):
"""eval"""
context.set_context(mode=context.GRAPH_MODE)
random.seed(1)
set_seed(1)
np.random.seed(1)
eval_dataset_generator = create_dataset_generator(Path(args.dataset_root, 'test'))
eval_dataset = ds.GeneratorDataset(eval_dataset_generator, ["image", "label"],
shuffle=False, num_parallel_workers=args.num_worker)
eval_dataset = eval_dataset.batch(batch_size=args.eval_batch_size, drop_remainder=True)
net = MIMOUNet()
param_dict = load_checkpoint(args.ckpt_file)
load_param_into_net(net, param_dict)
content_loss = ContentLoss()
model = Model(net, loss_fn=content_loss, metrics={"PSNR": PSNR()})
print("eval...")
results = model.eval(eval_dataset, dataset_sink_mode=False)
print(results)
if args.img_save_directory:
print("saving images...")
Path(args.img_save_directory).mkdir(parents=True, exist_ok=True)
ds_iter = eval_dataset.create_tuple_iterator()
for num, (image, _) in enumerate(ds_iter):
pred = net(image)[2]
pred = pred.clip(0, 1)
pred += 0.5 / 255
pred = pred.asnumpy()
pred = (pred.squeeze().transpose(1, 2, 0) * 255).astype(np.uint8)
im_pred = Image.fromarray(pred, 'RGB')
im_pred.save(Path(args.img_save_directory, f"{num}_pred.png"))
image = image.asnumpy()
image = (image.squeeze().transpose(1, 2, 0) * 255).astype(np.uint8)
image = Image.fromarray(image, 'RGB')
image.save(Path(args.img_save_directory, f"{num}_blur.png"))
return results
if __name__ == '__main__':
run_eval(config)
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
MIMO-UNet export mindir.
"""
import numpy as np
from mindspore import Tensor
from mindspore import context
from mindspore import export
from mindspore.train.serialization import load_checkpoint
from mindspore.train.serialization import load_param_into_net
from src.config import config
from src.mimo_unet import MIMOUNet
def run_export(args):
"""run export"""
context.set_context(mode=context.GRAPH_MODE, device_target=args.export_device_target)
context.set_context(device_id=args.device_id)
net = MIMOUNet()
param_dict = load_checkpoint(args.ckpt_file)
load_param_into_net(net, param_dict)
input_shp = [args.export_batch_size, 3, 256, 256]
input_array = Tensor(np.random.uniform(-1.0, 1.0, size=input_shp).astype(np.float32))
export(net, input_array, file_name=args.export_file_name, file_format=args.export_file_format)
if __name__ == '__main__':
run_export(config)
scikit-image==0.19.2
PyYAML==6.0
Pillow
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "===================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_distribute_train_gpu.sh [DATASET_PATH] [OUTPUT_CKPT_DIR]"
echo "for example: bash scripts/run_distribute_train_gpu.sh /path/to/dataset/root /save/checkpoint/directory"
echo "===================================================================================================="
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
realpath -m "$PWD/$1"
fi
}
DATASET_PATH=$(get_real_path "$1")
OUTPUT_CKPT_DIR=$(get_real_path "$2")
if [ ! -d "$DATASET_PATH" ] ; then
echo "Cannot find the specified dataset directory: $DATASET_PATH"
exit 1
fi
if [ -d logs ]
then
rm -r logs
fi
mkdir logs
mpirun --output-filename logs\
-np 8 --allow-run-as-root \
python ./train.py \
--dataset_root "$DATASET_PATH" \
--ckpt_save_directory "$OUTPUT_CKPT_DIR" \
--is_train_distributed True \
--learning_rate 0.0005 \
> ./logs/train.log 2>&1 &
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ] && [ $# != 2 ]
then
echo "===================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_eval_gpu.sh [DATASET_PATH] [CKPT_PATH] [SAVE_IMG_DIR](optional)"
echo "for example:"
echo "bash scripts/run_eval_gpu.sh /path/to/dataset/root /path/to/eval/checkpoint.ckpt /path/to/result/images"
echo "or"
echo "bash scripts/run_eval_gpu.sh /path/to/dataset/root /path/to/eval/checkpoint.ckpt"
echo "===================================================================================================="
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
realpath -m "$PWD/$1"
fi
}
DATASET_PATH=$(get_real_path "$1")
CKPT_PATH=$(get_real_path "$2")
if [ $# == 3 ]
then
SAVE_IMG_DIR=$(get_real_path "$3")
else
SAVE_IMG_DIR=""
fi
if [ ! -d "$DATASET_PATH" ] ; then
echo "Cannot find the specified dataset directory: $DATASET_PATH"
exit 1
fi
if [ ! -f "$CKPT_PATH" ] ; then
echo "Cannot find the specified checkpoint: $CKPT_PATH"
exit 1
fi
if [ -d eval_logs ]
then
rm -r eval_logs
fi
mkdir eval_logs
python eval.py --dataset_root "$DATASET_PATH" \
--ckpt_file "$CKPT_PATH" \
--img_save_directory "$SAVE_IMG_DIR" \
> ./eval_logs/eval.log 2>&1 &
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "===================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_standalone_train_gpu.sh [DATASET_PATH] [OUTPUT_CKPT_DIR]"
echo "for example: bash scripts/run_standalone_train_gpu.sh /path/to/dataset/root /save/checkpoint/directory"
echo "===================================================================================================="
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
realpath -m "$PWD/$1"
fi
}
DATASET_PATH=$(get_real_path "$1")
OUTPUT_CKPT_DIR=$(get_real_path "$2")
if [ ! -d "$DATASET_PATH" ] ; then
echo "Cannot find the specified dataset directory: $DATASET_PATH"
exit 1
fi
if [ -d logs ]
then
rm -r logs
fi
mkdir logs
python ./train.py --dataset_root "$DATASET_PATH" \
--ckpt_save_directory "$OUTPUT_CKPT_DIR" \
--is_train_distributed False \
--learning_rate 0.0001 \
> ./logs/train.log 2>&1 &
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from .layers import BasicConv, ResBlock
from .data_augment import PairRandomCrop, PairRandomHorizontalFlip
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse arguments"""
import argparse
import ast
from os import path
from pprint import pformat
import yaml
class Config:
"""
Configuration namespace. Convert dictionary to members.
"""
def __init__(self, cfg_dict):
for k, v in cfg_dict.items():
if isinstance(v, (list, tuple)):
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
else:
setattr(self, k, Config(v) if isinstance(v, dict) else v)
def __str__(self):
return pformat(self.__dict__)
def __repr__(self):
return self.__str__()
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
"""
Parse command line arguments to the configuration according to the default yaml.
Args:
parser: Parent parser.
cfg: Base configuration.
helper: Helper description.
cfg_path: Path to the default yaml config.
"""
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
parents=[parser])
helper = {} if helper is None else helper
choices = {} if choices is None else choices
for item in cfg:
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
choice = choices[item] if item in choices else None
if isinstance(cfg[item], bool):
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
help=help_description)
else:
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
help=help_description)
args = parser.parse_args()
return args
def parse_yaml(yaml_path):
"""
Parse the yaml config file.
Args:
yaml_path: Path to the yaml config.
"""
with open(yaml_path, 'r') as fin:
try:
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
cfgs = [x for x in cfgs]
if len(cfgs) == 1:
cfg_helper = {}
cfg = cfgs[0]
cfg_choices = {}
elif len(cfgs) == 2:
cfg, cfg_helper = cfgs
cfg_choices = {}
elif len(cfgs) == 3:
cfg, cfg_helper, cfg_choices = cfgs
else:
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
print(cfg_helper)
except:
raise ValueError("Failed to parse yaml")
return cfg, cfg_helper, cfg_choices
def merge(args, cfg):
"""
Merge the base config from yaml file and command line arguments.
Args:
args: Command line arguments.
cfg: Base configuration.
"""
args_var = vars(args)
for item in args_var:
cfg[item] = args_var[item]
return cfg
def get_config():
"""
Get Config according to the yaml file and cli arguments.
"""
parser = argparse.ArgumentParser(description="default name", add_help=False)
config_path = path.join(
path.dirname(path.abspath(__file__)),
"../configs",
"gpu_config.yaml",
)
parser.add_argument("--config_path", type=str, default=path.abspath(config_path),
help="Config file path")
path_args, _ = parser.parse_known_args()
default, helper, choices = parse_yaml(path_args.config_path)
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
final_config = merge(args, default)
return Config(final_config)
config = get_config()
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Augmentation
"""
import random
class PairRandomCrop:
"""pair random crop"""
def __init__(self, size=(256, 256)):
self.size = size
def __call__(self, image, label):
def _input_to_factor(img, size):
"""_input_to_factor"""
img_height, img_width, _ = img.shape
height, width = size
if height > img_height or width > img_width:
raise ValueError(f"Crop size {size} is larger than input image size {(img_height, img_width)}.")
if width == img_width and height == img_height:
return 0, 0, img_height, img_width
top = random.randint(0, img_height - height)
left = random.randint(0, img_width - width)
return top, left, height, width
y, x, h, w = _input_to_factor(image, self.size)
image, label = image[y:y+h, x:x+w], label[y:y+h, x:x+w]
assert image.shape == label.shape
return image, label
class PairRandomHorizontalFlip:
"""pair random horisontal flip"""
def __init__(self, prob=0.5):
self.prob = prob
def __call__(self, img, label):
"""
Args:
img (PIL Image): Image to be flipped.
Returns:
PIL Image: Randomly flipped image.
"""
if random.random() < self.prob:
return img[::, ::-1], label[::, ::-1]
return img, label
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Dataloader
"""
import os
import numpy as np
from PIL import Image
from src.data_augment import PairRandomCrop, PairRandomHorizontalFlip
class DeblurDatasetGenerator:
"""DeblurDatasetGenerator"""
def __init__(self, image_dir, make_aug=False):
self.image_dir = image_dir
self.image_list = os.listdir(os.path.join(image_dir, 'blur/'))
self._check_image(self.image_list)
self.image_list.sort()
self.random_horizontal_flip = PairRandomHorizontalFlip()
self.random_crop = PairRandomCrop()
self.make_aug = make_aug
def __len__(self):
"""get len"""
return len(self.image_list)
def __getitem__(self, idx):
"""get item"""
image = Image.open(os.path.join(self.image_dir, 'blur', self.image_list[idx]))
label = Image.open(os.path.join(self.image_dir, 'sharp', self.image_list[idx]))
image = np.asarray(image)
label = np.asarray(label)
if self.make_aug:
image, label = self.random_horizontal_flip(image, label)
image, label = self.random_crop(image, label)
image = image.astype(np.float32) / 255
label = label.astype(np.float32) / 255
image = image.transpose(2, 0, 1) # transform to chw format
label = label.transpose(2, 0, 1) # transform to chw format
return image, label
@staticmethod
def _check_image(lst):
"""check image format"""
for x in lst:
splits = x.split('.')
if splits[-1] not in ['png', 'jpg', 'jpeg']:
raise ValueError(f"{x} is not .png, .jpeg or .jpg image")
def create_dataset_generator(image_dir, make_aug=False):
"""create dataset generator"""
dataset_generator = DeblurDatasetGenerator(
image_dir,
make_aug=make_aug
)
return dataset_generator
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""init weights"""
import math
import numpy as np
from mindspore.common import initializer as init
from mindspore.common.initializer import _assignment
from mindspore.common.initializer import _calculate_correct_fan
from mindspore.common.initializer import _calculate_gain
class KaimingUniform(init.Initializer):
"""
Initialize the array with He kaiming algorithm.
Args:
a: the negative slope of the rectifier used after this layer (only
used with ``'leaky_relu'``)
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
preserves the magnitude of the variance of the weights in the
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
backwards pass.
nonlinearity: the non-linear function, recommended to use only with
``'relu'`` or ``'leaky_relu'`` (default).
"""
def __init__(self, a=math.sqrt(5), mode='fan_in', nonlinearity='leaky_relu'):
super().__init__()
self.mode = mode
self.gain = _calculate_gain(nonlinearity, a)
def _initialize(self, arr):
fan = _calculate_correct_fan(arr.shape, self.mode)
bound = math.sqrt(3.0) * self.gain / math.sqrt(fan)
data = np.random.uniform(-bound, bound, arr.shape)
_assignment(arr, data)
class UniformBias(init.Initializer):
"""bias uniform initializer"""
def __init__(self, shape, mode):
super().__init__()
self.mode = mode
self.shape = shape
def _initialize(self, arr):
fan_tgt = _calculate_correct_fan(self.shape, self.mode)
bound = 1 / math.sqrt(fan_tgt)
data = np.random.uniform(-bound, bound, arr.shape)
_assignment(arr, data)
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Basic layers
"""
import mindspore.nn as nn
from src.init_weights import KaimingUniform, UniformBias
class Conv2dLikeTorch(nn.Conv2d):
"""Conv2dTransposeLikeTorch"""
def __init__(self, in_channel, out_channel, kernel_size, pad_mode, padding, stride, has_bias):
initializer = KaimingUniform()
bias_initializer = UniformBias(shape=(out_channel, in_channel, kernel_size, kernel_size), mode='fan_in')
super().__init__(in_channel, out_channel, kernel_size, pad_mode=pad_mode, weight_init=initializer,
padding=padding, stride=stride, has_bias=has_bias, bias_init=bias_initializer)
class Conv2dTransposeLikeTorch(nn.Conv2dTranspose):
"""Conv2dTransposeLikeTorch"""
def __init__(self, in_channel, out_channel, kernel_size, pad_mode, padding, stride, has_bias):
initializer = KaimingUniform(mode='fan_in')
bias_initializer = UniformBias(shape=(out_channel, in_channel, kernel_size, kernel_size), mode='fan_in')
super().__init__(in_channel, out_channel, kernel_size, pad_mode=pad_mode, weight_init=initializer,
padding=padding, stride=stride, has_bias=has_bias, bias_init=bias_initializer)
class BasicConv(nn.Cell):
"""basic conv block"""
def __init__(self, in_channel, out_channel, kernel_size, stride, bias=True, norm=False, relu=True, transpose=False):
super().__init__()
if bias and norm:
bias = False
padding = kernel_size // 2
layers = list()
if transpose:
padding = kernel_size // 2 - 1
layers.append(
Conv2dTransposeLikeTorch(in_channel, out_channel, kernel_size, pad_mode='pad',
padding=padding, stride=stride, has_bias=bias)
)
else:
layers.append(
Conv2dLikeTorch(in_channel, out_channel, kernel_size, pad_mode='pad',
padding=padding, stride=stride, has_bias=bias)
)
if norm:
layers.append(nn.BatchNorm2d(out_channel))
if relu:
layers.append(nn.ReLU())
self.main = nn.SequentialCell(layers)
def construct(self, x):
"""construct basic conv block"""
return self.main(x)
class ResBlock(nn.Cell):
"""residual block"""
def __init__(self, in_channel, out_channel):
super().__init__()
self.main = nn.SequentialCell(
BasicConv(in_channel, out_channel, kernel_size=3, stride=1, relu=True),
BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False)
)
def construct(self, x):
"""construct residual block"""
return self.main(x) + x
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Loss function
"""
import mindspore.nn as nn
class ContentLoss(nn.Cell):
"""ContentLoss"""
def __init__(self):
super().__init__()
self.criterion = nn.L1Loss()
self.nn_interpolate = nn.ResizeBilinear()
def interpolate_downscale(self, x, scale_factor):
"""downscale"""
_, _, h, w = x.shape
h = h // scale_factor
w = w // scale_factor
return self.nn_interpolate(x, size=(h, w))
def construct(self, pred_img, label_img):
"""construct ContentLoss"""
label_img2 = self.interpolate_downscale(label_img, scale_factor=2)
label_img4 = self.interpolate_downscale(label_img, scale_factor=4)
l1 = self.criterion(pred_img[0], label_img4)
l2 = self.criterion(pred_img[1], label_img2)
l3 = self.criterion(pred_img[2], label_img)
return l1+l2+l3
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Metrics
"""
from mindspore import nn
from mindspore import ops
from skimage.metrics import peak_signal_noise_ratio
class PSNR(nn.Metric):
"""peak signal-noise ratio"""
def __init__(self):
super().__init__()
self.ops_sqrt = ops.Sqrt()
self.ops_max = ops.ReduceMax()
self.ops_log = ops.Log()
self.ops_se = ops.SquaredDifference()
self.ops_mean = ops.ReduceMean()
self.eps = 0.
self.psnr = 0.
self.total_num = 0
def clear(self):
"""clear"""
self.psnr = 0.
self.total_num = 0
def eval(self):
"""eval"""
if self.total_num == 0:
return 0
return self.psnr / self.total_num
def update(self, pred, label):
"""update"""
pred_numpy = pred[2].asnumpy()
label_numpy = label.asnumpy()
self.psnr += peak_signal_noise_ratio(pred_numpy, label_numpy, data_range=1)
self.total_num += 1
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
MIMO-UNet architecture
"""
import mindspore
from mindspore import nn
from mindspore import ops
from src.layers import BasicConv, ResBlock
class EBlock(nn.Cell):
"""EBlock"""
def __init__(self, out_channel, num_res=8):
super().__init__()
layers = [ResBlock(out_channel, out_channel) for _ in range(num_res)]
self.layers = nn.SequentialCell(*layers)
def construct(self, x):
"""construct EBlock"""
return self.layers(x)
class DBlock(nn.Cell):
"""DBlock"""
def __init__(self, channel, num_res=8):
super().__init__()
layers = [ResBlock(channel, channel) for _ in range(num_res)]
self.layers = nn.SequentialCell(*layers)
def construct(self, x):
"""construct DBlock"""
return self.layers(x)
class AFF(nn.Cell):
"""AFF"""
def __init__(self, in_channel, out_channel):
super().__init__()
self.conv = nn.SequentialCell(
BasicConv(in_channel, out_channel, kernel_size=1, stride=1, relu=True),
BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False)
)
self.cat = ops.Stack(axis=1)
def construct(self, x1, x2, x4):
"""construct AFF"""
x = ops.Concat(1)([x1, x2, x4])
return self.conv(x)
class SCM(nn.Cell):
"""SCM"""
def __init__(self, out_plane):
super().__init__()
self.main = nn.SequentialCell(
BasicConv(3, out_plane//4, kernel_size=3, stride=1, relu=True),
BasicConv(out_plane // 4, out_plane // 2, kernel_size=1, stride=1, relu=True),
BasicConv(out_plane // 2, out_plane // 2, kernel_size=3, stride=1, relu=True),
BasicConv(out_plane // 2, out_plane-3, kernel_size=1, stride=1, relu=True)
)
self.conv = BasicConv(out_plane, out_plane, kernel_size=1, stride=1, relu=False)
self.cat = ops.Stack(axis=1)
def construct(self, x):
"""construct SCM"""
y = self.main(x)
x = ops.Concat(1)([x, y])
return self.conv(x)
class FAM(nn.Cell):
"""FAM"""
def __init__(self, channel):
super().__init__()
self.merge = BasicConv(channel, channel, kernel_size=3, stride=1, relu=False)
def construct(self, x1, x2):
"""construct FAM"""
x = x1 * x2
out = x1 + self.merge(x)
return out
class MIMOUNet(nn.Cell):
"""MIMOUnet"""
def __init__(self, num_res=8):
super().__init__()
base_channel = 32
self.Encoder = nn.CellList([
EBlock(base_channel, num_res),
EBlock(base_channel*2, num_res),
EBlock(base_channel*4, num_res),
])
self.feat_extract = nn.CellList([
BasicConv(3, base_channel, kernel_size=3, relu=True, stride=1),
BasicConv(base_channel, base_channel*2, kernel_size=3, relu=True, stride=2),
BasicConv(base_channel*2, base_channel*4, kernel_size=3, relu=True, stride=2),
BasicConv(base_channel*4, base_channel*2, kernel_size=4, relu=True, stride=2, transpose=True),
BasicConv(base_channel*2, base_channel, kernel_size=4, relu=True, stride=2, transpose=True),
BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1)
])
self.Decoder = nn.CellList([
DBlock(base_channel * 4, num_res),
DBlock(base_channel * 2, num_res),
DBlock(base_channel, num_res)
])
self.Convs = nn.CellList([
BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1),
BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1),
])
self.ConvsOut = nn.CellList(
[
BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1),
BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1),
]
)
self.AFFs = nn.CellList([
AFF(base_channel * 7, base_channel*1),
AFF(base_channel * 7, base_channel*2)
])
self.FAM1 = FAM(base_channel * 4)
self.SCM1 = SCM(base_channel * 4)
self.FAM2 = FAM(base_channel * 2)
self.SCM2 = SCM(base_channel * 2)
self.cat = ops.Stack(axis=1)
self.nn_interpolate = nn.ResizeBilinear()
def interpolate(self, x, scale_factor):
"""interpolate"""
_, _, h, w = x.shape
h = ops.Cast()(h * scale_factor, mindspore.int32)[0]
w = ops.Cast()(w * scale_factor, mindspore.int32)[0]
return self.nn_interpolate(x, size=(h, w))
def interpolate_upscale(self, x, scale_factor):
"""upscale"""
_, _, h, w = x.shape
h = h * scale_factor
w = w * scale_factor
return self.nn_interpolate(x, size=(h, w))
def interpolate_downscale(self, x, scale_factor):
"""downscale"""
_, _, h, w = x.shape
h = h // scale_factor
w = w // scale_factor
return self.nn_interpolate(x, size=(h, w))
def construct(self, x):
"""construct MIMOUnet"""
x_2 = self.interpolate_downscale(x, scale_factor=2)
x_4 = self.interpolate_downscale(x_2, scale_factor=2)
z2 = self.SCM2(x_2)
z4 = self.SCM1(x_4)
outputs = []
x_ = self.feat_extract[0](x)
res1 = self.Encoder[0](x_)
z = self.feat_extract[1](res1)
z = self.FAM2(z, z2)
res2 = self.Encoder[1](z)
z = self.feat_extract[2](res2)
z = self.FAM1(z, z4)
z = self.Encoder[2](z)
z12 = self.interpolate_downscale(res1, scale_factor=2)
z21 = self.interpolate_upscale(res2, scale_factor=2)
z42 = self.interpolate_upscale(z, scale_factor=2)
z41 = self.interpolate_upscale(z42, scale_factor=2)
res2 = self.AFFs[1](z12, res2, z42)
res1 = self.AFFs[0](res1, z21, z41)
z = self.Decoder[0](z)
z_ = self.ConvsOut[0](z)
z = self.feat_extract[3](z)
outputs.append(z_+x_4)
z = ops.Concat(1)([z, res2])
z = self.Convs[0](z)
z = self.Decoder[1](z)
z_ = self.ConvsOut[1](z)
z = self.feat_extract[4](z)
outputs.append(z_+x_2)
z = ops.Concat(1)([z, res1])
z = self.Convs[1](z)
z = self.Decoder[2](z)
z = self.feat_extract[5](z)
outputs.append(z+x)
return outputs
class MIMOUNetPlus(nn.Cell):
"""MIMOUNetPlus"""
def __init__(self, num_res=20):
super().__init__()
base_channel = 32
self.Encoder = nn.CellList([
EBlock(base_channel, num_res),
EBlock(base_channel*2, num_res),
EBlock(base_channel*4, num_res),
])
self.feat_extract = nn.CellList([
BasicConv(3, base_channel, kernel_size=3, relu=True, stride=1),
BasicConv(base_channel, base_channel*2, kernel_size=3, relu=True, stride=2),
BasicConv(base_channel*2, base_channel*4, kernel_size=3, relu=True, stride=2),
BasicConv(base_channel*4, base_channel*2, kernel_size=4, relu=True, stride=2, transpose=True),
BasicConv(base_channel*2, base_channel, kernel_size=4, relu=True, stride=2, transpose=True),
BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1)
])
self.Decoder = nn.CellList([
DBlock(base_channel * 4, num_res),
DBlock(base_channel * 2, num_res),
DBlock(base_channel, num_res)
])
self.Convs = nn.CellList([
BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1),
BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1),
])
self.ConvsOut = nn.CellList(
[
BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1),
BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1),
]
)
self.AFFs = nn.CellList([
AFF(base_channel * 7, base_channel*1),
AFF(base_channel * 7, base_channel*2)
])
self.FAM1 = FAM(base_channel * 4)
self.SCM1 = SCM(base_channel * 4)
self.FAM2 = FAM(base_channel * 2)
self.SCM2 = SCM(base_channel * 2)
self.drop1 = nn.Dropout(0.1)
self.drop2 = nn.Dropout(0.1)
self.cat = ops.Stack(axis=1)
self.interpolate = nn.ResizeBilinear()
def construct(self, x):
"""construct MIMOUNetPlus"""
x_2 = self.interpolate(x, scale_factor=0.5)
x_4 = self.interpolate(x_2, scale_factor=0.5)
z2 = self.SCM2(x_2)
z4 = self.SCM1(x_4)
outputs = []
x_ = self.feat_extract[0](x)
res1 = self.Encoder[0](x_)
z = self.feat_extract[1](res1)
z = self.FAM2(z, z2)
res2 = self.Encoder[1](z)
z = self.feat_extract[2](res2)
z = self.FAM1(z, z4)
z = self.Encoder[2](z)
z12 = self.interpolate(res1, scale_factor=0.5)
z21 = self.interpolate(res2, scale_factor=2)
z42 = self.interpolate(z, scale_factor=2)
z41 = self.interpolate(z42, scale_factor=2)
res2 = self.AFFs[1](z12, res2, z42)
res1 = self.AFFs[0](res1, z21, z41)
res2 = self.drop2(res2)
res1 = self.drop1(res1)
z = self.Decoder[0](z)
z_ = self.ConvsOut[0](z)
z = self.feat_extract[3](z)
outputs.append(z_+x_4)
z = ops.Concat(1)([z, res2])
z = self.Convs[0](z)
z = self.Decoder[1](z)
z_ = self.ConvsOut[1](z)
z = self.feat_extract[4](z)
outputs.append(z_+x_2)
z = ops.Concat(1)([z, res1])
z = self.Convs[1](z)
z = self.Decoder[2](z)
z = self.feat_extract[5](z)
outputs.append(z+x)
return outputs
def build_net(model_name):
"""build network"""
if model_name == "MIMO-UNetPlus":
return MIMOUNetPlus()
if model_name == "MIMO-UNet":
return MIMOUNet()
raise ValueError('Wrong Model!\nYou should choose MIMO-UNetPlus or MIMO-UNet.')
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
preprocess dataset
"""
import argparse
import shutil
from pathlib import Path
def copy(src, dst):
"""copy"""
dst_blur = Path(dst, 'blur')
dst_blur.mkdir(parents=True, exist_ok=True)
dst_sharp = Path(dst, 'sharp')
dst_sharp.mkdir(parents=True, exist_ok=True)
src = Path(src)
for num, f_path in enumerate(sorted(src.rglob('*blur/*'))):
print(f_path, f_path.name)
print(f_path.parts[-3])
shutil.copy(f_path, dst_blur / f'{num + 1}.png')
for num, f_path in enumerate(sorted(src.rglob('*sharp/*'))):
print(f_path, f_path.name)
print(f_path.parts[-3])
shutil.copy(f_path, dst_sharp / f'{num + 1}.png')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--root_src', type=str)
parser.add_argument('--root_dst', type=str)
args = parser.parse_args()
copy(Path(args.root_src, 'train'), Path(args.root_dst, 'train'))
copy(Path(args.root_src, 'test'), Path(args.root_dst, 'test'))
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
train MIMO_UNet
"""
import random
from pathlib import Path
import numpy as np
from mindspore import context
from mindspore import dataset as ds
from mindspore import nn
from mindspore.common import set_seed
from mindspore.communication.management import get_group_size
from mindspore.communication.management import get_rank
from mindspore.communication.management import init
from mindspore.context import ParallelMode
from mindspore.train import Model
from mindspore.train.callback import CheckpointConfig
from mindspore.train.callback import LossMonitor
from mindspore.train.callback import ModelCheckpoint
from mindspore.train.callback import TimeMonitor
from src.config import config
from src.data_load import create_dataset_generator
from src.loss import ContentLoss
from src.metric import PSNR
from src.mimo_unet import MIMOUNet
def prepare_context(args):
"""prepare context"""
context.set_context(mode=context.GRAPH_MODE)
if args.is_train_distributed:
init()
args.rank = get_rank()
args.group_size = get_group_size()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
device_num=args.group_size,
gradients_mean=True)
else:
args.rank = 0
args.group_size = 1
context.set_context(device_id=args.device_id)
def prepare_dataset(args):
"""prepare dataset"""
train_dataset_generator = create_dataset_generator(Path(args.dataset_root, 'train'),
make_aug=True)
args.train_dataset_len = len(train_dataset_generator)
train_dataset = ds.GeneratorDataset(train_dataset_generator, ["image", "label"],
shuffle=True, num_parallel_workers=args.num_worker,
num_shards=args.group_size, shard_id=args.rank)
train_dataset = train_dataset.batch(batch_size=args.train_batch_size, drop_remainder=True)
return train_dataset
def prepare_optimizer(net, args):
"""prepare optimizer"""
lr = args.learning_rate
lr_list = []
for n_epoch in range(args.epochs_num):
for _ in range(args.train_dataset_len // args.train_batch_size // args.group_size):
lr_list.append(lr)
if (n_epoch + 1) % 500 == 0:
lr /= 2
optim = nn.Adam(net.trainable_params(), beta1=0.9, beta2=0.999, learning_rate=lr_list)
return optim
def prepare_callbacks(net, args):
"""prepare callbacks"""
step_per_epoch = (args.train_dataset_len // args.group_size // args.train_batch_size)
if args.rank == 0:
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_save_frequency * step_per_epoch,
saved_network=net,
keep_checkpoint_max=40)
ckpoint_cb = ModelCheckpoint(prefix='MIMO-UNet', directory=args.ckpt_save_directory,
config=ckpt_config)
train_callbacks = [
LossMonitor(step_per_epoch),
TimeMonitor(),
ckpoint_cb,
]
else:
train_callbacks = [
LossMonitor(step_per_epoch),
TimeMonitor(),
]
return train_callbacks
def train(args):
"""train"""
random.seed(1)
set_seed(1)
np.random.seed(1)
prepare_context(args)
print(f"info rank {args.rank}, groupsize {args.group_size}")
net = MIMOUNet()
content_loss = ContentLoss()
train_dataset = prepare_dataset(args)
optim = prepare_optimizer(net, args)
model = Model(net, content_loss, optim, metrics={"PSNR": PSNR()})
train_callbacks = prepare_callbacks(net, args)
print("train...")
model.train(args.epochs_num, train_dataset,
callbacks=train_callbacks,
dataset_sink_mode=True)
if __name__ == '__main__':
train(config)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment