Skip to content
Snippets Groups Projects
Commit 47029de3 authored by hangangqiang's avatar hangangqiang
Browse files

fix bug in readme & resnet slb support pretrain

parent c1bb6424
No related branches found
No related tags found
No related merge requests found
......@@ -18,13 +18,11 @@
- [Infer on Ascend310](#infer-on-ascend310)
- [result](#result)
- [Apply algorithm in MindSpore Golden Stick](#apply-algorithm-in-mindspore-golden-stick)
- [Training Process](#Training Process-1)
- [Training Process](#training-process-1)
- [Running on GPU](#running-on-gpu-1)
- [Evaluation Process](#evaluation-process-1)
- [Running on GPU](#running-on-gpu-2)
- [Result](#resutl-3)
- [Inference Process](#inference-process-1)
- [Export MindIR](#export-mindir-1)
- [Result](#result-3)
- [Model Description](#model-description)
- [Performance](#performance)
- [Evaluation Performance](#evaluation-performance)
......@@ -340,12 +338,6 @@ Evaluation result will be stored in the example path, whose folder name is "eval
================ {'Accuracy': 0.9907852564102564} ================
```
## Inference Process
### Export MindIR
Not support exporting MindIR now.
## [Model Description](#contents)
### [Performance](#contents)
......
......@@ -25,8 +25,6 @@
- [评估过程](#评估过程-1)
- [GPU处理器环境运行](#gpu处理器环境运行-1)
- [结果](#结果-3)
- [推理过程](#推理过程-1)
- [导出MindIR](#导出mindir-1)
- [模型描述](#模型描述)
- [性能](#性能)
- [评估性能](#评估性能)
......@@ -342,12 +340,6 @@ bash run_eval_gpu.sh ../quantization/simqat/ ../quantization/simqat/lenet_mnist_
================ {'Accuracy': 0.9907852564102564} ================
```
## 推理过程
### 导出MindIR
当前暂不支持导出MindIR。
## 模型描述
## 性能
......
......@@ -30,13 +30,11 @@
- [Infer on Ascend310](#infer-on-ascend310)
- [result](#result-2)
- [Apply algorithm in MindSpore Golden Stick](#apply-algorithm-in-mindspore-golden-stick)
- [Training Process](#Training Process-1)
- [Training Process](#training-process-1)
- [Running on GPU](#running-on-gpu-2)
- [Evaluation Process](#evaluation-process-1)
- [Running on GPU](#running-on-gpu-3)
- [Result](#resutl-3)
- [Inference Process](#inference-process-1)
- [Export MindIR](#export-mindir-1)
- [Result](#result-3)
- [Model Description](#model-description)
- [Performance](#performance)
- [Evaluation Performance](#evaluation-performance)
......@@ -918,12 +916,6 @@ result:{'top_1_accuracy': 0.8976362179487182, 'top_5_accuracy': 0.99238782051282
result:{'top_1_accuracy': 0.8845152243589743, 'top_5_accuracy': 0.9914863782051282} ckpt=~/resnet18_cifar10/train_parallel/resnet-100_1562.ckpt
```
## Inference Process
### Export MindIR
Not support exporting MindIR now.
# [Model Description](#contents)
## [Performance](#contents)
......
......@@ -38,8 +38,6 @@
- [评估过程](#评估过程-1)
- [GPU处理器环境运行](#gpu处理器环境运行-3)
- [结果](#结果-3)
- [推理过程](#推理过程-1)
- [导出MindIR](#导出mindir-1)
- [模型描述](#模型描述)
- [性能](#性能)
- [评估性能](#评估性能)
......@@ -880,12 +878,6 @@ result:{'top_1_accuracy': 0.8976362179487182, 'top_5_accuracy': 0.99238782051282
result:{'top_1_accuracy': 0.8845152243589743, 'top_5_accuracy': 0.9914863782051282} ckpt=~/resnet18_cifar10/train_parallel/resnet-100_1562.ckpt
```
## 推理过程
### 导出MindIR
当前暂不支持导出MindIR。
# 模型描述
## 性能
......
......@@ -183,9 +183,11 @@ class TemperatureScheduler(callback.Callback):
"""
TemperatureScheduler for SLB.
"""
def __init__(self, model, epoch_size=100, t_start_val=1.0, t_start_time=0.2, t_end_time=0.6, t_factor=1.2):
def __init__(self, model, epoch_size=100, has_trained_epoch=0,
t_start_val=1.0, t_start_time=0.2, t_end_time=0.6, t_factor=1.2):
super().__init__()
self.epochs = epoch_size
self.has_trained_epoch = has_trained_epoch
self.t_start_val = t_start_val
self.t_start_time = t_start_time
self.t_end_time = t_end_time
......@@ -197,7 +199,7 @@ class TemperatureScheduler(callback.Callback):
Epoch_begin.
"""
cb_params = run_context.original_args()
epoch = cb_params.cur_epoch_num
epoch = cb_params.cur_epoch_num + self.has_trained_epoch
# Compute temperature value
t = self.t_start_val
t_start_epoch = int(self.epochs*self.t_start_time)
......@@ -208,9 +210,8 @@ class TemperatureScheduler(callback.Callback):
for _, cell in self.model.train_network.cells_and_names():
if cell.cls_name == 'QBNNFakeQuantizerPerLayer': # for QBNN
cell.set_temperature(t)
if epoch == t_end_epoch:
if epoch >= t_end_epoch:
cell.set_temperature_end_flag()
print('Temperature stops changing. Start applying one-hot to latent weights.')
def train_net():
......@@ -266,7 +267,7 @@ def train_net():
if algo:
algo_cb = algo.callback()
cb.append(algo_cb)
cb.append(TemperatureScheduler(model, config.epoch_size, config.t_start_val,
cb.append(TemperatureScheduler(model, config.epoch_size, config.has_trained_epoch, config.t_start_val,
config.t_start_time, config.t_end_time, config.t_factor))
ckpt_save_dir = set_save_ckpt_dir()
if config.save_checkpoint:
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment