Skip to content
Snippets Groups Projects
Commit 81852fd3 authored by ZJUTER0126's avatar ZJUTER0126
Browse files

[add] add Ascend910 training and Ascend310 inference for DDRNet

parent 662f35f0
No related branches found
No related tags found
No related merge requests found
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""TrainOneStepWithLossScaleCellGlobalNormClip"""
import mindspore.nn as nn
from mindspore.common import RowTensor
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
_grad_scale = C.MultitypeFuncGraph("grad_scale")
reciprocal = P.Reciprocal()
@_grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad):
return grad * F.cast(reciprocal(scale), F.dtype(grad))
@_grad_scale.register("Tensor", "RowTensor")
def tensor_grad_scale_row_tensor(scale, grad):
return RowTensor(grad.indices,
grad.values * F.cast(reciprocal(scale), F.dtype(grad.values)),
grad.dense_shape)
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
grad_overflow = P.FloatStatus()
class TrainOneStepWithLossScaleCellGlobalNormClip(nn.TrainOneStepWithLossScaleCell):
"""
Encapsulation class of SSD network training.
Append an optimizer to the training network after that the construct
function can be called to create the backward graph.
Args:
network (Cell): The training network. Note that loss function should have been added.
optimizer (Optimizer): Optimizer for updating the weights.
sens (Number): The adjust parameter. Default: 1.0.
use_global_nrom(bool): Whether apply global norm before optimizer. Default: False
"""
def __init__(self, network, optimizer, scale_sense=1.0, clip_global_norm=True,
clip_global_norm_value=1.0):
super(TrainOneStepWithLossScaleCellGlobalNormClip, self).__init__(network, optimizer, scale_sense)
self.clip_global_norm = clip_global_norm
self.clip_global_norm_value = clip_global_norm_value
self.print = P.Print()
def construct(self, *inputs):
"""construct"""
weights = self.weights
loss = self.network(*inputs)
scaling_sens = self.scale_sense
status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads)
# apply grad reducer on grads
grads = self.grad_reducer(grads)
# get the overflow buffer
cond = self.get_overflow_status(status, grads)
overflow = self.process_loss_scale(cond)
# if there is no overflow, do optimize
if not overflow:
if self.clip_global_norm:
grads = C.clip_by_global_norm(grads, clip_norm=self.clip_global_norm_value)
loss = F.depend(loss, self.optimizer(grads))
else:
self.print("=============Over Flow, skipping=============")
return loss
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
train
code: https://github.com/ydhongHIT/DDRNet
paper: https://arxiv.org/pdf/2101.06085.pdf
Acc: ImageNet1k-75.9%
"""
import os
import numpy as np
from mindspore import Model
from mindspore import context
from mindspore import nn
from mindspore.common import set_seed
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from src.args import args
from src.tools.callback import EvaluateCallBack
from src.tools.cell import cast_amp
from src.tools.criterion import get_criterion, NetWithLoss
from src.tools.get_misc import get_dataset, set_device, get_model, pretrained, get_train_one_step
from src.tools.optimizer import get_optimizer
def main():
set_seed(args.seed)
mode = {
0: context.GRAPH_MODE,
1: context.PYNATIVE_MODE
}
context.set_context(mode=mode[args.graph_mode], device_target=args.device_target)
context.set_context(enable_graph_kernel=False)
if args.device_target == "Ascend":
context.set_context(enable_auto_mixed_precision=True)
rank = set_device(args)
# get model and cast amp_level
net = get_model(args)
params_num = 0
for param in net.trainable_params():
params_num += np.prod(param.shape)
print(f"=> params_num: {params_num}")
cast_amp(net)
criterion = get_criterion(args)
net_with_loss = NetWithLoss(net, criterion)
if args.pretrained:
pretrained(args, net)
data = get_dataset(args)
batch_num = data.train_dataset.get_dataset_size()
optimizer = get_optimizer(args, net, batch_num)
# save a yaml file to read to record parameters
net_with_loss = get_train_one_step(args, net_with_loss, optimizer)
eval_network = nn.WithEvalCell(net, criterion, args.amp_level in ["O2", "O3", "auto"])
eval_indexes = [0, 1, 2]
model = Model(net_with_loss, metrics={"acc", "loss"},
eval_network=eval_network,
eval_indexes=eval_indexes)
config_ck = CheckpointConfig(save_checkpoint_steps=data.train_dataset.get_dataset_size(),
keep_checkpoint_max=args.keep_checkpoint_max)
time_cb = TimeMonitor(data_size=data.train_dataset.get_dataset_size())
ckpt_save_dir = "./ckpt_" + str(rank)
if args.run_modelarts:
ckpt_save_dir = "/cache/ckpt_" + str(rank)
ckpoint_cb = ModelCheckpoint(prefix=args.arch + str(rank), directory=ckpt_save_dir,
config=config_ck)
loss_cb = LossMonitor()
eval_cb = EvaluateCallBack(model, eval_dataset=data.val_dataset, src_url=ckpt_save_dir,
train_url=os.path.join(args.train_url, "ckpt_" + str(rank)),
save_freq=args.save_every)
print("begin train")
model.train(int(args.epochs - args.start_epoch), data.train_dataset,
callbacks=[time_cb, ckpoint_cb, loss_cb, eval_cb],
dataset_sink_mode=True)
print("train success")
if args.run_modelarts:
import moxing as mox
mox.file.copy_parallel(src_url=ckpt_save_dir, dst_url=os.path.join(args.train_url, "ckpt_" + str(rank)))
if __name__ == '__main__':
main()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment