diff --git a/official/cv/vgg16/README.md b/official/cv/vgg16/README.md index 06383bb60dae8c7c680946721815068d62da6a98..303d90c778cd60435af99fc10eb707e18e30da05 100644 --- a/official/cv/vgg16/README.md +++ b/official/cv/vgg16/README.md @@ -6,6 +6,7 @@ - [Dataset](#dataset) - [Dataset used: CIFAR-10](#dataset-used-cifar-10) - [Dataset used: ImageNet2012](#dataset-used-imagenet2012) + - [Dataset used: Custom Dataset](#dataset-used-custom-dataset) - [Dataset organize way](#dataset-organize-way) - [Features](#features) - [Mixed Precision](#mixed-precision) @@ -24,6 +25,10 @@ - [Evaluation Process](#evaluation-process) - [Evaluation](#evaluation-1) - [ONNX Evaluation](#onnx-evaluation) + - [Migration process](#Migration process) + - [Dataset split](#Dataset split) + - [Migration](#Migration) + - [Model quick start](#Model quick start) - [Inference Process](#inference-process) - [Export MindIR](#export-mindir) - [Infer on Ascend310](#infer-on-ascend310) @@ -66,6 +71,11 @@ Note that you can run the scripts based on the dataset mentioned in original pap - Data format: RGB images - Note: Data will be processed in src/dataset.py +### Dataset used: Custom Dataset + +- Data format: RGB images + - Note: Data will be processed in src/data_split.py,Used to divide training and validation sets. + #### Dataset organize way CIFAR-10 @@ -89,6 +99,21 @@ Note that you can run the scripts based on the dataset mentioned in original pap > └─validation_preprocess # evaluate dataset > ``` + Custom Dataset + + > Unzip the custom dataset to any path, the folder structure should contain the folder with the class name and all the pictures under this folder, as shown below: + > + > ```bash + > . + > └─dataset + > ├─class_name1 # class name + > ├─xx.jpg # All images corresponding to the class name + > ├─ ... + > ├─xx.jpg + > ├─class_name2 + > ├─ ... + > ``` + ## [Features](#contents) ### Mixed Precision @@ -141,6 +166,23 @@ bash scripts/run_distribute_train_gpu.sh [DATA_PATH] --dataset=[DATASET_TYPE] python eval.py --config_path=[YAML_CONFIG_PATH] --device_target="GPU" --dataset=[DATASET_TYPE] --data_dir=[DATA_PATH] --pre_trained=[PRE_TRAINED] > output.eval.log 2>&1 & ``` +- Running on CPU + +```python + +# run dataset processing example +python src/data_split.py --split_path [SPLIT_PATH] + +# run finetune example +python tine_tune.py --config_path [YAML_CONFIG_PATH] + +# run eval example +python eval.py --config_path [YAML_CONFIG_PATH] + +# quick start +python quick_start.py --config_path [YAML_CONFIG_PATH] +``` + - Running on [ModelArts](https://support.huaweicloud.com/modelarts/) ```bash @@ -300,6 +342,7 @@ python eval.py --config_path=[YAML_CONFIG_PATH] --device_target="GPU" --dataset= │ │ ├── var_init.py // network parameter init method │ ├── crossentropy.py // loss calculation │ ├── dataset.py // creating dataset + │ ├── data_split.py // CPU dataset split script │ ├── linear_warmup.py // linear leanring rate │ ├── warmup_cosine_annealing_lr.py // consine anealing learning rate │ ├── warmup_step_lr.py // step or multi step learning rate @@ -307,11 +350,14 @@ python eval.py --config_path=[YAML_CONFIG_PATH] --device_target="GPU" --dataset= ├── train.py // training script ├── eval.py // evaluation script ├── eval_onnx.py // ONNX evaluation script + ├── finetune.py // CPU transfer script + ├── quick_start.py // CPU quick start script ├── postprocess.py // postprocess script ├── preprocess.py // preprocess script ├── mindspore_hub_conf.py // mindspore_hub_conf script ├── cifar10_config.yaml // Configurations for cifar10 ├── imagenet2012_config.yaml // Configurations for imagenet2012 + ├── cpu_config.yaml // Configurations for CPU transfer ├── export.py // model convert script └── requirements.txt // requirements ``` @@ -414,6 +460,29 @@ initialize_mode: "KaimingNormal" # conv2d init mode has_dropout: True # whether using Dropout layer ``` +- config for vgg16, custom dataset + +```bash +num_classes: 5 # number of dataset categories +lr: 0.001 # learning rate +batch_size: 64 # batch size of input tensor +num_epoch: 10 # number of training epochs +momentum: 0.9 # momentum +pad_mode: 'pad' # pad mode for conv2d +padding: 0 # padding value for conv2d +has_bias: False # whether has bias in conv2d +batch_norm: False # whether has batch_norm in conv2d +initialize_mode: "KaimingNormal" # conv2d init mode +has_dropout: True # whether using Dropout layer +ckpt_file: "./vgg16_bn_ascend_v170_imagenet2012_official_cv_top1acc74.33_top5acc92.1.ckpt" # The path to the pretrained weights file used by the migration +save_file: "./vgg16.ckpt" # Weight file path saved after migration +train_path: "./datasets/train/" # Migration train set path +eval_path: "./datasets/test/" # Migration valid set path +split_path: "./datasets/" # Migration dataset path +infer_ckpt_path: "./vgg16.ckpt" # Weight file path used by CPU inference + +``` + ### [Training Process](#contents) #### Training @@ -538,6 +607,40 @@ top-1 accuracy: 0.7332 top-5 accuracy: 0.9149 ``` +## Migration process + +### Dataset split + +- The data set division process is as follows, the /train and /test folders will be generated in the dataset directory, and the training and validation set images will be saved. + +```bash +python src/data_split.py --split_path /dir_to_code/{SPLIT_PATH} +``` + +### Migration + +- The migration process is as follows. The pre-training weight file needs to be downloaded [(https://download.mindspore.cn/models/r1.7/vgg16_bn_ascend_v170_imagenet2012_official_cv_top1acc74.33_top5acc92.1.ckpt)](https://download.mindspore.cn/models/r1.7/vgg16_bn_ascend_v170_imagenet2012_official_cv_top1acc74.33_top5acc92.1.ckpt) to ./vgg16 folder. After the training is completed, the file is saved as ./vgg16.ckpt by default. + +```bash +python fine_tune.py --config_path /dir_to_code/cpu_config.yaml +``` + +### Eval + +- The migration process is as follows, you need to specify the weight file to be migrated (default is ./vgg16.ckpt). + +```bash +python eval.py --config_path /dir_to_code/cpu_config.yaml +``` + +### Model quick start + +- The quick start process is as follows, you need to specify the weight file path and dataset path after training. + +```bash +python quick_start.py --config_path /dir_to_code/cpu_config.yaml +``` + ## Inference Process ### [Export MindIR](#contents) diff --git a/official/cv/vgg16/README_CN.md b/official/cv/vgg16/README_CN.md index deaaa6602d371e8ed24a330a67ea1ee6b830f2a2..355a3ea163179d1ac2f25c63aaa6cea7992abff0 100644 --- a/official/cv/vgg16/README_CN.md +++ b/official/cv/vgg16/README_CN.md @@ -8,6 +8,7 @@ - [数据集](#数据集) - [使用的数据集:CIFAR-10](#使用的数据集cifar-10) - [使用的数据集:ImageNet2012](#使用的数据集imagenet2012) + - [使用的数据集:自定义数据集](#使用的数据集:自定义数据集) - [数据集组织方式](#数据集组织方式) - [特性](#特性) - [混合精度](#混合精度) @@ -25,6 +26,10 @@ - [GPU处理器环境运行VGG16](#gpu处理器环境运行vgg16) - [评估过程](#评估过程) - [评估](#评估-1) + - [迁移过程](#迁移过程) + - [数据集划分](#数据集划分) + - [数据集迁移](#数据集迁移) + - [quick start](#quick start) - [推理过程](#推理过程) - [导出MindIR](#导出mindir) - [在Ascend310执行推理](#在ascend310执行推理) @@ -67,6 +72,11 @@ VGG 16网络主要由几个基本模块(包括卷积层和池化层)和三 - 数据格式:RGB图像。 - 注:数据在src/dataset.py中处理。 +### 使用的数据集:自定义数据集 + +- 数据格式:RGB图像。 + - 注:注:数据在src/data_split.py中处理,用来划分训练、验证集。 + ### 数据集组织方式 CIFAR-10 @@ -90,6 +100,21 @@ VGG 16网络主要由几个基本模块(包括卷积层和池化层)和三 > └─validation_preprocess # 评估数据集 > ``` + 自定义数据集 + + > 将自定义数据集解压到任意路径,文件夹结构应包含类名的文件夹以及在此文件夹下的所有图片,如下所示: + > + > ```bash + > . + > └─dataset + > ├─class_name1 # 类名 + > ├─xx.jpg # 对应类名的所有图片 + > ├─ ... + > ├─xx.jpg + > ├─class_name2 + > ├─ ... + > ``` + ## 特性 ### 混合精度 @@ -142,6 +167,23 @@ bash scripts/run_distribute_train_gpu.sh [DATA_PATH] --dataset=[DATASET_TYPE] python eval.py --config_path=[YAML_CONFIG_PATH] --device_target="GPU" --dataset=[DATASET_TYPE] --data_dir=[DATA_PATH] --pre_trained=[PRE_TRAINED] > output.eval.log 2>&1 & ``` +- CPU处理器环境运行 + +```python + +# 数据集处理实例 +python src/data_split.py --split_path [SPLIT_PATH] + +# 迁移示例 +python fine_tune.py --config_path [YAML_CONFIG_PATH] + +# 评估示例 +python eval.py --config_path [YAML_CONFIG_PATH] + +# quick start示例 +python quick_start.py --config_path [YAML_CONFIG_PATH] +``` + - 在 ModelArts 进行训练 (如果你想在modelarts上运行,可以参考以下文档 [modelarts](https://support.huaweicloud.com/modelarts/)) ```bash @@ -303,14 +345,18 @@ python eval.py --config_path=[YAML_CONFIG_PATH] --device_target="GPU" --dataset= │ ├── linear_warmup.py // 线性学习率 │ ├── warmup_cosine_annealing_lr.py // 余弦退火学习率 │ ├── warmup_step_lr.py // 单次或多次迭代学习率 - │ ├──vgg.py // VGG架构 + │ ├── vgg.py // VGG架构 + │ ├── data_split.py // CPU迁移数据集划分脚本 ├── train.py // 训练脚本 ├── eval.py // 评估脚本 + ├── finetune.py // CPU迁移脚本 + ├── quick_start.py // CPU quick start脚本 ├── postprocess.py // 后处理脚本 ├── preprocess.py // 预处理脚本 ├── mindspore_hub_conf.py // mindspore hub 脚本 ├── cifar10_config.yaml // cifar10 配置文件 ├── imagenet2012_config.yaml // imagenet2012 配置文件 + ├── cpu_config.yaml // CPU迁移配置文件 ├── export.py // 模型格式转换脚本 └── requirements.txt // requirements ``` @@ -413,6 +459,29 @@ initialize_mode: "KaimingNormal" # conv2d init模式 has_dropout: True # 是否使用Dropout层 ``` +- 配置VGG16,自定义数据集 + +```bash +num_classes: 5 # 数据集类别数 +lr: 0.001 # 学习率 +batch_size: 64 # 输入张量批次大小 +num_epoch: 10 # 训练轮数 +momentum: 0.9 # 动量 +pad_mode: 'pad' # conv2d的填充方式 +padding: 0 # conv2d的填充值 +has_bias: False # conv2d是否有偏差 +batch_norm: False # 在conv2d中是否有batch_norm +initialize_mode: "KaimingNormal" # conv2d init模式 +has_dropout: True # 是否使用Dropout层 +ckpt_file: "./vgg16_bn_ascend_v170_imagenet2012_official_cv_top1acc74.33_top5acc92.1.ckpt" # 迁移使用的预训练权重文件路径 +save_file: "./vgg16.ckpt" # 迁移后保存的权重文件路径 +train_path: "./datasets/train/" # 迁移数据集训练集路径 +eval_path: "./datasets/test/" # 迁移数据集验证集路径 +split_path: "./datasets/" # 迁移数据集路径 +infer_ckpt_path: "./vgg16.ckpt" # CPU推理使用的权重文件路径 + +``` + ### 训练过程 #### 训练 @@ -504,6 +573,40 @@ after allreduce eval: top1_correct=36636, tot=50000, acc=73.27% after allreduce eval: top5_correct=45582, tot=50000, acc=91.16% ``` +## 迁移过程 + +### 数据集划分 + +- 数据集划分过程如下,会在数据集目录下生成/train和/test文件夹,保存训练、验证集图片。 + +```bash +python src/data_split.py --split_path /dir_to_code/{SPLIT_PATH} +``` + +### 数据集迁移 + +- 迁移过程如下,需要将预训练权重文件[(https://download.mindspore.cn/models/r1.7/vgg16_bn_ascend_v170_imagenet2012_official_cv_top1acc74.33_top5acc92.1.ckpt)](https://download.mindspore.cn/models/r1.7/vgg16_bn_ascend_v170_imagenet2012_official_cv_top1acc74.33_top5acc92.1.ckpt)下载到vgg16文件夹下,训练完成后默认将文件保存成./vgg16.ckpt。 + +```bash +python fine_tune.py --config_path /dir_to_code/cpu_config.yaml +``` + +### 数据集评估 + +- 迁移过程如下,需要指定迁移完成的权重文件(默认是./vgg16.ckpt)。 + +```bash +python eval.py --config_path /dir_to_code/cpu_config.yaml +``` + +### quick start + +- quick start过程如下,需要指定训练完成的权重文件路径和数据集路径。 + +```bash +python quick_start.py --config_path /dir_to_code/cpu_config.yaml +``` + ## 推理过程 ### [导出MindIR](#contents) diff --git a/official/cv/vgg16/cpu_config.yaml b/official/cv/vgg16/cpu_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c8d076c30dc36962dbfb66aae178d6badf5423ff --- /dev/null +++ b/official/cv/vgg16/cpu_config.yaml @@ -0,0 +1,58 @@ +# ============================================================================== +# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing) +enable_modelarts: False +# device options +device_target: "CPU" + +# dataset options +train_path: "./datasets/train/" +eval_path: "./datasets/test/" +split_path: "./datasets/" + +# finetune options +dataset: 'custom' +image_size: '224,224' +log_path: "outputs/" +num_classes: 5 +lr: 0.001 +batch_size: 64 +num_epochs: 10 +momentum: 0.9 +ckpt_file: "./vgg16_bn_ascend_v170_imagenet2012_official_cv_top1acc74.33_top5acc92.1.ckpt" +save_file: "./vgg16.ckpt" +initialize_mode: "KaimingNormal" +pad_mode: 'pad' +padding: 1 +has_bias: False +batch_norm: False +has_dropout: True + +# infer options +pre_trained: "./vgg16.ckpt" + + + +--- + +# Help description for each configuration + +# device options +device_target: "device where the code will be implemented." + +# dataset options +train_path: "the training dataset path" +eval_path: "the eval dataset path" +split_path: "the original dataset path to split" + +# finetune options +num_classes: "num of class in dataset" +lr: "learning rate" +batch_size: "batch size" +num_epochs: "num of train epochs" +momentum: "num of train momentum" +ckpt_file: "the .ckpt file used for finetune" +save_file: "the .ckpt for saving" + + +# infer options +pre_trained: "the .ckpt file path to infer" diff --git a/official/cv/vgg16/eval.py b/official/cv/vgg16/eval.py index d62a04e6fcb1f2fb7d8082f76aaea15f6ce23f71..f4cb72126ee7cdf3131f632c7c01857b5a9311f7 100644 --- a/official/cv/vgg16/eval.py +++ b/official/cv/vgg16/eval.py @@ -29,16 +29,22 @@ from mindspore.ops import functional as F from mindspore.common import dtype as mstype from src.utils.logging import get_logger -from src.vgg import vgg16 +from src.vgg import vgg16, Vgg from src.dataset import vgg_create_dataset from src.dataset import classification_dataset +from src.dataset import create_dataset from model_utils.moxing_adapter import config from model_utils.moxing_adapter import moxing_wrapper from model_utils.device_adapter import get_device_id, get_rank_id, get_device_num +from model_utils.config import get_config +from fine_tune import DenseHead, cfg + + class ParameterReduce(nn.Cell): """ParameterReduce""" + def __init__(self): super(ParameterReduce, self).__init__() self.cast = P.Cast() @@ -61,6 +67,7 @@ def get_top5_acc(top5_arg, gt_class): def modelarts_pre_process(): '''modelarts pre process function.''' + def unzip(zip_file, save_dir): import zipfile s_time = time.time() @@ -132,7 +139,6 @@ def run_eval(): config.rank = get_rank_id() config.group_size = get_device_num() - _enable_graph_kernel = config.device_target == "GPU" context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=_enable_graph_kernel, device_target=config.device_target, save_graphs=False) @@ -168,11 +174,24 @@ def run_eval(): config.models = sorted(models, key=f) else: config.models = [config.pre_trained,] - for model in config.models: - dataset = classification_dataset(config.data_dir, config.image_size, config.per_batch_size, mode='eval') + if config.dataset == "custom": + dataset = create_dataset(dataset_path=config.eval_path, do_train=False, + batch_size=config.batch_size, + eval_image_size=config.image_size, + enable_cache=False) + model_config = get_config() + network = Vgg(cfg['16'], num_classes=1000, args=model_config, batch_norm=True) + + # replace head + src_head = network.classifier[6] + in_channels = src_head.in_channels + head = DenseHead(in_channels, config.num_classes) + network.classifier[6] = head + else: + dataset = classification_dataset(config.data_dir, config.image_size, config.per_batch_size, mode='eval') + network = vgg16(config.num_classes, config, phase="test") eval_dataloader = dataset.create_tuple_iterator(output_numpy=True, num_epochs=1) - network = vgg16(config.num_classes, config, phase="test") # pre_trained load_param_into_net(network, load_checkpoint(model)) diff --git a/official/cv/vgg16/fine_tune.py b/official/cv/vgg16/fine_tune.py new file mode 100644 index 0000000000000000000000000000000000000000..eb870088e7e2428e72660f6ecdfb8d7da410f3b0 --- /dev/null +++ b/official/cv/vgg16/fine_tune.py @@ -0,0 +1,232 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import mindspore as ms +import mindspore.nn as nn +from mindspore.train import Model +from mindspore.train.callback import LossMonitor, TimeMonitor +from model_utils.config import get_config +from src.vgg import Vgg +from src.dataset import create_dataset + +ms.set_context(mode=ms.GRAPH_MODE, device_target="CPU", save_graphs=False) +ms.set_seed(21) + + +def import_data(train_dataset_path="./datasets/train/", eval_dataset_path="./datasets/test/", batch_size=32): + """ + Read the dataset + + Args: + train_dataset_path(string): the path of train dataset. + eval_dataset_path(string): the path of eval dataset. + batch_size(int): the batch size of dataset. Default: 32 + + Returns: + dataset_train: the train dataset + dataset_val: the val dataset + """ + + dataset_train = create_dataset(dataset_path=train_dataset_path, do_train=True, + batch_size=batch_size, train_image_size=224, + eval_image_size=224, + enable_cache=False, cache_session_id=None) + dataset_val = create_dataset(dataset_path=eval_dataset_path, do_train=False, + batch_size=batch_size, train_image_size=224, + eval_image_size=224, + enable_cache=False, cache_session_id=None) + # print sample data/label + data = next(dataset_train.create_dict_iterator()) + images = data["image"] + labels = data["label"] + print("Tensor of image", images.shape) # Tensor of image (18, 3, 224, 224) + print("Labels:", labels) # Labels: [1 0 0 0 1 1 1 1 0 0 1 1 1 0 1 0 0 0] + + return dataset_train, dataset_val + + +# define head layer +class DenseHead(nn.Cell): + def __init__(self, input_channel, num_classes): + super(DenseHead, self).__init__() + self.dense = nn.Dense(input_channel, num_classes) + + def construct(self, x): + return self.dense(x) + + +def init_weight(net, param_dict): + """init_weight""" + + # if config.pre_trained: + has_trained_epoch = 0 + has_trained_step = 0 + if param_dict: + if param_dict.get("epoch_num") and param_dict.get("step_num"): + has_trained_epoch = int(param_dict["epoch_num"].data.asnumpy()) + has_trained_step = int(param_dict["step_num"].data.asnumpy()) + + ms.load_param_into_net(net, param_dict) + print("has_trained_epoch:", has_trained_epoch) + print("has_trained_step:", has_trained_step) + return has_trained_epoch, has_trained_step + + +def eval_net(model_config, checkpoint_path='./vgg16.ckpt', + train_dataset_path="./datasets/train/", + eval_dataset_path="./datasets/test/", + batch_size=32): + """ + eval the accuracy of vgg16 for flower dataset + + Args: + + model_config(Config in './model_utils/config.py'): vgg16 config + checkpoint_path(string): model checkout path(end with '.ckpt'). Default: './vgg16.ckpt' + train_dataset_path: the train dataset path. Default: "./datasets/train/" + eval_dataset_path: the eval dataset path. Default: "./datasets/test/" + batch_size: the batch size of dataset. Default: 32 + Returns: + None + """ + + # define val dataset and model + _, data_val = import_data(train_dataset_path=train_dataset_path, + eval_dataset_path=eval_dataset_path, batch_size=batch_size) + net = Vgg(cfg['16'], num_classes=1000, args=model_config, batch_norm=True) + + # replace head + src_head = net.classifier[6] + in_channels = src_head.in_channels + head = DenseHead(in_channels, 5) + net.classifier[6] = head + + # load checkpoint + param_dict = ms.load_checkpoint(checkpoint_path) + ms.load_param_into_net(net, param_dict) + net.set_train(False) + + # define loss + from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits + loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + + # define model + model = ms.Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'}) + + # eval step + res = model.eval(data_val) + + # show accuracy + print("result:", res) + + +cfg = { + '11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + '13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + '16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], + '19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], +} + + +def finetune_train(model_config, + finetune_checkpoint_path= + './vgg16_bn_ascend_v170_imagenet2012_official_cv_top1acc74.33_top5acc92.1.ckpt', + save_checkpoint_path="./vgg16.ckpt", + train_dataset_path="./datasets/train/", + eval_dataset_path="./datasets/test/", + class_num=5, + num_epochs=10, + learning_rate=0.001, + momentum=0.9, + batch_size=32 + ): + """ + finetune the flower dataset for vgg16 + + Args: + model_config(Config in './model_utils/config.py'): vgg16 config + class_num(int): the num of class for dataset. Default: 5 + num_epochs(int): the training epoch. Default: 10 + save_checkpoint_path(string): model checkout path for save(end with '.ckpt'). Default: ./vgg16.ckpt + train_dataset_path(string): the train dataset path. Default: "./datasets/train/" + eval_dataset_path(string): the eval dataset path. Default: "./datasets/test/" + finetune_checkpoint_path(string): model checkout path for initialize + Default: ./vgg16_bn_ascend_v170_imagenet2012_official_cv_top1acc74.33_top5acc92.1.ckpt + learning_rate: the finetune learning rate + momentum: the finetune momentum + batch_size: the batch size of dataset. Default: 32 + Returns: + None + """ + + # read train/val dataset + dataset_train, _ = import_data(train_dataset_path=train_dataset_path, + eval_dataset_path=eval_dataset_path, + batch_size=batch_size) + + ckpt_param_dict = ms.load_checkpoint(finetune_checkpoint_path) + net = Vgg(cfg['16'], num_classes=1000, args=model_config, batch_norm=True) + init_weight(net=net, param_dict=ckpt_param_dict) + print("net parameter:") + for param in net.get_parameters(): + print("param:", param) + + # replace head + src_head = net.classifier[6] + print("classifier.6.bias:", net.classifier[6]) + in_channels = src_head.in_channels + head = DenseHead(in_channels, class_num) + net.classifier[6] = head + + # freeze the param except last layer + for param in net.get_parameters(): + if param.name not in ["classifier.6.dense.weight", "classifier.6.dense.bias"]: + param.requires_grad = False + + # define optimizer and loss + opt = nn.Momentum(params=net.trainable_params(), learning_rate=learning_rate, momentum=momentum) + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + + # define model + model = Model(net, loss, opt, metrics={"Accuracy": nn.Accuracy()}) + + # define callbacks + batch_num = dataset_train.get_dataset_size() + time_cb = TimeMonitor(data_size=batch_num) + loss_cb = LossMonitor() + callbacks = [time_cb, loss_cb] + + # do training + model.train(num_epochs, dataset_train, callbacks=callbacks) + ms.save_checkpoint(net, save_checkpoint_path) + + +if __name__ == '__main__': + config = get_config() + print("config:", config) + # finetune + finetune_train(config, + finetune_checkpoint_path=config.ckpt_file, + save_checkpoint_path=config.save_file, train_dataset_path=config.train_path, + eval_dataset_path=config.eval_path, num_epochs=config.num_epochs, class_num=config.num_classes, + learning_rate=config.lr, + momentum=config.momentum, + batch_size=config.batch_size + ) + + # eval + eval_net(config, checkpoint_path=config.save_file, train_dataset_path=config.train_path, + eval_dataset_path=config.eval_path, + batch_size=config.batch_size) # 0.8505434782608695 diff --git a/official/cv/vgg16/quick_start.py b/official/cv/vgg16/quick_start.py new file mode 100644 index 0000000000000000000000000000000000000000..672db76509584ee79f8b5909dbbfaf924c1ce6d4 --- /dev/null +++ b/official/cv/vgg16/quick_start.py @@ -0,0 +1,104 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""inference for CPU""" +import matplotlib.pyplot as plt +import numpy as np +from mindspore import Tensor, load_checkpoint, load_param_into_net, nn +from mindspore.train import Model +from fine_tune import import_data +from model_utils.moxing_adapter import config +from src.vgg import Vgg + +# class_name for dataset +class_name = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips'] + +cfg = { + '11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + '13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + '16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], + '19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], +} + + +# define head layer +class DenseHead(nn.Cell): + def __init__(self, input_channel, num_classes): + super(DenseHead, self).__init__() + self.dense = nn.Dense(input_channel, num_classes) + + def construct(self, x): + return self.dense(x) + + +def visualize_model(best_ckpt_path, val_ds, num_classes): + """ + visualize model + + Args: + val_ds: eval dataset + best_ckpt_path(string): the .ckpt file for model to infer + num_classes(int): the class num + + Returns: + None + """ + + net = Vgg(cfg['16'], num_classes=1000, args=config, batch_norm=True) + + # replace head + src_head = net.classifier[6] + in_channels = src_head.in_channels + head = DenseHead(in_channels, num_classes) + net.classifier[6] = head + + # load param + param_dict = load_checkpoint(best_ckpt_path) + load_param_into_net(net, param_dict) + + net.set_train(False) + model = Model(net) + + # load some image in eval dataset for prediction + for i in range(5): + next(val_ds.create_dict_iterator()) + data = next(val_ds.create_dict_iterator()) + images = data["image"].asnumpy() + labels = data["label"].asnumpy() + + output = model.predict(Tensor(data['image'])) + pred = np.argmax(output.asnumpy(), axis=1) + print("\nAccuracy:", (pred == labels).sum() / len(labels)) + + # show image + plt.figure(figsize=(15, 7)) + for i in range(len(labels)): + plt.subplot(4, 8, i + 1) + # show blue color if right,otherwise show red color + color = 'blue' if pred[i] == labels[i] else 'red' + plt.title('predict:{}'.format(class_name[pred[i]]), color=color) + picture_show = np.transpose(images[i], (1, 2, 0)) + mean = np.array([0.485, 0.456, 0.406]) + std = np.array([0.229, 0.224, 0.225]) + picture_show = std * picture_show + mean + picture_show = np.clip(picture_show, 0, 1) + plt.imshow(picture_show) + plt.axis('off') + plt.show() + + +if __name__ == '__main__': + _, dataset_val = import_data(train_dataset_path=config.train_path, eval_dataset_path=config.eval_path) + + visualize_model(config.pre_trained, dataset_val, config.num_classes) diff --git a/official/cv/vgg16/requirements.txt b/official/cv/vgg16/requirements.txt index f77643b0fcb4183b7ffea63ade810bddd4921e27..186aee6f307ba653449abfe935bf6cb6ab64e0f1 100644 --- a/official/cv/vgg16/requirements.txt +++ b/official/cv/vgg16/requirements.txt @@ -1,4 +1,5 @@ -numpy -onnxruntime-gpu -pillow -pyyaml +numpy +onnxruntime-gpu +pillow +pyyaml +matplotlib diff --git a/official/cv/vgg16/src/data_split.py b/official/cv/vgg16/src/data_split.py new file mode 100644 index 0000000000000000000000000000000000000000..c851166ac67f0a8b27e675cd6a33951912c4ee24 --- /dev/null +++ b/official/cv/vgg16/src/data_split.py @@ -0,0 +1,158 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""split for CPU dataset""" +import os +import shutil +import multiprocessing +import mindspore as ms +import mindspore.dataset as ds + + +def get_num_parallel_workers(num_parallel_workers): + """ + Get num_parallel_workers used in dataset operations. + If num_parallel_workers > the real CPU cores number, set num_parallel_workers = the real CPU cores number. + """ + cores = multiprocessing.cpu_count() + if isinstance(num_parallel_workers, int): + if cores < num_parallel_workers: + print("The num_parallel_workers {} is set too large, now set it {}".format(num_parallel_workers, cores)) + num_parallel_workers = cores + else: + print("The num_parallel_workers {} is invalid, now set it {}".format(num_parallel_workers, min(cores, 8))) + num_parallel_workers = min(cores, 8) + return num_parallel_workers + + +def create_dataset(dataset_path, do_train, batch_size=32, train_image_size=224, eval_image_size=224, + enable_cache=False, cache_session_id=None): + """ + create a train or eval flower dataset for vgg16 + + Args: + dataset_path(string): the path of dataset. + do_train(bool): whether dataset is used for train or eval. + batch_size(int): the batch size of dataset. Default: 32 + enable_cache(bool): whether tensor caching service is used for eval. Default: False + cache_session_id(int): If enable_cache, cache session_id need to be provided. Default: None + + Returns: + dataset + """ + + ds.config.set_prefetch_size(64) + data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=get_num_parallel_workers(12), shuffle=True) + + mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] + std = [0.229 * 255, 0.224 * 255, 0.225 * 255] + + # define map operations + if do_train: + trans = [ + ds.vision.RandomCropDecodeResize(train_image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), + ds.vision.RandomHorizontalFlip(prob=0.5) + ] + else: + trans = [ + ds.vision.Decode(), + ds.vision.Resize(256), + ds.vision.CenterCrop(eval_image_size) + ] + trans_norm = [ds.vision.Normalize(mean=mean, std=std), ds.vision.HWC2CHW()] + + type_cast_op = ds.transforms.TypeCast(ms.int32) + trans_work_num = 24 + data_set = data_set.map(operations=trans, input_columns="image", + num_parallel_workers=get_num_parallel_workers(trans_work_num)) + data_set = data_set.map(operations=trans_norm, input_columns="image", + num_parallel_workers=get_num_parallel_workers(12)) + # only enable cache for eval + if do_train: + enable_cache = False + if enable_cache: + if not cache_session_id: + raise ValueError("A cache session_id must be provided to use cache.") + eval_cache = ds.DatasetCache(session_id=int(cache_session_id), size=0) + data_set = data_set.map(operations=type_cast_op, input_columns="label", + num_parallel_workers=get_num_parallel_workers(12), + cache=eval_cache) + else: + data_set = data_set.map(operations=type_cast_op, input_columns="label", + num_parallel_workers=get_num_parallel_workers(12)) + + # apply batch operations + data_set = data_set.batch(batch_size, drop_remainder=True) + + return data_set + + +def generate_data(path="./"): + dirs = [] + abs_path = None + for abs_path, j, _ in os.walk(path): + print("abs_path:", abs_path) + if j: + dirs.append(j) + print(dirs) + + train_folder = os.path.exists(path + 'train') + if not train_folder: + os.makedirs(path + 'train') + test_folder = os.path.exists(path + 'test') + if not test_folder: + os.makedirs(path + 'test') + + for class_dir in dirs[0]: + print("path", path) + print("dir", class_dir) + files = os.listdir(path + class_dir) + train_set = files[: int(len(files) * 0.8)] + test_set = files[int(len(files) * 0.8):] + for file in train_set: + file_path = path + "train/" + class_dir + "/" + folder = os.path.exists(file_path) + if not folder: + os.makedirs(file_path) + src_file = path + class_dir + "/" + file + print("src_file:", src_file) + dst_file = file_path + file + print("dst_file:", dst_file) + shutil.copyfile(src_file, dst_file) + + for file in test_set: + file_path = path + "test/" + class_dir + "/" + folder = os.path.exists(file_path) + if not folder: + os.makedirs(file_path) + src_file = path + class_dir + "/" + file + dst_file = file_path + file + shutil.copyfile(src_file, dst_file) + + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--split_path", help="the path of dataset to be split") + args = parser.parse_args() + + generate_data(path=args.split_path) + + create_dataset(dataset_path=args.split_path + "train/", do_train=True, batch_size=32, train_image_size=224, + eval_image_size=224, + enable_cache=False, cache_session_id=None) + + +if __name__ == '__main__': + main() diff --git a/official/cv/vgg16/src/dataset.py b/official/cv/vgg16/src/dataset.py index fc632b1bec13ae96924c24addaaba2fd0f5cd070..6f1be68c49d734960191049491ca72572666cbc0 100644 --- a/official/cv/vgg16/src/dataset.py +++ b/official/cv/vgg16/src/dataset.py @@ -16,7 +16,9 @@ dataset processing. """ import os +import multiprocessing from PIL import Image, ImageFile +import mindspore as ms from mindspore.common import dtype as mstype import mindspore.dataset as de import mindspore.dataset.transforms as C @@ -25,6 +27,20 @@ from src.utils.sampler import DistributedSampler ImageFile.LOAD_TRUNCATED_IMAGES = True +def get_num_parallel_workers(num_parallel_workers): + """ + Get num_parallel_workers used in dataset operations. + If num_parallel_workers > the real CPU cores number, set num_parallel_workers = the real CPU cores number. + """ + cores = multiprocessing.cpu_count() + if isinstance(num_parallel_workers, int): + if cores < num_parallel_workers: + print("The num_parallel_workers {} is set too large, now set it {}".format(num_parallel_workers, cores)) + num_parallel_workers = cores + else: + print("The num_parallel_workers {} is invalid, now set it {}".format(num_parallel_workers, min(cores, 8))) + num_parallel_workers = min(cores, 8) + return num_parallel_workers def vgg_create_dataset(data_home, image_size, batch_size, rank_id=0, rank_size=1, training=True): """Data operations.""" @@ -163,6 +179,66 @@ def classification_dataset(data_dir, image_size, per_batch_size, rank=0, group_s return de_dataset +def create_dataset(dataset_path, do_train, batch_size=32, train_image_size=224, eval_image_size=224, + enable_cache=False, cache_session_id=None): + """ + create a train or eval flower dataset for vgg16 + + Args: + dataset_path(string): the path of dataset. + do_train(bool): whether dataset is used for train or eval. + batch_size(int): the batch size of dataset. Default: 32 + enable_cache(bool): whether tensor caching service is used for eval. Default: False + cache_session_id(int): If enable_cache, cache session_id need to be provided. Default: None + + Returns: + dataset + """ + de.config.set_prefetch_size(64) + data_set = de.ImageFolderDataset(dataset_path, num_parallel_workers=get_num_parallel_workers(12), shuffle=True) + + mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] + std = [0.229 * 255, 0.224 * 255, 0.225 * 255] + + # define map operations + if do_train: + trans = [ + de.vision.RandomCropDecodeResize(train_image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), + de.vision.RandomHorizontalFlip(prob=0.5) + ] + else: + trans = [ + de.vision.Decode(), + de.vision.Resize(256), + de.vision.CenterCrop(eval_image_size) + ] + trans_norm = [de.vision.Normalize(mean=mean, std=std), de.vision.HWC2CHW()] + + type_cast_op = de.transforms.TypeCast(ms.int32) + trans_work_num = 24 + data_set = data_set.map(operations=trans, input_columns="image", + num_parallel_workers=get_num_parallel_workers(trans_work_num)) + data_set = data_set.map(operations=trans_norm, input_columns="image", + num_parallel_workers=get_num_parallel_workers(12)) + # only enable cache for eval + if do_train: + enable_cache = False + if enable_cache: + if not cache_session_id: + raise ValueError("A cache session_id must be provided to use cache.") + eval_cache = de.DatasetCache(session_id=int(cache_session_id), size=0) + data_set = data_set.map(operations=type_cast_op, input_columns="label", + num_parallel_workers=get_num_parallel_workers(12), + cache=eval_cache) + else: + data_set = data_set.map(operations=type_cast_op, input_columns="label", + num_parallel_workers=get_num_parallel_workers(12)) + + # apply batch operations + data_set = data_set.batch(batch_size, drop_remainder=True) + + return data_set + class TxtDataset: """