-
EvanBay authored
ff wnwn ewew rererere rere rerererere
fb17621b
get_misc.py 4.88 KiB
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""misc functions for program"""
import os
from mindspore import context
from mindspore import nn
from mindspore.communication.management import init, get_rank
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src import models, data
from src.data.data_utils.moxing_adapter import sync_data
from src.trainers import TrainClipGrad
def set_device(args):
"""Set device and ParallelMode(if device_num > 1)"""
rank = 0
# set context and device
device_target = args.device_target
device_num = int(os.environ.get("DEVICE_NUM", 1))
if device_target == "Ascend":
if device_num > 1:
context.set_context(device_id=int(os.environ["DEVICE_ID"]))
init(backend_name='hccl')
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
# context.set_auto_parallel_context(pipeline_stages=2, full_batch=True)
rank = get_rank()
else:
context.set_context(device_id=args.device_id)
elif device_target == "GPU":
if device_num > 1:
init(backend_name='nccl')
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
rank = get_rank()
else:
context.set_context(device_id=args.device_id)
else:
raise ValueError("Unsupported platform.")
return rank
def get_dataset(args, training=True):
""""Get model according to args.set"""
print(f"=> Getting {args.set} dataset")
dataset = getattr(data, args.set)(args, training)
return dataset
def get_model(args):
""""Get model according to args.arch"""
print("==> Creating model '{}'".format(args.arch))
model = models.__dict__[args.arch](args)
return model
def pretrained(args, model):
""""Load pretrained weights if args.pretrained is given"""
if args.run_modelarts:
print('Download data.')
local_data_path = '/cache/weight'
name = args.pretrained.split('/')[-1]
path = f"/".join(args.pretrained.split("/")[:-1])
sync_data(path, local_data_path, threads=128)
args.pretrained = os.path.join(local_data_path, name)
print("=> loading pretrained weights from '{}'".format(args.pretrained))
param_dict = load_checkpoint(args.pretrained)
for key, value in param_dict.copy().items():
if 'head' in key:
if value.shape[0] != args.num_classes:
print(f'==> removing {key} with shape {value.shape}')
param_dict.pop(key)
load_param_into_net(model, param_dict)
elif os.path.isfile(args.pretrained):
print("=> loading pretrained weights from '{}'".format(args.pretrained))
param_dict = load_checkpoint(args.pretrained)
for key, value in param_dict.copy().items():
if 'head' in key:
if value.shape[0] != args.num_classes:
print(f'==> removing {key} with shape {value.shape}')
param_dict.pop(key)
load_param_into_net(model, param_dict)
else:
print("=> no pretrained weights found at '{}'".format(args.pretrained))
def get_train_one_step(args, net_with_loss, optimizer):
"""get_train_one_step cell"""
if args.is_dynamic_loss_scale:
print(f"=> Using DynamicLossScaleUpdateCell")
scale_sense = nn.wrap.loss_scale.DynamicLossScaleUpdateCell(loss_scale_value=2 ** 24, scale_factor=2,
scale_window=2000)
else:
print(f"=> Using FixedLossScaleUpdateCell, loss_scale_value:{args.loss_scale}")
scale_sense = nn.wrap.FixedLossScaleUpdateCell(loss_scale_value=args.loss_scale)
net_with_loss = TrainClipGrad(net_with_loss, optimizer, scale_sense=scale_sense,
clip_global_norm_value=args.clip_global_norm_value,
use_global_norm=True)
return net_with_loss