Skip to content
Snippets Groups Projects
Unverified Commit c2ab2515 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!3594 [兰州大学][高校贡献][Mindspore][GhostNet]-CPU模型迁移训练+推理提交

Merge pull request !3594 from GaoXL/branch1
parents bb1f56d7 331b74f3
No related branches found
No related tags found
No related merge requests found
......@@ -24,6 +24,14 @@
- [导出MindIR](#导出MindIR)
- [在Ascend310执行推理](#在Ascend310执行推理)
- [结果](#结果)
- [迁移训练过程](#迁移训练过程)
- [迁移数据集处理](#迁移数据集处理)
- [迁移训练Ckpt获取](#迁移训练Ckpt获取)
- [用法](#用法)
- [结果](#结果)
- [迁移训练推理过程](#迁移训练推理过程)
- [用法](#用法)
- [结果](#结果)
- [模型描述](#模型描述)
- [性能](#性能)
- [评估性能](#评估性能)
......@@ -107,16 +115,18 @@ GhostNet的总体网络架构如下:[链接](https://arxiv.org/pdf/1911.11907.
├── CMakeLists.txt # ascend310推理
├── main.cc # ascend310推理
└── utils.cc # ascend310推理
├── scripts
├── run_distribute_train.sh # 启动Ascend分布式训练(8卡)
├── run_eval.sh # 启动Ascend评估
├── run_infer_310.sh # 启动Ascend310推理
└── run_standalone_train.sh # 启动Ascend单机训练(单卡)
├── scripts
├── run_distribute_train.sh # 启动Ascend分布式训练(8卡)
├── run_eval.sh # 启动Ascend评估
├── run_infer_310.sh # 启动Ascend310推理
└── run_standalone_train.sh # 启动Ascend单机训练(单卡)
├── src
├── config.py # 参数配置
├── dataset.py # 数据预处理
├── CrossEntropySmooth.py # ImageNet2012数据集的损失定义
├── lr_generator.py # 生成每个步骤的学习率
├── dense.py # 调整预训练模型全连接层
├── data_split.py # 切分迁移数据集脚本
├── ghostnet600.py
├── launch.py
└── ghostnet.py # ghostnet网络
......@@ -125,6 +135,7 @@ GhostNet的总体网络架构如下:[链接](https://arxiv.org/pdf/1911.11907.
├── export.py # 导出MindIR模型
├── postprocess.py # 310推理的后期处理
├── requirements.txt # 需求文件
├── fine_tune.py # 迁移训练网络
└── train.py # 训练网络
```
......@@ -255,6 +266,134 @@ bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
Total data: 50000, top1 accuracy: 0.73816, top5 accuracy: 0.9178.
```
# 迁移训练过程
## 迁移数据集处理
[根据提供的数据集链接下载数据集](https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz ),下载数据集后,将切分数据集脚本src.data_split.py放置在解压好的flower_photos目录下,运行data_split.py后会生成train文件夹及test文件夹,将train文件夹及test文件夹保存到fine_tune.py同级目录的新建文件夹dataset中。
## 迁移训练Ckpt获取
[根据提供的Ckpt链接下载预训练模型文件](https://download.mindspore.cn/model_zoo/research/cv/ghostnet/ ),下载“ghostnet_1x_pets.ckpt”、“ghostnet_nose_1x_pets.ckpt”以及“ghostnet600M_pets.ckpt”文件,保存在fine_tune.py同级目录新建pre_ckpt文件夹中。
## 用法
整理好文件后激活环境,开始训练。
因为ckpt仓有三个ckpt预训练文件,故迁移训练三个网络。
导入预训练模型“ghostnet_1x_pets.ckpt”,训练如下
```shell
python fine_tune.py --pre_trained ./pre_ckpt/ghostnet_1x_pets.ckpt
```
导入预训练模型“ghostnet_nose_1x_pets.ckpt”,训练如下
```shell
python fine_tune.py --pre_trained ./pre_ckpt/ghostnet_nose_1x_pets.ckpt
```
导入预训练模型“ghostnet600M_pets.ckpt”,训练如下
```shell
python fine_tune.py --pre_trained ./pre_ckpt/ghostnet600M_pets.ckpt
```
## 结果
- 使用flower_photos据集训练GhostNet
```text
# ghostnet_1x迁移训练结果
epoch: 1 step: 21, loss is 1.0636098384857178
Train epoch time: 295952.507 ms, per step time: 14092.977 ms
epoch: 2 step: 21, loss is 1.007066011428833
Train epoch time: 20480.770 ms, per step time: 975.275 ms
epoch: 3 step: 21, loss is 0.9204861521720886
Train epoch time: 20673.888 ms, per step time: 984.471 ms
...
epoch: 498 step: 21, loss is 0.5347862839698792
Train epoch time: 19795.049 ms, per step time: 942.621 ms
epoch: 499 step: 21, loss is 0.49817660450935364
Train epoch time: 19959.692 ms, per step time: 950.462 ms
epoch: 500 step: 21, loss is 0.5028425455093384
Train epoch time: 20185.629 ms, per step time: 961.220 ms
```
```text
# ghostnet_nose_1x迁移训练结果
epoch: 1 step: 21, loss is 1.1746268272399902
Train epoch time: 94845.916 ms, per step time: 4516.472 ms
epoch: 2 step: 21, loss is 1.0321934223175049
Train epoch time: 37248.247 ms, per step time: 1773.726 ms
epoch: 3 step: 21, loss is 0.9764260053634644
Train epoch time: 37365.344 ms, per step time: 1779.302 ms
...
epoch: 498 step: 21, loss is 0.5118361711502075
Train epoch time: 36716.475 ms, per step time: 1748.404 ms
epoch: 499 step: 21, loss is 0.5035715103149414
Train epoch time: 37642.484 ms, per step time: 1792.499 ms
epoch: 500 step: 21, loss is 0.49066391587257385
Train epoch time: 36474.781 ms, per step time: 1736.894 ms
```
```text
# ghostnet_600m迁移训练结果
epoch: 1 step: 21, loss is 1.2935304641723633
Train epoch time: 296802.766 ms, per step time: 14133.465 ms
epoch: 2 step: 21, loss is 1.356112003326416
Train epoch time: 44871.251 ms, per step time: 2136.726 ms
epoch: 3 step: 21, loss is 1.1128544807434082
Train epoch time: 45124.813 ms, per step time: 2148.801 ms
...
epoch: 498 step: 21, loss is 0.4896056652069092
Train epoch time: 45314.303 ms, per step time: 2157.824 ms
epoch: 499 step: 21, loss is 0.5079032182693481
Train epoch time: 45675.234 ms, per step time: 2175.011 ms
epoch: 500 step: 21, loss is 0.5031487345695496
Train epoch time: 45935.200 ms, per step time: 2187.390 ms
```
# 迁移训练推理过程
## 用法
设置好ckpt文件路径,通过python脚本开始推理。
ghostnet_1x网络推理
```shell
python eval.py --device_target CPU --checkpoint_path ./ckpt/ghostnet_1x_3-500_21.ckpt
```
ghostnet_nose_1x网络推理
```shell
python eval.py --device_target CPU --checkpoint_path ./ckpt/ghostnet_nose_1x-500_21.ckpt
```
ghostnet_600mx网络推理
```shell
python eval.py --device_target CPU --checkpoint_path ./ckpt/ghostnet600m_1-500_21.ckpt
```
## 结果
- 各个网络结构使用flower_photos数据集推理效果
```text
result: {'top_5_accuracy': 1.0, 'top_1_accuracy': 0.9207589285714286} ckpt= ./ckpt/ghostnet_1x_3-500_21.ckpt
```
```text
result: {'top_1_accuracy': 0.9252232142857143, 'top_5_accuracy': 1.0} ckpt= ./ckpt/ghostnet_nose_1x-500_21.ckpt
```
```text
result: {'top_1_accuracy': 0.9308035714285714, 'top_5_accuracy': 1.0} ckpt= ./ckpt/ghostnet_600m_1-500_21.ckpt
```
# 模型描述
## 性能
......@@ -273,12 +412,30 @@ Total data: 50000, top1 accuracy: 0.73816, top5 accuracy: 0.9178.
| 损失函数 |Softmax交叉熵 |
| 输出 | 概率 |
| 损失 | 1.7887309 |
|速度|223.92毫秒/步(8卡) |
|速度|203.92毫秒/步(8卡) |
|总时长 | 39小时 |
|参数(M) | 5.18 |
| 微调检查点 | 42.05M(.ckpt文件) |
| 脚本 | [链接](https://gitee.com/mindspore/models/tree/master/research/cv/ghostnet) |
| 参数 | GPU; CPU |
|:---:|:---:|
| 模型版本 | GhostNet |
| 资源 | GeForce RTX 3090;CPU:3.60GHz,4核;内存:8G |
| 上传日期 |2022-09-05 ; |
| MindSpore版本 | 1.8.1 |
| 数据集 | flower_photos |
| 训练参数 | epoch=500, steps per epoch=21, batch_size = 128 |
| 优化器 | Momentum |
| 损失函数 |Softmax交叉熵 |
| 输出 | 概率 |
| 损失 | 0.5028425455093384 |
|速度|20185.629毫秒/步 |
|总时长 | 2.79小时 |
|参数(M) | 5.18 |
| 微调检查点 | 29.9M(.ckpt文件) |
| 脚本 | [链接](https://gitee.com/mindspore/models/tree/master/research/cv/ghostnet) |
# 随机情况说明
dataset.py中设置了“create_dataset”函数内的种子,同时还使用了train.py中的随机种子。
......
......@@ -22,24 +22,35 @@ from mindspore import nn
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.dataset import create_dataset
from src.ghostnet import ghostnet_1x
from src.ghostnet import ghostnet_1x, ghostnet_nose_1x
from src.ghostnet600 import ghostnet_600m
from src.config import config
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--data_url', type=str, default=None, help='Dataset path')
parser.add_argument('--num_classes', type=int, default=5, help='Classes number')
parser.add_argument('--device_target', type=str, default='CPU', help='Device platform')
parser.add_argument('--checkpoint_path', type=str, default='./ckpt/ghostnet_1x_2-500_21.ckpt',
help='Checkpoint file path')
parser.add_argument('--data_url', type=str, default='./dataset/', help='Dataset path')
args_opt = parser.parse_args()
if __name__ == '__main__':
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",
device_id=device_id, save_graphs=False)
device_id = int(os.getenv('DEVICE_ID', '0'))
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id,
save_graphs=False)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net = ghostnet_1x()
if 'ghostnet_1x' in args_opt.checkpoint_path:
net = ghostnet_1x(num_classes=args_opt.num_classes)
elif 'ghostnet_nose_1x' in args_opt.checkpoint_path:
net = ghostnet_nose_1x(num_classes=args_opt.num_classes)
elif 'ghostnet600M' in args_opt.checkpoint_path:
net = ghostnet_600m(num_classes=args_opt.num_classes)
dataset = create_dataset(dataset_path=args_opt.data_url, do_train=False)
dataset = create_dataset(dataset_path=args_opt.data_url, do_train=False, batch_size=config.batch_size,
num_parallel_workers=None)
step_size = dataset.get_dataset_size()
if args_opt.checkpoint_path:
......
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train ghostnet."""
import os
import argparse
import ast
from mindspore import context
from mindspore import nn
from mindspore import Tensor
from mindspore.train.model import Model
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.serialization import load_checkpoint
from mindspore.common import dtype as mstype
from mindspore.common import set_seed
from mindspore.nn.optim.momentum import Momentum
from mindspore.communication.management import init, get_rank
from mindspore.context import ParallelMode
from src.lr_generator import get_lr
from src.CrossEntropySmooth import CrossEntropySmooth
from src.dataset import create_dataset
from src.config import config
from src.ghostnet import ghostnet_1x, ghostnet_nose_1x
from src.ghostnet600 import ghostnet_600m
from src.dense import init_weight
parser = argparse.ArgumentParser(description='GhostNet')
parser.add_argument('--num_classes', type=int, default=5, help='Classes number')
parser.add_argument('--device_target', type=str, default='CPU', help='Device platform')
parser.add_argument('--save_checkpoint_path1', type=str, default='./ckpt/', help='Save path of ckpt file')
parser.add_argument('--data_url', type=str, default='./dataset/', help='Dataset path')
parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='Run distribute')
parser.add_argument('--pre_trained', type=str, default='./pre_ckpt/ghostnet_1x_pets.ckpt',
help='Pretrained checkpoint path')
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank')
args_opt = parser.parse_args()
set_seed(1)
if __name__ == '__main__':
# init context
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, save_graphs=False)
if args_opt.run_distribute:
device_id = int(os.getenv('DEVICE_ID'))
rank_size = int(os.environ.get("RANK_SIZE", 1))
print(rank_size)
device_num = rank_size
context.set_context(device_id=device_id)
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
init()
args_opt.rank = get_rank()
# select for master rank save ckpt or all rank save, compatible for model parallel
args_opt.rank_save_ckpt_flag = 0
if args_opt.is_save_on_master:
if args_opt.rank == 0:
args_opt.rank_save_ckpt_flag = 1
else:
args_opt.rank_save_ckpt_flag = 1
# define net
if 'ghostnet_1x_pets.ckpt' in args_opt.pre_trained:
net = ghostnet_1x(num_classes=args_opt.num_classes)
elif 'ghostnet_nose_1x_pets.ckpt' in args_opt.pre_trained:
net = ghostnet_nose_1x(num_classes=args_opt.num_classes)
elif 'ghostnet600M_pets.ckpt' in args_opt.pre_trained:
net = ghostnet_600m(num_classes=args_opt.num_classes)
net.to_float(mstype.float16)
for _, cell in net.cells_and_names():
if isinstance(cell, nn.Dense):
cell.to_float(mstype.float32)
local_data_path = args_opt.data_url
print('Download data:')
dataset = create_dataset(dataset_path=local_data_path, do_train=True, batch_size=config.batch_size,
num_parallel_workers=None)
step_size = dataset.get_dataset_size()
print('steps:', step_size)
# init weight
ckpt_param_dict = load_checkpoint(args_opt.pre_trained)
init_weight(net=net, param_dict=ckpt_param_dict)
# init lr
lr = get_lr(lr_init=config.lr_init, lr_end=config.lr_end,
lr_max=config.lr_max, warmup_epochs=config.warmup_epochs,
total_epochs=config.epoch_size, steps_per_epoch=step_size)
lr = Tensor(lr)
if not config.use_label_smooth:
config.label_smooth_factor = 0.0
loss = CrossEntropySmooth(sparse=True, reduction="mean",
smooth_factor=config.label_smooth_factor, num_classes=args_opt.num_classes)
opt = Momentum(net.trainable_params(), lr, config.momentum, loss_scale=config.loss_scale,
weight_decay=config.weight_decay)
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale,
metrics={'top_1_accuracy', 'top_5_accuracy'},
amp_level="O3", keep_batchnorm_fp32=False)
# define callbacks
time_cb = TimeMonitor(data_size=step_size)
loss_cb = LossMonitor()
cb = [time_cb, loss_cb]
if config.save_checkpoint:
if args_opt.rank_save_ckpt_flag:
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
keep_checkpoint_max=config.keep_checkpoint_max)
if 'ghostnet_1x_pets.ckpt' in args_opt.pre_trained:
ckpt_cb = ModelCheckpoint(prefix="ghostnet_1x", directory=args_opt.save_checkpoint_path1,
config=config_ck)
elif 'ghostnet_nose_1x_pets.ckpt' in args_opt.pre_trained:
ckpt_cb = ModelCheckpoint(prefix="ghostnet_nose_1x", directory=args_opt.save_checkpoint_path1,
config=config_ck)
elif 'ghostnet600M_pets.ckpt' in args_opt.pre_trained:
ckpt_cb = ModelCheckpoint(prefix="ghostnet600M", directory=args_opt.save_checkpoint_path1,
config=config_ck)
cb += [ckpt_cb]
# train model
model.train(config.epoch_size, dataset, callbacks=cb, sink_size=dataset.get_dataset_size())
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""quick start"""
# ## This paper mainly visualizes the prediction data, uses the model to predict, and visualizes the prediction results.
import argparse
import numpy as np
import matplotlib.pyplot as plt
from mindspore import nn
from mindspore import Tensor
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.dataset import create_dataset
from src.ghostnet import ghostnet_1x, ghostnet_nose_1x
from src.ghostnet600 import ghostnet_600m
parser = argparse.ArgumentParser(description='Quick start')
parser.add_argument('--num_classes', type=int, default=5, help='Classes number')
parser.add_argument('--ckpt_path', type=str, default='./ckpt/ghostnet_1x_2-500_21.ckpt', help='Checkpoint file path')
parser.add_argument('--data_path', type=str, default='./dataset/', help='Dataset path')
args_opt = parser.parse_args()
# class_name corresponds to label,and labels are marked in the order of folders
class_name = {0: "daisy", 1: "dandelion", 2: "roses", 3: "sunflowers", 4: "tulips"}
# define visual prediction data functions:
def visual_input_data(val_dataset):
data = next(val_dataset.create_dict_iterator())
images = data["image"]
labels = data["label"]
print("Tensor of image", images.shape)
print("Labels:", labels)
plt.figure(figsize=(15, 7))
for i in range(len(labels)):
# get the image and its corresponding label
data_image = images[i].asnumpy()
# data_label = labels[i]
# process images for display
data_image = np.transpose(data_image, (1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
data_image = std * data_image + mean
data_image = np.clip(data_image, 0, 1)
# display image
plt.subplot(4, 8, i+1)
plt.imshow(data_image)
plt.title(class_name[int(labels[i].asnumpy())], fontsize=10)
plt.axis("off")
plt.show()
# define visualize_model(),visualize model prediction
def visualize_model(ckpt_path, val_ds):
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
if 'ghostnet_1x' in args_opt.ckpt_path:
net = ghostnet_1x(num_classes=args_opt.num_classes)
elif 'ghostnet_nose_1x' in args_opt.ckpt_path:
net = ghostnet_nose_1x(num_classes=args_opt.num_classes)
elif 'ghostnet600M' in args_opt.ckpt_path:
net = ghostnet_600m(num_classes=args_opt.num_classes)
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
net.set_train(False)
model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
data = next(val_ds.create_dict_iterator())
images = data["image"].asnumpy()
labels = data["label"].asnumpy()
flower_class_name = {0: "daisy", 1: "dandelion", 2: "roses", 3: "sunflowers", 4: "tulips"}
# prediction image category
output = model.predict(Tensor(data['image']))
pred = np.argmax(output.asnumpy(), axis=1)
# display the image and the predicted value of the image
plt.figure(figsize=(15, 7))
for i in range(len(labels)):
plt.subplot(4, 8, i + 1)
# if the prediction is correct, it is displayed in blue; if the prediction is wrong, it is displayed in red
color = 'blue' if pred[i] == labels[i] else 'red'
plt.title('predict:{}'.format(flower_class_name[pred[i]]), color=color)
picture_show = np.transpose(images[i], (1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
picture_show = std * picture_show + mean
picture_show = np.clip(picture_show, 0, 1)
plt.imshow(picture_show)
plt.axis('off')
plt.show()
if __name__ == '__main__':
ds = create_dataset(dataset_path=args_opt.data_path, do_train=False, batch_size=32, num_parallel_workers=None)
visual_input_data(ds)
visualize_model(args_opt.ckpt_path, ds)
easydict
numpy
matplotlib
\ No newline at end of file
......@@ -36,4 +36,5 @@ config = ed({
"save_checkpoint_epochs": 20,
"keep_checkpoint_max": 10,
"save_checkpoint_path": "./",
"filter_weight": True,
})
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
cpu_cut_data.
"""
import os
import shutil
def generate_data():
dirs = []
path = "./"
abs_path = None
for abs_path, j, _ in os.walk(path):
print("abs_path:", abs_path)
if len(j).__trunc__() > 0:
dirs.append(j)
print(dirs)
train_folder = os.path.exists("./train")
if not train_folder:
os.makedirs("./train")
test_folder = os.path.exists("./test")
if not test_folder:
os.makedirs("./test")
for di in dirs[0]:
files = os.listdir(di)
train_set = files[: int(len(files) * 3 / 4)]
test_set = files[int(len(files) * 3 / 4):]
for file in train_set:
file_path = "./train/" + di + "/"
folder = os.path.exists(file_path)
if not folder:
os.makedirs(file_path)
src_file = "./" + di + "/" + file
print("src_file:", src_file)
dst_file = file_path + file
print("dst_file:", dst_file)
shutil.copyfile(src_file, dst_file)
for file in test_set:
file_path = "./test/" + di + "/"
folder = os.path.exists(file_path)
if not folder:
os.makedirs(file_path)
src_file = "./" + di + "/" + file
dst_file = file_path + file
shutil.copyfile(src_file, dst_file)
if __name__ == '__main__':
generate_data()
......@@ -16,42 +16,41 @@
import os
import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de
import mindspore.dataset.transforms as C2
import mindspore.dataset.vision as C
import mindspore.dataset.transforms.c_transforms as C2
import mindspore.dataset.vision.c_transforms as C
def create_dataset(dataset_path, do_train, repeat_num=1, infer_910=True, device_id=0, batch_size=128):
def create_dataset(dataset_path, do_train, infer_910=False, device_id=0, batch_size=128, num_parallel_workers=8):
"""
create a train or eval dataset
Args:
batch_size:
device_id:
infer_910:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
rank (int): The shard ID within num_shards (default=None).
group_size (int): Number of shards that the dataset should be divided into (default=None).
repeat_num(int): the repeat times of dataset. Default: 1.
dataset_path (string): The path of dataset.
do_train (bool): Whether dataset is used for train or eval.
infer_910 (bool): Whether to use Ascend 910.
device_id (int): Device id.
batch_size (int): Input image batch size.
num_parallel_workers (int): Number of workers to read the data.
Returns:
dataset
"""
device_num = 1
device_id = device_id
rank_id = os.getenv('RANK_ID', '0')
if infer_910:
device_id = int(os.getenv('DEVICE_ID'))
device_num = int(os.getenv('RANK_SIZE'))
if not do_train:
dataset_path = os.path.join(dataset_path, 'val')
dataset_path = os.path.join(dataset_path, 'test')
else:
dataset_path = os.path.join(dataset_path, 'train')
if device_num == 1:
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=num_parallel_workers, shuffle=True)
else:
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=num_parallel_workers, shuffle=True,
num_shards=device_num, shard_id=rank_id)
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
......@@ -75,8 +74,8 @@ def create_dataset(dataset_path, do_train, repeat_num=1, infer_910=True, device_
]
type_cast_op = C2.TypeCast(mstype.int32)
ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=8)
ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=8)
ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=num_parallel_workers)
ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=num_parallel_workers)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
......
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Adjust the fully connected layer and load the pretrained network"""
import mindspore as ms
from src.config import config
def filter_checkpoint_parameter_by_list(origin_dict, param_filter):
"""remove useless parameters according to filter_list"""
for key in list(origin_dict.keys()):
for name in param_filter:
if name in key:
print("Delete parameter from checkpoint: ", key)
del origin_dict[key]
break
def init_weight(net, param_dict):
"""init_weight"""
# if config.pre_trained:
if param_dict:
if config.filter_weight:
filter_list = [x.name for x in net.classifier.get_parameters()]
filter_checkpoint_parameter_by_list(param_dict, filter_list)
ms.load_param_into_net(net, param_dict)
......@@ -404,21 +404,20 @@ class GhostNet(nn.Cell):
for _, m in self.cells_and_names():
if isinstance(m, (nn.Conv2d)):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n),
m.weight.data.shape).astype("float32")))
m.weight.set_data(Tensor(np.random.normal(0, np.sqrt(2. / n), m.weight.data.shape).astype("float32")))
if m.bias is not None:
m.bias.set_parameter_data(
m.bias.set_data(
Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
elif isinstance(m, nn.BatchNorm2d):
m.gamma.set_parameter_data(
m.gamma.set_data(
Tensor(np.ones(m.gamma.data.shape, dtype="float32")))
m.beta.set_parameter_data(
m.beta.set_data(
Tensor(np.zeros(m.beta.data.shape, dtype="float32")))
elif isinstance(m, nn.Dense):
m.weight.set_parameter_data(Tensor(np.random.normal(
m.weight.set_data(Tensor(np.random.normal(
0, 0.01, m.weight.data.shape).astype("float32")))
if m.bias is not None:
m.bias.set_parameter_data(
m.bias.set_data(
Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment