Skip to content
Snippets Groups Projects
Unverified Commit 2638681c authored by Juncheng's avatar Juncheng Committed by GitHub
Browse files

Fix DeviceTick time shape (#4587)


Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent 9dafa9d9
No related branches found
No related tags found
No related merge requests found
......@@ -75,11 +75,11 @@ void AccCompActor::Act() {
if (acc_cnt_ == 0) {
Blob* out_blob = out_regst->GetMutSoleBlob();
if (GetDeviceType() == DeviceType::kCPU) {
Memset<DeviceType::kCPU>(kernel_ctx.device_ctx, out_blob->mut_dptr(), 0,
Memset<DeviceType::kCPU>(kernel_ctx.device_ctx, out_blob->ForceMutDptr(), 0,
out_blob->ByteSizeOfBlobBody());
} else if (GetDeviceType() == DeviceType::kGPU) {
#ifdef WITH_CUDA
Memset<DeviceType::kGPU>(kernel_ctx.device_ctx, out_blob->mut_dptr(), 0,
Memset<DeviceType::kGPU>(kernel_ctx.device_ctx, out_blob->ForceMutDptr(), 0,
out_blob->ByteSizeOfBlobBody());
#else
UNIMPLEMENTED();
......
......@@ -206,22 +206,26 @@ OperatorConf MakeDeviceTickOpConf(const std::string& tick_name) {
}
OperatorConf AppendTick(const std::string tick_name, const std::vector<std::string>& op_names,
ParallelConf parallel_conf, JobBuilder* job_builder) {
const std::shared_ptr<const Shape>& time_shape, ParallelConf parallel_conf,
JobBuilder* job_builder) {
OperatorConf device_tick_op_conf = MakeDeviceTickOpConf(tick_name);
if (time_shape) {
time_shape->ToProto(device_tick_op_conf.mutable_device_tick_conf()->mutable_time_shape());
}
for (const auto& op_name : op_names) { device_tick_op_conf.add_ctrl_in_op_name(op_name); }
job_builder->AddOps(parallel_conf, {device_tick_op_conf});
return device_tick_op_conf;
}
OperatorConf AppendTick(const std::string tick_name, const std::list<const OpNode*>& op_nodes,
JobBuilder* job_builder) {
const std::shared_ptr<const Shape>& time_shape, JobBuilder* job_builder) {
std::vector<std::string> op_names;
for (const auto* op_node : op_nodes) {
CHECK(op_nodes.front()->parallel_desc() == op_node->parallel_desc());
op_names.push_back(op_node->op().op_name());
}
return AppendTick(tick_name, op_names, op_nodes.front()->parallel_desc().parallel_conf(),
job_builder);
return AppendTick(tick_name, op_names, time_shape,
op_nodes.front()->parallel_desc().parallel_conf(), job_builder);
}
OperatorConf PrependTick(const HashSet<const OpNode*>& op_nodes, JobBuilder* job_builder) {
......@@ -244,7 +248,7 @@ OperatorConf AppendAccTick(const Shape& src_shape, const std::list<const OpNode*
JobBuilder* job_builder) {
std::shared_ptr<const Shape> tick_shape = CHECK_JUST(op_nodes.front()->op().GetOpTimeShape());
CHECK_EQ(tick_shape->elem_cnt() % src_shape.elem_cnt(), 0);
const OperatorConf& tick_op_conf = AppendTick("AppendAcc", op_nodes, job_builder);
const OperatorConf& tick_op_conf = AppendTick("AppendAcc", op_nodes, tick_shape, job_builder);
OperatorConf acc_op_conf;
{
acc_op_conf.set_name(std::string("System-AutoTick-AccTick_") + NewUniqueId());
......@@ -329,7 +333,8 @@ std::vector<OperatorConf> AddTickForTimeShape(const Shape& src_time_shape,
const std::pair<Shape, Shape>& ts = pair.first.second;
if (ts.second.elem_cnt() == src_time_shape.elem_cnt()) {
CHECK_GE(ts.first.elem_cnt(), ts.second.elem_cnt());
op_confs.push_back(AppendTick("Append", pair.second, job_builder));
op_confs.push_back(
AppendTick("Append", pair.second, std::make_shared<const Shape>(ts.second), job_builder));
} else if (ts.second.elem_cnt() > src_time_shape.elem_cnt()) {
op_confs.push_back(AppendAccTick(src_time_shape, pair.second, job_builder));
} else {
......@@ -419,7 +424,7 @@ void AutoSourceAndSinkTick(const OpGraph& op_graph, JobBuilder* job_builder) {
op_graph.ForEachDataAndCtrlOutNode(op_node, [&](OpNode*) { ++out_cnt; });
if (out_cnt > 0) { return; }
CHECK(op_node->op().op_conf().has_device_tick_conf());
CHECK(*CHECK_JUST(op_node->op().GetOpTimeShape()) == src_time_shape);
CHECK(CHECK_JUST(op_node->op().GetOpTimeShape())->elem_cnt() == src_time_shape.elem_cnt());
CHECK(tick_lbis.emplace(op_node->op().BnInOp2Lbi(op_node->op().SoleObn())).second);
});
OperatorConf src_subset_tick = CHECK_JUST(FindSrcSubsetTickOpConf(job_builder->job()));
......
......@@ -50,9 +50,19 @@ Maybe<void> AccTickOp::InferOpTimeShape(
const std::function<Maybe<const Shape>(const std::string&)>& GetTimeShape4BnInOp,
std::shared_ptr<const Shape>* time_shape) const {
const int32_t max_acc_num = op_conf().acc_tick_conf().max_acc_num();
CHECK_EQ_OR_RETURN(JUST(GetTimeShape4BnInOp("one"))->elem_cnt() % max_acc_num, 0);
std::shared_ptr<Shape> op_time_shape(
new Shape({JUST(GetTimeShape4BnInOp("one"))->elem_cnt() / max_acc_num}));
std::shared_ptr<const Shape> in_shape = JUST(GetTimeShape4BnInOp("one"));
CHECK_EQ_OR_RETURN(in_shape->elem_cnt() % max_acc_num, 0);
DimVector in_dim_vec = in_shape->dim_vec();
std::shared_ptr<Shape> op_time_shape;
if (in_dim_vec.back() == max_acc_num) {
in_dim_vec.pop_back();
op_time_shape.reset(new Shape(in_dim_vec));
} else if (in_dim_vec.back() % max_acc_num == 0) {
in_dim_vec.back() /= max_acc_num;
op_time_shape.reset(new Shape(in_dim_vec));
} else {
op_time_shape.reset(new Shape({in_shape->elem_cnt() / max_acc_num}));
}
*time_shape = op_time_shape;
return Maybe<void>::Ok();
}
......
......@@ -53,6 +53,34 @@ Maybe<void> DeviceTickOp::GetSbpSignatures(
return Maybe<void>::Ok();
}
Maybe<void> DeviceTickOp::InferOpTimeShape(
const std::function<Maybe<const Shape>(const std::string&)>& GetTimeShape4BnInOp,
std::shared_ptr<const Shape>* time_shape) const {
std::shared_ptr<const Shape> in_time_shape;
for (const auto& bn : input_bns()) {
std::shared_ptr<const Shape> ts = JUST(GetTimeShape4BnInOp(bn));
if (!in_time_shape) {
in_time_shape = ts;
} else {
CHECK_OR_RETURN(*in_time_shape == *ts);
}
}
if (this->op_conf().device_tick_conf().has_time_shape()) {
if (!in_time_shape) {
in_time_shape.reset(new Shape(this->op_conf().device_tick_conf().time_shape()));
} else {
CHECK_OR_RETURN(in_time_shape->elem_cnt()
== Shape(this->op_conf().device_tick_conf().time_shape()).elem_cnt());
}
}
if (in_time_shape) {
*time_shape = in_time_shape;
} else {
*time_shape = std::make_shared<const Shape>(Shape({1, 1}));
}
return Maybe<void>::Ok();
}
REGISTER_OP_SAME_OUTPUT_BLOB_REGST_NUM(OperatorConf::kDeviceTickConf, 2);
REGISTER_OP(OperatorConf::kDeviceTickConf, DeviceTickOp);
REGISTER_TICK_TOCK_OP(OperatorConf::kDeviceTickConf);
......
......@@ -33,6 +33,9 @@ class DeviceTickOp final : public Operator {
Maybe<void> InferOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override;
Maybe<void> InferOpTimeShape(
const std::function<Maybe<const Shape>(const std::string&)>& GetTimeShape4BnInOp,
std::shared_ptr<const Shape>* time_shape) const override;
private:
Maybe<void> GetSbpSignatures(
......
......@@ -171,6 +171,7 @@ message TickOpConf {
message DeviceTickOpConf {
repeated string tick = 1;
required string out = 2;
optional ShapeProto time_shape = 3;
}
message WaitAndSendIdsOpConf {
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment