diff --git a/research/cv/StarGAN/infer/convert/convert_om.sh b/research/cv/StarGAN/infer/convert/convert_om.sh new file mode 100644 index 0000000000000000000000000000000000000000..670f6667c5fc2a08c0ed1e3f74c7e08b52199493 --- /dev/null +++ b/research/cv/StarGAN/infer/convert/convert_om.sh @@ -0,0 +1,43 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if [ $# -ne 2 ] +then + echo "Wrong parameter format." + echo "Usage:" + echo " bash $0 INPUT_AIR_PATH OUTPUT_OM_PATH_NAME" + echo "Example: " + echo " bash convert_om.sh models/0-150_1251.air models/0-150_1251.om" + + exit 255 +fi + +input_air_path=$1 +output_om_path=$2 + +export ASCEND_SLOG_PRINT_TO_STDOUT=1 + +echo "Input AIR file path: ${input_air_path}" +echo "Output OM file path: ${output_om_path}" + +atc --input_format=NCHW \ +--framework=1 \ +--model=${input_air_path} \ +--output=${output_om_path} \ +--soc_version=Ascend310 \ +--disable_reuse_memory=0 \ +--output_type=FP32 \ +--precision_mode=allow_fp32_to_fp16 \ +--op_select_implmode=high_precision \ No newline at end of file diff --git a/research/cv/StarGAN/infer/data/config/stargan.pipeline b/research/cv/StarGAN/infer/data/config/stargan.pipeline new file mode 100644 index 0000000000000000000000000000000000000000..499d1d8416c8bcb364200e86f238cfa98b937d6e --- /dev/null +++ b/research/cv/StarGAN/infer/data/config/stargan.pipeline @@ -0,0 +1,35 @@ +{ + "stargan": { + "stream_config": { + "deviceId": "0" + }, + "appsrc0": { + "props": { + "blocksize": "409600" + }, + "factory": "appsrc", + "next": "mxpi_tensorinfer0:0" + }, + "appsrc1": { + "props": { + "blocksize": "409600" + }, + "factory": "appsrc", + "next": "mxpi_tensorinfer0:1" + }, + "mxpi_tensorinfer0": { + "props": { + "dataSource": "appsrc0,appsrc1", + "modelPath": "../data/model/stargan.om" + }, + "factory": "mxpi_tensorinfer", + "next": "appsink0" + }, + "appsink0": { + "props": { + "blocksize": "4096000" + }, + "factory": "appsink" + } + } +} \ No newline at end of file diff --git a/research/cv/StarGAN/infer/data/model/label.names b/research/cv/StarGAN/infer/data/model/label.names new file mode 100644 index 0000000000000000000000000000000000000000..3007b80e177fff9460b5ea186e252d914ca14c20 --- /dev/null +++ b/research/cv/StarGAN/infer/data/model/label.names @@ -0,0 +1,40 @@ +5_o_Clock_Shadow +Arched_Eyebrows +Attractive +Bags_Under_Eyes +Bald +Bangs +Big_Lips +Big_Nose +Black_Hair +Blond_Hair +Blurry +Brown_Hair +Bushy_Eyebrows +Chubby +Double_Chin +Eyeglasses +Goatee +Gray_Hair +Heavy_Makeup +High_Cheekbones +Male +Mouth_Slightly_Open +Mustache +Narrow_Eyes +No_Beard +Oval_Face +Pale_Skin +Pointy_Nose +Receding_Hairline +Rosy_Cheeks +Sideburns +Smiling +Straight_Hair +Wavy_Hair +Wearing_Earrings +Wearing_Hat +Wearing_Lipstick +Wearing_Necklace +Wearing_Necktie +Young \ No newline at end of file diff --git a/research/cv/StarGAN/infer/docker_start_infer.sh b/research/cv/StarGAN/infer/docker_start_infer.sh new file mode 100644 index 0000000000000000000000000000000000000000..022ff222968145977df53c32a26d31fde5429d39 --- /dev/null +++ b/research/cv/StarGAN/infer/docker_start_infer.sh @@ -0,0 +1,48 @@ +#!/bin/bash + +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +docker_image=$1 +data_dir=$2 + +function show_help() { + echo "Usage: docker_start.sh docker_image data_dir" +} + +function param_check() { + if [ -z "${docker_image}" ]; then + echo "please input docker_image" + show_help + exit 1 + fi + + if [ -z "${data_dir}" ]; then + echo "please input data_dir" + show_help + exit 1 + fi +} + +param_check + +docker run -it \ + --device=/dev/davinci0 \ + --device=/dev/davinci_manager \ + --device=/dev/devmm_svm \ + --device=/dev/hisi_hdc \ + -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \ + -v ${data_dir}:${data_dir} \ + ${docker_image} \ + /bin/bash \ No newline at end of file diff --git a/research/cv/StarGAN/infer/mxbase/CMakeLists.txt b/research/cv/StarGAN/infer/mxbase/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..0da1342654283d841d84e5db1a2a31be28cc4b79 --- /dev/null +++ b/research/cv/StarGAN/infer/mxbase/CMakeLists.txt @@ -0,0 +1,44 @@ +cmake_minimum_required(VERSION 3.14.0) +project(stargan) +set(TARGET_MAIN Stargan) + +add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) +add_definitions(-Dgoogle=mindxsdk_private) +add_compile_options(-std=c++11 -fPIE -fstack-protector-all -fPIC -Wall) +add_link_options(-Wl,-z,relro,-z,now,-z,noexecstack -s -pie) +# Check environment variable +if(NOT DEFINED ENV{ASCEND_HOME}) + message(FATAL_ERROR "please define environment variable:ASCEND_HOME") +endif() +if(NOT DEFINED ENV{ARCH_PATTERN}) + message(WARNING "please define environment variable:ARCH_PATTERN") +endif() +set(ACL_LIB_PATH $ENV{ASCEND_HOME}/nnrt/latest/acllib) +set(MXBASE_ROOT_DIR $ENV{MX_SDK_HOME}) +set(MXBASE_INC ${MXBASE_ROOT_DIR}/include) +set(MXBASE_LIB_DIR ${MXBASE_ROOT_DIR}/lib) +set(MXBASE_POST_LIB_DIR ${MXBASE_ROOT_DIR}/lib/modelpostprocessors) +set(MXBASE_POST_PROCESS_DIR ${MXBASE_ROOT_DIR}/include/MxBase/postprocess/include) + +set(OPENSOURCE_DIR $ENV{MX_SDK_HOME}/opensource) + +include_directories(src) +include_directories(${ACL_INC_DIR}) +include_directories(${OPENSOURCE_DIR}/include) +include_directories(${OPENSOURCE_DIR}/include/opencv4) + + +include_directories(${MXBASE_INC}) +include_directories(${MXBASE_POST_PROCESS_DIR}) + +link_directories(${ACL_LIB_DIR}) +link_directories(${OPENSOURCE_DIR}/lib) +link_directories(${MXBASE_LIB_DIR}) +link_directories(${MXBASE_POST_LIB_DIR}) + + +add_executable(${TARGET_MAIN} src/main.cpp src/StarGanGeneration.cpp) + +target_link_libraries(${TARGET_MAIN} glog cpprest mxbase opencv_world) + +install(TARGETS ${TARGET_MAIN} RUNTIME DESTINATION ${PROJECT_SOURCE_DIR}/) \ No newline at end of file diff --git a/research/cv/StarGAN/infer/mxbase/StarGanGeneration.cpp b/research/cv/StarGAN/infer/mxbase/StarGanGeneration.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4226ea4b2a54b55318a42b056327d1c4e432fbc8 --- /dev/null +++ b/research/cv/StarGAN/infer/mxbase/StarGanGeneration.cpp @@ -0,0 +1,258 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "StarGanGeneration.h" +#include "MxBase/DeviceManager/DeviceManager.h" +#include "MxBase/Log/Log.h" + +namespace { + const uint32_t FLOAT32_TYPE_BYTE_NUM = 4; + const float NORMALIZE_MEAN = 255/2; + const float NORMALIZE_STD = 255/2; + const uint32_t OUTPUT_HEIGHT = 128; + const uint32_t OUTPUT_WIDTH = 128; + const uint32_t CHANNEL = 3; +} + +void PrintTensorShape(const std::vector<MxBase::TensorDesc> &tensorDescVec, const std::string &tensorName) { + LogInfo << "The shape of " << tensorName << " is as follows:"; + for (size_t i = 0; i < tensorDescVec.size(); ++i) { + LogInfo << " Tensor " << i << ":"; + for (size_t j = 0; j < tensorDescVec[i].tensorDims.size(); ++j) { + LogInfo << " dim: " << j << ": " << tensorDescVec[i].tensorDims[j]; + } + } +} + +APP_ERROR StarGanGeneration::Init(const InitParam &initParam) { + deviceId_ = initParam.deviceId; + APP_ERROR ret = MxBase::DeviceManager::GetInstance()->InitDevices(); + if (ret != APP_ERR_OK) { + LogError << "Init devices failed, ret=" << ret << "."; + return ret; + } + ret = MxBase::TensorContext::GetInstance()->SetContext(initParam.deviceId); + if (ret != APP_ERR_OK) { + LogError << "Set context failed, ret=" << ret << "."; + return ret; + } + dvppWrapper_ = std::make_shared<MxBase::DvppWrapper>(); + ret = dvppWrapper_->Init(); + if (ret != APP_ERR_OK) { + LogError << "DvppWrapper init failed, ret=" << ret << "."; + return ret; + } + model_ = std::make_shared<MxBase::ModelInferenceProcessor>(); + ret = model_->Init(initParam.modelPath, modelDesc_); + if (ret != APP_ERR_OK) { + LogError << "ModelInferenceProcessor init failed, ret=" << ret << "."; + return ret; + } + savePath_ = initParam.savePath; + PrintTensorShape(modelDesc_.inputTensors, "Model Input Tensors"); + PrintTensorShape(modelDesc_.outputTensors, "Model Output Tensors"); + + return APP_ERR_OK; +} + +APP_ERROR StarGanGeneration::DeInit() { + dvppWrapper_->DeInit(); + model_->DeInit(); + MxBase::DeviceManager::GetInstance()->DestroyDevices(); + return APP_ERR_OK; +} + +APP_ERROR StarGanGeneration::ReadImage(const std::string &imgPath, cv::Mat *imageMat) { + *imageMat = cv::imread(imgPath, cv::IMREAD_COLOR); + return APP_ERR_OK; +} + +void StarGanGeneration::ResizeImage(const cv::Mat &srcImageMat, cv::Mat *dstImageMat) { + static constexpr uint32_t resizeHeight = OUTPUT_HEIGHT; + static constexpr uint32_t resizeWidth = OUTPUT_WIDTH; + + cv::resize(srcImageMat, *dstImageMat, cv::Size(resizeWidth, resizeHeight)); +} + +APP_ERROR StarGanGeneration::CVMatToTensorBase(const cv::Mat& imageMat, MxBase::TensorBase *tensorBase) { + uint32_t dataSize = 1; + for (size_t i = 0; i < modelDesc_.inputTensors.size(); ++i) { + std::vector<uint32_t> shape = {}; + for (size_t j = 0; j < modelDesc_.inputTensors[i].tensorDims.size(); ++j) { + shape.push_back((uint32_t)modelDesc_.inputTensors[i].tensorDims[j]); + } + for (uint32_t s = 0; s < shape.size(); ++s) { + dataSize *= shape[s]; + } + } + // mat NHWC to NCHW, BGR to RGB, and Normalize + size_t H = OUTPUT_HEIGHT, W = OUTPUT_WIDTH, C = CHANNEL; + + float* mat_data = new float[dataSize]; + dataSize = dataSize * FLOAT32_TYPE_BYTE_NUM; + for (size_t c = 0; c < C; c++) { + for (size_t h = 0; h < H; h++) { + for (size_t w = 0; w < W; w++) { + int id = (C - c - 1) * (H * W) + h * W + w; + mat_data[id] = (imageMat.at<cv::Vec3b>(h, w)[c] - NORMALIZE_MEAN) / NORMALIZE_STD; + } + } + } + + MxBase::MemoryData memoryDataDst(dataSize, MxBase::MemoryData::MEMORY_DEVICE, deviceId_); + MxBase::MemoryData memoryDataSrc(reinterpret_cast<void*>(&mat_data[0]), + dataSize, MxBase::MemoryData::MEMORY_HOST_MALLOC); + + APP_ERROR ret = MxBase::MemoryHelper::MxbsMallocAndCopy(memoryDataDst, memoryDataSrc); + if (ret != APP_ERR_OK) { + LogError << GetError(ret) << "Memory malloc failed."; + return ret; + } + std::vector<uint32_t> shape = { 1, CHANNEL, OUTPUT_HEIGHT, OUTPUT_WIDTH}; + *tensorBase = MxBase::TensorBase(memoryDataDst, false, shape, MxBase::TENSOR_DTYPE_FLOAT32); + return APP_ERR_OK; +} + +APP_ERROR StarGanGeneration::Inference(const std::vector<MxBase::TensorBase> &inputs, + std::vector<MxBase::TensorBase> *outputs) { + auto dtypes = model_->GetOutputDataType(); + for (size_t i = 0; i < modelDesc_.outputTensors.size(); ++i) { + std::vector<uint32_t> shape = {}; + for (size_t j = 0; j < modelDesc_.outputTensors[i].tensorDims.size(); ++j) { + shape.push_back((uint32_t)modelDesc_.outputTensors[i].tensorDims[j]); + } + MxBase::TensorBase tensor(shape, dtypes[i], MxBase::MemoryData::MemoryType::MEMORY_DEVICE, deviceId_); + APP_ERROR ret = MxBase::TensorBase::TensorBaseMalloc(tensor); + if (ret != APP_ERR_OK) { + LogError << "TensorBaseMalloc failed, ret=" << ret << "."; + return ret; + } + outputs->push_back(tensor); + } + MxBase::DynamicInfo dynamicInfo = {}; + dynamicInfo.dynamicType = MxBase::DynamicType::STATIC_BATCH; + dynamicInfo.batchSize = 1; + + APP_ERROR ret = model_->ModelInference(inputs, *outputs, dynamicInfo); + if (ret != APP_ERR_OK) { + LogError << "ModelInference failed, ret=" << ret << "."; + return ret; + } + return APP_ERR_OK; +} + +APP_ERROR StarGanGeneration::PostProcess(std::vector<MxBase::TensorBase> outputs, cv::Mat *resultImg) { + APP_ERROR ret = outputs[0].ToHost(); + if (ret != APP_ERR_OK) { + LogError << GetError(ret) << "tohost fail."; + return ret; + } + float *outputPtr = reinterpret_cast<float *>(outputs[0].GetBuffer()); + + size_t H = OUTPUT_HEIGHT, W = OUTPUT_WIDTH, C = CHANNEL; + + for (size_t c = 0; c < C; c++) { + for (size_t h = 0; h < H; h++) { + for (size_t w = 0; w < W; w++) { + float tmpNum = *(outputPtr + (C - c - 1) * (H * W) + h * W + w) * NORMALIZE_STD + NORMALIZE_MEAN; + resultImg->at<cv::Vec3b>(h, w)[c] = static_cast<int>(tmpNum); + } + } + } + + return APP_ERR_OK; +} + +APP_ERROR StarGanGeneration::SaveResult(const cv::Mat &resultImg, const std::string &imgName) { + DIR *dirPtr = opendir(savePath_.c_str()); + if (dirPtr == nullptr) { + std::string path1 = "mkdir -p " + savePath_; + system(path1.c_str()); + } + cv::imwrite(savePath_ + "/" + imgName, resultImg); + return APP_ERR_OK; +} + +APP_ERROR StarGanGeneration::GetImageLabel(std::vector<float> label, MxBase::TensorBase *imgLabels) { + float* mat_data = new float[label.size()]; + for (size_t i = 0; i < label.size(); i++) { + mat_data[i] = label[i]; + } + MxBase::MemoryData memoryDataDst(label.size()*FLOAT32_TYPE_BYTE_NUM, MxBase::MemoryData::MEMORY_DEVICE, deviceId_); + MxBase::MemoryData memoryDataSrc(reinterpret_cast<void*>(&mat_data[0]), + label.size()*FLOAT32_TYPE_BYTE_NUM, MxBase::MemoryData::MEMORY_HOST_MALLOC); + APP_ERROR ret = MxBase::MemoryHelper::MxbsMallocAndCopy(memoryDataDst, memoryDataSrc); + if (ret != APP_ERR_OK) { + LogError << GetError(ret) << "Memory malloc failed."; + return ret; + } + + const std::vector<uint32_t> shape = {1, (unsigned int)label.size()}; + *imgLabels = MxBase::TensorBase(memoryDataDst, false, shape, MxBase::TENSOR_DTYPE_FLOAT32); + return APP_ERR_OK; +} + +APP_ERROR StarGanGeneration::Process(const std::string &imgPath, + const std::string &imgName, const std::vector<float> &label) { + cv::Mat imageMat; + APP_ERROR ret = ReadImage(imgPath, &imageMat); + if (ret != APP_ERR_OK) { + LogError << "ReadImage failed, ret=" << ret << "."; + return ret; + } + ResizeImage(imageMat, &imageMat); + + std::vector<MxBase::TensorBase> inputs = {}; + std::vector<MxBase::TensorBase> outputs = {}; + + MxBase::TensorBase tensorBase; + ret = CVMatToTensorBase(imageMat, &tensorBase); + if (ret != APP_ERR_OK) { + LogError << "CVMatToTensorBase failed, ret=" << ret << "."; + return ret; + } + inputs.push_back(tensorBase); + + MxBase::TensorBase imgLabels; + ret = GetImageLabel(label, &imgLabels); + if (ret != APP_ERR_OK) { + LogError << "Get Image label failed, ret=" << ret << "."; + return ret; + } + inputs.push_back(imgLabels); + auto startTime = std::chrono::high_resolution_clock::now(); + ret = Inference(inputs, &outputs); + auto endTime = std::chrono::high_resolution_clock::now(); + double costMs = std::chrono::duration<double, std::milli>(endTime - startTime).count(); // save time + inferCostTimeMilliSec += costMs; + if (ret != APP_ERR_OK) { + LogError << "Inference failed, ret=" << ret << "."; + return ret; + } + cv::Mat resultImg(OUTPUT_HEIGHT, OUTPUT_WIDTH, CV_8UC3); + ret = PostProcess(outputs, &resultImg); + if (ret != APP_ERR_OK) { + LogError << "PostProcess failed, ret=" << ret << "."; + return ret; + } + ret = SaveResult(resultImg, imgName); + if (ret != APP_ERR_OK) { + LogError << "Save infer results into file failed. ret = " << ret << "."; + return ret; + } + + return APP_ERR_OK; +} diff --git a/research/cv/StarGAN/infer/mxbase/StarGanGeneration.h b/research/cv/StarGAN/infer/mxbase/StarGanGeneration.h new file mode 100644 index 0000000000000000000000000000000000000000..d64689a65a2d851a86dda8b4f91a8208bcb3901a --- /dev/null +++ b/research/cv/StarGAN/infer/mxbase/StarGanGeneration.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2021. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MXBASE_STARGANGENERATION_H +#define MXBASE_STARGANGENERATION_H +#include <dirent.h> +#include <memory> +#include <vector> +#include <map> +#include <string> +#include <fstream> +#include <iostream> +#include <opencv2/opencv.hpp> +#include "MxBase/Log/Log.h" +#include "MxBase/DvppWrapper/DvppWrapper.h" +#include "MxBase/ModelInfer/ModelInferenceProcessor.h" +#include "MxBase/DeviceManager/DeviceManager.h" +#include "MxBase/Tensor/TensorContext/TensorContext.h" + +struct InitParam { + uint32_t deviceId; + std::string savePath; + std::string modelPath; +}; + +class StarGanGeneration { + public: + APP_ERROR Init(const InitParam &initParam); + APP_ERROR DeInit(); + APP_ERROR ReadImage(const std::string &imgPath, cv::Mat *imageMat); + void ResizeImage(const cv::Mat &srcImageMat, cv::Mat *dstImageMat); + APP_ERROR CVMatToTensorBase(const cv::Mat &imageMat, MxBase::TensorBase *tensorBase); + APP_ERROR Inference(const std::vector<MxBase::TensorBase> &inputs, + std::vector<MxBase::TensorBase> *outputs); + APP_ERROR PostProcess(std::vector<MxBase::TensorBase> outputs, cv::Mat *resultImg); + APP_ERROR Process(const std::string &imgPath, const std::string &imgName, const std::vector<float> &label); + APP_ERROR GetImageLabel(std::vector<float> label, MxBase::TensorBase *imgLabels); + // get infer time + double GetInferCostMilliSec() const {return inferCostTimeMilliSec;} + + private: + APP_ERROR SaveResult(const cv::Mat &resultImg, const std::string &imgName); + std::shared_ptr<MxBase::DvppWrapper> dvppWrapper_; + std::shared_ptr<MxBase::ModelInferenceProcessor> model_; + std::string savePath_; + MxBase::ModelDesc modelDesc_; + uint32_t deviceId_ = 0; + // infer time + double inferCostTimeMilliSec = 0.0; +}; + + +#endif diff --git a/research/cv/StarGAN/infer/mxbase/build.sh b/research/cv/StarGAN/infer/mxbase/build.sh new file mode 100644 index 0000000000000000000000000000000000000000..1841d76d5cc1951c40a788524546c9b3fc8f10ac --- /dev/null +++ b/research/cv/StarGAN/infer/mxbase/build.sh @@ -0,0 +1,55 @@ +#!/bin/bash + +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +path_cur=$(dirname $0) + +function check_env() +{ + # set ASCEND_VERSION to ascend-toolkit/latest when it was not specified by user + if [ ! "${ASCEND_VERSION}" ]; then + export ASCEND_VERSION=ascend-toolkit/latest + echo "Set ASCEND_VERSION to the default value: ${ASCEND_VERSION}" + else + echo "ASCEND_VERSION is set to ${ASCEND_VERSION} by user" + fi + + if [ ! "${ARCH_PATTERN}" ]; then + # set ARCH_PATTERN to ./ when it was not specified by user + export ARCH_PATTERN=./ + echo "ARCH_PATTERN is set to the default value: ${ARCH_PATTERN}" + else + echo "ARCH_PATTERN is set to ${ARCH_PATTERN} by user" + fi +} + +function build_stargan() +{ + cd $path_cur + rm -rf build + mkdir -p build + cd build + cmake .. + make + ret=$? + if [ ${ret} -ne 0 ]; then + echo "Failed to build stargan." + exit ${ret} + fi + make install +} + +check_env +build_stargan diff --git a/research/cv/StarGAN/infer/mxbase/main.cpp b/research/cv/StarGAN/infer/mxbase/main.cpp new file mode 100644 index 0000000000000000000000000000000000000000..831bb839bb4976b73a183f7ac87b6750a1243751 --- /dev/null +++ b/research/cv/StarGAN/infer/mxbase/main.cpp @@ -0,0 +1,139 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "StarGanGeneration.h" + +namespace { + std::vector<std::string> SELECTED_ATTRS {"Black_Hair", "Blond_Hair", "Brown_Hair", "Male", "Young"}; + const std::string OM_MODEL_PATH = "../data/model/stargan.om"; +} + +APP_ERROR ScanImages(const std::string &path, std::vector<std::string> *imgFiles) { + DIR *dirPtr = opendir(path.c_str()); + if (dirPtr == nullptr) { + LogError << "opendir failed. dir:" << path; + return APP_ERR_INTERNAL_ERROR; + } + dirent *direntPtr = nullptr; + while ((direntPtr = readdir(dirPtr)) != nullptr) { + std::string fileName = direntPtr->d_name; + if (fileName == "." || fileName == "..") { + continue; + } + + imgFiles->emplace_back(path + "/" + fileName); + } + closedir(dirPtr); + return APP_ERR_OK; +} + +std::vector<std::string> split(std::string str, char ch) { + size_t start = 0; + size_t len = 0; + std::vector<std::string> ret; + for (size_t i = 0; i < str.length(); i++) { + if (str[i] == ch && i+1 < str.length() && str[i+1] == ch) { + continue; + } + if (str[i] == ch) { + ret.push_back(str.substr(start, len)); + start = i+1; + len = 0; + } else { + len++; + } + } + if (start < str.length()) + ret.push_back(str.substr(start, len)); + return ret; +} + + +int main(int argc, char* argv[]) { + if (argc <= 1) { + LogWarn << "Please input image path, such as '../data/test_data/'."; + return APP_ERR_OK; + } + + InitParam initParam = {}; + initParam.deviceId = 0; + initParam.modelPath = OM_MODEL_PATH; + initParam.savePath = "./result"; + auto stargan = std::make_shared<StarGanGeneration>(); + APP_ERROR ret = stargan->Init(initParam); + if (ret != APP_ERR_OK) { + LogError << "stargan init failed, ret=" << ret << "."; + return ret; + } + + // Read the contents of a label + std::string dataPath = argv[1]; + std::string imagePath = dataPath + "/images/"; + std::string labelPath = dataPath + "/anno/list_attr_celeba.txt"; + + std::vector<std::string> imagePathList; + ret = ScanImages(imagePath, &imagePathList); + if (ret != APP_ERR_OK) { + LogError << "stargan init failed, ret=" << ret << "."; + return ret; + } + std::ifstream fin; + std::string s; + fin.open(labelPath); + int i = 0; + int imgNum; + std::map<int, std::string> idx2attr; + std::map<std::string, int> attr2idx; + auto startTime = std::chrono::high_resolution_clock::now(); + + while (getline(fin, s)) { + i++; + if (i == 1) { + imgNum = atoi(s.c_str()); + } else if (i == 2) { + std::vector<std::string> allAttrNames = split(s, ' '); + for (size_t j = 0; j < allAttrNames.size(); j++) { + idx2attr[j] = allAttrNames[j]; + attr2idx[allAttrNames[j]] = j; + } + } else { + std::vector<std::string> eachAttr = split(s, ' '); + // first one is file name + std::string imgName = eachAttr[0]; + std::vector<float> label; + for (size_t j = 0; j < SELECTED_ATTRS.size(); j++) { + if (atoi(eachAttr[attr2idx[SELECTED_ATTRS[j]] + 1].c_str()) == 1) + label.push_back(1.0); + else + label.push_back(0.0); + // label.push_back(atoi(eachAttr[attr2idx[SELECTED_ATTRS[j]] + 1].c_str()) * -0.5); + } + ret = stargan->Process(imagePath + imgName, imgName, label); + if (ret != APP_ERR_OK) { + LogError << "stargan process failed, ret=" << ret << "."; + stargan->DeInit(); + return ret; + } + } + } + fin.close(); + auto endTime = std::chrono::high_resolution_clock::now(); + stargan->DeInit(); + double costMilliSecs = std::chrono::duration<double, std::milli>(endTime - startTime).count(); + double fps = 1000.0 * imgNum / stargan->GetInferCostMilliSec(); + LogInfo << "[Process Delay] cost: " << costMilliSecs << " ms\tfps: " << fps << " imgs/sec"; + return APP_ERR_OK; +} diff --git a/research/cv/StarGAN/infer/mxbase/src/StarGanGeneration.cpp b/research/cv/StarGAN/infer/mxbase/src/StarGanGeneration.cpp new file mode 100644 index 0000000000000000000000000000000000000000..47a1111255d9f2b78a5ebe893e988f85936ce380 --- /dev/null +++ b/research/cv/StarGAN/infer/mxbase/src/StarGanGeneration.cpp @@ -0,0 +1,259 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "StarGanGeneration.h" +#include "MxBase/DeviceManager/DeviceManager.h" +#include "MxBase/Log/Log.h" + +namespace { + const uint32_t FLOAT32_TYPE_BYTE_NUM = 4; + const float NORMALIZE_MEAN = 255/2; + const float NORMALIZE_STD = 255/2; + const uint32_t OUTPUT_HEIGHT = 128; + const uint32_t OUTPUT_WIDTH = 128; + const uint32_t CHANNEL = 3; +} + +void PrintTensorShape(const std::vector<MxBase::TensorDesc> &tensorDescVec, const std::string &tensorName) { + LogInfo << "The shape of " << tensorName << " is as follows:"; + for (size_t i = 0; i < tensorDescVec.size(); ++i) { + LogInfo << " Tensor " << i << ":"; + for (size_t j = 0; j < tensorDescVec[i].tensorDims.size(); ++j) { + LogInfo << " dim: " << j << ": " << tensorDescVec[i].tensorDims[j]; + } + } +} + +APP_ERROR StarGanGeneration::Init(const InitParam &initParam) { + deviceId_ = initParam.deviceId; + APP_ERROR ret = MxBase::DeviceManager::GetInstance()->InitDevices(); + if (ret != APP_ERR_OK) { + LogError << "Init devices failed, ret=" << ret << "."; + return ret; + } + ret = MxBase::TensorContext::GetInstance()->SetContext(initParam.deviceId); + if (ret != APP_ERR_OK) { + LogError << "Set context failed, ret=" << ret << "."; + return ret; + } + dvppWrapper_ = std::make_shared<MxBase::DvppWrapper>(); + ret = dvppWrapper_->Init(); + if (ret != APP_ERR_OK) { + LogError << "DvppWrapper init failed, ret=" << ret << "."; + return ret; + } + model_ = std::make_shared<MxBase::ModelInferenceProcessor>(); + ret = model_->Init(initParam.modelPath, modelDesc_); + if (ret != APP_ERR_OK) { + LogError << "ModelInferenceProcessor init failed, ret=" << ret << "."; + return ret; + } + savePath_ = initParam.savePath; + PrintTensorShape(modelDesc_.inputTensors, "Model Input Tensors"); + PrintTensorShape(modelDesc_.outputTensors, "Model Output Tensors"); + + return APP_ERR_OK; +} + +APP_ERROR StarGanGeneration::DeInit() { + dvppWrapper_->DeInit(); + model_->DeInit(); + MxBase::DeviceManager::GetInstance()->DestroyDevices(); + return APP_ERR_OK; +} + +APP_ERROR StarGanGeneration::ReadImage(const std::string &imgPath, cv::Mat *imageMat) { + *imageMat = cv::imread(imgPath, cv::IMREAD_COLOR); + return APP_ERR_OK; +} + +APP_ERROR StarGanGeneration::ResizeImage(const cv::Mat &srcImageMat, cv::Mat *dstImageMat) { + static constexpr uint32_t resizeHeight = OUTPUT_HEIGHT; + static constexpr uint32_t resizeWidth = OUTPUT_WIDTH; + + cv::resize(srcImageMat, *dstImageMat, cv::Size(resizeWidth, resizeHeight)); + return APP_ERR_OK; +} + +APP_ERROR StarGanGeneration::CVMatToTensorBase(const cv::Mat& imageMat, MxBase::TensorBase *tensorBase) { + uint32_t dataSize = 1; + for (size_t i = 0; i < modelDesc_.inputTensors.size(); ++i) { + std::vector<uint32_t> shape = {}; + for (size_t j = 0; j < modelDesc_.inputTensors[i].tensorDims.size(); ++j) { + shape.push_back((uint32_t)modelDesc_.inputTensors[i].tensorDims[j]); + } + for (uint32_t s = 0; s < shape.size(); ++s) { + dataSize *= shape[s]; + } + } + // mat NHWC to NCHW, BGR to RGB, and Normalize + size_t H = OUTPUT_HEIGHT, W = OUTPUT_WIDTH, C = CHANNEL; + + float* mat_data = new float[dataSize]; + dataSize = dataSize * FLOAT32_TYPE_BYTE_NUM; + for (size_t c = 0; c < C; c++) { + for (size_t h = 0; h < H; h++) { + for (size_t w = 0; w < W; w++) { + int id = (C - c - 1) * (H * W) + h * W + w; + mat_data[id] = (imageMat.at<cv::Vec3b>(h, w)[c] - NORMALIZE_MEAN) / NORMALIZE_STD; + } + } + } + + MxBase::MemoryData memoryDataDst(dataSize, MxBase::MemoryData::MEMORY_DEVICE, deviceId_); + MxBase::MemoryData memoryDataSrc(reinterpret_cast<void*>(&mat_data[0]), + dataSize, MxBase::MemoryData::MEMORY_HOST_MALLOC); + + APP_ERROR ret = MxBase::MemoryHelper::MxbsMallocAndCopy(memoryDataDst, memoryDataSrc); + if (ret != APP_ERR_OK) { + LogError << GetError(ret) << "Memory malloc failed."; + return ret; + } + std::vector<uint32_t> shape = { 1, CHANNEL, OUTPUT_HEIGHT, OUTPUT_WIDTH}; + *tensorBase = MxBase::TensorBase(memoryDataDst, false, shape, MxBase::TENSOR_DTYPE_FLOAT32); + return APP_ERR_OK; +} + +APP_ERROR StarGanGeneration::Inference(const std::vector<MxBase::TensorBase> &inputs, + std::vector<MxBase::TensorBase> *outputs) { + auto dtypes = model_->GetOutputDataType(); + for (size_t i = 0; i < modelDesc_.outputTensors.size(); ++i) { + std::vector<uint32_t> shape = {}; + for (size_t j = 0; j < modelDesc_.outputTensors[i].tensorDims.size(); ++j) { + shape.push_back((uint32_t)modelDesc_.outputTensors[i].tensorDims[j]); + } + MxBase::TensorBase tensor(shape, dtypes[i], MxBase::MemoryData::MemoryType::MEMORY_DEVICE, deviceId_); + APP_ERROR ret = MxBase::TensorBase::TensorBaseMalloc(tensor); + if (ret != APP_ERR_OK) { + LogError << "TensorBaseMalloc failed, ret=" << ret << "."; + return ret; + } + outputs->push_back(tensor); + } + MxBase::DynamicInfo dynamicInfo = {}; + dynamicInfo.dynamicType = MxBase::DynamicType::STATIC_BATCH; + dynamicInfo.batchSize = 1; + + APP_ERROR ret = model_->ModelInference(inputs, *outputs, dynamicInfo); + if (ret != APP_ERR_OK) { + LogError << "ModelInference failed, ret=" << ret << "."; + return ret; + } + return APP_ERR_OK; +} + +APP_ERROR StarGanGeneration::PostProcess(std::vector<MxBase::TensorBase> outputs, cv::Mat *resultImg) { + APP_ERROR ret = outputs[0].ToHost(); + if (ret != APP_ERR_OK) { + LogError << GetError(ret) << "tohost fail."; + return ret; + } + float *outputPtr = reinterpret_cast<float *>(outputs[0].GetBuffer()); + + size_t H = OUTPUT_HEIGHT, W = OUTPUT_WIDTH, C = CHANNEL; + + for (size_t c = 0; c < C; c++) { + for (size_t h = 0; h < H; h++) { + for (size_t w = 0; w < W; w++) { + float tmpNum = *(outputPtr + (C - c - 1) * (H * W) + h * W + w) * NORMALIZE_STD + NORMALIZE_MEAN; + resultImg->at<cv::Vec3b>(h, w)[c] = static_cast<int>(tmpNum); + } + } + } + + return APP_ERR_OK; +} + +APP_ERROR StarGanGeneration::SaveResult(const cv::Mat &resultImg, const std::string &imgName) { + DIR *dirPtr = opendir(savePath_.c_str()); + if (dirPtr == nullptr) { + std::string path1 = "mkdir -p " + savePath_; + system(path1.c_str()); + } + cv::imwrite(savePath_ + "/" + imgName, resultImg); + return APP_ERR_OK; +} + +APP_ERROR StarGanGeneration::GetImageLabel(std::vector<float> label, MxBase::TensorBase *imgLabels) { + float* mat_data = new float[label.size()]; + for (size_t i = 0; i < label.size(); i++) { + mat_data[i] = label[i]; + } + MxBase::MemoryData memoryDataDst(label.size()*FLOAT32_TYPE_BYTE_NUM, MxBase::MemoryData::MEMORY_DEVICE, deviceId_); + MxBase::MemoryData memoryDataSrc(reinterpret_cast<void*>(&mat_data[0]), + label.size()*FLOAT32_TYPE_BYTE_NUM, MxBase::MemoryData::MEMORY_HOST_MALLOC); + APP_ERROR ret = MxBase::MemoryHelper::MxbsMallocAndCopy(memoryDataDst, memoryDataSrc); + if (ret != APP_ERR_OK) { + LogError << GetError(ret) << "Memory malloc failed."; + return ret; + } + + const std::vector<uint32_t> shape = {1, (unsigned int)label.size()}; + *imgLabels = MxBase::TensorBase(memoryDataDst, false, shape, MxBase::TENSOR_DTYPE_FLOAT32); + return APP_ERR_OK; +} + +APP_ERROR StarGanGeneration::Process(const std::string &imgPath + , const std::string &imgName, const std::vector<float> &label) { + cv::Mat imageMat; + APP_ERROR ret = ReadImage(imgPath, &imageMat); + if (ret != APP_ERR_OK) { + LogError << "ReadImage failed, ret=" << ret << "."; + return ret; + } + ResizeImage(imageMat, &imageMat); + + std::vector<MxBase::TensorBase> inputs = {}; + std::vector<MxBase::TensorBase> outputs = {}; + + MxBase::TensorBase tensorBase; + ret = CVMatToTensorBase(imageMat, &tensorBase); + if (ret != APP_ERR_OK) { + LogError << "CVMatToTensorBase failed, ret=" << ret << "."; + return ret; + } + inputs.push_back(tensorBase); + + MxBase::TensorBase imgLabels; + ret = GetImageLabel(label, &imgLabels); + if (ret != APP_ERR_OK) { + LogError << "Get Image label failed, ret=" << ret << "."; + return ret; + } + inputs.push_back(imgLabels); + auto startTime = std::chrono::high_resolution_clock::now(); + ret = Inference(inputs, &outputs); + auto endTime = std::chrono::high_resolution_clock::now(); + double costMs = std::chrono::duration<double, std::milli>(endTime - startTime).count(); // save time + inferCostTimeMilliSec += costMs; + if (ret != APP_ERR_OK) { + LogError << "Inference failed, ret=" << ret << "."; + return ret; + } + cv::Mat resultImg(OUTPUT_HEIGHT, OUTPUT_WIDTH, CV_8UC3); + ret = PostProcess(outputs, &resultImg); + if (ret != APP_ERR_OK) { + LogError << "PostProcess failed, ret=" << ret << "."; + return ret; + } + ret = SaveResult(resultImg, imgName); + if (ret != APP_ERR_OK) { + LogError << "Save infer results into file failed. ret = " << ret << "."; + return ret; + } + + return APP_ERR_OK; +} diff --git a/research/cv/StarGAN/infer/mxbase/src/StarGanGeneration.h b/research/cv/StarGAN/infer/mxbase/src/StarGanGeneration.h new file mode 100644 index 0000000000000000000000000000000000000000..bc7dca072af72e127419823ed1937de67986fc60 --- /dev/null +++ b/research/cv/StarGAN/infer/mxbase/src/StarGanGeneration.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2021. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MXBASE_STARGANGENERATION_H +#define MXBASE_STARGANGENERATION_H +#include <dirent.h> +#include <memory> +#include <vector> +#include <map> +#include <string> +#include <fstream> +#include <iostream> +#include <opencv2/opencv.hpp> +#include "MxBase/Log/Log.h" +#include "MxBase/DvppWrapper/DvppWrapper.h" +#include "MxBase/ModelInfer/ModelInferenceProcessor.h" +#include "MxBase/DeviceManager/DeviceManager.h" +#include "MxBase/Tensor/TensorContext/TensorContext.h" + +struct InitParam { + uint32_t deviceId; + std::string savePath; + std::string modelPath; +}; + +class StarGanGeneration { + public: + APP_ERROR Init(const InitParam &initParam); + APP_ERROR DeInit(); + APP_ERROR ReadImage(const std::string &imgPath, cv::Mat *imageMat); + APP_ERROR ResizeImage(const cv::Mat &srcImageMat, cv::Mat *dstImageMat); + APP_ERROR CVMatToTensorBase(const cv::Mat &imageMat, MxBase::TensorBase *tensorBase); + APP_ERROR Inference(const std::vector<MxBase::TensorBase> &inputs, + std::vector<MxBase::TensorBase> *outputs); + APP_ERROR PostProcess(std::vector<MxBase::TensorBase> outputs, cv::Mat *resultImg); + APP_ERROR Process(const std::string &imgPath, const std::string &imgName, const std::vector<float> &label); + APP_ERROR GetImageLabel(std::vector<float> label, MxBase::TensorBase *imgLabels); + // get infer time + double GetInferCostMilliSec() const {return inferCostTimeMilliSec;} + + private: + APP_ERROR SaveResult(const cv::Mat &resultImg, const std::string &imgName); + std::shared_ptr<MxBase::DvppWrapper> dvppWrapper_; + std::shared_ptr<MxBase::ModelInferenceProcessor> model_; + std::string savePath_; + MxBase::ModelDesc modelDesc_; + uint32_t deviceId_ = 0; + // infer time + double inferCostTimeMilliSec = 0.0; +}; + + +#endif diff --git a/research/cv/StarGAN/infer/mxbase/src/main.cpp b/research/cv/StarGAN/infer/mxbase/src/main.cpp new file mode 100644 index 0000000000000000000000000000000000000000..831bb839bb4976b73a183f7ac87b6750a1243751 --- /dev/null +++ b/research/cv/StarGAN/infer/mxbase/src/main.cpp @@ -0,0 +1,139 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "StarGanGeneration.h" + +namespace { + std::vector<std::string> SELECTED_ATTRS {"Black_Hair", "Blond_Hair", "Brown_Hair", "Male", "Young"}; + const std::string OM_MODEL_PATH = "../data/model/stargan.om"; +} + +APP_ERROR ScanImages(const std::string &path, std::vector<std::string> *imgFiles) { + DIR *dirPtr = opendir(path.c_str()); + if (dirPtr == nullptr) { + LogError << "opendir failed. dir:" << path; + return APP_ERR_INTERNAL_ERROR; + } + dirent *direntPtr = nullptr; + while ((direntPtr = readdir(dirPtr)) != nullptr) { + std::string fileName = direntPtr->d_name; + if (fileName == "." || fileName == "..") { + continue; + } + + imgFiles->emplace_back(path + "/" + fileName); + } + closedir(dirPtr); + return APP_ERR_OK; +} + +std::vector<std::string> split(std::string str, char ch) { + size_t start = 0; + size_t len = 0; + std::vector<std::string> ret; + for (size_t i = 0; i < str.length(); i++) { + if (str[i] == ch && i+1 < str.length() && str[i+1] == ch) { + continue; + } + if (str[i] == ch) { + ret.push_back(str.substr(start, len)); + start = i+1; + len = 0; + } else { + len++; + } + } + if (start < str.length()) + ret.push_back(str.substr(start, len)); + return ret; +} + + +int main(int argc, char* argv[]) { + if (argc <= 1) { + LogWarn << "Please input image path, such as '../data/test_data/'."; + return APP_ERR_OK; + } + + InitParam initParam = {}; + initParam.deviceId = 0; + initParam.modelPath = OM_MODEL_PATH; + initParam.savePath = "./result"; + auto stargan = std::make_shared<StarGanGeneration>(); + APP_ERROR ret = stargan->Init(initParam); + if (ret != APP_ERR_OK) { + LogError << "stargan init failed, ret=" << ret << "."; + return ret; + } + + // Read the contents of a label + std::string dataPath = argv[1]; + std::string imagePath = dataPath + "/images/"; + std::string labelPath = dataPath + "/anno/list_attr_celeba.txt"; + + std::vector<std::string> imagePathList; + ret = ScanImages(imagePath, &imagePathList); + if (ret != APP_ERR_OK) { + LogError << "stargan init failed, ret=" << ret << "."; + return ret; + } + std::ifstream fin; + std::string s; + fin.open(labelPath); + int i = 0; + int imgNum; + std::map<int, std::string> idx2attr; + std::map<std::string, int> attr2idx; + auto startTime = std::chrono::high_resolution_clock::now(); + + while (getline(fin, s)) { + i++; + if (i == 1) { + imgNum = atoi(s.c_str()); + } else if (i == 2) { + std::vector<std::string> allAttrNames = split(s, ' '); + for (size_t j = 0; j < allAttrNames.size(); j++) { + idx2attr[j] = allAttrNames[j]; + attr2idx[allAttrNames[j]] = j; + } + } else { + std::vector<std::string> eachAttr = split(s, ' '); + // first one is file name + std::string imgName = eachAttr[0]; + std::vector<float> label; + for (size_t j = 0; j < SELECTED_ATTRS.size(); j++) { + if (atoi(eachAttr[attr2idx[SELECTED_ATTRS[j]] + 1].c_str()) == 1) + label.push_back(1.0); + else + label.push_back(0.0); + // label.push_back(atoi(eachAttr[attr2idx[SELECTED_ATTRS[j]] + 1].c_str()) * -0.5); + } + ret = stargan->Process(imagePath + imgName, imgName, label); + if (ret != APP_ERR_OK) { + LogError << "stargan process failed, ret=" << ret << "."; + stargan->DeInit(); + return ret; + } + } + } + fin.close(); + auto endTime = std::chrono::high_resolution_clock::now(); + stargan->DeInit(); + double costMilliSecs = std::chrono::duration<double, std::milli>(endTime - startTime).count(); + double fps = 1000.0 * imgNum / stargan->GetInferCostMilliSec(); + LogInfo << "[Process Delay] cost: " << costMilliSecs << " ms\tfps: " << fps << " imgs/sec"; + return APP_ERR_OK; +} diff --git a/research/cv/StarGAN/infer/sdk/api/infer.py b/research/cv/StarGAN/infer/sdk/api/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..0cf788045e4707e23ebf22a3c8ffbe636b4bffac --- /dev/null +++ b/research/cv/StarGAN/infer/sdk/api/infer.py @@ -0,0 +1,126 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" Model Infer """ +import json +import logging + +import MxpiDataType_pb2 as MxpiDataType +from StreamManagerApi import StreamManagerApi, MxDataInput, InProtobufVector, MxProtobufIn, StringVector +from config import config as cfg + + +class SdkApi: + """ Class SdkApi """ + INFER_TIMEOUT = cfg.INFER_TIMEOUT + STREAM_NAME = cfg.STREAM_NAME + + def __init__(self, pipeline_cfg): + self.pipeline_cfg = pipeline_cfg + self._stream_api = None + self._data_input = None + self._device_id = None + + def init(self): + """ Initialize Stream """ + with open(self.pipeline_cfg, 'r') as fp: + self._device_id = int( + json.loads(fp.read())[self.STREAM_NAME]["stream_config"] + ["deviceId"]) + print(f"The device id: {self._device_id}.") + + # create api + self._stream_api = StreamManagerApi() + + # init stream mgr + ret = self._stream_api.InitManager() + if ret != 0: + print(f"Failed to init stream manager, ret={ret}.") + return False + + # create streams + with open(self.pipeline_cfg, 'rb') as fp: + pipe_line = fp.read() + + ret = self._stream_api.CreateMultipleStreams(pipe_line) + if ret != 0: + print(f"Failed to create stream, ret={ret}.") + return False + + self._data_input = MxDataInput() + return True + + def __del__(self): + if not self._stream_api: + return + + self._stream_api.DestroyAllStreams() + + def _send_protobuf(self, stream_name, plugin_id, element_name, buf_type, + pkg_list): + """ Send Stream """ + protobuf = MxProtobufIn() + protobuf.key = element_name.encode("utf-8") + protobuf.type = buf_type + protobuf.protobuf = pkg_list.SerializeToString() + protobuf_vec = InProtobufVector() + protobuf_vec.push_back(protobuf) + err_code = self._stream_api.SendProtobuf(stream_name, plugin_id, + protobuf_vec) + if err_code != 0: + logging.error( + "Failed to send data to stream, stream_name(%s), plugin_id(%s), element_name(%s), " + "buf_type(%s), err_code(%s).", stream_name, plugin_id, + element_name, buf_type, err_code) + return False + return True + + def send_tensor_input(self, stream_name, plugin_id, element_name, + input_data, input_shape, data_type): + """ Send Tensor """ + tensor_list = MxpiDataType.MxpiTensorPackageList() + tensor_pkg = tensor_list.tensorPackageVec.add() + # init tensor vector + tensor_vec = tensor_pkg.tensorVec.add() + tensor_vec.deviceId = self._device_id + tensor_vec.memType = 0 + tensor_vec.tensorShape.extend(input_shape) + tensor_vec.tensorDataType = data_type + tensor_vec.dataStr = input_data + tensor_vec.tensorDataSize = len(input_data) + + buf_type = b"MxTools.MxpiTensorPackageList" + return self._send_protobuf(stream_name, plugin_id, element_name, + buf_type, tensor_list) + + def get_result(self, stream_name, out_plugin_id=0): + """ Get Result """ + keys = [b"mxpi_tensorinfer0"] + keyVec = StringVector() + for key in keys: + keyVec.push_back(key) + infer_result = self._stream_api.GetProtobuf(stream_name, 0, keyVec) + print(infer_result) + if infer_result.size() == 0: + print("infer_result is null") + exit() + + if infer_result[0].errorCode != 0: + print("GetProtobuf error. errorCode=%d" % ( + infer_result[0].errorCode)) + exit() + + TensorList = MxpiDataType.MxpiTensorPackageList() + TensorList.ParseFromString(infer_result[0].messageBuf) + return TensorList diff --git a/research/cv/StarGAN/infer/sdk/config/config.py b/research/cv/StarGAN/infer/sdk/config/config.py new file mode 100644 index 0000000000000000000000000000000000000000..275cdb4abf5de8a08791ae2a7f524e2937260a7b --- /dev/null +++ b/research/cv/StarGAN/infer/sdk/config/config.py @@ -0,0 +1,25 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" Model Config """ +MODEL_WIDTH = 128 +MODEL_HEIGHT = 128 +STREAM_NAME = "stargan" + +INFER_TIMEOUT = 100000 +SELECTED_ATTRS = ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young'] + +TENSOR_DTYPE_FLOAT32 = 0 +TENSOR_DTYPE_FLOAT16 = 1 +TENSOR_DTYPE_INT8 = 2 diff --git a/research/cv/StarGAN/infer/sdk/main.py b/research/cv/StarGAN/infer/sdk/main.py new file mode 100644 index 0000000000000000000000000000000000000000..4432710247722bd6b78f72e212503222eaa7552e --- /dev/null +++ b/research/cv/StarGAN/infer/sdk/main.py @@ -0,0 +1,162 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" Model Main """ +import argparse +import time +import os +import numpy as np +import cv2 +from api.infer import SdkApi +from config import config as cfg + +def parser_args(): + """ Args Setting """ + parser = argparse.ArgumentParser(description="stgan inference") + + parser.add_argument("--img_path", + type=str, + required=True, + help="image directory.") + parser.add_argument( + "--pipeline_path", + type=str, + required=False, + default="../data/config/stargan.pipeline", + help="image file path. The default is 'config/stgan.pipeline'. ") + parser.add_argument( + "--model_type", + type=str, + required=False, + default="dvpp", + help= + "rgb: high-precision, dvpp: high performance. The default is 'dvpp'.") + parser.add_argument( + "--infer_result_dir", + type=str, + required=False, + default="./result", + help= + "cache dir of inference result. The default is './result'." + ) + + parser.add_argument("--ann_file", + type=str, + required=False, + help="eval ann_file.") + + args_ = parser.parse_args() + return args_ + +def get_labels(img_dir): + """ Get Labels Setting """ + # labels preprocess + selected_attrs = cfg.SELECTED_ATTRS + lines = [ + line.rstrip() for line in open( + os.path.join(img_dir, 'anno', 'list_attr_celeba.txt'), 'r') + ] + + all_attr_names = lines[1].split() + attr2idx = {} + idx2attr = {} + for i, attr_name in enumerate(all_attr_names): + attr2idx[attr_name] = i + idx2attr[i] = attr_name + lines = lines[2:] + items = {} + for i, line in enumerate(lines): + split = line.split() + filename = split[0] + values = split[1:] + label = [] + for attr_name in selected_attrs: + idx = attr2idx[attr_name] + label.append(1.0 if values[idx] == '1' else 0.0) + items[filename] = np.array(label).astype(np.float32) + return items + +def process_img(img_file): + """ Preprocess Image """ + print(img_file) + img = cv2.imread(img_file) + model_img = cv2.resize(img, (cfg.MODEL_WIDTH, cfg.MODEL_HEIGHT)) + img_ = model_img[:, :, ::-1].transpose((2, 0, 1)) + img_ = np.expand_dims(img_, axis=0) + img_ = np.array((img_-127.5)/127.5).astype(np.float32) + return img_ + +def decode_image(img): + """ Decode Image """ + mean = 0.5 * 255 + std = 0.5 * 255 + return (img * std + mean).astype(np.uint8).transpose( + (1, 2, 0)) + +def image_inference(pipeline_path, stream_name, img_dir, result_dir, + replace_last, model_type): + """ Image Inference """ + # init stream manager + sdk_api = SdkApi(pipeline_path) + if not sdk_api.init(): + exit(-1) + + if not os.path.exists(result_dir): + os.makedirs(result_dir) + + img_data_plugin_id = 0 + img_label_plugin_id = 1 + + label_items = get_labels(img_dir) + + print(f"\nBegin to inference for {img_dir}.\n\n") + + file_list = os.listdir(os.path.join(img_dir, 'images')) + for _, file_name in enumerate(file_list): + if not file_name.lower().endswith((".jpg", "jpeg")): + continue + file_path = os.path.join(img_dir, 'images', file_name) + save_path = os.path.join(result_dir, + f"{os.path.splitext(file_name)[0]}.jpg") + if not replace_last and os.path.exists(save_path): + print(f"The infer result image({save_path}) has existed, will be skip.") + continue + + img_np = process_img(file_path) + + start_time = time.time() + sdk_api.send_tensor_input(stream_name, img_data_plugin_id, "appsrc0", + img_np.tobytes(), img_np.shape, cfg.TENSOR_DTYPE_FLOAT32) + + # set label data + label_dim = np.expand_dims(label_items[file_name], axis=0) + input_shape = 4 + sdk_api.send_tensor_input(stream_name, img_label_plugin_id, "appsrc1", + label_dim.tobytes(), [1, input_shape], cfg.TENSOR_DTYPE_FLOAT32) + + result = sdk_api.get_result(stream_name) + end_time = time.time() - start_time + print(f"The image({save_path}) inference time is {end_time}") + data = np.frombuffer(result.tensorPackageVec[0].tensorVec[0].dataStr, dtype=np.float32) + data = data.reshape(3, 128, 128) + img = decode_image(data) + img = img[:, :, ::-1] + cv2.imwrite(save_path, img) + +if __name__ == "__main__": + args = parser_args() + args.replace_last = True + args.stream_name = "stargan".encode("utf-8") + image_inference(args.pipeline_path, args.stream_name, args.img_path, + args.infer_result_dir, args.replace_last, args.model_type) diff --git a/research/cv/StarGAN/infer/sdk/run.sh b/research/cv/StarGAN/infer/sdk/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..32d0b1bde722458e4ceb4b2c89a7bac7a0fa1407 --- /dev/null +++ b/research/cv/StarGAN/infer/sdk/run.sh @@ -0,0 +1,28 @@ +#!/bin/bash + +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +image_path=$1 +result_dir=$2 + +set -e + +# Simple log helper functions +info() { echo -e "\033[1;34m[INFO ][MxStream] $1\033[1;37m" ; } +warn() { echo >&2 -e "\033[1;31m[WARN ][MxStream] $1\033[1;37m" ; } + +export GST_PLUGIN_SCANNER=${MX_SDK_HOME}/opensource/libexec/gstreamer-1.0/gst-plugin-scanner + +python3 main.py --img_path=$image_path --infer_result_dir=$result_dir +exit 0 \ No newline at end of file diff --git a/research/cv/StarGAN/modelarts/__init__.py b/research/cv/StarGAN/modelarts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/research/cv/StarGAN/modelarts/train_start.py b/research/cv/StarGAN/modelarts/train_start.py new file mode 100644 index 0000000000000000000000000000000000000000..4bfb63c9826df874ea71ae7f05409762d7861ce1 --- /dev/null +++ b/research/cv/StarGAN/modelarts/train_start.py @@ -0,0 +1,295 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Train the model.""" +from time import time +import os +import argparse +import ast +import glob +import numpy as np +import mindspore.common.dtype as mstype +from mindspore import nn +from mindspore import Tensor, context +from mindspore.common import set_seed +from mindspore.context import ParallelMode +from mindspore.communication.management import init, get_rank +from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, _InternalCallbackParam, RunContext +from mindspore.train.serialization import export, load_param_into_net +from src.dataset import dataloader +from src.utils import get_network, resume_model +from src.cell import TrainOneStepCellGen, TrainOneStepCellDis +from src.loss import GeneratorLoss, DiscriminatorLoss, ClassificationLoss, WGANGPGradientPenalty +from src.reporter import Reporter + +set_seed(1) + +# Modelarts +parser = argparse.ArgumentParser(description='StarGAN_args') +parser.add_argument('--modelarts', type=ast.literal_eval, default=True, help='Dataset path') +parser.add_argument('--data_url', type=str, default=None, help='Dataset path') +parser.add_argument('--train_url', type=str, default=None, help='Train output path') +# Model configuration. +parser.add_argument('--c_dim', type=int, default=5, help='dimension of domain labels (1st dataset)') +parser.add_argument('--c2_dim', type=int, default=7, help='dimension of domain labels (2nd dataset)') +parser.add_argument('--celeba_crop_size', type=int, default=178, help='crop size for the CelebA dataset') +parser.add_argument('--rafd_crop_size', type=int, default=256, help='crop size for the RaFD dataset') +parser.add_argument('--image_size', type=int, default=128, help='image resolution') +parser.add_argument('--g_conv_dim', type=int, default=64, help='number of conv filters in the first layer of G') +parser.add_argument('--d_conv_dim', type=int, default=64, help='number of conv filters in the first layer of D') +parser.add_argument('--g_repeat_num', type=int, default=6, help='number of residual blocks in G') +parser.add_argument('--d_repeat_num', type=int, default=6, help='number of strided conv layers in D') +parser.add_argument('--lambda_cls', type=float, default=1, help='weight for domain classification loss') +parser.add_argument('--lambda_rec', type=float, default=10, help='weight for reconstruction loss') +parser.add_argument('--lambda_gp', type=float, default=10, help='weight for gradient penalty') +# Training configuration. +parser.add_argument('--dataset', type=str, default='CelebA', choices=['CelebA', 'RaFD', 'Both']) +parser.add_argument('--batch_size', type=int, default=4, help='mini-batch size') +parser.add_argument('--num_iters', type=int, default=200000, help='number of total iterations for training D') +parser.add_argument('--epochs', type=int, default=59, help='number of epoch') +parser.add_argument('--num_iters_decay', type=int, default=100000, help='number of iterations for decaying lr') +parser.add_argument('--g_lr', type=float, default=0.0001, help='learning rate for G') +parser.add_argument('--d_lr', type=float, default=0.0001, help='learning rate for D') +parser.add_argument('--n_critic', type=int, default=5, help='number of D updates per each G update') +parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for Adam optimizer') +parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer') +parser.add_argument('--resume_iters', type=int, default=200000, help='resume training from this step') +parser.add_argument('--selected_attrs', '--list', nargs='+', help='selected attributes for the CelebA dataset', + default=['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young']) +parser.add_argument('--init_type', type=str, default='normal', choices=("normal", "xavier"), + help='network initialization, default is normal.') +parser.add_argument('--init_gain', type=float, default=0.02, + help='scaling factor for normal, xavier and orthogonal, default is 0.02.') +# Test configuration. +parser.add_argument('--test_iters', type=int, default=200000, help='test model from this step') +# Train Device. +parser.add_argument('--num_workers', type=int, default=8) +parser.add_argument('--mode', type=str, default='train', choices=['train', 'test']) +parser.add_argument('--device_target', type=str, default='Ascend') +parser.add_argument("--run_distribute", type=int, default=0, help="Run distribute, default: false.") +parser.add_argument("--device_id", type=int, default=0, help="device id, default: 0.") +parser.add_argument("--device_num", type=int, default=1, help="number of device, default: 0.") +parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default: 0.") +# Directories. +parser.add_argument('--celeba_image_dir', type=str, default=r'/home/data/celeba/images') +parser.add_argument('--attr_path', type=str, default=r'/home/data/celeba/list_attr_celeba.txt') +parser.add_argument('--rafd_image_dir', type=str, default='data/RaFD/train') +parser.add_argument('--log_dir', type=str, default='stargan/logs') +parser.add_argument('--model_save_dir', type=str, default='./models/') +parser.add_argument('--result_dir', type=str, default='./results') +# Step size. +parser.add_argument('--log_step', type=int, default=10) +parser.add_argument('--sample_step', type=int, default=5000) +parser.add_argument('--model_save_step', type=int, default=5000) +parser.add_argument('--lr_update_step', type=int, default=1000) +# export +parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', \ + help='file format') +#args_opt = parser.parse_args() +args_opt, unparsed = parser.parse_known_args() +config = args_opt + +if __name__ == '__main__': + + #config = get_config() + if args_opt.modelarts: + import moxing as mox + device_id = int(os.getenv('DEVICE_ID')) + device_num = int(os.getenv('RANK_SIZE')) + context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False) + context.set_context(device_id=device_id) + local_data_url = './cache/data' + local_train_url = '/cache/ckpt' + if not os.path.isdir(local_data_url): + os.makedirs(local_data_url) + if not os.path.isdir(local_train_url): + os.makedirs(local_train_url) + + # local_data_url = os.path.join(local_data_url, str(device_id)) + # local_train_url = os.path.join(local_train_url, str(device_id)) + + # unzip data + path = os.getcwd() + print("cwd: %s" % path) + data_url = 'obs://data/CelebA/' + + data_name = '/celeba.zip' + print('listdir1: %s' % os.listdir('./')) + + a1time = time() + mox.file.copy_parallel(args_opt.data_url, local_data_url) + print('listdir2: %s' % os.listdir(local_data_url)) + b1time = time() + print('time1:', b1time - a1time) + + a2time = time() + zip_command = "unzip -o %s -d %s" % (local_data_url + data_name, local_data_url) + if os.system(zip_command) == 0: + print('Successful backup') + else: + print('FAILED backup') + b2time = time() + print('time2:', b2time - a2time) + print('listdir3: %s' % os.listdir(local_data_url)) + + # Device Environment + if config.run_distribute: + if config.device_target == "Ascend": + rank = device_id + # device_num = device_num + context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, + gradients_mean=True) + init() + else: + rank = 0 + device_num = 1 + + data_path = local_data_url + '/celeba/images' + attr_path = local_data_url + '/celeba/list_attr_celeba.txt' + dataset, length = dataloader(img_path=data_path, + attr_path=attr_path, + batch_size=config.batch_size, + selected_attr=config.selected_attrs, + device_num=config.num_workers, + dataset=config.dataset, + mode=config.mode, + shuffle=True) + + + else: + context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, + device_id=config.device_id, save_graphs=False) + if args_opt.run_distribute: + if os.getenv("DEVICE_ID", "not_set").isdigit(): + context.set_context(device_id=int(os.getenv("DEVICE_ID"))) + device_num = config.device_num + context.reset_auto_parallel_context() + context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, + device_num=device_num) + init() + + rank = get_rank() + + data_path = config.celeba_image_dir + attr_path = config.attr_path + local_train_url = config.model_save_dir + dataset, length = dataloader(img_path=data_path, + attr_path=attr_path, + batch_size=config.batch_size, + selected_attr=config.selected_attrs, + device_num=config.device_num, + dataset=config.dataset, + mode=config.mode, + shuffle=True) + print(length) + dataset_iter = dataset.create_dict_iterator() + + # Get and initial network + generator, discriminator = get_network(config) + + cls_loss = ClassificationLoss() + wgan_loss = WGANGPGradientPenalty(discriminator) + + # Define network with loss + G_loss_cell = GeneratorLoss(config, generator, discriminator) + D_loss_cell = DiscriminatorLoss(config, generator, discriminator) + + # Define Optimizer + star_iter = 0 + iter_sum = config.num_iters + + Optimizer_G = nn.Adam(generator.trainable_params(), learning_rate=config.g_lr, + beta1=config.beta1, beta2=config.beta2) + Optimizer_D = nn.Adam(discriminator.trainable_params(), learning_rate=config.d_lr, + beta1=config.beta1, beta2=config.beta2) + + # Define One step train + G_trainOneStep = TrainOneStepCellGen(G_loss_cell, Optimizer_G) + D_trainOneStep = TrainOneStepCellDis(D_loss_cell, Optimizer_D) + + # Train + G_trainOneStep.set_train() + D_trainOneStep.set_train() + + print('Start Training') + + reporter = Reporter(config) + + ckpt_config = CheckpointConfig(save_checkpoint_steps=config.model_save_step) + ckpt_cb_g = ModelCheckpoint(config=ckpt_config, directory=local_train_url, prefix='Generator') + ckpt_cb_d = ModelCheckpoint(config=ckpt_config, directory=local_train_url, prefix='Discriminator') + + cb_params_g = _InternalCallbackParam() + cb_params_g.train_network = generator + cb_params_g.cur_step_num = 0 + cb_params_g.batch_num = 4 + cb_params_g.cur_epoch_num = 0 + + cb_params_d = _InternalCallbackParam() + cb_params_d.train_network = discriminator + cb_params_d.cur_step_num = 0 + cb_params_d.batch_num = config.batch_size + cb_params_d.cur_epoch_num = 0 + run_context_g = RunContext(cb_params_g) + run_context_d = RunContext(cb_params_d) + ckpt_cb_g.begin(run_context_g) + ckpt_cb_d.begin(run_context_d) + start = time() + + for iterator in range(config.num_iters): + data = next(dataset_iter) + x_real = Tensor(data['image'], mstype.float32) + c_trg = Tensor(data['attr'], mstype.float32) + c_org = Tensor(data['attr'], mstype.float32) + np.random.shuffle(c_trg) + + d_out = D_trainOneStep(x_real, c_org, c_trg) + + if (iterator + 1) % config.n_critic == 0: + g_out = G_trainOneStep(x_real, c_org, c_trg) + + if (iterator + 1) % config.log_step == 0: + reporter.print_info(start, iterator, g_out, d_out) + _, _, dict_G, dict_D = reporter.return_loss_array(g_out, d_out) + + if (iterator + 1) % config.model_save_step == 0: + cb_params_d.cur_step_num = iterator + 1 + cb_params_d.batch_num = iterator + 2 + cb_params_g.cur_step_num = iterator + 1 + cb_params_g.batch_num = iterator + 2 + ckpt_cb_g.step_end(run_context_g) + ckpt_cb_d.step_end(run_context_d) + + if args_opt.modelarts: + print('output dir3: %s' % os.listdir(local_train_url)) + ckpt_list = glob.glob("/cache/ckpt/*.ckpt") + if not ckpt_list: + print("ckpt file not generated.") + + ckpt_list.sort(key=os.path.getmtime) + ckpt_model = ckpt_list[-1] + print("checkpoint path", ckpt_model) + context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target) + G, D = get_network(config) + # Use BatchNorm2d with batchsize=1, affine=False, training=True instead of InstanceNorm2d + # Use real mean and varance rather than moving_men and moving_varance in BatchNorm2d + G.set_train(True) + param_G, _ = resume_model(config, G, D) + load_param_into_net(G, param_G) + input_array = Tensor(np.random.uniform(-1.0, 1.0, size=(1, 3, 128, 128)).astype(np.float32)) + input_label = Tensor(np.random.uniform(-1.0, 1.0, size=(1, 5)).astype(np.float32)) + G_file = f"StarGAN_Generator" + export(G, input_array, input_label, file_name='/cache/ckpt/stargan', file_format=config.file_format) + + mox.file.copy_parallel(local_train_url, args_opt.train_url)