Skip to content

Commit

Permalink
Support ComputeFn where output type differs from input type (#1771)
Browse files Browse the repository at this point in the history
This is useful for e.g. function taking in 2 float inputs and turn them to complex
  • Loading branch information
tridao authored Sep 6, 2024
1 parent 82f5075 commit 323c817
Showing 1 changed file with 8 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -181,14 +181,20 @@ struct Sm90Compute {
},
[&] (auto&&... cvt_frg_inputs) {
using ComputeOutput = ComputeFn<Array<ElementCompute, FragmentSize>>;
using ConvertOutput = NumericArrayConverter<ElementOutput, ElementCompute, FragmentSize, RoundStyle>;
ComputeOutput compute_output{};
ConvertOutput convert_output{};

if constexpr (cute::is_same_v<Arguments, EmptyArguments>) {
using ElementComputeOutput =
typename cute::remove_cvref_t<decltype(compute_output(cvt_frg_inputs...))>::Element;
using ConvertOutput = NumericArrayConverter<ElementOutput, ElementComputeOutput, FragmentSize, RoundStyle>;
ConvertOutput convert_output{};
return convert_output(compute_output(cvt_frg_inputs...));
}
else {
using ElementComputeOutput =
typename cute::remove_cvref_t<decltype(compute_output(cvt_frg_inputs..., params))>::Element;
using ConvertOutput = NumericArrayConverter<ElementOutput, ElementComputeOutput, FragmentSize, RoundStyle>;
ConvertOutput convert_output{};
return convert_output(compute_output(cvt_frg_inputs..., params));
}
}
Expand Down

0 comments on commit 323c817

Please sign in to comment.