Skip to content

Commit

Permalink
Lower cummax op (pytorch#8491)
Browse files Browse the repository at this point in the history
  • Loading branch information
zyy-martin authored Jan 10, 2025
1 parent af223d3 commit c394d1b
Show file tree
Hide file tree
Showing 12 changed files with 235 additions and 21 deletions.
1 change: 1 addition & 0 deletions codegen/xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ supported:
- count_nonzero
- count_nonzero.dim_IntList
- cross
- cummax
- cumprod
- cumsum
- detach_copy
Expand Down
18 changes: 18 additions & 0 deletions test/cpp/test_aten_xla_tensor_2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <torch/torch.h>

#include <iostream>
#include <tuple>

#include "test/cpp/cpp_test_util.h"
#include "test/cpp/torch_xla_test.h"
Expand Down Expand Up @@ -2118,6 +2119,23 @@ TEST_F(AtenXlaTensorTest, TestCumProdCastLong) {
}
}

TEST_F(AtenXlaTensorTest, TestCumMax) {
torch::Tensor input = torch::rand({4, 3, 4});
int rank = input.dim();
for (int dim = -rank; dim < rank; ++dim) {
std::tuple<torch::Tensor, torch::Tensor> result = torch::cummax(input, dim);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input, device);
std::tuple<torch::Tensor, torch::Tensor> xla_result =
torch::cummax(xla_input, dim);
AllClose(std::get<0>(result), std::get<0>(xla_result));
AllClose(std::get<1>(result), std::get<1>(xla_result));
});
}
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::cummax", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestArgMin) {
torch::Tensor a = torch::rand({4, 4, 4}, torch::TensorOptions(torch::kFloat));
torch::Tensor b = torch::argmin(a, std::nullopt, /*keepdim=*/false);
Expand Down
2 changes: 1 addition & 1 deletion test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ def get_allowed_ops_map(
AllowedOpInfoEntry('complex'),
AllowedOpInfoEntry('copysign'),
AllowedOpInfoEntry('cross'),
AllowedOpInfoEntry('cummax'),
AllowedOpInfoEntry('cummin'),
AllowedOpInfoEntry('deg2rad'),
AllowedOpInfoEntry('div', 'no_rounding_mode'),
Expand Down Expand Up @@ -289,6 +288,7 @@ def get_allowed_ops_map(
# AllowedOpInfoEntry('cos'),
# AllowedOpInfoEntry('cosh'),
# AllowedOpInfoEntry('cov'),
# AllowedOpInfoEntry('cummax'),
# AllowedOpInfoEntry('cumsum'),
# AllowedOpInfoEntry('cumprod'),
# AllowedOpInfoEntry('diff'),
Expand Down
10 changes: 10 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1308,6 +1308,16 @@ at::Tensor XLANativeFunctions::cross(const at::Tensor& self,
XlaHelpers::I64Optional(dim)));
}

std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::cummax(
const at::Tensor& self, int64_t dim) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
std::tuple<XLATensorPtr, XLATensorPtr> res =
tensor_methods::cummax(self_tensor, dim);
return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(res)),
bridge::AtenFromXlaTensor(std::get<1>(res)));
}

at::Tensor XLANativeFunctions::cumprod(const at::Tensor& self, int64_t dim,
std::optional<at::ScalarType> dtype) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
Expand Down
31 changes: 31 additions & 0 deletions torch_xla/csrc/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,31 @@ xla::XlaComputation CreateComputation(
return ConsumeValue(builder.Build(op(x, y)));
}

