forked from pytorch/xla
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
af223d3
commit c394d1b
Showing
12 changed files
with
235 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
#include "torch_xla/csrc/ops/cummax.h" | ||
|
||
#include <torch/csrc/lazy/core/tensor_util.h> | ||
|
||
#include "torch_xla/csrc/convert_ops.h" | ||
#include "torch_xla/csrc/helpers.h" | ||
#include "torch_xla/csrc/lowering_context.h" | ||
#include "torch_xla/csrc/ops/infer_output_shape.h" | ||
#include "torch_xla/csrc/reduction.h" | ||
#include "torch_xla/csrc/shape_helper.h" | ||
#include "torch_xla/csrc/tensor_util.h" | ||
#include "torch_xla/csrc/torch_util.h" | ||
|
||
namespace torch_xla { | ||
namespace { | ||
|
||
xla::XlaOp LowerCumMax(xla::XlaOp input, int64_t dim) { | ||
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input); | ||
xla::XlaOp value_init_value = xla::ConstantLiteral( | ||
input.builder(), xla::LiteralUtil::MinValue(input_shape.element_type())); | ||
xla::XlaOp index_init_value = xla::ConstantLiteral( | ||
input.builder(), xla::LiteralUtil::Zero(xla::PrimitiveType::S32)); | ||
xla::XlaOp iota = | ||
xla::Iota(input.builder(), | ||
xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, | ||
input_shape.dimensions()), | ||
dim); | ||
xla::XlaComputation reducer = XlaHelpers::CreateMaxAndArgMaxComputation( | ||
input_shape.element_type(), xla::PrimitiveType::S32); | ||
return BuildCumulativeComputationWithIndices( | ||
input, iota, dim, reducer, value_init_value, index_init_value); | ||
} | ||
|
||
xla::Shape NodeOutputShape(const torch::lazy::Value& input, int64_t dim) { | ||
auto lower_for_shape_fn = | ||
[&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp { | ||
xla::XlaOp values_and_indices = LowerCumMax(operands[0], dim); | ||
return values_and_indices; | ||
}; | ||
return InferOutputShape({GetXlaShape(input)}, lower_for_shape_fn); | ||
} | ||
|
||
} // namespace | ||
|
||
CumMax::CumMax(const torch::lazy::Value& input, int64_t dim) | ||
: XlaNode( | ||
torch::lazy::OpKind(at::aten::cummax), {input}, | ||
[&]() { return NodeOutputShape(input, dim); }, | ||
/*num_outputs=*/2, torch::lazy::MHash(dim)), | ||
dim_(dim) {} | ||
|
||
torch::lazy::NodePtr CumMax::Clone(torch::lazy::OpList operands) const { | ||
return torch_xla::MakeNode<CumMax>(operands.at(0), dim_); | ||
} | ||
|
||
XlaOpVector CumMax::Lower(LoweringContext* loctx) const { | ||
xla::XlaOp input = loctx->GetOutputOp(operand(0)); | ||
xla::XlaOp values_and_indices = LowerCumMax(input, dim_); | ||
return ReturnOps({xla::GetTupleElement(values_and_indices, 0), | ||
xla::GetTupleElement(values_and_indices, 1)}, | ||
loctx); | ||
} | ||
|
||
std::string CumMax::ToString() const { | ||
std::stringstream ss; | ||
ss << XlaNode::ToString() << ", dim=" << dim_; | ||
return ss.str(); | ||
} | ||
|
||
} // namespace torch_xla |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
#ifndef XLA_TORCH_XLA_CSRC_OPS_CUMMAX_H_ | ||
#define XLA_TORCH_XLA_CSRC_OPS_CUMMAX_H_ | ||
|
||
#include <c10/core/ScalarType.h> | ||
|
||
#include "torch_xla/csrc/ir.h" | ||
|
||
namespace torch_xla { | ||
|
||
class CumMax : public XlaNode { | ||
public: | ||
CumMax(const torch::lazy::Value& input, int64_t dim); | ||
|
||
std::string ToString() const override; | ||
|
||
torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; | ||
|
||
XlaOpVector Lower(LoweringContext* loctx) const override; | ||
|
||
int64_t dim() const { return dim_; } | ||
|
||
private: | ||
int64_t dim_; | ||
}; | ||
|
||
} // namespace torch_xla | ||
|
||
#endif // XLA_TORCH_XLA_CSRC_OPS_CUMMAX_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.