Skip to content

Commit

Permalink
Add repeat op support in Forge to TTIR Lowering
Browse files Browse the repository at this point in the history
  • Loading branch information
ashokkumarkannan1 committed Feb 24, 2025
1 parent 5154af9 commit 01221e3
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 6 deletions.
7 changes: 7 additions & 0 deletions forge/csrc/passes/lower_to_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ enum class TargetType
SourceType,
UInt32,
Int64,
DenseI64ArrayAttr,
};

struct AttributeRemap
Expand Down Expand Up @@ -106,6 +107,7 @@ class AttributeMapper
add_op_mapping("repeat_interleave", "repeats", AttributeRemap(std::nullopt, TargetType::UInt32));
add_op_mapping("reduce_avg", "dim", AttributeRemap("dim_arg"));
add_op_mapping("cumsum", "dim", AttributeRemap(std::nullopt, TargetType::Int64));
add_op_mapping("repeat", "repeats", AttributeRemap("repeat_dimensions", TargetType::DenseI64ArrayAttr));

// Add more default mappings here
}
Expand Down Expand Up @@ -237,6 +239,10 @@ class MLIRGenerator
TT_ASSERT(std::get<int>(value) >= 0, "Value must be an >= 0 for conversion to uint32");
return builder_.getUI32IntegerAttr(static_cast<uint32_t>(std::get<int>(value)));
case TargetType::Int64: return builder_.getI64IntegerAttr(static_cast<int64_t>(std::get<int>(value)));

case TargetType::DenseI64ArrayAttr:
return builder_.getDenseI64ArrayAttr(std::vector<int64_t>(
std::get<std::vector<int>>(value).begin(), std::get<std::vector<int>>(value).end()));
default:
// If type not handled, throw an exception
throw std::runtime_error("Unhandled target type conversion");
Expand Down Expand Up @@ -636,6 +642,7 @@ class MLIRGenerator
lowering_handler_map["remainder"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::RemainderOp>;
lowering_handler_map["repeat_interleave"] =
&MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::RepeatInterleaveOp>;
lowering_handler_map["repeat"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::RepeatOp>;
lowering_handler_map["reshape"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::ReshapeOp>;
lowering_handler_map["select"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::SelectOp>;
lowering_handler_map["sigmoid"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::SigmoidOp>;
Expand Down
22 changes: 17 additions & 5 deletions forge/test/mlir/operators/tm/test_tm.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,11 +562,23 @@ def forward(self, *tensors):
fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out


@pytest.mark.xfail(
reason="RuntimeError: Found Unsupported operations while lowering from TTForge to TTIR in forward graph - repeat"
@pytest.mark.parametrize(
["input_shapes", "repeats"],
[
pytest.param([(1, 2)], (10, 1)),
pytest.param([(1, 99)], (100, 1)),
pytest.param(
[(1, 100)],
(50, 2),
marks=pytest.mark.xfail(reason="info:Incompatible dimensions 200 and 100"),
),
pytest.param([(4, 1, 4)], (1, 10, 1)),
pytest.param([(2, 2, 1, 2)], (1, 1, 4, 1)),
pytest.param([(1, 4, 1, 4, 4)], (1, 1, 3, 1, 1)),
],
)
@pytest.mark.push
def test_repeat():
def test_repeat(input_shapes, repeats):
class Repeat(nn.Module):
def __init__(self, repeats):
super().__init__()
Expand All @@ -575,9 +587,9 @@ def __init__(self, repeats):
def forward(self, x):
return x.repeat(*self.repeats)

inputs = [torch.rand(1, 2, 1, 4, 4)]
inputs = [torch.rand(shape) for shape in input_shapes]

framework_model = Repeat(repeats=(1, 1, 4, 1, 1))
framework_model = Repeat(repeats=repeats)
compiled_model = forge.compile(framework_model, sample_inputs=inputs)

verify(inputs, framework_model, compiled_model)
Expand Down
2 changes: 1 addition & 1 deletion third_party/tvm

0 comments on commit 01221e3

Please sign in to comment.