diff --git a/research/cv/SiamFC/README.md b/research/cv/SiamFC/README.md index 4a51504dce357a86575402d116c0c87b8bfac77e..f35a9ea37c65e9d1eebb879f7496f1f06cca465b 100644 --- a/research/cv/SiamFC/README.md +++ b/research/cv/SiamFC/README.md @@ -12,6 +12,7 @@ - [Training](#training) - [Evaluation Process](#evaluation-process) - [Evaluation](#evaluation) + - [Ascend310 infer](#evaluation-process) - [Model Description](#model-description) - [Performance](#performance) - [Evaluation Performance](#evaluation-performance) @@ -90,10 +91,16 @@ After installing mindspree through the official website, you can follow the foll ```python 鈹溾攢鈹€ SiamFC 鈹溾攢鈹€ README.md // Notes on siamfc + 鈹溾攢鈹€ ascend310_infer // Implementation inference script on ascend310 + 鈹� 鈹溾攢鈹€inc //Head file + 鈹� 鈹溾攢鈹€src //Main.cc and utils.cc file + 鈹� 鈹溾攢鈹€build.sh //Build file + 鈹� 鈹溾攢鈹€CMakeLists.txt //Required library files 鈹溾攢鈹€ scripts 鈹� 鈹溾攢鈹€ma-pre-start.sh // Create environment before modelarts training 鈹� 鈹溾攢鈹€run_standalone_train_ascend.sh // Single card training in ascend 鈹� 鈹溾攢鈹€run_distribution_ascend.sh // Multi card distributed training in ascend + 鈹� 鈹溾攢鈹€run_infer_310.sh //310infer scripts 鈹溾攢鈹€ src 鈹� 鈹溾攢鈹€alexnet.py // Create dataset 鈹� 鈹溾攢鈹€config.py // Alexnet architecture @@ -175,6 +182,22 @@ Check the checkpoint path used for evaluation before running the following comma SiamFC_159_50_6650.ckpt -prec_score:0.777 -succ_score:0.589 _succ_rate:0.754 ``` +## Ascend310 infer + +Check the checkpoint path used for evaluation before running the following command. + +Run this reference scripts need two different MINDIR + +```bash + python export.py --device_id=${DEVICE_ID} --model_path=${MODEL_PATH} --file_name_export1=${SAVE_MODEL_PATH1} --file_name_export2=${SAVE_MODEL_PATH2} --file_name=${FILE_FORMAT} --device_target=${DEVICE_TARGET} +``` + +- Running in ascend310 device processor environment + +```bash + bash run_infer_310.sh [MODEL_PATH1] [MODEL_PATH2] [DATASET_PATH] [CODE_PATH] [DEVICE_TARGET] [DEVICE_ID] +``` + # [Model description](#Contents) ## performance @@ -193,3 +216,18 @@ Check the checkpoint path used for evaluation before running the following comma |total time |about 5 hours | |Script URL |<https://gitee.com/mindspore/models/tree/master/research/cv/SiamFC> | |Random number seed |set_seed = 1234 | + +## performance + +### Inference on Ascend310 Performance + +|parameter | Ascend | +| -------------------------- | --------------------------------------------| +|Model Version | SiamFC | +|Upload date |2021.11.1 | +|mindspore version |mindspore1.3.0 | +|Dataset | OTB2013 | +|total time |about 5 minutes | +|outputs |probability | +|Accuracy |prec_score:0.779 -succ_score:0.588 _succ_rate:0.756 | + diff --git a/research/cv/SiamFC/ascend310_infer/CMakeLists.txt b/research/cv/SiamFC/ascend310_infer/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..f76af274bcbb58cd8de3043a20027da52c28d60f --- /dev/null +++ b/research/cv/SiamFC/ascend310_infer/CMakeLists.txt @@ -0,0 +1,16 @@ +cmake_minimum_required(VERSION 3.14.1) +project(Ascend310Infer) +find_package(OpenCV 2 REQUIRED) +find_package(gflags REQUIRED) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined -D_GLIBCXX_USE_CXX11_ABI=0") +set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/) +option(MINDSPORE_PATH "mindspore install path" "") +include_directories(${OpenCV_INCLUDE_DIRS}) +include_directories(${MINDSPORE_PATH}) +message(STATUS "mindsporelibs:${MINDSPORE_PATH}") +include_directories(${MINDSPORE_PATH}/include) +include_directories(${PROJECT_SRC_ROOT}) +find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib) +file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*) +add_executable(main src/main.cc src/utils.cc) +target_link_libraries(main ${MS_LIB} ${MD_LIB} ${OpenCV_LIBS} gflags) diff --git a/research/cv/SiamFC/ascend310_infer/build.sh b/research/cv/SiamFC/ascend310_infer/build.sh new file mode 100644 index 0000000000000000000000000000000000000000..285514e19f2a1878a7bf8f0eed3c99fbc73868c4 --- /dev/null +++ b/research/cv/SiamFC/ascend310_infer/build.sh @@ -0,0 +1,29 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +if [ -d out ]; then + rm -rf out +fi + +mkdir out +cd out || exit + +if [ -f "Makefile" ]; then + make clean +fi + +cmake .. \ + -DMINDSPORE_PATH="`pip3.7 show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`" +make diff --git a/research/cv/SiamFC/ascend310_infer/inc/utils.h b/research/cv/SiamFC/ascend310_infer/inc/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..fe2e13493ece94e6ca92042d9dd098fcfc92c4a1 --- /dev/null +++ b/research/cv/SiamFC/ascend310_infer/inc/utils.h @@ -0,0 +1,42 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_INFERENCE_UTILS_H_ +#define MINDSPORE_INFERENCE_UTILS_H_ +#include <opencv2/opencv.hpp> +#include <opencv2/core/core.hpp> +#include <opencv2/highgui/highgui.hpp> +#include <opencv2/imgproc/imgproc.hpp> +#include <opencv2/objdetect/objdetect.hpp> +#include <opencv2/imgproc/types_c.h> + +#include <sys/stat.h> +#include <dirent.h> +#include <vector> +#include <string> +#include <memory> +#include "include/api/types.h" + +std::vector<std::string> GetAllFiles(const std::string_view& dirName,const std::string& seq_name); +DIR *OpenDir(const std::string_view& dirName); +std::string RealPath(const std::string_view& path); +mindspore::MSTensor ReadFileToTensor(const std::string &file); +int WriteResult(const std::string& imageFile, const std::vector<mindspore::MSTensor> &outputs); +cv::Mat BGRToRGB(cv::Mat img); +cv::Mat crop_and_pad(cv::Mat img, float cx, float cy, float size_z, float s_z); +std::vector<double> Getpos(const std::string &dirName); +float sumMat(cv::Mat& inputImg); +#endif diff --git a/research/cv/SiamFC/ascend310_infer/src/main.cc b/research/cv/SiamFC/ascend310_infer/src/main.cc new file mode 100644 index 0000000000000000000000000000000000000000..2094fb60843ceb36400530d2bfde8a77f4709597 --- /dev/null +++ b/research/cv/SiamFC/ascend310_infer/src/main.cc @@ -0,0 +1,533 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <dirent.h> +#include <gflags/gflags.h> +#include <opencv2/imgproc/types_c.h> +#include <sys/time.h> + +#include <algorithm> +#include <cmath> +#include <fstream> +#include <iosfwd> +#include <iostream> +#include <opencv2/core/core.hpp> +#include <opencv2/highgui/highgui.hpp> +#include <opencv2/imgproc/imgproc.hpp> +#include <opencv2/objdetect/objdetect.hpp> +#include <opencv2/opencv.hpp> +#include <sstream> +#include <string> +#include <vector> + +#include "inc/utils.h" +#include "include/api/context.h" +#include "include/api/model.h" +#include "include/api/serialization.h" +#include "include/api/types.h" +#include "include/dataset/execute.h" +#include "include/dataset/transforms.h" +#include "include/dataset/vision.h" +#include "include/dataset/vision_ascend.h" + +using mindspore::Context; +using mindspore::DataType; +using mindspore::Graph; +using mindspore::GraphCell; +using mindspore::kSuccess; +using mindspore::Model; +using mindspore::ModelType; +using mindspore::MSTensor; +using mindspore::Serialization; +using mindspore::Status; +using mindspore::dataset::Execute; +using mindspore::dataset::TensorTransform; +using mindspore::dataset::transforms::TypeCast; +using mindspore::dataset::vision::Decode; +using mindspore::dataset::vision::HWC2CHW; +using mindspore::dataset::vision::Normalize; +using mindspore::dataset::vision::Resize; +using namespace cv; +using namespace std; + +DEFINE_string(model_path1, "/home/siamfc/model1.mindir", "model path"); +DEFINE_string(model_path2, "/home/siamfc/model2_change.mindir", "model path"); +DEFINE_int32(device_id, 0, "device id"); +DEFINE_string(precision_mode, "allow_fp32_to_fp16", "precision mode"); +DEFINE_string(op_select_impl_mode, "", "op select impl mode"); +DEFINE_string(aipp_path, "./aipp.cfg", "aipp path"); +DEFINE_string(device_target, "Ascend310", "device target"); +DEFINE_string(code_path, "/home/Siamfc/", "code path"); +DEFINE_string(seq_root_path, "/home/siamfc/OTB2013/", "OTB route"); +std::vector<std::string> all_videos = { + "Basketball", "Bolt", "Boy", "Car4", "CarDark", + "CarScale", "Coke", "Couple", "Crossing", "David", + "David2", "David3", "Deer", "Dog1", "Doll", + "Dudek", "FaceOcc1", "FaceOcc2", "Fish", "FleetFace", + "Football", "Football1", "Football1", "Freeman1", "Freeman3", + "Freeman4", "Girl", "Ironman", "Jogging", "Jumping", + "Lemming", "Liquor", "Matrix", "Mhyang", "MotorRolling", + "MountainBike", "Shaking", "Singer1", "Singer2", "Skating1", + "Skiing", "Soccer", "Subway", "Suv", "Sylvester", + "Tiger1", "Tiger2", "Trellis", "Walking", "Walking2", + "Woman"}; + +struct param { + const int none = 1; + const int* one = &none; + size_t s_one = 4; + size_t size_s; + double init_x; + double init_y; + double init_w; + double init_h; + double target_position[2]; + double target_sz[2]; + double wc_z; + double hc_z; + double s_z; + double scale_z; + double penalty[3] = {0.9745, 1, 0.9745}; + double scales[3] = {0.96385542, 1.00, 1.0375}; + string dataset_path_txt; + string record_name; + string record_times; + double s_x; + double min_s_x; + double max_s_x; + double size_x_scales[3]; + vector<double> box; + vector<string> all_files; +}; +Mat hwc2chw(Mat dst, size_t resize_detection) { + std::vector<float> dst_data; + std::vector<cv::Mat> bgrChannels(3); + cv::split(dst, bgrChannels); + for (size_t i = 0; i < bgrChannels.size(); i++) { + std::vector<float> data = std::vector<float>(bgrChannels[i].reshape(1, 1)); + dst_data.insert(dst_data.end(), data.begin(), data.end()); + } + cv::Mat srcMat; + srcMat = cv::Mat(dst_data, true); + cv::Mat dst_img = srcMat.reshape(3, resize_detection); + return dst_img; +} +void pretreatment(cv::Mat src, cv::Mat& target, param config, int size, + double s_x) { + cv::Mat cropImg = crop_and_pad(src, config.target_position[0], + config.target_position[1], size, s_x); + cv::Mat exemplar_FLOAT; + cropImg.convertTo(exemplar_FLOAT, CV_32FC3); + target = hwc2chw(exemplar_FLOAT, size); +} +void init_position(param& config, string& temp_video) { + config.all_files = GetAllFiles(FLAGS_seq_root_path, temp_video); + config.box = Getpos(config.dataset_path_txt); + config.size_s = config.all_files.size(); + config.init_x = config.box[0] - 1; + config.init_y = config.box[1] - 1; + config.init_w = config.box[2]; + config.init_h = config.box[3]; + config.target_position[0] = config.init_x + (config.init_w - 1) / 2; + config.target_position[1] = config.init_y + (config.init_h - 1) / 2; + config.target_sz[0] = config.init_w; + config.target_sz[1] = config.init_h; + config.wc_z = config.init_w + 0.5 * (config.init_w + config.init_h); + config.hc_z = config.init_h + 0.5 * (config.init_w + config.init_h); + config.s_z = sqrt(config.wc_z * config.hc_z); + config.scale_z = 127 / config.s_z; + config.s_x = config.s_z + (255 - 127) / config.scale_z; + config.min_s_x = 0.2 * config.s_x; + config.max_s_x = 5 * config.s_x; +} + +void getPath(param& config, string& temp_video, int jogging_count) { + config.dataset_path_txt = + FLAGS_seq_root_path + "/" + temp_video + "/" + "groundtruth_rect.txt"; + config.record_name = + FLAGS_code_path + "/results/OTB2013/SiamFC/" + temp_video + ".txt"; + config.record_times = FLAGS_code_path + "/results/OTB2013/SiamFC/times/" + + temp_video + "_time.txt"; + if (temp_video == "Jogging") { + auto jogging_path = FLAGS_seq_root_path + "/" + temp_video + "/" + + "groundtruth_rect" + "." + + std::to_string(jogging_count) + ".txt"; + auto jogging_record = FLAGS_code_path + "/results/OTB2013/SiamFC/" + + temp_video + "." + std::to_string(jogging_count) + + ".txt"; + config.dataset_path_txt = jogging_path; + config.record_name = jogging_record; + } +} + +void getSizeScales(param& config) { + for (int k = 0; k < 3; k++) { + config.size_x_scales[k] = config.s_x * config.scales[k]; + } +} + +void getExemplar(string& temp_video, vector<MSTensor>& outputs_exemplar, + vector<MSTensor>& inputs_exemplar, Model& model1, + param& config, int jogging_count) { + getPath(config, temp_video, jogging_count); + std::vector<MSTensor> model_inputs = model1.GetInputs(); + init_position(config, temp_video); + cv::Mat src = cv::imread(config.all_files[0], cv::IMREAD_COLOR); + cv::Mat exemplar; + pretreatment(src, exemplar, config, 127, config.s_z); + cout << "box :" << config.box[0] << " " << config.box[1] << " " + << config.box[2] << " " << config.box[3] << endl; + size_t size_buffer = exemplar.size().width * exemplar.size().height * 4 * 3; + mindspore::MSTensor image("x", mindspore::DataType::kNumberTypeFloat32, + {static_cast<int64_t>(3), static_cast<int64_t>(127), + static_cast<int64_t>(127)}, + exemplar.data, size_buffer); + std::vector<int64_t> shape = image.Shape(); + inputs_exemplar.clear(); + inputs_exemplar.emplace_back( + model_inputs[0].Name(), model_inputs[0].DataType(), + model_inputs[0].Shape(), image.Data().get(), image.DataSize()); + inputs_exemplar.emplace_back( + model_inputs[1].Name(), model_inputs[1].DataType(), + model_inputs[1].Shape(), config.one, config.s_one); + Status ret_instance; + ret_instance = + model1.Predict(inputs_exemplar, &outputs_exemplar); // get exemplar img + if (ret_instance != kSuccess) { + cout << " Failed predict" << endl; + } else { + cout << " Success predict" << endl; + } +} +void preInstance(vector<MSTensor>& input_exemplar, + vector<MSTensor>& outputs_exemplar, + vector<MSTensor>& output_exemplar, + vector<MSTensor>& model_inputs_instance, Model& model2, + MSTensor& instance) { + input_exemplar.clear(); + input_exemplar.emplace_back( + model_inputs_instance[0].Name(), model_inputs_instance[0].DataType(), + model_inputs_instance[0].Shape(), outputs_exemplar[0].Data().get(), + outputs_exemplar[0].DataSize()); + input_exemplar.emplace_back(model_inputs_instance[1].Name(), + model_inputs_instance[1].DataType(), + model_inputs_instance[1].Shape(), + instance.Data().get(), instance.DataSize()); + model2.Predict(input_exemplar, &output_exemplar); +} + +void getRetInstance(int instance_num, vector<MSTensor>& inputs, + vector<MSTensor>& outputs, + vector<MSTensor>& outputs_exemplar, Mat cos_window, + param& config, Model& model2) { + getSizeScales(config); + vector<MSTensor> model_inputs_instance = model2.GetInputs(); + cv::Mat instance_src; + instance_src = cv::imread(config.all_files[instance_num], cv::IMREAD_COLOR); + cv::Mat exemplar_img[3]; + cv::Mat inputs_instance[3]; + cv::Mat response_mapInit[3]; + cv::Mat response_map[3]; + double response_map_max[3]; + std::vector<MSTensor> input_exemplar; + std::vector<MSTensor> output_exemplar1; + std::vector<MSTensor> output_exemplar2; + std::vector<MSTensor> output_exemplar3; + for (int n = 0; n < 3; n++) { + pretreatment(instance_src, exemplar_img[n], config, 255, + config.size_x_scales[n]); + } + size_t size_buffer_instance = + exemplar_img[0].size().width * exemplar_img[0].size().height * 3 * 4; + mindspore::MSTensor instance1( + "y", mindspore::DataType::kNumberTypeFloat32, + {static_cast<int64_t>(3), static_cast<int64_t>(255), + static_cast<int64_t>(255)}, + exemplar_img[0].data, size_buffer_instance); + mindspore::MSTensor instance2( + "y", mindspore::DataType::kNumberTypeFloat32, + {static_cast<int64_t>(3), static_cast<int64_t>(255), + static_cast<int64_t>(255)}, + exemplar_img[1].data, size_buffer_instance); + mindspore::MSTensor instance3( + "y", mindspore::DataType::kNumberTypeFloat32, + {static_cast<int64_t>(3), static_cast<int64_t>(255), + static_cast<int64_t>(255)}, + exemplar_img[2].data, size_buffer_instance); + + preInstance(input_exemplar, outputs_exemplar, output_exemplar1, + model_inputs_instance, model2, instance1); + preInstance(input_exemplar, outputs_exemplar, output_exemplar2, + model_inputs_instance, model2, instance2); + preInstance(input_exemplar, outputs_exemplar, output_exemplar3, + model_inputs_instance, model2, instance3); + response_mapInit[0] = + cv::Mat(17, 17, CV_32FC1, output_exemplar1[0].MutableData()); + response_mapInit[1] = + cv::Mat(17, 17, CV_32FC1, output_exemplar2[0].MutableData()); + response_mapInit[2] = + cv::Mat(17, 17, CV_32FC1, output_exemplar3[0].MutableData()); + + double minValue = 0; + double maxValue = 0; + for (int n = 0; n < 3; n++) { + cv::resize(response_mapInit[n], response_map[n], Size(272, 272), 0, 0, + cv::INTER_CUBIC); + cv::minMaxIdx(response_map[n], &minValue, &maxValue, NULL, NULL); + response_map_max[n] = maxValue * config.penalty[n]; + } + int scale_index = std::max_element(response_map_max, response_map_max + 3) - + response_map_max; + cv::Mat response_map_up = response_map[scale_index]; + double minValue_response = 0; + double maxValue_response = 0; + cv::minMaxIdx(response_map_up, &minValue_response, &maxValue_response); + response_map_up = response_map_up - minValue_response; + Scalar sum_response = sum(response_map_up); + response_map_up = response_map_up / sum_response[0]; + response_map_up = (1 - 0.176) * response_map_up + 0.176 * cos_window; + cv::minMaxIdx(response_map_up, &minValue_response, &maxValue_response); + + cv::Point maxLoc; + cv::minMaxLoc(response_map_up, NULL, NULL, NULL, &maxLoc); + double maxLoc_x = static_cast<double>(maxLoc.x); + double maxLoc_y = static_cast<double>(maxLoc.y); + maxLoc_x -= (271 / 2); + maxLoc_y -= (271 / 2); + maxLoc_x /= 2; + maxLoc_y /= 2; + + double scale = config.scales[scale_index]; + maxLoc_x = maxLoc_x * (config.s_x * scale) / 255; + maxLoc_y = maxLoc_y * (config.s_x * scale) / 255; + config.target_position[0] += maxLoc_x; + config.target_position[1] += maxLoc_y; + cout << " target_position[0]: " << config.target_position[0] + << " target_positon[1]:" << config.target_position[1] << endl; + config.s_x = (0.41 + 0.59 * scale) * config.s_x; + config.s_x = max(config.min_s_x, min(config.max_s_x, config.s_x)); + config.target_sz[0] = (0.41 + 0.59 * scale) * config.target_sz[0]; + config.target_sz[1] = (0.41 + 0.59 * scale) * config.target_sz[1]; + config.box[0] = config.target_position[0] + 1 - (config.target_sz[0]) / 2; + config.box[1] = config.target_position[1] + 1 - (config.target_sz[1]) / 2; + config.box[2] = config.target_sz[0]; + config.box[3] = config.target_sz[1]; +} + +void myCreateHanningWindow(OutputArray _dst, cv::Size winSize, int type) { + CV_Assert(type == CV_32FC1 || type == CV_64FC1); + _dst.create(winSize, type); + Mat dst = _dst.getMat(); + int rows = dst.rows; + int cols = dst.cols; + if (dst.depth() == CV_32F) { + if (rows == 1 && cols == 1) { + dst.at<float>(0, 0) = 1; + } else if (rows == 1 && cols > 1) { + float* dstData = dst.ptr<float>(0); + for (int j = 0; j < cols; j++) { + dstData[j] = + 0.5 * (1.0 - cos(2.0 * CV_PI * (double)j / (double)(cols - 1))); + } + } else if (rows > 1 && cols == 1) { + for (int i = 0; i < rows; i++) { + float* dstData = dst.ptr<float>(i); + dstData[0] = + 0.5 * (1.0 - cos(2.0 * CV_PI * (double)i / (double)(rows - 1))); + } + + } else { + for (int i = 0; i < rows; i++) { + float* dstData = dst.ptr<float>(i); + double wr = + 0.5 * (1.0 - cos(2.0 * CV_PI * (double)i / (double)(rows - 1))); + for (int j = 0; j < cols; j++) { + double wc = + 0.5 * (1.0 - cos(2.0 * CV_PI * (double)j / (double)(cols - 1))); + dstData[j] = (float)(wr * wc); + } + } + sqrt(dst, dst); + } + } else { + if (rows == 1 && cols == 1) { + dst.at<double>(0, 0) = 1; + } else if (rows == 1 && cols > 1) { + double* dstData = dst.ptr<double>(0); + for (int j = 0; j < cols; j++) { + dstData[j] = + 0.5 * (1.0 - cos(2.0 * CV_PI * (double)j / (double)(cols - 1))); + } + } else if (rows > 1 && cols == 1) { + for (int i = 0; i < rows; i++) { + double* dstData = dst.ptr<double>(i); + dstData[0] = + 0.5 * (1.0 - cos(2.0 * CV_PI * (double)i / (double)(rows - 1))); + } + } else { + for (int i = 0; i < rows; i++) { + double* dstData = dst.ptr<double>(i); + double wr = + 0.5 * (1.0 - cos(2.0 * CV_PI * (double)i / (double)(rows - 1))); + for (int j = 0; j < cols; j++) { + double wc = + 0.5 * (1.0 - cos(2.0 * CV_PI * (double)j / (double)(cols - 1))); + dstData[j] = (double)(wr * wc); + } + } + sqrt(dst, dst); + } + } +} +Mat createMulHanningWindow(cv::Size winSize, int type) { + int size1[2] = {1, winSize.width}; + cv::Mat selfhanning1(1, size1, CV_32FC1, cv::Scalar(0)); + myCreateHanningWindow(selfhanning1, cv::Size(1, winSize.width), CV_32FC1); + int size2[2] = {winSize.height, 1}; + cv::Mat selfhanning2(1, size2, CV_32FC1, cv::Scalar(0)); + myCreateHanningWindow(selfhanning2, cv::Size(winSize.height, 1), CV_32FC1); + cv::Mat mulHanning; + mulHanning = selfhanning1 * selfhanning2; + return mulHanning; +} + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + if (RealPath(FLAGS_model_path1).empty()) { + std::cout << "Invalid model" << std::endl; + return 1; + } + + auto context = std::make_shared<Context>(); + auto ascend310_info = std::make_shared<mindspore::Ascend310DeviceInfo>(); + ascend310_info->SetDeviceID(FLAGS_device_id); + context->MutableDeviceInfo().push_back(ascend310_info); + // load graph1 + Graph graph1; + Status ret = + Serialization::Load(FLAGS_model_path1, ModelType::kMindIR, &graph1); + cout << "Load model success" << endl; + if (ret != kSuccess) { + std::cout << "Load model failed." << std::endl; + return 1; + } + Model model1; + Status ret_build = model1.Build(GraphCell(graph1), context); + if (ret_build != kSuccess) { + std::cout << "ERROR: Build failed." << std::endl; + return 1; + } else { + cout << " Build success " << endl; + } + // load graph2 + Graph graph2; + Status ret_graph2 = + Serialization::Load(FLAGS_model_path2, ModelType::kMindIR, &graph2); + if (ret_graph2 != kSuccess) { + cout << " load graph2 failed" << endl; + } else { + cout << " load graph2 Success" << endl; + } + Model model2; + Status ret_build2 = model2.Build(GraphCell(graph2), context); + if (ret_build2 != kSuccess) { + cout << " build graph2 failed" << endl; + } else { + cout << " build graph2 Success" << endl; + } + + auto all_files = GetAllFiles(FLAGS_seq_root_path, all_videos[0]); + if (all_files.empty()) { + std::cout << "ERROR: no input data." << std::endl; + return 1; + } + int jogging_count = 1; + std::map<double, double> costTime_map; + size_t size_v = all_videos.size(); + for (size_t i = 0; i < size_v; ++i) { + param config; + vector<MSTensor> inputs_exemplar; + vector<MSTensor> outputs_exemplar; + struct timeval start, end; + double startTime_ms, endTime_ms, useTime_ms; + gettimeofday(&start, NULL); + getExemplar(all_videos[i], outputs_exemplar, inputs_exemplar, model1, + config, jogging_count); + cout << "record:" << config.record_name << " " << config.record_times + << endl; + gettimeofday(&end, NULL); + costTime_map.insert(std::pair<double, double>(startTime_ms, endTime_ms)); + ofstream outfile_record; + ofstream outfile_times; + outfile_times.open(config.record_times); + outfile_record.open(config.record_name); + startTime_ms = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000; + endTime_ms = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000; + useTime_ms = endTime_ms - startTime_ms; + outfile_times << useTime_ms << std::endl; + outfile_record << config.box[0] << "," << config.box[1] << "," + << config.box[2] << "," << config.box[3] << endl; + cv::Mat hann; + hann = createMulHanningWindow(cv::Size(16 * 17, 16 * 17), CV_32FC1); + Scalar sum_hann = sum(hann); + cv::Mat cos_window = hann / sum_hann[0]; // create hanning + // load graph2 + std::vector<MSTensor> inputs; + std::vector<MSTensor> outputs; + for (size_t j = 1; j < config.size_s; j++) { + gettimeofday(&start, NULL); + getRetInstance(j, inputs, outputs, outputs_exemplar, cos_window, config, + model2); + gettimeofday(&end, NULL); + costTime_map.insert(std::pair<double, double>(startTime_ms, endTime_ms)); + startTime_ms = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000; + endTime_ms = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 100; + useTime_ms = endTime_ms - startTime_ms; + outfile_times << useTime_ms << std::endl; + outfile_record << config.box[0] << "," << config.box[1] << "," + << config.box[2] << "," << config.box[3] << endl; + } + if (all_videos[i] == "Jogging" && jogging_count == 1) { + i--; + jogging_count++; + } + } + double average = 0.0; + int infer_cnt = 0; + for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) { + double diff = 0.0; + diff = iter->second - iter->first; + average += diff; + infer_cnt++; + } + + average = average / infer_cnt; + + std::stringstream timeCost; + timeCost << "NN inference cost average time: " << average + << " ms of infer_count " << infer_cnt << std::endl; + std::cout << "NN inference cost average time: " << average + << "ms of infer_count " << infer_cnt << std::endl; + std::string file_name = + "./time_Result" + std::string("/test_perform_static.txt"); + std::ofstream file_stream(file_name.c_str(), std::ios::trunc); + file_stream << timeCost.str(); + file_stream.close(); + costTime_map.clear(); + return 0; + cout << "End project" << endl; + return 0; +} diff --git a/research/cv/SiamFC/ascend310_infer/src/utils.cc b/research/cv/SiamFC/ascend310_infer/src/utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..c054842005ead6941632a63f6bf53a9d07028bcf --- /dev/null +++ b/research/cv/SiamFC/ascend310_infer/src/utils.cc @@ -0,0 +1,256 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "inc/utils.h" + +#include <dirent.h> +#include <opencv2/imgproc/types_c.h> + +#include <algorithm> +#include <cmath> +#include <fstream> +#include <iostream> +#include <opencv2/core/core.hpp> +#include <opencv2/highgui/highgui.hpp> +#include <opencv2/imgproc/imgproc.hpp> +#include <opencv2/objdetect/objdetect.hpp> +#include <opencv2/opencv.hpp> +#include <sstream> +#include <string> +#include <string_view> +#include <vector> + +const int DAVID_DATA_SIZE = 471; +const int DAVID_DATA_BEGIN = 299; +const int DAVID_DATA_END = 770; +const int FOOTBALL_DATA_SIZE = 74; +const int FREEMAN3_DATA_SIZE = 460; +const int FREEMAN4_DATA_SIZE = 283; +const int DIVING_DATA_SIZE = 215; +using mindspore::DataType; +using mindspore::MSTensor; +using namespace std; + +std::vector<std::string> GetAllFiles(const std::string_view& dirName, + const std::string& seq_name) { + struct dirent* filename; + string seqName = string(dirName) + "/" + seq_name + "/img"; + + DIR* dir = OpenDir(seqName); + if (dir == nullptr) { + cout << "no dir" << endl; + return {}; + } + std::vector<std::string> res; + while ((filename = readdir(dir)) != nullptr) { + std::string dName = std::string(filename->d_name); + if (dName == "." || dName == ".." || filename->d_type != DT_REG) { + continue; + } + res.emplace_back(string(dirName) + "/" + seq_name + "/img/" + + filename->d_name); + } + std::sort(res.begin(), res.end()); + std::vector<std::string> res_all; + if (seq_name == "David") { + res_all.resize(DAVID_DATA_SIZE); + std::copy(res.begin() + DAVID_DATA_BEGIN, res.begin() + DAVID_DATA_END, + res_all.begin()); + } else if (seq_name == "Football1") { + res_all.resize(FOOTBALL_DATA_SIZE); + std::copy(res.begin(), res.begin() + FOOTBALL_DATA_SIZE, res_all.begin()); + } else if (seq_name == "Freeman3") { + res_all.resize(FREEMAN3_DATA_SIZE); + std::copy(res.begin(), res.begin() + FREEMAN3_DATA_SIZE, res_all.begin()); + } else if (seq_name == "Freeman4") { + res_all.resize(FREEMAN4_DATA_SIZE); + std::copy(res.begin(), res.begin() + FREEMAN4_DATA_SIZE, res_all.begin()); + } else if (seq_name == "Diving") { + res_all.resize(FREEMAN4_DATA_SIZE); + std::copy(res.begin(), res.begin() + FREEMAN4_DATA_SIZE, res_all.begin()); + } else { + for (size_t i = 0; i < res.size(); i++) { + res_all.emplace_back(res[i]); + } + } + return res_all; +} +std::vector<double> Getpos(const std::string& dirName) { + std::ifstream infile; + infile.open(dirName.c_str()); + std::string s; + getline(infile, s); + std::stringstream ss; + ss << s; + double temp; + std::vector<double> data; + while (ss >> temp) { + data.push_back(temp); + if (ss.peek() == ',' || ss.peek() == ' ' || ss.peek() == '\t') { + ss.ignore(); + } + } + infile.close(); + return data; +} + +int WriteResult(const std::string& imageFile, + const std::vector<MSTensor>& outputs) { + std::string homePath = "./result_Files"; + for (size_t i = 0; i < outputs.size(); ++i) { + size_t outputSize; + std::shared_ptr<const void> netOutput = outputs[i].Data(); + outputSize = outputs[i].DataSize(); + int pos = imageFile.rfind('/'); + std::string fileName(imageFile, pos + 1); + fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), + '_' + std::to_string(i) + ".bin"); + std::string outFileName = homePath + "/" + fileName; + FILE* outputFile = fopen(outFileName.c_str(), "wb"); + fwrite(netOutput.get(), outputSize, sizeof(char), outputFile); + fclose(outputFile); + outputFile = nullptr; + } + return 0; +} + +mindspore::MSTensor ReadFileToTensor(const std::string& file) { + if (file.empty()) { + std::cout << "Pointer file is nullptr" << std::endl; + return mindspore::MSTensor(); + } + + std::ifstream ifs(file); + if (!ifs.good()) { + std::cout << "File: " << file << " is not exist" << std::endl; + return mindspore::MSTensor(); + } + + if (!ifs.is_open()) { + std::cout << "File: " << file << "open failed" << std::endl; + return mindspore::MSTensor(); + } + + ifs.seekg(0, std::ios::end); + size_t size = ifs.tellg(); + mindspore::MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8, + {static_cast<int64_t>(size)}, nullptr, size); + + ifs.seekg(0, std::ios::beg); + ifs.read(reinterpret_cast<char*>(buffer.MutableData()), size); + ifs.close(); + + return buffer; +} + +DIR* OpenDir(const std::string_view& dirName) { + if (dirName.empty()) { + std::cout << " dirName is null ! " << std::endl; + return nullptr; + } + std::string realPath = RealPath(dirName); + struct stat s; + lstat(realPath.c_str(), &s); + if (!S_ISDIR(s.st_mode)) { + std::cout << "dirName is not a valid directory !" << std::endl; + return nullptr; + } + DIR* dir = opendir(realPath.c_str()); + if (dir == nullptr) { + std::cout << "Can not open dir " << dirName << std::endl; + return nullptr; + } + std::cout << "Successfully opened the dir " << dirName << std::endl; + return dir; +} + +std::string RealPath(const std::string_view& path) { + char realPathMem[PATH_MAX] = {0}; + char* realPathRet = nullptr; + realPathRet = realpath(path.data(), realPathMem); + if (realPathRet == nullptr) { + std::cout << "File: " << path << " is not exist."; + return ""; + } + + std::string realPath(realPathMem); + std::cout << path << " realpath is: " << realPath << std::endl; + return realPath; +} + +cv::Mat BGRToRGB(cv::Mat& img) { + cv::Mat image(img.rows, img.cols, CV_8UC3); + for (int i = 0; i < img.rows; ++i) { + cv::Vec3b* p1 = img.ptr<cv::Vec3b>(i); + cv::Vec3b* p2 = image.ptr<cv::Vec3b>(i); + for (int j = 0; j < img.cols; ++j) { + p2[j][2] = p1[j][0]; + p2[j][1] = p1[j][1]; + p2[j][0] = p1[j][2]; + } + } + return image; +} +cv::Mat crop_and_pad(cv::Mat img, float cx, float cy, float size_z, float s_z) { + float xmin = cx - s_z / 2; + float xmax = cx + s_z / 2; + float ymin = cy - s_z / 2; + float ymax = cy + s_z / 2; + int w = img.cols; + int h = img.rows; + int left = 0; + int right = 0; + int top = 0; + int bottom = 0; + + if (xmin < 0) left = static_cast<int>(abs(xmin)); + if (xmax > w) right = static_cast<int>(xmax - w); + if (ymin < 0) top = static_cast<int>(abs(ymin)); + if (ymax > h) bottom = static_cast<int>(ymax - h); + + xmin = std::max(0, static_cast<int>(xmin)); + xmax = std::min(w, static_cast<int>(xmax)); + ymin = std::max(0, static_cast<int>(ymin)); + ymax = std::min(h, static_cast<int>(ymax)); + + cv::Mat im_patch = img(cv::Range(ymin, ymax), cv::Range(xmin, xmax)); + if (left != 0 || right != 0 || top != 0 || bottom != 0) { + cv::Scalar tempVal = cv::mean(img); + tempVal.val[0] = static_cast<int>(tempVal.val[0]); + tempVal.val[1] = static_cast<int>(tempVal.val[1]); + tempVal.val[2] = static_cast<int>(tempVal.val[2]); + cv::copyMakeBorder(im_patch, im_patch, top, bottom, left, right, + cv::BORDER_CONSTANT, tempVal); + } + if (size_z != s_z) { + cv::resize(im_patch, im_patch, cv::Size(size_z, size_z)); + } + return im_patch; +} + +float sumMat(cv::Mat& inputImg) { + float sum = 0.0; + int rowNumber = inputImg.rows; + int colNumber = inputImg.cols * inputImg.channels(); + for (int i = 0; i < rowNumber; i++) { + uchar* data = inputImg.ptr<uchar>(i); + for (int j = 0; j < colNumber; j++) { + sum = data[j] + sum; + } + } + + return sum; +} diff --git a/research/cv/SiamFC/eval.py b/research/cv/SiamFC/eval.py index 1951036215a5dbb914b61a164ca628ea0a6dfe72..47e85e85af745e7d3639089884f32329e011ab0e 100644 --- a/research/cv/SiamFC/eval.py +++ b/research/cv/SiamFC/eval.py @@ -26,8 +26,9 @@ sys.path.append(os.getcwd()) if __name__ == '__main__': parser = argparse.ArgumentParser(description='siamfc tracking') - parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend') - parser.add_argument('--model_path', default='/root/SiamFC/models/siamfc_{}.ckpt/SiamFC-6650.ckpt' + parser.add_argument('--device_id', type=int, default=7 + , help='device id of GPU or Ascend') + parser.add_argument('--model_path', default='/root/models/siamfc_{}.ckpt/SiamFC_177-47_6650.ckpt' , type=str, help='eval one special video') parser.add_argument('--dataset_path', default='/root/datasets/OTB2013', type=str) diff --git a/research/cv/SiamFC/export.py b/research/cv/SiamFC/export.py index c842137708027e1c61405ab6c4503eb002f47680..4ee984257c0cc5f6bd28bc6604070e415a74080e 100644 --- a/research/cv/SiamFC/export.py +++ b/research/cv/SiamFC/export.py @@ -19,14 +19,16 @@ import mindspore as ms from mindspore import Tensor, context from mindspore.train.serialization import load_checkpoint, export, load_param_into_net from src.alexnet import SiameseAlexNet + parser = argparse.ArgumentParser(description='siamfc export') -parser.add_argument("--device_id", type=int, default=0, help="Device id") -parser.add_argument('--model_path', default='/root/HRBEU-MedAI/SiamFC/models/siamfc_{}.ckpt/', - type=str, help='eval one special video') -parser.add_argument('--file_name', type=str, default='/root/HRBEU-MedAI/SiamFC/models', +parser.add_argument("--device_id", type=int, default=7, help="Device id") +parser.add_argument('--model_path', default='/root/models/siamfc_{}.ckpt/SiamFC_177-47_6650.ckpt' + , type=str, help='eval one special video') +parser.add_argument('--file_name_export1', type=str, default='/root/SiamFC/models1', + help='SiamFc output file name.') +parser.add_argument('--file_name_export2', type=str, default='/root/SiamFC/models2', help='SiamFc output file name.') -parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='MINDIR', - help='file format') +parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='MINDIR', help='file format') parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend", help="device target") args = parser.parse_args() @@ -34,12 +36,19 @@ context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) if args.device_target == "Ascend": context.set_context(device_id=args.device_id) -if __name__ == "__main__": - net = SiameseAlexNet(train=False) - load_param_into_net(net, load_checkpoint(args.model_path), strict_load=True) - net.set_train(False) - - input_data_exemplar = Tensor(np.zeros([3, 256, 6, 6]), ms.float32) - input_data_instance = Tensor(np.zeros([3, 3, 255, 255]), ms.float32) - export(net, input_data_exemplar, input_data_instance, file_name=args.file_name, - file_format=args.file_format) +if __name__ == "__main__": + net1 = SiameseAlexNet(train=False) + load_param_into_net(net1, load_checkpoint(args.model_path), strict_load=True) + net1.set_train(False) + net2 = SiameseAlexNet(train=False) + load_param_into_net(net2, load_checkpoint(args.model_path), strict_load=True) + net2.set_train(False) + input_data_exemplar1 = Tensor(np.zeros([1, 3, 127, 127]), ms.float32) + input_data_instance1 = Tensor(np.zeros(1), ms.float32) + input_data_exemplar2 = Tensor(np.ones([1, 256, 6, 6]), ms.float32) + input_data_instance2 = Tensor(np.ones([1, 3, 255, 255]), ms.float32) + input1 = [input_data_exemplar1, input_data_instance1] + input2 = [input_data_exemplar2, input_data_instance2] + export(net1, *input1, file_name=args.file_name_export1, file_format=args.file_format) + export(net2, *input2, file_name=args.file_name_export2, file_format=args.file_format) + print("-- complete --") \ No newline at end of file diff --git a/research/cv/SiamFC/postprocess.py b/research/cv/SiamFC/postprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..a30b209ef930fffc6a4801eb1e3ef721aceb5bb3 --- /dev/null +++ b/research/cv/SiamFC/postprocess.py @@ -0,0 +1,50 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""post process for 310 inference""" +from __future__ import absolute_import +import argparse +import os +import sys +from got10k.experiments import ExperimentOTB +from mindspore import context +from src import SiamFCTracker +sys.path.append(os.getcwd()) +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='siamfc tracking') + parser.add_argument('--device_id', type=int, default=7 + , help='device id of GPU or Ascend') + + parser.add_argument('--dataset_path', default='/root/datasets/OTB2013', type=str) + + args = parser.parse_args() + context.set_context( + mode=context.GRAPH_MODE, + device_id=args.device_id, + save_graphs=False, + device_target='Ascend') + + root_dir = os.path.abspath(args.dataset_path) + e = ExperimentOTB(root_dir, version=2013) + prec_score = e.report(['SiamFC'])['SiamFC']['overall'] + score = ['success_score', 'precision_score', 'success_rate'] + mydic = [] + for key in score: + mydic.append(prec_score[key]) + ss = '-prec_score:%.3f -succ_score:%.3f -succ_rate:%.3f' % (float(mydic[1]), + float(mydic[0]), + float(mydic[2])) + + + print(ss) \ No newline at end of file diff --git a/research/cv/SiamFC/src/alexnet.py b/research/cv/SiamFC/src/alexnet.py index a70d476b72964651d43def41bea69d66df8175d6..fc0cd33fc7130658722cf5c40601e8a17815f5ac 100644 --- a/research/cv/SiamFC/src/alexnet.py +++ b/research/cv/SiamFC/src/alexnet.py @@ -118,7 +118,14 @@ class SiameseAlexNet(nn.Cell): score_map = n_p.transpose(score_map, (1, 0, 2, 3)) score_map = score_map*1e-3+self.corr_bias score = self.loss(score_map, self.train_gt)/8 - + elif x.shape.as_list()[0] == 1 : + exemplar = x + instance = y + if exemplar.size is not None and instance.size == 1: + exemplar = self.seq(exemplar) + return exemplar + instance = self.seq(instance) + score = self.Conv2D_1(instance, exemplar) else: exemplar = x instance = y