Skip to content
Snippets Groups Projects
Commit a231bda1 authored by hangangqiang's avatar hangangqiang
Browse files

remove activation quantization of slb

parent 767f6278
Branches
No related tags found
No related merge requests found
......@@ -892,28 +892,22 @@ result:{'top_1_accuracy': 0.9354967948717948, 'top_5_accuracy': 0.99819711538461
result:{'top_1_accuracy': 0.9273838141025641} prune_rate=0.45 ckpt=~/resnet50_cifar10/train_parallel0/resnet-400_390.ckpt
```
- Apply SLB on ResNet18 with W4A8, and evaluating with CIFAR-10 dataset. W4A8 means quantize weight with 4bit and activation with 8bit:
- Apply SLB on ResNet18 with W4, and evaluating with CIFAR-10 dataset. W4 means quantize weight with 4bit:
```text
result:{'top_1_accuracy': 0.9285857371794872, 'top_5_accuracy': 0.9959935897435898} ckpt=~/resnet18_cifar10/train_parallel/resnet-100_1562.ckpt
result:{'top_1_accuracy': 0.9534254807692307, 'top_5_accuracy': 0.9969951923076923} ckpt=~/resnet18_cifar10/train_parallel/resnet-100_1562.ckpt
```
- Apply SLB on ResNet18 with W2A8, and evaluating with CIFAR-10 dataset. W2A8 means quantize weight with 2bit and activation with 8bit:
- Apply SLB on ResNet18 with W2, and evaluating with CIFAR-10 dataset. W2 means quantize weight with 2bit:
```text
result:{'top_1_accuracy': 0.9207732371794872, 'top_5_accuracy': 0.9955929487179487} ckpt=~/resnet18_cifar10/train_parallel/resnet-100_1562.ckpt
result:{'top_1_accuracy': 0.9503205128205128, 'top_5_accuracy': 0.9966947115384616} ckpt=~/resnet18_cifar10/train_parallel/resnet-100_1562.ckpt
```
- Apply SLB on ResNet18 with W1A8, and evaluating with CIFAR-10 dataset. W1A8 means quantize weight with 1bit and activation with 8bit:
- Apply SLB on ResNet18 with W1, and evaluating with CIFAR-10 dataset. W1 means quantize weight with 1bit:
```text
result:{'top_1_accuracy': 0.8976362179487182, 'top_5_accuracy': 0.9923878205128205} ckpt=~/resnet18_cifar10/train_parallel/resnet-100_1562.ckpt
```
- Apply SLB on ResNet18 with W1A4, and evaluating with CIFAR-10 dataset. W1A4 means quantize weight with 1bit and activation with 4bit:
```text
result:{'top_1_accuracy': 0.8845152243589743, 'top_5_accuracy': 0.9914863782051282} ckpt=~/resnet18_cifar10/train_parallel/resnet-100_1562.ckpt
result:{'top_1_accuracy': 0.9485176282051282, 'top_5_accuracy': 0.9965945512820513} ckpt=~/resnet18_cifar10/train_parallel/resnet-100_1562.ckpt
```
# [Model Description](#contents)
......
......@@ -720,7 +720,7 @@ bash run_infer_310.sh [MINDIR_PATH] [NET_TYPE] [DATASET] [DATA_PATH] [CONFIG_PAT
- 使用CIFAR-10数据集评估ResNet18
```bash
Total data: 10000, top1 accuracy: 0.94.26, top5 accuracy: 0.9987.
Total data: 10000, top1 accuracy: 0.9426, top5 accuracy: 0.9987.
```
- 使用ImageNet2012数据集评估ResNet18
......@@ -854,28 +854,22 @@ result:{'top_1_accuracy': 0.9354967948717948, 'top_5_accuracy': 0.99819711538461
result:{'top_1_accuracy': 0.9273838141025641} prune_rate=0.45 ckpt=~/resnet50_cifar10/train_parallel0/resnet-400_390.ckpt
```
- 使用SLB算法对ResNet18做W4A8量化,并使用CIFAR-10数据集评估,W4A8表示weight量化为4bit,activation量化为8bit
- 使用SLB算法对ResNet18做W4量化,并使用CIFAR-10数据集评估,W4表示weight量化为4bit:
```text
result:{'top_1_accuracy': 0.9285857371794872, 'top_5_accuracy': 0.9959935897435898} ckpt=~/resnet18_cifar10/train_parallel/resnet-100_1562.ckpt
result:{'top_1_accuracy': 0.9534254807692307, 'top_5_accuracy': 0.9969951923076923} ckpt=~/resnet18_cifar10/train_parallel/resnet-100_1562.ckpt
```
- 使用SLB算法对ResNet18做W2A8量化,并使用CIFAR-10数据集评估,W2A8表示weight量化为2bit,activation量化为8bit
- 使用SLB算法对ResNet18做W2量化,并使用CIFAR-10数据集评估,W2表示weight量化为2bit:
```text
result:{'top_1_accuracy': 0.9207732371794872, 'top_5_accuracy': 0.9955929487179487} ckpt=~/resnet18_cifar10/train_parallel/resnet-100_1562.ckpt
result:{'top_1_accuracy': 0.9503205128205128, 'top_5_accuracy': 0.9966947115384616} ckpt=~/resnet18_cifar10/train_parallel/resnet-100_1562.ckpt
```
- 使用SLB算法对ResNet18做W1A8量化,并使用CIFAR-10数据集评估,W1A8表示weight量化为1bit,activation量化为8bit
- 使用SLB算法对ResNet18做W1量化,并使用CIFAR-10数据集评估,W1表示weight量化为1bit:
```text
result:{'top_1_accuracy': 0.8976362179487182, 'top_5_accuracy': 0.9923878205128205} ckpt=~/resnet18_cifar10/train_parallel/resnet-100_1562.ckpt
```
- 使用SLB算法对ResNet18做W1A4量化,并使用CIFAR-10数据集评估,W1A4表示weight量化为1bit,activation量化为4bit:
```text
result:{'top_1_accuracy': 0.8845152243589743, 'top_5_accuracy': 0.9914863782051282} ckpt=~/resnet18_cifar10/train_parallel/resnet-100_1562.ckpt
result:{'top_1_accuracy': 0.9485176282051282, 'top_5_accuracy': 0.9965945512820513} ckpt=~/resnet18_cifar10/train_parallel/resnet-100_1562.ckpt
```
# 模型描述
......
......@@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
"""eval resnet."""
import mindspore as ms
import mindspore.log as logger
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
......
......@@ -60,7 +60,7 @@ eval_image_size: 224
# Golden-stick options
comp_algo: "SLB"
quant_type: "W1A8"
quant_type: "W1"
t_start_val: 1.0
t_start_time: 0.2
t_end_time: 0.6
......
......@@ -14,21 +14,15 @@
# ============================================================================
"""Create SLB-QAT algorithm instance."""
from mindspore_gs.quantization.slb import SlbQuantAwareTraining as QBNNQAT
from mindspore_gs.quantization.slb import SlbQuantAwareTraining as SlbQAT
from mindspore_gs.quantization.constant import QuantDtype
def create_slb(quant_type="graph_W4A8"):
algo = QBNNQAT()
if "W4A8" in quant_type:
def create_slb(quant_type="W1"):
algo = SlbQAT()
if "W4" in quant_type:
algo.set_weight_quant_dtype(QuantDtype.INT4)
algo.set_act_quant_dtype(QuantDtype.INT8)
elif "W2A8" in quant_type:
elif "W2" in quant_type:
algo.set_weight_quant_dtype(QuantDtype.INT2)
algo.set_act_quant_dtype(QuantDtype.INT8)
elif "W1A8" in quant_type:
elif "W1" in quant_type:
algo.set_weight_quant_dtype(QuantDtype.INT1)
algo.set_act_quant_dtype(QuantDtype.INT8)
elif "W1A4" in quant_type:
algo.set_weight_quant_dtype(QuantDtype.INT1)
algo.set_act_quant_dtype(QuantDtype.INT4)
return algo
......@@ -13,9 +13,9 @@
# limitations under the License.
# ============================================================================
"""train resnet."""
import os
import numpy as np
import mindspore as ms
import mindspore.nn as nn
import mindspore.train.callback as callback
......@@ -130,7 +130,7 @@ def load_pretrained_ckpt(net):
if config.pre_trained:
if os.path.isfile(config.pre_trained):
ckpt = ms.load_checkpoint(config.pre_trained)
if False and ckpt.get("epoch_num") and ckpt.get("step_num"):
if ckpt.get("epoch_num") and ckpt.get("step_num"):
config.has_trained_epoch = int(ckpt["epoch_num"].data.asnumpy())
config.has_trained_step = int(ckpt["step_num"].data.asnumpy())
else:
......@@ -208,7 +208,7 @@ class TemperatureScheduler(callback.Callback):
t *= self.t_factor**(min(epoch, t_end_epoch) - t_start_epoch)
# Assign new value to temperature parameter
for _, cell in self.model.train_network.cells_and_names():
if cell.cls_name == 'QBNNFakeQuantizerPerLayer': # for QBNN
if cell.cls_name == 'SlbFakeQuantizerPerLayer': # for SLB
cell.set_temperature(t)
if epoch >= t_end_epoch:
cell.set_temperature_end_flag()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment