Skip to content
Snippets Groups Projects
Commit 0f1a5c5b authored by l00486551's avatar l00486551
Browse files

add vgg19 310 infer master

parent 5175ab29
No related branches found
No related tags found
No related merge requests found
Showing
with 625 additions and 26 deletions
......@@ -257,6 +257,14 @@ python eval.py --config_path=[YAML_CONFIG_PATH] --device_target="GPU" --dataset=
├── vgg19
├── README.md // descriptions about vgg
├── README_CN.md // descriptions about vgg with Chinese
├── ascend310_infer
├── ├── CMakeLists.txt // CMake
├── ├── build.sh // shell file of build
├── ├── inc
├── ├── ├── utils.h // utils.h
├── ├── src
├── ├── ├── main.cc // main file
├── ├── ├── utils.cc // utils.cc
├── model_utils
│ ├── __init__.py // init file
│ ├── config.py // Parse arguments
......@@ -447,7 +455,7 @@ python train.py --config_path=/dir_to_code/imagenet2012_config.yaml --device_ta
```bash
# distributed training(4p)
bash scripts/run_distribute_train_gpu.sh /path/ImageNet2012/train"
bash scripts/run_distribute_train_gpu.sh /path/ImageNet2012/train
```
### [Evaluation Process](#contents)
......@@ -501,7 +509,8 @@ Current batch_Size for imagenet2012 dataset can only be set to 1.
Inference result is saved in current path, you can find result like this in acc.log file.
```bash
'acc': 0.92
'top1 acc': 0.748
'top5 acc': 0.922
```
## [Model Description](#contents)
......@@ -545,4 +554,4 @@ In dataset.py, we set the seed inside “create_dataset" function. We also use r
## [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
Please check the official [homepage](https://gitee.com/mindspore/models/tree/master/).
......@@ -272,16 +272,25 @@ python eval.py --config_path=[YAML_CONFIG_PATH] --device_target="GPU" --dataset=
├── vgg19
├── README.md // VGG 相关说明
├── README_CN.md // VGG 相关中文说明
├── ascend310_infer
├── ├── CMakeLists.txt // CMake文件
├── ├── build.sh // build脚本文件
├── ├── inc
├── ├── ├── utils.h // utils类头文件
├── ├── src
├── ├── ├── main.cc // 主文件
├── ├── ├── utils.cc // utils类实现
├── model_utils
├── __init__.py // 初始化文件
├── config.py // 参数配置
├── device_adapter.py // ModelArts的设备适配器
├── local_adapter.py // 本地适配器
└── moxing_adapter.py // ModelArts的模型适配器
├── ├── __init__.py // 初始化文件
├── ├── config.py // 参数配置
├── ├── device_adapter.py // ModelArts的设备适配器
├── ├── local_adapter.py // 本地适配器
├── └── moxing_adapter.py // ModelArts的模型适配器
├── scripts
│ ├── run_distribute_train.sh // Ascend 分布式训练shell脚本
│ ├── run_distribute_train_gpu.sh // GPU 分布式训练shell脚本
│ ├── run_eval.sh // Ascend 验证shell脚本
│ ├── run_infer_310.sh // Ascend 310推理脚本
├── src
│ ├── utils
│ │ ├── logging.py // 日志格式设置
......@@ -463,7 +472,7 @@ python train.py --config_path=/dir_to_code/imagenet2012_config.yaml --device_tar
```bash
# 分布式训练(8p)
bash scripts/run_distribute_train_gpu.sh /path/ImageNet2012/train"
bash scripts/run_distribute_train_gpu.sh /path/ImageNet2012/train
```
### 评估过程
......@@ -517,7 +526,8 @@ python export.py --config_path [YMAL_CONFIG_PATH] --ckpt_file [CKPT_PATH] --file
推理结果保存在脚本执行的当前路径,你可以在acc.log中看到以下精度计算结果。
```bash
'acc': 0.92
'top1 acc': 0.748
'top5 acc': 0.922
```
## 模型描述
......@@ -562,4 +572,4 @@ dataset.py中设置了“create_dataset”函数内的种子,同时还使用
## ModelZoo主页
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)
请浏览官网[主页](https://gitee.com/mindspore/models/tree/master/)
cmake_minimum_required(VERSION 3.14.1)
project(Ascend310Infer)
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
option(MINDSPORE_PATH "mindspore install path" "")
include_directories(${MINDSPORE_PATH})
include_directories(${MINDSPORE_PATH}/include)
include_directories(${PROJECT_SRC_ROOT})
find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib)
file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*)
add_executable(main src/main.cc src/utils.cc)
target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags)
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ -d out ]; then
rm -rf out
fi
mkdir out
cd out || exit
if [ -f "Makefile" ]; then
make clean
fi
cmake .. \
-DMINDSPORE_PATH="`pip3.7 show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`"
make
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INFERENCE_UTILS_H_
#define MINDSPORE_INFERENCE_UTILS_H_
#include <sys/stat.h>
#include <dirent.h>
#include <vector>
#include <string>
#include <memory>
#include "include/api/types.h"
std::vector<std::string> GetAllFiles(std::string_view dirName);
DIR *OpenDir(std::string_view dirName);
std::string RealPath(std::string_view path);
mindspore::MSTensor ReadFileToTensor(const std::string &file);
int WriteResult(const std::string& imageFile, const std::vector<mindspore::MSTensor> &outputs);
std::vector<std::string> GetAllFiles(std::string dir_name);
std::vector<std::vector<std::string>> GetAllInputData(std::string dir_name);
#endif
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <sys/time.h>
#include <gflags/gflags.h>
#include <dirent.h>
#include <iostream>
#include <string>
#include <algorithm>
#include <iosfwd>
#include <vector>
#include <fstream>
#include <sstream>
#include "include/api/model.h"
#include "include/api/context.h"
#include "include/api/types.h"
#include "include/api/serialization.h"
#include "include/dataset/vision_ascend.h"
#include "include/dataset/execute.h"
#include "include/dataset/transforms.h"
#include "include/dataset/vision.h"
#include "inc/utils.h"
using mindspore::Context;
using mindspore::Serialization;
using mindspore::Model;
using mindspore::Status;
using mindspore::ModelType;
using mindspore::GraphCell;
using mindspore::kSuccess;
using mindspore::MSTensor;
using mindspore::dataset::Execute;
using mindspore::dataset::vision::Decode;
using mindspore::dataset::vision::Resize;
using mindspore::dataset::vision::CenterCrop;
using mindspore::dataset::vision::Normalize;
using mindspore::dataset::vision::HWC2CHW;
DEFINE_string(mindir_path, "", "mindir path");
DEFINE_string(dataset_name, "cifar10", "['cifar10', 'imagenet2012']");
DEFINE_string(input0_path, ".", "input0 path");
DEFINE_int32(device_id, 0, "device id");
int load_model(Model *model, std::vector<MSTensor> *model_inputs, std::string mindir_path, int device_id) {
if (RealPath(mindir_path).empty()) {
std::cout << "Invalid mindir" << std::endl;
return 1;
}
auto context = std::make_shared<Context>();
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
ascend310->SetDeviceID(device_id);
context->MutableDeviceInfo().push_back(ascend310);
mindspore::Graph graph;
Serialization::Load(mindir_path, ModelType::kMindIR, &graph);
Status ret = model->Build(GraphCell(graph), context);
if (ret != kSuccess) {
std::cout << "ERROR: Build failed." << std::endl;
return 1;
}
*model_inputs = model->GetInputs();
if (model_inputs->empty()) {
std::cout << "Invalid model, inputs is empty." << std::endl;
return 1;
}
return 0;
}
int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
Model model;
std::vector<MSTensor> model_inputs;
load_model(&model, &model_inputs, FLAGS_mindir_path, FLAGS_device_id);
std::map<double, double> costTime_map;
struct timeval start = {0};
struct timeval end = {0};
double startTimeMs;
double endTimeMs;
if (FLAGS_dataset_name == "cifar10") {
auto input0_files = GetAllFiles(FLAGS_input0_path);
if (input0_files.empty()) {
std::cout << "ERROR: no input data." << std::endl;
return 1;
}
size_t size = input0_files.size();
for (size_t i = 0; i < size; ++i) {
std::vector<MSTensor> inputs;
std::vector<MSTensor> outputs;
std::cout << "Start predict input files:" << input0_files[i] <<std::endl;
auto input0 = ReadFileToTensor(input0_files[i]);
inputs.emplace_back(model_inputs[0].Name(), model_inputs[0].DataType(), model_inputs[0].Shape(),
input0.Data().get(), input0.DataSize());
gettimeofday(&start, nullptr);
Status ret = model.Predict(inputs, &outputs);
gettimeofday(&end, nullptr);
if (ret != kSuccess) {
std::cout << "Predict " << input0_files[i] << " failed." << std::endl;
return 1;
}
startTimeMs = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000;
endTimeMs = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000;
costTime_map.insert(std::pair<double, double>(startTimeMs, endTimeMs));
WriteResult(input0_files[i], outputs);
}
} else {
auto input0_files = GetAllInputData(FLAGS_input0_path);
if (input0_files.empty()) {
std::cout << "ERROR: no input data." << std::endl;
return 1;
}
size_t size = input0_files.size();
for (size_t i = 0; i < size; ++i) {
for (size_t j = 0; j < input0_files[i].size(); ++j) {
std::vector<MSTensor> inputs;
std::vector<MSTensor> outputs;
std::cout << "Start predict input files:" << input0_files[i][j] <<std::endl;
auto decode = Decode();
auto resize = Resize({256, 256});
auto centercrop = CenterCrop({224, 224});
auto normalize = Normalize({123.675, 116.28, 103.53}, {58.395, 57.12, 57.375});
auto hwc2chw = HWC2CHW();
Execute SingleOp({decode, resize, centercrop, normalize, hwc2chw});
auto imgDvpp = std::make_shared<MSTensor>();
SingleOp(ReadFileToTensor(input0_files[i][j]), imgDvpp.get());
inputs.emplace_back(model_inputs[0].Name(), model_inputs[0].DataType(), model_inputs[0].Shape(),
imgDvpp->Data().get(), imgDvpp->DataSize());
gettimeofday(&start, nullptr);
Status ret = model.Predict(inputs, &outputs);
gettimeofday(&end, nullptr);
if (ret != kSuccess) {
std::cout << "Predict " << input0_files[i][j] << " failed." << std::endl;
return 1;
}
startTimeMs = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000;
endTimeMs = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000;
costTime_map.insert(std::pair<double, double>(startTimeMs, endTimeMs));
WriteResult(input0_files[i][j], outputs);
}
}
}
double average = 0.0;
int inferCount = 0;
for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) {
double diff = 0.0;
diff = iter->second - iter->first;
average += diff;
inferCount++;
}
average = average / inferCount;
std::stringstream timeCost;
timeCost << "NN inference cost average time: "<< average << " ms of infer_count " << inferCount << std::endl;
std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << inferCount << std::endl;
std::string fileName = "./time_Result" + std::string("/test_perform_static.txt");
std::ofstream fileStream(fileName.c_str(), std::ios::trunc);
fileStream << timeCost.str();
fileStream.close();
costTime_map.clear();
return 0;
}
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <fstream>
#include <algorithm>
#include <iostream>
#include "inc/utils.h"
using mindspore::MSTensor;
using mindspore::DataType;
std::vector<std::vector<std::string>> GetAllInputData(std::string dir_name) {
std::vector<std::vector<std::string>> ret;
DIR *dir = OpenDir(dir_name);
if (dir == nullptr) {
return {};
}
struct dirent *filename;
/* read all the files in the dir ~ */
std::vector<std::string> sub_dirs;
while ((filename = readdir(dir)) != nullptr) {
std::string d_name = std::string(filename->d_name);
// get rid of "." and ".."
if (d_name == "." || d_name == ".." || d_name.empty()) {
continue;
}
std::string dir_path = RealPath(std::string(dir_name) + "/" + filename->d_name);
struct stat s;
lstat(dir_path.c_str(), &s);
if (!S_ISDIR(s.st_mode)) {
continue;
}
sub_dirs.emplace_back(dir_path);
}
std::sort(sub_dirs.begin(), sub_dirs.end());
(void)std::transform(sub_dirs.begin(), sub_dirs.end(), std::back_inserter(ret),
[](const std::string &d) { return GetAllFiles(d); });
return ret;
}
std::vector<std::string> GetAllFiles(std::string dir_name) {
struct dirent *filename;
DIR *dir = OpenDir(dir_name);
if (dir == nullptr) {
return {};
}
std::vector<std::string> res;
while ((filename = readdir(dir)) != nullptr) {
std::string d_name = std::string(filename->d_name);
if (d_name == "." || d_name == ".." || d_name.size() <= 3) {
continue;
}
res.emplace_back(std::string(dir_name) + "/" + filename->d_name);
}
std::sort(res.begin(), res.end());
return res;
}
std::vector<std::string> GetAllFiles(std::string_view dirName) {
struct dirent *filename;
DIR *dir = OpenDir(dirName);
if (dir == nullptr) {
return {};
}
std::vector<std::string> res;
while ((filename = readdir(dir)) != nullptr) {
std::string dName = std::string(filename->d_name);
if (dName == "." || dName == ".." || filename->d_type != DT_REG) {
continue;
}
res.emplace_back(std::string(dirName) + "/" + filename->d_name);
}
std::sort(res.begin(), res.end());
for (auto &f : res) {
std::cout << "image file: " << f << std::endl;
}
return res;
}
int WriteResult(const std::string& imageFile, const std::vector<MSTensor> &outputs) {
std::string homePath = "./result_Files";
for (size_t i = 0; i < outputs.size(); ++i) {
size_t outputSize;
std::shared_ptr<const void> netOutput;
netOutput = outputs[i].Data();
outputSize = outputs[i].DataSize();
int pos = imageFile.rfind('/');
std::string fileName(imageFile, pos + 1);
fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), '_' + std::to_string(i) + ".bin");
std::string outFileName = homePath + "/" + fileName;
FILE *outputFile = fopen(outFileName.c_str(), "wb");
fwrite(netOutput.get(), outputSize, sizeof(char), outputFile);
fclose(outputFile);
outputFile = nullptr;
}
return 0;
}
mindspore::MSTensor ReadFileToTensor(const std::string &file) {
if (file.empty()) {
std::cout << "Pointer file is nullptr" << std::endl;
return mindspore::MSTensor();
}
std::ifstream ifs(file);
if (!ifs.good()) {
std::cout << "File: " << file << " is not exist" << std::endl;
return mindspore::MSTensor();
}
if (!ifs.is_open()) {
std::cout << "File: " << file << "open failed" << std::endl;
return mindspore::MSTensor();
}
ifs.seekg(0, std::ios::end);
size_t size = ifs.tellg();
mindspore::MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8, {static_cast<int64_t>(size)}, nullptr, size);
ifs.seekg(0, std::ios::beg);
ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
ifs.close();
return buffer;
}
DIR *OpenDir(std::string_view dirName) {
if (dirName.empty()) {
std::cout << " dirName is null ! " << std::endl;
return nullptr;
}
std::string realPath = RealPath(dirName);
struct stat s;
lstat(realPath.c_str(), &s);
if (!S_ISDIR(s.st_mode)) {
std::cout << "dirName is not a valid directory !" << std::endl;
return nullptr;
}
DIR *dir;
dir = opendir(realPath.c_str());
if (dir == nullptr) {
std::cout << "Can not open dir " << dirName << std::endl;
return nullptr;
}
std::cout << "Successfully opened the dir " << dirName << std::endl;
return dir;
}
std::string RealPath(std::string_view path) {
char realPathMem[PATH_MAX] = {0};
char *realPathRet = nullptr;
realPathRet = realpath(path.data(), realPathMem);
if (realPathRet == nullptr) {
std::cout << "File: " << path << " is not exist.";
return "";
}
std::string realPath(realPathMem);
std::cout << path << " realpath is: " << realPath << std::endl;
return realPath;
}
......@@ -34,7 +34,7 @@ def modelarts_pre_process():
@moxing_wrapper(pre_process=modelarts_pre_process)
def run_export():
"""export model in mindir/air/onnx"""
'''run_export function.'''
config.image_size = list(map(int, config.image_size.split(',')))
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
......
......@@ -14,7 +14,7 @@
# ============================================================================
"""hub config."""
from src.vgg import vgg19 as VGG19
from model_utils.moxing_adapter import config
def vgg19(*args, **kwargs):
return VGG19(*args, **kwargs)
......@@ -22,5 +22,5 @@ def vgg19(*args, **kwargs):
def create_network(name, *args, **kwargs):
if name == "vgg19":
return vgg19(*args, **kwargs)
return vgg19(args=config, *args, **kwargs)
raise NotImplementedError(f"{name} is not implemented in the repo")
......@@ -22,7 +22,7 @@ from model_utils.moxing_adapter import config
def create_label(result_path, dir_path):
"""create label for dataset"""
'''create_label function'''
print("[WARNING] Create imagenet label. Currently only use for Imagenet2012!")
dirs = os.listdir(dir_path)
file_list = []
......
easydict
numpy
pillow
pyyaml
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [[ $# -lt 4 || $# -gt 5 ]]; then
echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATASET_NAME] [DATASET_PATH] [NEED_PREPROCESS] [DEVICE_ID]
DATASET_NAME can choose from ['cifar10', 'imagenet2012'].
NEED_PREPROCESS means weather need preprocess or not, it's value is 'y' or 'n'.
DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
model=$(get_real_path $1)
if [ $2 == 'cifar10' ] || [ $2 == 'imagenet2012' ]; then
dataset_name=$2
else
echo "DATASET_NAME can choose from ['cifar10', 'imagenet2012']"
exit 1
fi
config_path=$(get_real_path "../${dataset_name}_config.yaml")
echo "config path is : ${config_path}"
dataset_path=$(get_real_path $3)
if [ "$4" == "y" ] || [ "$4" == "n" ];then
need_preprocess=$4
else
echo "weather need preprocess or not, it's value must be in [y, n]"
exit 1
fi
device_id=0
if [ $# == 5 ]; then
device_id=$5
fi
echo "mindir name: "$model
echo "dataset name: "$dataset_name
echo "dataset path: "$dataset_path
echo "need preprocess: "$need_preprocess
echo "device id: "$device_id
export ASCEND_HOME=/usr/local/Ascend/
if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
export PATH=$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH
export LD_LIBRARY_PATH=$ASCEND_HOME/fwkacllib/lib64:/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe
export PYTHONPATH=$ASCEND_HOME/fwkacllib/python/site-packages:${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
else
export PATH=$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
export LD_LIBRARY_PATH=$ASCEND_HOME/fwkacllib/lib64:/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export PYTHONPATH=$ASCEND_HOME/fwkacllib/python/site-packages:$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
fi
function preprocess_data()
{
if [ -d preprocess_Result ]; then
rm -rf ./preprocess_Result
fi
mkdir preprocess_Result
python3.7 ../preprocess.py --config_path=$config_path --dataset=$dataset_name --data_dir=$dataset_path --result_path=./preprocess_Result/
}
function compile_app()
{
cd ../ascend310_infer/ || exit
bash build.sh &> build.log
}
function infer()
{
cd - || exit
if [ -d result_Files ]; then
rm -rf ./result_Files
fi
if [ -d time_Result ]; then
rm -rf ./time_Result
fi
mkdir result_Files
mkdir time_Result
if [ "$dataset_name" == "cifar10" ]; then
../ascend310_infer/out/main --mindir_path=$model --dataset_name=$dataset_name --input0_path=./preprocess_Result/00_data --device_id=$device_id &> infer.log
else
../ascend310_infer/out/main --mindir_path=$model --dataset_name=$dataset_name --input0_path=$dataset_path --device_id=$device_id &> infer.log
fi
}
function cal_acc()
{
if [ "$dataset_name" == "cifar10" ]; then
python3.7 ../postprocess.py --config_path=$config_path --result_dir=./result_Files --label_dir=./preprocess_Result/cifar10_label_ids.npy --dataset_name=$dataset_name &> acc.log
else
python3.7 ../postprocess.py --config_path=$config_path --result_dir=./result_Files --label_dir=./preprocess_Result/imagenet_label.json --dataset_name=$dataset_name &> acc.log
fi
}
if [ $need_preprocess == "y" ]; then
preprocess_data
if [ $? -ne 0 ]; then
echo "preprocess dataset failed"
exit 1
fi
fi
compile_app
if [ $? -ne 0 ]; then
echo "compile app code failed"
exit 1
fi
infer
if [ $? -ne 0 ]; then
echo " execute inference failed"
exit 1
fi
cal_acc
if [ $? -ne 0 ]; then
echo "calculate accuracy failed"
exit 1
fi
\ No newline at end of file
......@@ -16,11 +16,11 @@
dataset processing.
"""
import os
from PIL import Image, ImageFile
from mindspore.common import dtype as mstype
import mindspore.dataset as de
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as vision
from PIL import Image, ImageFile
from src.utils.sampler import DistributedSampler
ImageFile.LOAD_TRUNCATED_IMAGES = True
......
......@@ -127,9 +127,9 @@ cfg = {
}
def vgg19(num_classes=1000, args=None, phase="train"):
def vgg19(num_classes=1000, args=None, phase="train", **kwargs):
"""
Get Vgg19 neural network with batch normalization.
Get Vgg19 neural network with Batch Normalization.
Args:
num_classes (int): Class numbers. Default: 1000.
......@@ -137,14 +137,10 @@ def vgg19(num_classes=1000, args=None, phase="train"):
phase(str): train or test mode.
Returns:
Cell, cell instance of Vgg19 neural network with batch normalization.
Cell, cell instance of Vgg19 neural network with Batch Normalization.
Examples:
>>> vgg19(num_classes=1000, args=args)
>>> vgg19(num_classes=1000, args=args, **kwargs)
"""
if args is None:
from .config import cifar_cfg
args = cifar_cfg
net = Vgg(cfg['19'], num_classes=num_classes, args=args, batch_norm=args.batch_norm, phase=phase)
net = Vgg(cfg['19'], num_classes=num_classes, args=args, batch_norm=args.batch_norm, phase=phase, **kwargs)
return net
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment