diff --git a/.jenkins/check/config/filter_cpplint.txt b/.jenkins/check/config/filter_cpplint.txt index 87056544ef6df66a98ff808336dd4cf8ea0263b6..36e86870fa21699f9df151e46b7208f56f23975b 100644 --- a/.jenkins/check/config/filter_cpplint.txt +++ b/.jenkins/check/config/filter_cpplint.txt @@ -18,7 +18,7 @@ "models/research/cv/FaceQualityAssessment/infer/util/plugins/MxpiTransposePlugin.cpp" "runtime/references" "models/research/cv/stgcn/infer/mxbase/src/stgcnUtil.h" "runtime/references" -"models/research/cv/stgcn/infer/mxbase/src/stgcnUtil.cpp" "runtime/references" +"models/research/cv/stgcn/infer/mxbase/src/stgcnUtil.cpp" "runtime/references""models/official/cv/resnext/infer/mxbase/src/resnext50Classify.h" "runtime/references" "models/research/cv/stgcn/infer/mxbase/src/main.cpp" "runtime/references" @@ -32,3 +32,4 @@ "models/research/cv/squeezenet1_1/infer/mxbase/main_opencv.cpp" "runtime/references" "models/research/cv/squeezenet1_1/infer/mxbase/Squeezenet1_1ClassifyOpencv.cpp" "runtime/references" "models/official/cv/vgg16/infer/mxbase/src/Vgg16Classify.h" "runtime/references" +"models/official/cv/resnext/infer/mxbase/src/resnext50Classify.h" "runtime/references" diff --git a/official/cv/resnext/infer/Dockerfile b/official/cv/resnext/infer/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..f522a400cc040f7c9e44f2b0b9b5267076a6b40b --- /dev/null +++ b/official/cv/resnext/infer/Dockerfile @@ -0,0 +1,20 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +ARG FROM_IMAGE_NAME +#Base image +FROM $FROM_IMAGE_NAME +ARG SDK_PKG +#Copy the SDK installation package to the image +COPY ./$SDK_PKG . diff --git a/official/cv/resnext/infer/README.md b/official/cv/resnext/infer/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7baef14657a2d80f3929762bf35433e36defb381 --- /dev/null +++ b/official/cv/resnext/infer/README.md @@ -0,0 +1,628 @@ +# 离线推理过程 + +## MindX SDK推理 + +1. 准备容器环境。 + +- 准备SDK安装包和Dockerfile文件(resnext50 /Dockerfile),将其一同上传至服务器任意目录(如/home/data) + Dockerfile文件内容: + + ~~~ Dockerfile + ARG FROM_IMAGE_NAME + #基础镜像 + FROM $FROM_IMAGE_NAME + ARG SDK_PKG + #将SDK安装包拷贝到镜像中 + COPY ./$SDK_PKG . + ~~~ + +- 下载MindX SDK开发套件(mxManufacture)。 + + 将下载的软件包上传到代码目录“/home/data/resnext50/infer”下并执行安装命令:./Ascend-mindxsdk-mxManufacture_xxx_linux-x86_64.run --install --install-path=安装路径(安装路径为源代码存放路径,如第一点所说的/home/data/infer)。 + +- 编译推理镜像: + + ```shell + #非root权限,需在指令前面加"sudo" + docker build -t infer_image --build-arg FROM_IMAGE_NAME=base_image:tag --build-arg SDK_PKG=sdk_pkg . + ``` + + | 参数 | 说明 | + | ------------- | ------------------------------------------------------------ | + | *infer_image* | 推理镜像名称,根据实际写入。 | + | *base_image* | 基础镜像,可从Ascend Hub上下载,如ascendhub.huawei.com/public-ascendhub/ascend-infer-x86。 | + | *tag* | 镜像tag,请根据实际配置,如:21.0.2。 | + | sdk_pkg | 下载的mxManufacture包名称,如Ascend-mindxsdk-mxmanufacture_*{version}*_linux-*{arch}*.run | + 注:指令末尾的”.“一定不能省略,这代表当前目录。 + +- 执行以下命令,启动容器实例: + + ```shell + bash docker_start_infer.sh docker_image:tag model_dir + ``` + + | 参数 | 说明 | + | -------------- | -------------------------------------------- | + | *docker_image* | 推理镜像名称,推理镜像请从Ascend Hub上下载。 | + | *tag* | 镜像tag,请根据实际配置,如:21.0.2。 | + | *model_dir* | 推理容器挂载路径,本例中为/home/data | + 启动容器时会将推理芯片和数据路径挂载到容器中。 + 其中docker_start_infer.sh的内容如下: + + ```shell + #!/bin/bash + docker_image=$1 + model_dir=$2 + if [ -z "${docker_image}" ]; then + echo "please input docker_image" + exit 1 + fi + if [ ! -d "${model_dir}" ]; then + echo "please input model_dir" + exit 1 + fi + docker run -it \ + --device=/dev/davinci0 \ #请根据芯片的情况更改 + --device=/dev/davinci_manager \ + --device=/dev/devmm_svm \ + --device=/dev/hisi_hdc \ + -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \ + -v ${model_dir}:${model_dir} \ + ${docker_image} \ + /bin/bash + ``` + +2. 将源代码(resnext50文件夹)上传至服务器任意目录(如:/home/data/,后续示例将以/home/data/resnext50为例),并进入该目录。 + 源码目录结构如下图所示: + +```text +/home/data/resnext50 +├── infer # MindX高性能预训练模型新增 +│ └── README.md # 离线推理文档 +│ ├── convert # 转换om模型命令,AIPP +│ │ ├──aipp.config +│ │ └──atc.sh +│ ├── data # 包括模型文件、模型输入数据集、模型相关配置文件(如label、SDK的pipeline) +│ │ ├── model +│ │ ├── input +│ │ └── config +│ │ │ ├──imagenet1000_clsidx_to_labels.names +│ │ │ ├──resnext50.cfg +│ │ │ └──resnext50.pipeline +│ ├── mxbase # 基于mxbase推理 +│ │ ├── build +│ │ ├── src +│ │ │ ├── resnext50Classify.cpp +│ │ │ ├── resnext50Classify.h +│ │ │ ├── main.cpp +│ │ │ └── include #包含运行所需库 +│ │ ├── CMakeLists.txt +│ │ └── build.sh +│ └── sdk # 基于sdk run包推理;如果是C++实现,存放路径一样 +│ │ ├── main.py +│ │ └── run.sh +│ └── util # 精度验证脚本 +│ │ ├──imagenet2012_val.txt +│ │ └──task_metric.py +``` + +3. 将air模型放入resnext50/infer/data/model目录下。 + +4. 准备AIPP配置文件。 + +AIPP需要配置aipp.config文件,在ATC转换的过程中插入AIPP算子,即可与DVPP处理后的数据无缝对接,AIPP参数配置请参见《[CANN 开发辅助工具指南 (推理)](https://support.huawei.com/enterprise/zh/ascend-computing/cann-pid-251168373?category=developer-documents&subcategory=auxiliary-development-tools)》中“ATC工具使用指南”。 +aipp.config文件内容如下,该文件放在resnext50/infer/convert目录下: + +~~~ config +aipp_op{ + aipp_mode: static + input_format : RGB888_U8 + rbuv_swap_switch : true + mean_chn_0 : 0 + mean_chn_1 : 0 + mean_chn_2 : 0 + min_chn_0 : 123.675 + min_chn_1 : 116.28 + min_chn_2 : 103.53 + var_reci_chn_0: 0.0171247538316637 + var_reci_chn_1: 0.0175070028011204 + var_reci_chn_2: 0.0174291938997821 +} +~~~ + +5. 进入resnext50/infer/convert目录,并执行: + +```shell +#本例中模型名称为resnext50.air +bash atc.sh ../data/model/resnext50.air +``` + +利用ATC工具将air模型转换为om模型,om模型会自动放在resnext50/infer/data/model/文件夹下。 +其中atc.sh的内容如下: + +~~~ sh +model=$1 +atc \ + --model=$model \ + --framework=1 \ + --output=../data/model/resnext50 \ + --output_type=FP32 \ + --soc_version=Ascend310 \ + --input_shape="input:1,224,224,3" \ + --log=info \ + --insert_op_conf=aipp.config +~~~ + +参数说明: + +- --model:待转换的air模型所在路径。 +- --framework:1代表MindSpore框架。 +- --output:转换后输出的om模型存放路径以及名称。 +- --soc_version:生成的模型支持的推理环境。 +- --input_shape:输入数据的shape。 +- --log=info:打印转换过程中info级别的日志。 +- --insert_op_conf:模型转换使用的AIPP配置文件。 + +6. 数据准备。 + 将推离图片数据集放在resnext50/infer/data/input目录下。 + +- 推理使用的数据集是[ImageNet2012](http://www.image-net.org/)中的验证集,input目录下已有其中十张测试图片。 +- 测试集:6.4 GB,50, 000张图像。 + +7. 准备模型推理所需文件。 + + 1. 在“/home/data/resnext50/infer/data/config/”目录下编写pipeline文件。 + 根据实际情况修改resnext50.pipeline文件中图片规格、模型路径、配置文件路径和标签路径。 + 更多介绍请参见《[mxManufacture 用户指南](https://ascend.huawei.com/#/software/mindx-sdk/sdk-detail)》中“基础开发”章节。 + resnext50.pipeline + +```pipeline +{ + "im_resnext50": { + "stream_config": { + "deviceId": "0" + }, + "mxpi_imagedecoder0": { + "props": { + "handleMethod": "opencv" + }, + "factory": "mxpi_imagedecoder", + "next": "mxpi_imageresize0" + }, + "mxpi_imageresize0": { + "props": { + "handleMethod": "opencv", + "resizeHeight": "256", + "resizeWidth": "256", + "resizeType": "Resizer_Stretch" + }, + "factory": "mxpi_imageresize", + "next": "mxpi_imagecrop0:1" + }, + "mxpi_imagecrop0": { + "props": { + "dataSource": "appsrc1", + "dataSourceImage": "mxpi_imageresize0", + "handleMethod": "opencv" + }, + "factory": "mxpi_imagecrop", + "next": "mxpi_tensorinfer0" + }, + "mxpi_tensorinfer0": { + "props": { + "dataSource": "mxpi_imagecrop0", + "modelPath": "../data/model/resnext50.om", + "waitingTime": "1", + "outputDeviceId": "-1" + }, + "factory": "mxpi_tensorinfer", + "next": "mxpi_classpostprocessor0" + }, + "mxpi_classpostprocessor0": { + "props": { + "dataSource": "mxpi_tensorinfer0", + "postProcessConfigPath": "../data/config/resnext50.cfg", + "labelPath": "../data/config/imagenet1000_clsidx_to_labels.names", + "postProcessLibPath": "/home/data/zjut_mindx/SDK_2.0.2/mxManufacture/lib/modelpostprocessors/libresnet50postprocess.so" + }, + "factory": "mxpi_classpostprocessor", + "next": "mxpi_dataserialize0" + }, + "mxpi_dataserialize0": { + "props": { + "outputDataKeys": "mxpi_classpostprocessor0" + }, + "factory": "mxpi_dataserialize", + "next": "appsink0" + }, + "appsrc1": { + "props": { + "blocksize": "409600" + }, + "factory": "appsrc", + "next": "mxpi_imagecrop0:0" + }, + "appsrc0": { + "props": { + "blocksize": "409600" + }, + "factory": "appsrc", + "next": "mxpi_imagedecoder0" + }, + "appsink0": { + "props": { + "blocksize": "4096000" + }, + "factory": "appsink" + } + } +} +``` + +参数说明: + +- resizeHeight:图片缩放后高度,请根据实际需求尺寸输入。 +- resizeWidth:图片缩放后宽度,请根据实际需求尺寸输入。 +- modelPath:模型路径,请根据模型实际路径修改。 +- postProcessConfigPath:模型配置文件路径,请根据模型配置文件的实际路径修改。 +- labelPath:标签文件路径,请根据标签文件的实际路径修改。 + + 2. 在“/home/data/resnext50/infer/data/config/”目录下编写resnext50.cfg配置文件。 + 配置文件resnext50.cfg内容如下: + +```python +CLASS_NUM=1000 +SOFTMAX=false +TOP_K=5 +``` + +8. 根据实际情况修改main.py文件中裁剪图片的位置和**pipeline**文件路径。 + +其中main.py的内容如下: + +```python +... + def _predict_gen_protobuf(self): + object_list = MxpiDataType.MxpiObjectList() + object_vec = object_list.objectVec.add() + object_vec.x0 = 16 + object_vec.y0 = 16 + object_vec.x1 = 240 + object_vec.y1 = 240 +... + def main(): + pipeline_conf = "../data/config/resnext50.pipeline" + stream_name = b'im_resnext50' + + args = parse_args() + result_fname = get_file_name(args.result_file) + pred_result_file = f"{result_fname}.txt" + dataset = GlobDataLoader(args.glob, limit=None) + with ExitStack() as stack: + predictor = stack.enter_context(Predictor(pipeline_conf, stream_name)) + result_fd = stack.enter_context(open(pred_result_file, 'w')) + + for fname, pred_result in predictor.predict(dataset): + result_fd.write(result_encode(fname, pred_result)) + + print(f"success, result in {pred_result_file}") +... +``` + +9. 进入resnext50/infer/sdk文件夹,执行bash run.sh,推理结果保存在当前目录下的resnext50_sdk_pred_result.txt文件中。 + +其中run.sh的内容如下: + +~~~ bash +set -e +CUR_PATH=$(cd "$(dirname "$0")" || { warn "Failed to check path/to/run.sh" ; exit ; } ; pwd) +# Simple log helper functions +info() { echo -e "\033[1;34m[INFO ][MxStream] $1\033[1;37m" ; } +warn() { echo >&2 -e "\033[1;31m[WARN ][MxStream] $1\033[1;37m" ; } +export LD_LIBRARY_PATH=${MX_SDK_HOME}/lib:${MX_SDK_HOME}/opensource/lib:${MX_SDK_HOME}/opensource/lib64:/usr/local/Ascend/ascend-toolkit/latest/acllib/lib64:${LD_LIBRARY_PATH} +export GST_PLUGIN_SCANNER=${MX_SDK_HOME}/opensource/libexec/gstreamer-1.0/gst-plugin-scanner +export GST_PLUGIN_PATH=${MX_SDK_HOME}/opensource/lib/gstreamer-1.0:${MX_SDK_HOME}/lib/plugins +#to set PYTHONPATH, import the StreamManagerApi.py +export PYTHONPATH=$PYTHONPATH:${MX_SDK_HOME}/python +python3.7 main.py "../data/input/*.JPEG" resnext50_sdk_pred_result.txt +exit 0 +~~~ + +推理结果: + +```shell +ILSVRC2012_val_00000008 415,928,850,968,911 +ILSVRC2012_val_00000001 65,62,58,54,56 +ILSVRC2012_val_00000009 674,338,333,106,337 +ILSVRC2012_val_00000007 334,361,375,7,8 +ILSVRC2012_val_00000010 332,338,153,204,190 +ILSVRC2012_val_00000005 520,516,431,797,564 +ILSVRC2012_val_00000006 62,60,57,67,65 +ILSVRC2012_val_00000003 230,231,226,169,193 +ILSVRC2012_val_00000002 795,970,537,796,672 +ILSVRC2012_val_00000004 809,968,969,504,967 +...... +``` + +10. 验证精度,进入resnext50/infer/util目录下并执行: + +```shell +python3.7 task_metric.py ../sdk/resnext50_sdk_pred_result.txt imagenet2012_val.txt resnext50_sdk_pred_result_acc.json 5 +``` + +参数说明: + +- 第一个参数(../sdk/resnext50_sdk_pred_result.txt):推理结果保存路径。 +- 第二个参数(imagenet2012_val.txt):验证集标签文件。 +- 第三个参数(resnext50_sdk_pred_result_acc.json):结果文件 +- 第四个参数(5):"1"表示TOP-1准确率,“5”表示TOP-5准确率。 + +11. 查看推理性能和精度 + +- 打开性能统计开关。将“enable_ps”参数设置为true,“ps_interval_time”参数设置为6。 + + **vi /usr/local/sdk_home/mxManufacture/config/sdk.conf** + + ```bash + # MindX SDK configuration file + + # whether to enable performance statistics, default is false [dynamic config] + enable_ps=true + ... + ps_interval_time=6 + ... + ``` + +- 执行run.sh脚本 **bash run.sh** + +- 在日志目录"/usr/local/sdk_home/mxManufacture/logs"查看性能统计结果。 + + ```bash + performance--statistics.log.e2e.xxx + performance--statistics.log.plugin.xxx + performance--statistics.log.tpr.xxx + ``` + + 其中e2e日志统计端到端时间,plugin日志统计单插件时间。 + +查看推理精度 + +~~~ bash +cat resnext50_sdk_pred_result_acc.json +~~~ + +top-5推理精度: + +```json + "total": 50000, + "accuracy": [ + 0.78386, + 0.8793, + 0.91304, + 0.92962, + 0.94016 + ] +``` + +## mxBase推理 + +1. 准备容器环境。 + +- 准备SDK安装包和Dockerfile文件(resnext50/Dockerfile),将其一同上传至服务器任意目录(如/home/data) + Dockerfile文件内容: + + ```dockerfile + ARG FROM_IMAGE_NAME + #基础镜像 + FROM $FROM_IMAGE_NAME + ARG SDK_PKG + #将SDK安装包拷贝到镜像中 + COPY ./$SDK_PKG . + ``` + +- 下载MindX SDK开发套件(mxManufacture)。 + + 将下载的软件包上传到代码目录“/home/data/resnext50/infer”下并执行安装命令:./Ascend-mindxsdk-mxManufacture_xxx_linux-x86_64.run --install --install-path=安装路径(安装路径为源代码存放路径,如第一点所说的/home/data/infer)。 + +- 编译推理镜像。 + + ```shell + #非root权限,需在指令前面加"sudo" + docker build -t infer_image --build-arg FROM_IMAGE_NAME= base_image:tag --build-arg SDK_PKG= sdk_pkg . + ``` + + | 参数 | 说明 | + | ------------- | ------------------------------------------------------------ | + | *infer_image* | 推理镜像名称,根据实际写入。 | + | *base_image* | 基础镜像,可从Ascend Hub上下载。 | + | *tag* | 镜像tag,请根据实际配置,如:21.0.2。 | + | sdk_pkg | 下载的mxManufacture包名称,如Ascend-mindxsdk-mxmanufacture_*{version}*_linux-*{arch}*.run | + +- 执行以下命令,启动容器实例。 + + ```shell + bash docker_start_infer.sh docker_image:tag model_dir + ``` + + | 参数 | 说明 | + | -------------- | -------------------------------------------- | + | *docker_image* | 推理镜像名称,推理镜像请从Ascend Hub上下载。 | + | *tag* | 镜像tag,请根据实际配置,如:21.0.2。 | + | *model_dir* | 推理容器挂载路径,本例中为/home/data | + 启动容器时会将推理芯片和数据路径挂载到容器中。 + 其中docker_start_infer.sh的内容如下: + + ```shell + #!/bin/bash + docker_image=$1 + model_dir=$2 + if [ -z "${docker_image}" ]; then + echo "please input docker_image" + exit 1 + fi + if [ ! -d "${model_dir}" ]; then + echo "please input model_dir" + exit 1 + fi + docker run -it \ + --device=/dev/davinci0 \ + --device=/dev/davinci_manager \ + --device=/dev/devmm_svm \ + --device=/dev/hisi_hdc \ + -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \ + -v ${model_dir}:${model_dir} \ + ${docker_image} \ + /bin/bash + ``` + +2. 将源代码(resnext50文件夹)上传至服务器任意目录(如:/home/data/,后续示例将以/home/data/resnext50为例),并进入该目录。 + 源码目录结构如下图所示: + +```text +/home/data/resnext50 +├── infer # MindX高性能预训练模型新增 +│ └── README.md # 离线推理文档 +│ ├── convert # 转换om模型命令,AIPP +│ │ ├──aipp.config +│ │ └──atc.sh +│ ├── data # 包括模型文件、模型输入数据集、模型相关配置文件(如label、SDK的pipeline) +│ │ ├── model +│ │ ├── input +│ │ └── config +│ │ │ ├──imagenet1000_clsidx_to_labels.names +│ │ │ ├──resnext50.cfg +│ │ │ └──resnext50.pipeline +│ ├── mxbase # 基于mxbase推理 +│ │ ├── build +│ │ ├── src +│ │ │ ├── resnext50Classify.cpp +│ │ │ ├── resnext50Classify.h +│ │ │ ├── main.cpp +│ │ │ └── include #包含运行所需库 +│ │ ├── CMakeLists.txt +│ │ └── build.sh +│ └── sdk # 基于sdk run包推理;如果是C++实现,存放路径一样 +│ │ ├── main.py +│ │ └── run.sh +│ └── util # 精度验证脚本 +│ │ ├──imagenet2012_val.txt +│ │ └──task_metric.py +``` + +3. 添加环境变量。 + +通过vi ~/.bashrc命令打开~/.bashrc文件,将下面的环境变量添加进当前环境,添加好环境变量以后退出文件编辑,执行source ~/.bashrc使环境变量生效。 + +```bash +export ASCEND_HOME="/usr/local/Ascend" +export ASCEND_VERSION="nnrt/latest" +export ARCH_PATTERN="." +export LD_LIBRARY_PATH="${MX_SDK_HOME}/lib/modelpostprocessors:${LD_LIBRARY_PATH}" +export MXSDK_OPENSOURCE_DIR="${MX_SDK_HOME}/opensource" +``` + +4. 进行模型转换,详细步骤见上一章MindX SDK推理。 +5. 进入resnext50/infer/mxbase目录,并执行: + +```shell +bash build.sh +``` + +其中build.sh的内容如下: + +```sh +path_cur=$(dirname $0) +function check_env() +{ + # set ASCEND_VERSION to ascend-toolkit/latest when it was not specified by user + if [ ! "${ASCEND_VERSION}" ]; then + export ASCEND_VERSION=ascend-toolkit/latest + echo "Set ASCEND_VERSION to the default value: ${ASCEND_VERSION}" + else + echo "ASCEND_VERSION is set to ${ASCEND_VERSION} by user" + fi + + if [ ! "${ARCH_PATTERN}" ]; then + # set ARCH_PATTERN to ./ when it was not specified by user + export ARCH_PATTERN=./ + echo "ARCH_PATTERN is set to the default value: ${ARCH_PATTERN}" + else + echo "ARCH_PATTERN is set to ${ARCH_PATTERN} by user" + fi +} +function build_resnext50() +{ + cd $path_cur + rm -rf build + mkdir -p build + cd build + cmake .. + make + ret=$? + if [ ${ret} -ne 0 ]; then + echo "Failed to build resnext50." + exit ${ret} + fi + make install +} +check_env +build_resnext50 +``` + +6. 在当前目录下执行。 + +```shell +#假设图片数量为10,推理结果保存在当前目录下的mx_pred_result.txt文件下 +./resnext ../data/input 5000 +``` + +- 第一个参数(../data/input):图片输入路径。 +- 第二个参数(50000):图片输入数量。 + +mx_pred_result.txt(推理结果): + +```shell +ILSVRC2012_val_00047110 221,267,220,206,266, +ILSVRC2012_val_00014552 550,505,503,899,804, +ILSVRC2012_val_00006604 276,275,287,286,212, +ILSVRC2012_val_00016859 2,3,148,5,19, +ILSVRC2012_val_00020009 336,649,972,299,356, +ILSVRC2012_val_00025515 917,921,454,446,620, +ILSVRC2012_val_00046794 541,632,412,822,686, +ILSVRC2012_val_00035447 856,866,595,586,864, +ILSVRC2012_val_00016392 54,68,56,60,66, +ILSVRC2012_val_00023902 50,44,60,61,54, +ILSVRC2012_val_00000719 268,151,171,237,285, +...... +``` + +7. 验证精度,进入resnext50/infer/util目录下,并执行: + +```python +python3.7 task_metric.py ../mxbase/mx_pred_result.txt imagenet2012_val.txt resnext50_mxbase_pred_result_acc.json 5 +``` + +参数说明: + +- 第一个参数(../mxbase/mx_pred_result.txt):推理结果保存路径。 +- 第二个参数(imagenet2012_val.txt):验证集标签文件。 +- 第三个参数(resnext50_mxbase_pred_result_acc.json):结果文件 +- 第四个参数(5):"1"表示TOP-1准确率,“5”表示TOP-5准确率。 + +8. 查看推理精度结果 + +~~~ bash +cat resnext50_mxbase_pred_result_acc.json +~~~ + +top-5推理精度: + +```json +"accuracy": [ + 0.78386, + 0.8793, + 0.91304, + 0.92962, + 0.94016 + ] +``` \ No newline at end of file diff --git a/official/cv/resnext/infer/convert/aipp.config b/official/cv/resnext/infer/convert/aipp.config new file mode 100644 index 0000000000000000000000000000000000000000..6cc99495f0f8b8aa4bf7297e7e221762ef4e9b4d --- /dev/null +++ b/official/cv/resnext/infer/convert/aipp.config @@ -0,0 +1,15 @@ +aipp_op{ + aipp_mode: static + input_format : RGB888_U8 + rbuv_swap_switch : true + + mean_chn_0 : 0 + mean_chn_1 : 0 + mean_chn_2 : 0 + min_chn_0 : 123.675 + min_chn_1 : 116.28 + min_chn_2 : 103.53 + var_reci_chn_0: 0.0171247538316637 + var_reci_chn_1: 0.0175070028011204 + var_reci_chn_2: 0.0174291938997821 +} \ No newline at end of file diff --git a/official/cv/resnext/infer/convert/atc.sh b/official/cv/resnext/infer/convert/atc.sh new file mode 100644 index 0000000000000000000000000000000000000000..705eaa37b5b76c6753fd83867e7d3b61f4a6ea08 --- /dev/null +++ b/official/cv/resnext/infer/convert/atc.sh @@ -0,0 +1,25 @@ +#!/usr/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +model=$1 +/usr/local/Ascend/atc/bin/atc \ + --model=$model \ + --framework=1 \ + --output=../data/model/resnext50 \ + --output_type=FP32 \ + --soc_version=Ascend310 \ + --input_shape="input:1,224,224,3" \ + --log=info \ + --insert_op_conf=aipp.config diff --git a/official/cv/resnext/infer/data/config/resnext50.cfg b/official/cv/resnext/infer/data/config/resnext50.cfg new file mode 100644 index 0000000000000000000000000000000000000000..581fc76d3d75445323ea9a387f7152a72bedd1d3 --- /dev/null +++ b/official/cv/resnext/infer/data/config/resnext50.cfg @@ -0,0 +1,3 @@ +CLASS_NUM=1000 +SOFTMAX=false +TOP_K=5 diff --git a/official/cv/resnext/infer/data/config/resnext50.pipeline b/official/cv/resnext/infer/data/config/resnext50.pipeline new file mode 100644 index 0000000000000000000000000000000000000000..e421a48983c02ab69ea7aa9a062a82d9673e0b88 --- /dev/null +++ b/official/cv/resnext/infer/data/config/resnext50.pipeline @@ -0,0 +1,80 @@ +{ + "im_resnext50": { + "stream_config": { + "deviceId": "0" + }, + "mxpi_imagedecoder0": { + "props": { + "handleMethod": "opencv" + }, + "factory": "mxpi_imagedecoder", + "next": "mxpi_imageresize0" + }, + "mxpi_imageresize0": { + "props": { + "handleMethod": "opencv", + "resizeHeight": "256", + "resizeWidth": "256", + "resizeType": "Resizer_Stretch" + }, + "factory": "mxpi_imageresize", + "next": "mxpi_imagecrop0:1" + }, + "mxpi_imagecrop0": { + "props": { + "dataSource": "appsrc1", + "dataSourceImage": "mxpi_imageresize0", + "handleMethod": "opencv" + }, + "factory": "mxpi_imagecrop", + "next": "mxpi_tensorinfer0" + }, + "mxpi_tensorinfer0": { + "props": { + "dataSource": "mxpi_imagecrop0", + "modelPath": "../data/model/resnext50.om", + "waitingTime": "1", + "outputDeviceId": "-1" + }, + "factory": "mxpi_tensorinfer", + "next": "mxpi_classpostprocessor0" + }, + "mxpi_classpostprocessor0": { + "props": { + "dataSource": "mxpi_tensorinfer0", + "postProcessConfigPath": "../data/config/resnext50.cfg", + "labelPath": "../data/model/imagenet1000_clsidx_to_labels.names", + "postProcessLibPath": "/usr/local/sdk_home/mxManufacture/lib/modelpostprocessors/libresnet50postprocess.so" + }, + "factory": "mxpi_classpostprocessor", + "next": "mxpi_dataserialize0" + }, + "mxpi_dataserialize0": { + "props": { + "outputDataKeys": "mxpi_classpostprocessor0" + }, + "factory": "mxpi_dataserialize", + "next": "appsink0" + }, + "appsrc1": { + "props": { + "blocksize": "409600" + }, + "factory": "appsrc", + "next": "mxpi_imagecrop0:0" + }, + "appsrc0": { + "props": { + "blocksize": "409600" + }, + "factory": "appsrc", + "next": "mxpi_imagedecoder0" + }, + "appsink0": { + "props": { + "blocksize": "4096000" + }, + "factory": "appsink" + } + } +} \ No newline at end of file diff --git a/official/cv/resnext/infer/data/input/ILSVRC2012_val_00000001.JPEG b/official/cv/resnext/infer/data/input/ILSVRC2012_val_00000001.JPEG new file mode 100644 index 0000000000000000000000000000000000000000..fd3a93f59385d6ff632483646e6caee300b56d09 Binary files /dev/null and b/official/cv/resnext/infer/data/input/ILSVRC2012_val_00000001.JPEG differ diff --git a/official/cv/resnext/infer/docker_start_infer.sh b/official/cv/resnext/infer/docker_start_infer.sh new file mode 100644 index 0000000000000000000000000000000000000000..014147b45771af2c7338e870e525f6c2d7d6cbe8 --- /dev/null +++ b/official/cv/resnext/infer/docker_start_infer.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +docker_image=$1 +model_dir=$2 + +if [ -z "${docker_image}" ]; then + echo "please input docker_image" + exit 1 +fi + +if [ ! -d "${model_dir}" ]; then + echo "please input model_dir" + exit 1 +fi + +docker run -it \ + --device=/dev/davinci2 \ + --device=/dev/davinci_manager \ + --device=/dev/devmm_svm \ + --device=/dev/hisi_hdc \ + -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \ + -v ${model_dir}:${model_dir} \ + ${docker_image} \ + /bin/bash \ No newline at end of file diff --git a/official/cv/resnext/infer/mxbase/CMakeLists.txt b/official/cv/resnext/infer/mxbase/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..1a1897d51ee97b82f459071885ea008d3f557be5 --- /dev/null +++ b/official/cv/resnext/infer/mxbase/CMakeLists.txt @@ -0,0 +1,67 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +cmake_minimum_required(VERSION 3.10.0) +project(resnext) + +set(TARGET resnext) + +add_definitions(-DENABLE_DVPP_INTERFACE) + +add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) +add_definitions(-Dgoogle=mindxsdk_private) + +add_compile_options(-std=c++11 -fPIE -fstack-protector-all -fPIC -Wall) +add_link_options(-Wl,-z,relro,-z,now,-z,noexecstack -s -pie) + +# Check environment variable +if(NOT DEFINED ENV{ASCEND_HOME}) + message(FATAL_ERROR "please define environment variable:ASCEND_HOME") +endif() +if(NOT DEFINED ENV{ASCEND_VERSION}) + message(WARNING "please define environment variable:ASCEND_VERSION") +endif() +if(NOT DEFINED ENV{ARCH_PATTERN}) + message(WARNING "please define environment variable:ARCH_PATTERN") +endif() +set(ACL_INC_DIR $ENV{ASCEND_HOME}/$ENV{ASCEND_VERSION}/$ENV{ARCH_PATTERN}/acllib/include) +set(ACL_LIB_DIR $ENV{ASCEND_HOME}/$ENV{ASCEND_VERSION}/$ENV{ARCH_PATTERN}/acllib/lib64) + +set(MXBASE_ROOT_DIR $ENV{MX_SDK_HOME}) +set(MXBASE_INC ${MXBASE_ROOT_DIR}/include) +set(MXBASE_LIB_DIR ${MXBASE_ROOT_DIR}/lib) +set(MXBASE_POST_LIB_DIR ${MXBASE_ROOT_DIR}/lib/modelpostprocessors) +set(MXBASE_POST_PROCESS_DIR ${PROJECT_SOURCE_DIR}/src/include) +if(DEFINED ENV{MXSDK_OPENSOURCE_DIR}) + set(OPENSOURCE_DIR $ENV{MXSDK_OPENSOURCE_DIR}) +else() + set(OPENSOURCE_DIR ${MXBASE_ROOT_DIR}/opensource) +endif() + +include_directories(${ACL_INC_DIR}) +include_directories(${OPENSOURCE_DIR}/include) +include_directories(${OPENSOURCE_DIR}/include/opencv4) + +include_directories(${MXBASE_INC}) +include_directories(${MXBASE_POST_PROCESS_DIR}) + +link_directories(${ACL_LIB_DIR}) +link_directories(${OPENSOURCE_DIR}/lib) +link_directories(${MXBASE_LIB_DIR}) +link_directories(${MXBASE_POST_LIB_DIR}) + +add_executable(${TARGET} ./src/main.cpp ./src/resnext50Classify.cpp) +target_link_libraries(${TARGET} glog cpprest mxbase resnet50postprocess opencv_world stdc++fs) + +install(TARGETS ${TARGET} RUNTIME DESTINATION ${PROJECT_SOURCE_DIR}/) \ No newline at end of file diff --git a/official/cv/resnext/infer/mxbase/build.sh b/official/cv/resnext/infer/mxbase/build.sh new file mode 100644 index 0000000000000000000000000000000000000000..c003e99e1fdaf8d5faefed03a408f335b820a6eb --- /dev/null +++ b/official/cv/resnext/infer/mxbase/build.sh @@ -0,0 +1,54 @@ +#!/usr/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +path_cur=$(dirname $0) + +function check_env() +{ + # set ASCEND_VERSION to ascend-toolkit/latest when it was not specified by user + if [ ! "${ASCEND_VERSION}" ]; then + export ASCEND_VERSION=ascend-toolkit/latest + echo "Set ASCEND_VERSION to the default value: ${ASCEND_VERSION}" + else + echo "ASCEND_VERSION is set to ${ASCEND_VERSION} by user" + fi + + if [ ! "${ARCH_PATTERN}" ]; then + # set ARCH_PATTERN to ./ when it was not specified by user + export ARCH_PATTERN=./ + echo "ARCH_PATTERN is set to the default value: ${ARCH_PATTERN}" + else + echo "ARCH_PATTERN is set to ${ARCH_PATTERN} by user" + fi +} + +function build_resnext50() +{ + cd $path_cur + rm -rf build + mkdir -p build + cd build + cmake .. + make + ret=$? + if [ ${ret} -ne 0 ]; then + echo "Failed to build vgg16." + exit ${ret} + fi + make install +} + +check_env +build_resnext50 \ No newline at end of file diff --git a/official/cv/resnext/infer/mxbase/src/main.cpp b/official/cv/resnext/infer/mxbase/src/main.cpp new file mode 100644 index 0000000000000000000000000000000000000000..49290b0c6569b9afbfe8815840d47356288267ad --- /dev/null +++ b/official/cv/resnext/infer/mxbase/src/main.cpp @@ -0,0 +1,70 @@ +/* + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================ + */ +#include <experimental/filesystem> +#include "resnext50Classify.h" +#include "MxBase/Log/Log.h" + +namespace fs = std::experimental::filesystem; +namespace { +const uint32_t CLASS_NUM = 1000; +} +std::vector<double> g_inferCost; + +int main(int argc, char* argv[]) +{ if (argc <= 1) { + LogWarn << "Please input image path, such as './resnext ../data/input 10'."; + return APP_ERR_OK; + } + InitParam initParam = {}; + initParam.deviceId = 0; + initParam.classNum = CLASS_NUM; + initParam.labelPath = "../data/model/imagenet1000_clsidx_to_labels.names"; + initParam.topk = 5; + initParam.softmax = false; + initParam.checkTensor = true; + initParam.modelPath = "../data/model/resnext50.om"; + auto resnext50 = std::make_shared<resnext50Classify>(); + APP_ERROR ret = resnext50->Init(initParam); + if (ret != APP_ERR_OK) { + LogError << "resnext50Classify init failed, ret=" << ret << "."; + return ret; + } + std::string imgDir = argv[1]; + int limit = std::strtol(argv[2], nullptr, 0); + int index = 0; + for (auto & entry : fs::directory_iterator(imgDir)) { + if (limit > 0 && index == limit) { + break; + } + index++; + LogInfo << "read image path " << entry.path(); + ret = resnext50->Process(entry.path()); + if (ret != APP_ERR_OK) { + LogError << "resnext50Classify process failed, ret=" << ret << "."; + resnext50->DeInit(); + return ret; + } + } + resnext50->DeInit(); + double costSum = 0; + for (unsigned int i = 0; i < g_inferCost.size(); i++) { + costSum += g_inferCost[i]; + } + LogInfo << "Infer images sum " << g_inferCost.size() << ", cost total time: " << costSum << " ms."; + LogInfo << "The throughput: " << g_inferCost.size() * 1000 / costSum << " images/sec."; + return APP_ERR_OK; +} diff --git a/official/cv/resnext/infer/mxbase/src/resnext50Classify.cpp b/official/cv/resnext/infer/mxbase/src/resnext50Classify.cpp new file mode 100644 index 0000000000000000000000000000000000000000..20fb67cec72f6788f5b100fdae22e63e2346383f --- /dev/null +++ b/official/cv/resnext/infer/mxbase/src/resnext50Classify.cpp @@ -0,0 +1,225 @@ +/* + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================ + */ +#include <unistd.h> +#include <sys/stat.h> +#include <map> for map<> +#include <memory> for make_shared<> +#include <vector> for vector<> +#include <string> for string +#include "resnext50Classify.h" +#include "MxBase/DeviceManager/DeviceManager.h" +#include "MxBase/Log/Log.h" + +namespace { + const uint32_t YUV_BYTE_NU = 3; + const uint32_t YUV_BYTE_DE = 2; + const uint32_t VPC_H_ALIGN = 2; +} + +APP_ERROR resnext50Classify::Init(const InitParam &initParam) { + deviceId_ = initParam.deviceId; + APP_ERROR ret = MxBase::DeviceManager::GetInstance()->InitDevices(); + if (ret != APP_ERR_OK) { + LogError << "Init devices failed, ret=" << ret << "."; + return ret; + } + ret = MxBase::TensorContext::GetInstance()->SetContext(initParam.deviceId); + if (ret != APP_ERR_OK) { + LogError << "Set context failed, ret=" << ret << "."; + return ret; + } + model_ = std::make_shared<MxBase::ModelInferenceProcessor>(); + ret = model_->Init(initParam.modelPath, modelDesc_); + if (ret != APP_ERR_OK) { + LogError << "ModelInferenceProcessor init failed, ret=" << ret << "."; + return ret; + } + MxBase::ConfigData configData; + const std::string softmax = initParam.softmax ? "true" : "false"; + const std::string checkTensor = initParam.checkTensor ? "true" : "false"; + + configData.SetJsonValue("CLASS_NUM", std::to_string(initParam.classNum)); + configData.SetJsonValue("TOP_K", std::to_string(initParam.topk)); + configData.SetJsonValue("SOFTMAX", softmax); + configData.SetJsonValue("CHECK_MODEL", checkTensor); + + auto jsonStr = configData.GetCfgJson().serialize(); + std::map<std::string, std::shared_ptr<void>> config; + config["postProcessConfigContent"] = std::make_shared<std::string>(jsonStr); + config["labelPath"] = std::make_shared<std::string>(initParam.labelPath); + + post_ = std::make_shared<MxBase::Resnet50PostProcess>(); + ret = post_->Init(config); + if (ret != APP_ERR_OK) { + LogError << "Resnet50PostProcess init failed, ret=" << ret << "."; + return ret; + } + tfile_.open("mx_pred_result.txt"); + if (!tfile_) { + LogError << "Open result file failed."; + return APP_ERR_COMM_OPEN_FAIL; + } + return APP_ERR_OK; +} + +APP_ERROR resnext50Classify::DeInit() { + model_->DeInit(); + post_->DeInit(); + MxBase::DeviceManager::GetInstance()->DestroyDevices(); + tfile_.close(); + return APP_ERR_OK; +} + +APP_ERROR resnext50Classify::ReadImage(const std::string &imgPath, cv::Mat &imageMat) { + imageMat = cv::imread(imgPath, cv::IMREAD_COLOR); + return APP_ERR_OK; +} + +APP_ERROR resnext50Classify::Resize(const cv::Mat &srcImageMat, cv::Mat &dstImageMat) { + static constexpr uint32_t resizeHeight = 256; + static constexpr uint32_t resizeWidth = 256; + cv::resize(srcImageMat, dstImageMat, cv::Size(resizeHeight, resizeWidth)); + return APP_ERR_OK; +} + +APP_ERROR resnext50Classify::Crop(const cv::Mat &srcMat, cv::Mat &dstMat) { + static cv::Rect rectOfImg(16, 16, 224, 224); + dstMat = srcMat(rectOfImg).clone(); + return APP_ERR_OK; +} + +APP_ERROR resnext50Classify::CVMatToTensorBase(const cv::Mat &imageMat, MxBase::TensorBase &tensorBase) { + const uint32_t dataSize = imageMat.cols * imageMat.rows * imageMat.channels(); + MxBase::MemoryData MemoryDataDst(dataSize, MxBase::MemoryData::MEMORY_DEVICE, deviceId_); + MxBase::MemoryData MemoryDataSrc(imageMat.data, dataSize, MxBase::MemoryData::MEMORY_HOST_MALLOC); + APP_ERROR ret = MxBase::MemoryHelper::MxbsMallocAndCopy(MemoryDataDst, MemoryDataSrc); + if (ret != APP_ERR_OK) { + LogError << GetError(ret) << "Memory malloc failed."; + return ret; + } + std::vector<uint32_t> shape = {static_cast<uint32_t>(imageMat.rows), + static_cast<uint32_t>(imageMat.cols), static_cast<uint32_t>(imageMat.channels())}; + tensorBase = MxBase::TensorBase(MemoryDataDst, false, shape, MxBase::TENSOR_DTYPE_UINT8); + return APP_ERR_OK; +} + +APP_ERROR resnext50Classify::Inference(const std::vector<MxBase::TensorBase> &inputs, + std::vector<MxBase::TensorBase> &outputs) { + auto dtypes = model_->GetOutputDataType(); + for (size_t i = 0; i < modelDesc_.outputTensors.size(); ++i) { + std::vector<uint32_t> shape = {}; + for (size_t j = 0; j < modelDesc_.outputTensors[i].tensorDims.size(); ++j) { + shape.push_back((uint32_t)modelDesc_.outputTensors[i].tensorDims[j]); + } + MxBase::TensorBase tensor(shape, dtypes[i], MxBase::MemoryData::MemoryType::MEMORY_DEVICE, deviceId_); + APP_ERROR ret = MxBase::TensorBase::TensorBaseMalloc(tensor); + if (ret != APP_ERR_OK) { + LogError << "TensorBaseMalloc failed, ret=" << ret << "."; + return ret; + } + outputs.push_back(tensor); + } + MxBase::DynamicInfo dynamicInfo = {}; + dynamicInfo.dynamicType = MxBase::DynamicType::STATIC_BATCH; + auto startTime = std::chrono::high_resolution_clock::now(); + APP_ERROR ret = model_->ModelInference(inputs, outputs, dynamicInfo); + auto endTime = std::chrono::high_resolution_clock::now(); + double costMs = std::chrono::duration<double, std::milli>(endTime - startTime).count(); + g_inferCost.push_back(costMs); + if (ret != APP_ERR_OK) { + LogError << "ModelInference failed, ret=" << ret << "."; + return ret; + } + return APP_ERR_OK; +} + +APP_ERROR resnext50Classify::PostProcess(const std::vector<MxBase::TensorBase> &inputs, + std::vector<std::vector<MxBase::ClassInfo>> &clsInfos) { + APP_ERROR ret = post_->Process(inputs, clsInfos); + if (ret != APP_ERR_OK) { + LogError << "Process failed, ret=" << ret << "."; + return ret; + } + return APP_ERR_OK; +} + +APP_ERROR resnext50Classify::Process(const std::string &imgPath) { + cv::Mat imageMat; + APP_ERROR ret = ReadImage(imgPath, imageMat); + if (ret != APP_ERR_OK) { + LogError << "ReadImage failed, ret=" << ret << "."; + return ret; + } + ret = Resize(imageMat, imageMat); + if (ret != APP_ERR_OK) { + LogError << "Resize failed, ret=" << ret << "."; + return ret; + } + cv::Mat cropImage; + ret = Crop(imageMat, cropImage); + if (ret != APP_ERR_OK) { + LogError << "Crop failed, ret=" << ret << "."; + return ret; + } + MxBase::TensorBase tensorBase; + ret = CVMatToTensorBase(cropImage, tensorBase); + if (ret != APP_ERR_OK) { + LogError << "CVMatToTensorBase failed, ret=" << ret << "."; + return ret; + } + std::vector<MxBase::TensorBase> inputs = {}; + std::vector<MxBase::TensorBase> outputs = {}; + inputs.push_back(tensorBase); + ret = Inference(inputs, outputs); + if (ret != APP_ERR_OK) { + LogError << "Inference failed, ret=" << ret << "."; + return ret; + } + std::vector<std::vector<MxBase::ClassInfo>> BatchClsInfos = {}; + ret = PostProcess(outputs, BatchClsInfos); + if (ret != APP_ERR_OK) { + LogError << "PostProcess failed, ret=" << ret << "."; + return ret; + } + ret = SaveResult(imgPath, BatchClsInfos); + if (ret != APP_ERR_OK) { + LogError << "Export result to file failed, ret=" << ret << "."; + return ret; + } + return APP_ERR_OK; +} + + +APP_ERROR resnext50Classify::SaveResult(const std::string &imgPath, \ + const std::vector< std::vector<MxBase::ClassInfo>> &BatchClsInfos) { + uint32_t batchIndex = 0; + std::string fileName = imgPath.substr(imgPath.find_last_of("/") + 1); + size_t dot = fileName.find_last_of("."); + for (const auto &clsInfos : BatchClsInfos) { + std::string resultStr; + for (const auto &clsInfo : clsInfos) { + resultStr += std::to_string(clsInfo.classId) + ","; + } + tfile_ << fileName.substr(0, dot) << " " << resultStr << std::endl; + if (tfile_.fail()) { + LogError << "Failed to write the result to file."; + return APP_ERR_COMM_WRITE_FAIL; + } + batchIndex++; + } + return APP_ERR_OK; +} diff --git a/official/cv/resnext/infer/mxbase/src/resnext50Classify.h b/official/cv/resnext/infer/mxbase/src/resnext50Classify.h new file mode 100644 index 0000000000000000000000000000000000000000..8e0c0ba1eeb727fb037043ef556adfa9f72db7bd --- /dev/null +++ b/official/cv/resnext/infer/mxbase/src/resnext50Classify.h @@ -0,0 +1,60 @@ +/* + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================ + */ +#ifndef MXBASE_resnext50CLASSIFY_H +#define MXBASE_resnext50CLASSIFY_H + +#include <string> for string +#include <vector> for vector<> +#include <memory> for shared_ptr<> +#include <opencv2/opencv.hpp> +#include "MxBase/ModelInfer/ModelInferenceProcessor.h" +#include "MxBase/postprocess/include/ClassPostProcessors/Resnet50PostProcess.h" +#include "MxBase/Tensor/TensorContext/TensorContext.h" + +extern std::vector<double> g_inferCost; + +struct InitParam { + uint32_t deviceId; + std::string labelPath; + uint32_t classNum; + uint32_t topk; + bool softmax; + bool checkTensor; + std::string modelPath; +}; + +class resnext50Classify { + public: + APP_ERROR Init(const InitParam &initParam); + APP_ERROR DeInit(); + APP_ERROR ReadImage(const std::string &imgPath, cv::Mat &imageMat); + APP_ERROR Resize(const cv::Mat &srcImageMat, cv::Mat &dstImageMat); + APP_ERROR Crop(const cv::Mat &srcMat, cv::Mat &dstMat); + APP_ERROR CVMatToTensorBase(const cv::Mat &imageMat, MxBase::TensorBase &tensorBase); + APP_ERROR Inference(const std::vector<MxBase::TensorBase> &inputs, std::vector<MxBase::TensorBase> &outputs); + APP_ERROR PostProcess(const std::vector<MxBase::TensorBase> &inputs, + std::vector<std::vector<MxBase::ClassInfo>> &clsInfos); + APP_ERROR Process(const std::string &imgPath); + APP_ERROR SaveResult(const std::string &imgPath, const std::vector< std::vector<MxBase::ClassInfo>> &BatchClsInfos); + private: + std::shared_ptr<MxBase::ModelInferenceProcessor> model_; + std::shared_ptr<MxBase::Resnet50PostProcess> post_; + MxBase::ModelDesc modelDesc_; + uint32_t deviceId_ = 0; + std::ofstream tfile_; +}; +#endif diff --git a/official/cv/resnext/infer/sdk/main.py b/official/cv/resnext/infer/sdk/main.py new file mode 100644 index 0000000000000000000000000000000000000000..70811248d891f3314801690c109d6347736e4e81 --- /dev/null +++ b/official/cv/resnext/infer/sdk/main.py @@ -0,0 +1,171 @@ +#/usr/bin/env python +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + + +import argparse +import glob +import json +import os +from contextlib import ExitStack + +from StreamManagerApi import StreamManagerApi, MxDataInput, InProtobufVector, \ + MxProtobufIn +import MxpiDataType_pb2 as MxpiDataType + + +class GlobDataLoader(): + def __init__(self, glob_pattern, limit=None): + self.glob_pattern = glob_pattern + self.limit = limit + self.file_list = self.get_file_list() + self.cur_index = 0 + + def get_file_list(self): + return glob.iglob(self.glob_pattern) + + def __iter__(self): + return self + + def __next__(self): + if self.cur_index == self.limit: + raise StopIteration() + label = None + file_path = next(self.file_list) + with open(file_path, 'rb') as fd: + data = fd.read() + + self.cur_index += 1 + return get_file_name(file_path), label, data + + +class Predictor(): + def __init__(self, pipeline_conf, stream_name): + self.pipeline_conf = pipeline_conf + self.stream_name = stream_name + + def __enter__(self): + self.stream_manager_api = StreamManagerApi() + ret = self.stream_manager_api.InitManager() + if ret != 0: + raise Exception(f"Failed to init Stream manager, ret={ret}") + + # create streams by pipeline config file + with open(self.pipeline_conf, 'rb') as f: + pipeline_str = f.read() + ret = self.stream_manager_api.CreateMultipleStreams(pipeline_str) + if ret != 0: + raise Exception(f"Failed to create Stream, ret={ret}") + self.data_input = MxDataInput() + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + # destroy streams + self.stream_manager_api.DestroyAllStreams() + + def predict(self, dataset): + print("Start predict........") + print('>' * 30) + for name, label, data in dataset: + self.data_input.data = data + yield self._predict(name, self.data_input) + label0 = label + print(label0) + print("predict end.") + print('<' * 30) + + def _predict(self, name, data): + plugin_id = 0 + protobuf_data = self._predict_gen_protobuf() + self._predict_send_protobuf(self.stream_name, 1, protobuf_data) + unique_id = self._predict_send_data(self.stream_name, plugin_id, data) + result = self._predict_get_result(self.stream_name, unique_id) + return name, json.loads(result.data.decode()) + + def _predict_gen_protobuf(self): + object_list = MxpiDataType.MxpiObjectList() + object_vec = object_list.objectVec.add() + object_vec.x0 = 16 + object_vec.y0 = 16 + object_vec.x1 = 240 + object_vec.y1 = 240 + + protobuf = MxProtobufIn() + protobuf.key = b'appsrc1' + protobuf.type = b'MxTools.MxpiObjectList' + protobuf.protobuf = object_list.SerializeToString() + protobuf_vec = InProtobufVector() + protobuf_vec.push_back(protobuf) + return protobuf_vec + + def _predict_send_protobuf(self, stream_name, in_plugin_id, data): + self.stream_manager_api.SendProtobuf(stream_name, in_plugin_id, data) + + def _predict_send_data(self, stream_name, in_plugin_id, data_input): + unique_id = self.stream_manager_api.SendData(stream_name, in_plugin_id, + data_input) + if unique_id < 0: + raise Exception("Failed to send data to stream") + return unique_id + + def _predict_get_result(self, stream_name, unique_id): + result = self.stream_manager_api.GetResult(stream_name, unique_id) + if result.errorCode != 0: + raise Exception( + f"GetResultWithUniqueId error." + f"errorCode={result.errorCode}, msg={result.data.decode()}") + return result + + +def get_file_name(file_path): + return os.path.splitext(os.path.basename(file_path.rstrip('/')))[0] + + +def result_encode(file_name, result): + sep = ',' + pred_class_ids = sep.join( + str(i.get('classId')) for i in result.get("MxpiClass", [])) + return f"{file_name} {pred_class_ids}\n" + + +def parse_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('glob', help='img pth glob pattern.') + parser.add_argument('result_file', help='result file') + return parser.parse_args() + + +def main(): + pipeline_conf = "../data/config/resnext50.pipeline" + stream_name = b'im_resnext50' + + args = parse_args() + result_fname = get_file_name(args.result_file) + pred_result_file = f"{result_fname}.txt" + dataset = GlobDataLoader(args.glob, limit=None) + with ExitStack() as stack: + predictor = stack.enter_context(Predictor(pipeline_conf, stream_name)) + result_fd = stack.enter_context(open(pred_result_file, 'w')) + + for fname, pred_result in predictor.predict(dataset): + result_fd.write(result_encode(fname, pred_result)) + + print(f"success, result in {pred_result_file}") + + +if __name__ == "__main__": + main() diff --git a/official/cv/resnext/infer/sdk/run.sh b/official/cv/resnext/infer/sdk/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..518ae4d48d32cc72a2568ca0cfefc347525232c1 --- /dev/null +++ b/official/cv/resnext/infer/sdk/run.sh @@ -0,0 +1,30 @@ +#!/usr/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +set -e + +# Simple log helper functions +info() { echo -e "\033[1;34m[INFO ][MxStream] $1\033[1;37m" ; } +warn() { echo >&2 -e "\033[1;31m[WARN ][MxStream] $1\033[1;37m" ; } + +export LD_LIBRARY_PATH=${MX_SDK_HOME}/lib:${MX_SDK_HOME}/opensource/lib:${MX_SDK_HOME}/opensource/lib64:/usr/local/Ascend/ascend-toolkit/latest/acllib/lib64:${LD_LIBRARY_PATH} +export GST_PLUGIN_SCANNER=${MX_SDK_HOME}/opensource/libexec/gstreamer-1.0/gst-plugin-scanner +export GST_PLUGIN_PATH=${MX_SDK_HOME}/opensource/lib/gstreamer-1.0:${MX_SDK_HOME}/lib/plugins + +#to set PYTHONPATH, import the StreamManagerApi.py +export PYTHONPATH=$PYTHONPATH:${MX_SDK_HOME}/python + +python3.7 main.py "../data/input/*.JPEG" resnext50_sdk_pred_result.txt +exit 0 diff --git a/official/cv/resnext/infer/util/task_metric.py b/official/cv/resnext/infer/util/task_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..5dadaf18c62b4e8ce1ff0dbe077d8392f0533d52 --- /dev/null +++ b/official/cv/resnext/infer/util/task_metric.py @@ -0,0 +1,108 @@ +# coding: utf-8 +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + + +import argparse +import json +import os + +import numpy as np + + +def get_file_name(file_path): + return os.path.splitext(os.path.basename(file_path.rstrip('/')))[0] + + +def load_gt(gt_file): + gt = {} + with open(gt_file, 'r') as fd: + for line in fd.readlines(): + img_name, img_label_index = line.strip().split(" ", 1) + gt[get_file_name(img_name)] = img_label_index + return gt + + +def load_pred(pred_file): + pred = {} + with open(pred_file, 'r') as fd: + for line in fd.readlines(): + ret = line.strip().split(" ", 1) + if len(ret) < 2: + print(f"Warning: load pred, no result, line:{line}") + continue + img_name, ids = ret + img_name = get_file_name(img_name) + pred[img_name] = [x.strip() for x in ids.split(',')] + return pred + + +def calc_accuracy(gt_map, pred_map, top_k=5): + hits = [0] * top_k + miss_match = [] + total = 0 + for img, preds in pred_map.items(): + gt = gt_map.get(img) + if not gt: + print(f"Warning: {img}'s gt is not exists.") + continue + try: + index = preds.index(gt, 0, top_k) + hits[index] += 1 + except ValueError: + miss_match.append({'img': img, 'gt': gt, 'prediction': preds}) + finally: + total += 1 + + top_k_hit = np.cumsum(hits) + accuracy = top_k_hit / total + return { + 'total': total, + 'accuracy': [acc for acc in accuracy], + 'miss': miss_match, + } + + +def parse_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('prediction', help='prediction result file') + parser.add_argument('gt', help='ground true result file') + parser.add_argument('result_json', help='metric result file') + parser.add_argument('top_k', help='top k', type=int) + return parser.parse_args() + + +def main(): + args = parse_args() + prediction_file = args.prediction + gt_file = args.gt + top_k = args.top_k + result_json = args.result_json + + gt = load_gt(gt_file) + prediction = load_pred(prediction_file) + result = calc_accuracy(gt, prediction, top_k) + result.update({ + 'prediction_file': prediction_file, + 'gt_file': gt_file, + }) + with open(result_json, 'w') as fd: + json.dump(result, fd, indent=2) + print(f"\nsuccess, result in {result_json}") + + +if __name__ == '__main__': + main() diff --git a/official/cv/resnext/modelart/start.py b/official/cv/resnext/modelart/start.py new file mode 100644 index 0000000000000000000000000000000000000000..bf2584a6fea8d01c4b7942bb5deec0d6ab1512c2 --- /dev/null +++ b/official/cv/resnext/modelart/start.py @@ -0,0 +1,273 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train ImageNet.""" +import os +import time +import datetime +import numpy as np +import moxing as mox + +import mindspore.nn as nn +from mindspore.context import ParallelMode +from mindspore.nn.optim import Momentum +from mindspore.communication.management import init, get_rank, get_group_size +from mindspore.train.callback import ModelCheckpoint +from mindspore.train.callback import CheckpointConfig, Callback +from mindspore.train.model import Model +from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager +from mindspore.common import set_seed +from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export +import mindspore.common.initializer as weight_init +from mindspore.common import dtype as mstype + +from src.dataset import classification_dataset +from src.crossentropy import CrossEntropy +from src.lr_generator import get_lr +from src.utils.logging import get_logger +from src.utils.optimizers__init__ import get_param_groups +from src.image_classification import get_network +from src.model_utils.config import config +from src.model_utils.moxing_adapter import moxing_wrapper +from src.utils.auto_mixed_precision import auto_mixed_precision + +set_seed(1) + +class BuildTrainNetwork(nn.Cell): + """build training network""" + def __init__(self, network, criterion): + super(BuildTrainNetwork, self).__init__() + self.network = network + self.criterion = criterion + + def construct(self, input_data, label): + output = self.network(input_data) + loss = self.criterion(output, label) + return loss + +class ProgressMonitor(Callback): + """monitor loss and time""" + def __init__(self, args): + super(ProgressMonitor, self).__init__() + self.me_epoch_start_time = 0 + self.me_epoch_start_step_num = 0 + self.args = args + self.ckpt_history = [] + + def begin(self, run_context): + self.args.logger.info('start network train...') + + def epoch_begin(self, run_context): + pass + + def epoch_end(self, run_context, *me_args): + cb_params = run_context.original_args() + me_step = cb_params.cur_step_num - 1 + + real_epoch = me_step // self.args.steps_per_epoch + time_used = time.time() - self.me_epoch_start_time + fps_mean = self.args.per_batch_size * (me_step-self.me_epoch_start_step_num) * self.args.group_size / time_used + self.args.logger.info('epoch[{}], iter[{}], loss:{}, mean_fps:{:.2f}' + 'imgs/sec'.format(real_epoch, me_step, cb_params.net_outputs, fps_mean)) + + if self.args.rank_save_ckpt_flag: + import glob + ckpts = glob.glob(os.path.join(self.args.outputs_dir, '*.ckpt')) + for ckpt in ckpts: + ckpt_fn = os.path.basename(ckpt) + if not ckpt_fn.startswith('{}-'.format(self.args.rank)): + continue + if ckpt in self.ckpt_history: + continue + self.ckpt_history.append(ckpt) + self.args.logger.info('epoch[{}], iter[{}], loss:{}, ckpt:{},' + 'ckpt_fn:{}'.format(real_epoch, me_step, cb_params.net_outputs, ckpt, ckpt_fn)) + + + self.me_epoch_start_step_num = me_step + self.me_epoch_start_time = time.time() + + def step_begin(self, run_context): + pass + + def step_end(self, run_context, *me_args): + pass + + def end(self, run_context): + self.args.logger.info('end network train...') + + +def set_parameters(): + """parameters""" + context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, + device_target=config.device_target, save_graphs=False) + # init distributed + if config.run_distribute: + init() + config.rank = get_rank() + config.group_size = get_group_size() + else: + config.rank = 0 + config.group_size = 1 + init() + + if config.is_dynamic_loss_scale == 1: + config.loss_scale = 1 # for dynamic loss scale can not set loss scale in momentum opt + + # select for master rank save ckpt or all rank save, compatible for model parallel + config.rank_save_ckpt_flag = 0 + if config.is_save_on_master: + if config.rank == 0: + config.rank_save_ckpt_flag = 1 + else: + config.rank_save_ckpt_flag = 1 + + # logger + config.outputs_dir = os.path.join(config.output_path, + datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) + config.logger = get_logger(config.outputs_dir, config.rank) + return config + +def set_graph_kernel_context(device_target): + if device_target == "GPU": + context.set_context(enable_graph_kernel=True) + +def _get_last_ckpt(ckpt_dir): + ckpt_files = [ckpt_file for ckpt_file in os.listdir(ckpt_dir) + if ckpt_file.endswith('.ckpt')] + if not ckpt_files: + print("No ckpt file found.") + return None + + return os.path.join(ckpt_dir, sorted(ckpt_files)[-1]) + +def run_export(ckpt_dir): + """run export.""" + checkpoint_file_path = _get_last_ckpt(ckpt_dir) + network = get_network(network=config.network, num_classes=config.num_classes, platform=config.device_target) + + param_dict = load_checkpoint(checkpoint_file_path) + load_param_into_net(network, param_dict) + if config.device_target == "Ascend": + network.to_float(mstype.float16) + else: + auto_mixed_precision(network) + network.set_train(False) + input_shp = [config.batch_size, 3, config.height, config.width] + input_array = Tensor(np.random.uniform(-1.0, 1.0, size=input_shp).astype(np.float32)) + export(network, input_array, file_name=config.file_name, file_format=config.file_format) + mox.file.copy_parallel(os.getcwd(), config.train_url) + +def filter_checkpoint_parameter_by_list(origin_dict, param_filter): + """remove useless parameters according to filter_list""" + for key in list(origin_dict.keys()): + for name in param_filter: + if name in key: + print("Delete parameter from checkpoint: ", key) + del origin_dict[key] + break + +def init_weight(net): + if os.path.exists(config.checkpoint_file_path): + param_dict = load_checkpoint(config.checkpoint_file_path) + filter_weight = True + print(1111111) + if filter_weight: + filter_list = ['head.fc.weight', 'head.fc.bias'] + filter_checkpoint_parameter_by_list(param_dict, filter_list) + load_param_into_net(net, param_dict) + + else: + for _, cell in net.cells_and_names(): + if isinstance(cell, nn.Conv2d): + cell.weight.set_data( + weight_init.initializer(weight_init.XavierUniform(), + cell.weight.shape, + cell.weight.dtype)) + if isinstance(cell, nn.Dense): + cell.weight.set_data( + weight_init.initializer(weight_init.TruncatedNormal(), + cell.weight.shape, + cell.weight.dtype)) + + +@moxing_wrapper() +def train(): + """training process""" + set_parameters() + if os.getenv('DEVICE_ID', "not_set").isdigit(): + context.set_context(device_id=int(os.getenv('DEVICE_ID'))) + set_graph_kernel_context(config.device_target) + + # init distributed + if config.run_distribute: + parallel_mode = ParallelMode.DATA_PARALLEL + context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=config.group_size, + gradients_mean=True) + # dataloader + de_dataset = classification_dataset(config.data_path, config.image_size, + config.per_batch_size, 1, + config.rank, config.group_size, num_parallel_workers=8) + config.steps_per_epoch = de_dataset.get_dataset_size() + + config.logger.save_args(config) + + # network + config.logger.important_info('start create network') + # get network and init + network = get_network(network=config.network, num_classes=config.num_classes, platform=config.device_target) + init_weight(network) + + + # lr scheduler + lr = get_lr(config) + + # optimizer + opt = Momentum(params=get_param_groups(network), + learning_rate=Tensor(lr), + momentum=config.momentum, + weight_decay=config.weight_decay, + loss_scale=config.loss_scale) + + + # loss + if not config.label_smooth: + config.label_smooth_factor = 0.0 + loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.num_classes) + + if config.is_dynamic_loss_scale == 1: + loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000) + else: + loss_scale_manager = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) + + model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager, + metrics={'acc'}, amp_level="O3") + + # checkpoint save + progress_cb = ProgressMonitor(config) + callbacks = [progress_cb,] + if config.rank_save_ckpt_flag: + ckpt_config = CheckpointConfig(save_checkpoint_steps=config.ckpt_interval * config.steps_per_epoch, + keep_checkpoint_max=config.ckpt_save_max) + save_ckpt_path = os.path.join(config.outputs_dir, 'ckpt_' + str(config.rank) + '/') + ckpt_cb = ModelCheckpoint(config=ckpt_config, + directory=save_ckpt_path, + prefix='{}'.format(config.rank)) + callbacks.append(ckpt_cb) + + model.train(config.max_epoch, de_dataset, callbacks=callbacks, dataset_sink_mode=True) + run_export(save_ckpt_path) + +if __name__ == "__main__": + train()