xla::XlaComputation CreateMinMaxComputation(const std::string& name,
xla::PrimitiveType value_type,
xla::PrimitiveType index_type,
bool is_min) {
xla::XlaBuilder builder(name);
xla::XlaOp lhs_value = xla::Parameter(
&builder, 0, xla::ShapeUtil::MakeShape(value_type, {}), "lhs_value");
xla::XlaOp lhs_index = xla::Parameter(
&builder, 1, xla::ShapeUtil::MakeShape(index_type, {}), "lhs_index");
xla::XlaOp rhs_value = xla::Parameter(
&builder, 2, xla::ShapeUtil::MakeShape(value_type, {}), "rhs_value");
xla::XlaOp rhs_index = xla::Parameter(
&builder, 3, xla::ShapeUtil::MakeShape(index_type, {}), "rhs_index");

xla::XlaOp cmp =
is_min ? xla::Le(lhs_value, rhs_value) : xla::Ge(lhs_value, rhs_value);
xla::XlaOp max = xla::Select(cmp, lhs_value, rhs_value);
xla::XlaOp arg_max = xla::Select(cmp, lhs_index, rhs_index);
xla::XlaOp eq = xla::Eq(lhs_value, rhs_value);
xla::XlaOp tie_id = xla::Min(lhs_index, rhs_index);
arg_max = xla::Select(eq, tie_id, arg_max);
xla::Tuple(&builder, {max, arg_max});
return ConsumeValue(builder.Build());
}

} // namespace

xla::PrecisionConfig::Precision XlaHelpers::s_mat_mul_precision =
Expand Down Expand Up @@ -229,6 +254,12 @@ xla::XlaComputation XlaHelpers::CreateOrComputation(xla::PrimitiveType type) {
[&](xla::XlaOp x, xla::XlaOp y) { return xla::Or(x, y); });
}

xla::XlaComputation XlaHelpers::CreateMaxAndArgMaxComputation(
xla::PrimitiveType value_type, xla::PrimitiveType index_type) {
return CreateMinMaxComputation("MaxAndArgMaxComputation", value_type,
index_type, /*is_min=*/false);
}

std::vector<int64_t> XlaHelpers::SizesOfXlaOp(xla::XlaOp op) {
const xla::Shape& op_shape = ShapeHelper::ShapeOfXlaOp(op);
return std::vector<int64_t>(op_shape.dimensions().begin(),
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,9 @@ class XlaHelpers {

static xla::XlaComputation CreateOrComputation(xla::PrimitiveType type);

static xla::XlaComputation CreateMaxAndArgMaxComputation(
xla::PrimitiveType value_type, xla::PrimitiveType index_type);

// Returns an XLA operation which is a reshape to the expected rank, by
// appending 1s to the major dimension. If offset is greater than zero, 1s
// will be prepened to the minor dimension as well.
Expand Down
70 changes: 70 additions & 0 deletions torch_xla/csrc/ops/cummax.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#include "torch_xla/csrc/ops/cummax.h"

#include <torch/csrc/lazy/core/tensor_util.h>

#include "torch_xla/csrc/convert_ops.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/infer_output_shape.h"
#include "torch_xla/csrc/reduction.h"
#include "torch_xla/csrc/shape_helper.h"
#include "torch_xla/csrc/tensor_util.h"
#include "torch_xla/csrc/torch_util.h"

namespace torch_xla {
namespace {

xla::XlaOp LowerCumMax(xla::XlaOp input, int64_t dim) {
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input);
xla::XlaOp value_init_value = xla::ConstantLiteral(
input.builder(), xla::LiteralUtil::MinValue(input_shape.element_type()));
xla::XlaOp index_init_value = xla::ConstantLiteral(
input.builder(), xla::LiteralUtil::Zero(xla::PrimitiveType::S32));
xla::XlaOp iota =
xla::Iota(input.builder(),
xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32,
input_shape.dimensions()),
dim);
xla::XlaComputation reducer = XlaHelpers::CreateMaxAndArgMaxComputation(
input_shape.element_type(), xla::PrimitiveType::S32);
return BuildCumulativeComputationWithIndices(
input, iota, dim, reducer, value_init_value, index_init_value);
}

xla::Shape NodeOutputShape(const torch::lazy::Value& input, int64_t dim) {
auto lower_for_shape_fn =
[&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
xla::XlaOp values_and_indices = LowerCumMax(operands[0], dim);
return values_and_indices;
};
return InferOutputShape({GetXlaShape(input)}, lower_for_shape_fn);
}

} // namespace

CumMax::CumMax(const torch::lazy::Value& input, int64_t dim)
: XlaNode(
torch::lazy::OpKind(at::aten::cummax), {input},
[&]() { return NodeOutputShape(input, dim); },
/*num_outputs=*/2, torch::lazy::MHash(dim)),
dim_(dim) {}

torch::lazy::NodePtr CumMax::Clone(torch::lazy::OpList operands) const {
return torch_xla::MakeNode<CumMax>(operands.at(0), dim_);
}

XlaOpVector CumMax::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
xla::XlaOp values_and_indices = LowerCumMax(input, dim_);
return ReturnOps({xla::GetTupleElement(values_and_indices, 0),
xla::GetTupleElement(values_and_indices, 1)},
loctx);
}

std::string CumMax::ToString() const {
std::stringstream ss;
ss << XlaNode::ToString() << ", dim=" << dim_;
return ss.str();
}

} // namespace torch_xla
28 changes: 28 additions & 0 deletions torch_xla/csrc/ops/cummax.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#ifndef XLA_TORCH_XLA_CSRC_OPS_CUMMAX_H_
#define XLA_TORCH_XLA_CSRC_OPS_CUMMAX_H_

