diff --git a/tt_torch/dynamo/backend.py b/tt_torch/dynamo/backend.py index 28524ab6..1c8e853e 100644 --- a/tt_torch/dynamo/backend.py +++ b/tt_torch/dynamo/backend.py @@ -226,12 +226,12 @@ def compile_op(self, node, *inputs, **kwargs): ): getitem_nodes = [] graph_node.meta["val"] = node.meta["val"] - for idx, _ in enumerate(node.meta["tensor_meta"]): + for idx, tensor_meta in enumerate(node.meta["tensor_meta"]): getitem_node = graph.call_function( operator.getitem, args=(graph_node, idx) ) - # getitem_node.meta["val"] = graph_node.meta["val"] getitem_nodes.append(getitem_node) + getitem_node.meta["tensor_meta"] = tensor_meta out = graph.output(tuple(getitem_nodes)) else: out = graph.output((graph_node,))