Skip to content
Snippets Groups Projects
Commit 43de83f1 authored by Zhang_Anthony's avatar Zhang_Anthony
Browse files

123

fix mard main.py

456

.

4566

..

..

...

...
parent e2e86d3e
No related branches found
No related tags found
No related merge requests found
Showing
with 1658 additions and 1 deletion
......@@ -31,3 +31,4 @@
"models/research/cv/squeezenet1_1/infer/mxbase/Squeezenet1_1ClassifyOpencv.h" "runtime/references"
"models/research/cv/squeezenet1_1/infer/mxbase/main_opencv.cpp" "runtime/references"
"models/research/cv/squeezenet1_1/infer/mxbase/Squeezenet1_1ClassifyOpencv.cpp" "runtime/references"
"models/official/cv/vgg16/infer/mxbase/src/Vgg16Classify.h" "runtime/references"
# 离线推理过程
## 准备容器环境
1、将源代码(vgg16_mindspore_1.3.0_code)上传至服务器任意目录(如:/home/data/),并进入该目录。
源码目录结构如下图所示:
```bash
/home/data/vgg16_mindspore_1.3.0_code
├── infer # MindX高性能预训练模型新增
│ └── README.md # 离线推理文档
│ ├── convert # 转换om模型命令,AIPP
│ │ ├──aipp_vgg16_rgb.config
│ │ └──atc.sh
│ ├── data # 包括模型文件、模型输入数据集、模型相关配置文件(如label、SDK的pipeline)
│ │ ├── input
│ │ │ └──ILSVRC2012_val_00000001.JPEG
│ │ └── config
│ │ │ ├──vgg16.cfg
│ │ │ └──vgg16.pipeline
│ ├── mxbase # 基于mxbase推理
│ │ ├── src
│ │ │ ├── Vgg16Classify.cpp
│ │ │ ├── Vgg16Classify.h
│ │ │ └── main.cpp
│ │ ├── CMakeLists.txt
│ │ └── build.sh
│ └── sdk # 基于sdk.run包推理;如果是C++实现,存放路径一样
│ │ ├── main.py
│ │ └── run.sh
│ └── util # 精度验证脚本
│ │ └──task_metric.py
│ ├──Dockerfile #容器文件
│ └──docker_start_infer.sh # 启动容器脚本
```
2、启动容器
执行以下命令,启动容器实例。
```bash
bash docker_start_infer.sh docker_image:tag model_dir
```
| 参数 | 说明 |
| -------------- | -------------------------------------------- |
| *docker_image* | 推理镜像名称,推理镜像请从Ascend Hub上下载。 |
| *tag* | 镜像tag,请根据实际配置,如:21.0.2。 |
| *model_dir* | 推理容器挂载路径,本例中为/home/data |
启动容器时会将推理芯片和数据路径挂载到容器中。
其中docker_start_infer.sh(vgg16_mindspore_1.3.0_code/infer/docker_start_infer.sh)内容如下。
docker_start_infer.sh文件内容
```shell
#!/bin/bash
docker_image=$1
model_dir=$2
if [ -z "${docker_image}" ]; then
echo "please input docker_image"
exit 1
fi
if [ ! -d "${model_dir}" ]; then
echo "please input model_dir"
exit 1
fi
docker run -it \
--device=/dev/davinci0 \ #请根据芯片的情况更改
--device=/dev/davinci_manager \
--device=/dev/devmm_svm \
--device=/dev/hisi_hdc \
-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
-v ${model_dir}:${model_dir} \
${docker_image} \
/bin/bash
```
## 转换模型
1、将vgg16.air模型放入/vgg16_mindspore_1.3.0_code/infer/data/model/目录下(model文件夹需要自己创建,转换前的air模型和转换后的om模型都放在此文件夹下)。
2、准备AIPP配置文件
AIPP需要配置aipp.config文件,在ATC转换的过程中插入AIPP算子,即可与DVPP处理后的数据无缝对接,AIPP参数配置请参见《[CANN 开发辅助工具指南 (推理)](https://support.huawei.com/enterprise/zh/ascend-computing/cann-pid-251168373?category=developer-documents&subcategory=auxiliary-development-tools)》中“ATC工具使用指南”。
aipp.config文件内容如下,该文件放在/vgg16_mindspore_1.3.0_code/infer/convert目录下
~~~ config
aipp_op {
aipp_mode: static
input_format: RGB888_U8
rbuv_swap_switch: true
min_chn_0: 123.675
min_chn_1: 116.28
min_chn_2: 103.33
var_reci_chn_0: 0.0171247538316637
var_reci_chn_1: 0.0175070028011204
var_reci_chn_2: 0.0174291938997821
}
~~~
3、进入vgg16_mindspore_1.3.0_code/infer/convert目录,执行命令**bash atc.sh ../data/model/vgg16.air**(本例中模型名称为vgg16.air)。利用ATC工具将air模型转换为om模型,om模型会自动放在vgg16_mindspore_1.3.0_code/infer/data/model/文件夹下。
atc.sh
~~~ shell
model=$1
/usr/local/Ascend/atc/bin/atc \
--model=$model \
--framework=1 \
--output=../data/model/vgg16 \
--input_shape="input:1,224,224,3" \
--enable_small_channel=1 \
--log=error \
--soc_version=Ascend310 \
--insert_op_conf=aipp_vgg16_rgb.config
~~~
参数说明:
- --model:待转换的air模型所在路径。
- --framework:1代表MindSpore框架。
- --output:转换后输出的om模型存放路径以及名称。
- --input_shape:输入数据的shape。
- --insert_op_conf:aipp配置文件所在路径。
## mxBase推理
1、添加环境变量
通过**vi ~/.bashrc**命令打开~/.bashrc文件,将下面的环境变量添加进当前环境,添加好环境变量以后退出文件编辑,执行**source ~/.bashrc**使环境变量生效。
```bash
export ASCEND_HOME="/usr/local/Ascend"
export ASCEND_VERSION="nnrt/latest"
export ARCH_PATTERN="."
export LD_LIBRARY_PATH="${MX_SDK_HOME}/lib/modelpostprocessors:${LD_LIBRARY_PATH}"
export MXSDK_OPENSOURCE_DIR="${MX_SDK_HOME}/opensource"
```
2、进入vgg16_mindspore_1.3.0_code/infer/mxbase目录,执行指令**bash build.sh**
build.sh
~~~ shell
path_cur=$(dirname $0)
function check_env()
{
# set ASCEND_VERSION to ascend-toolkit/latest when it was not specified by user
if [ ! "${ASCEND_VERSION}" ]; then
export ASCEND_VERSION=ascend-toolkit/latest
echo "Set ASCEND_VERSION to the default value: ${ASCEND_VERSION}"
else
echo "ASCEND_VERSION is set to ${ASCEND_VERSION} by user"
fi
if [ ! "${ARCH_PATTERN}" ]; then
# set ARCH_PATTERN to ./ when it was not specified by user
export ARCH_PATTERN=./
echo "ARCH_PATTERN is set to the default value: ${ARCH_PATTERN}"
else
echo "ARCH_PATTERN is set to ${ARCH_PATTERN} by user"
fi
}
function build_vgg16()
{
cd $path_cur
rm -rf build
mkdir -p build
cd build
cmake ..
make
ret=$?
if [ ${ret} -ne 0 ]; then
echo "Failed to build vgg16."
exit ${ret}
fi
make install
}
check_env
build_vgg16
~~~
3、执行**./vgg16 ../data/input 50000**,推理结果保存在当前目录下的mx_pred_result.txt文件下。
- 第一个参数(../data/input):图片输入路径。
- 第二个参数(50000):图片输入数量。
推理结果
~~~bash
ILSVRC2012_val_00047110 221,267,266,206,220,
ILSVRC2012_val_00014552 505,550,804,899,859,
ILSVRC2012_val_00006604 276,287,289,275,285,
ILSVRC2012_val_00016859 2,3,4,148,5,
ILSVRC2012_val_00020009 336,649,350,371,972,
ILSVRC2012_val_00025515 917,921,446,620,692,
ILSVRC2012_val_00046794 427,504,463,412,686,
ILSVRC2012_val_00035447 856,866,595,730,603,
ILSVRC2012_val_00016392 54,67,68,60,66,
ILSVRC2012_val_00023902 50,49,44,39,62,
ILSVRC2012_val_00000719 268,151,171,158,104,
......
~~~
4、验证精度,进入vgg16_mindspore_1.3.0_code/infer/util目录下,执行**python3.7 task_metric.py ../mxbase/mx_pred_result.txt imagenet2012_val.txt vgg16_mx_pred_result_acc.json 5**
参数说明:
- 第一个参数(../mxbase/mx_pred_result.txt):推理结果保存路径。
- 第二个参数(image2012_val.txt):验证集标签文件。
- 第三个参数(vgg16_mx_pred_result_acc.json):结果文件
- 第四个参数(5):"1"表示TOP-1准确率,“5”表示TOP-5准确率。
5、查看推理精度结果
~~~ bash
cat vgg16_mx_pred_result_acc.json
~~~
top-5推理精度
~~~ bash
"accuracy": [
0.73328,
0.83924,
0.8786,
0.90034,
0.91496
]
...
~~~
## MindX SDK推理
1、数据准备
将推理图片数据集放在vgg16_mindspore_1.3.0_code/infer/data/input目录下
- 本例推理使用的数据集是[ImageNet2012](http://www.image-net.org/)中的验证集,input目录下已有其中十张测试图片
- 测试集:6.4 GB,50, 000张图像
2、准备模型推理所需文件
(1)在“/home/data/vgg16_mindspore_1.3.0_code/infer/data/config/”目录下编写pipeline文件。
根据实际情况修改vgg16.pipeline文件中图片规格、模型路径、配置文件路径和标签路径。
更多介绍请参见《[mxManufacture 用户指南](https://ascend.huawei.com/#/software/mindx-sdk/sdk-detail)》中“基础开发”章节。
vgg16.pipeline
```pipeline
{
"im_vgg16": {
"stream_config": {
"deviceId": "0"
},
"mxpi_imagedecoder0": {
"props": {
"handleMethod": "opencv"
},
"factory": "mxpi_imagedecoder",
"next": "mxpi_imageresize0"
},
"mxpi_imageresize0": {
"props": {
"handleMethod": "opencv",
"resizeHeight": "256",
"resizeWidth": "256",
"resizeType": "Resizer_Stretch"
},
"factory": "mxpi_imageresize",
"next": "mxpi_imagecrop0:1"
},
"mxpi_imagecrop0": {
"props": {
"dataSource": "appsrc1",
"dataSourceImage": "mxpi_imageresize0",
"handleMethod": "opencv"
},
"factory": "mxpi_imagecrop",
"next": "mxpi_tensorinfer0"
},
"mxpi_tensorinfer0": {
"props": {
"dataSource": "mxpi_imagecrop0",
"modelPath": "../data/model/vgg16.om",
"waitingTime": "1",
"outputDeviceId": "-1"
},
"factory": "mxpi_tensorinfer",
"next": "mxpi_classpostprocessor0"
},
"mxpi_classpostprocessor0": {
"props": {
"dataSource": "mxpi_tensorinfer0",
"postProcessConfigPath": "../data/config/vgg16.cfg",
"labelPath": "../data/config/imagenet1000_clsidx_to_labels.names",
"postProcessLibPath": "/usr/local/sdk_home/mxManufacture/lib/modelpostprocessors/libresnet50postprocess.so"
},
"factory": "mxpi_classpostprocessor",
"next": "mxpi_dataserialize0"
},
"mxpi_dataserialize0": {
"props": {
"outputDataKeys": "mxpi_classpostprocessor0"
},
"factory": "mxpi_dataserialize",
"next": "appsink0"
},
"appsrc1": {
"props": {
"blocksize": "409600"
},
"factory": "appsrc",
"next": "mxpi_imagecrop0:0"
},
"appsrc0": {
"props": {
"blocksize": "409600"
},
"factory": "appsrc",
"next": "mxpi_imagedecoder0"
},
"appsink0": {
"props": {
"blocksize": "4096000"
},
"factory": "appsink"
}
}
}
```
参数说明:
- resizeHeight:图片缩放后高度,请根据实际需求尺寸输入。
- resizeWidth:图片缩放后宽度,请根据实际需求尺寸输入。
- modelPath:模型路径,请根据模型实际路径修改。
- postProcessConfigPath:模型配置文件路径,请根据模型配置文件的实际路径修改。
- labelPath:标签文件路径,请根据标签文件的实际路径修改。
(2)在“/home/data/vgg16_mindspore_1.3.0_code/infer/data/config/”目录下编写vgg16.cfg配置文件。
配置文件vgg16.cfg内容如下。
```cfg
CLASS_NUM=1000
SOFTMAX=false
TOP_K=5
```
(3)进入“/home/data/vgg16_mindspore_1.3.0_code/infer/sdk/”目录。
根据实际情况修改main.py文件中裁剪图片的位置和**pipeline**文件路径。
```pipeline
...
def _predict_gen_protobuf(self):
object_list = MxpiDataType.MxpiObjectList()
object_vec = object_list.objectVec.add()
object_vec.x0 = 16
object_vec.y0 = 16
object_vec.x1 = 240
object_vec.y1 = 240
...
def main():
pipeline_conf = "../data/config/vgg16.pipeline"
stream_name = b'im_vgg16'
args = parse_args()
result_fname = get_file_name(args.result_file)
pred_result_file = f"{result_fname}.txt"
dataset = GlobDataLoader(args.glob, limit=None)
with ExitStack() as stack:
predictor = stack.enter_context(Predictor(pipeline_conf, stream_name))
result_fd = stack.enter_context(open(pred_result_file, 'w'))
for fname, pred_result in predictor.predict(dataset):
result_fd.write(result_encode(fname, pred_result))
print(f"success, result in {pred_result_file}")
...
```
3、进入vgg16_mindspore_1.3.0_code/infer/sdk文件夹,执行**bash run.sh**,推理结果保存在当前目录下的vgg16_sdk_pred_result.txt文件中。
run.sh
~~~ shell
set -e
CUR_PATH=$(cd "$(dirname "$0")" || { warn "Failed to check path/to/run.sh" ; exit ; } ; pwd)
# Simple log helper functions
info() { echo -e "\033[1;34m[INFO ][MxStream] $1\033[1;37m" ; }
warn() { echo >&2 -e "\033[1;31m[WARN ][MxStream] $1\033[1;37m" ; }
export LD_LIBRARY_PATH=${MX_SDK_HOME}/lib:${MX_SDK_HOME}/opensource/lib:${MX_SDK_HOME}/opensource/lib64:/usr/local/Ascend/ascend-toolkit/latest/acllib/lib64:${LD_LIBRARY_PATH}
export GST_PLUGIN_SCANNER=${MX_SDK_HOME}/opensource/libexec/gstreamer-1.0/gst-plugin-scanner
export GST_PLUGIN_PATH=${MX_SDK_HOME}/opensource/lib/gstreamer-1.0:${MX_SDK_HOME}/lib/plugins
#to set PYTHONPATH, import the StreamManagerApi.py
export PYTHONPATH=$PYTHONPATH:${MX_SDK_HOME}/python
python3.7 main.py "../data/input/*" vgg16_sdk_pred_result.txt
exit 0
~~~
推理结果
~~~bash
ILSVRC2012_val_00047110 221,267,266,206,220
ILSVRC2012_val_00014552 505,550,804,899,859
ILSVRC2012_val_00006604 276,287,289,275,285
ILSVRC2012_val_00016859 2,3,4,148,5
ILSVRC2012_val_00020009 336,649,350,371,972
ILSVRC2012_val_00025515 917,921,446,620,692
ILSVRC2012_val_00046794 427,504,463,412,686
ILSVRC2012_val_00035447 856,866,595,730,603
ILSVRC2012_val_00016392 54,67,68,60,66
......
~~~
4、验证精度,进入vgg16_mindspore_1.3.0_code/infer/util目录下,执行**python3.7 task_metric.py ../sdk/vgg16_sdk_pred_result.txt imagenet2012_val.txt vgg16_sdk_pred_result_acc.json 5**
参数说明:
- 第一个参数(../sdk/vgg16_sdk_pred_result.txt):推理结果保存路径。
- 第二个参数(imagenet2012_val.txt):验证集标签文件。
- 第三个参数(vgg16_sdk_pred_result_acc.json):结果文件
- 第四个参数(5):"1"表示TOP-1准确率,“5”表示TOP-5准确率。
5、查看推理性能和精度
- 打开性能统计开关。将“enable_ps”参数设置为true,“ps_interval_time”参数设置为6。
**vi /usr/local/sdk_home/mxManufacture/config/sdk.conf**
```bash
# MindX SDK configuration file
# whether to enable performance statistics, default is false [dynamic config]
enable_ps=true
...
ps_interval_time=6
...
```
- 执行run.sh脚本 **bash run.sh**
- 在日志目录"/usr/local/sdk_home/mxManufacture/logs"查看性能统计结果。
```bash
performance--statistics.log.e2e.xxx
performance--statistics.log.plugin.xxx
performance--statistics.log.tpr.xxx
```
其中e2e日志统计端到端时间,plugin日志统计单插件时间。
- 查看推理精度
```bash
cat vgg16_sdk_pred_result_acc.json
```
top-5推理精度
```bash
"total": 50000,
"accuracy": [
0.73328,
0.83924,
0.8786,
0.90034,
0.91496
]
...
```
aipp_op {
aipp_mode: static
input_format: RGB888_U8
rbuv_swap_switch: true
min_chn_0: 123.675
min_chn_1: 116.28
min_chn_2: 103.33
var_reci_chn_0: 0.0171247538316637
var_reci_chn_1: 0.0175070028011204
var_reci_chn_2: 0.0174291938997821
}
\ No newline at end of file
#!/usr/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
model=$1
/usr/local/Ascend/atc/bin/atc \
--model=$model \
--framework=1 \
--output=../data/model/vgg16 \
--input_shape="input:1,224,224,3" \
--enable_small_channel=1 \
--log=error \
--soc_version=Ascend310 \
--insert_op_conf=aipp_vgg16_rgb.config
exit 0
CLASS_NUM=1000
SOFTMAX=false
TOP_K=5
{
"im_vgg16": {
"stream_config": {
"deviceId": "0"
},
"mxpi_imagedecoder0": {
"props": {
"handleMethod": "opencv"
},
"factory": "mxpi_imagedecoder",
"next": "mxpi_imageresize0"
},
"mxpi_imageresize0": {
"props": {
"handleMethod": "opencv",
"resizeHeight": "256",
"resizeWidth": "256",
"resizeType": "Resizer_Stretch"
},
"factory": "mxpi_imageresize",
"next": "mxpi_imagecrop0:1"
},
"mxpi_imagecrop0": {
"props": {
"dataSource": "appsrc1",
"dataSourceImage": "mxpi_imageresize0",
"handleMethod": "opencv"
},
"factory": "mxpi_imagecrop",
"next": "mxpi_tensorinfer0"
},
"mxpi_tensorinfer0": {
"props": {
"dataSource": "mxpi_imagecrop0",
"modelPath": "../data/model/vgg16.om",
"waitingTime": "1",
"outputDeviceId": "-1"
},
"factory": "mxpi_tensorinfer",
"next": "mxpi_classpostprocessor0"
},
"mxpi_classpostprocessor0": {
"props": {
"dataSource": "mxpi_tensorinfer0",
"postProcessConfigPath": "../data/config/vgg16.cfg",
"labelPath": "../data/config/imagenet1000_clsidx_to_labels.names",
"postProcessLibPath": "/usr/local/sdk_home/mxManufacture/lib/modelpostprocessors/libresnet50postprocess.so"
},
"factory": "mxpi_classpostprocessor",
"next": "mxpi_dataserialize0"
},
"mxpi_dataserialize0": {
"props": {
"outputDataKeys": "mxpi_classpostprocessor0"
},
"factory": "mxpi_dataserialize",
"next": "appsink0"
},
"appsrc1": {
"props": {
"blocksize": "409600"
},
"factory": "appsrc",
"next": "mxpi_imagecrop0:0"
},
"appsrc0": {
"props": {
"blocksize": "409600"
},
"factory": "appsrc",
"next": "mxpi_imagedecoder0"
},
"appsink0": {
"props": {
"blocksize": "4096000"
},
"factory": "appsink"
}
}
}
\ No newline at end of file
#!/bin/bash
docker_image=$1
model_dir=$2
if [ -z "${docker_image}" ]; then
echo "please input docker_image"
exit 1
fi
if [ ! -d "${model_dir}" ]; then
echo "please input model_dir"
exit 1
fi
docker run -it \
--device=/dev/davinci0 \
--device=/dev/davinci_manager \
--device=/dev/devmm_svm \
--device=/dev/hisi_hdc \
-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
-v ${model_dir}:${model_dir} \
${docker_image} \
/bin/bash
\ No newline at end of file
cmake_minimum_required(VERSION 3.10.0)
project(vgg16)
set(TARGET vgg16)
add_definitions(-DENABLE_DVPP_INTERFACE)
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
add_definitions(-Dgoogle=mindxsdk_private)
add_compile_options(-std=c++11 -fPIE -fstack-protector-all -fPIC -Wall)
add_link_options(-Wl,-z,relro,-z,now,-z,noexecstack -s -pie)
# Check environment variable
if(NOT DEFINED ENV{ASCEND_HOME})
message(FATAL_ERROR "please define environment variable:ASCEND_HOME")
endif()
if(NOT DEFINED ENV{ASCEND_VERSION})
message(WARNING "please define environment variable:ASCEND_VERSION")
endif()
if(NOT DEFINED ENV{ARCH_PATTERN})
message(WARNING "please define environment variable:ARCH_PATTERN")
endif()
set(ACL_INC_DIR $ENV{ASCEND_HOME}/$ENV{ASCEND_VERSION}/$ENV{ARCH_PATTERN}/acllib/include)
set(ACL_LIB_DIR $ENV{ASCEND_HOME}/$ENV{ASCEND_VERSION}/$ENV{ARCH_PATTERN}/acllib/lib64)
set(MXBASE_ROOT_DIR $ENV{MX_SDK_HOME})
set(MXBASE_INC ${MXBASE_ROOT_DIR}/include)
set(MXBASE_LIB_DIR ${MXBASE_ROOT_DIR}/lib)
set(MXBASE_POST_LIB_DIR ${MXBASE_ROOT_DIR}/lib/modelpostprocessors)
set(MXBASE_POST_PROCESS_DIR ${PROJECT_SOURCE_DIR}/src/include)
if(DEFINED ENV{MXSDK_OPENSOURCE_DIR})
set(OPENSOURCE_DIR $ENV{MXSDK_OPENSOURCE_DIR})
else()
set(OPENSOURCE_DIR ${MXBASE_ROOT_DIR}/opensource)
endif()
include_directories(${ACL_INC_DIR})
include_directories(${OPENSOURCE_DIR}/include)
include_directories(${OPENSOURCE_DIR}/include/opencv4)
include_directories(${MXBASE_INC})
include_directories(${MXBASE_POST_PROCESS_DIR})
link_directories(${ACL_LIB_DIR})
link_directories(${OPENSOURCE_DIR}/lib)
link_directories(${MXBASE_LIB_DIR})
link_directories(${MXBASE_POST_LIB_DIR})
add_executable(${TARGET} ./src/main.cpp ./src/Vgg16Classify.cpp)
target_link_libraries(${TARGET} glog cpprest mxbase resnet50postprocess opencv_world stdc++fs)
install(TARGETS ${TARGET} RUNTIME DESTINATION ${PROJECT_SOURCE_DIR}/)
\ No newline at end of file
#!/usr/bin/bash
path_cur=$(dirname $0)
function check_env()
{
# set ASCEND_VERSION to ascend-toolkit/latest when it was not specified by user
if [ ! "${ASCEND_VERSION}" ]; then
export ASCEND_VERSION=ascend-toolkit/latest
echo "Set ASCEND_VERSION to the default value: ${ASCEND_VERSION}"
else
echo "ASCEND_VERSION is set to ${ASCEND_VERSION} by user"
fi
if [ ! "${ARCH_PATTERN}" ]; then
# set ARCH_PATTERN to ./ when it was not specified by user
export ARCH_PATTERN=./
echo "ARCH_PATTERN is set to the default value: ${ARCH_PATTERN}"
else
echo "ARCH_PATTERN is set to ${ARCH_PATTERN} by user"
fi
}
function build_vgg16()
{
cd $path_cur
rm -rf build
mkdir -p build
cd build
cmake ..
make
ret=$?
if [ ${ret} -ne 0 ]; then
echo "Failed to build vgg16."
exit ${ret}
fi
make install
}
check_env
build_vgg16
exit 0
\ No newline at end of file
/*
* Copyright 2021. Huawei Technologies Co., Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <unistd.h>
#include <sys/stat.h>
#include <map> for map<>
#include <memory> for make_shared<>
#include <vector> for vector<>
#include <string> for string
#include "Vgg16Classify.h"
#include "MxBase/DeviceManager/DeviceManager.h"
#include "MxBase/Log/Log.h"
namespace {
const uint32_t YUV_BYTE_NU = 3;
const uint32_t YUV_BYTE_DE = 2;
const uint32_t VPC_H_ALIGN = 2;
}
APP_ERROR Vgg16Classify::Init(const InitParam &initParam) {
deviceId_ = initParam.deviceId;
APP_ERROR ret = MxBase::DeviceManager::GetInstance()->InitDevices();
if (ret != APP_ERR_OK) {
LogError << "Init devices failed, ret=" << ret << ".";
return ret;
}
ret = MxBase::TensorContext::GetInstance()->SetContext(initParam.deviceId);
if (ret != APP_ERR_OK) {
LogError << "Set context failed, ret=" << ret << ".";
return ret;
}
dvppWrapper_ = std::make_shared<MxBase::DvppWrapper>();
ret = dvppWrapper_->Init();
if (ret != APP_ERR_OK) {
LogError << "DvppWrapper init failed, ret=" << ret << ".";
return ret;
}
model_ = std::make_shared<MxBase::ModelInferenceProcessor>();
ret = model_->Init(initParam.modelPath, modelDesc_);
if (ret != APP_ERR_OK) {
LogError << "ModelInferenceProcessor init failed, ret=" << ret << ".";
return ret;
}
MxBase::ConfigData configData;
const std::string softmax = initParam.softmax ? "true" : "false";
const std::string checkTensor = initParam.checkTensor ? "true" : "false";
configData.SetJsonValue("CLASS_NUM", std::to_string(initParam.classNum));
configData.SetJsonValue("TOP_K", std::to_string(initParam.topk));
configData.SetJsonValue("SOFTMAX", softmax);
configData.SetJsonValue("CHECK_MODEL", checkTensor);
auto jsonStr = configData.GetCfgJson().serialize();
std::map<std::string, std::shared_ptr<void>> config;
config["postProcessConfigContent"] = std::make_shared<std::string>(jsonStr);
config["labelPath"] = std::make_shared<std::string>(initParam.labelPath);
post_ = std::make_shared<MxBase::Resnet50PostProcess>();
ret = post_->Init(config);
if (ret != APP_ERR_OK) {
LogError << "Resnet50PostProcess init failed, ret=" << ret << ".";
return ret;
}
tfile_.open("mx_pred_result.txt");
if (!tfile_) {
LogError << "Open result file failed.";
return APP_ERR_COMM_OPEN_FAIL;
}
return APP_ERR_OK;
}
APP_ERROR Vgg16Classify::DeInit() {
dvppWrapper_->DeInit();
model_->DeInit();
post_->DeInit();
MxBase::DeviceManager::GetInstance()->DestroyDevices();
tfile_.close();
return APP_ERR_OK;
}
APP_ERROR Vgg16Classify::ReadImage(const std::string &imgPath, cv::Mat &imageMat) {
imageMat = cv::imread(imgPath, cv::IMREAD_COLOR);
return APP_ERR_OK;
}
APP_ERROR Vgg16Classify::Resize(const cv::Mat &srcImageMat, cv::Mat &dstImageMat) {
static constexpr uint32_t resizeHeight = 256;
static constexpr uint32_t resizeWidth = 256;
cv::resize(srcImageMat, dstImageMat, cv::Size(resizeHeight, resizeWidth));
return APP_ERR_OK;
}
APP_ERROR Vgg16Classify::Crop(const cv::Mat &srcMat, cv::Mat &dstMat) {
static cv::Rect rectOfImg(16, 16, 224, 224);
dstMat = srcMat(rectOfImg).clone();
return APP_ERR_OK;
}
APP_ERROR Vgg16Classify::CVMatToTensorBase(const cv::Mat &imageMat, MxBase::TensorBase &tensorBase) {
const uint32_t dataSize = imageMat.cols * imageMat.rows * imageMat.channels();
MxBase::MemoryData MemoryDataDst(dataSize, MxBase::MemoryData::MEMORY_DEVICE, deviceId_);
MxBase::MemoryData MemoryDataSrc(imageMat.data, dataSize, MxBase::MemoryData::MEMORY_HOST_MALLOC);
APP_ERROR ret = MxBase::MemoryHelper::MxbsMallocAndCopy(MemoryDataDst, MemoryDataSrc);
if (ret != APP_ERR_OK) {
LogError << GetError(ret) << "Memory malloc failed.";
return ret;
}
std::vector<uint32_t> shape = {static_cast<uint32_t>(imageMat.rows),
static_cast<uint32_t>(imageMat.cols), static_cast<uint32_t>(imageMat.channels())};
tensorBase = MxBase::TensorBase(MemoryDataDst, false, shape, MxBase::TENSOR_DTYPE_UINT8);
return APP_ERR_OK;
}
APP_ERROR Vgg16Classify::Inference(const std::vector<MxBase::TensorBase> &inputs,
std::vector<MxBase::TensorBase> &outputs) {
auto dtypes = model_->GetOutputDataType();
for (size_t i = 0; i < modelDesc_.outputTensors.size(); ++i) {
std::vector<uint32_t> shape = {};
for (size_t j = 0; j < modelDesc_.outputTensors[i].tensorDims.size(); ++j) {
shape.push_back((uint32_t)modelDesc_.outputTensors[i].tensorDims[j]);
}
MxBase::TensorBase tensor(shape, dtypes[i], MxBase::MemoryData::MemoryType::MEMORY_DEVICE, deviceId_);
APP_ERROR ret = MxBase::TensorBase::TensorBaseMalloc(tensor);
if (ret != APP_ERR_OK) {
LogError << "TensorBaseMalloc failed, ret=" << ret << ".";
return ret;
}
outputs.push_back(tensor);
}
MxBase::DynamicInfo dynamicInfo = {};
dynamicInfo.dynamicType = MxBase::DynamicType::STATIC_BATCH;
auto startTime = std::chrono::high_resolution_clock::now();
APP_ERROR ret = model_->ModelInference(inputs, outputs, dynamicInfo);
auto endTime = std::chrono::high_resolution_clock::now();
double costMs = std::chrono::duration<double, std::milli>(endTime - startTime).count();
g_inferCost.push_back(costMs);
if (ret != APP_ERR_OK) {
LogError << "ModelInference failed, ret=" << ret << ".";
return ret;
}
return APP_ERR_OK;
}
APP_ERROR Vgg16Classify::PostProcess(const std::vector<MxBase::TensorBase> &inputs,
std::vector<std::vector<MxBase::ClassInfo>> &clsInfos) {
APP_ERROR ret = post_->Process(inputs, clsInfos);
if (ret != APP_ERR_OK) {
LogError << "Process failed, ret=" << ret << ".";
return ret;
}
return APP_ERR_OK;
}
APP_ERROR Vgg16Classify::Process(const std::string &imgPath) {
cv::Mat imageMat;
APP_ERROR ret = ReadImage(imgPath, imageMat);
if (ret != APP_ERR_OK) {
LogError << "ReadImage failed, ret=" << ret << ".";
return ret;
}
ret = Resize(imageMat, imageMat);
if (ret != APP_ERR_OK) {
LogError << "Resize failed, ret=" << ret << ".";
return ret;
}
cv::Mat cropImage;
ret = Crop(imageMat, cropImage);
if (ret != APP_ERR_OK) {
LogError << "Crop failed, ret=" << ret << ".";
return ret;
}
MxBase::TensorBase tensorBase;
ret = CVMatToTensorBase(cropImage, tensorBase);
if (ret != APP_ERR_OK) {
LogError << "CVMatToTensorBase failed, ret=" << ret << ".";
return ret;
}
std::vector<MxBase::TensorBase> inputs = {};
std::vector<MxBase::TensorBase> outputs = {};
inputs.push_back(tensorBase);
ret = Inference(inputs, outputs);
if (ret != APP_ERR_OK) {
LogError << "Inference failed, ret=" << ret << ".";
return ret;
}
std::vector<std::vector<MxBase::ClassInfo>> BatchClsInfos = {};
ret = PostProcess(outputs, BatchClsInfos);
if (ret != APP_ERR_OK) {
LogError << "PostProcess failed, ret=" << ret << ".";
return ret;
}
ret = SaveResult(imgPath, BatchClsInfos);
if (ret != APP_ERR_OK) {
LogError << "Export result to file failed, ret=" << ret << ".";
return ret;
}
return APP_ERR_OK;
}
APP_ERROR Vgg16Classify::SaveResult(const std::string &imgPath, const std::vector<std::vector<MxBase::ClassInfo>> \
&BatchClsInfos) {
uint32_t batchIndex = 0;
std::string fileName = imgPath.substr(imgPath.find_last_of("/") + 1);
size_t dot = fileName.find_last_of(".");
for (const auto &clsInfos : BatchClsInfos) {
std::string resultStr;
for (const auto &clsInfo : clsInfos) {
resultStr += std::to_string(clsInfo.classId) + ",";
}
tfile_ << fileName.substr(0, dot) << " " << resultStr << std::endl;
if (tfile_.fail()) {
LogError << "Failed to write the result to file.";
return APP_ERR_COMM_WRITE_FAIL;
}
batchIndex++;
}
return APP_ERR_OK;
}
/*
* Copyright 2021. Huawei Technologies Co., Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MXBASE_VGG16CLASSIFY_H
#define MXBASE_VGG16CLASSIFY_H
#include <string> for string
#include <vector> for vector<>
#include <memory> for shared_ptr<>
#include <opencv2/opencv.hpp>
#include "MxBase/DvppWrapper/DvppWrapper.h"
#include "MxBase/ModelInfer/ModelInferenceProcessor.h"
#include "MxBase/postprocess/include/ClassPostProcessors/Resnet50PostProcess.h"
#include "MxBase/Tensor/TensorContext/TensorContext.h"
extern std::vector<double> g_inferCost;
struct InitParam {
uint32_t deviceId;
std::string labelPath;
uint32_t classNum;
uint32_t topk;
bool softmax;
bool checkTensor;
std::string modelPath;
};
class Vgg16Classify {
public:
APP_ERROR Init(const InitParam &initParam);
APP_ERROR DeInit();
APP_ERROR ReadImage(const std::string &imgPath, cv::Mat &imageMat);
APP_ERROR Resize(const cv::Mat &srcImageMat, cv::Mat &dstImageMat);
APP_ERROR Crop(const cv::Mat &srcMat, cv::Mat &dstMat);
APP_ERROR CVMatToTensorBase(const cv::Mat &imageMat, MxBase::TensorBase &tensorBase);
APP_ERROR Inference(const std::vector<MxBase::TensorBase> &inputs, std::vector<MxBase::TensorBase> &outputs);
APP_ERROR PostProcess(const std::vector<MxBase::TensorBase> &inputs,
std::vector<std::vector<MxBase::ClassInfo>> &clsInfos);
APP_ERROR Process(const std::string &imgPath);
APP_ERROR SaveResult(const std::string &imgPath, const std::vector<std::vector<MxBase::ClassInfo>> &BatchClsInfos);
private:
std::shared_ptr<MxBase::DvppWrapper> dvppWrapper_;
std::shared_ptr<MxBase::ModelInferenceProcessor> model_;
std::shared_ptr<MxBase::Resnet50PostProcess> post_;
MxBase::ModelDesc modelDesc_;
uint32_t deviceId_ = 0;
std::ofstream tfile_;
};
#endif
/*
* Copyright 2021. Huawei Technologies Co., Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <experimental/filesystem>
#include "Vgg16Classify.h"
#include "MxBase/Log/Log.h"
namespace fs = std::experimental::filesystem;
namespace {
const uint32_t CLASS_NUM = 1000;
}
std::vector<double> g_inferCost;
int main(int argc, char* argv[]) {
if (argc <= 1) {
LogWarn << "Please input image path, such as './main ../data/input 10'.";
return APP_ERR_OK;
}
InitParam initParam = {};
initParam.deviceId = 0;
initParam.classNum = CLASS_NUM;
initParam.labelPath = "../data/config/imagenet1000_clsidx_to_labels.names";
initParam.topk = 5;
initParam.softmax = false;
initParam.checkTensor = true;
initParam.modelPath = "../data/model/vgg16.om";
auto vgg16 = std::make_shared<Vgg16Classify>();
APP_ERROR ret = vgg16->Init(initParam);
if (ret != APP_ERR_OK) {
LogError << "Vgg16Classify init failed, ret=" << ret << ".";
return ret;
}
std::string imgDir = argv[1];
int limit = std::strtol(argv[2], nullptr, 0);
int index = 0;
for (auto & entry : fs::directory_iterator(imgDir)) {
if (index == limit) {
break;
}
index++;
LogInfo << "read image path " << entry.path();
ret = vgg16->Process(entry.path());
if (ret != APP_ERR_OK) {
LogError << "Vgg16Classify process failed, ret=" << ret << ".";
vgg16->DeInit();
return ret;
}
}
vgg16->DeInit();
double costSum = 0;
for (unsigned int i = 0; i < g_inferCost.size(); i++) {
costSum += g_inferCost[i];
}
LogInfo << "Infer images sum " << g_inferCost.size() << ", cost total time: " << costSum << " ms.";
LogInfo << "The throughput: " << g_inferCost.size() * 1000 / costSum << " images/sec.";
return APP_ERR_OK;
}
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
import glob
import json
import os
from contextlib import ExitStack
from StreamManagerApi import StreamManagerApi, MxDataInput, InProtobufVector, \
MxProtobufIn
import MxpiDataType_pb2 as MxpiDataType
class GlobDataLoader():
def __init__(self, glob_pattern, limit=None):
self.glob_pattern = glob_pattern
self.limit = limit
self.file_list = self.get_file_list()
self.cur_index = 0
def get_file_list(self):
return glob.iglob(self.glob_pattern)
def __iter__(self):
return self
def __next__(self):
if self.cur_index == self.limit:
raise StopIteration()
label = None
file_path = next(self.file_list)
with open(file_path, 'rb') as fd:
data = fd.read()
self.cur_index += 1
return get_file_name(file_path), label, data
class Predictor():
def __init__(self, pipeline_conf, stream_name):
self.pipeline_conf = pipeline_conf
self.stream_name = stream_name
def __enter__(self):
self.stream_manager_api = StreamManagerApi()
ret = self.stream_manager_api.InitManager()
if ret != 0:
raise Exception(f"Failed to init Stream manager, ret={ret}")
# create streams by pipeline config file
with open(self.pipeline_conf, 'rb') as f:
pipeline_str = f.read()
ret = self.stream_manager_api.CreateMultipleStreams(pipeline_str)
if ret != 0:
raise Exception(f"Failed to create Stream, ret={ret}")
self.data_input = MxDataInput()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# destroy streams
self.stream_manager_api.DestroyAllStreams()
def predict(self, dataset):
print("Start predict........")
print('>' * 30)
for name, data in dataset:
self.data_input.data = data
yield self._predict(name, self.data_input)
print("predict end.")
print('<' * 30)
def _predict(self, name, data):
plugin_id = 0
protobuf_data = self._predict_gen_protobuf()
self._predict_send_protobuf(self.stream_name, 1, protobuf_data)
unique_id = self._predict_send_data(self.stream_name, plugin_id, data)
result = self._predict_get_result(self.stream_name, unique_id)
return name, json.loads(result.data.decode())
def _predict_gen_protobuf(self):
object_list = MxpiDataType.MxpiObjectList()
object_vec = object_list.objectVec.add()
object_vec.x0 = 16
object_vec.y0 = 16
object_vec.x1 = 240
object_vec.y1 = 240
protobuf = MxProtobufIn()
protobuf.key = b'appsrc1'
protobuf.type = b'MxTools.MxpiObjectList'
protobuf.protobuf = object_list.SerializeToString()
protobuf_vec = InProtobufVector()
protobuf_vec.push_back(protobuf)
return protobuf_vec
def _predict_send_protobuf(self, stream_name, in_plugin_id, data):
self.stream_manager_api.SendProtobuf(stream_name, in_plugin_id, data)
def _predict_send_data(self, stream_name, in_plugin_id, data_input):
unique_id = self.stream_manager_api.SendData(stream_name, in_plugin_id,
data_input)
if unique_id < 0:
raise Exception("Failed to send data to stream")
return unique_id
def _predict_get_result(self, stream_name, unique_id):
result = self.stream_manager_api.GetResult(stream_name, unique_id)
if result.errorCode != 0:
raise Exception(
f"GetResultWithUniqueId error."
f"errorCode={result.errorCode}, msg={result.data.decode()}")
return result
def get_file_name(file_path):
return os.path.splitext(os.path.basename(file_path.rstrip('/')))[0]
def result_encode(file_name, result):
sep = ','
pred_class_ids = sep.join(
str(i.get('classId')) for i in result.get("MxpiClass", []))
return f"{file_name} {pred_class_ids}\n"
def parse_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('glob', help='img pth glob pattern.')
parser.add_argument('result_file', help='result file')
return parser.parse_args()
def main():
pipeline_conf = "../data/config/vgg16.pipeline"
stream_name = b'im_vgg16'
args = parse_args()
result_fname = get_file_name(args.result_file)
pred_result_file = f"{result_fname}.txt"
dataset = GlobDataLoader(args.glob, limit=None)
with ExitStack() as stack:
predictor = stack.enter_context(Predictor(pipeline_conf, stream_name))
result_fd = stack.enter_context(open(pred_result_file, 'w'))
for fname, pred_result in predictor.predict(dataset):
result_fd.write(result_encode(fname, pred_result))
print(f"success, result in {pred_result_file}")
if __name__ == "__main__":
main()
#!/usr/bin/bash
set -e
# Simple log helper functions
info() { echo -e "\033[1;34m[INFO ][MxStream] $1\033[1;37m" ; }
warn() { echo >&2 -e "\033[1;31m[WARN ][MxStream] $1\033[1;37m" ; }
export LD_LIBRARY_PATH=${MX_SDK_HOME}/lib:${MX_SDK_HOME}/opensource/lib:${MX_SDK_HOME}/opensource/lib64:/usr/local/Ascend/ascend-toolkit/latest/acllib/lib64:${LD_LIBRARY_PATH}
export GST_PLUGIN_SCANNER=${MX_SDK_HOME}/opensource/libexec/gstreamer-1.0/gst-plugin-scanner
export GST_PLUGIN_PATH=${MX_SDK_HOME}/opensource/lib/gstreamer-1.0:${MX_SDK_HOME}/lib/plugins
#to set PYTHONPATH, import the StreamManagerApi.py
export PYTHONPATH=$PYTHONPATH:${MX_SDK_HOME}/python
python3.7 main.py "../data/input/*.JPEG" vgg16_sdk_pred_result.txt
exit 0
# coding: utf-8
"""
Copyright 2021 Huawei Technologies Co., Ltd
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import argparse
import json
import os
import numpy as np
def get_file_name(file_path):
return os.path.splitext(os.path.basename(file_path.rstrip('/')))[0]
def load_gt(gt_file):
gt = {}
with open(gt_file, 'r') as fd:
for line in fd.readlines():
img_name, img_label_index = line.strip().split(" ", 1)
gt[get_file_name(img_name)] = img_label_index
return gt
def load_pred(pred_file):
pred = {}
with open(pred_file, 'r') as fd:
for line in fd.readlines():
ret = line.strip().split(" ", 1)
if len(ret) < 2:
print(f"Warning: load pred, no result, line:{line}")
continue
img_name, ids = ret
img_name = get_file_name(img_name)
pred[img_name] = [x.strip() for x in ids.split(',')]
return pred
def calc_accuracy(gt_map, pred_map, top_k=5):
hits = [0] * top_k
miss_match = []
total = 0
for img, preds in pred_map.items():
gt = gt_map.get(img)
if gt is None:
print(f"Warning: {img}'s gt is not exists.")
continue
try:
index = preds.index(gt, 0, top_k)
hits[index] += 1
except ValueError:
miss_match.append({'img': img, 'gt': gt, 'prediction': preds})
finally:
total += 1
top_k_hit = np.cumsum(hits)
accuracy = top_k_hit / total
return {
'total': total,
'accuracy': [acc for acc in accuracy],
'miss': miss_match,
}
def parse_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('prediction', help='prediction result file')
parser.add_argument('gt', help='ground true result file')
parser.add_argument('result_json', help='metric result file')
parser.add_argument('top_k', help='top k', type=int)
return parser.parse_args()
def main():
args = parse_args()
prediction_file = args.prediction
gt_file = args.gt
top_k = args.top_k
result_json = args.result_json
gt = load_gt(gt_file)
prediction = load_pred(prediction_file)
result = calc_accuracy(gt, prediction, top_k)
result.update({
'prediction_file': prediction_file,
'gt_file': gt_file,
})
with open(result_json, 'w') as fd:
json.dump(result, fd, indent=2)
print(f"\nsuccess, result in {result_json}")
if __name__ == '__main__':
main()
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
#################train vgg16 example on cifar10########################
"""
import datetime
import os
import time
import numpy as np
import moxing as mox
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
import mindspore.common.dtype as mstype
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_param_into_net, load_checkpoint, export
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.common import set_seed
from src.dataset import vgg_create_dataset
from src.dataset import classification_dataset
from src.crossentropy import CrossEntropy
from src.warmup_step_lr import warmup_step_lr
from src.warmup_cosine_annealing_lr import warmup_cosine_annealing_lr
from src.warmup_step_lr import lr_steps
from src.utils.logging import get_logger
from src.utils.util import get_param_groups
from src.vgg import vgg16
from model_utils.moxing_adapter import config
from model_utils.moxing_adapter import moxing_wrapper
from model_utils.device_adapter import get_device_id, get_rank_id, get_device_num
set_seed(1)
def modelarts_pre_process():
'''modelarts pre process function.'''
def unzip(zip_file, save_dir):
import zipfile
s_time = time.time()
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
zip_isexist = zipfile.is_zipfile(zip_file)
if zip_isexist:
fz = zipfile.ZipFile(zip_file, 'r')
data_num = len(fz.namelist())
print("Extract Start...")
print("unzip file num: {}".format(data_num))
data_print = int(data_num / 100) if data_num > 100 else 1
i = 0
for file in fz.namelist():
if i % data_print == 0:
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
i += 1
fz.extract(file, save_dir)
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
int(int(time.time() - s_time) % 60)))
print("Extract Done.")
else:
print("This is not zip.")
else:
print("Zip has been extracted.")
if config.need_modelarts_dataset_unzip:
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
save_dir_1 = os.path.join(config.data_path)
sync_lock = "/tmp/unzip_sync.lock"
# Each server contains 8 devices as most.
if config.device_target == "GPU":
init()
device_id = get_rank()
device_num = get_group_size()
elif config.device_target == "Ascend":
device_id = get_device_id()
device_num = get_device_num()
else:
raise ValueError("Not support device_target.")
if device_id % min(device_num, 8) == 0 and not os.path.exists(sync_lock):
print("Zip file path: ", zip_file_1)
print("Unzip file save dir: ", save_dir_1)
unzip(zip_file_1, save_dir_1)
print("===Finish extract data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Device: {}, Finish sync unzip data from {} to {}.".format(device_id, zip_file_1, save_dir_1))
config.ckpt_path = os.path.join(config.output_path, config.ckpt_path)
def _get_last_ckpt(ckpt_dir):
ckpt_files = [ckpt_file for ckpt_file in os.listdir(ckpt_dir)
if ckpt_file.endswith('.ckpt')]
if not ckpt_files:
print("No ckpt file found.")
return None
return os.path.join(ckpt_dir, sorted(ckpt_files)[-1])
def run_export(ckpt_dir):
ckpt_file = _get_last_ckpt(ckpt_dir)
# config.image_size = list(map(int, config.image_size.split(',')))
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
if config.device_target == "Ascend":
config.device_id = get_device_id()
context.set_context(device_id=config.device_id)
if config.dataset == "cifar10":
net = vgg16(num_classes=config.num_classes, args=config)
else:
net = vgg16(config.num_classes, config, phase="test")
load_checkpoint(ckpt_file, net=net)
net.set_train(False)
input_data = Tensor(np.zeros([config.batch_size, 3, config.image_size[0], config.image_size[1]]), mstype.float32)
export(net, input_data, file_name=config.file_name, file_format=config.file_format)
mox.file.copy_parallel(os.getcwd(), config.train_url)
@moxing_wrapper(pre_process=modelarts_pre_process)
def run_train():
'''run train'''
config.lr_epochs = list(map(int, config.lr_epochs.split(',')))
config.image_size = list(map(int, config.image_size.split(',')))
config.per_batch_size = config.batch_size
_enable_graph_kernel = config.device_target == "GPU"
context.set_context(mode=context.GRAPH_MODE,
enable_graph_kernel=_enable_graph_kernel, device_target=config.device_target)
config.rank = get_rank_id()
config.device_id = get_device_id()
config.group_size = get_device_num()
if config.is_distributed:
if config.device_target == "Ascend":
init()
context.set_context(device_id=config.device_id)
elif config.device_target == "GPU":
if not config.enable_modelarts:
init()
else:
if not config.need_modelarts_dataset_unzip:
init()
device_num = config.group_size
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True, all_reduce_fusion_config=[2, 18])
else:
if config.device_target == "Ascend":
context.set_context(device_id=config.device_id)
# select for master rank save ckpt or all rank save, compatible for model parallel
config.rank_save_ckpt_flag = 0
if config.is_save_on_master:
if config.rank == 0:
config.rank_save_ckpt_flag = 1
else:
config.rank_save_ckpt_flag = 1
# logger
config.outputs_dir = os.path.join(config.ckpt_path,
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
config.logger = get_logger(config.outputs_dir, config.rank)
if config.dataset == "cifar10":
dataset = vgg_create_dataset(config.data_dir, config.image_size, config.per_batch_size,
config.rank, config.group_size)
else:
dataset = classification_dataset(config.data_dir, config.image_size, config.per_batch_size,
config.rank, config.group_size)
batch_num = dataset.get_dataset_size()
config.steps_per_epoch = dataset.get_dataset_size()
config.logger.save_args(config)
# network
config.logger.important_info('start create network')
# get network and init
network = vgg16(config.num_classes, config)
# pre_trained
if config.pre_trained:
load_param_into_net(network, load_checkpoint(config.pre_trained))
# lr scheduler
if config.lr_scheduler == 'exponential':
lr = warmup_step_lr(config.lr,
config.lr_epochs,
config.steps_per_epoch,
config.warmup_epochs,
config.max_epoch,
gamma=config.lr_gamma,
)
elif config.lr_scheduler == 'cosine_annealing':
lr = warmup_cosine_annealing_lr(config.lr,
config.steps_per_epoch,
config.warmup_epochs,
config.max_epoch,
config.T_max,
config.eta_min)
elif config.lr_scheduler == 'step':
lr = lr_steps(0, lr_init=config.lr_init, lr_max=config.lr_max, warmup_epochs=config.warmup_epochs,
total_epochs=config.max_epoch, steps_per_epoch=batch_num)
else:
raise NotImplementedError(config.lr_scheduler)
# optimizer
opt = Momentum(params=get_param_groups(network),
learning_rate=Tensor(lr),
momentum=config.momentum,
weight_decay=config.weight_decay,
loss_scale=config.loss_scale)
if config.dataset == "cifar10":
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
model = Model(network, loss_fn=loss, optimizer=opt, metrics={'acc'},
amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None)
else:
if not config.label_smooth:
config.label_smooth_factor = 0.0
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.num_classes)
loss_scale_manager = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager, amp_level="O2")
# define callbacks
time_cb = TimeMonitor(data_size=batch_num)
loss_cb = LossMonitor(per_print_times=batch_num)
callbacks = [time_cb, loss_cb]
if config.rank_save_ckpt_flag:
ckpt_config = CheckpointConfig(save_checkpoint_steps=config.ckpt_interval * config.steps_per_epoch,
keep_checkpoint_max=config.keep_checkpoint_max)
save_ckpt_path = os.path.join(config.outputs_dir, 'ckpt_' + str(config.rank) + '/')
ckpt_cb = ModelCheckpoint(config=ckpt_config,
directory=save_ckpt_path,
prefix='{}'.format(config.rank))
callbacks.append(ckpt_cb)
model.train(config.max_epoch, dataset, callbacks=callbacks)
run_export(save_ckpt_path)
if __name__ == '__main__':
run_train()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment