diff --git a/test/cpp/jit/test_backend_compiler_lib.cpp b/test/cpp/jit/test_backend_compiler_lib.cpp index 078c405195f286..baa54b0024e45b 100644 --- a/test/cpp/jit/test_backend_compiler_lib.cpp +++ b/test/cpp/jit/test_backend_compiler_lib.cpp @@ -145,6 +145,15 @@ class BackendWithCompiler : public PyTorchBackendInterface { auto x_ptr = float_data_ptr(x); auto h_ptr = float_data_ptr(h); auto y_ptr = float_data_ptr(y); +#ifndef NO_PROFILING + RECORD_BACKEND_MEMORY_EVENT_TO_EDGE_PROFILER( + x_ptr, + x.numel() * sizeof(float), + x.numel() * sizeof(float), + x.numel() * sizeof(float) + y.numel() * sizeof(float) + + h.numel() * sizeof(float), + c10::Device(c10::kCPU)); +#endif if (instruction == "aten::add") { y_ptr[0] = x_ptr[0] + h_ptr[0]; } else { diff --git a/test/cpp/lite_interpreter_runtime/test_mobile_profiler.cpp b/test/cpp/lite_interpreter_runtime/test_mobile_profiler.cpp index 95ba2b7b853edf..867b775c1adb45 100644 --- a/test/cpp/lite_interpreter_runtime/test_mobile_profiler.cpp +++ b/test/cpp/lite_interpreter_runtime/test_mobile_profiler.cpp @@ -25,7 +25,9 @@ bool checkMetaData( if (line.find(op_name) != std::string::npos) { while (std::getline(trace_file, line)) { if (line.find(metadata_name) != std::string::npos) { - return (line.find(metadata_val) != std::string::npos); + if (line.find(metadata_val) != std::string::npos) { + return true; + } } } } @@ -122,6 +124,39 @@ TEST(MobileProfiler, Backend) { checkMetaData("aten::add", metadata_name, "test_backend", trace_file)); } +TEST(MobileProfiler, BackendMemoryEvents) { + std::string filePath(__FILE__); + auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1); + testModelFile.append("test_backend_for_profiling.ptl"); + + std::vector inputs; + inputs.emplace_back(at::rand({64, 64})); + inputs.emplace_back(at::rand({64, 64})); + std::string trace_file_name("/tmp/test_trace_backend_memory.trace"); + + mobile::Module bc = _load_for_mobile(testModelFile); + { + mobile::KinetoEdgeCPUProfiler profiler( + bc, + trace_file_name, + false, // record input_shapes + true, // profile memory + true, // record callstack + false, // record flops + true); // record module hierarchy + bc.forward(inputs); + } + std::ifstream trace_file(trace_file_name); + std::string line; + ASSERT_TRUE(trace_file.is_open()); + trace_file.seekg(0, std::ios_base::beg); + std::string metadata_name("Bytes"); + ASSERT_TRUE(checkMetaData("[memory]", metadata_name, "16384", trace_file)); + trace_file.seekg(0, std::ios_base::beg); + metadata_name = "Total Reserved"; + ASSERT_TRUE(checkMetaData("[memory]", metadata_name, "49152", trace_file)); +} + } // namespace mobile } // namespace jit } // namespace torch