#include <c10/core/ScalarType.h>

#include "torch_xla/csrc/ir.h"

namespace torch_xla {

class CumMax : public XlaNode {
public:
CumMax(const torch::lazy::Value& input, int64_t dim);

std::string ToString() const override;

torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override;

XlaOpVector Lower(LoweringContext* loctx) const override;

int64_t dim() const { return dim_; }

private:
int64_t dim_;
};

} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_OPS_CUMMAX_H_
16 changes: 16 additions & 0 deletions torch_xla/csrc/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,22 @@ xla::XlaOp BuildCumulativeComputation(xla::XlaOp input, int64_t dim,
/*base_dilations=*/{}, /*window_dilations=*/{}, padding);
}

xla::XlaOp BuildCumulativeComputationWithIndices(
xla::XlaOp value_input, xla::XlaOp index_input, int64_t dim,
const xla::XlaComputation& reducer, xla::XlaOp value_init,
xla::XlaOp index_init) {
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(value_input);
std::vector<int64_t> window_strides(input_shape.rank(), 1);
std::vector<int64_t> window_dims(input_shape.rank(), 1);
window_dims[dim] = input_shape.dimensions(dim);
std::vector<std::pair<int64_t, int64_t>> padding(input_shape.rank());
padding[dim].first = input_shape.dimensions(dim) - 1;
return xla::ReduceWindowWithGeneralPadding(
{value_input, index_input}, {value_init, index_init}, reducer,
window_dims, window_strides,
/*base_dilations=*/{}, /*window_dilations=*/{}, padding);
}

xla::XlaOp BuildMean(xla::XlaOp input, absl::Span<const int64_t> dimensions,
bool keep_reduced_dimensions) {
return CreateSummation(input, dimensions, keep_reduced_dimensions,
Expand Down
8 changes: 8 additions & 0 deletions torch_xla/csrc/reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@ xla::XlaOp BuildCumulativeComputation(xla::XlaOp input, int64_t dim,
const xla::XlaComputation& reducer,
xla::XlaOp init);

// Computes the cumulative computation specified by "reducer" and "init" in the
// given dimension "dim".
// Returns a tuple XlaOp (values, indices).
xla::XlaOp BuildCumulativeComputationWithIndices(
xla::XlaOp value_input, xla::XlaOp index_input, int64_t dim,
const xla::XlaComputation& reducer, xla::XlaOp value_init,
xla::XlaOp index_init);

xla::XlaOp BuildAll(xla::XlaOp input, absl::Span<const int64_t> dimensions,
bool keep_reduced_dimensions);

Expand Down
Loading

0 comments on commit c394d1b

Please sign in to comment.