Skip to content
Snippets Groups Projects
Commit 148f07b1 authored by wukesong's avatar wukesong
Browse files

310 infer

modify

format

u2net modify

delete unused mode

add one space

u2net modify

small modify
parent b66839a3
No related branches found
No related tags found
No related merge requests found
......@@ -9,7 +9,9 @@
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Run On Modelarts](#run-on-modelarts)
- [Model Export](#model-export)
- [Training Process](#training-process)
- [Ascend310 Inference Process](#ascend310-inference-process)
- [Model Description](#model-description)
- [Performance](#performance)
- [Training Performance](#training-performance)
......@@ -54,7 +56,8 @@ To train U<sup>2</sup>-Net, We use the dataset [DUTS-TR](http://saliencydetectio
```shell
U-2-Net
├─ README.md # descriptions about U-2-Net
├─ scripts
├─ scripts
├─ run_infer_310.sh # 310 inference
└─ run_distribute_train.sh # launch Ascend training (8 Ascend)
├─ assets # save pics for README.MD
├─ ckpts # save ckpt
......@@ -65,7 +68,11 @@ U-2-Net
├─ train_modelarts.py # train script for online train
├─ test.py # generate detection images
├─ eval.py # eval script
└─ train.py # train script
├─ train.py # train script
├─ ascend310_infer # 310 main
├─ export.py
├─ preprocess.py
└─ postprocess.py
```
## [Script Parameters](#contents)
......@@ -166,6 +173,30 @@ bash run_distribute_train.sh [/path/to/content] [/path/to/label] [/path/to/RANK_
1. hccl.json which is specified by RANK_TABLE_FILE is needed when you are running a distribute task. You can generate it
by using the [hccl_tools](https://gitee.com/mindspore/models/tree/master/utils/hccl_tools).
## [Model Export](#contents)
```bash
python export.py --ckpt_dir [/path/to/ckpt_file]
```
## [Ascend310 Inference Process](#contents)
### Export MINDIR file
```bash
python export.py --ckpt_file [/path/to/ckpt_file]
```
### Ascend310 Inference
- Run `run_infer_310.sh` for Ascend310 inference.
```bash
# infer
bash run_infer_310.sh [MINDIR_PATH] [CONTENT_PATH] [LABEL_PATH] [DEVICE_ID]
```
Semantically segmented pictures will be stored in the postprocess_Result path and the evaluation result will be stored in evaluation.log.
# [Model Description](#contents)
......@@ -191,7 +222,7 @@ bash run_distribute_train.sh [/path/to/content] [/path/to/label] [/path/to/RANK_
| Parameters | single Ascend |
| ----------------- | ------------------------------------------------ |
| Model Version | v1 |
| Model Version | U-2-Net |
| Resource | Red Hat 8.3.1; Ascend 910; CPU 2.60GHz; 192cores |
| MindSpore Version | 1.3.0 |
| Dataset | content images |
......@@ -203,7 +234,7 @@ bash run_distribute_train.sh [/path/to/content] [/path/to/label] [/path/to/RANK_
| Parameters | single Ascend |
| ----------------- | ------------------------------------------------ |
| Model Version | v1 |
| Model Version | U-2-Net |
| Resource | Red Hat 8.3.1; Ascend 910; CPU 2.60GHz; 192cores |
| MindSpore Version | 1.3.0 |
| Dataset | DUTS-TE |
......
cmake_minimum_required(VERSION 3.14.1)
project(Ascend310Infer)
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
option(MINDSPORE_PATH "mindspore install path" "")
include_directories(${MINDSPORE_PATH})
include_directories(${MINDSPORE_PATH}/include)
include_directories(${PROJECT_SRC_ROOT})
find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib)
file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*)
add_executable(main src/main.cc src/utils.cc)
target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags)
find_package(gflags REQUIRED)
\ No newline at end of file
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ -d out ]; then
rm -rf out
fi
mkdir out
cd out || exit
if [ -f "Makefile" ]; then
make clean
fi
cmake .. \
-DMINDSPORE_PATH="`pip show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`"
make
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INFERENCE_UTILS_H_
#define MINDSPORE_INFERENCE_UTILS_H_
#include <sys/stat.h>
#include <dirent.h>
#include <vector>
#include <string>
#include <memory>
#include "include/api/types.h"
std::vector<std::string> GetAllFiles(std::string_view dirName);
DIR *OpenDir(std::string_view dirName);
std::string RealPath(std::string_view path);
mindspore::MSTensor ReadFileToTensor(const std::string &file);
int WriteResult(const std::string& imageFile, const std::vector<mindspore::MSTensor> &outputs);
#endif
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <dirent.h>
#include <gflags/gflags.h>
#include <sys/time.h>
#include <algorithm>
#include <fstream>
#include <iosfwd>
#include <iostream>
#include <sstream>
#include <string>
#include <vector>
#include "inc/utils.h"
#include "include/api/context.h"
#include "include/api/model.h"
#include "include/api/serialization.h"
#include "include/api/types.h"
#include "include/dataset/execute.h"
#include "include/dataset/vision.h"
using mindspore::Context;
using mindspore::GraphCell;
using mindspore::kSuccess;
using mindspore::Model;
using mindspore::ModelType;
using mindspore::MSTensor;
using mindspore::Serialization;
using mindspore::Status;
using mindspore::dataset::Execute;
DEFINE_string(mindir_path, "", "mindir path");
DEFINE_string(input_path, ".", "input path");
DEFINE_int32(device_id, 0, "device id");
int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
if (RealPath(FLAGS_mindir_path).empty()) {
std::cout << "Invalid mindir" << std::endl;
return 1;
}
auto context = std::make_shared<Context>();
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
ascend310->SetDeviceID(FLAGS_device_id);
context->MutableDeviceInfo().push_back(ascend310);
mindspore::Graph graph;
Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph);
Model model;
Status ret = model.Build(GraphCell(graph), context);
if (ret != kSuccess) {
std::cout << "ERROR: Build failed." << std::endl;
return 1;
}
std::vector<MSTensor> model_inputs = model.GetInputs();
if (model_inputs.empty()) {
std::cout << "Invalid model, inputs is empty." << std::endl;
return 1;
}
auto input_files = GetAllFiles(FLAGS_input_path);
if (input_files.empty()) {
std::cout << "ERROR: input data empty." << std::endl;
return 1;
}
std::map<double, double> costTime_map;
size_t size = input_files.size();
for (size_t i = 0; i < size; ++i) {
struct timeval start = {0};
struct timeval end = {0};
double startTimeMs;
double endTimeMs;
std::vector<MSTensor> inputs;
std::vector<MSTensor> outputs;
std::cout << "Start predict input files:" << input_files[i] << std::endl;
auto input0 = ReadFileToTensor(input_files[i]);
inputs.emplace_back(model_inputs[0].Name(), model_inputs[0].DataType(),
model_inputs[0].Shape(), input0.Data().get(),
input0.DataSize());
gettimeofday(&start, nullptr);
ret = model.Predict(inputs, &outputs);
gettimeofday(&end, nullptr);
if (ret != kSuccess) {
std::cout << "Predict " << input_files[i] << " failed." << std::endl;
return 1;
}
startTimeMs = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000;
endTimeMs = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000;
costTime_map.insert(std::pair<double, double>(startTimeMs, endTimeMs));
WriteResult(input_files[i], outputs);
}
double average = 0.0;
int inferCount = 0;
for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) {
double diff = iter->second - iter->first;
average += diff;
inferCount++;
}
average = average / inferCount;
std::stringstream timeCost;
timeCost << "NN inference cost average time: " << average
<< " ms of infer_count " << inferCount << std::endl;
std::cout << "NN inference cost average time: " << average
<< "ms of infer_count " << inferCount << std::endl;
std::string fileName =
"./time_Result" + std::string("/test_perform_static.txt");
std::ofstream fileStream(fileName.c_str(), std::ios::trunc);
fileStream << timeCost.str();
fileStream.close();
costTime_map.clear();
return 0;
}
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/utils.h"
#include <algorithm>
#include <fstream>
#include <iostream>
using mindspore::DataType;
using mindspore::MSTensor;
std::vector<std::string> GetAllFiles(std::string_view dirName) {
struct dirent *filename;
DIR *dir = OpenDir(dirName);
if (dir == nullptr) {
return {};
}
std::vector<std::string> res;
while ((filename = readdir(dir)) != nullptr) {
std::string dName = std::string(filename->d_name);
if (dName == "." || dName == ".." || filename->d_type != DT_REG) {
continue;
}
res.emplace_back(std::string(dirName) + "/" + filename->d_name);
}
std::sort(res.begin(), res.end());
for (auto &f : res) {
std::cout << "image file: " << f << std::endl;
}
return res;
}
int WriteResult(const std::string &imageFile,
const std::vector<MSTensor> &outputs) {
std::string homePath = "./result_Files";
for (size_t i = 0; i < outputs.size(); ++i) {
size_t outputSize;
std::shared_ptr<const void> netOutput;
netOutput = outputs[i].Data();
outputSize = outputs[i].DataSize();
int pos = imageFile.rfind('/');
std::string fileName(imageFile, pos + 1);
fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'),
'_' + std::to_string(i) + ".bin");
std::string outFileName = homePath + "/" + fileName;
FILE *outputFile = fopen(outFileName.c_str(), "wb");
fwrite(netOutput.get(), outputSize, sizeof(char), outputFile);
fclose(outputFile);
outputFile = nullptr;
}
return 0;
}
mindspore::MSTensor ReadFileToTensor(const std::string &file) {
if (file.empty()) {
std::cout << "Pointer file is nullptr" << std::endl;
return mindspore::MSTensor();
}
std::ifstream ifs(file);
if (!ifs.good()) {
std::cout << "File: " << file << " is not exist" << std::endl;
return mindspore::MSTensor();
}
if (!ifs.is_open()) {
std::cout << "File: " << file << "open failed" << std::endl;
return mindspore::MSTensor();
}
ifs.seekg(0, std::ios::end);
size_t size = ifs.tellg();
mindspore::MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8,
{static_cast<int64_t>(size)}, nullptr, size);
ifs.seekg(0, std::ios::beg);
ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
ifs.close();
return buffer;
}
DIR *OpenDir(std::string_view dirName) {
if (dirName.empty()) {
std::cout << " dirName is null ! " << std::endl;
return nullptr;
}
std::string realPath = RealPath(dirName);
struct stat s;
lstat(realPath.c_str(), &s);
if (!S_ISDIR(s.st_mode)) {
std::cout << "dirName is not a valid directory !" << std::endl;
return nullptr;
}
DIR *dir;
dir = opendir(realPath.c_str());
if (dir == nullptr) {
std::cout << "Can not open dir " << dirName << std::endl;
return nullptr;
}
std::cout << "Successfully opened the dir " << dirName << std::endl;
return dir;
}
std::string RealPath(std::string_view path) {
char realPathMem[PATH_MAX] = {0};
char *realPathRet = nullptr;
realPathRet = realpath(path.data(), realPathMem);
if (realPathRet == nullptr) {
std::cout << "File: " << path << " is not exist.";
return "";
}
std::string realPath(realPathMem);
std::cout << path << " realpath is: " << realPath << std::endl;
return realPath;
}
......@@ -75,7 +75,7 @@ if __name__ == '__main__':
pred = np.array(Image.open(pred_path), dtype='float32')
pic_name = content_list[i].replace(".jpg", "").replace(".png", "").replace(".JPEG", "")
print("%d / %d , %s \n" % (i, len(content_list), pic_name))
print("%d / %d , %s \n" % (i+1, len(content_list), pic_name))
label_path = os.path.join(label_directory, pic_name) + ".png"
label = np.array(Image.open(label_path), dtype='float32')
if len(label.shape) > 2:
......
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""export U-2-Net model"""
import argparse
import numpy as np
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
from src.blocks import U2NET
parser = argparse.ArgumentParser(description='checkpoint export')
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--file_name", type=str, default="u2net",
help="output file name.")
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="MINDIR", help="file format")
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if __name__ == '__main__':
context.set_context(device_id="Ascend")
net = U2NET()
param_dict = load_checkpoint(args.ckpt_file)
load_param_into_net(net, param_dict)
input_data = Tensor(np.zeros([1, 3, 320, 320], np.float32))
export(net, input_data, file_name=args.file_name, file_format=args.file_format)
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""postprocess"""
import argparse
import os
import cv2
import imageio
import numpy as np
from PIL import Image
parser = argparse.ArgumentParser()
parser.add_argument("--bin_path", type=str, help='bin_path, path to binary files generated by 310 model, default: None')
parser.add_argument("--content_path", type=str, help='content_path, default: None')
parser.add_argument("--output_dir", type=str, default='output_dir',
help='output_path, path to store output, default: None')
args = parser.parse_args()
if __name__ == "__main__":
bin_path = args.bin_path
original_dir = args.content_path
content_list = os.listdir(args.bin_path)
def normPRED(d):
"""rescale the value of tensor to between 0 and 1"""
ma = d.max()
mi = d.min()
dn = (d - mi) / (ma - mi)
return dn
for i in range(0, len(content_list)):
pic_path = os.path.join(args.bin_path, content_list[i])
b = np.fromfile(pic_path, dtype=np.float32, count=320 * 320)
b = np.reshape(b, (320, 320))
file_path = os.path.join(original_dir, content_list[i]).replace("_0.bin", ".jpg")
original = np.array(Image.open(file_path), dtype='float32')
shape = original.shape
b = normPRED(b)
image = b
content_name = content_list[i].replace("_0.bin", "")
image = cv2.resize(image, dsize=(0, 0), fx=shape[1] / image.shape[1], fy=shape[0] / image.shape[0])
image_path = os.path.join(args.output_dir, content_name) + ".png"
imageio.imsave(image_path, image)
print("%d / %d , %s \n" % (i, len(content_list), content_name))
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""preprocess"""
import argparse
import os
import cv2
import numpy as np
from PIL import Image
parser = argparse.ArgumentParser('preprocess')
parser.add_argument("--content_path", type=str, help='content_path, default: None')
parser.add_argument('--output_path', type=str, default="./preprocess_Result/", help='eval data dir')
args = parser.parse_args()
if __name__ == "__main__":
if not os.path.exists(args.output_path):
os.makedirs(args.output_path)
def normalize(img, im_type):
"""normalize tensor"""
if im_type == "label":
return img
if len(img.shape) == 3:
img[:, :, 0] = (img[:, :, 0] - 0.485) / 0.229
img[:, :, 1] = (img[:, :, 1] - 0.456) / 0.224
img[:, :, 2] = (img[:, :, 2] - 0.406) / 0.225
else:
img = (img - 0.485) / 0.229
return img
def crop_and_resize(img_path, im_type, size=320):
"""crop and resize tensors"""
img = np.array(Image.open(img_path), dtype='float32')
img = img / 255
img = normalize(img, im_type)
h, w = img.shape[:2]
img = cv2.resize(img, dsize=(0, 0), fx=size / w, fy=size / h)
if len(img.shape) == 2:
img = np.expand_dims(img, 2).repeat(1, axis=2)
im = img
im = np.swapaxes(im, 1, 2)
im = np.swapaxes(im, 0, 1)
im = np.reshape(im, (1, im.shape[0], im.shape[1], im.shape[2]))
return im
content_list = os.listdir(args.content_path)
for j in range(0, len(content_list)):
pic_path = os.path.join(args.content_path, content_list[j])
content_pic = crop_and_resize(pic_path, im_type="content", size=320)
file_name = content_list[j].replace(".jpg", "") + ".bin"
image_path = os.path.join(args.output_path, file_name)
content_pic.tofile(image_path)
print("Export bin files finished!")
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [[ $# -lt 3 || $# -gt 4 ]]; then
echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [CONTENT_PATH] [LABEL_PATH] [DEVICE_ID]
DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
model=$(get_real_path $1)
content_path=$(get_real_path $2)
label_path=$(get_real_path $3)
device_id=0
if [ $# == 4 ]; then
device_id=$4
fi
echo "mindir name: "$model
echo "content path: "$content_path
echo "device id: "$device_id
export ASCEND_HOME=/usr/local/Ascend/
if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
export PATH=$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
else
export PATH=$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export PYTHONPATH=$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
fi
function preprocess_data()
{
if [ -d preprocess_Result ]; then
rm -rf ./preprocess_Result
fi
mkdir preprocess_Result
python3.7 ../preprocess.py --content_path $content_path --output_path='./preprocess_Result/'
}
function compile_app()
{
cd ../ascend310_infer/ || exit
bash build.sh &> build.log
}
function infer()
{
cd - || exit
if [ -d result_Files ]; then
rm -rf ./result_Files
fi
if [ -d time_Result ]; then
rm -rf ./time_Result
fi
mkdir result_Files
mkdir time_Result
../ascend310_infer/out/main --mindir_path=$model --input_path=./preprocess_Result --device_id=$device_id &> infer.log
}
function post_process()
{
if [ -d postprocess_Result ]; then
rm -rf ./postprocess_Result
fi
mkdir postprocess_Result
python3.7 ../postprocess.py --bin_path='./result_Files' --content_path $content_path --output_dir='./postprocess_Result/' &> postprocess.log
}
function evaluation()
{
python3.7 ../eval.py --pred_dir='./postprocess_Result/' --label_dir $label_path &> evaluation.log
}
preprocess_data
if [ $? -ne 0 ]; then
echo "preprocess dataset failed"
exit 1
fi
compile_app
if [ $? -ne 0 ]; then
echo "compile app code failed"
exit 1
fi
infer
if [ $? -ne 0 ]; then
echo " execute inference failed"
exit 1
fi
post_process
if [ $? -ne 0 ]; then
echo " execute post_process failed"
exit 1
fi
\ No newline at end of file
......@@ -82,7 +82,7 @@ if __name__ == '__main__':
return img
def crop_and_resize(img_path, im_type, size=320):
def resize_im(img_path, size=320):
"""crop and resize tensors"""
img = np.array(Image.open(img_path), dtype='float32')
img = img / 255
......@@ -105,7 +105,7 @@ if __name__ == '__main__':
start_time = time.time()
for j in range(0, len(content_list)):
pic_path = os.path.join(local_dataset_dir, content_list[j])
content_pic = crop_and_resize(pic_path, im_type="content", size=320)
content_pic = resize_im(pic_path, size=320)
image = net(Tensor(content_pic))
content_name = content_list[j].replace(".jpg", "")
content_name = content_name.replace(".png", "")
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment