Skip to content
Snippets Groups Projects
Commit 06a105a7 authored by lixinqi's avatar lixinqi
Browse files

gpu direct watch_op

parent bd42971f
No related branches found
No related tags found
No related merge requests found
......@@ -4,7 +4,7 @@
namespace oneflow {
void AddLbiDiffWatherOpConfs(Job* job) {
void AddLbiDiffWatcherOpConfs(Job* job) {
JobBuilder job_builder(job);
const auto& map = Global<LbiDiffWatcherInfo>::Get()->job_name2lbi_and_watcher_uuids();
if (map.find(GlobalJobDesc().job_name()) == map.end()) { return; }
......@@ -16,7 +16,6 @@ void AddLbiDiffWatherOpConfs(Job* job) {
CHECK(lbi2diff_lbi.emplace(pair.first(), pair.second()).second);
}
const auto& pair_list = map.at(GlobalJobDesc().job_name()).lbi_and_uuid_pair();
std::vector<OperatorConf> op_confs;
for (const LbiAndDiffWatcherUuidPair& pair : pair_list) {
if (lbi2diff_lbi.find(pair.lbi()) == lbi2diff_lbi.end()) { continue; }
OperatorConf foreign_watcher_op;
......@@ -24,11 +23,8 @@ void AddLbiDiffWatherOpConfs(Job* job) {
auto* foreign_watcher_conf = foreign_watcher_op.mutable_foreign_watch_conf();
foreign_watcher_conf->set_in(GenLogicalBlobName(lbi2diff_lbi.at(pair.lbi())));
foreign_watcher_conf->set_handler_uuid(pair.watcher_uuid());
op_confs.push_back(foreign_watcher_op);
job_builder.AddOps(job_builder.ParallelConf4Lbi(pair.lbi()), {foreign_watcher_op});
}
ParallelConf parallel_conf;
parallel_conf.add_device_name("0:cpu:0");
job_builder.AddOps(parallel_conf, op_confs);
}
} // namespace oneflow
......@@ -5,7 +5,7 @@
namespace oneflow {
void AddLbiDiffWatherOpConfs(Job* job);
void AddLbiDiffWatcherOpConfs(Job* job);
}
#endif // ONEFLOW_CORE_JOB_COMPLETER_ADD_LBI_DIFF_WATCHER_H_
......@@ -360,7 +360,7 @@ void JobCompleter::Complete(Job* job) const {
WithOpGraphAndMutJobBuilder(job, &GenerateOpConf4Trainning);
WithOpGraphAndMutJobBuilder(job, &MakeNcclTupleBroadcastReduceSequence);
WithOpGraphAndMutJobBuilder(job, &RewriteBoxingWithAllReduce);
AddLbiDiffWatherOpConfs(job);
AddLbiDiffWatcherOpConfs(job);
WithOpGraphAndMutJobBuilder(job, &MakeAllReduceSequence);
}
WithOpGraphAndMutJobBuilder(job, &DumpLogicalBlobDescAndSbpSignature);
......
......@@ -6,13 +6,41 @@
namespace oneflow {
void ForeignWatchKernel::ForwardDataContent(
template<>
void ForeignWatchKernel<DeviceType::kCPU>::WithInBlob(DeviceCtx* ctx, Blob* blob,
std::function<void(Blob*)> Handler) const {
Handler(blob);
}
template<>
void ForeignWatchKernel<DeviceType::kGPU>::WithInBlob(DeviceCtx* ctx, Blob* blob,
std::function<void(Blob*)> Handler) const {
char* host_raw_dptr = nullptr;
CudaCheck(cudaMallocHost(&host_raw_dptr, blob->AlignedTotalByteSize()));
MemoryCase mem_case;
mem_case.mutable_host_mem();
Blob host_blob(mem_case, &blob->blob_desc(), host_raw_dptr);
Memcpy<DeviceType::kGPU>(ctx, host_blob.mut_dptr(), blob->dptr(), blob->ByteSizeOfBlobBody(),
cudaMemcpyDeviceToHost);
Handler(&host_blob);
CudaCheck(cudaStreamSynchronize(ctx->cuda_stream()));
CudaCheck(cudaFreeHost(host_raw_dptr));
}
template<DeviceType device_type>
void ForeignWatchKernel<device_type>::ForwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
OfBlob of_blob(ctx.device_ctx, BnInOp2Blob("in"));
Global<ForeignWatcher>::Get()->Call(op_conf().foreign_watch_conf().handler_uuid(),
reinterpret_cast<int64_t>(&of_blob));
WithInBlob(ctx.device_ctx, BnInOp2Blob("in"), [&](Blob* in_blob) {
OfBlob of_blob(ctx.device_ctx, in_blob);
Global<ForeignWatcher>::Get()->Call(this->op_conf().foreign_watch_conf().handler_uuid(),
reinterpret_cast<int64_t>(&of_blob));
});
}
REGISTER_KERNEL(OperatorConf::kForeignWatchConf, ForeignWatchKernel);
REGISTER_KERNEL_WITH_DEVICE(OperatorConf::kForeignWatchConf, DeviceType::kCPU,
ForeignWatchKernel<DeviceType::kCPU>);
REGISTER_KERNEL_WITH_DEVICE(OperatorConf::kForeignWatchConf, DeviceType::kGPU,
ForeignWatchKernel<DeviceType::kGPU>);
} // namespace oneflow
......@@ -5,13 +5,15 @@
namespace oneflow {
class ForeignWatchKernel final : public KernelIf<DeviceType::kCPU> {
template<DeviceType device_type>
class ForeignWatchKernel final : public KernelIf<device_type> {
public:
OF_DISALLOW_COPY_AND_MOVE(ForeignWatchKernel);
ForeignWatchKernel() = default;
~ForeignWatchKernel() = default;
private:
void WithInBlob(DeviceCtx* device_ctx, Blob* blob, std::function<void(Blob*)> Handler) const;
void ForwardDataContent(const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
};
......
......@@ -24,9 +24,14 @@ Maybe<void> ForeignWatchOp::InferBatchAxis(
return Maybe<void>::Ok();
}
Maybe<void> ForeignWatchOp::GetSbpSignatures(
const std::function<Maybe<const BlobDesc*>(const std::string&)>& LogicalBlobDesc4Ibn,
SbpSignatureList* sbp_sig_list) const {
Maybe<void> ForeignWatchOp::InferSbpSignature(
SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf,
const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig,
std::function<Maybe<const SbpInferHint*>(const std::string&)> SbpInferHint4Ibn,
const ParallelDesc& parallel_desc) const {
OF_CHECK_EQ(JUST(SbpInferHint4Ibn("in"))->parallel_desc().parallel_num(), 1);
OF_CHECK(JUST(SbpInferHint4Ibn("in"))->parallel_desc() == parallel_desc);
(*sbp_signature->mutable_bn_in_op2sbp_parallel())["in"].mutable_split_parallel()->set_axis(0);
return Maybe<void>::Ok();
}
......
......@@ -16,14 +16,15 @@ class ForeignWatchOp final : public Operator {
const PbMessage& GetCustomizedConf() const override;
Maybe<void> InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override;
LogicalNode* NewProperLogicalNode() const override { return new ForeignOutputLogicalNode; }
private:
Maybe<void> InferBatchAxis(
std::function<OptInt64*(const std::string&)> BatchAxis4BnInOp) const override;
Maybe<void> GetSbpSignatures(
const std::function<Maybe<const BlobDesc*>(const std::string&)>& LogicalBlobDesc4Ibn,
SbpSignatureList* sbp_sig_list) const override;
Maybe<void> InferSbpSignature(
SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf,
const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig,
std::function<Maybe<const SbpInferHint*>(const std::string&)> SbpInferHint4Ibn,
const ParallelDesc& parallel_desc) const override;
};
} // namespace oneflow
......
......@@ -25,7 +25,7 @@ def watch(watched, handler = None):
op_conf.name = id_util.UniqueStr("ForeignWatch_")
setattr(op_conf.foreign_watch_conf, "in", watched.logical_blob_name)
op_conf.foreign_watch_conf.handler_uuid = handler_uuid
with oneflow.fixed_placement("cpu", "0:0"): compile_context.CurJobAddOp(op_conf)
compile_context.CurJobAddOp(op_conf, watched.parallel_conf)
watcher_util.BindUuidAndHandler(handler_uuid, handler)
@oneflow_export("watch_diff")
......
......@@ -4,19 +4,13 @@ import numpy as np
flow.config.gpu_device_num(1)
flow.config.grpc_use_no_signal()
flow.config.piece_size(10)
flow.config.default_data_type(flow.float)
def Print(x, y):
print("x: ")
print(x)
print("y: ")
print(y)
@flow.function
def ReluJob(x = flow.input_blob_def((10,))):
y = flow.keras.activations.relu(x)
flow.watch([x, y], Print)
flow.watch(x, "x:")
flow.watch(y, "y:")
index = [-2, -1, 0, 1, 2]
data = []
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment