Skip to content
Snippets Groups Projects
Commit 26babce0 authored by 185******25's avatar 185******25 Committed by 13488604207
Browse files

update official/cv/resnet/golden_stick/pruner/scop/train.py.

update official/cv/resnet/golden_stick/pruner/scop/resnet50_cifar10_config.yaml.

update official/cv/resnet/golden_stick/pruner/scop/eval.py.

add official/cv/resnet/golden_stick/scripts/ run_eval.sh.

update official/cv/resnet/golden_stick/pruner/scop/eval.py.

update official/cv/resnet/golden_stick/pruner/scop/eval.py.

fix ci

update official/cv/resnet/golden_stick/pruner/scop/train.py.

update official/cv/resnet/golden_stick/pruner/scop/eval.py.

update official/cv/resnet/golden_stick/pruner/scop/eval.py.

update official/cv/resnet/golden_stick/pruner/scop/resnet50_cifar10_config.yaml.

update official/cv/resnet/golden_stick/pruner/scop/train.py.

update official/cv/resnet/golden_stick/pruner/scop/eval.py.

update official/cv/resnet/golden_stick/pruner/scop/train.py.

update official/cv/resnet/golden_stick/scripts/run_eval.sh.

update official/cv/resnet/golden_stick/pruner/scop/eval.py.

update official/cv/resnet/golden_stick/pruner/scop/train.py.

update .jenkins/check/config/filter_pylint.txt.

