Skip to content
Snippets Groups Projects
Commit 71c7242f authored by lixinqi's avatar lixinqi
Browse files

wrapper Error for ErrorProto

parent 21d934a0
No related branches found
No related tags found
No related merge requests found
#include "oneflow/core/common/error.h"
namespace oneflow {
Error Error::Ok() { return std::make_shared<ErrorProto>(); }
Error Error::ProtoParseFailedError() {
auto error = std::make_shared<ErrorProto>();
error->mutable_proto_parse_failed_error();
return error;
}
Error Error::JobSetEmpty() {
auto error = std::make_shared<ErrorProto>();
error->set_job_build_and_infer_error(JobBuildAndInferError::kJobSetEmpty);
return error;
}
Error Error::DeviceTagNotFound() {
auto error = std::make_shared<ErrorProto>();
error->set_job_build_and_infer_error(JobBuildAndInferError::kDeviceTagNotFound);
return error;
}
Error Error::JobTypeNotSet() {
auto error = std::make_shared<ErrorProto>();
error->set_job_build_and_infer_error(JobBuildAndInferError::kJobTypeNotSet);
return error;
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_COMMON_ERROR_H_
#define ONEFLOW_CORE_COMMON_ERROR_H_
#include "oneflow/core/common/error.pb.h"
namespace oneflow {
class Error final {
public:
Error(const std::shared_ptr<ErrorProto>& error_proto) : error_proto_(error_proto) {}
Error(const Error&) = default;
~Error() = default;
static Error Ok();
static Error ProtoParseFailedError();
static Error JobSetEmpty();
static Error DeviceTagNotFound();
static Error JobTypeNotSet();
std::shared_ptr<ErrorProto> error_proto() const { return error_proto_; }
ErrorProto* operator->() const { return error_proto_.get(); }
private:
std::shared_ptr<ErrorProto> error_proto_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_ERROR_H_
......@@ -67,7 +67,7 @@ message ProtoParseFailedError {
message UnkownError { }
message Error {
message ErrorProto {
optional string msg = 1 [default = ""];
oneof error_type {
ConfigAssertFailedError config_assert_failed_error = 2;
......
......@@ -2,33 +2,33 @@
namespace oneflow {
Error ErrorUtil::Ok() { return Error(); }
std::shared_ptr<ErrorProto> ErrorUtil::Ok() { return std::make_shared<ErrorProto>(); }
Error ErrorUtil::ProtoParseFailedError(const std::string& msg) {
Error error;
error.set_msg(msg);
error.mutable_proto_parse_failed_error();
std::shared_ptr<ErrorProto> ErrorUtil::ProtoParseFailedError(const std::string& msg) {
auto error = std::make_shared<ErrorProto>();
error->set_msg(msg);
error->mutable_proto_parse_failed_error();
return error;
}
Error ErrorUtil::JobSetEmpty(const std::string& msg) {
Error error;
error.set_msg(msg);
error.set_job_build_and_infer_error(JobBuildAndInferError::kJobSetEmpty);
std::shared_ptr<ErrorProto> ErrorUtil::JobSetEmpty(const std::string& msg) {
auto error = std::make_shared<ErrorProto>();
error->set_msg(msg);
error->set_job_build_and_infer_error(JobBuildAndInferError::kJobSetEmpty);
return error;
}
Error ErrorUtil::DeviceTagNotFound(const std::string& msg) {
Error error;
error.set_msg(msg);
error.set_job_build_and_infer_error(JobBuildAndInferError::kDeviceTagNotFound);
std::shared_ptr<ErrorProto> ErrorUtil::DeviceTagNotFound(const std::string& msg) {
auto error = std::make_shared<ErrorProto>();
error->set_msg(msg);
error->set_job_build_and_infer_error(JobBuildAndInferError::kDeviceTagNotFound);
return error;
}
Error ErrorUtil::JobTypeNotSet(const std::string& msg) {
Error error;
error.set_msg(msg);
error.set_job_build_and_infer_error(JobBuildAndInferError::kJobTypeNotSet);
std::shared_ptr<ErrorProto> ErrorUtil::JobTypeNotSet(const std::string& msg) {
auto error = std::make_shared<ErrorProto>();
error->set_msg(msg);
error->set_job_build_and_infer_error(JobBuildAndInferError::kJobTypeNotSet);
return error;
}
......
......@@ -6,12 +6,12 @@
namespace oneflow {
struct ErrorUtil final {
static Error Ok();
static Error ProtoParseFailedError(const std::string& msg);
static Error JobSetEmpty(const std::string& msg);
static Error DeviceTagNotFound(const std::string& msg);
static std::shared_ptr<ErrorProto> Ok();
static std::shared_ptr<ErrorProto> ProtoParseFailedError(const std::string& msg);
static std::shared_ptr<ErrorProto> JobSetEmpty(const std::string& msg);
static std::shared_ptr<ErrorProto> DeviceTagNotFound(const std::string& msg);
static Error JobTypeNotSet(const std::string& msg);
static std::shared_ptr<ErrorProto> JobTypeNotSet(const std::string& msg);
};
} // namespace oneflow
......
......@@ -3,37 +3,35 @@
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/either_ptr.h"
#include "oneflow/core/common/error.pb.h"
#include "oneflow/core/common/error.h"
namespace oneflow {
template<typename T>
class MaybeBase {
public:
MaybeBase(const std::shared_ptr<const Error>& error) : data_or_error_(error) {}
MaybeBase(const std::shared_ptr<T>& data) : data_or_error_(data) {}
MaybeBase(const std::shared_ptr<ErrorProto>& error) : data_or_error_(error) {}
MaybeBase(const MaybeBase<T>&) = default;
virtual ~MaybeBase() = default;
~MaybeBase() = default; // no virtual is what we want
bool IsOk() const { return data_or_error_.template Has<T>(); }
const std::shared_ptr<T>& data() const { return data_or_error_.template Get<T>(); }
std::shared_ptr<const Error> error() const { return data_or_error_.template Get<const Error>(); }
std::shared_ptr<ErrorProto> error() const { return data_or_error_.template Get<ErrorProto>(); }
private:
EitherPtr<T, const Error> data_or_error_;
EitherPtr<T, ErrorProto> data_or_error_;
};
template<typename T>
class Maybe final : public MaybeBase<T> {
public:
Maybe(const Error& error) : MaybeBase<T>(std::make_shared<const Error>(error)) {}
Maybe(const T& data) : MaybeBase<T>(std::make_shared<T>(data)) {}
Maybe(const std::shared_ptr<const Error>& error) : MaybeBase<T>(error) {}
Maybe(const Error& error) : MaybeBase<T>(error.error_proto()) {}
Maybe(const std::shared_ptr<T>& data) : MaybeBase<T>(data) {}
Maybe(const Error* error) : MaybeBase<T>(std::shared_ptr<const Error>(error)) {}
Maybe(T* data) : MaybeBase<T>(std::shared_ptr<T>(data)) {}
Maybe(const std::shared_ptr<ErrorProto>& error) : MaybeBase<T>(error) {}
Maybe(const Maybe<T>&) = default;
~Maybe() override = default;
~Maybe() = default;
static Maybe<T> Ok() { return Maybe<T>(); }
};
......@@ -41,19 +39,16 @@ class Maybe final : public MaybeBase<T> {
template<>
class Maybe<void> final : public MaybeBase<void> {
public:
Maybe(const Error& error) : MaybeBase<void>(std::make_shared<const Error>(error)) {
CheckError();
}
Maybe(const std::shared_ptr<const Error>& error) : MaybeBase<void>(error) { CheckError(); }
Maybe(const Error* error) : MaybeBase<void>(std::shared_ptr<const Error>(error)) { CheckError(); }
Maybe(const Error& error) : MaybeBase<void>(error.error_proto()) { CheckError(); }
Maybe(const std::shared_ptr<ErrorProto>& error) : MaybeBase<void>(error) { CheckError(); }
Maybe(const Maybe<void>&) = default;
~Maybe() override = default;
~Maybe() = default;
static Maybe<void> Ok() { return Maybe<void>(); }
private:
Maybe() : MaybeBase<void>(std::shared_ptr<void>()) {}
void CheckError() const { CHECK_NE(error()->error_type_case(), Error::ERROR_TYPE_NOT_SET); }
void CheckError() const { CHECK_NE(error()->error_type_case(), ErrorProto::ERROR_TYPE_NOT_SET); }
};
template<typename T>
......@@ -143,14 +138,14 @@ std::string Sprintf(const Args&... args) {
return oss.str(); \
}()
#define CHECK_OR_RETURN(expr, ...) \
{ \
if (!(expr)) { \
Error error; \
error.set_msg(GEN_ERROR_MSG(ErrorType::kCondition, #expr, __VA_ARGS__)); \
error.mutable_unknown_error(); \
return Maybe<void>(error); \
} \
#define CHECK_OR_RETURN(expr, ...) \
{ \
if (!(expr)) { \
auto error = std::make_shared<ErrorProto>(); \
error->set_msg(GEN_ERROR_MSG(ErrorType::kCondition, #expr, __VA_ARGS__)); \
error->mutable_unknown_error(); \
return error; \
} \
}
#define CHECK_EQ_OR_RETURN(lhs, rhs, ...) CHECK_OR_RETURN(OF_TEST_EQ(lhs, rhs), __VA_ARGS__)
......@@ -168,11 +163,12 @@ std::string Sprintf(const Args&... args) {
#define CHECK_STREQ_OR_RETURN(lhs, rhs, ...) \
CHECK_EQ_OR_RETURN(std::string(lhs), std::string(rhs), __VA_ARGS__)
#define ENFORCE_THEN_RETURN(type, ...) \
{ \
Error error; \
error.set_msg(GEN_ERROR_MSG(ErrorType::kEnforce, #type, __VA_ARGS__)); \
return Maybe<void>(error); \
#define ENFORCE_THEN_RETURN(type, ...) \
{ \
auto error = std::make_shared<ErrorProto>(); \
error->set_msg(GEN_ERROR_MSG(ErrorType::kEnforce, #type, __VA_ARGS__)); \
error->mutable_unknown_error(); \
return error; \
}
#define UNSUPPORTED_THEN_RETURN(...) ENFORCE_THEN_RETURN(OF_TEST_UNSUPPORTED, __VA_ARGS__)
......
......@@ -3,10 +3,11 @@
namespace oneflow {
Error GenJobBuildAndInferError(JobBuildAndInferError err_code, std::string msg) {
Error err;
err.set_msg(msg);
err.set_job_build_and_infer_error(err_code);
std::shared_ptr<ErrorProto> GenJobBuildAndInferError(JobBuildAndInferError err_code,
std::string msg) {
auto err = std::make_shared<ErrorProto>();
err->set_msg(msg);
err->set_job_build_and_infer_error(err_code);
return err;
}
......
......@@ -12,7 +12,8 @@
namespace oneflow {
Error GenJobBuildAndInferError(JobBuildAndInferError err_code, std::string msg);
std::shared_ptr<ErrorProto> GenJobBuildAndInferError(JobBuildAndInferError err_code,
std::string msg);
class JobBuildAndInferCtx {
public:
......
......@@ -69,7 +69,8 @@ Maybe<void> NormalizationGradOp::InferBlobDescs(
BlobDesc* dx = GetBlobDesc4BnInOp("dx");
if (dx) { *dx = *x; }
const Shape param_shape({x->shape().At(conf.axis())});
const std::function<void(const std::string&)> CheckParamBlobDesc = [&](const std::string& bn) {
const std::function<void(const std::string&)> CheckParamBlobDesc =
[&](const std::string& bn) -> Maybe<void> {
const BlobDesc* blob_desc = GetBlobDesc4BnInOp(bn);
if (blob_desc != nullptr) {
CHECK_EQ_OR_RETURN(blob_desc->data_type(), data_type);
......@@ -77,7 +78,8 @@ Maybe<void> NormalizationGradOp::InferBlobDescs(
}
return Maybe<void>::Ok();
};
const std::function<void(const std::string&)> SetParamBlobDesc = [&](const std::string& bn) {
const std::function<void(const std::string&)> SetParamBlobDesc =
[&](const std::string& bn) -> Maybe<void> {
BlobDesc* blob_desc = GetBlobDesc4BnInOp(bn);
if (blob_desc != nullptr) {
blob_desc->set_data_type(data_type);
......
......@@ -62,7 +62,8 @@ Maybe<void> NormalizationOp::InferBlobDescs(
const DataType data_type = in->data_type();
*GetBlobDesc4BnInOp("out") = *in;
const Shape param_shape({in->shape().At(conf.axis())});
const std::function<void(const std::string&)> CheckParamBlobDesc = [&](const std::string& bn) {
const std::function<void(const std::string&)> CheckParamBlobDesc =
[&](const std::string& bn) -> Maybe<void> {
const BlobDesc* blob_desc = GetBlobDesc4BnInOp(bn);
if (blob_desc != nullptr) {
CHECK_EQ_OR_RETURN(blob_desc->data_type(), data_type);
......@@ -70,7 +71,8 @@ Maybe<void> NormalizationOp::InferBlobDescs(
}
return Maybe<void>::Ok();
};
const std::function<void(const std::string&)> SetParamBlobDesc = [&](const std::string& bn) {
const std::function<void(const std::string&)> SetParamBlobDesc =
[&](const std::string& bn) -> Maybe<void> {
BlobDesc* blob_desc = GetBlobDesc4BnInOp(bn);
if (blob_desc != nullptr) {
blob_desc->set_data_type(data_type);
......
......@@ -21,7 +21,7 @@ def Init(config_proto):
def InitGlobalOneflow():
serialized_error = oneflow_internal.InitGlobalOneflow()
error = text_format.Parse(serialized_error, error_util.Error())
error = text_format.Parse(serialized_error, error_util.ErrorProto())
if error.HasField("error_type"): raise JobBuildAndInferError(error)
def GetInterUserJobInfo():
......@@ -43,12 +43,12 @@ def DestroyGlobalEnvironment():
def JobBuildAndInferCtx_Open(job_name):
job_name = str(job_name)
serialized_error = oneflow_internal.JobBuildAndInferCtx_Open(job_name)
error = text_format.Parse(serialized_error, error_util.Error())
error = text_format.Parse(serialized_error, error_util.ErrorProto())
if error.HasField("error_type"): raise JobBuildAndInferError(error)
def JobBuildAndInferCtx_GetCurrentJobName():
job_name, error_str = oneflow_internal.JobBuildAndInferCtx_GetCurrentJobName()
error = text_format.Parse(error_str, error_util.Error())
error = text_format.Parse(error_str, error_util.ErrorProto())
if error.HasField("error_type"): raise JobBuildAndInferError(error)
return job_name
......@@ -58,7 +58,7 @@ def JobBuildAndInferCtx_Close():
def CurJobBuildAndInferCtx_SetJobConf(job_config_proto):
serialized_job_conf = str(text_format.MessageToString(job_config_proto))
serialized_error = oneflow_internal.CurJobBuildAndInferCtx_SetJobConf(serialized_job_conf)
error = text_format.Parse(serialized_error, error_util.Error())
error = text_format.Parse(serialized_error, error_util.ErrorProto())
if error.HasField("error_type"): raise JobBuildAndInferError(error)
def CurJobBuildAndInferCtx_AddAndInferOp(op_conf_proto, parallel_conf_proto):
......@@ -66,29 +66,29 @@ def CurJobBuildAndInferCtx_AddAndInferOp(op_conf_proto, parallel_conf_proto):
serialized_parallel_conf = str(text_format.MessageToString(parallel_conf_proto))
add_and_infer = oneflow_internal.CurJobBuildAndInferCtx_AddAndInferOp
error_str = add_and_infer(serialized_op_conf, serialized_parallel_conf)
error = text_format.Parse(error_str, error_util.Error())
error = text_format.Parse(error_str, error_util.ErrorProto())
if error.HasField("error_type"): raise JobBuildAndInferError(error)
def CurJobBuildAndInferCtx_AddLossLogicalBlobName(lbn):
lbn = str(lbn)
error_str = oneflow_internal.CurJobBuildAndInferCtx_AddLossLogicalBlobName(lbn)
error = text_format.Parse(error_str, error_util.Error())
error = text_format.Parse(error_str, error_util.ErrorProto())
if error.HasField("error_type"): raise JobBuildAndInferError(error)
def CurJobBuildAndInferCtx_AddPlacementGroup(placement_group_proto):
serialized_placement_grp = str(text_format.MessageToString(placement_group_proto))
error_str = oneflow_internal.CurJobBuildAndInferCtx_AddPlacementGroup(serialized_placement_grp)
error = text_format.Parse(error_str, error_util.Error())
error = text_format.Parse(error_str, error_util.ErrorProto())
if error.HasField("error_type"): raise JobBuildAndInferError(error)
def CurJobBuildAndInferCtx_CheckJob():
error_str = oneflow_internal.CurJobBuildAndInferCtx_CheckJob()
error = text_format.Parse(error_str, error_util.Error())
error = text_format.Parse(error_str, error_util.ErrorProto())
if error.HasField("error_type"): raise JobBuildAndInferError(error)
def CurJobBuildAndInferCtx_HasJobConf():
has_job_conf, error_str = oneflow_internal.CurJobBuildAndInferCtx_HasJobConf()
error = text_format.Parse(error_str, error_util.Error())
error = text_format.Parse(error_str, error_util.ErrorProto())
if error.HasField("error_type"): raise JobBuildAndInferError(error)
return has_job_conf
......@@ -97,7 +97,7 @@ def JobBuildAndInferCtx_GetStaticShape(job_name, lbn):
lbn = str(lbn)
axis_str, error_str = \
oneflow_internal.JobBuildAndInferCtx_GetSerializedIdListAsStaticShape(job_name, lbn)
error = text_format.Parse(error_str, error_util.Error())
error = text_format.Parse(error_str, error_util.ErrorProto())
if error.HasField("error_type"): raise JobBuildAndInferError(error)
int_list = text_format.Parse(axis_str, record_util.Int64List())
return tuple(map(int, int_list.value))
......@@ -106,7 +106,7 @@ def JobBuildAndInferCtx_GetDataType(job_name, lbn):
job_name = str(job_name)
lbn = str(lbn)
dtype, erro_str = oneflow_internal.JobBuildAndInferCtx_GetDataType(job_name, lbn)
error = text_format.Parse(erro_str, error_util.Error())
error = text_format.Parse(erro_str, error_util.ErrorProto())
if error.HasField("error_type"): raise JobBuildAndInferError(error)
return int(dtype)
......@@ -115,7 +115,7 @@ def JobBuildAndInferCtx_GetBatchAxis(job_name, lbn):
lbn = str(lbn)
batch_axis_str, error_str = oneflow_internal.JobBuildAndInferCtx_GetBatchAxis(job_name, lbn)
batch_axis = text_format.Parse(batch_axis_str, dtype_util.OptInt64())
error = text_format.Parse(error_str, error_util.Error())
error = text_format.Parse(error_str, error_util.ErrorProto())
if error.HasField("error_type"): raise JobBuildAndInferError(error)
if batch_axis.HasField("value"): return batch_axis.value
return None
......@@ -125,7 +125,7 @@ def JobBuildAndInferCtx_GetHasSplitAxisFromProducerView(job_name, lbn):
lbn = str(lbn)
has_split_axis, error_str = \
oneflow_internal.JobBuildAndInferCtx_GetHasSplitAxisFromProducerView(job_name, lbn)
error = text_format.Parse(error_str, error_util.Error())
error = text_format.Parse(error_str, error_util.ErrorProto())
if error.HasField("error_type"): raise JobBuildAndInferError(error)
return has_split_axis
......@@ -134,7 +134,7 @@ def JobBuildAndInferCtx_GetSplitAxisFromProducerView(job_name, lbn):
lbn = str(lbn)
split_axis, error_str = \
oneflow_internal.JobBuildAndInferCtx_GetSplitAxisFromProducerView(job_name, lbn)
error = text_format.Parse(error_str, error_util.Error())
error = text_format.Parse(error_str, error_util.ErrorProto())
if error.HasField("error_type"): raise JobBuildAndInferError(error)
return split_axis
......@@ -142,14 +142,14 @@ def JobBuildAndInferCtx_GetParallelConfFromProducerView(job_name, lbn):
job_name = str(job_name)
lbn = str(lbn)
parallel_conf, error_str = oneflow_internal.JobBuildAndInferCtx_GetSerializedParallelConfFromProducerView(job_name, lbn)
error = text_format.Parse(error_str, error_util.Error())
error = text_format.Parse(error_str, error_util.ErrorProto())
if error.HasField("error_type"): raise JobBuildAndInferError(error)
return text_format.Parse(parallel_conf, placment_util.ParallelConf())
def DeviceType4DeviceTag(device_tag):
device_tag = str(device_tag)
device_type, error_str = oneflow_internal.DeviceType4DeviceTag(device_tag)
error = text_format.Parse(error_str, error_util.Error())
error = text_format.Parse(error_str, error_util.ErrorProto())
if error.HasField("error_type"): raise JobBuildAndInferError(error)
return device_type
......@@ -12,14 +12,14 @@ std::string JobBuildAndInferCtx_Open(const std::string& job_name) {
using namespace oneflow;
auto maybe_ok = TRY(Global<JobBuildAndInferCtxMgr>::Get()->OpenJobBuildAndInferCtx(job_name));
if (maybe_ok.IsOk() == false) { return PbMessage2TxtString(*maybe_ok.error()); }
return PbMessage2TxtString(ErrorUtil::Ok());
return PbMessage2TxtString(*ErrorUtil::Ok());
}
std::string JobBuildAndInferCtx_GetCurrentJobName(std::string* error_str) {
using namespace oneflow;
auto maybe_ok = TRY(Global<JobBuildAndInferCtxMgr>::Get()->GetCurrentJobName());
if (maybe_ok.IsOk()) {
PbMessage2TxtString(ErrorUtil::Ok(), error_str);
PbMessage2TxtString(*ErrorUtil::Ok(), error_str);
return *maybe_ok.data();
} else {
PbMessage2TxtString(*maybe_ok.error(), error_str);
......@@ -39,7 +39,7 @@ std::string CurJobBuildAndInferCtx_CheckJob() {
if (ctx == nullptr) { return error_str; }
auto maybe_ok = TRY(ctx->CheckJob());
if (maybe_ok.IsOk() == false) { return PbMessage2TxtString(*maybe_ok.error()); }
return PbMessage2TxtString(ErrorUtil::Ok());
return PbMessage2TxtString(*ErrorUtil::Ok());
}
std::string CurJobBuildAndInferCtx_SetJobConf(const std::string& serialized_job_conf) {
......@@ -47,7 +47,7 @@ std::string CurJobBuildAndInferCtx_SetJobConf(const std::string& serialized_job_
// parse
JobConfigProto job_conf;
if (TxtString2PbMessage(serialized_job_conf, &job_conf) == false) {
return PbMessage2TxtString(ErrorUtil::ProtoParseFailedError("job conf parse failed"));
return PbMessage2TxtString(*ErrorUtil::ProtoParseFailedError("job conf parse failed"));
}
// get current JobBuildandInferCtx
std::string error_str;
......@@ -56,14 +56,14 @@ std::string CurJobBuildAndInferCtx_SetJobConf(const std::string& serialized_job_
// set job_conf
auto maybe_ok = TRY(ctx->SetJobConf(job_conf));
if (maybe_ok.IsOk() == false) { return PbMessage2TxtString(*maybe_ok.error()); }
return PbMessage2TxtString(ErrorUtil::Ok());
return PbMessage2TxtString(*ErrorUtil::Ok());
}
bool CurJobBuildAndInferCtx_HasJobConf(std::string* error_str) {
using namespace oneflow;
JobBuildAndInferCtx* ctx = JobBuildAndInferHelper::GetCurInferCtx(error_str);
if (ctx == nullptr) { return false; }
PbMessage2TxtString(ErrorUtil::Ok(), error_str);
PbMessage2TxtString(*ErrorUtil::Ok(), error_str);
return ctx->HasJobConf();
}
......@@ -73,11 +73,11 @@ std::string CurJobBuildAndInferCtx_AddAndInferOp(const std::string& serialized_o
// parse
OperatorConf op_conf;
if (TxtString2PbMessage(serialized_op_conf, &op_conf) == false) {
return PbMessage2TxtString(ErrorUtil::ProtoParseFailedError("operator conf parse failed"));
return PbMessage2TxtString(*ErrorUtil::ProtoParseFailedError("operator conf parse failed"));
}
ParallelConf parallel_conf;
if (TxtString2PbMessage(serialized_parallel_conf, &parallel_conf) == false) {
return PbMessage2TxtString(ErrorUtil::ProtoParseFailedError("parallel conf parse failed"));
return PbMessage2TxtString(*ErrorUtil::ProtoParseFailedError("parallel conf parse failed"));
}
// get current JobBuildandInferCtx
std::string error_str;
......@@ -86,7 +86,7 @@ std::string CurJobBuildAndInferCtx_AddAndInferOp(const std::string& serialized_o
// add and infer input_op
auto maybe_ok = TRY(ctx->AddAndInferOp(op_conf, parallel_conf));
if (maybe_ok.IsOk() == false) { return PbMessage2TxtString(*maybe_ok.error()); }
return PbMessage2TxtString(ErrorUtil::Ok());
return PbMessage2TxtString(*ErrorUtil::Ok());
}
std::string CurJobBuildAndInferCtx_AddLossLogicalBlobName(const std::string& lbn) {
......@@ -98,7 +98,7 @@ std::string CurJobBuildAndInferCtx_AddLossLogicalBlobName(const std::string& lbn
// add loss_lbn
auto maybe_ok = TRY(ctx->AddLossLogicalBlobName(lbn));
if (maybe_ok.IsOk() == false) { return PbMessage2TxtString(*maybe_ok.error()); }
return PbMessage2TxtString(ErrorUtil::Ok());
return PbMessage2TxtString(*ErrorUtil::Ok());
}
std::string CurJobBuildAndInferCtx_AddPlacementGroup(const std::string& serialized_placement_grp) {
......@@ -106,7 +106,7 @@ std::string CurJobBuildAndInferCtx_AddPlacementGroup(const std::string& serializ
// parse
PlacementGroup placement_group;
if (TxtString2PbMessage(serialized_placement_grp, &placement_group) == false) {
return PbMessage2TxtString(ErrorUtil::ProtoParseFailedError("placement group parse failed"));
return PbMessage2TxtString(*ErrorUtil::ProtoParseFailedError("placement group parse failed"));
}
// get current JobBuildandInferCtx
std::string error_str;
......@@ -115,7 +115,7 @@ std::string CurJobBuildAndInferCtx_AddPlacementGroup(const std::string& serializ
// add and infer input_op
auto maybe_ok = TRY(ctx->AddPlacementGroup(placement_group));
if (maybe_ok.IsOk() == false) { return PbMessage2TxtString(*maybe_ok.error()); }
return PbMessage2TxtString(ErrorUtil::Ok());
return PbMessage2TxtString(*ErrorUtil::Ok());
}
std::string JobBuildAndInferCtx_GetSerializedIdListAsStaticShape(const std::string& job_name,
......@@ -133,7 +133,7 @@ std::string JobBuildAndInferCtx_GetSerializedIdListAsStaticShape(const std::stri
PbMessage2TxtString(*maybe_shape.error(), error_str);
return PbMessage2TxtString(id_list);
}
PbMessage2TxtString(ErrorUtil::Ok(), error_str);
PbMessage2TxtString(*ErrorUtil::Ok(), error_str);
const auto& shape = *maybe_shape.data();
*id_list.mutable_value() = {shape.dim_vec().begin(), shape.dim_vec().end()};
return PbMessage2TxtString(id_list);
......@@ -152,7 +152,7 @@ long long JobBuildAndInferCtx_GetDataType(const std::string& job_name, const std
PbMessage2TxtString(*maybe_data_type.error(), error_str);
return 0LL;
}
PbMessage2TxtString(ErrorUtil::Ok(), error_str);
PbMessage2TxtString(*ErrorUtil::Ok(), error_str);
return *maybe_data_type.data();
}
......@@ -169,7 +169,7 @@ std::string JobBuildAndInferCtx_GetBatchAxis(const std::string& job_name, const
PbMessage2TxtString(*maybe_has_batch_dim.error(), error_str);
return PbMessage2TxtString(OptInt64());
}
PbMessage2TxtString(ErrorUtil::Ok(), error_str);
PbMessage2TxtString(*ErrorUtil::Ok(), error_str);
return PbMessage2TxtString(*maybe_has_batch_dim.data());
}
......@@ -187,7 +187,7 @@ bool JobBuildAndInferCtx_GetHasSplitAxisFromProducerView(const std::string& job_
PbMessage2TxtString(*maybe_has_split_axis.error(), error_str);
return false;
}
PbMessage2TxtString(ErrorUtil::Ok(), error_str);
PbMessage2TxtString(*ErrorUtil::Ok(), error_str);
return *maybe_has_split_axis.data();
}
......@@ -205,7 +205,7 @@ long long JobBuildAndInferCtx_GetSplitAxisFromProducerView(const std::string& jo
PbMessage2TxtString(*maybe_split_axis.error(), error_str);
return 0LL;
}
PbMessage2TxtString(ErrorUtil::Ok(), error_str);
PbMessage2TxtString(*ErrorUtil::Ok(), error_str);
return *maybe_split_axis.data();
}
......@@ -222,6 +222,6 @@ std::string JobBuildAndInferCtx_GetSerializedParallelConfFromProducerView(
PbMessage2TxtString(*maybe_parallel_conf.error(), error_str);
return PbMessage2TxtString(ParallelConf());
}
PbMessage2TxtString(ErrorUtil::Ok(), error_str);
PbMessage2TxtString(*ErrorUtil::Ok(), error_str);
return PbMessage2TxtString(maybe_parallel_conf.data()->parallel_conf());
}
......@@ -57,14 +57,14 @@ std::string InitGlobalOneflow() {
ClusterControl::MasterSendSessionStart();
const JobSet& job_set = Global<JobBuildAndInferCtxMgr>::Get()->job_set();
if (job_set.job().empty()) {
return PbMessage2TxtString(ErrorUtil::JobSetEmpty("no function defined"));
return PbMessage2TxtString(*ErrorUtil::JobSetEmpty("no function defined"));
}
CHECK_ISNULL(Global<Oneflow>::Get());
Global<CtrlClient>::Get()->PushKV("session_job_set", job_set);
Global<RuntimeBufferManagersScope>::New();
Global<JobSetCompileCtx>::New();
Global<Oneflow>::New(job_set);
return PbMessage2TxtString(ErrorUtil::Ok());
return PbMessage2TxtString(*ErrorUtil::Ok());
}
std::string GetSerializedInterUserJobInfo() {
......@@ -136,7 +136,7 @@ long long DeviceType4DeviceTag(const std::string& device_tag, std::string* error
PbMessage2TxtString(*maybe_dev_type.error(), error_str);
return DeviceType::kInvalidDevice;
}
PbMessage2TxtString(ErrorUtil::Ok(), error_str);
PbMessage2TxtString(*ErrorUtil::Ok(), error_str);
return *maybe_dev_type.data();
}
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment