Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add repeat op support from Forge to TTIR Lowering #1214

Merged
merged 1 commit into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions forge/csrc/passes/lower_to_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ enum class TargetType
SourceType,
UInt32,
Int64,
DenseI64ArrayAttr,
};

struct AttributeRemap
Expand Down Expand Up @@ -106,6 +107,7 @@ class AttributeMapper
add_op_mapping("repeat_interleave", "repeats", AttributeRemap(std::nullopt, TargetType::UInt32));
add_op_mapping("reduce_avg", "dim", AttributeRemap("dim_arg"));
add_op_mapping("cumsum", "dim", AttributeRemap(std::nullopt, TargetType::Int64));
add_op_mapping("repeat", "repeats", AttributeRemap("repeat_dimensions", TargetType::DenseI64ArrayAttr));

// Add more default mappings here
}
Expand Down Expand Up @@ -237,6 +239,10 @@ class MLIRGenerator
TT_ASSERT(std::get<int>(value) >= 0, "Value must be an >= 0 for conversion to uint32");
return builder_.getUI32IntegerAttr(static_cast<uint32_t>(std::get<int>(value)));
case TargetType::Int64: return builder_.getI64IntegerAttr(static_cast<int64_t>(std::get<int>(value)));

case TargetType::DenseI64ArrayAttr:
return builder_.getDenseI64ArrayAttr(std::vector<int64_t>(
std::get<std::vector<int>>(value).begin(), std::get<std::vector<int>>(value).end()));
default:
// If type not handled, throw an exception
throw std::runtime_error("Unhandled target type conversion");
Expand Down Expand Up @@ -636,6 +642,7 @@ class MLIRGenerator
lowering_handler_map["remainder"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::RemainderOp>;
lowering_handler_map["repeat_interleave"] =
&MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::RepeatInterleaveOp>;
lowering_handler_map["repeat"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::RepeatOp>;
lowering_handler_map["reshape"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::ReshapeOp>;
lowering_handler_map["select"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::SelectOp>;
lowering_handler_map["sigmoid"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::SigmoidOp>;
Expand Down
24 changes: 22 additions & 2 deletions forge/forge/op/eval/forge/tm.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ def eval(type, attr, ops):

if type == "repeat":
sizes = attr
assert len(t_ops[0].shape) == len(sizes)
return t_ops[0].repeat(*sizes)

if type == "repeat_interleave":
Expand Down Expand Up @@ -486,7 +485,15 @@ def shape(type, attr, ops):

if type == "repeat":
sizes = attr
return tuple(dim * size for dim, size in zip(list(ops[0]), sizes)), []
if len(ops[0]) < len(sizes):
# Scenario: When the input is a 1D tensor and needs to be repeated in 2D,
# `ttir.repeat` does not currently support this directly,
# so we are calculating the new shape by expanding the dimensions
# to match repeat attr dimensions and calculate the output shape
shape = (1,) * (len(sizes) - len(ops[0])) + tuple(ops[0])
else:
shape = ops[0]
return tuple(dim * size for dim, size in zip(list(shape), sizes)), []

if type == "repeat_interleave":
assert len(attr) <= 3, "repeat_interleave should have two attributes - repeats and dim"
Expand Down Expand Up @@ -1379,6 +1386,19 @@ def decompose(type, attr, dc, inputs):
rank -= 1
dc.fuse(result)
return
if type == "repeat":
input_shape = inputs[0].shape.as_list()
target_shape = attr
result = inputs[0]

if len(input_shape) < len(target_shape):
# Scenario: When the input is a 1D tensor and needs to be repeated in 2D,
# `ttir.repeat` does not currently support this directly.
# To handle this, we first reshape the input to ensure both the input and the repeats have the same dimensions
new_shape = (1,) * (len(target_shape) - len(input_shape)) + tuple(input_shape)
result = dc.op("reshape", [result], new_shape)
result = dc.op_with_named_attrs("repeat", [result], {"repeats": target_shape}, target_shape)
dc.fuse(result)


def create_row_picker_matrix(col_indices, lhs_num_cols, lhs_num_channels=None, lhs_batch_size=None):
Expand Down
1 change: 0 additions & 1 deletion forge/forge/op/tm.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,6 @@ def Repeat(name: str, operandA: Tensor, repeats: List[int]) -> Tensor:
Tensor
Forge tensor
"""
assert len(operandA.shape) == len(repeats)
return op("repeat", name, operandA, attrs=repeats, repeats=repeats).get_tensor()


Expand Down
27 changes: 22 additions & 5 deletions forge/test/mlir/operators/tm/test_tm.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,11 +562,28 @@ def forward(self, *tensors):
fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out


@pytest.mark.xfail(
reason="RuntimeError: Found Unsupported operations while lowering from TTForge to TTIR in forward graph - repeat"
@pytest.mark.parametrize(
["input_shape", "repeats"],
[
pytest.param((1, 2), (10, 1)),
pytest.param((1, 99), (100, 1)),
pytest.param(
(1, 100),
(50, 2),
marks=pytest.mark.xfail(reason="info:Incompatible dimensions 200 and 100"),
),
pytest.param(
(3,),
(4, 2),
marks=pytest.mark.xfail(reason="info:Incompatible dimensions 6 and 3"),
),
pytest.param((4, 1, 4), (1, 10, 1)),
pytest.param((2, 2, 1, 2), (1, 1, 4, 1)),
pytest.param((1, 4, 1, 4, 4), (1, 1, 3, 1, 1)),
],
)
@pytest.mark.push
def test_repeat():
def test_repeat(input_shape, repeats):
class Repeat(nn.Module):
def __init__(self, repeats):
super().__init__()
Expand All @@ -575,9 +592,9 @@ def __init__(self, repeats):
def forward(self, x):
return x.repeat(*self.repeats)

inputs = [torch.rand(1, 2, 1, 4, 4)]
inputs = [torch.rand(input_shape)]

framework_model = Repeat(repeats=(1, 1, 4, 1, 1))
framework_model = Repeat(repeats=repeats)
compiled_model = forge.compile(framework_model, sample_inputs=inputs)

verify(inputs, framework_model, compiled_model)
Expand Down
Loading