diff --git a/test/cpp/profiler/containers.cpp b/test/cpp/profiler/containers.cpp index c0e0bf14745c8..5f8e974343b94 100644 --- a/test/cpp/profiler/containers.cpp +++ b/test/cpp/profiler/containers.cpp @@ -77,3 +77,18 @@ TEST(ProfilerTest, clock_converter) { EXPECT_LT(std::abs(deltas[n / 2]), 200); EXPECT_LT(deltas[n * 3 / 4] - deltas[n / 4], 50); } + +TEST(ProfilerTest, soft_assert) { + EXPECT_TRUE(SOFT_ASSERT(true)); + torch::profiler::impl::setSoftAssertRaises(true); + EXPECT_ANY_THROW(SOFT_ASSERT(false)); + torch::profiler::impl::setSoftAssertRaises(false); + EXPECT_NO_THROW(SOFT_ASSERT(false)); + // Reset soft assert behavior to default + torch::profiler::impl::setSoftAssertRaises(c10::nullopt); +#ifdef NDEBUG + EXPECT_NO_THROW(SOFT_ASSERT(false)); +#else + EXPECT_ANY_THROW(SOFT_ASSERT(false)); +#endif +} diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index dd19bbe7e8ce4..292c55becb02e 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -336,6 +336,8 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) { .def_property_readonly("duration_time_ns", [](const Result& r) { return r.endTimeNS() - r.start_time_ns_; }); + + m.def("_soft_assert_raises", &setSoftAssertRaises); } py::class_(m, "_ProfilerResult") diff --git a/torch/csrc/profiler/util.cpp b/torch/csrc/profiler/util.cpp index 6ae2e745806f3..3a26e04bfaddd 100644 --- a/torch/csrc/profiler/util.cpp +++ b/torch/csrc/profiler/util.cpp @@ -85,6 +85,24 @@ std::function ApproximateClockToUnixTimeConverter:: }; } +namespace { +c10::optional soft_assert_raises_; +} // namespace + +void setSoftAssertRaises(c10::optional value) { + soft_assert_raises_ = value; +} + +bool softAssertRaises() { + return soft_assert_raises_.value_or( +#ifdef NDEBUG + false +#else + true +#endif + ); +} + // ---------------------------------------------------------------------------- // -- NVTX -------------------------------------------------------------------- // ---------------------------------------------------------------------------- diff --git a/torch/csrc/profiler/util.h b/torch/csrc/profiler/util.h index 928e1889a0ee1..8bee4275c22f9 100644 --- a/torch/csrc/profiler/util.h +++ b/torch/csrc/profiler/util.h @@ -9,6 +9,7 @@ #include #include +#include #include #include @@ -35,9 +36,25 @@ #endif #endif +// TODO: replace with pytorch/rfcs#43 when it is ready. +#define SOFT_ASSERT(cond, ...) \ + [&]() -> bool { \ + if (C10_UNLIKELY(!(cond))) { \ + if (torch::profiler::impl::softAssertRaises()) { \ + TORCH_INTERNAL_ASSERT(cond, __VA_ARGS__); \ + } else { \ + TORCH_WARN(__VA_ARGS__); \ + } \ + return false; \ + } \ + return true; \ + }() + namespace torch { namespace profiler { namespace impl { +TORCH_API bool softAssertRaises(); +TORCH_API void setSoftAssertRaises(c10::optional value); using time_t = int64_t; using steady_clock_t = std::conditional<