diff --git a/oneflow/api/python/profiler.cpp b/oneflow/api/python/profiler.cpp index d8e0e83cfa4592922467b3de09204c9fb1884634..5848dc8a997cd2cd27477bc17ba0b7be789a94c0 100644 --- a/oneflow/api/python/profiler.cpp +++ b/oneflow/api/python/profiler.cpp @@ -26,6 +26,10 @@ ONEFLOW_API_PYBIND11_MODULE("profiler", m) { m.def("RangePush", [](const std::string& str) { OF_PROFILER_RANGE_PUSH(str); }); m.def("RangePop", []() { OF_PROFILER_RANGE_POP(); }); + + m.def("ProfilerStart", []() { profiler::ProfilerStart(); }); + + m.def("ProfilerStop", []() { profiler::ProfilerStop(); }); } } // namespace oneflow diff --git a/oneflow/core/profiler/profiler.cpp b/oneflow/core/profiler/profiler.cpp index 5ea3d84ac03532994283894128f9141a2488efb3..b3502d340875e62982274ee2e7bffd9d307a08b9 100644 --- a/oneflow/core/profiler/profiler.cpp +++ b/oneflow/core/profiler/profiler.cpp @@ -19,6 +19,8 @@ limitations under the License. #include <nvtx3/nvToolsExt.h> #include <sys/syscall.h> #include <iostream> +#include <cuda_profiler_api.h> +#include "oneflow/core/device/cuda_util.h" #endif // OF_ENABLE_PROFILER namespace oneflow { @@ -107,6 +109,18 @@ void LogHostMemoryUsage(const std::string& name) { #endif // OF_ENABLE_PROFILER } +void ProfilerStart() { +#ifdef OF_ENABLE_PROFILER + OF_CUDA_CHECK(cudaProfilerStart()); +#endif // OF_ENABLE_PROFILER +} + +void ProfilerStop() { +#ifdef OF_ENABLE_PROFILER + OF_CUDA_CHECK(cudaProfilerStop()); +#endif // OF_ENABLE_PROFILER +} + } // namespace profiler } // namespace oneflow diff --git a/oneflow/core/profiler/profiler.h b/oneflow/core/profiler/profiler.h index f6c26131a8c479d791aa088be42dc14c5de97b6a..7a90fa8c5a943b41582e252046a812d1269ff636 100644 --- a/oneflow/core/profiler/profiler.h +++ b/oneflow/core/profiler/profiler.h @@ -32,6 +32,10 @@ void RangePop(); void LogHostMemoryUsage(const std::string& name); +void ProfilerStart(); + +void ProfilerStop(); + class RangeGuardCtx; class RangeGuard final { diff --git a/oneflow/python/framework/profiler.py b/oneflow/python/framework/profiler.py index 0fe5c2a14dbd8708fb1a894f1cfbb159e99f4551..6ba67601da45c87594b5d95a851ca76151846575 100644 --- a/oneflow/python/framework/profiler.py +++ b/oneflow/python/framework/profiler.py @@ -27,3 +27,13 @@ def RangePush(range_name): @oneflow_export("profiler.range_pop") def RangePop(): oneflow._oneflow_internal.profiler.RangePop() + + +@oneflow_export("profiler.profiler_start") +def ProfilerStart(): + oneflow._oneflow_internal.profiler.ProfilerStart() + + +@oneflow_export("profiler.profiler_stop") +def ProfilerStop(): + oneflow._oneflow_internal.profiler.ProfilerStop()