update official/cv/resnet/golden_stick/pruner/scop/eval.py.
parent 9a13f02c
No related branches found
No related tags found
No related merge requests found
......@@ -10,6 +10,7 @@
"models/official/cv" "c-extension-no-member"
"models/official/nlp/bert_thor/src/bert_model.py" "redefined-outer-name"
"models/official/cv/resnet/golden_stick/pruner/scop/train.py" "protected-access"
"models/official/cv/resnet/golden_stick/pruner/scop/eval.py" "protected-access"
# research
"models/research/" "missing-docstring"
......
......@@ -14,9 +14,11 @@
# ============================================================================
"""eval resnet."""
import os
import numpy as np
import mindspore as ms
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore_gs import PrunerKfCompressAlgo, PrunerFtCompressAlgo
from mindspore_gs.pruner.scop.scop_pruner import KfConv2d, MaskedConv2dbn
from src.CrossEntropySmooth import CrossEntropySmooth
from src.resnet import resnet50 as resnet
from src.model_utils.config import config
......@@ -46,10 +48,17 @@ def eval_net():
# define net
net = resnet(class_num=config.class_num)
net = PrunerKfCompressAlgo({}).apply(net)
out_index = []
param_dict = ms.load_checkpoint(config.checkpoint_file_path)
for key in param_dict.keys():
if 'out_index' in key:
out_index.append(param_dict[key])
for _, (_, module) in enumerate(net.cells_and_names()):
if isinstance(module, KfConv2d):
module.out_index = out_index.pop(0)
net = PrunerFtCompressAlgo({}).apply(net)
# load checkpoint
param_dict = ms.load_checkpoint(config.checkpoint_file_path)
ms.load_param_into_net(net, param_dict)
net.set_train(False)
......@@ -68,7 +77,24 @@ def eval_net():
# eval model
res = model.eval(dataset)
print("result:", res, "prune_rate=", config.prune_rate, "ckpt=", config.checkpoint_file_path)
masked_conv_list = []
for imd, (nam, module) in enumerate(net.cells_and_names()):
if isinstance(module, MaskedConv2dbn):
masked_conv_list.append((nam, module))
for imd in range(len(masked_conv_list)):
if 'conv2' in masked_conv_list[imd][0] or 'conv3' in masked_conv_list[imd][0]:
masked_conv_list[imd][1].in_index = masked_conv_list[imd - 1][1].out_index
# Only use when calculate params, next version will provide the interface.
net = PrunerFtCompressAlgo({})._pruning_conv(net)
# calculate params
total_params = 0
for param in net.trainable_params():
total_params += np.prod(param.shape)
print("result:", res, "prune_rate=", config.prune_rate,
"ckpt=", config.checkpoint_file_path, "params=", total_params)
if __name__ == '__main__':
......
......@@ -19,12 +19,12 @@ checkpoint_file_path: ""
optimizer: "Momentum"
infer_label: ""
class_num: 10
batch_size: 32
batch_size: 64
loss_scale: 1024
momentum: 0.9
weight_decay: 0.0001
epoch_kf: 90
epoch_ft: 400
epoch_ft: 200
pretrain_epoch_size: 0
save_checkpoint: True
save_checkpoint_epochs: 5
......
......@@ -257,16 +257,6 @@ def train_net():
net_train_step = nn.TrainOneStepCell(net_with_loss, optimizer)
if config.pre_trained:
for _, (_, module) in enumerate(model.cells_and_names()):
if isinstance(module, KfConv2d):
module.score = module.bn.gamma.data.abs() * ops.Squeeze()(
module.kfscale.data - (1 - module.kfscale.data))
module.prune_rate = config.prune_rate
for _, (_, module) in enumerate(model.cells_and_names()):
if isinstance(module, KfConv2d):
_, index = ops.Sort()(module.score)
num_pruned_channel = int(module.prune_rate * module.score.shape[0])
module.out_index = index[num_pruned_channel:]
for param in model.get_parameters():
param.requires_grad = True
train_ft(model, dataset)
......@@ -304,15 +294,9 @@ def train_kf(dataset, net_train_step, model, kfconv_list, kfscale_list):
for param in model.get_parameters():
param.requires_grad = True
for kfscale in kfscale_list[10]:
print(ops.Squeeze()(kfscale).asnumpy())
for kfscale_last in kfscale_list:
print(ops.Squeeze()(kfscale_last[-1]).asnumpy())
for _, (_, module) in enumerate(model.cells_and_names()):
if isinstance(module, KfConv2d):
module.score = module.bn.gamma.data.abs() * ops.Squeeze()(module.kfscale.data - (1 - module.kfscale.data))
for kfconv in kfconv_list:
kfconv.prune_rate = config.prune_rate
for _, (_, module) in enumerate(model.cells_and_names()):
......@@ -326,10 +310,20 @@ def train_kf(dataset, net_train_step, model, kfconv_list, kfscale_list):
def train_ft(model, dataset):
"""train finetune."""
algo_ft = PrunerFtCompressAlgo({})
model = algo_ft.apply(model)
if config.pre_trained:
pre_ckpt = ms.load_checkpoint(config.pre_trained)
out_index = []
param_dict = ms.load_checkpoint(config.checkpoint_file_path)
for key in param_dict.keys():
if 'out_index' in key:
out_index.append(param_dict[key])
for _, (_, module) in enumerate(model.cells_and_names()):
if isinstance(module, KfConv2d):
module.out_index = out_index.pop(0)
model = algo_ft.apply(model)
ms.load_param_into_net(model, pre_ckpt)
else:
model = algo_ft.apply(model)
lr_ft_new = ms.Tensor(get_lr(lr_init=config.lr_init,
lr_end=config.lr_end_ft,
lr_max=config.lr_max_ft,
......
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
CURPATH="$(dirname "$0")"
if [ $# != 4 ]
then
echo "Usage: bash run_eval.sh [PYTHON_PATH] [CONFIG_FILE] [DATASET_PATH] [CHECKPOINT_PATH]"
echo "PYTHON_PATH represents path to directory of 'eval.py'."
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PYTHON_PATH=$(get_real_path $1)
CONFIG_FILE=$(get_real_path $2)
DATASET_PATH=$(get_real_path $3)
CKPT_FILE=$(get_real_path $4)
if [ ! -d $PYTHON_PATH ]
then
echo "error: PYTHON_PATH=$PYTHON_PATH is not a directory"
exit 1
fi
if [ ! -f $CONFIG_FILE ]
then
echo "error: CONFIG_FILE=$CONFIG_FILE is not a file"
exit 1
fi
if [ ! -d $DATASET_PATH ]
then
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
exit 1
fi
if [ ! -f $CKPT_FILE ]
then
echo "error: CKPT_FILE=$CKPT_FILE is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ${PYTHON_PATH}/*.py ./eval
cp -r ${CURPATH}/../../src ./eval
cd ./eval || exit
env > env.log
echo "start evaluation for device $DEVICE_ID"
python eval.py --config_path=$CONFIG_FILE --data_path=$DATASET_PATH --checkpoint_file_path=$CKPT_FILE \
--device_target="Ascend" &> log &
cd ..
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment