diff --git a/official/cv/resnet/README.md b/official/cv/resnet/README.md index 796b8ae11a222d0e54a14872447922f0cc32b741..32a55437cf381a22d552a013dd9f2416e2ac45a0 100644 --- a/official/cv/resnet/README.md +++ b/official/cv/resnet/README.md @@ -810,7 +810,10 @@ Total data: 50000, top1 accuracy: 0.76844, top5 accuracy: 0.93522. # Apply algorithm in MindSpore Golden Stick MindSpore Golden Stick is a compression algorithm set for MindSpore. We usually apply algorithm in Golden Stick before training for smaller model size, lower power consuming or faster inference process. -MindSpore Golden Stick provides SimQAT algorithm for ResNet50. SimQAT is a quantization-aware training algorithm that trains the quantization parameters of certain layers in the network by introducing fake-quantization nodes, so that the model can perform inference with less power consumption or higher performance during the deployment phase. + +MindSpore Golden Stick provides SimQAT and SCOP algorithm for ResNet50. SimQAT is a quantization-aware training algorithm that trains the quantization parameters of certain layers in the network by introducing fake-quantization nodes, so that the model can perform inference with less power consumption or higher performance during the deployment phase. SCOP algorithm is a reliable pruning algorithm, which reduces the influence of all potential irrelevant factors by constructing a scientific control mechanism, and effectively deletes nodes in proportion, thereby realizing the miniaturization of the model. + +MindSpore Golden Stick provides SLB algorithm for ResNet18. SLB is provided by Huawei Noah's Ark Lab. SLB is a quantization algorithm with low-bit weight searching, it regards the discrete weights in an arbitrary quantized neural network as searchable variables, and utilize a differential method to search them accurately. In particular, each weight is represented as a probability distribution over the discrete value set. The probabilities are optimized during training and the values with the highest probability are selected to establish the desired quantized network. SLB have more advantage when quantize with low-bit compared with SimQAT. ## Training Process @@ -850,8 +853,13 @@ bash run_standalone_train_gpu.sh ../quantization/simqat/ ../quantization/simqat/ # standalone training example, apply SimQAT and train from pretrained checkpoint cd ./golden_stick/scripts/ bash run_standalone_train_gpu.sh ../quantization/simqat/ ../quantization/simqat/resnet50_cifar10_config.yaml /path/to/dataset PRETRAINED /path/to/pretrained_ckpt + +# Just replace PYTHON_PATH CONFIG_FILE for applying different algorithm, take SLB algorithm as an example +bash run_standalone_train_gpu.sh ../quantization/slb/ ../quantization/slb/resnet18_cifar10_config.yaml /path/to/dataset ``` +- SLB only support standalone training now, and not support train from full precision checkpoint. + ## Evaluation Process ### Running on GPU @@ -865,18 +873,51 @@ bash run_eval_gpu.sh [PYTHON_PATH] [CONFIG_FILE] [DATASET_PATH] [CHECKPOINT_PATH # evaluation example cd ./golden_stick/scripts/ bash run_eval_gpu.sh ../quantization/simqat/ ../quantization/simqat/resnet50_cifar10_config.yaml ./cifar10/train/ ./checkpoint/resnet-90.ckpt + +# Just replace PYTHON_PATH CONFIG_FILE for applying different algorithm, take SLB algorithm as an example +bash run_eval_gpu.sh ../quantization/slb/ ../quantization/slb/resnet18_cifar10_config.yaml ./cifar10/train/ ./checkpoint/resnet-100.ckpt ``` ### Result Evaluation result will be stored in the example path, whose folder name is "eval". Under this, you can find result like the following in log. -- Apply SimQAT on ResNet50, and evaluating with CIFAR-10 dataset +- Apply SimQAT on ResNet50, and evaluating with CIFAR-10 dataset: -```bash +```text result:{'top_1_accuracy': 0.9354967948717948, 'top_5_accuracy': 0.9981971153846154} ckpt=~/resnet50_cifar10/train_parallel0/resnet-180_195.ckpt ``` +- Apply SCOP on ResNet50, and evaluating with CIFAR-10 dataset: + +```text +result:{'top_1_accuracy': 0.9273838141025641} prune_rate=0.45 ckpt=~/resnet50_cifar10/train_parallel0/resnet-400_390.ckpt +``` + +- Apply SLB on ResNet18 with W4A8, and evaluating with CIFAR-10 dataset. W4A8 means quantize weight with 4bit and activation with 8bit: + +```text +result:{'top_1_accuracy': 0.9285857371794872, 'top_5_accuracy': 0.9959935897435898} ckpt=~/resnet18_cifar10/train_parallel/resnet-100_1562.ckpt +``` + +- Apply SLB on ResNet18 with W2A8, and evaluating with CIFAR-10 dataset. W2A8 means quantize weight with 2bit and activation with 8bit: + +```text +result:{'top_1_accuracy': 0.9207732371794872, 'top_5_accuracy': 0.9955929487179487} ckpt=~/resnet18_cifar10/train_parallel/resnet-100_1562.ckpt +``` + +- Apply SLB on ResNet18 with W1A8, and evaluating with CIFAR-10 dataset. W1A8 means quantize weight with 1bit and activation with 8bit: + +```text +result:{'top_1_accuracy': 0.8976362179487182, 'top_5_accuracy': 0.9923878205128205} ckpt=~/resnet18_cifar10/train_parallel/resnet-100_1562.ckpt +``` + +- Apply SLB on ResNet18 with W1A4, and evaluating with CIFAR-10 dataset. W1A4 means quantize weight with 1bit and activation with 4bit: + +```text +result:{'top_1_accuracy': 0.8845152243589743, 'top_5_accuracy': 0.9914863782051282} ckpt=~/resnet18_cifar10/train_parallel/resnet-100_1562.ckpt +``` + ## Inference Process ### Export MindIR diff --git a/official/cv/resnet/README_CN.md b/official/cv/resnet/README_CN.md index 60a922023b98bf5404a4714db511da6ae761b696..f77b5f97cf73f9bb00ac143ddf56e16e7056fb85 100644 --- a/official/cv/resnet/README_CN.md +++ b/official/cv/resnet/README_CN.md @@ -773,6 +773,8 @@ Total data: 50000, top1 accuracy: 0.76844, top5 accuracy: 0.93522. 针对ResNet50,金箍棒提供了SimQAT和SCOP算法,SimQAT是一种量化感知训练算法,通过引入伪量化节点来训练网络中的某些层的量化参数,从而在部署阶段,模型得以以更小的功耗或者更高的性能进行推理。SCOP算法提出一种可靠剪枝方法,通过构建一种科学控制机制减少所有潜在不相关因子的影响,有效的按比例进行节点删除,从而实现模型小型化。 +针对ResNet18,金箍棒引入了华为自研量化算法SLB,SLB是一种基于权值搜索的低比特量化算法,利用连续松弛策略搜索离散权重,训练时优化离散权重的分布,最后根据概率挑选离散权重实现量化。与传统的量化算法相比,规避了不准确的梯度更新过程,在极低比特量化中更有优势。 + ## 训练过程 ### GPU处理器环境运行 @@ -795,10 +797,6 @@ bash run_distribute_train_gpu.sh ../quantization/simqat/ ../quantization/simqat/ cd ./golden_stick/scripts/ bash run_distribute_train_gpu.sh ../quantization/simqat/ ../quantization/simqat/resnet50_cifar10_config.yaml /path/to/dataset PRETRAINED /path/to/pretrained_ckpt -# 分布式训练示例(应用SCOP算法进行剪枝训练) -cd ./golden_stick/scripts/ -bash run_distribute_train_gpu.sh ../pruner/scop/ ../pruner/scop/resnet50_cifar10_config.yaml ./cifar10/train/ PRETRAINED /path/to/pretrained_ckpt - # 单机训练 cd ./golden_stick/scripts/ # PYTHON_PATH 表示需要应用的算法的'train.py'脚本所在的目录。 @@ -815,8 +813,13 @@ bash run_standalone_train_gpu.sh ../quantization/simqat/ ../quantization/simqat/ # 单机训练示例(应用SimQAT算法并加载上次量化训练的checkoutpoint,继续进行量化训练) cd ./golden_stick/scripts/ bash run_standalone_train_gpu.sh ../quantization/simqat/ ../quantization/simqat/resnet50_cifar10_config.yaml /path/to/dataset PRETRAINED /path/to/pretrained_ckpt + +# 针对不同的量化算法,只需替换PYTHON_PATH CONFIG_FILE即可,以SLB算法为例: +bash run_standalone_train_gpu.sh ../quantization/slb/ ../quantization/slb/resnet18_cifar10_config.yaml ./cifar10/train/ ``` +- 当前SLB只支持单机训练,且不支持加载预训练的全精度checkpoint + ## 评估过程 ### GPU处理器环境运行 @@ -832,6 +835,9 @@ bash run_eval_gpu.sh [PYTHON_PATH] [CONFIG_FILE] [DATASET_PATH] [CHECKPOINT_PATH # 评估示例 cd ./golden_stick/scripts/ bash run_eval_gpu.sh ../quantization/simqat/ ../quantization/simqat/resnet50_cifar10_config.yaml ./cifar10/train/ ./checkpoint/resnet-90.ckpt + +# 针对不同的量化算法,只需替换PYTHON_PATH CONFIG_FILE即可,以SLB算法为例: +bash run_eval_gpu.sh ../quantization/slb/ ../quantization/slb/resnet18_cifar10_config.yaml ./cifar10/train/ ./checkpoint/resnet-100.ckpt ``` ### 结果 @@ -840,16 +846,40 @@ bash run_eval_gpu.sh ../quantization/simqat/ ../quantization/simqat/resnet50_cif - 使用SimQAT算法量化ResNet50,并使用CIFAR-10数据集评估: -```bash +```text result:{'top_1_accuracy': 0.9354967948717948, 'top_5_accuracy': 0.9981971153846154} ckpt=~/resnet50_cifar10/train_parallel0/resnet-180_195.ckpt ``` -- 使用SCOP算法剪枝ResNet,并使用CIFAR-10数据集评估: +- 使用SCOP算法剪枝ResNet50,并使用CIFAR-10数据集评估: ```text result:{'top_1_accuracy': 0.9273838141025641} prune_rate=0.45 ckpt=~/resnet50_cifar10/train_parallel0/resnet-400_390.ckpt ``` +- 使用SLB算法对ResNet18做W4A8量化,并使用CIFAR-10数据集评估,W4A8表示weight量化为4bit,activation量化为8bit: + +```text +result:{'top_1_accuracy': 0.9285857371794872, 'top_5_accuracy': 0.9959935897435898} ckpt=~/resnet18_cifar10/train_parallel/resnet-100_1562.ckpt +``` + +- 使用SLB算法对ResNet18做W2A8量化,并使用CIFAR-10数据集评估,W2A8表示weight量化为2bit,activation量化为8bit: + +```text +result:{'top_1_accuracy': 0.9207732371794872, 'top_5_accuracy': 0.9955929487179487} ckpt=~/resnet18_cifar10/train_parallel/resnet-100_1562.ckpt +``` + +- 使用SLB算法对ResNet18做W1A8量化,并使用CIFAR-10数据集评估,W1A8表示weight量化为1bit,activation量化为8bit: + +```text +result:{'top_1_accuracy': 0.8976362179487182, 'top_5_accuracy': 0.9923878205128205} ckpt=~/resnet18_cifar10/train_parallel/resnet-100_1562.ckpt +``` + +- 使用SLB算法对ResNet18做W1A4量化,并使用CIFAR-10数据集评估,W1A4表示weight量化为1bit,activation量化为4bit: + +```text +result:{'top_1_accuracy': 0.8845152243589743, 'top_5_accuracy': 0.9914863782051282} ckpt=~/resnet18_cifar10/train_parallel/resnet-100_1562.ckpt +``` + ## 推理过程 ### 导出MindIR diff --git a/official/cv/resnet/golden_stick/quantization/slb/eval.py b/official/cv/resnet/golden_stick/quantization/slb/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..6a39726527c67b45d4320910b00557dc5a2a17e2 --- /dev/null +++ b/official/cv/resnet/golden_stick/quantization/slb/eval.py @@ -0,0 +1,68 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""eval resnet.""" +import mindspore as ms +import mindspore.log as logger +from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits +from slb import create_slb +from src.resnet import resnet18 as resnet +from src.model_utils.config import config + +if config.dataset == "cifar10": + from src.dataset import create_dataset1 as create_dataset +else: + from src.dataset import create_dataset2 as create_dataset + +ms.set_seed(1) + +def eval_net(): + """eval net""" + target = config.device_target + if target != "GPU": + logger.warning("SLB only support GPU now!") + + # init context + if config.mode_name == "GRAPH": + ms.set_context(mode=ms.GRAPH_MODE, device_target=target, save_graphs=False) + else: + ms.set_context(mode=ms.PYNATIVE_MODE, device_target=target, save_graphs=False) + + # create dataset + dataset = create_dataset(dataset_path=config.data_path, do_train=False, batch_size=config.batch_size, + eval_image_size=config.eval_image_size, target=target) + + # define net + net = resnet(class_num=config.class_num) + algo = create_slb(config.quant_type) + net = algo.apply(net) + + # load checkpoint + param_dict = ms.load_checkpoint(config.checkpoint_file_path) + ms.load_param_into_net(net, param_dict) + net.set_train(False) + + # define loss + loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + + # define model + model = ms.Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'}) + + # eval model + res = model.eval(dataset) + print("result:", res, "ckpt=", config.checkpoint_file_path) + + +if __name__ == '__main__': + eval_net() diff --git a/official/cv/resnet/golden_stick/quantization/slb/resnet18_cifar10_config.yaml b/official/cv/resnet/golden_stick/quantization/slb/resnet18_cifar10_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ea592bc9b9f69426ca9b4a39e015794b65900ba4 --- /dev/null +++ b/official/cv/resnet/golden_stick/quantization/slb/resnet18_cifar10_config.yaml @@ -0,0 +1,104 @@ +# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing) +enable_modelarts: False +# Url for modelarts +data_url: "" +train_url: "" +checkpoint_url: "" +# Path for local +run_distribute: False +enable_profiling: False +data_path: "/cache/data" +output_path: "/cache/train" +load_path: "/cache/checkpoint_path/" +device_target: "Ascend" +checkpoint_path: "./checkpoint/" +checkpoint_file_path: "" + +# ============================================================================== +# Training options +optimizer: "Momentum" +infer_label: "" +class_num: 10 +batch_size: 32 +loss_scale: 1024 +momentum: 0.9 +weight_decay: 0.0001 +epoch_size: 100 +save_checkpoint: True +save_checkpoint_epochs: 5 +keep_checkpoint_max: 10 +warmup_epochs: 5 +use_label_smooth: True +label_smooth_factor: 0.1 +lr_decay_mode: "poly" +lr_init: 0.01 +lr_end: 0.00001 +lr_max: 0.1 +lars_epsilon: 0.0 +lars_coefficient: 0.001 + +net_name: "resnet18" +dataset: "cifar10" +device_num: 1 +pre_trained: "" +fp32_ckpt: "" +run_eval: False +eval_dataset_path: "" +parameter_server: False +filter_weight: False +save_best_ckpt: True +eval_start_epoch: 40 +eval_interval: 1 +enable_cache: False +cache_session_id: "" +mode_name: "GRAPH" +boost_mode: "O0" +conv_init: "XavierUniform" +dense_init: "RandomNormal" +train_image_size: 224 +eval_image_size: 224 + +# Golden-stick options +comp_algo: "SLB" +quant_type: "W1A8" +t_start_val: 1.0 +t_start_time: 0.2 +t_end_time: 0.6 +t_factor: 1.2 + +# Export options +device_id: 0 +width: 224 +height: 224 +file_name: "resnet18" +file_format: "MINDIR" +ckpt_file: "" +network_dataset: "resnet18_cifar10" + +# Retrain options +save_graphs: False +save_graphs_path: "./graphs" +has_trained_epoch: 0 +has_trained_step: 0 + +# postprocess resnet inference +result_path: '' +label_path: '' + +--- +# Help description for each configuration +enable_modelarts: "Whether training on modelarts, default: False" +data_url: "Dataset url for obs" +checkpoint_url: "The location of checkpoint for obs" +data_path: "Dataset path for local" +output_path: "Training output path for local" +load_path: "The location of checkpoint for obs" +device_target: "Target device type, available: [Ascend, GPU, CPU]" +enable_profiling: "Whether enable profiling while training, default: False" +num_classes: "Class for dataset" +batch_size: "Batch size for training and evaluation" +epoch_size: "Total training epochs." +checkpoint_path: "The location of the checkpoint file." +checkpoint_file_path: "The location of the checkpoint file." +save_graphs: "Whether save graphs during training, default: False." +save_graphs_path: "Path to save graphs." diff --git a/official/cv/resnet/golden_stick/quantization/slb/slb.py b/official/cv/resnet/golden_stick/quantization/slb/slb.py new file mode 100644 index 0000000000000000000000000000000000000000..2a791578084d5d3ef6506a36808b1d2b09d8850a --- /dev/null +++ b/official/cv/resnet/golden_stick/quantization/slb/slb.py @@ -0,0 +1,34 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Create SLB-QAT algorithm instance.""" + +from mindspore_gs.quantization.slb import SlbQuantAwareTraining as QBNNQAT +from mindspore_gs.quantization.constant import QuantDtype + +def create_slb(quant_type="graph_W4A8"): + algo = QBNNQAT() + if "W4A8" in quant_type: + algo.set_weight_quant_dtype(QuantDtype.INT4) + algo.set_act_quant_dtype(QuantDtype.INT8) + elif "W2A8" in quant_type: + algo.set_weight_quant_dtype(QuantDtype.INT2) + algo.set_act_quant_dtype(QuantDtype.INT8) + elif "W1A8" in quant_type: + algo.set_weight_quant_dtype(QuantDtype.INT1) + algo.set_act_quant_dtype(QuantDtype.INT8) + elif "W1A4" in quant_type: + algo.set_weight_quant_dtype(QuantDtype.INT1) + algo.set_act_quant_dtype(QuantDtype.INT4) + return algo diff --git a/official/cv/resnet/golden_stick/quantization/slb/train.py b/official/cv/resnet/golden_stick/quantization/slb/train.py new file mode 100644 index 0000000000000000000000000000000000000000..55a47dbe05c022e5a7897147fbc091b38219d16e --- /dev/null +++ b/official/cv/resnet/golden_stick/quantization/slb/train.py @@ -0,0 +1,286 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train resnet.""" +import os +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +import mindspore.train.callback as callback +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor +from mindspore.communication.management import init, get_rank +from mindspore.train.loss_scale_manager import FixedLossScaleManager +from slb import create_slb +from src.lr_generator import get_lr +from src.CrossEntropySmooth import CrossEntropySmooth +from src.metric import DistAccuracy +from src.resnet import conv_variance_scaling_initializer +from src.resnet import resnet18 as resnet +from src.model_utils.config import config + +if config.dataset == "cifar10": + from src.dataset import create_dataset1 as create_dataset +else: + if config.mode_name == "GRAPH": + from src.dataset import create_dataset2 as create_dataset + else: + from src.dataset import create_dataset_pynative as create_dataset + +ms.set_seed(1) + + +class LossCallBack(LossMonitor): + """ + Monitor the loss in training. + If the loss in NAN or INF terminating training. + """ + + def __init__(self, has_trained_epoch=0): + super(LossCallBack, self).__init__() + self.has_trained_epoch = has_trained_epoch + + def step_end(self, run_context): + cb_params = run_context.original_args() + loss = cb_params.net_outputs + + if isinstance(loss, (tuple, list)): + if isinstance(loss[0], ms.Tensor) and isinstance(loss[0].asnumpy(), np.ndarray): + loss = loss[0] + + if isinstance(loss, ms.Tensor) and isinstance(loss.asnumpy(), np.ndarray): + loss = np.mean(loss.asnumpy()) + + cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 + + if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)): + raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format( + cb_params.cur_epoch_num, cur_step_in_epoch)) + if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: + print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num + int(self.has_trained_epoch), + cur_step_in_epoch, loss), flush=True) + + +def filter_checkpoint_parameter_by_list(origin_dict, param_filter): + """remove useless parameters according to filter_list""" + for key in list(origin_dict.keys()): + for name in param_filter: + if name in key: + print("Delete parameter from checkpoint: ", key) + del origin_dict[key] + break + + +def set_parameter(): + """set_parameter""" + target = config.device_target + + # init context + if config.mode_name == "GRAPH": + ms.set_context(mode=ms.GRAPH_MODE, device_target=target, save_graphs=False) + else: + ms.set_context(mode=ms.PYNATIVE_MODE, device_target=target, save_graphs=False) + + if config.run_distribute: + # GPU target + init() + ms.set_auto_parallel_context(device_num=config.device_num, + parallel_mode=ms.ParallelMode.DATA_PARALLEL, + gradients_mean=True) + + +def init_weight(net): + """init_weight""" + for _, cell in net.cells_and_names(): + if isinstance(cell, nn.Conv2d): + if config.conv_init == "XavierUniform": + cell.weight.set_data(ms.common.initializer.initializer(ms.common.initializer.XavierUniform(), + cell.weight.shape, + cell.weight.dtype)) + elif config.conv_init == "TruncatedNormal": + weight = conv_variance_scaling_initializer(cell.in_channels, + cell.out_channels, + cell.kernel_size[0]) + cell.weight.set_data(weight) + if isinstance(cell, nn.Dense): + if config.dense_init == "TruncatedNormal": + cell.weight.set_data(ms.common.initializer.initializer(ms.common.initializer.TruncatedNormal(), + cell.weight.shape, + cell.weight.dtype)) + elif config.dense_init == "RandomNormal": + in_channel = cell.in_channels + out_channel = cell.out_channels + weight = np.random.normal(loc=0, scale=0.01, size=out_channel * in_channel) + weight = ms.Tensor(np.reshape(weight, (out_channel, in_channel)), dtype=cell.weight.dtype) + cell.weight.set_data(weight) + + +def load_pretrained_ckpt(net): + if config.pre_trained: + if os.path.isfile(config.pre_trained): + ckpt = ms.load_checkpoint(config.pre_trained) + if False and ckpt.get("epoch_num") and ckpt.get("step_num"): + config.has_trained_epoch = int(ckpt["epoch_num"].data.asnumpy()) + config.has_trained_step = int(ckpt["step_num"].data.asnumpy()) + else: + config.has_trained_epoch = 0 + config.has_trained_step = 0 + + if config.filter_weight: + filter_list = [x.name for x in net.end_point.get_parameters()] + filter_checkpoint_parameter_by_list(ckpt, filter_list) + ms.load_param_into_net(net, ckpt) + else: + print(f"Invalid pre_trained {config.pre_trained} parameter.") + + +def init_group_params(net): + decayed_params = [] + no_decayed_params = [] + for param in net.trainable_params(): + if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name: + decayed_params.append(param) + else: + no_decayed_params.append(param) + + group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay}, + {'params': no_decayed_params}, + {'order_params': net.trainable_params()}] + return group_params + + +def set_save_ckpt_dir(): + """set save ckpt dir""" + ckpt_save_dir = os.path.join(config.output_path, config.checkpoint_path) + if config.run_distribute: + ckpt_save_dir = ckpt_save_dir + "ckpt_" + str(get_rank()) + "/" + return ckpt_save_dir + + +def init_loss_scale(): + if config.dataset == "imagenet2012" or config.dataset == "cifar10": + if not config.use_label_smooth: + config.label_smooth_factor = 0.0 + loss = CrossEntropySmooth(sparse=True, reduction="mean", + smooth_factor=config.label_smooth_factor, num_classes=config.class_num) + else: + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + return loss + + +class TemperatureScheduler(callback.Callback): + """ + TemperatureScheduler for SLB. + """ + def __init__(self, model, epoch_size=100, t_start_val=1.0, t_start_time=0.2, t_end_time=0.6, t_factor=1.2): + super().__init__() + self.epochs = epoch_size + self.t_start_val = t_start_val + self.t_start_time = t_start_time + self.t_end_time = t_end_time + self.t_factor = t_factor + self.model = model + + def epoch_begin(self, run_context): + """ + Epoch_begin. + """ + cb_params = run_context.original_args() + epoch = cb_params.cur_epoch_num + # Compute temperature value + t = self.t_start_val + t_start_epoch = int(self.epochs*self.t_start_time) + t_end_epoch = int(self.epochs*self.t_end_time) + if epoch > t_start_epoch: + t *= self.t_factor**(min(epoch, t_end_epoch) - t_start_epoch) + # Assign new value to temperature parameter + for _, cell in self.model.train_network.cells_and_names(): + if cell.cls_name == 'QBNNFakeQuantizerPerLayer': # for QBNN + cell.set_temperature(t) + if epoch == t_end_epoch: + cell.set_temperature_end_flag() + print('Temperature stops changing. Start applying one-hot to latent weights.') + + +def train_net(): + """train net""" + print("Train configure: {}".format(config)) + target = config.device_target + if target != "GPU": + raise NotImplementedError("SLB only support running on GPU now!") + set_parameter() + dataset = create_dataset(dataset_path=config.data_path, do_train=True, + batch_size=config.batch_size, train_image_size=config.train_image_size, + eval_image_size=config.eval_image_size, target=target, + distribute=config.run_distribute) + step_size = dataset.get_dataset_size() + net = resnet(class_num=config.class_num) + + init_weight(net) + algo = create_slb(config.quant_type) + net = algo.apply(net) + load_pretrained_ckpt(net) + + lr = get_lr(lr_init=config.lr_init, + lr_end=config.lr_end, + lr_max=config.lr_max, + warmup_epochs=config.warmup_epochs, + total_epochs=config.epoch_size, + steps_per_epoch=step_size, + lr_decay_mode=config.lr_decay_mode) + if config.pre_trained: + lr = lr[config.has_trained_epoch * step_size:] + lr = ms.Tensor(lr) + # define optimizer + group_params = init_group_params(net) + if config.optimizer == 'Momentum': + opt = nn.Momentum(group_params, lr, config.momentum, weight_decay=config.weight_decay, + loss_scale=config.loss_scale) + + loss = init_loss_scale() + loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) + + metrics = {"acc"} + if config.run_distribute: + metrics = {'acc': DistAccuracy(batch_size=config.batch_size, device_num=config.device_num)} + model = ms.Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics=metrics, + amp_level="O0", boost_level=config.boost_mode, keep_batchnorm_fp32=False, + boost_config_dict={"grad_freeze": {"total_steps": config.epoch_size * step_size}}) + + # define callbacks + time_cb = TimeMonitor(data_size=step_size) + loss_cb = LossCallBack(config.has_trained_epoch) + + cb = [time_cb, loss_cb] + if algo: + algo_cb = algo.callback() + cb.append(algo_cb) + cb.append(TemperatureScheduler(model, config.epoch_size, config.t_start_val, + config.t_start_time, config.t_end_time, config.t_factor)) + ckpt_save_dir = set_save_ckpt_dir() + if config.save_checkpoint: + ckpt_append_info = [{"epoch_num": config.has_trained_epoch, "step_num": config.has_trained_step}] + config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size, + keep_checkpoint_max=config.keep_checkpoint_max, + append_info=ckpt_append_info) + ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck) + cb += [ckpt_cb] + # train model + dataset_sink_mode = target != "CPU" + model.train(config.epoch_size - config.has_trained_epoch, dataset, callbacks=cb, + sink_size=dataset.get_dataset_size(), dataset_sink_mode=dataset_sink_mode) + + +if __name__ == '__main__': + train_net()