Skip to content
Snippets Groups Projects
Commit fa05de93 authored by Jin Sutong's avatar Jin Sutong
Browse files

Merge remote-tracking branch 'upstream/master'

stargan
parent e32ce331
No related branches found
No related tags found
No related merge requests found
Showing
with 1828 additions and 0 deletions
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if [ $# -ne 2 ]
then
echo "Wrong parameter format."
echo "Usage:"
echo " bash $0 INPUT_AIR_PATH OUTPUT_OM_PATH_NAME"
echo "Example: "
echo " bash convert_om.sh models/0-150_1251.air models/0-150_1251.om"
exit 255
fi
input_air_path=$1
output_om_path=$2
export ASCEND_SLOG_PRINT_TO_STDOUT=1
echo "Input AIR file path: ${input_air_path}"
echo "Output OM file path: ${output_om_path}"
atc --input_format=NCHW \
--framework=1 \
--model=${input_air_path} \
--output=${output_om_path} \
--soc_version=Ascend310 \
--disable_reuse_memory=0 \
--output_type=FP32 \
--precision_mode=allow_fp32_to_fp16 \
--op_select_implmode=high_precision
\ No newline at end of file
{
"stargan": {
"stream_config": {
"deviceId": "0"
},
"appsrc0": {
"props": {
"blocksize": "409600"
},
"factory": "appsrc",
"next": "mxpi_tensorinfer0:0"
},
"appsrc1": {
"props": {
"blocksize": "409600"
},
"factory": "appsrc",
"next": "mxpi_tensorinfer0:1"
},
"mxpi_tensorinfer0": {
"props": {
"dataSource": "appsrc0,appsrc1",
"modelPath": "../data/model/stargan.om"
},
"factory": "mxpi_tensorinfer",
"next": "appsink0"
},
"appsink0": {
"props": {
"blocksize": "4096000"
},
"factory": "appsink"
}
}
}
\ No newline at end of file
5_o_Clock_Shadow
Arched_Eyebrows
Attractive
Bags_Under_Eyes
Bald
Bangs
Big_Lips
Big_Nose
Black_Hair
Blond_Hair
Blurry
Brown_Hair
Bushy_Eyebrows
Chubby
Double_Chin
Eyeglasses
Goatee
Gray_Hair
Heavy_Makeup
High_Cheekbones
Male
Mouth_Slightly_Open
Mustache
Narrow_Eyes
No_Beard
Oval_Face
Pale_Skin
Pointy_Nose
Receding_Hairline
Rosy_Cheeks
Sideburns
Smiling
Straight_Hair
Wavy_Hair
Wearing_Earrings
Wearing_Hat
Wearing_Lipstick
Wearing_Necklace
Wearing_Necktie
Young
\ No newline at end of file
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
docker_image=$1
data_dir=$2
function show_help() {
echo "Usage: docker_start.sh docker_image data_dir"
}
function param_check() {
if [ -z "${docker_image}" ]; then
echo "please input docker_image"
show_help
exit 1
fi
if [ -z "${data_dir}" ]; then
echo "please input data_dir"
show_help
exit 1
fi
}
param_check
docker run -it \
--device=/dev/davinci0 \
--device=/dev/davinci_manager \
--device=/dev/devmm_svm \
--device=/dev/hisi_hdc \
-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
-v ${data_dir}:${data_dir} \
${docker_image} \
/bin/bash
\ No newline at end of file
cmake_minimum_required(VERSION 3.14.0)
project(stargan)
set(TARGET_MAIN Stargan)
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
add_definitions(-Dgoogle=mindxsdk_private)
add_compile_options(-std=c++11 -fPIE -fstack-protector-all -fPIC -Wall)
add_link_options(-Wl,-z,relro,-z,now,-z,noexecstack -s -pie)
# Check environment variable
if(NOT DEFINED ENV{ASCEND_HOME})
message(FATAL_ERROR "please define environment variable:ASCEND_HOME")
endif()
if(NOT DEFINED ENV{ARCH_PATTERN})
message(WARNING "please define environment variable:ARCH_PATTERN")
endif()
set(ACL_LIB_PATH $ENV{ASCEND_HOME}/nnrt/latest/acllib)
set(MXBASE_ROOT_DIR $ENV{MX_SDK_HOME})
set(MXBASE_INC ${MXBASE_ROOT_DIR}/include)
set(MXBASE_LIB_DIR ${MXBASE_ROOT_DIR}/lib)
set(MXBASE_POST_LIB_DIR ${MXBASE_ROOT_DIR}/lib/modelpostprocessors)
set(MXBASE_POST_PROCESS_DIR ${MXBASE_ROOT_DIR}/include/MxBase/postprocess/include)
set(OPENSOURCE_DIR $ENV{MX_SDK_HOME}/opensource)
include_directories(src)
include_directories(${ACL_INC_DIR})
include_directories(${OPENSOURCE_DIR}/include)
include_directories(${OPENSOURCE_DIR}/include/opencv4)
include_directories(${MXBASE_INC})
include_directories(${MXBASE_POST_PROCESS_DIR})
link_directories(${ACL_LIB_DIR})
link_directories(${OPENSOURCE_DIR}/lib)
link_directories(${MXBASE_LIB_DIR})
link_directories(${MXBASE_POST_LIB_DIR})
add_executable(${TARGET_MAIN} src/main.cpp src/StarGanGeneration.cpp)
target_link_libraries(${TARGET_MAIN} glog cpprest mxbase opencv_world)
install(TARGETS ${TARGET_MAIN} RUNTIME DESTINATION ${PROJECT_SOURCE_DIR}/)
\ No newline at end of file
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "StarGanGeneration.h"
#include "MxBase/DeviceManager/DeviceManager.h"
#include "MxBase/Log/Log.h"
namespace {
const uint32_t FLOAT32_TYPE_BYTE_NUM = 4;
const float NORMALIZE_MEAN = 255/2;
const float NORMALIZE_STD = 255/2;
const uint32_t OUTPUT_HEIGHT = 128;
const uint32_t OUTPUT_WIDTH = 128;
const uint32_t CHANNEL = 3;
}
void PrintTensorShape(const std::vector<MxBase::TensorDesc> &tensorDescVec, const std::string &tensorName) {
LogInfo << "The shape of " << tensorName << " is as follows:";
for (size_t i = 0; i < tensorDescVec.size(); ++i) {
LogInfo << " Tensor " << i << ":";
for (size_t j = 0; j < tensorDescVec[i].tensorDims.size(); ++j) {
LogInfo << " dim: " << j << ": " << tensorDescVec[i].tensorDims[j];
}
}
}
APP_ERROR StarGanGeneration::Init(const InitParam &initParam) {
deviceId_ = initParam.deviceId;
APP_ERROR ret = MxBase::DeviceManager::GetInstance()->InitDevices();
if (ret != APP_ERR_OK) {
LogError << "Init devices failed, ret=" << ret << ".";
return ret;
}
ret = MxBase::TensorContext::GetInstance()->SetContext(initParam.deviceId);
if (ret != APP_ERR_OK) {
LogError << "Set context failed, ret=" << ret << ".";
return ret;
}
dvppWrapper_ = std::make_shared<MxBase::DvppWrapper>();
ret = dvppWrapper_->Init();
if (ret != APP_ERR_OK) {
LogError << "DvppWrapper init failed, ret=" << ret << ".";
return ret;
}
model_ = std::make_shared<MxBase::ModelInferenceProcessor>();
ret = model_->Init(initParam.modelPath, modelDesc_);
if (ret != APP_ERR_OK) {
LogError << "ModelInferenceProcessor init failed, ret=" << ret << ".";
return ret;
}
savePath_ = initParam.savePath;
PrintTensorShape(modelDesc_.inputTensors, "Model Input Tensors");
PrintTensorShape(modelDesc_.outputTensors, "Model Output Tensors");
return APP_ERR_OK;
}
APP_ERROR StarGanGeneration::DeInit() {
dvppWrapper_->DeInit();
model_->DeInit();
MxBase::DeviceManager::GetInstance()->DestroyDevices();
return APP_ERR_OK;
}
APP_ERROR StarGanGeneration::ReadImage(const std::string &imgPath, cv::Mat *imageMat) {
*imageMat = cv::imread(imgPath, cv::IMREAD_COLOR);
return APP_ERR_OK;
}
void StarGanGeneration::ResizeImage(const cv::Mat &srcImageMat, cv::Mat *dstImageMat) {
static constexpr uint32_t resizeHeight = OUTPUT_HEIGHT;
static constexpr uint32_t resizeWidth = OUTPUT_WIDTH;
cv::resize(srcImageMat, *dstImageMat, cv::Size(resizeWidth, resizeHeight));
}
APP_ERROR StarGanGeneration::CVMatToTensorBase(const cv::Mat& imageMat, MxBase::TensorBase *tensorBase) {
uint32_t dataSize = 1;
for (size_t i = 0; i < modelDesc_.inputTensors.size(); ++i) {
std::vector<uint32_t> shape = {};
for (size_t j = 0; j < modelDesc_.inputTensors[i].tensorDims.size(); ++j) {
shape.push_back((uint32_t)modelDesc_.inputTensors[i].tensorDims[j]);
}
for (uint32_t s = 0; s < shape.size(); ++s) {
dataSize *= shape[s];
}
}
// mat NHWC to NCHW, BGR to RGB, and Normalize
size_t H = OUTPUT_HEIGHT, W = OUTPUT_WIDTH, C = CHANNEL;
float* mat_data = new float[dataSize];
dataSize = dataSize * FLOAT32_TYPE_BYTE_NUM;
for (size_t c = 0; c < C; c++) {
for (size_t h = 0; h < H; h++) {
for (size_t w = 0; w < W; w++) {
int id = (C - c - 1) * (H * W) + h * W + w;
mat_data[id] = (imageMat.at<cv::Vec3b>(h, w)[c] - NORMALIZE_MEAN) / NORMALIZE_STD;
}
}
}
MxBase::MemoryData memoryDataDst(dataSize, MxBase::MemoryData::MEMORY_DEVICE, deviceId_);
MxBase::MemoryData memoryDataSrc(reinterpret_cast<void*>(&mat_data[0]),
dataSize, MxBase::MemoryData::MEMORY_HOST_MALLOC);
APP_ERROR ret = MxBase::MemoryHelper::MxbsMallocAndCopy(memoryDataDst, memoryDataSrc);
if (ret != APP_ERR_OK) {
LogError << GetError(ret) << "Memory malloc failed.";
return ret;
}
std::vector<uint32_t> shape = { 1, CHANNEL, OUTPUT_HEIGHT, OUTPUT_WIDTH};
*tensorBase = MxBase::TensorBase(memoryDataDst, false, shape, MxBase::TENSOR_DTYPE_FLOAT32);
return APP_ERR_OK;
}
APP_ERROR StarGanGeneration::Inference(const std::vector<MxBase::TensorBase> &inputs,
std::vector<MxBase::TensorBase> *outputs) {
auto dtypes = model_->GetOutputDataType();
for (size_t i = 0; i < modelDesc_.outputTensors.size(); ++i) {
std::vector<uint32_t> shape = {};
for (size_t j = 0; j < modelDesc_.outputTensors[i].tensorDims.size(); ++j) {
shape.push_back((uint32_t)modelDesc_.outputTensors[i].tensorDims[j]);
}
MxBase::TensorBase tensor(shape, dtypes[i], MxBase::MemoryData::MemoryType::MEMORY_DEVICE, deviceId_);
APP_ERROR ret = MxBase::TensorBase::TensorBaseMalloc(tensor);
if (ret != APP_ERR_OK) {
LogError << "TensorBaseMalloc failed, ret=" << ret << ".";
return ret;
}
outputs->push_back(tensor);
}
MxBase::DynamicInfo dynamicInfo = {};
dynamicInfo.dynamicType = MxBase::DynamicType::STATIC_BATCH;
dynamicInfo.batchSize = 1;
APP_ERROR ret = model_->ModelInference(inputs, *outputs, dynamicInfo);
if (ret != APP_ERR_OK) {
LogError << "ModelInference failed, ret=" << ret << ".";
return ret;
}
return APP_ERR_OK;
}
APP_ERROR StarGanGeneration::PostProcess(std::vector<MxBase::TensorBase> outputs, cv::Mat *resultImg) {
APP_ERROR ret = outputs[0].ToHost();
if (ret != APP_ERR_OK) {
LogError << GetError(ret) << "tohost fail.";
return ret;
}
float *outputPtr = reinterpret_cast<float *>(outputs[0].GetBuffer());
size_t H = OUTPUT_HEIGHT, W = OUTPUT_WIDTH, C = CHANNEL;
for (size_t c = 0; c < C; c++) {
for (size_t h = 0; h < H; h++) {
for (size_t w = 0; w < W; w++) {
float tmpNum = *(outputPtr + (C - c - 1) * (H * W) + h * W + w) * NORMALIZE_STD + NORMALIZE_MEAN;
resultImg->at<cv::Vec3b>(h, w)[c] = static_cast<int>(tmpNum);
}
}
}
return APP_ERR_OK;
}
APP_ERROR StarGanGeneration::SaveResult(const cv::Mat &resultImg, const std::string &imgName) {
DIR *dirPtr = opendir(savePath_.c_str());
if (dirPtr == nullptr) {
std::string path1 = "mkdir -p " + savePath_;
system(path1.c_str());
}
cv::imwrite(savePath_ + "/" + imgName, resultImg);
return APP_ERR_OK;
}
APP_ERROR StarGanGeneration::GetImageLabel(std::vector<float> label, MxBase::TensorBase *imgLabels) {
float* mat_data = new float[label.size()];
for (size_t i = 0; i < label.size(); i++) {
mat_data[i] = label[i];
}
MxBase::MemoryData memoryDataDst(label.size()*FLOAT32_TYPE_BYTE_NUM, MxBase::MemoryData::MEMORY_DEVICE, deviceId_);
MxBase::MemoryData memoryDataSrc(reinterpret_cast<void*>(&mat_data[0]),
label.size()*FLOAT32_TYPE_BYTE_NUM, MxBase::MemoryData::MEMORY_HOST_MALLOC);
APP_ERROR ret = MxBase::MemoryHelper::MxbsMallocAndCopy(memoryDataDst, memoryDataSrc);
if (ret != APP_ERR_OK) {
LogError << GetError(ret) << "Memory malloc failed.";
return ret;
}
const std::vector<uint32_t> shape = {1, (unsigned int)label.size()};
*imgLabels = MxBase::TensorBase(memoryDataDst, false, shape, MxBase::TENSOR_DTYPE_FLOAT32);
return APP_ERR_OK;
}
APP_ERROR StarGanGeneration::Process(const std::string &imgPath,
const std::string &imgName, const std::vector<float> &label) {
cv::Mat imageMat;
APP_ERROR ret = ReadImage(imgPath, &imageMat);
if (ret != APP_ERR_OK) {
LogError << "ReadImage failed, ret=" << ret << ".";
return ret;
}
ResizeImage(imageMat, &imageMat);
std::vector<MxBase::TensorBase> inputs = {};
std::vector<MxBase::TensorBase> outputs = {};
MxBase::TensorBase tensorBase;
ret = CVMatToTensorBase(imageMat, &tensorBase);
if (ret != APP_ERR_OK) {
LogError << "CVMatToTensorBase failed, ret=" << ret << ".";
return ret;
}
inputs.push_back(tensorBase);
MxBase::TensorBase imgLabels;
ret = GetImageLabel(label, &imgLabels);
if (ret != APP_ERR_OK) {
LogError << "Get Image label failed, ret=" << ret << ".";
return ret;
}
inputs.push_back(imgLabels);
auto startTime = std::chrono::high_resolution_clock::now();
ret = Inference(inputs, &outputs);
auto endTime = std::chrono::high_resolution_clock::now();
double costMs = std::chrono::duration<double, std::milli>(endTime - startTime).count(); // save time
inferCostTimeMilliSec += costMs;
if (ret != APP_ERR_OK) {
LogError << "Inference failed, ret=" << ret << ".";
return ret;
}
cv::Mat resultImg(OUTPUT_HEIGHT, OUTPUT_WIDTH, CV_8UC3);
ret = PostProcess(outputs, &resultImg);
if (ret != APP_ERR_OK) {
LogError << "PostProcess failed, ret=" << ret << ".";
return ret;
}
ret = SaveResult(resultImg, imgName);
if (ret != APP_ERR_OK) {
LogError << "Save infer results into file failed. ret = " << ret << ".";
return ret;
}
return APP_ERR_OK;
}
/*
* Copyright (c) 2021. Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MXBASE_STARGANGENERATION_H
#define MXBASE_STARGANGENERATION_H
#include <dirent.h>
#include <memory>
#include <vector>
#include <map>
#include <string>
#include <fstream>
#include <iostream>
#include <opencv2/opencv.hpp>
#include "MxBase/Log/Log.h"
#include "MxBase/DvppWrapper/DvppWrapper.h"
#include "MxBase/ModelInfer/ModelInferenceProcessor.h"
#include "MxBase/DeviceManager/DeviceManager.h"
#include "MxBase/Tensor/TensorContext/TensorContext.h"
struct InitParam {
uint32_t deviceId;
std::string savePath;
std::string modelPath;
};
class StarGanGeneration {
public:
APP_ERROR Init(const InitParam &initParam);
APP_ERROR DeInit();
APP_ERROR ReadImage(const std::string &imgPath, cv::Mat *imageMat);
void ResizeImage(const cv::Mat &srcImageMat, cv::Mat *dstImageMat);
APP_ERROR CVMatToTensorBase(const cv::Mat &imageMat, MxBase::TensorBase *tensorBase);
APP_ERROR Inference(const std::vector<MxBase::TensorBase> &inputs,
std::vector<MxBase::TensorBase> *outputs);
APP_ERROR PostProcess(std::vector<MxBase::TensorBase> outputs, cv::Mat *resultImg);
APP_ERROR Process(const std::string &imgPath, const std::string &imgName, const std::vector<float> &label);
APP_ERROR GetImageLabel(std::vector<float> label, MxBase::TensorBase *imgLabels);
// get infer time
double GetInferCostMilliSec() const {return inferCostTimeMilliSec;}
private:
APP_ERROR SaveResult(const cv::Mat &resultImg, const std::string &imgName);
std::shared_ptr<MxBase::DvppWrapper> dvppWrapper_;
std::shared_ptr<MxBase::ModelInferenceProcessor> model_;
std::string savePath_;
MxBase::ModelDesc modelDesc_;
uint32_t deviceId_ = 0;
// infer time
double inferCostTimeMilliSec = 0.0;
};
#endif
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
path_cur=$(dirname $0)
function check_env()
{
# set ASCEND_VERSION to ascend-toolkit/latest when it was not specified by user
if [ ! "${ASCEND_VERSION}" ]; then
export ASCEND_VERSION=ascend-toolkit/latest
echo "Set ASCEND_VERSION to the default value: ${ASCEND_VERSION}"
else
echo "ASCEND_VERSION is set to ${ASCEND_VERSION} by user"
fi
if [ ! "${ARCH_PATTERN}" ]; then
# set ARCH_PATTERN to ./ when it was not specified by user
export ARCH_PATTERN=./
echo "ARCH_PATTERN is set to the default value: ${ARCH_PATTERN}"
else
echo "ARCH_PATTERN is set to ${ARCH_PATTERN} by user"
fi
}
function build_stargan()
{
cd $path_cur
rm -rf build
mkdir -p build
cd build
cmake ..
make
ret=$?
if [ ${ret} -ne 0 ]; then
echo "Failed to build stargan."
exit ${ret}
fi
make install
}
check_env
build_stargan
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "StarGanGeneration.h"
namespace {
std::vector<std::string> SELECTED_ATTRS {"Black_Hair", "Blond_Hair", "Brown_Hair", "Male", "Young"};
const std::string OM_MODEL_PATH = "../data/model/stargan.om";
}
APP_ERROR ScanImages(const std::string &path, std::vector<std::string> *imgFiles) {
DIR *dirPtr = opendir(path.c_str());
if (dirPtr == nullptr) {
LogError << "opendir failed. dir:" << path;
return APP_ERR_INTERNAL_ERROR;
}
dirent *direntPtr = nullptr;
while ((direntPtr = readdir(dirPtr)) != nullptr) {
std::string fileName = direntPtr->d_name;
if (fileName == "." || fileName == "..") {
continue;
}
imgFiles->emplace_back(path + "/" + fileName);
}
closedir(dirPtr);
return APP_ERR_OK;
}
std::vector<std::string> split(std::string str, char ch) {
size_t start = 0;
size_t len = 0;
std::vector<std::string> ret;
for (size_t i = 0; i < str.length(); i++) {
if (str[i] == ch && i+1 < str.length() && str[i+1] == ch) {
continue;
}
if (str[i] == ch) {
ret.push_back(str.substr(start, len));
start = i+1;
len = 0;
} else {
len++;
}
}
if (start < str.length())
ret.push_back(str.substr(start, len));
return ret;
}
int main(int argc, char* argv[]) {
if (argc <= 1) {
LogWarn << "Please input image path, such as '../data/test_data/'.";
return APP_ERR_OK;
}
InitParam initParam = {};
initParam.deviceId = 0;
initParam.modelPath = OM_MODEL_PATH;
initParam.savePath = "./result";
auto stargan = std::make_shared<StarGanGeneration>();
APP_ERROR ret = stargan->Init(initParam);
if (ret != APP_ERR_OK) {
LogError << "stargan init failed, ret=" << ret << ".";
return ret;
}
// Read the contents of a label
std::string dataPath = argv[1];
std::string imagePath = dataPath + "/images/";
std::string labelPath = dataPath + "/anno/list_attr_celeba.txt";
std::vector<std::string> imagePathList;
ret = ScanImages(imagePath, &imagePathList);
if (ret != APP_ERR_OK) {
LogError << "stargan init failed, ret=" << ret << ".";
return ret;
}
std::ifstream fin;
std::string s;
fin.open(labelPath);
int i = 0;
int imgNum;
std::map<int, std::string> idx2attr;
std::map<std::string, int> attr2idx;
auto startTime = std::chrono::high_resolution_clock::now();
while (getline(fin, s)) {
i++;
if (i == 1) {
imgNum = atoi(s.c_str());
} else if (i == 2) {
std::vector<std::string> allAttrNames = split(s, ' ');
for (size_t j = 0; j < allAttrNames.size(); j++) {
idx2attr[j] = allAttrNames[j];
attr2idx[allAttrNames[j]] = j;
}
} else {
std::vector<std::string> eachAttr = split(s, ' ');
// first one is file name
std::string imgName = eachAttr[0];
std::vector<float> label;
for (size_t j = 0; j < SELECTED_ATTRS.size(); j++) {
if (atoi(eachAttr[attr2idx[SELECTED_ATTRS[j]] + 1].c_str()) == 1)
label.push_back(1.0);
else
label.push_back(0.0);
// label.push_back(atoi(eachAttr[attr2idx[SELECTED_ATTRS[j]] + 1].c_str()) * -0.5);
}
ret = stargan->Process(imagePath + imgName, imgName, label);
if (ret != APP_ERR_OK) {
LogError << "stargan process failed, ret=" << ret << ".";
stargan->DeInit();
return ret;
}
}
}
fin.close();
auto endTime = std::chrono::high_resolution_clock::now();
stargan->DeInit();
double costMilliSecs = std::chrono::duration<double, std::milli>(endTime - startTime).count();
double fps = 1000.0 * imgNum / stargan->GetInferCostMilliSec();
LogInfo << "[Process Delay] cost: " << costMilliSecs << " ms\tfps: " << fps << " imgs/sec";
return APP_ERR_OK;
}
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "StarGanGeneration.h"
#include "MxBase/DeviceManager/DeviceManager.h"
#include "MxBase/Log/Log.h"
namespace {
const uint32_t FLOAT32_TYPE_BYTE_NUM = 4;
const float NORMALIZE_MEAN = 255/2;
const float NORMALIZE_STD = 255/2;
const uint32_t OUTPUT_HEIGHT = 128;
const uint32_t OUTPUT_WIDTH = 128;
const uint32_t CHANNEL = 3;
}
void PrintTensorShape(const std::vector<MxBase::TensorDesc> &tensorDescVec, const std::string &tensorName) {
LogInfo << "The shape of " << tensorName << " is as follows:";
for (size_t i = 0; i < tensorDescVec.size(); ++i) {
LogInfo << " Tensor " << i << ":";
for (size_t j = 0; j < tensorDescVec[i].tensorDims.size(); ++j) {
LogInfo << " dim: " << j << ": " << tensorDescVec[i].tensorDims[j];
}
}
}
APP_ERROR StarGanGeneration::Init(const InitParam &initParam) {
deviceId_ = initParam.deviceId;
APP_ERROR ret = MxBase::DeviceManager::GetInstance()->InitDevices();
if (ret != APP_ERR_OK) {
LogError << "Init devices failed, ret=" << ret << ".";
return ret;
}
ret = MxBase::TensorContext::GetInstance()->SetContext(initParam.deviceId);
if (ret != APP_ERR_OK) {
LogError << "Set context failed, ret=" << ret << ".";
return ret;
}
dvppWrapper_ = std::make_shared<MxBase::DvppWrapper>();
ret = dvppWrapper_->Init();
if (ret != APP_ERR_OK) {
LogError << "DvppWrapper init failed, ret=" << ret << ".";
return ret;
}
model_ = std::make_shared<MxBase::ModelInferenceProcessor>();
ret = model_->Init(initParam.modelPath, modelDesc_);
if (ret != APP_ERR_OK) {
LogError << "ModelInferenceProcessor init failed, ret=" << ret << ".";
return ret;
}
savePath_ = initParam.savePath;
PrintTensorShape(modelDesc_.inputTensors, "Model Input Tensors");
PrintTensorShape(modelDesc_.outputTensors, "Model Output Tensors");
return APP_ERR_OK;
}
APP_ERROR StarGanGeneration::DeInit() {
dvppWrapper_->DeInit();
model_->DeInit();
MxBase::DeviceManager::GetInstance()->DestroyDevices();
return APP_ERR_OK;
}
APP_ERROR StarGanGeneration::ReadImage(const std::string &imgPath, cv::Mat *imageMat) {
*imageMat = cv::imread(imgPath, cv::IMREAD_COLOR);
return APP_ERR_OK;
}
APP_ERROR StarGanGeneration::ResizeImage(const cv::Mat &srcImageMat, cv::Mat *dstImageMat) {
static constexpr uint32_t resizeHeight = OUTPUT_HEIGHT;
static constexpr uint32_t resizeWidth = OUTPUT_WIDTH;
cv::resize(srcImageMat, *dstImageMat, cv::Size(resizeWidth, resizeHeight));
return APP_ERR_OK;
}
APP_ERROR StarGanGeneration::CVMatToTensorBase(const cv::Mat& imageMat, MxBase::TensorBase *tensorBase) {
uint32_t dataSize = 1;
for (size_t i = 0; i < modelDesc_.inputTensors.size(); ++i) {
std::vector<uint32_t> shape = {};
for (size_t j = 0; j < modelDesc_.inputTensors[i].tensorDims.size(); ++j) {
shape.push_back((uint32_t)modelDesc_.inputTensors[i].tensorDims[j]);
}
for (uint32_t s = 0; s < shape.size(); ++s) {
dataSize *= shape[s];
}
}
// mat NHWC to NCHW, BGR to RGB, and Normalize
size_t H = OUTPUT_HEIGHT, W = OUTPUT_WIDTH, C = CHANNEL;
float* mat_data = new float[dataSize];
dataSize = dataSize * FLOAT32_TYPE_BYTE_NUM;
for (size_t c = 0; c < C; c++) {
for (size_t h = 0; h < H; h++) {
for (size_t w = 0; w < W; w++) {
int id = (C - c - 1) * (H * W) + h * W + w;
mat_data[id] = (imageMat.at<cv::Vec3b>(h, w)[c] - NORMALIZE_MEAN) / NORMALIZE_STD;
}
}
}
MxBase::MemoryData memoryDataDst(dataSize, MxBase::MemoryData::MEMORY_DEVICE, deviceId_);
MxBase::MemoryData memoryDataSrc(reinterpret_cast<void*>(&mat_data[0]),
dataSize, MxBase::MemoryData::MEMORY_HOST_MALLOC);
APP_ERROR ret = MxBase::MemoryHelper::MxbsMallocAndCopy(memoryDataDst, memoryDataSrc);
if (ret != APP_ERR_OK) {
LogError << GetError(ret) << "Memory malloc failed.";
return ret;
}
std::vector<uint32_t> shape = { 1, CHANNEL, OUTPUT_HEIGHT, OUTPUT_WIDTH};
*tensorBase = MxBase::TensorBase(memoryDataDst, false, shape, MxBase::TENSOR_DTYPE_FLOAT32);
return APP_ERR_OK;
}
APP_ERROR StarGanGeneration::Inference(const std::vector<MxBase::TensorBase> &inputs,
std::vector<MxBase::TensorBase> *outputs) {
auto dtypes = model_->GetOutputDataType();
for (size_t i = 0; i < modelDesc_.outputTensors.size(); ++i) {
std::vector<uint32_t> shape = {};
for (size_t j = 0; j < modelDesc_.outputTensors[i].tensorDims.size(); ++j) {
shape.push_back((uint32_t)modelDesc_.outputTensors[i].tensorDims[j]);
}
MxBase::TensorBase tensor(shape, dtypes[i], MxBase::MemoryData::MemoryType::MEMORY_DEVICE, deviceId_);
APP_ERROR ret = MxBase::TensorBase::TensorBaseMalloc(tensor);
if (ret != APP_ERR_OK) {
LogError << "TensorBaseMalloc failed, ret=" << ret << ".";
return ret;
}
outputs->push_back(tensor);
}
MxBase::DynamicInfo dynamicInfo = {};
dynamicInfo.dynamicType = MxBase::DynamicType::STATIC_BATCH;
dynamicInfo.batchSize = 1;
APP_ERROR ret = model_->ModelInference(inputs, *outputs, dynamicInfo);
if (ret != APP_ERR_OK) {
LogError << "ModelInference failed, ret=" << ret << ".";
return ret;
}
return APP_ERR_OK;
}
APP_ERROR StarGanGeneration::PostProcess(std::vector<MxBase::TensorBase> outputs, cv::Mat *resultImg) {
APP_ERROR ret = outputs[0].ToHost();
if (ret != APP_ERR_OK) {
LogError << GetError(ret) << "tohost fail.";
return ret;
}
float *outputPtr = reinterpret_cast<float *>(outputs[0].GetBuffer());
size_t H = OUTPUT_HEIGHT, W = OUTPUT_WIDTH, C = CHANNEL;
for (size_t c = 0; c < C; c++) {
for (size_t h = 0; h < H; h++) {
for (size_t w = 0; w < W; w++) {
float tmpNum = *(outputPtr + (C - c - 1) * (H * W) + h * W + w) * NORMALIZE_STD + NORMALIZE_MEAN;
resultImg->at<cv::Vec3b>(h, w)[c] = static_cast<int>(tmpNum);
}
}
}
return APP_ERR_OK;
}
APP_ERROR StarGanGeneration::SaveResult(const cv::Mat &resultImg, const std::string &imgName) {
DIR *dirPtr = opendir(savePath_.c_str());
if (dirPtr == nullptr) {
std::string path1 = "mkdir -p " + savePath_;
system(path1.c_str());
}
cv::imwrite(savePath_ + "/" + imgName, resultImg);
return APP_ERR_OK;
}
APP_ERROR StarGanGeneration::GetImageLabel(std::vector<float> label, MxBase::TensorBase *imgLabels) {
float* mat_data = new float[label.size()];
for (size_t i = 0; i < label.size(); i++) {
mat_data[i] = label[i];
}
MxBase::MemoryData memoryDataDst(label.size()*FLOAT32_TYPE_BYTE_NUM, MxBase::MemoryData::MEMORY_DEVICE, deviceId_);
MxBase::MemoryData memoryDataSrc(reinterpret_cast<void*>(&mat_data[0]),
label.size()*FLOAT32_TYPE_BYTE_NUM, MxBase::MemoryData::MEMORY_HOST_MALLOC);
APP_ERROR ret = MxBase::MemoryHelper::MxbsMallocAndCopy(memoryDataDst, memoryDataSrc);
if (ret != APP_ERR_OK) {
LogError << GetError(ret) << "Memory malloc failed.";
return ret;
}
const std::vector<uint32_t> shape = {1, (unsigned int)label.size()};
*imgLabels = MxBase::TensorBase(memoryDataDst, false, shape, MxBase::TENSOR_DTYPE_FLOAT32);
return APP_ERR_OK;
}
APP_ERROR StarGanGeneration::Process(const std::string &imgPath
, const std::string &imgName, const std::vector<float> &label) {
cv::Mat imageMat;
APP_ERROR ret = ReadImage(imgPath, &imageMat);
if (ret != APP_ERR_OK) {
LogError << "ReadImage failed, ret=" << ret << ".";
return ret;
}
ResizeImage(imageMat, &imageMat);
std::vector<MxBase::TensorBase> inputs = {};
std::vector<MxBase::TensorBase> outputs = {};
MxBase::TensorBase tensorBase;
ret = CVMatToTensorBase(imageMat, &tensorBase);
if (ret != APP_ERR_OK) {
LogError << "CVMatToTensorBase failed, ret=" << ret << ".";
return ret;
}
inputs.push_back(tensorBase);
MxBase::TensorBase imgLabels;
ret = GetImageLabel(label, &imgLabels);
if (ret != APP_ERR_OK) {
LogError << "Get Image label failed, ret=" << ret << ".";
return ret;
}
inputs.push_back(imgLabels);
auto startTime = std::chrono::high_resolution_clock::now();
ret = Inference(inputs, &outputs);
auto endTime = std::chrono::high_resolution_clock::now();
double costMs = std::chrono::duration<double, std::milli>(endTime - startTime).count(); // save time
inferCostTimeMilliSec += costMs;
if (ret != APP_ERR_OK) {
LogError << "Inference failed, ret=" << ret << ".";
return ret;
}
cv::Mat resultImg(OUTPUT_HEIGHT, OUTPUT_WIDTH, CV_8UC3);
ret = PostProcess(outputs, &resultImg);
if (ret != APP_ERR_OK) {
LogError << "PostProcess failed, ret=" << ret << ".";
return ret;
}
ret = SaveResult(resultImg, imgName);
if (ret != APP_ERR_OK) {
LogError << "Save infer results into file failed. ret = " << ret << ".";
return ret;
}
return APP_ERR_OK;
}
/*
* Copyright (c) 2021. Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MXBASE_STARGANGENERATION_H
#define MXBASE_STARGANGENERATION_H
#include <dirent.h>
#include <memory>
#include <vector>
#include <map>
#include <string>
#include <fstream>
#include <iostream>
#include <opencv2/opencv.hpp>
#include "MxBase/Log/Log.h"
#include "MxBase/DvppWrapper/DvppWrapper.h"
#include "MxBase/ModelInfer/ModelInferenceProcessor.h"
#include "MxBase/DeviceManager/DeviceManager.h"
#include "MxBase/Tensor/TensorContext/TensorContext.h"
struct InitParam {
uint32_t deviceId;
std::string savePath;
std::string modelPath;
};
class StarGanGeneration {
public:
APP_ERROR Init(const InitParam &initParam);
APP_ERROR DeInit();
APP_ERROR ReadImage(const std::string &imgPath, cv::Mat *imageMat);
APP_ERROR ResizeImage(const cv::Mat &srcImageMat, cv::Mat *dstImageMat);
APP_ERROR CVMatToTensorBase(const cv::Mat &imageMat, MxBase::TensorBase *tensorBase);
APP_ERROR Inference(const std::vector<MxBase::TensorBase> &inputs,
std::vector<MxBase::TensorBase> *outputs);
APP_ERROR PostProcess(std::vector<MxBase::TensorBase> outputs, cv::Mat *resultImg);
APP_ERROR Process(const std::string &imgPath, const std::string &imgName, const std::vector<float> &label);
APP_ERROR GetImageLabel(std::vector<float> label, MxBase::TensorBase *imgLabels);
// get infer time
double GetInferCostMilliSec() const {return inferCostTimeMilliSec;}
private:
APP_ERROR SaveResult(const cv::Mat &resultImg, const std::string &imgName);
std::shared_ptr<MxBase::DvppWrapper> dvppWrapper_;
std::shared_ptr<MxBase::ModelInferenceProcessor> model_;
std::string savePath_;
MxBase::ModelDesc modelDesc_;
uint32_t deviceId_ = 0;
// infer time
double inferCostTimeMilliSec = 0.0;
};
#endif
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "StarGanGeneration.h"
namespace {
std::vector<std::string> SELECTED_ATTRS {"Black_Hair", "Blond_Hair", "Brown_Hair", "Male", "Young"};
const std::string OM_MODEL_PATH = "../data/model/stargan.om";
}
APP_ERROR ScanImages(const std::string &path, std::vector<std::string> *imgFiles) {
DIR *dirPtr = opendir(path.c_str());
if (dirPtr == nullptr) {
LogError << "opendir failed. dir:" << path;
return APP_ERR_INTERNAL_ERROR;
}
dirent *direntPtr = nullptr;
while ((direntPtr = readdir(dirPtr)) != nullptr) {
std::string fileName = direntPtr->d_name;
if (fileName == "." || fileName == "..") {
continue;
}
imgFiles->emplace_back(path + "/" + fileName);
}
closedir(dirPtr);
return APP_ERR_OK;
}
std::vector<std::string> split(std::string str, char ch) {
size_t start = 0;
size_t len = 0;
std::vector<std::string> ret;
for (size_t i = 0; i < str.length(); i++) {
if (str[i] == ch && i+1 < str.length() && str[i+1] == ch) {
continue;
}
if (str[i] == ch) {
ret.push_back(str.substr(start, len));
start = i+1;
len = 0;
} else {
len++;
}
}
if (start < str.length())
ret.push_back(str.substr(start, len));
return ret;
}
int main(int argc, char* argv[]) {
if (argc <= 1) {
LogWarn << "Please input image path, such as '../data/test_data/'.";
return APP_ERR_OK;
}
InitParam initParam = {};
initParam.deviceId = 0;
initParam.modelPath = OM_MODEL_PATH;
initParam.savePath = "./result";
auto stargan = std::make_shared<StarGanGeneration>();
APP_ERROR ret = stargan->Init(initParam);
if (ret != APP_ERR_OK) {
LogError << "stargan init failed, ret=" << ret << ".";
return ret;
}
// Read the contents of a label
std::string dataPath = argv[1];
std::string imagePath = dataPath + "/images/";
std::string labelPath = dataPath + "/anno/list_attr_celeba.txt";
std::vector<std::string> imagePathList;
ret = ScanImages(imagePath, &imagePathList);
if (ret != APP_ERR_OK) {
LogError << "stargan init failed, ret=" << ret << ".";
return ret;
}
std::ifstream fin;
std::string s;
fin.open(labelPath);
int i = 0;
int imgNum;
std::map<int, std::string> idx2attr;
std::map<std::string, int> attr2idx;
auto startTime = std::chrono::high_resolution_clock::now();
while (getline(fin, s)) {
i++;
if (i == 1) {
imgNum = atoi(s.c_str());
} else if (i == 2) {
std::vector<std::string> allAttrNames = split(s, ' ');
for (size_t j = 0; j < allAttrNames.size(); j++) {
idx2attr[j] = allAttrNames[j];
attr2idx[allAttrNames[j]] = j;
}
} else {
std::vector<std::string> eachAttr = split(s, ' ');
// first one is file name
std::string imgName = eachAttr[0];
std::vector<float> label;
for (size_t j = 0; j < SELECTED_ATTRS.size(); j++) {
if (atoi(eachAttr[attr2idx[SELECTED_ATTRS[j]] + 1].c_str()) == 1)
label.push_back(1.0);
else
label.push_back(0.0);
// label.push_back(atoi(eachAttr[attr2idx[SELECTED_ATTRS[j]] + 1].c_str()) * -0.5);
}
ret = stargan->Process(imagePath + imgName, imgName, label);
if (ret != APP_ERR_OK) {
LogError << "stargan process failed, ret=" << ret << ".";
stargan->DeInit();
return ret;
}
}
}
fin.close();
auto endTime = std::chrono::high_resolution_clock::now();
stargan->DeInit();
double costMilliSecs = std::chrono::duration<double, std::milli>(endTime - startTime).count();
double fps = 1000.0 * imgNum / stargan->GetInferCostMilliSec();
LogInfo << "[Process Delay] cost: " << costMilliSecs << " ms\tfps: " << fps << " imgs/sec";
return APP_ERR_OK;
}
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" Model Infer """
import json
import logging
import MxpiDataType_pb2 as MxpiDataType
from StreamManagerApi import StreamManagerApi, MxDataInput, InProtobufVector, MxProtobufIn, StringVector
from config import config as cfg
class SdkApi:
""" Class SdkApi """
INFER_TIMEOUT = cfg.INFER_TIMEOUT
STREAM_NAME = cfg.STREAM_NAME
def __init__(self, pipeline_cfg):
self.pipeline_cfg = pipeline_cfg
self._stream_api = None
self._data_input = None
self._device_id = None
def init(self):
""" Initialize Stream """
with open(self.pipeline_cfg, 'r') as fp:
self._device_id = int(
json.loads(fp.read())[self.STREAM_NAME]["stream_config"]
["deviceId"])
print(f"The device id: {self._device_id}.")
# create api
self._stream_api = StreamManagerApi()
# init stream mgr
ret = self._stream_api.InitManager()
if ret != 0:
print(f"Failed to init stream manager, ret={ret}.")
return False
# create streams
with open(self.pipeline_cfg, 'rb') as fp:
pipe_line = fp.read()
ret = self._stream_api.CreateMultipleStreams(pipe_line)
if ret != 0:
print(f"Failed to create stream, ret={ret}.")
return False
self._data_input = MxDataInput()
return True
def __del__(self):
if not self._stream_api:
return
self._stream_api.DestroyAllStreams()
def _send_protobuf(self, stream_name, plugin_id, element_name, buf_type,
pkg_list):
""" Send Stream """
protobuf = MxProtobufIn()
protobuf.key = element_name.encode("utf-8")
protobuf.type = buf_type
protobuf.protobuf = pkg_list.SerializeToString()
protobuf_vec = InProtobufVector()
protobuf_vec.push_back(protobuf)
err_code = self._stream_api.SendProtobuf(stream_name, plugin_id,
protobuf_vec)
if err_code != 0:
logging.error(
"Failed to send data to stream, stream_name(%s), plugin_id(%s), element_name(%s), "
"buf_type(%s), err_code(%s).", stream_name, plugin_id,
element_name, buf_type, err_code)
return False
return True
def send_tensor_input(self, stream_name, plugin_id, element_name,
input_data, input_shape, data_type):
""" Send Tensor """
tensor_list = MxpiDataType.MxpiTensorPackageList()
tensor_pkg = tensor_list.tensorPackageVec.add()
# init tensor vector
tensor_vec = tensor_pkg.tensorVec.add()
tensor_vec.deviceId = self._device_id
tensor_vec.memType = 0
tensor_vec.tensorShape.extend(input_shape)
tensor_vec.tensorDataType = data_type
tensor_vec.dataStr = input_data
tensor_vec.tensorDataSize = len(input_data)
buf_type = b"MxTools.MxpiTensorPackageList"
return self._send_protobuf(stream_name, plugin_id, element_name,
buf_type, tensor_list)
def get_result(self, stream_name, out_plugin_id=0):
""" Get Result """
keys = [b"mxpi_tensorinfer0"]
keyVec = StringVector()
for key in keys:
keyVec.push_back(key)
infer_result = self._stream_api.GetProtobuf(stream_name, 0, keyVec)
print(infer_result)
if infer_result.size() == 0:
print("infer_result is null")
exit()
if infer_result[0].errorCode != 0:
print("GetProtobuf error. errorCode=%d" % (
infer_result[0].errorCode))
exit()
TensorList = MxpiDataType.MxpiTensorPackageList()
TensorList.ParseFromString(infer_result[0].messageBuf)
return TensorList
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" Model Config """
MODEL_WIDTH = 128
MODEL_HEIGHT = 128
STREAM_NAME = "stargan"
INFER_TIMEOUT = 100000
SELECTED_ATTRS = ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young']
TENSOR_DTYPE_FLOAT32 = 0
TENSOR_DTYPE_FLOAT16 = 1
TENSOR_DTYPE_INT8 = 2
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" Model Main """
import argparse
import time
import os
import numpy as np
import cv2
from api.infer import SdkApi
from config import config as cfg
def parser_args():
""" Args Setting """
parser = argparse.ArgumentParser(description="stgan inference")
parser.add_argument("--img_path",
type=str,
required=True,
help="image directory.")
parser.add_argument(
"--pipeline_path",
type=str,
required=False,
default="../data/config/stargan.pipeline",
help="image file path. The default is 'config/stgan.pipeline'. ")
parser.add_argument(
"--model_type",
type=str,
required=False,
default="dvpp",
help=
"rgb: high-precision, dvpp: high performance. The default is 'dvpp'.")
parser.add_argument(
"--infer_result_dir",
type=str,
required=False,
default="./result",
help=
"cache dir of inference result. The default is './result'."
)
parser.add_argument("--ann_file",
type=str,
required=False,
help="eval ann_file.")
args_ = parser.parse_args()
return args_
def get_labels(img_dir):
""" Get Labels Setting """
# labels preprocess
selected_attrs = cfg.SELECTED_ATTRS
lines = [
line.rstrip() for line in open(
os.path.join(img_dir, 'anno', 'list_attr_celeba.txt'), 'r')
]
all_attr_names = lines[1].split()
attr2idx = {}
idx2attr = {}
for i, attr_name in enumerate(all_attr_names):
attr2idx[attr_name] = i
idx2attr[i] = attr_name
lines = lines[2:]
items = {}
for i, line in enumerate(lines):
split = line.split()
filename = split[0]
values = split[1:]
label = []
for attr_name in selected_attrs:
idx = attr2idx[attr_name]
label.append(1.0 if values[idx] == '1' else 0.0)
items[filename] = np.array(label).astype(np.float32)
return items
def process_img(img_file):
""" Preprocess Image """
print(img_file)
img = cv2.imread(img_file)
model_img = cv2.resize(img, (cfg.MODEL_WIDTH, cfg.MODEL_HEIGHT))
img_ = model_img[:, :, ::-1].transpose((2, 0, 1))
img_ = np.expand_dims(img_, axis=0)
img_ = np.array((img_-127.5)/127.5).astype(np.float32)
return img_
def decode_image(img):
""" Decode Image """
mean = 0.5 * 255
std = 0.5 * 255
return (img * std + mean).astype(np.uint8).transpose(
(1, 2, 0))
def image_inference(pipeline_path, stream_name, img_dir, result_dir,
replace_last, model_type):
""" Image Inference """
# init stream manager
sdk_api = SdkApi(pipeline_path)
if not sdk_api.init():
exit(-1)
if not os.path.exists(result_dir):
os.makedirs(result_dir)
img_data_plugin_id = 0
img_label_plugin_id = 1
label_items = get_labels(img_dir)
print(f"\nBegin to inference for {img_dir}.\n\n")
file_list = os.listdir(os.path.join(img_dir, 'images'))
for _, file_name in enumerate(file_list):
if not file_name.lower().endswith((".jpg", "jpeg")):
continue
file_path = os.path.join(img_dir, 'images', file_name)
save_path = os.path.join(result_dir,
f"{os.path.splitext(file_name)[0]}.jpg")
if not replace_last and os.path.exists(save_path):
print(f"The infer result image({save_path}) has existed, will be skip.")
continue
img_np = process_img(file_path)
start_time = time.time()
sdk_api.send_tensor_input(stream_name, img_data_plugin_id, "appsrc0",
img_np.tobytes(), img_np.shape, cfg.TENSOR_DTYPE_FLOAT32)
# set label data
label_dim = np.expand_dims(label_items[file_name], axis=0)
input_shape = 4
sdk_api.send_tensor_input(stream_name, img_label_plugin_id, "appsrc1",
label_dim.tobytes(), [1, input_shape], cfg.TENSOR_DTYPE_FLOAT32)
result = sdk_api.get_result(stream_name)
end_time = time.time() - start_time
print(f"The image({save_path}) inference time is {end_time}")
data = np.frombuffer(result.tensorPackageVec[0].tensorVec[0].dataStr, dtype=np.float32)
data = data.reshape(3, 128, 128)
img = decode_image(data)
img = img[:, :, ::-1]
cv2.imwrite(save_path, img)
if __name__ == "__main__":
args = parser_args()
args.replace_last = True
args.stream_name = "stargan".encode("utf-8")
image_inference(args.pipeline_path, args.stream_name, args.img_path,
args.infer_result_dir, args.replace_last, args.model_type)
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
image_path=$1
result_dir=$2
set -e
# Simple log helper functions
info() { echo -e "\033[1;34m[INFO ][MxStream] $1\033[1;37m" ; }
warn() { echo >&2 -e "\033[1;31m[WARN ][MxStream] $1\033[1;37m" ; }
export GST_PLUGIN_SCANNER=${MX_SDK_HOME}/opensource/libexec/gstreamer-1.0/gst-plugin-scanner
python3 main.py --img_path=$image_path --infer_result_dir=$result_dir
exit 0
\ No newline at end of file
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Train the model."""
from time import time
import os
import argparse
import ast
import glob
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import nn
from mindspore import Tensor, context
from mindspore.common import set_seed
from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_rank
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, _InternalCallbackParam, RunContext
from mindspore.train.serialization import export, load_param_into_net
from src.dataset import dataloader
from src.utils import get_network, resume_model
from src.cell import TrainOneStepCellGen, TrainOneStepCellDis
from src.loss import GeneratorLoss, DiscriminatorLoss, ClassificationLoss, WGANGPGradientPenalty
from src.reporter import Reporter
set_seed(1)
# Modelarts
parser = argparse.ArgumentParser(description='StarGAN_args')
parser.add_argument('--modelarts', type=ast.literal_eval, default=True, help='Dataset path')
parser.add_argument('--data_url', type=str, default=None, help='Dataset path')
parser.add_argument('--train_url', type=str, default=None, help='Train output path')
# Model configuration.
parser.add_argument('--c_dim', type=int, default=5, help='dimension of domain labels (1st dataset)')
parser.add_argument('--c2_dim', type=int, default=7, help='dimension of domain labels (2nd dataset)')
parser.add_argument('--celeba_crop_size', type=int, default=178, help='crop size for the CelebA dataset')
parser.add_argument('--rafd_crop_size', type=int, default=256, help='crop size for the RaFD dataset')
parser.add_argument('--image_size', type=int, default=128, help='image resolution')
parser.add_argument('--g_conv_dim', type=int, default=64, help='number of conv filters in the first layer of G')
parser.add_argument('--d_conv_dim', type=int, default=64, help='number of conv filters in the first layer of D')
parser.add_argument('--g_repeat_num', type=int, default=6, help='number of residual blocks in G')
parser.add_argument('--d_repeat_num', type=int, default=6, help='number of strided conv layers in D')
parser.add_argument('--lambda_cls', type=float, default=1, help='weight for domain classification loss')
parser.add_argument('--lambda_rec', type=float, default=10, help='weight for reconstruction loss')
parser.add_argument('--lambda_gp', type=float, default=10, help='weight for gradient penalty')
# Training configuration.
parser.add_argument('--dataset', type=str, default='CelebA', choices=['CelebA', 'RaFD', 'Both'])
parser.add_argument('--batch_size', type=int, default=4, help='mini-batch size')
parser.add_argument('--num_iters', type=int, default=200000, help='number of total iterations for training D')
parser.add_argument('--epochs', type=int, default=59, help='number of epoch')
parser.add_argument('--num_iters_decay', type=int, default=100000, help='number of iterations for decaying lr')
parser.add_argument('--g_lr', type=float, default=0.0001, help='learning rate for G')
parser.add_argument('--d_lr', type=float, default=0.0001, help='learning rate for D')
parser.add_argument('--n_critic', type=int, default=5, help='number of D updates per each G update')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for Adam optimizer')
parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer')
parser.add_argument('--resume_iters', type=int, default=200000, help='resume training from this step')
parser.add_argument('--selected_attrs', '--list', nargs='+', help='selected attributes for the CelebA dataset',
default=['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young'])
parser.add_argument('--init_type', type=str, default='normal', choices=("normal", "xavier"),
help='network initialization, default is normal.')
parser.add_argument('--init_gain', type=float, default=0.02,
help='scaling factor for normal, xavier and orthogonal, default is 0.02.')
# Test configuration.
parser.add_argument('--test_iters', type=int, default=200000, help='test model from this step')
# Train Device.
parser.add_argument('--num_workers', type=int, default=8)
parser.add_argument('--mode', type=str, default='train', choices=['train', 'test'])
parser.add_argument('--device_target', type=str, default='Ascend')
parser.add_argument("--run_distribute", type=int, default=0, help="Run distribute, default: false.")
parser.add_argument("--device_id", type=int, default=0, help="device id, default: 0.")
parser.add_argument("--device_num", type=int, default=1, help="number of device, default: 0.")
parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default: 0.")
# Directories.
parser.add_argument('--celeba_image_dir', type=str, default=r'/home/data/celeba/images')
parser.add_argument('--attr_path', type=str, default=r'/home/data/celeba/list_attr_celeba.txt')
parser.add_argument('--rafd_image_dir', type=str, default='data/RaFD/train')
parser.add_argument('--log_dir', type=str, default='stargan/logs')
parser.add_argument('--model_save_dir', type=str, default='./models/')
parser.add_argument('--result_dir', type=str, default='./results')
# Step size.
parser.add_argument('--log_step', type=int, default=10)
parser.add_argument('--sample_step', type=int, default=5000)
parser.add_argument('--model_save_step', type=int, default=5000)
parser.add_argument('--lr_update_step', type=int, default=1000)
# export
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', \
help='file format')
#args_opt = parser.parse_args()
args_opt, unparsed = parser.parse_known_args()
config = args_opt
if __name__ == '__main__':
#config = get_config()
if args_opt.modelarts:
import moxing as mox
device_id = int(os.getenv('DEVICE_ID'))
device_num = int(os.getenv('RANK_SIZE'))
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False)
context.set_context(device_id=device_id)
local_data_url = './cache/data'
local_train_url = '/cache/ckpt'
if not os.path.isdir(local_data_url):
os.makedirs(local_data_url)
if not os.path.isdir(local_train_url):
os.makedirs(local_train_url)
# local_data_url = os.path.join(local_data_url, str(device_id))
# local_train_url = os.path.join(local_train_url, str(device_id))
# unzip data
path = os.getcwd()
print("cwd: %s" % path)
data_url = 'obs://data/CelebA/'
data_name = '/celeba.zip'
print('listdir1: %s' % os.listdir('./'))
a1time = time()
mox.file.copy_parallel(args_opt.data_url, local_data_url)
print('listdir2: %s' % os.listdir(local_data_url))
b1time = time()
print('time1:', b1time - a1time)
a2time = time()
zip_command = "unzip -o %s -d %s" % (local_data_url + data_name, local_data_url)
if os.system(zip_command) == 0:
print('Successful backup')
else:
print('FAILED backup')
b2time = time()
print('time2:', b2time - a2time)
print('listdir3: %s' % os.listdir(local_data_url))
# Device Environment
if config.run_distribute:
if config.device_target == "Ascend":
rank = device_id
# device_num = device_num
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
init()
else:
rank = 0
device_num = 1
data_path = local_data_url + '/celeba/images'
attr_path = local_data_url + '/celeba/list_attr_celeba.txt'
dataset, length = dataloader(img_path=data_path,
attr_path=attr_path,
batch_size=config.batch_size,
selected_attr=config.selected_attrs,
device_num=config.num_workers,
dataset=config.dataset,
mode=config.mode,
shuffle=True)
else:
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target,
device_id=config.device_id, save_graphs=False)
if args_opt.run_distribute:
if os.getenv("DEVICE_ID", "not_set").isdigit():
context.set_context(device_id=int(os.getenv("DEVICE_ID")))
device_num = config.device_num
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=device_num)
init()
rank = get_rank()
data_path = config.celeba_image_dir
attr_path = config.attr_path
local_train_url = config.model_save_dir
dataset, length = dataloader(img_path=data_path,
attr_path=attr_path,
batch_size=config.batch_size,
selected_attr=config.selected_attrs,
device_num=config.device_num,
dataset=config.dataset,
mode=config.mode,
shuffle=True)
print(length)
dataset_iter = dataset.create_dict_iterator()
# Get and initial network
generator, discriminator = get_network(config)
cls_loss = ClassificationLoss()
wgan_loss = WGANGPGradientPenalty(discriminator)
# Define network with loss
G_loss_cell = GeneratorLoss(config, generator, discriminator)
D_loss_cell = DiscriminatorLoss(config, generator, discriminator)
# Define Optimizer
star_iter = 0
iter_sum = config.num_iters
Optimizer_G = nn.Adam(generator.trainable_params(), learning_rate=config.g_lr,
beta1=config.beta1, beta2=config.beta2)
Optimizer_D = nn.Adam(discriminator.trainable_params(), learning_rate=config.d_lr,
beta1=config.beta1, beta2=config.beta2)
# Define One step train
G_trainOneStep = TrainOneStepCellGen(G_loss_cell, Optimizer_G)
D_trainOneStep = TrainOneStepCellDis(D_loss_cell, Optimizer_D)
# Train
G_trainOneStep.set_train()
D_trainOneStep.set_train()
print('Start Training')
reporter = Reporter(config)
ckpt_config = CheckpointConfig(save_checkpoint_steps=config.model_save_step)
ckpt_cb_g = ModelCheckpoint(config=ckpt_config, directory=local_train_url, prefix='Generator')
ckpt_cb_d = ModelCheckpoint(config=ckpt_config, directory=local_train_url, prefix='Discriminator')
cb_params_g = _InternalCallbackParam()
cb_params_g.train_network = generator
cb_params_g.cur_step_num = 0
cb_params_g.batch_num = 4
cb_params_g.cur_epoch_num = 0
cb_params_d = _InternalCallbackParam()
cb_params_d.train_network = discriminator
cb_params_d.cur_step_num = 0
cb_params_d.batch_num = config.batch_size
cb_params_d.cur_epoch_num = 0
run_context_g = RunContext(cb_params_g)
run_context_d = RunContext(cb_params_d)
ckpt_cb_g.begin(run_context_g)
ckpt_cb_d.begin(run_context_d)
start = time()
for iterator in range(config.num_iters):
data = next(dataset_iter)
x_real = Tensor(data['image'], mstype.float32)
c_trg = Tensor(data['attr'], mstype.float32)
c_org = Tensor(data['attr'], mstype.float32)
np.random.shuffle(c_trg)
d_out = D_trainOneStep(x_real, c_org, c_trg)
if (iterator + 1) % config.n_critic == 0:
g_out = G_trainOneStep(x_real, c_org, c_trg)
if (iterator + 1) % config.log_step == 0:
reporter.print_info(start, iterator, g_out, d_out)
_, _, dict_G, dict_D = reporter.return_loss_array(g_out, d_out)
if (iterator + 1) % config.model_save_step == 0:
cb_params_d.cur_step_num = iterator + 1
cb_params_d.batch_num = iterator + 2
cb_params_g.cur_step_num = iterator + 1
cb_params_g.batch_num = iterator + 2
ckpt_cb_g.step_end(run_context_g)
ckpt_cb_d.step_end(run_context_d)
if args_opt.modelarts:
print('output dir3: %s' % os.listdir(local_train_url))
ckpt_list = glob.glob("/cache/ckpt/*.ckpt")
if not ckpt_list:
print("ckpt file not generated.")
ckpt_list.sort(key=os.path.getmtime)
ckpt_model = ckpt_list[-1]
print("checkpoint path", ckpt_model)
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
G, D = get_network(config)
# Use BatchNorm2d with batchsize=1, affine=False, training=True instead of InstanceNorm2d
# Use real mean and varance rather than moving_men and moving_varance in BatchNorm2d
G.set_train(True)
param_G, _ = resume_model(config, G, D)
load_param_into_net(G, param_G)
input_array = Tensor(np.random.uniform(-1.0, 1.0, size=(1, 3, 128, 128)).astype(np.float32))
input_label = Tensor(np.random.uniform(-1.0, 1.0, size=(1, 5)).astype(np.float32))
G_file = f"StarGAN_Generator"
export(G, input_array, input_label, file_name='/cache/ckpt/stargan', file_format=config.file_format)
mox.file.copy_parallel(local_train_url, args_opt.train_url)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment