From d960356d804c55e9cf0e0e808947a1e059f59f1d Mon Sep 17 00:00:00 2001 From: Joe Malone Date: Fri, 28 Feb 2025 20:51:15 +0000 Subject: [PATCH] Updates conversion of torch.expand to ttnn.repeat to allow repeating on last dimension. This has been fixed in tt-metal, so we can now use tensor expand on the last dimension. This also allows us to run mobilenet_ssd end to end. closes 436 --- tests/models/mobilenet_ssd/test_mobilenet_ssd.py | 1 + torch_ttnn/passes/lowering/to_tt_pass.py | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/mobilenet_ssd/test_mobilenet_ssd.py b/tests/models/mobilenet_ssd/test_mobilenet_ssd.py index 4069424bd..abd37c61a 100644 --- a/tests/models/mobilenet_ssd/test_mobilenet_ssd.py +++ b/tests/models/mobilenet_ssd/test_mobilenet_ssd.py @@ -32,6 +32,7 @@ def _load_inputs(self): "mode", ["eval"], ) +@pytest.mark.converted_end_to_end def test_mobilenet_ssd(record_property, mode): model_name = "MobileNetSSD" record_property("model_name", model_name) diff --git a/torch_ttnn/passes/lowering/to_tt_pass.py b/torch_ttnn/passes/lowering/to_tt_pass.py index f741fb389..c55968c2f 100644 --- a/torch_ttnn/passes/lowering/to_tt_pass.py +++ b/torch_ttnn/passes/lowering/to_tt_pass.py @@ -676,8 +676,7 @@ def reshape_1d(code, args=args, kwargs=kwargs): # aten.expand and ttnn.repeat has different meaning for their `shape` argument # aten.expand: the desired output shape, where respective singleton dims are broadcasted # ttnn.repeat: the number of times to repeat a respective singleton dim - # Repeat fails if last dimension of input is 1 - if input_tensor_shape[-1] != 1 and len(input_tensor_shape) == len(output_shape): + if len(input_tensor_shape) == len(output_shape): return g.call_function(target_wrappers.repeat, args=(args[0], multiplier.tolist())) return None