diff --git a/codegen/xla_native_functions.yaml b/codegen/xla_native_functions.yaml index 657ade2d279..b3a187fc60a 100644 --- a/codegen/xla_native_functions.yaml +++ b/codegen/xla_native_functions.yaml @@ -171,6 +171,7 @@ supported: - count_nonzero - count_nonzero.dim_IntList - cross + - cummax - cumprod - cumsum - detach_copy diff --git a/test/cpp/test_aten_xla_tensor_2.cpp b/test/cpp/test_aten_xla_tensor_2.cpp index f3833b28d1a..92d3bce02ed 100755 --- a/test/cpp/test_aten_xla_tensor_2.cpp +++ b/test/cpp/test_aten_xla_tensor_2.cpp @@ -2,6 +2,7 @@ #include #include +#include #include "test/cpp/cpp_test_util.h" #include "test/cpp/torch_xla_test.h" @@ -2118,6 +2119,23 @@ TEST_F(AtenXlaTensorTest, TestCumProdCastLong) { } } +TEST_F(AtenXlaTensorTest, TestCumMax) { + torch::Tensor input = torch::rand({4, 3, 4}); + int rank = input.dim(); + for (int dim = -rank; dim < rank; ++dim) { + std::tuple result = torch::cummax(input, dim); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_input = CopyToDevice(input, device); + std::tuple xla_result = + torch::cummax(xla_input, dim); + AllClose(std::get<0>(result), std::get<0>(xla_result)); + AllClose(std::get<1>(result), std::get<1>(xla_result)); + }); + } + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::cummax", cpp_test::GetIgnoredCounters()); +} + TEST_F(AtenXlaTensorTest, TestArgMin) { torch::Tensor a = torch::rand({4, 4, 4}, torch::TensorOptions(torch::kFloat)); torch::Tensor b = torch::argmin(a, std::nullopt, /*keepdim=*/false); diff --git a/test/test_ops.py b/test/test_ops.py index 51a80eecf93..f0cc9dc22c8 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -86,7 +86,6 @@ def get_allowed_ops_map( AllowedOpInfoEntry('complex'), AllowedOpInfoEntry('copysign'), AllowedOpInfoEntry('cross'), - AllowedOpInfoEntry('cummax'), AllowedOpInfoEntry('cummin'), AllowedOpInfoEntry('deg2rad'), AllowedOpInfoEntry('div', 'no_rounding_mode'), @@ -289,6 +288,7 @@ def get_allowed_ops_map( # AllowedOpInfoEntry('cos'), # AllowedOpInfoEntry('cosh'), # AllowedOpInfoEntry('cov'), + # AllowedOpInfoEntry('cummax'), # AllowedOpInfoEntry('cumsum'), # AllowedOpInfoEntry('cumprod'), # AllowedOpInfoEntry('diff'), diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 250161e3a97..7d979158f16 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1308,6 +1308,16 @@ at::Tensor XLANativeFunctions::cross(const at::Tensor& self, XlaHelpers::I64Optional(dim))); } +std::tuple XLANativeFunctions::cummax( + const at::Tensor& self, int64_t dim) { + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + std::tuple res = + tensor_methods::cummax(self_tensor, dim); + return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(res)), + bridge::AtenFromXlaTensor(std::get<1>(res))); +} + at::Tensor XLANativeFunctions::cumprod(const at::Tensor& self, int64_t dim, std::optional dtype) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); diff --git a/torch_xla/csrc/helpers.cpp b/torch_xla/csrc/helpers.cpp index c9d82a7a02d..7160ef4715f 100644 --- a/torch_xla/csrc/helpers.cpp +++ b/torch_xla/csrc/helpers.cpp @@ -44,6 +44,31 @@ xla::XlaComputation CreateComputation( return ConsumeValue(builder.Build(op(x, y))); } +xla::XlaComputation CreateMinMaxComputation(const std::string& name, + xla::PrimitiveType value_type, + xla::PrimitiveType index_type, + bool is_min) { + xla::XlaBuilder builder(name); + xla::XlaOp lhs_value = xla::Parameter( + &builder, 0, xla::ShapeUtil::MakeShape(value_type, {}), "lhs_value"); + xla::XlaOp lhs_index = xla::Parameter( + &builder, 1, xla::ShapeUtil::MakeShape(index_type, {}), "lhs_index"); + xla::XlaOp rhs_value = xla::Parameter( + &builder, 2, xla::ShapeUtil::MakeShape(value_type, {}), "rhs_value"); + xla::XlaOp rhs_index = xla::Parameter( + &builder, 3, xla::ShapeUtil::MakeShape(index_type, {}), "rhs_index"); + + xla::XlaOp cmp = + is_min ? xla::Le(lhs_value, rhs_value) : xla::Ge(lhs_value, rhs_value); + xla::XlaOp max = xla::Select(cmp, lhs_value, rhs_value); + xla::XlaOp arg_max = xla::Select(cmp, lhs_index, rhs_index); + xla::XlaOp eq = xla::Eq(lhs_value, rhs_value); + xla::XlaOp tie_id = xla::Min(lhs_index, rhs_index); + arg_max = xla::Select(eq, tie_id, arg_max); + xla::Tuple(&builder, {max, arg_max}); + return ConsumeValue(builder.Build()); +} + } // namespace xla::PrecisionConfig::Precision XlaHelpers::s_mat_mul_precision = @@ -229,6 +254,12 @@ xla::XlaComputation XlaHelpers::CreateOrComputation(xla::PrimitiveType type) { [&](xla::XlaOp x, xla::XlaOp y) { return xla::Or(x, y); }); } +xla::XlaComputation XlaHelpers::CreateMaxAndArgMaxComputation( + xla::PrimitiveType value_type, xla::PrimitiveType index_type) { + return CreateMinMaxComputation("MaxAndArgMaxComputation", value_type, + index_type, /*is_min=*/false); +} + std::vector XlaHelpers::SizesOfXlaOp(xla::XlaOp op) { const xla::Shape& op_shape = ShapeHelper::ShapeOfXlaOp(op); return std::vector(op_shape.dimensions().begin(), diff --git a/torch_xla/csrc/helpers.h b/torch_xla/csrc/helpers.h index a8ec39a973a..9ac60207476 100644 --- a/torch_xla/csrc/helpers.h +++ b/torch_xla/csrc/helpers.h @@ -230,6 +230,9 @@ class XlaHelpers { static xla::XlaComputation CreateOrComputation(xla::PrimitiveType type); + static xla::XlaComputation CreateMaxAndArgMaxComputation( + xla::PrimitiveType value_type, xla::PrimitiveType index_type); + // Returns an XLA operation which is a reshape to the expected rank, by // appending 1s to the major dimension. If offset is greater than zero, 1s // will be prepened to the minor dimension as well. diff --git a/torch_xla/csrc/ops/cummax.cpp b/torch_xla/csrc/ops/cummax.cpp new file mode 100644 index 00000000000..89791d41848 --- /dev/null +++ b/torch_xla/csrc/ops/cummax.cpp @@ -0,0 +1,70 @@ +#include "torch_xla/csrc/ops/cummax.h" + +#include + +#include "torch_xla/csrc/convert_ops.h" +#include "torch_xla/csrc/helpers.h" +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/ops/infer_output_shape.h" +#include "torch_xla/csrc/reduction.h" +#include "torch_xla/csrc/shape_helper.h" +#include "torch_xla/csrc/tensor_util.h" +#include "torch_xla/csrc/torch_util.h" + +namespace torch_xla { +namespace { + +xla::XlaOp LowerCumMax(xla::XlaOp input, int64_t dim) { + const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input); + xla::XlaOp value_init_value = xla::ConstantLiteral( + input.builder(), xla::LiteralUtil::MinValue(input_shape.element_type())); + xla::XlaOp index_init_value = xla::ConstantLiteral( + input.builder(), xla::LiteralUtil::Zero(xla::PrimitiveType::S32)); + xla::XlaOp iota = + xla::Iota(input.builder(), + xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, + input_shape.dimensions()), + dim); + xla::XlaComputation reducer = XlaHelpers::CreateMaxAndArgMaxComputation( + input_shape.element_type(), xla::PrimitiveType::S32); + return BuildCumulativeComputationWithIndices( + input, iota, dim, reducer, value_init_value, index_init_value); +} + +xla::Shape NodeOutputShape(const torch::lazy::Value& input, int64_t dim) { + auto lower_for_shape_fn = + [&](absl::Span operands) -> xla::XlaOp { + xla::XlaOp values_and_indices = LowerCumMax(operands[0], dim); + return values_and_indices; + }; + return InferOutputShape({GetXlaShape(input)}, lower_for_shape_fn); +} + +} // namespace + +CumMax::CumMax(const torch::lazy::Value& input, int64_t dim) + : XlaNode( + torch::lazy::OpKind(at::aten::cummax), {input}, + [&]() { return NodeOutputShape(input, dim); }, + /*num_outputs=*/2, torch::lazy::MHash(dim)), + dim_(dim) {} + +torch::lazy::NodePtr CumMax::Clone(torch::lazy::OpList operands) const { + return torch_xla::MakeNode(operands.at(0), dim_); +} + +XlaOpVector CumMax::Lower(LoweringContext* loctx) const { + xla::XlaOp input = loctx->GetOutputOp(operand(0)); + xla::XlaOp values_and_indices = LowerCumMax(input, dim_); + return ReturnOps({xla::GetTupleElement(values_and_indices, 0), + xla::GetTupleElement(values_and_indices, 1)}, + loctx); +} + +std::string CumMax::ToString() const { + std::stringstream ss; + ss << XlaNode::ToString() << ", dim=" << dim_; + return ss.str(); +} + +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/cummax.h b/torch_xla/csrc/ops/cummax.h new file mode 100644 index 00000000000..1a75242e6f5 --- /dev/null +++ b/torch_xla/csrc/ops/cummax.h @@ -0,0 +1,28 @@ +#ifndef XLA_TORCH_XLA_CSRC_OPS_CUMMAX_H_ +#define XLA_TORCH_XLA_CSRC_OPS_CUMMAX_H_ + +#include + +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { + +class CumMax : public XlaNode { + public: + CumMax(const torch::lazy::Value& input, int64_t dim); + + std::string ToString() const override; + + torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + int64_t dim() const { return dim_; } + + private: + int64_t dim_; +}; + +} // namespace torch_xla + +#endif // XLA_TORCH_XLA_CSRC_OPS_CUMMAX_H_ diff --git a/torch_xla/csrc/reduction.cpp b/torch_xla/csrc/reduction.cpp index 56702e79279..9ec01eb46c2 100644 --- a/torch_xla/csrc/reduction.cpp +++ b/torch_xla/csrc/reduction.cpp @@ -284,6 +284,22 @@ xla::XlaOp BuildCumulativeComputation(xla::XlaOp input, int64_t dim, /*base_dilations=*/{}, /*window_dilations=*/{}, padding); } +xla::XlaOp BuildCumulativeComputationWithIndices( + xla::XlaOp value_input, xla::XlaOp index_input, int64_t dim, + const xla::XlaComputation& reducer, xla::XlaOp value_init, + xla::XlaOp index_init) { + const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(value_input); + std::vector window_strides(input_shape.rank(), 1); + std::vector window_dims(input_shape.rank(), 1); + window_dims[dim] = input_shape.dimensions(dim); + std::vector> padding(input_shape.rank()); + padding[dim].first = input_shape.dimensions(dim) - 1; + return xla::ReduceWindowWithGeneralPadding( + {value_input, index_input}, {value_init, index_init}, reducer, + window_dims, window_strides, + /*base_dilations=*/{}, /*window_dilations=*/{}, padding); +} + xla::XlaOp BuildMean(xla::XlaOp input, absl::Span dimensions, bool keep_reduced_dimensions) { return CreateSummation(input, dimensions, keep_reduced_dimensions, diff --git a/torch_xla/csrc/reduction.h b/torch_xla/csrc/reduction.h index f71fb6f1c3c..8c4ba2d4418 100644 --- a/torch_xla/csrc/reduction.h +++ b/torch_xla/csrc/reduction.h @@ -88,6 +88,14 @@ xla::XlaOp BuildCumulativeComputation(xla::XlaOp input, int64_t dim, const xla::XlaComputation& reducer, xla::XlaOp init); +// Computes the cumulative computation specified by "reducer" and "init" in the +// given dimension "dim". +// Returns a tuple XlaOp (values, indices). +xla::XlaOp BuildCumulativeComputationWithIndices( + xla::XlaOp value_input, xla::XlaOp index_input, int64_t dim, + const xla::XlaComputation& reducer, xla::XlaOp value_init, + xla::XlaOp index_init); + xla::XlaOp BuildAll(xla::XlaOp input, absl::Span dimensions, bool keep_reduced_dimensions); diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 411596797a8..d8939972e67 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -40,6 +40,7 @@ #include "torch_xla/csrc/ops/convolution_backward_overrideable.h" #include "torch_xla/csrc/ops/convolution_overrideable.h" #include "torch_xla/csrc/ops/count_nonzero.h" +#include "torch_xla/csrc/ops/cummax.h" #include "torch_xla/csrc/ops/cumprod.h" #include "torch_xla/csrc/ops/cumsum.h" #include "torch_xla/csrc/ops/custom_call.h" @@ -1294,6 +1295,25 @@ XLATensorPtr cross(const XLATensorPtr& input, const XLATensorPtr& other, return tensor_ops::Cross(input, other, dim); } +std::tuple cummax(const XLATensorPtr& input, + int64_t dim) { + torch::lazy::NodePtr node = torch_xla::MakeNode( + input->GetIrValue(), torch::lazy::GetCanonicalDimensionIndex( + dim, input->shape().get().rank())); + XLATensorPtr t_value = input->CreateFrom(torch::lazy::Value(node, 0), + /*delay_eager_executation=*/true); + XLATensorPtr t_index = + input->CreateFrom(torch::lazy::Value(node, 1), at::ScalarType::Long, + /*delay_eager_executation=*/true); + XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); + if (graph_executor->UseEagerMode()) { + // Execute the HLO that will run the `kthvalue` and in one hlo + std::vector tensors_to_sync = {t_value, t_index}; + graph_executor->ApplyEagerSync(tensors_to_sync); + } + return std::make_tuple(t_value, t_index); +} + XLATensorPtr cumprod(const XLATensorPtr& input, int64_t dim, std::optional dtype) { int64_t canonical_dim = @@ -1364,8 +1384,8 @@ XLATensorPtr div(const XLATensorPtr& input, const XLATensorPtr& other, } else if (!input_is_float && other_is_float) { scalar_type = MaybeUpcastToHostTorchType(other_type); } - // We need to cast both input and other to float to perform true divide, floor - // divide and trunc divide. + // We need to cast both input and other to float to perform true divide, + // floor divide and trunc divide. torch::lazy::Value input_value = GetFloatingIrValue(input, scalar_type); torch::lazy::Value other_value = GetFloatingIrValue(other, scalar_type); torch::lazy::Value res = Div(input_value, other_value); @@ -1381,9 +1401,9 @@ XLATensorPtr div(const XLATensorPtr& input, const XLATensorPtr& other, } // Promote the result to the logical_element_type if one of the - // input and the other is float. If that is not the case logical_element_type - // will be non-floating-point type, we should only promote the result to that - // when rounding_mode is not nullopt. + // input and the other is float. If that is not the case + // logical_element_type will be non-floating-point type, we should only + // promote the result to that when rounding_mode is not nullopt. if (input_is_float || other_is_float || rounding_mode.has_value()) { if (logical_element_type.has_value()) { xla::PrimitiveType res_intended_type = @@ -1872,7 +1892,8 @@ XLATensorPtr linalg_vector_norm(const XLATensorPtr& input, const at::Scalar& ord, std::vector dimensions, bool keep_dim, std::optional dtype) { - // If the input is a scalar, we have to manually create the dimensions vector. + // If the input is a scalar, we have to manually create the dimensions + // vector. auto input_rank = input->shape().get().rank(); std::vector canonical_dims; if (input_rank != 0) { @@ -1988,8 +2009,8 @@ XLATensorPtr logsumexp(const XLATensorPtr& input, XLATensorPtr xlogy(const XLATensorPtr& input, const XLATensorPtr& other) { // Here we explictly pass std::nullopt as logical_element_type because // otherwise result will inherit the input's logical_element_type. In the - // case of xlogy(int,int) -> float, we want to derive the dtype from IR value - // instead of input's logical_element_type. + // case of xlogy(int,int) -> float, we want to derive the dtype from IR + // value instead of input's logical_element_type. return input->CreateFrom( XLogY(input->GetIrValue(), GetFloatingIrValue(other, at::ScalarType::Float)), @@ -2016,9 +2037,9 @@ XLATensorPtr masked_scatter(XLATensorPtr& input, const XLATensorPtr& mask, auto input_value = input->GetIrValue(); // This ensures that input tensor is at least the same shape as mask tensor. // Note that we can't use the existing MaybeExpand function since - // input tensor may sometimes be bigger than the mask tensor, and MaybeExpand - // requires the first parameter to always be less or equal to the second - // parameter. + // input tensor may sometimes be bigger than the mask tensor, and + // MaybeExpand requires the first parameter to always be less or equal to + // the second parameter. if (input->shape().get().dimensions() < mask->shape().get().dimensions()) { input_value = MaybeExpand(input->GetIrValue(), mask->shape()); } @@ -2335,7 +2356,8 @@ std::tuple native_batch_norm( running_var->SetIrValue( torch_xla::MakeNode( torch::lazy::Value(node, 2), running_var->GetIrValue(), momentum), - /*inplace=*/true, /*delay_eager_executation=*/true); + /*inplace=*/true, + /*delay_eager_executation=*/true); } } else { at::Tensor at_input = bridge::AtenFromXlaTensor(input); @@ -2381,8 +2403,8 @@ std::tuple native_batch_norm_backward( /*delay_eager_executation=*/true); XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); if (graph_executor->UseEagerMode()) { - // Execute the HLO that will run the `native_batch_norm_backward` and in one - // hlo + // Execute the HLO that will run the `native_batch_norm_backward` and in + // one hlo std::vector tensors_to_sync = {grad_input, grad_weight, grad_bias}; graph_executor->ApplyEagerSync(tensors_to_sync); @@ -2489,8 +2511,8 @@ XLATensorPtr norm(const XLATensorPtr& input, const std::optional& p, } auto out = Norm(input->GetIrValue(), p, dtype, canonical_dims, keepdim); if (dtype.has_value()) { - // The returned tensor is actually of type `dtype`. Therefore, it should not - // inherit the data-type from the input, when creating the XLATensor. + // The returned tensor is actually of type `dtype`. Therefore, it should + // not inherit the data-type from the input, when creating the XLATensor. return input->CreateFrom(out, dtype); } else { return input->CreateFrom(out); @@ -3058,7 +3080,8 @@ std::tuple eigh(const XLATensorPtr& input, // from IR value instead of input's dtype. return std::make_tuple( input->CreateFrom(torch::lazy::Value(node, 0), std::nullopt), - // From https://pytorch.org/docs/stable/generated/torch.linalg.eigh.html, + // From + // https://pytorch.org/docs/stable/generated/torch.linalg.eigh.html, // eigenvectors will have the same dtype as A. input->CreateFrom(torch::lazy::Value(node, 1))); } @@ -3137,8 +3160,8 @@ std::vector split(const XLATensorPtr& input, int64_t split_size, torch::lazy::GetCanonicalDimensionIndex(dim, input_shape.get().rank()); int64_t dim_size = input_shape.get().dimensions(split_dim); if (dim_size == 0) { - // Deal with dim_size=0, it's a corner case which only return 1 0-dim tensor - // no matter what split_size is. + // Deal with dim_size=0, it's a corner case which only return 1 0-dim + // tensor no matter what split_size is. xla::Literal literal(input_shape.get()); return { input->CreateFrom(torch_xla::MakeNode(std::move(literal)))}; @@ -3403,7 +3426,8 @@ XLATensorPtr transpose(const XLATensorPtr& input, int64_t dim0, int64_t dim1) { GetXlaShape(ir_value)); } else { std::vector permute_dims = torch::lazy::MakeTransposePermutation( - /*dim0=*/dim0, /*dim1=*/dim1, /*rank=*/input_shape.get().rank()); + /*dim0=*/dim0, /*dim1=*/dim1, + /*rank=*/input_shape.get().rank()); view_info = ViewInfo(ViewInfo::Type::kPermute, input_shape, permute_dims); } return input->CreateViewTensor(std::move(view_info)); diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 677bd1c303b..c1d20b70f7a 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -361,6 +361,11 @@ XLATensorPtr count_nonzero(const XLATensorPtr& input, XLATensorPtr cross(const XLATensorPtr& input, const XLATensorPtr& other, std::optional dim); +// Returns a tuple of the cumulative max of elements and the corresponding +// indices of input in the given dimension. +std::tuple cummax(const XLATensorPtr& input, + int64_t dim); + // Returns the cumulative product of elements of input in the given dimension. XLATensorPtr cumprod(const XLATensorPtr& input, int64_t dim, std::optional dtype);