Skip to content
Snippets Groups Projects
Unverified Commit b7e9c264 authored by daquexian's avatar daquexian Committed by GitHub
Browse files

remove SourceIntruction, add ResettingIdToObjectMap, fix bug in...

remove SourceIntruction, add ResettingIdToObjectMap, fix bug in ForEachMutMirroredObject4MutPhyInstrOperand (#4734)

* Add NeedsRunInAdvance()

* Rename to ResettingIdToObjectMap

* rename IsImmediateOperandsOnly -> HasImmediateOperandsOnly

Signed-off-by: default avatardaquexian <daquexian566@gmail.com>

* Fix ForEachMut2MirroredObject -> ForEachMutMirroredObject in ForEachMutMirroredObject4MutPhyInstrOperand

Signed-off-by: default avatardaquexian <daquexian566@gmail.com>

* split CHECK(..&&..) to two CHECK(..)

Signed-off-by: default avatardaquexian <daquexian566@gmail.com>

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent 96b7cf8e
No related branches found
No related tags found
No related merge requests found
......@@ -32,6 +32,8 @@ class NewSymbolInstructionType final : public InstructionType {
NewSymbolInstructionType() = default;
~NewSymbolInstructionType() override = default;
bool ResettingIdToObjectMap() const override { return true; }
using stream_type = ControlStreamType;
// clang-format off
......
......@@ -33,6 +33,7 @@ class InstructionType {
bool IsSequential() const { return IsFrontSequential(); }
virtual bool IsFrontSequential() const { return false; }
virtual bool ResettingIdToObjectMap() const { return false; }
virtual void Compute(Instruction* instruction) const = 0;
virtual void Infer(Instruction* instruction) const = 0;
......
......@@ -63,6 +63,8 @@ class NewObjectInstructionType final : public InstructionType {
NewObjectInstructionType() = default;
~NewObjectInstructionType() override = default;
bool ResettingIdToObjectMap() const override { return true; }
using stream_type = ControlStreamType;
// clang-format off
......@@ -111,6 +113,8 @@ class BroadcastObjectReferenceInstructionType final : public InstructionType {
BroadcastObjectReferenceInstructionType() = default;
~BroadcastObjectReferenceInstructionType() override = default;
bool ResettingIdToObjectMap() const override { return true; }
using stream_type = ControlStreamType;
// clang-format off
......@@ -168,6 +172,8 @@ class ReplaceMirroredInstructionType final : public InstructionType {
ReplaceMirroredInstructionType() = default;
~ReplaceMirroredInstructionType() override = default;
bool ResettingIdToObjectMap() const override { return true; }
using stream_type = ControlStreamType;
// clang-format off
......
......@@ -32,6 +32,8 @@ class NewParallelDescSymbolInstructionType final : public InstructionType {
NewParallelDescSymbolInstructionType() = default;
~NewParallelDescSymbolInstructionType() override = default;
bool ResettingIdToObjectMap() const override { return true; }
using stream_type = ControlStreamType;
// clang-format off
......
......@@ -27,7 +27,7 @@ namespace vm {
namespace {
bool IsSourceInstruction(const InstructionMsg& instr_msg) {
bool HasImmediateOperandsOnly(const InstructionMsg& instr_msg) {
for (const auto& instr_operand : instr_msg.operand()) {
if (instr_operand->has_const_operand()) { return false; }
if (instr_operand->has_mut_operand()) { return false; }
......@@ -90,12 +90,13 @@ void VirtualMachine::TryReleaseFinishedInstructions(
}
}
void VirtualMachine::FilterAndRunSourceInstructions(TmpPendingInstrMsgList* instr_msg_list) {
void VirtualMachine::FilterAndRunInstructionsInAdvance(TmpPendingInstrMsgList* instr_msg_list) {
OBJECT_MSG_LIST_FOR_EACH_PTR(instr_msg_list, instr_msg) {
const auto& instr_type_id = instr_msg->instr_type_id();
const StreamType& stream_type = instr_type_id.stream_type_id().stream_type();
if (stream_type.IsControlStreamType() && !instr_type_id.instruction_type().IsSequential()
&& IsSourceInstruction(*instr_msg)) {
if (instr_type_id.instruction_type().ResettingIdToObjectMap()) {
const StreamType& stream_type = instr_type_id.stream_type_id().stream_type();
CHECK(stream_type.IsControlStreamType());
CHECK(HasImmediateOperandsOnly(*instr_msg));
const auto& parallel_desc = CHECK_JUST(GetInstructionParallelDesc(*instr_msg));
if (!parallel_desc || parallel_desc->ContainingMachineId(this_machine_id())) {
stream_type.Run(this, instr_msg);
......@@ -291,7 +292,7 @@ void ForEachMutMirroredObject4MutPhyInstrOperand(InterpretType interpret_type,
phy_instr_operand.ForEachMutMirroredObject(
[&](MirroredObject* infer, MirroredObject* compute) { Callback(compute); });
} else if (interpret_type == InterpretType::kInfer) {
phy_instr_operand.ForEachMut2MirroredObject(
phy_instr_operand.ForEachMutMirroredObject(
[&](MirroredObject* infer, MirroredObject* compute) { Callback(infer); });
} else {
UNIMPLEMENTED();
......@@ -639,7 +640,7 @@ void VirtualMachine::Schedule() {
if (pending_msg_list().size() > 0) {
TmpPendingInstrMsgList tmp_pending_msg_list;
mut_pending_msg_list()->MoveTo(&tmp_pending_msg_list);
FilterAndRunSourceInstructions(&tmp_pending_msg_list);
FilterAndRunInstructionsInAdvance(&tmp_pending_msg_list);
NewInstructionList new_instruction_list;
MakeInstructions(&tmp_pending_msg_list, /*out*/ &new_instruction_list);
ConsumeMirroredObjects(mut_id2logical_object(), &new_instruction_list);
......
......@@ -89,7 +89,7 @@ OBJECT_MSG_BEGIN(VirtualMachine);
/*out*/ ReadyInstructionList* ready_instruction_list);
void TryReleaseFinishedInstructions(
Stream* stream, /*out*/ ReadyInstructionList* ready_instruction_list);
void FilterAndRunSourceInstructions(TmpPendingInstrMsgList* instr_msg_list);
void FilterAndRunInstructionsInAdvance(TmpPendingInstrMsgList* instr_msg_list);
void MakeInstructions(TmpPendingInstrMsgList*, /*out*/ NewInstructionList* ret_instruction_list);
template<int64_t (*TransformLogicalObjectId)(int64_t), typename DoEachT>
void ForEachMirroredObject(Id2LogicalObject* id2logical_object,
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment