diff --git a/oneflow/core/framework/py_remote_blob.cpp b/oneflow/core/framework/py_remote_blob.cpp index 7da4754837e14f3ab02763dd6dd42a1955c44448..12892d1f72e2e205549ede41a3927c8c17c0fe8a 100644 --- a/oneflow/core/framework/py_remote_blob.cpp +++ b/oneflow/core/framework/py_remote_blob.cpp @@ -197,6 +197,7 @@ int64_t EagerBlobTrait::split_axis() const { return INVALID_SPLIT_AXIS; } else { UNIMPLEMENTED(); + return 0; } } diff --git a/oneflow/core/job_rewriter/pass_util.h b/oneflow/core/job_rewriter/pass_util.h index 2b1ea58add2abd4426af46e3f53d7d7be584ec0e..23620b7b2086b554a62ce597fa60e47174881215 100644 --- a/oneflow/core/job_rewriter/pass_util.h +++ b/oneflow/core/job_rewriter/pass_util.h @@ -20,6 +20,7 @@ limitations under the License. namespace oneflow { #define INSERT_CHECK(expr) CHECK(expr.second) +#define INSERT_CHECK_OR_RETURN(expr) CHECK_OR_RETURN(expr.second) template<typename MapT, typename KeyT> bool IsKeyFound(const MapT& m, const KeyT& k) { diff --git a/oneflow/core/job_rewriter/quantization_aware_training.cpp b/oneflow/core/job_rewriter/quantization_aware_training.cpp index c888ed7121887d75180775397750df12080446a5..97ffaa447297d1c4ffda31f29a25c5773b387f66 100644 --- a/oneflow/core/job_rewriter/quantization_aware_training.cpp +++ b/oneflow/core/job_rewriter/quantization_aware_training.cpp @@ -40,11 +40,12 @@ const std::string MUL_BIAS_SUFFIX = "-fake-quant-mul-bias"; const std::string OBSERVER_SUFFIX = "-fake-quant-observer"; const std::string TRAIN_STEP_SUFFIX = "-fake-train-step"; -void VerifyQATList(const OpTypeSet& op_list) { +Maybe<void> VerifyQATList(const OpTypeSet& op_list) { for (const auto& op_type : op_list) { - CHECK(user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_type) != nullptr) + CHECK_OR_RETURN(user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_type) != nullptr) << "Cannot find " << op_type << " of QuantAwareTraining list in OpRegistry."; } + return Maybe<void>::Ok(); } HashMap<std::string, std::string> scale_map; @@ -168,47 +169,47 @@ std::string QuantizationSchemeAttr4QatConfig(const QatConfig& qat_config) { } // TODO: refactor the following 4 methods by registration -std::string QuantizationFormulaAttr4QatConfig(const QatConfig& qat_config) { +Maybe<std::string> QuantizationFormulaAttr4QatConfig(const QatConfig& qat_config) { const auto target_backend = qat_config.target_backend(); if (target_backend == "" || target_backend == "tensorrt") { - return "google"; + return std::string("google"); } else if (target_backend == "cambricon") { - return "cambricon"; + return std::string("cambricon"); } else { - UNIMPLEMENTED(); + UNIMPLEMENTED_THEN_RETURN(); } } -OpTypeSet Int8List4QatConfig(const QatConfig& qat_config) { +Maybe<OpTypeSet> Int8List4QatConfig(const QatConfig& qat_config) { const auto target_backend = qat_config.target_backend(); if (target_backend == "") { - return {"add_n", "matmul", "batch_matmul", "conv2d", "avg_pool_2d", "max_pool_2d"}; + return OpTypeSet{"add_n", "matmul", "batch_matmul", "conv2d", "avg_pool_2d", "max_pool_2d"}; } else if (target_backend == "cambricon" || target_backend == "tensorrt") { - return {"conv2d", "matmul"}; + return OpTypeSet{"conv2d", "matmul"}; } else { - UNIMPLEMENTED(); + UNIMPLEMENTED_THEN_RETURN(); } } -OpTypeSet TransparentList4QatConfig(const QatConfig& qat_config) { +Maybe<OpTypeSet> TransparentList4QatConfig(const QatConfig& qat_config) { const auto target_backend = qat_config.target_backend(); if (target_backend == "" || target_backend == "tensorrt") { - return {"reshape"}; + return OpTypeSet{"reshape"}; } else if (target_backend == "cambricon") { - return {}; + return OpTypeSet{}; } else { - UNIMPLEMENTED(); + UNIMPLEMENTED_THEN_RETURN(); } } -bool InsertQuantOpAfterInt8Ops4QatConfig(const QatConfig& qat_config) { +Maybe<bool> InsertQuantOpAfterInt8Ops4QatConfig(const QatConfig& qat_config) { const auto target_backend = qat_config.target_backend(); if (target_backend == "" || target_backend == "tensorrt") { return true; } else if (target_backend == "cambricon") { return false; } else { - UNIMPLEMENTED(); + UNIMPLEMENTED_THEN_RETURN(); } } @@ -226,16 +227,18 @@ user_op::UserOpConfWrapper MultiplyOp(const std::string& name, const std::string return op_wrapper; } -user_op::UserOpConfWrapper MinMaxObserver(const std::string& name, const std::string& input, - const QatConfig& qat_config, - const int64_t scope_symbol_id, OpConfMap* inserted_ops) { +Maybe<user_op::UserOpConfWrapper> MinMaxObserver(const std::string& name, const std::string& input, + const QatConfig& qat_config, + const int64_t scope_symbol_id, + OpConfMap* inserted_ops) { const auto op_wrapper = user_op::UserOpConfWrapperBuilder(name) .Op("min_max_observer") .Input("in", input) .Output("scale") .Output("zero_point") - .Attr<std::string>("quantization_formula", QuantizationFormulaAttr4QatConfig(qat_config)) + .Attr<std::string>("quantization_formula", + *JUST(QuantizationFormulaAttr4QatConfig(qat_config))) .Attr<std::string>("quantization_scheme", QuantizationSchemeAttr4QatConfig(qat_config)) .Attr("per_layer_quantization", PerLayerQuantizationAttr4Config(qat_config)) .ScopeSymbolId(scope_symbol_id) @@ -244,11 +247,9 @@ user_op::UserOpConfWrapper MinMaxObserver(const std::string& name, const std::st return op_wrapper; } -user_op::UserOpConfWrapper MovingMinMaxObserver(const std::string& name, const std::string& input, - const std::string& train_step_lbn, - const QatConfig& qat_config, - const int64_t scope_symbol_id, - OpConfMap* inserted_ops) { +Maybe<user_op::UserOpConfWrapper> MovingMinMaxObserver( + const std::string& name, const std::string& input, const std::string& train_step_lbn, + const QatConfig& qat_config, const int64_t scope_symbol_id, OpConfMap* inserted_ops) { const std::string moving_max_name = name + MOVING_MAX_SUFFIX; const std::string moving_min_name = name + MOVING_MIN_SUFFIX; const auto moving_max_var = @@ -276,7 +277,8 @@ user_op::UserOpConfWrapper MovingMinMaxObserver(const std::string& name, const s .Output("zero_point") .Attr("training", GlobalJobDesc().IsTrain()) .Attr("stop_update_after_iters", qat_config.moving_min_max_stop_update_after_iters()) - .Attr<std::string>("quantization_formula", QuantizationFormulaAttr4QatConfig(qat_config)) + .Attr<std::string>("quantization_formula", + *JUST(QuantizationFormulaAttr4QatConfig(qat_config))) .Attr<std::string>("quantization_scheme", QuantizationSchemeAttr4QatConfig(qat_config)) .Attr("momentum", qat_config.moving_min_max_momentum()) .ScopeSymbolId(scope_symbol_id) @@ -285,17 +287,20 @@ user_op::UserOpConfWrapper MovingMinMaxObserver(const std::string& name, const s return op_wrapper; } -user_op::UserOpConfWrapper FakeQuantOp(const std::string& name, const std::string& input, - const std::string& scale, const std::string& zero_point, - const QatConfig& qat_config, const int64_t scope_symbol_id, - OpConfMap* inserted_ops) { +Maybe<user_op::UserOpConfWrapper> FakeQuantOp(const std::string& name, const std::string& input, + const std::string& scale, + const std::string& zero_point, + const QatConfig& qat_config, + const int64_t scope_symbol_id, + OpConfMap* inserted_ops) { const auto op_wrapper = user_op::UserOpConfWrapperBuilder(name) .Op("fake_quantization") .Input("in", input) .Input("scale", scale) .Input("zero_point", zero_point) - .Attr<std::string>("quantization_formula", QuantizationFormulaAttr4QatConfig(qat_config)) + .Attr<std::string>("quantization_formula", + *JUST(QuantizationFormulaAttr4QatConfig(qat_config))) .Attr<std::string>("quantization_scheme", QuantizationSchemeAttr4QatConfig(qat_config)) .Output("out") .ScopeSymbolId(scope_symbol_id) @@ -329,15 +334,15 @@ Maybe<void> GetScaleAndZeroPointLbn4Edge(OpEdge* edge, const std::string train_s const std::string observer_op_name = ReplaceSlashToDash4Lbn(lbn) + OBSERVER_SUFFIX; if (IsWeightEdge(edge)) { const auto observer_op = - MinMaxObserver(observer_op_name, lbn, qat_config, scope_symbol_id, inserted_ops); - *scale = observer_op.output("scale", 0); - *zero_point = observer_op.output("zero_point", 0); + JUST(MinMaxObserver(observer_op_name, lbn, qat_config, scope_symbol_id, inserted_ops)); + *scale = observer_op->output("scale", 0); + *zero_point = observer_op->output("zero_point", 0); } else { CHECK_OR_RETURN(qat_config.has_moving_min_max_stop_update_after_iters()); - const auto observer_op = MovingMinMaxObserver(observer_op_name, lbn, train_step_lbn, - qat_config, scope_symbol_id, inserted_ops); - *scale = observer_op.output("scale", 0); - *zero_point = observer_op.output("zero_point", 0); + const auto observer_op = JUST(MovingMinMaxObserver( + observer_op_name, lbn, train_step_lbn, qat_config, scope_symbol_id, inserted_ops)); + *scale = observer_op->output("scale", 0); + *zero_point = observer_op->output("zero_point", 0); } } return Maybe<void>::Ok(); @@ -374,9 +379,9 @@ class QuantAwareTraining final : public JobPass { HashSet<OpNode*> downstream_white, Job* job) const; }; -bool IsNodeQuantizationEnabled(const OpNode& node) { +Maybe<bool> IsNodeQuantizationEnabled(const OpNode& node) { int64_t scope_symbol_id = node.op().op_conf().scope_symbol_id(); - CHECK(Global<symbol::Storage<Scope>>::Get()->Has(scope_symbol_id)); + CHECK_OR_RETURN(Global<symbol::Storage<Scope>>::Get()->Has(scope_symbol_id)); const Scope& scope = Global<symbol::Storage<Scope>>::Get()->Get(scope_symbol_id); return scope.Bool("quantization_aware_training"); } @@ -384,20 +389,20 @@ bool IsNodeQuantizationEnabled(const OpNode& node) { Maybe<void> QuantAwareTraining::Apply(Job* job, JobPassCtx* ctx) const { if (!IsEnabled(*ctx)) { return Maybe<void>::Ok(); } const OpGraph op_graph(*job); - CHECK(GlobalJobDesc().DefaultDataType() == DataType::kFloat); + CHECK_OR_RETURN(GlobalJobDesc().DefaultDataType() == DataType::kFloat); const auto qat_config = ctx->job_desc().job_conf().qat_config(); - OpTypeSet int8_list = Int8List4QatConfig(qat_config); - OpTypeSet transparent_list = TransparentList4QatConfig(qat_config); + OpTypeSet int8_list = *JUST(Int8List4QatConfig(qat_config)); + OpTypeSet transparent_list = *JUST(TransparentList4QatConfig(qat_config)); // if `insert_quant_op_after_int8_ops` is false, // always insert quant op before int8 ops. // if `insert_quant_op_after_int8_ops` is true, // always insert quant op after int8 ops - bool insert_quant_op_after_int8_ops = InsertQuantOpAfterInt8Ops4QatConfig(qat_config); + bool insert_quant_op_after_int8_ops = JUST(InsertQuantOpAfterInt8Ops4QatConfig(qat_config)); - VerifyQATList(int8_list); - VerifyQATList(transparent_list); + JUST(VerifyQATList(int8_list)); + JUST(VerifyQATList(transparent_list)); std::function<std::string(OpNode* const&)> OpName4Node = [](OpNode* const& node) { return node->op().op_name(); @@ -456,7 +461,7 @@ Maybe<void> QuantAwareTraining::InsertFakeQuantOp(const QatConfig& qat_config, const std::string lbn = GenLogicalBlobName(edge->lbis().front()); scale_map[lbn] = ReplaceSlashToDash4Lbn(lbn) + OBSERVER_SUFFIX + "/scale_0"; VLOG(3) << "set " << lbn << " to " << scale_map[lbn]; - INSERT_CHECK(white_set_edges.insert(edge)); + INSERT_CHECK_OR_RETURN(white_set_edges.insert(edge)); return Maybe<void>::Ok(); }; auto PropagateScale = [](OpNode* node) -> Maybe<void> { @@ -478,16 +483,16 @@ Maybe<void> QuantAwareTraining::InsertFakeQuantOp(const QatConfig& qat_config, if (IsKeyFound(white_set, node)) { for (OpEdge* edge : node->in_edges()) { if (IsKeyFound(white_set, edge->src_node())) { continue; } - if (IsNodeQuantizationEnabled(*edge->dst_node())) { JUST(AddWhiteSetEdge(edge)); } + if (JUST(IsNodeQuantizationEnabled(*edge->dst_node()))) { JUST(AddWhiteSetEdge(edge)); } } if (IsNodeInList(int8_list, node)) { if (insert_quant_op_after_int8_ops) { OpNode* inference_node = JUST(GetInferenceOutputNode(op_graph, node)); - if (IsNodeQuantizationEnabled(*inference_node)) { + if (JUST(IsNodeQuantizationEnabled(*inference_node))) { for (OpEdge* edge : inference_node->out_edges()) { JUST(AddWhiteSetEdge(edge)); } } } else { - if (IsNodeQuantizationEnabled(*node)) { + if (JUST(IsNodeQuantizationEnabled(*node))) { for (OpEdge* edge : node->in_edges()) { if (white_set_edges.find(edge) == white_set_edges.end()) { JUST(AddWhiteSetEdge(edge)); @@ -535,10 +540,10 @@ Maybe<void> QuantAwareTraining::InsertFakeQuantOp(const QatConfig& qat_config, JUST(GetScaleAndZeroPointLbn4Edge(edge, job->job_conf().train_conf().train_step_lbn(), &scale, &zero_point, qat_config, scope_symbol_id, &inserted_ops)); const std::string fake_quant_op_name = ReplaceSlashToDash4Lbn(lbn) + FAKE_QUANT_SUFFIX; - const auto fake_quant_op = FakeQuantOp(fake_quant_op_name, lbn, scale, zero_point, qat_config, - scope_symbol_id, &inserted_ops); + const auto fake_quant_op = JUST(FakeQuantOp(fake_quant_op_name, lbn, scale, zero_point, + qat_config, scope_symbol_id, &inserted_ops)); - const std::string fake_quant_op_output_name = fake_quant_op.output("out", 0); + const std::string fake_quant_op_output_name = fake_quant_op->output("out", 0); JUST(ReplaceInputLbn4DstNodeOfEdge(edge, fake_quant_op_output_name, &op_conf_cache)); } diff --git a/oneflow/xrt/xla/ops/scalar_binary_op.cpp b/oneflow/xrt/xla/ops/scalar_binary_op.cpp index 2e646aa940af411b606d0bd42a0bee3c96bbe0bd..a8f2131e4925085efc3b8296e055828db06403a2 100644 --- a/oneflow/xrt/xla/ops/scalar_binary_op.cpp +++ b/oneflow/xrt/xla/ops/scalar_binary_op.cpp @@ -44,6 +44,8 @@ class ScalarBinaryOp : public XlaOpKernel { double value = ctx->Attr<double>("float_operand"); return FloatLiteral(builder, data_type, value); } + UNIMPLEMENTED(); + return xla::XlaOp(); } };