Skip to content
Snippets Groups Projects
Commit 5e095c04 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!1192 update mindir preprocess

Merge pull request !1192 from luoyang/mymaster
parents c2726e1a 43d7abf4
No related branches found
No related tags found
No related merge requests found
......@@ -13,24 +13,26 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <sys/time.h>
#include <gflags/gflags.h>
#include <dirent.h>
#include <iostream>
#include <string>
#include <gflags/gflags.h>
#include <sys/time.h>
#include <fstream>
#include <algorithm>
#include <iosfwd>
#include <vector>
#include <fstream>
#include <iostream>
#include <sstream>
#include <string>
#include <vector>
#include "include/api/model.h"
#include "inc/utils.h"
#include "include/api/context.h"
#include "include/api/types.h"
#include "include/api/model.h"
#include "include/api/serialization.h"
#include "inc/utils.h"
#include "include/api/types.h"
using mindspore::Context;
using mindspore::DataType;
using mindspore::GraphCell;
using mindspore::Model;
using mindspore::ModelType;
......@@ -39,6 +41,7 @@ using mindspore::Serialization;
using mindspore::Status;
DEFINE_string(mindir_path, "", "mindir path");
DEFINE_string(batch_mindir_path, "", "mindir path");
DEFINE_string(dataset_path, ".", "dataset path");
DEFINE_string(image_path, ".", "image path");
DEFINE_int32(device_id, 0, "device id");
......@@ -49,77 +52,103 @@ int main(int argc, char **argv) {
std::cout << "Invalid mindir" << std::endl;
return 1;
}
if (RealPath(FLAGS_batch_mindir_path).empty()) {
std::cout << "Invalid mindir" << std::endl;
return 1;
}
auto context = std::make_shared<Context>();
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
ascend310->SetDeviceID(FLAGS_device_id);
context->MutableDeviceInfo().push_back(ascend310);
mindspore::Graph graph;
Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph);
std::vector<mindspore::Graph> graph;
Serialization::Load({FLAGS_mindir_path, FLAGS_batch_mindir_path}, ModelType::kMindIR, &graph);
Model model;
Status ret = model.Build(GraphCell(graph), context);
Status ret = model.Build(GraphCell(graph[0]), context);
if (ret.IsError()) {
std::cout << "ERROR: Build failed." << std::endl;
return 1;
}
if (!model.HasPreprocess()) {
std::cout << "data preprocess not exists in MindIR " << std::endl;
return 1;
}
std::cout << "Check if data preprocess exists: " << model.HasPreprocess() << std::endl;
Model model2;
ret = model2.Build(GraphCell(graph[1]), context);
if (ret.IsError()) {
std::cout << "ERROR: Build failed." << std::endl;
return 1;
}
if (!model2.HasPreprocess()) {
std::cout << "data preprocess not exists in MindIR " << std::endl;
return 1;
}
// way 1, construct a common MSTensor
std::vector<MSTensor> inputs1 = {ReadFileToTensor(FLAGS_image_path)};
// preprocess and predict with batch 1
std::vector<std::vector<MSTensor>> inputs1;
MSTensor *t1 = MSTensor::CreateTensorFromFile(FLAGS_image_path);
inputs1 = {{*t1}};
std::vector<MSTensor> outputs1;
ret = model.PredictWithPreprocess(inputs1, &outputs1);
ret = model.Preprocess(inputs1, &outputs1);
if (ret.IsError()) {
std::cout << "ERROR: Predict failed." << std::endl;
std::cout << ret.GetErrDescription() << std::endl;
std::cout << "ERROR: Preprocess failed." << std::endl;
return 1;
}
std::ofstream o1("result1.txt", std::ios::out);
o1.write(reinterpret_cast<const char *>(outputs1[0].MutableData()), std::streamsize(outputs1[0].DataSize()));
// way 2, construct a pointer of MSTensor, be careful of destroy
MSTensor *tensor = MSTensor::CreateImageTensor(FLAGS_image_path);
std::vector<MSTensor> inputs2 = {*tensor};
MSTensor::DestroyTensorPtr(tensor);
std::vector<MSTensor> outputs2;
ret = model.PredictWithPreprocess(inputs2, &outputs2);
std::vector<MSTensor> outputs1_1;
ret = model.Predict(outputs1, &outputs1_1);
if (ret.IsError()) {
std::cout << ret.GetErrDescription() << std::endl;
std::cout << "ERROR: Predict failed." << std::endl;
return 1;
}
std::ofstream o2("result2.txt", std::ios::out);
o2.write(reinterpret_cast<const char *>(outputs2[0].MutableData()), std::streamsize(outputs2[0].DataSize()));
// way 3, split preprocess and predict
std::vector<MSTensor> inputs3 = {ReadFileToTensor(FLAGS_image_path)};
std::vector<MSTensor> outputs3;
std::ofstream o1("result1.txt", std::ios::out);
o1.write(reinterpret_cast<const char *>(outputs1_1[0].MutableData()),
std::streamsize(outputs1_1[0].DataSize()));
ret = model.Preprocess(inputs3, &outputs3);
if (ret.IsError()) {
std::cout << "ERROR: Preprocess failed." << std::endl;
return 1;
// check shape
auto shape1 = outputs1_1[0].Shape();
std::cout << "outputs1_1 shape: " << std::endl;
for (auto s : shape1) {
std::cout << s << ", ";
}
std::cout << std::endl;
MSTensor::DestroyTensorPtr(t1);
std::vector<MSTensor> outputs4;
ret = model.Predict(outputs3, &outputs4);
// preprocess and predict with batch 3
std::vector<std::vector<MSTensor>> inputs2;
MSTensor *t2 = MSTensor::CreateTensorFromFile(FLAGS_image_path);
MSTensor *t3 = MSTensor::CreateTensorFromFile(FLAGS_image_path);
MSTensor *t4 = MSTensor::CreateTensorFromFile(FLAGS_image_path);
inputs2 = {{*t2}, {*t3}, {*t4}};
std::vector<MSTensor> outputs2;
ret = model2.PredictWithPreprocess(inputs2, &outputs2);
if (ret.IsError()) {
std::cout << "ERROR: Preprocess failed." << std::endl;
std::cout << ret.GetErrDescription() << std::endl;
std::cout << "ERROR: Predict failed." << std::endl;
return 1;
}
std::ofstream o3("result3.txt", std::ios::out);
o3.write(reinterpret_cast<const char *>(outputs4[0].MutableData()), std::streamsize(outputs4[0].DataSize()));
std::ofstream o2("result2.txt", std::ios::out);
o2.write(reinterpret_cast<const char *>(outputs2[0].MutableData()),
std::streamsize(outputs2[0].DataSize()));
// check shape
auto shape = outputs1[0].Shape();
std::cout << "Output Shape: " << std::endl;
for (auto s : shape) {
auto shape2 = outputs2[0].Shape();
std::cout << "outputs2 shape: " << std::endl;
for (auto s : shape2) {
std::cout << s << ", ";
}
std::cout << std::endl;
MSTensor::DestroyTensorPtr(t2);
MSTensor::DestroyTensorPtr(t3);
MSTensor::DestroyTensorPtr(t4);
return 0;
}
......@@ -16,9 +16,8 @@
resnext export mindir.
"""
import os
import numpy as np
from mindspore.common import dtype as mstype
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
from mindspore import context, load_checkpoint, load_param_into_net, export
from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.image_classification import get_network
......@@ -46,12 +45,9 @@ def run_export():
else:
auto_mixed_precision(network)
network.set_train(False)
input_shp = [config.batch_size, 3, config.height, config.width]
de_dataset = classification_dataset("src/", config.image_size, config.per_batch_size, 1, 0, 1, mode="eval")
input_array = Tensor(np.random.uniform(-1.0, 1.0, size=input_shp).astype(np.float32))
export(network, input_array, file_name=config.file_name, file_format=config.file_format, dataset=de_dataset)
de_dataset = classification_dataset("src/", config.image_size, 1, 1, 0, 1, mode="eval")
export(network, de_dataset, file_name=config.file_name, file_format=config.file_format)
if __name__ == '__main__':
run_export()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment