Skip to content

Commit

Permalink
Add support for OpModel pydantic model in downstream parser
Browse files Browse the repository at this point in the history
  • Loading branch information
AleksKnezevic committed Jan 30, 2025
1 parent 766b188 commit 7e09bab
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 5 deletions.
4 changes: 3 additions & 1 deletion tt_torch/dynamo/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
CompilerConfig,
CompileDepth,
Op,
Tensor,
OpCompilationStatus,
calculate_atol,
calculate_pcc,
Expand Down Expand Up @@ -162,7 +163,7 @@ def compile_op(self, node, *inputs, **kwargs):
raise ValueError(f"Node target is not an OpOverload: {name}")
return None, None

op = Op(name, input_shapes_and_constants)
op = Op(name, input_shapes_and_constants, self.compiler_config.model_name)
if op.unique_key() not in self.compiler_config.unique_ops:
self.compiler_config.unique_ops[op.unique_key()] = op
else:
Expand Down Expand Up @@ -278,6 +279,7 @@ def compile_op(self, node, *inputs, **kwargs):
op.add_ttnn_graph(result["ttnn"])
ttnn_event.set()
op.compilation_status = OpCompilationStatus.CONVERTED_TO_TTNN
op.parse_json()
break
except mp.queues.Empty:
pass
Expand Down
65 changes: 61 additions & 4 deletions tt_torch/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,41 @@ class OpCompilationStatus(IntEnum):
EXECUTED = 7


class Tensor:
def __init__(self, shape):
constrained_shape = []
for dim in shape:
if isinstance(dim, int):
constrained_shape.append(dim)
else:
constrained_shape.append(-1)
self.shape = constrained_shape

self.data_type = ""
self.buffer_type = ""
self.layout = ""
self.grid_shape = []

def to_dict(self):
return {
"shape": self.shape,
"data_type": self.data_type,
"buffer_type": self.buffer_type,
"layout": self.layout,
"grid_shape": self.grid_shape,
}


class Op:
def __init__(self, torch_name, input_shapes):
self.torch_name = torch_name
def __init__(self, torch_name, input_shapes, model_name):
self.framework_op_name = torch_name
self.num_ops = 1
self.model_name = model_name
self.input_shapes = input_shapes
self.input_tensors = []
self.output_shapes = []
self.output_tensors = []
self.frontend = "tt-torch"

self.torch_ir_graph = ""
self.stable_hlo_graph = ""
Expand All @@ -56,6 +85,24 @@ def __init__(self, torch_name, input_shapes):
self.pcc = None
self.atol = None

def parse_json(self):
binary = json.loads(self.json)

def tensor_from_tensor_desc(desc):
tensor = Tensor(desc["shape"])
tensor.data_type = desc["layout"]["memory_desc"]["data_type"]
tensor.buffer_type = desc["layout"]["memory_desc"]["memory_space"]
tensor.layout = desc["layout"]["memory_desc"]["memory_layout"]
grid_shape = desc["layout"]["core_range_set"][0]["size"]
tensor.grid_shape = [grid_shape["x"], grid_shape["y"]]
return tensor

for inp in binary["programs"][0]["inputs"]:
self.input_tensors.append(tensor_from_tensor_desc(inp["desc"]))

for out in binary["programs"][0]["outputs"]:
self.output_tensors.append(tensor_from_tensor_desc(out["desc"]))

def print_shapes(self, shapes):
output = []
for shape in shapes:
Expand All @@ -78,10 +125,20 @@ def scrub_nan_inf(value):
pcc = scrub_nan_inf(self.pcc)
atol = scrub_nan_inf(self.atol)

if len(self.input_tensors) == 0:
self.input_tensors = [Tensor(shp) for shp in self.input_shapes]

if len(self.output_tensors) == 0:
self.output_tensors = [Tensor(shp) for shp in self.output_shapes]

return {
"torch_name": self.torch_name,
"framework_op_name": self.framework_op_name,
"frontend": self.frontend,
"model_name": self.model_name,
"input_shapes": self.print_shapes(self.input_shapes),
"input_tensors": [tensor.to_dict() for tensor in self.input_tensors],
"output_shapes": self.print_shapes(self.output_shapes),
"output_tensors": [tensor.to_dict() for tensor in self.output_tensors],
"num_ops": self.num_ops,
"compilation_status": self.compilation_status,
"parsed_stable_hlo_ops": self.parsed_stable_hlo_ops,
Expand All @@ -97,7 +154,7 @@ def scrub_nan_inf(value):
}

def unique_key(self):
key = self.torch_name
key = self.framework_op_name
for shape in self.input_shapes:
if isinstance(shape, torch.Size):
key += f"_{print_shape(shape)}"
Expand Down

0 comments on commit 7e09bab

Please sign in to comment.