From 323c8170bffdd11d774437b450e42d842e203517 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Thu, 5 Sep 2024 20:25:03 -0700 Subject: [PATCH] Support ComputeFn where output type differs from input type (#1771) This is useful for e.g. function taking in 2 float inputs and turn them to complex --- .../sm90_visitor_compute_tma_warpspecialized.hpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp index 0b12badc7d..8f5ceb5489 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp @@ -181,14 +181,20 @@ struct Sm90Compute { }, [&] (auto&&... cvt_frg_inputs) { using ComputeOutput = ComputeFn>; - using ConvertOutput = NumericArrayConverter; ComputeOutput compute_output{}; - ConvertOutput convert_output{}; if constexpr (cute::is_same_v) { + using ElementComputeOutput = + typename cute::remove_cvref_t::Element; + using ConvertOutput = NumericArrayConverter; + ConvertOutput convert_output{}; return convert_output(compute_output(cvt_frg_inputs...)); } else { + using ElementComputeOutput = + typename cute::remove_cvref_t::Element; + using ConvertOutput = NumericArrayConverter; + ConvertOutput convert_output{}; return convert_output(compute_output(cvt_frg_inputs..., params)); } }