Skip to content
Snippets Groups Projects
Unverified Commit 218f61e3 authored by Li Xinqi's avatar Li Xinqi Committed by GitHub
Browse files

use less event records (#4861)


* use less event records

* more comments

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent e5dadaf5
No related branches found
No related tags found
No related merge requests found
......@@ -129,6 +129,8 @@ void AccessBlobByCallbackInstructionType::Compute(vm::Instruction* instruction)
DeviceCtx* device_ctx = instruction->stream().device_ctx().get();
OfBlob ofblob(device_ctx, ptr->eager_blob_object()->mut_blob());
ptr->callback()(reinterpret_cast<uint64_t>(&ofblob));
// Always records instruction-complete event.
instruction->set_has_event_record(true);
}
class ReadTensorShapeByCallbackInstructionType : public vm::InstructionType {
......
......@@ -60,6 +60,7 @@ Maybe<void> NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in
std::shared_ptr<const Device> op_device;
std::shared_ptr<const ParallelDesc> op_parallel_desc;
CHECK_EQ(out_devices->size(), output_eager_blob_objects->size());
bool need_event_record = false;
if (!user_op_expr.has_device_infer_fn()) {
op_device = default_device;
op_parallel_desc = op_device->parallel_desc_ptr();
......@@ -72,6 +73,9 @@ Maybe<void> NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in
}
} else {
op_device = JUST(user_op_expr.InferDevices(attrs, inputs, out_devices));
for (const auto& input_tensor : inputs) {
need_event_record = need_event_record || !(*op_device == *input_tensor->device());
}
op_parallel_desc = op_device->parallel_desc_ptr();
for (int i = 0; i < output_eager_blob_objects->size(); i++) {
const auto& tensor_device = out_devices->at(i);
......@@ -98,6 +102,16 @@ Maybe<void> NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in
const auto& instr_type_name = JUST(op_device->local_call_instruction_name());
JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {
if (need_event_record) {
for (const auto& input_tensor : inputs) {
const auto& tensor = std::dynamic_pointer_cast<one::MirroredTensor>(input_tensor);
CHECK_OR_RETURN(static_cast<bool>(tensor));
// Instruction `AccessBlobByCallback` records event which can be used to synchronize cuda
// stream.
JUST(builder->AccessBlobByCallback(
tensor, [](uint64_t) {}, "mut"));
}
}
return builder->LocalCallOpKernel(kernel, input_eager_blob_objects, output_eager_blob_objects,
attrs, op_parallel_desc, instr_type_name);
}));
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifdef WITH_CUDA
#include "oneflow/core/vm/cuda_optional_event_record_status_querier.h"
#include "oneflow/core/device/device_context.h"
namespace oneflow {
namespace vm {
bool CudaOptionalEventRecordStatusQuerier::event_completed() const {
cudaSetDevice(device_id_);
return cudaEventQuery(event_) == cudaSuccess;
}
void CudaOptionalEventRecordStatusQuerier::SetLaunched(DeviceCtx* device_ctx) {
if (has_event_record_) {
cudaSetDevice(device_id_);
OF_CUDA_CHECK(
cudaEventCreateWithFlags(&event_, cudaEventBlockingSync | cudaEventDisableTiming));
OF_CUDA_CHECK(cudaEventRecord(event_, device_ctx->cuda_stream()));
}
launched_ = true;
}
} // namespace vm
} // namespace oneflow
#endif
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_VM_CUDA_OPTIONAL_EVENT_RECORD_STATUS_QUERIER_H_
#define ONEFLOW_CORE_VM_CUDA_OPTIONAL_EVENT_RECORD_STATUS_QUERIER_H_
#include <atomic>
#include "oneflow/core/device/cuda_util.h"
namespace oneflow {
class DeviceCtx;
namespace vm {
#ifdef WITH_CUDA
class CudaOptionalEventRecordStatusQuerier {
public:
~CudaOptionalEventRecordStatusQuerier() = default;
bool done() const { return launched_ && (!has_event_record_ || event_completed()); }
void set_has_event_record(bool val) { has_event_record_ = val; }
void SetLaunched(DeviceCtx* device_ctx);
static const CudaOptionalEventRecordStatusQuerier* Cast(const char* mem_ptr) {
return reinterpret_cast<const CudaOptionalEventRecordStatusQuerier*>(mem_ptr);
}
static CudaOptionalEventRecordStatusQuerier* MutCast(char* mem_ptr) {
return reinterpret_cast<CudaOptionalEventRecordStatusQuerier*>(mem_ptr);
}
static CudaOptionalEventRecordStatusQuerier* PlacementNew(char* mem_ptr, int64_t device_id) {
return new (mem_ptr) CudaOptionalEventRecordStatusQuerier(device_id);
}
private:
explicit CudaOptionalEventRecordStatusQuerier(int64_t device_id)
: launched_(false), has_event_record_(false), device_id_(device_id) {}
bool event_completed() const;
std::atomic<bool> launched_;
std::atomic<bool> has_event_record_;
int64_t device_id_;
cudaEvent_t event_;
};
#endif
} // namespace vm
} // namespace oneflow
#endif // ONEFLOW_CORE_VM_CUDA_OPTIONAL_EVENT_RECORD_STATUS_QUERIER_H_
......@@ -19,7 +19,7 @@ limitations under the License.
#include "oneflow/core/vm/instruction_type.h"
#include "oneflow/core/vm/stream.msg.h"
#include "oneflow/core/vm/thread_ctx.msg.h"
#include "oneflow/core/vm/cuda_instruction_status_querier.h"
#include "oneflow/core/vm/cuda_optional_event_record_status_querier.h"
#include "oneflow/core/vm/cuda_stream_handle_device_context.h"
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/common/util.h"
......@@ -34,8 +34,9 @@ void CudaStreamType::InitDeviceCtx(std::unique_ptr<DeviceCtx>* device_ctx, Strea
void CudaStreamType::InitInstructionStatus(const Stream& stream,
InstructionStatusBuffer* status_buffer) const {
static_assert(sizeof(CudaInstrStatusQuerier) < kInstructionStatusBufferBytes, "");
CudaInstrStatusQuerier::PlacementNew(status_buffer->mut_buffer()->mut_data(), stream.device_id());
static_assert(sizeof(CudaOptionalEventRecordStatusQuerier) < kInstructionStatusBufferBytes, "");
CudaOptionalEventRecordStatusQuerier::PlacementNew(status_buffer->mut_buffer()->mut_data(),
stream.device_id());
}
void CudaStreamType::DeleteInstructionStatus(const Stream& stream,
......@@ -45,7 +46,13 @@ void CudaStreamType::DeleteInstructionStatus(const Stream& stream,
bool CudaStreamType::QueryInstructionStatusDone(
const Stream& stream, const InstructionStatusBuffer& status_buffer) const {
return CudaInstrStatusQuerier::Cast(status_buffer.buffer().data())->done();
return CudaOptionalEventRecordStatusQuerier::Cast(status_buffer.buffer().data())->done();
}
void CudaStreamType::set_has_event_record(InstructionStatusBuffer* status_buffer, bool val) const {
auto* querier =
CudaOptionalEventRecordStatusQuerier::MutCast(status_buffer->mut_buffer()->mut_data());
return querier->set_has_event_record(val);
}
void CudaStreamType::Compute(Instruction* instruction) const {
......@@ -59,7 +66,7 @@ void CudaStreamType::Compute(Instruction* instruction) const {
}
stream->mut_callback_list()->MoveTo(instruction->mut_callback_list());
char* data_ptr = instruction->mut_status_buffer()->mut_buffer()->mut_data();
CudaInstrStatusQuerier::MutCast(data_ptr)->SetLaunched(stream->device_ctx().get());
CudaOptionalEventRecordStatusQuerier::MutCast(data_ptr)->SetLaunched(stream->device_ctx().get());
}
ObjectMsgPtr<StreamDesc> CudaStreamType::MakeStreamDesc(const Resource& resource,
......
......@@ -42,6 +42,7 @@ class CudaStreamType final : public StreamType {
InstructionStatusBuffer* status_buffer) const override;
bool QueryInstructionStatusDone(const Stream& stream,
const InstructionStatusBuffer& status_buffer) const override;
void set_has_event_record(InstructionStatusBuffer* status_buffer, bool val) const override;
void Compute(Instruction* instruction) const override;
ObjectMsgPtr<StreamDesc> MakeStreamDesc(const Resource& resource,
int64_t this_machine_id) const override;
......
......@@ -333,6 +333,10 @@ bool Instruction::Done() const {
return stream_type().QueryInstructionStatusDone(stream(), status_buffer());
}
void Instruction::set_has_event_record(bool val) {
return stream_type().set_has_event_record(mut_status_buffer(), val);
}
const StreamType& Instruction::stream_type() const { return stream().stream_type(); }
} // namespace vm
......
......@@ -128,6 +128,7 @@ OBJECT_MSG_BEGIN(Instruction);
OF_PUBLIC void __Init__(InstructionMsg* instr_msg, Stream* stream, const std::shared_ptr<const ParallelDesc>& parallel_desc);
OF_PUBLIC void __Delete__();
OF_PUBLIC bool Done() const;
OF_PUBLIC void set_has_event_record(bool val);
OF_PUBLIC const StreamType& stream_type() const;
OF_PUBLIC template<OperandMemZoneModifier mem_zone_modifier>
......
......@@ -57,6 +57,9 @@ class StreamType {
InstructionStatusBuffer* status_buffer) const = 0;
virtual bool QueryInstructionStatusDone(const Stream& stream,
const InstructionStatusBuffer& status_buffer) const = 0;
virtual void set_has_event_record(InstructionStatusBuffer* status_buffer, bool val) const {
// Do nothing.
}
virtual void Compute(Instruction* instruction) const = 0;
virtual void Infer(Instruction* instruction) const { LOG(FATAL) << "UNIMPLEMENTED"; }
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment