From f7c2217f2ca7c5bf8fa99bed513c496947a07f4f Mon Sep 17 00:00:00 2001 From: Lewis Panos Date: Fri, 21 Feb 2025 18:18:08 +0000 Subject: [PATCH] Fix parse_json because it WILL break on uplift tonight without this --- tt_torch/tools/utils.py | 35 +++++++++++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/tt_torch/tools/utils.py b/tt_torch/tools/utils.py index 916832db..54e5c204 100644 --- a/tt_torch/tools/utils.py +++ b/tt_torch/tools/utils.py @@ -97,10 +97,37 @@ def tensor_from_tensor_desc(desc): tensor = Tensor(desc["shape"]) if "memory_desc" in desc["layout"]: tensor.data_type = desc["layout"]["memory_desc"]["data_type"] - tensor.buffer_type = desc["layout"]["memory_desc"]["memory_space"] - tensor.layout = desc["layout"]["memory_desc"]["memory_layout"] - grid_shape = desc["layout"]["core_range_set"][0]["size"] - tensor.grid_shape = [grid_shape["x"], grid_shape["y"]] + try: + tensor.buffer_type = desc["layout"]["memory_desc"]["memory_space"] + except KeyError: + + if "memory_config" in desc["layout"]["memory_desc"]: + # If the tensor is on device, the descriptor will have a "memory_config" field + tensor.buffer_type = desc["layout"]["memory_desc"][ + "memory_config" + ]["buffer_type"] + else: + # If the tensor is on host, the descriptor will have a "storage_type" field and no "memory_config" field + tensor.buffer_type = desc["layout"]["memory_desc"][ + "storage_type" + ] + + try: + tensor.layout = desc["layout"]["memory_desc"]["memory_layout"] + except KeyError: + if "memory_config" in desc["layout"]["memory_desc"]: + # If the tensor is on device, the descriptor will have a "memory_config" field + tensor.layout = desc["layout"]["memory_desc"]["memory_config"][ + "tensor_memory_layout" + ] + else: + # If the tensor is on host, there will be no "memory_config" and thus no "tensor_memory_layout" field + tensor.layout = "" + try: + grid_shape = desc["layout"]["core_range_set"][0]["size"] + tensor.grid_shape = [grid_shape["x"], grid_shape["y"]] + except KeyError: + pass return tensor for inp in binary["programs"][0]["inputs"]: