diff --git a/oneflow/core/framework/instructions_builder.cpp b/oneflow/core/framework/instructions_builder.cpp index a946f2f36b49572bbeb79dc47322e051a39ecb79..cf8511199d9f69e9c4212a1544bdc7cbdb57eae8 100644 --- a/oneflow/core/framework/instructions_builder.cpp +++ b/oneflow/core/framework/instructions_builder.cpp @@ -896,7 +896,7 @@ Maybe<void> InstructionsBuilder::AccessBlobByCallback( JUST(tensor->compute_local_dep_object()); *instruction->mutable_phy_instr_operand() = std::make_shared<vm::AccessBlobArgCbPhyInstrOperand>( eager_blob_object, infer_local_dep_object, compute_local_dep_object, callback, modifier); - instruction->set_parallel_desc_symbol_id(JUST(tensor->parallel_desc()->symbol_id())); + *instruction->mut_parallel_desc() = tensor->parallel_desc(); instruction_list_->EmplaceBack(std::move(instruction.Mutable())); return Maybe<void>::Ok(); } diff --git a/oneflow/core/vm/instruction.cpp b/oneflow/core/vm/instruction.cpp index 5b62d5b1c2e097b9b04d5cceedcf143a4bc207ce..4739062ba5c868745634f94b897b900e1b1dc3d3 100644 --- a/oneflow/core/vm/instruction.cpp +++ b/oneflow/core/vm/instruction.cpp @@ -314,7 +314,7 @@ const MirroredObject* Instruction::GetMirroredObject(const Operand& operand, int64_t Instruction::GetOperandDefaultGlobalDeviceId() const { return stream().global_device_id(); } void Instruction::__Init__(InstructionMsg* instr_msg, Stream* stream, - const std::shared_ptr<ParallelDesc>& parallel_desc) { + const std::shared_ptr<const ParallelDesc>& parallel_desc) { mutable_status_buffer(); reset_instr_msg(instr_msg); set_stream(stream); diff --git a/oneflow/core/vm/instruction.msg.h b/oneflow/core/vm/instruction.msg.h index 3939cf48e6cd9c8c471e0e64d8f3203e9167743c..d38e6fac80d7d765434d43d8adefbd8242d28923 100644 --- a/oneflow/core/vm/instruction.msg.h +++ b/oneflow/core/vm/instruction.msg.h @@ -82,6 +82,7 @@ OBJECT_MSG_BEGIN(InstructionMsg); // instr_type_name is a necessary reduandant field for method ToProto OBJECT_MSG_DEFINE_STRUCT(std::string, instr_type_name); OBJECT_MSG_DEFINE_OPTIONAL(int64_t, parallel_desc_symbol_id); + OBJECT_MSG_DEFINE_STRUCT(std::shared_ptr<const ParallelDesc>, parallel_desc); OBJECT_MSG_DEFINE_OPTIONAL(InstructionOperandList, operand_list); OBJECT_MSG_DEFINE_STRUCT(std::shared_ptr<PhyInstrOperand>, phy_instr_operand); @@ -124,7 +125,7 @@ class Stream; // clang-format off OBJECT_MSG_BEGIN(Instruction); // methods - OF_PUBLIC void __Init__(InstructionMsg* instr_msg, Stream* stream, const std::shared_ptr<ParallelDesc>& parallel_desc); + OF_PUBLIC void __Init__(InstructionMsg* instr_msg, Stream* stream, const std::shared_ptr<const ParallelDesc>& parallel_desc); OF_PUBLIC void __Delete__(); OF_PUBLIC bool Done() const; OF_PUBLIC const StreamType& stream_type() const; @@ -216,7 +217,7 @@ OBJECT_MSG_BEGIN(Instruction); // fields OBJECT_MSG_DEFINE_FLAT_MSG(InstructionStatusBuffer, status_buffer); OBJECT_MSG_DEFINE_OPTIONAL(InstructionMsg, instr_msg); - OBJECT_MSG_DEFINE_STRUCT(std::shared_ptr<ParallelDesc>, parallel_desc); + OBJECT_MSG_DEFINE_STRUCT(std::shared_ptr<const ParallelDesc>, parallel_desc); OBJECT_MSG_DEFINE_PTR(Stream, stream); // links diff --git a/oneflow/core/vm/stream.cpp b/oneflow/core/vm/stream.cpp index 503208d617ed0497cefb8fabb92869869f1d51a4..ac8aa2ab93d5625cf1490c9a5107bf36a6d3b054 100644 --- a/oneflow/core/vm/stream.cpp +++ b/oneflow/core/vm/stream.cpp @@ -43,7 +43,7 @@ const StreamTypeId& Stream::stream_type_id() const { } ObjectMsgPtr<Instruction> Stream::NewInstruction( - InstructionMsg* instr_msg, const std::shared_ptr<ParallelDesc>& parallel_desc) { + InstructionMsg* instr_msg, const std::shared_ptr<const ParallelDesc>& parallel_desc) { if (free_instruction_list().empty()) { return ObjectMsgPtr<Instruction>::NewFrom(mut_allocator(), instr_msg, this, parallel_desc); } diff --git a/oneflow/core/vm/stream.msg.h b/oneflow/core/vm/stream.msg.h index 459a87c91a353963d90290c315ffd606e7f8a5e9..414825d84116f20e78390bf5efb73fa5e7dd1277 100644 --- a/oneflow/core/vm/stream.msg.h +++ b/oneflow/core/vm/stream.msg.h @@ -29,7 +29,7 @@ class ThreadCtx; OBJECT_MSG_BEGIN(Stream); // methods OF_PUBLIC void __Init__(ThreadCtx* thread_ctx, const StreamId& stream_id, const int64_t max_device_num_per_machine); - OF_PUBLIC ObjectMsgPtr<Instruction> NewInstruction(InstructionMsg* instr_msg, const std::shared_ptr<ParallelDesc>& parallel_desc); + OF_PUBLIC ObjectMsgPtr<Instruction> NewInstruction(InstructionMsg* instr_msg, const std::shared_ptr<const ParallelDesc>& parallel_desc); OF_PUBLIC void DeleteInstruction(ObjectMsgPtr<Instruction>&&); OF_PUBLIC int64_t global_device_id() const { return stream_id().global_device_id(); } OF_PUBLIC int64_t machine_id() const; diff --git a/oneflow/core/vm/virtual_machine.cpp b/oneflow/core/vm/virtual_machine.cpp index 3d8036de2a4a3b4b343e840eaa712221283eeb1d..aa530f8e09217214ce3109195950a10f5e6c06fe 100644 --- a/oneflow/core/vm/virtual_machine.cpp +++ b/oneflow/core/vm/virtual_machine.cpp @@ -154,15 +154,19 @@ void VirtualMachine::MakeInstructions(TmpPendingInstrMsgList* instr_msg_list, } } -Maybe<ParallelDesc> VirtualMachine::GetInstructionParallelDesc(const InstructionMsg& instr_msg) { - static const std::shared_ptr<ParallelDesc> empty_ptr; +Maybe<const ParallelDesc> VirtualMachine::GetInstructionParallelDesc( + const InstructionMsg& instr_msg) { + static const std::shared_ptr<const ParallelDesc> empty_ptr; + if (instr_msg.parallel_desc()) { return instr_msg.parallel_desc(); } if (!instr_msg.has_parallel_desc_symbol_id()) { return empty_ptr; } int64_t symbol_id = instr_msg.parallel_desc_symbol_id(); auto* logical_object = mut_id2logical_object()->FindPtr(symbol_id); CHECK_NOTNULL_OR_RETURN(logical_object) << "symbol_id: " << symbol_id; auto* map = logical_object->mut_global_device_id2mirrored_object(); CHECK_EQ_OR_RETURN(map->size(), 1); - return JUST(map->Begin()->rw_mutexed_object().Get<ObjectWrapper<ParallelDesc>>()).GetPtr(); + const std::shared_ptr<const ParallelDesc> parallel_desc = + JUST(map->Begin()->rw_mutexed_object().Get<ObjectWrapper<ParallelDesc>>()).GetPtr(); + return parallel_desc; } MirroredObject* VirtualMachine::MutMirroredObject(int64_t logical_object_id, diff --git a/oneflow/core/vm/virtual_machine.msg.h b/oneflow/core/vm/virtual_machine.msg.h index 3ba4023f2d33c4d2729be937b8442408800303b8..880acbf2f0b3d1d4705f4fc713bcf4a4014e6eb6 100644 --- a/oneflow/core/vm/virtual_machine.msg.h +++ b/oneflow/core/vm/virtual_machine.msg.h @@ -42,7 +42,7 @@ OBJECT_MSG_BEGIN(VirtualMachine); OF_PUBLIC void Receive(ObjectMsgPtr<InstructionMsg>&& instruction_msg); OF_PUBLIC void Schedule(); OF_PUBLIC bool Empty() const; - OF_PUBLIC Maybe<ParallelDesc> GetInstructionParallelDesc(const InstructionMsg&); + OF_PUBLIC Maybe<const ParallelDesc> GetInstructionParallelDesc(const InstructionMsg&); OF_PUBLIC MirroredObject* MutMirroredObject(int64_t logical_object_id, int64_t global_device_id); OF_PUBLIC const MirroredObject* GetMirroredObject(int64_t logical_object_id, int64_t global_device_id); diff --git a/oneflow/core/vm/vm_object.msg.h b/oneflow/core/vm/vm_object.msg.h index c3317696edfaa4ea41cb0d1c3069d555cfbf746f..901aea0d3ffc7b53423bfc794b99c3ad1ffaac81 100644 --- a/oneflow/core/vm/vm_object.msg.h +++ b/oneflow/core/vm/vm_object.msg.h @@ -118,16 +118,16 @@ class VirtualMachine; OBJECT_MSG_BEGIN(LogicalObject); // methods OF_PUBLIC void __Init__(const ObjectId& logical_object_id) { - __Init__(logical_object_id, std::shared_ptr<ParallelDesc>()); + __Init__(logical_object_id, std::shared_ptr<const ParallelDesc>()); } OF_PUBLIC void __Init__(const ObjectId& logical_object_id, - const std::shared_ptr<ParallelDesc>& parallel_desc) { + const std::shared_ptr<const ParallelDesc>& parallel_desc) { set_logical_object_id(logical_object_id); *mutable_parallel_desc() = parallel_desc; } // fields - OBJECT_MSG_DEFINE_STRUCT(std::shared_ptr<ParallelDesc>, parallel_desc); + OBJECT_MSG_DEFINE_STRUCT(std::shared_ptr<const ParallelDesc>, parallel_desc); // links OBJECT_MSG_DEFINE_MAP_KEY(ObjectId, logical_object_id);