Skip to content
Snippets Groups Projects
Unverified Commit b93d2765 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!2756 [毕设][模型复现] CBAM inster to ms master

Merge pull request !2756 from yangyanjuan/master
parents bdcdeebf e304ea70
No related branches found
No related tags found
No related merge requests found
# 目录
- [cbam说明](#cbam说明)
- [模型架构](#模型架构)
- [数据集](#数据集)
- [环境要求](#环境要求)
- [脚本说明](#脚本说明)
- [脚本及样例代码](#脚本及样例代码)
- [脚本参数](#脚本参数)
- [训练过程](#训练过程)
- [用法](#训练用法)
- [评估过程](#评估过程)
- [用法](#评估用法)
- [结果](#评估结果)
- [模型描述](#模型描述)
- [性能](#性能)
- [训练性能](#训练性能)
- [随机情况说明](#随机情况说明)
- [ModelZoo主页](#modelzoo主页)
# CBAM说明
CBAM(Convolutional Block Attention Module)是一种轻量级注意力模块的提出于2018年,它可以在空间维度和通道维度上进行Attention操作。
[论文](https://arxiv.org/abs/1807.06521): Sanghyuan Woo, Jongchan Park, Joon-Young Lee, In So Kweon. CBAM: Convolutional Block Attention Module.
# 模型架构
CBAM整体网络架构如下:
[链接](https://arxiv.org/abs/1807.06521)
# 数据集
使用的数据集:[RML2016.10A](https://www.xueshufan.com/publication/2562146178)
- 数据集大小:共611M,总共有22万条样本。
- 训练集:110000条样本。
- 测试集:取一个信噪比下的数据,含有样本5500条。
- 数据格式:IQ(In-phaseand Quadrature):2*128。
- 注:数据在src/data中处理。
# 环境要求
- 硬件(Ascend)
- 使用Ascend处理器来搭建硬件环境。
- 框架
- [MindSpore](https://www.mindspore.cn/install)
- 如需查看详情,请参见如下资源:
- [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
# 脚本说明
## 脚本及样例代码
```path
.
└─CBAM
├─src
├─data.py # 数据集处理
├─model.py # CBAM网络定义
├─get_lr.py # 生成学习率
├─model_utils
├─config.py # 参数配置
├─device_adapter.py # 适配云上或线下
├─local_adapter.py # 线下配置
├─moxing_adapter.py # 云上配置
├──eval.py # 评估网络
├──train.py # 训练网络
├──default_config.yaml # 参数配置
└──README.md # README文件
```
## 脚本参数
在default_config.yaml中可以同时配置训练和评估参数。
```python
"batch_size":32, # 输入张量的批次大小
"epoch_size":70, # 训练周期大小
"lr_init":0.001, # 初始学习率
"save_checkpoint":True, # 是否保存检查点
"save_checkpoint_epochs":1, # 两个检查点之间的周期间隔;默认情况下,最后一个检查点将在最后一个周期完成后保存
"keep_checkpoint_max":10, # 只保存最后一个keep_checkpoint_max检查点
"warmup_epochs":5, # 热身周期
```
## 训练过程
### 训练用法
首先需要在`default_config.yaml`中设置好超参数。
您可以通过华为云等资源开始训练,其中配置如下所示:
```shell
Ascend:
训练输入:data_url = /cbam/dataset,
训练输出:train_url = /cbam/train_output,
输出日志:/cbam/train_logs
```
### 训练结果
Ascend评估结果保存在`/cbam/train_logs`下。您可以在日志中找到类似以下的结果。
```log
epoch: 1 step: 3437, loss is 0.7258548
epoch: 2 step: 3437, loss is 0.6980165
epoch: 3 step: 3437, loss is 0.6887816
epoch: 4 step: 3437, loss is 0.7017617
epoch: 5 step: 3437, loss is 0.694684
```
## 评估过程
### 评估用法
与训练相同,在`default_config.yaml`中设置好超参数,通过华为云平台进行训练:
```shell
Ascend:
训练输入:data_url = /cbam/dataset,
训练输入:ckpt_file = /cbam/train_output/cbam_train-70_3437.ckpt,
输出日志:/cbam/eval_logs
```
### 评估结果
Ascend评估结果保存在`/cbam/eval_logs`下。您可以在日志中找到类似以下的结果。
```log
result: {'Accuracy': 0.8494152046783626}
```
# 模型描述
## 性能
### 训练性能
| 参数 | Ascend 910 |
| -------------------------- | ---------------------------------------------------------- |
| 资源 | Ascend 910 |
| 上传日期 | 2022-5-31 |
| MindSpore版本 | 1.5.1 |
| 数据集 | RML2016.10A |
| 训练参数 | default_config.yaml |
| 优化器 | Adam |
| 损失函数 | BCEWithLogitsLoss |
| 损失 | 0.6702158 |
| 准确率 | 84.9% |
| 总时长 | 41分钟 (1卡) |
| 调优检查点 | 5.80 M(.ckpt文件) |
# 随机情况说明
在train.py中的随机种子。
# ModelZoo主页
请浏览官网[主页](https://gitee.com/mindspore/models)
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: True
data_url: ""
train_url: ""
checkpoint_url: ""
data_path: "/cache/data/"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path"
checkpoint_path: './checkpoint/'
checkpoint_file: './checkpoint/cbam_train-70_3437.ckpt'
device_target: Ascend
enable_profiling: False
ckpt_path: "/cache/train"
ckpt_file: "/cache/train//cbam_train-70_3437.ckpt"
# ==============================================================================
# Training options
lr: 0.01
lr_init: 0.01
lr_max: 0.1
lr_epochs: '30, 60, 90, 120'
lr_scheduler: "piecewise"
warmup_epochs: 5
epoch_size: 70
max_epoch: 70
momentum: 0.9
loss_scale: 1.0
label_smooth: 0
label_smooth_factor: 0
weight_decay: 0.0005
batch_size: 32
keep_checkpoint_max: 10
MINDIR_name: 'cbam.MINDIR'
ckpt_file_path: './cbam_train-70_3437.ckpt'
num_classes : 4
dataset_name: 'rate'
dataset_sink_mode: True
device_id: 0
save_checkpoint: True
save_checkpoint_epochs: 1
local_data_path: '../data'
run_distribute: False
batch_norm: False
initialize_mode: "KaimingNormal"
padding: 1
pad_mode: 'pad'
has_bias: False
has_dropout: True
# Model Description
model_name: CBAM
file_name: 'CBAM'
file_format: 'MINDIR'
---
# Config description for each option
enable_modelarts: 'Whether training on modelarts, default: False'
data_url: 'Dataset url for obs'
train_url: 'Training output url for obs'
data_path: 'Dataset path for local'
output_path: 'Training output path for local'
device_target: 'Target device type'
enable_profiling: 'Whether enable profiling while training, default: False'
---
device_target: ['Ascend', 'GPU', 'CPU']
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
##############test CBAM example on dataset#################
python eval.py
"""
from mindspore import context
from mindspore.train import Model
from mindspore.communication.management import init
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn.metrics import Accuracy
from mindspore.nn import SoftmaxCrossEntropyWithLogits
from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.device_adapter import get_device_id, get_device_num
from src.data import create_dataset
from src.model import resnet50_cbam
def modelarts_process():
config.ckpt_path = config.ckpt_file
@moxing_wrapper(pre_process=modelarts_process)
def eval_():
""" model eval """
device_num = get_device_num()
if device_num > 1:
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', save_graphs=False)
if config.device_target == "Ascend":
context.set_context(device_id=get_device_id())
init()
elif config.device_target == "GPU":
init()
print("================init finished=====================")
ds_eval = create_dataset(data_path=config.data_path, batch_size=config.batch_size,
training=False, snr=2, target=config.device_target)
if ds_eval.get_dataset_size() == 0:
raise ValueError("Please check dataset size > 0 and batch_size <= dataset size.")
print("ds_eval_size", ds_eval.get_dataset_size())
print("===============create dataset finished==============")
net = resnet50_cbam(phase="test")
param_dict = load_checkpoint(config.ckpt_path)
print("load checkpoint from [{}].".format(config.ckpt_path))
load_param_into_net(net, param_dict)
net.set_train(False)
loss = SoftmaxCrossEntropyWithLogits()
model = Model(net, loss_fn=loss, metrics={"Accuracy": Accuracy()})
result = model.eval(ds_eval, dataset_sink_mode=False)
print("===================result: {}==========================".format(result))
if __name__ == '__main__':
eval_()
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Data preprocessing"""
import os
import pickle
import numpy as np
from mindspore.communication.management import get_rank, get_group_size
import mindspore.dataset as de
import mindspore.common.dtype as mstype
import mindspore.dataset.transforms.c_transforms as C
def _get_rank_info(run_distribute):
"""get rank size and rank id"""
rank_size = int(os.environ.get("RANK_SIZE", 1))
if run_distribute:
rank_size = get_group_size()
rank_id = get_rank()
else:
rank_size = 1
rank_id = 0
return rank_size, rank_id
class Get_Data():
"""
The data is preprocessed before being converted to midnspore.
"""
def __init__(self, data_path, snr=None, training=True):
self.data_path = data_path + 'RML2016.10a_dict.pkl'
self.snr = snr
self.do_train = training
self.data_file = open(self.data_path, 'rb')
self.all_data = pickle.load(self.data_file, encoding='iso-8859-1')
self.snrs, self.mods = map(lambda j: sorted(list(set(map(lambda x: x[j], self.all_data.keys())))), [1, 0])
def to_one_hot(self, label_index):
"""generate one hot label"""
one_hot_label = np.zeros([len(label_index), max(label_index) + 1])
one_hot_label[np.arange(len(label_index)), label_index] = 1
return one_hot_label
def get_loader(self):
train_data = []
train_label = []
test_data = []
test_label = []
train_index = []
for key in self.all_data.keys():
v = self.all_data[key]
train_data.append(v[:v.shape[0]//2])
test_data.append(v[v.shape[0]//2:])
for i in range(self.all_data[key].shape[0]//2):
train_label.append(key)
test_label.append(key)
train_data = np.vstack(train_data)
test_data = dict(zip(self.all_data.keys(), test_data))
for i in range(0, 110000):
train_index.append(i)
ds_label = []
if self.do_train:
ds = train_data
ds = np.expand_dims(ds, axis=1)
ds_label = self.to_one_hot(list(map(lambda x: self.mods.index(train_label[x][0]), train_index)))
else:
ds = []
for mod in self.mods:
ds.append((test_data[(mod, self.snr)]))
ds_label += [mod] * test_data[(mod, self.snr)].shape[0]
ds = np.vstack(ds)
ds = np.expand_dims(ds, axis=1)
ds_label = self.to_one_hot(list(self.mods.index(x) for x in ds_label))
return ds, ds_label
def create_dataset(data_path,
batch_size=1,
training=True,
snr=None,
target="Ascend",
run_distribute=False):
"""create dataset for train or eval"""
if target == "Ascend":
device_num, rank_id = _get_rank_info(run_distribute)
if training:
getter = Get_Data(data_path=data_path, snr=None, training=training)
data = getter.get_loader()
else:
getter = Get_Data(data_path=data_path, snr=snr, training=training)
data = getter.get_loader()
dataset_column_names = ["data", "label"]
if target != "Ascend" or device_num == 1:
if training:
ds = de.NumpySlicesDataset(data=data,
column_names=dataset_column_names,
shuffle=True)
else:
ds = de.NumpySlicesDataset(data=data,
column_names=dataset_column_names,
shuffle=False)
else:
if training:
ds = de.NumpySlicesDataset(data=data,
column_names=dataset_column_names,
shuffle=True,
num_shards=device_num,
shard_id=rank_id)
else:
ds = de.NumpySlicesDataset(data=data,
column_names=dataset_column_names,
shuffle=False,
num_shards=device_num,
shard_id=rank_id)
ds_label = [
C.TypeCast(mstype.float32)
]
ds = ds.map(operations=ds_label, input_columns=["label"])
ds = ds.batch(batch_size, drop_remainder=True)
return ds
if __name__ == '__main__':
data_getter = Get_Data(data_path='../data/', snr=2, training=False)
ms_ds = data_getter.get_loader()
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""get learning rate"""
import numpy as np
def get_lr(init_lr, total_epoch, step_per_epoch, anneal_step=250):
"""warmup lr schedule"""
total_step = total_epoch * step_per_epoch
lr_step = []
for step in range(total_step):
lambda_lr = anneal_step ** 0.5 * \
min((step + 1) * anneal_step ** -1.5, (step + 1) ** -0.5)
lr_step.append(init_lr * lambda_lr)
learning_rate = np.array(lr_step).astype(np.float32)
return learning_rate
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Generate network."""
import math
from mindspore import nn
from mindspore import ops as P
import mindspore.common.initializer as weight_init
class ChannelAttention(nn.Cell):
"""
ChannelAttention: Since each channel of the feature map is considered as a feature detector, it is meaningful
for the channel to focus on the "what" of a given input image;In order to effectively calculate channel attention,
the method of compressing the spatial dimension of input feature mapping is adopted.
"""
def __init__(self, in_channel):
super(ChannelAttention, self).__init__()
self.avg_pool = P.ReduceMean(keep_dims=True)
self.max_pool = P.ReduceMax(keep_dims=True)
self.fc = nn.SequentialCell(nn.Conv2d(in_channels=in_channel, out_channels=in_channel // 16, kernel_size=1,
has_bias=False),
nn.ReLU(),
nn.Conv2d(in_channels=in_channel // 16, out_channels=in_channel, kernel_size=1,
has_bias=False))
self.sigmoid = nn.Sigmoid()
def construct(self, x):
avg_out = self.avg_pool(x, -1)
avg_out = self.fc(avg_out)
max_out = self.max_pool(x, -1)
max_out = self.fc(max_out)
out = avg_out + max_out
return self.sigmoid(out)
class SpatialAttention(nn.Cell):
"""
SpatialAttention: Different from the channel attention module, the spatial attention module focuses on the
"where" of the information part as a supplement to the channel attention module.
"""
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, pad_mode='pad', has_bias=False)
self.concat = P.Concat(axis=1)
self.sigmod = nn.Sigmoid()
self.reduce_mean = P.ReduceMean(keep_dims=True)
self.max_pool = P.ReduceMax(keep_dims=True)
def construct(self, x):
avg_out = self.reduce_mean(x, 1)
max_out = self.max_pool(x, 1)
x = self.concat((avg_out, max_out))
x = self.conv1(x)
return self.sigmod(x)
def conv3x3(in_channels, out_channels, stride=1):
"""
3x3 convolution with padding.
"""
return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride,
padding=1, has_bias=False)
class Bottleneck(nn.Cell):
"""
Residual structure.
"""
expansion = 4
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, has_bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride,
pad_mode='pad', padding=1, has_bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv3 = nn.Conv2d(out_channels, out_channels * 4, kernel_size=1, has_bias=False)
self.bn3 = nn.BatchNorm2d(out_channels * 4)
self.relu = nn.ReLU()
self.ca = ChannelAttention(out_channels * 4)
self.sa = SpatialAttention()
self.dowmsample = downsample
self.stride = stride
def construct(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
out = self.ca(out) * out
out = self.sa(out) * out
if self.dowmsample is not None:
residual = self.dowmsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Cell):
"""
Overall network architecture.
"""
def __init__(self, block, layers, num_classes=11, phase="train"):
self.in_channels = 64
super(ResNet, self).__init__()
self.phase = phase
self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, pad_mode='pad', padding=3, has_bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
self.layer1 = self._make_layer(block, 8, layers[0])
self.layer2 = self._make_layer(block, 16, layers[1], stride=2)
self.layer3 = self._make_layer(block, 32, layers[2], stride=2)
self.layer4 = self._make_layer(block, 64, layers[3], stride=2)
self.avgpool = P.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten()
self.Linear = nn.Dense(64 * block.expansion, num_classes)
dropout_ratio = 0.5
self.dropout = nn.Dropout(dropout_ratio)
self.softmax = nn.Softmax()
self.print = P.Print()
def construct(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x, 3)
x = x.view(32, 256)
x = self.Linear(x)
x = self.softmax(x)
return x
def _make_layer(self, block, out_channels, blocks, stride=1):
downsample = None
if stride != 1 or self.in_channels != out_channels * block.expansion:
downsample = nn.SequentialCell(
nn.Conv2d(self.in_channels, out_channels * block.expansion,
kernel_size=1, stride=stride, has_bias=False),
nn.BatchNorm2d(out_channels * block.expansion)
)
layers = []
layers.append(block(self.in_channels, out_channels, stride, downsample))
self.in_channels = out_channels * block.expansion
for i in range(1, blocks):
i += 1
layers.append(block(self.in_channels, out_channels))
return nn.SequentialCell(*layers)
def custom_init_weight(self):
"""
Init the weight of Conv2d and Batchnorm2D in the net.
:return:
"""
for _, cell in self.cells_and_names():
if isinstance(cell, nn.Conv2d):
n = cell.kernel_size[0] * cell.kernel_size[1] * cell.out_channels
cell.weight.set_data(weight_init.initializer(weight_init.Normal(sigma=math.sqrt(2. / n), mean=0),
cell.weight.shape,
cell.weight.dtype))
elif isinstance(cell, nn.BatchNorm2d):
cell.weight.set_data(weight_init.initializer(weight_init.One(),
cell.weight.shape,
cell.weight.dtype))
def resnet50_cbam(phase="train", **kwargs):
"""
Constructs a ResNet-50 model.
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], phase=phase, **kwargs)
return model
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse arguments"""
import os
import ast
import argparse
from pprint import pprint, pformat
import yaml
class Config:
"""
Configuration namespace. Convert dictionary to members.
"""
def __init__(self, cfg_dict):
for k, v in cfg_dict.items():
if isinstance(v, (list, tuple)):
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
else:
setattr(self, k, Config(v) if isinstance(v, dict) else v)
def __str__(self):
return pformat(self.__dict__)
def __repr__(self):
return self.__str__()
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
"""
Parse command line arguments to the configuration according to the default yaml.
Args:
parser: Parent parser.
cfg: Base configuration.
helper: Helper description.
cfg_path: Path to the default yaml config.
"""
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
parents=[parser])
helper = {} if helper is None else helper
choices = {} if choices is None else choices
for item in cfg:
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
choice = choices[item] if item in choices else None
if isinstance(cfg[item], bool):
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
help=help_description)
else:
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
help=help_description)
args = parser.parse_args()
return args
def parse_yaml(yaml_path):
"""
Parse the yaml config file.
Args:
yaml_path: Path to the yaml config.
"""
with open(yaml_path, 'r') as fin:
try:
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
cfgs = [x for x in cfgs]
if len(cfgs) == 1:
cfg_helper = {}
cfg = cfgs[0]
cfg_choices = {}
elif len(cfgs) == 2:
cfg, cfg_helper = cfgs
cfg_choices = {}
elif len(cfgs) == 3:
cfg, cfg_helper, cfg_choices = cfgs
else:
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
print(cfg_helper)
except:
raise ValueError("Failed to parse yaml")
return cfg, cfg_helper, cfg_choices
def merge(args, cfg):
"""
Merge the base config from yaml file and command line arguments.
Args:
args: Command line arguments.
cfg: Base configuration.
"""
args_var = vars(args)
for item in args_var:
cfg[item] = args_var[item]
return cfg
def get_config():
"""
Get Config according to the yaml file and cli arguments.
"""
parser = argparse.ArgumentParser(description="default name", add_help=False)
current_dir = os.path.dirname(os.path.abspath(__file__))
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../default_config.yaml"),
help="Config file path")
path_args, _ = parser.parse_known_args()
default, helper, choices = parse_yaml(path_args.config_path)
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
final_config = merge(args, default)
pprint(final_config)
print("Please check the above information for the configurations", flush=True)
return Config(final_config)
config = get_config()
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Device adapter for ModelArts"""
from src.model_utils.config import config
if config.enable_modelarts:
from src.model_utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
else:
from src.model_utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
__all__ = [
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
]
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Local adapter"""
import os
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
return "Local Job"
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Moxing adapter for ModelArts"""
import os
import functools
from mindspore import context
from mindspore.profiler import Profiler
from src.model_utils.config import config
_global_sync_count = 0
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
job_id = os.getenv('JOB_ID')
job_id = job_id if job_id != "" else "default"
return job_id
def sync_data(from_path, to_path):
"""
Download data from remote obs to local directory if the first url is remote url and the second one is local path
Upload data from local directory to remote obs in contrast.
"""
import moxing as mox
import time
global _global_sync_count
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
_global_sync_count += 1
# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("from path: ", from_path)
print("to path: ", to_path)
mox.file.copy_parallel(from_path, to_path)
print("===finish data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
print("===save flag===")
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Finish sync data from {} to {}.".format(from_path, to_path))
def moxing_wrapper(pre_process=None, post_process=None):
"""
Moxing wrapper to download dataset and upload outputs.
"""
def wrapper(run_func):
@functools.wraps(run_func)
def wrapped_func(*args, **kwargs):
# Download data from data_url
if config.enable_modelarts:
if config.data_url:
sync_data(config.data_url, config.data_path)
print("Dataset downloaded: ", os.listdir(config.data_path))
if config.checkpoint_url:
sync_data(config.checkpoint_url, config.load_path)
print("Preload downloaded: ", os.listdir(config.load_path))
if config.train_url:
sync_data(config.train_url, config.output_path)
print("Workspace downloaded: ", os.listdir(config.output_path))
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
config.device_num = get_device_num()
config.device_id = get_device_id()
if not os.path.exists(config.output_path):
os.makedirs(config.output_path)
if pre_process:
pre_process()
if config.enable_profiling:
profiler = Profiler()
run_func(*args, **kwargs)
if config.enable_profiling:
profiler.analyse()
# Upload data to train_url
if config.enable_modelarts:
if post_process:
post_process()
if config.train_url:
print("Start to copy output directory")
sync_data(config.output_path, config.train_url)
return wrapped_func
return wrapper
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
##############train CBAM example on dataset#################
python train.py
"""
from mindspore.common import set_seed
from mindspore import context
from mindspore.communication.management import init
from mindspore.context import ParallelMode
from mindspore.nn.optim import Adam
from mindspore.train.model import Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.nn.metrics import Accuracy
from mindspore.nn import BCEWithLogitsLoss
from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.device_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
from src.data import create_dataset
from src.model import resnet50_cbam
from src.get_lr import get_lr
set_seed(1)
def modelarts_pre_process():
pass
@moxing_wrapper(pre_process=modelarts_pre_process)
def run_train():
"""train function"""
print('device id:', get_device_id())
print('device num:', get_device_num())
print('rank id:', get_rank_id())
print('job id:', get_job_id())
device_target = config.device_target
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
context.set_context(save_graphs=False)
if config.device_target == "GPU":
context.set_context(enable_graph_kernel=True)
context.set_context(graph_kernel_flags="--enable_cluster_ops=MatMul")
device_num = get_device_num()
if config.run_distribute:
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num,
parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
if device_target == "Ascend":
context.set_context(device_id=get_device_id())
init()
elif device_target == "GPU":
init()
else:
context.set_context(device_id=get_device_id())
print("init finished.")
ds_train = create_dataset(data_path=config.data_path, batch_size=config.batch_size, training=True,
snr=None, target=config.device_target, run_distribute=config.run_distribute)
if ds_train.get_dataset_size() == 0:
raise ValueError("Please check dataset size > 0 and batch_size <= dataset size.")
print("create dataset finished.")
step_per_size = ds_train.get_dataset_size()
print("train dataset size:", step_per_size)
net = resnet50_cbam(phase="train")
loss = BCEWithLogitsLoss()
lr = get_lr(0.001, config.epoch_size, step_per_size, step_per_size * 2)
opt = Adam(net.trainable_params(),
learning_rate=lr,
beta1=0.9,
beta2=0.999,
eps=1e-7,
weight_decay=0.0,
loss_scale=1.0)
metrics = {"Accuracy": Accuracy()}
model = Model(net, loss_fn=loss, optimizer=opt, metrics=metrics)
time_cb = TimeMonitor()
loss_cb = LossMonitor(per_print_times=step_per_size)
callbacks_list = [time_cb, loss_cb]
config_ck = CheckpointConfig(save_checkpoint_steps=step_per_size * 10,
keep_checkpoint_max=100)
if get_rank_id() == 0:
ckpoint_cb = ModelCheckpoint(prefix='cbam_train', directory=config.ckpt_path, config=config_ck)
callbacks_list.append(ckpoint_cb)
print("train start!")
model.train(epoch=config.epoch_size,
train_dataset=ds_train,
callbacks=callbacks_list,
dataset_sink_mode=config.dataset_sink_mode)
if __name__ == '__main__':
run_train()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment