Skip to content

Commit

Permalink
Move masked_fill blocklist to appropriate file
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinwuTT committed Oct 28, 2024
1 parent 1f9b033 commit fdf3a57
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
5 changes: 5 additions & 0 deletions torch_ttnn/passes/lowering/to_tt_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@
],
]

aten_masked_fill_scalar_blocklist = [
["Tensor<[2, 1, 7, 7]> self = ?", "Tensor<[2, 1, 7, 7]> mask = ?", "number value = -3.3895313892515355e+38"],
]

# Need to remove this from the blocklist so that yolos can pass
aten_view_default_blocklist.remove(["Tensor<[1, 192, 32, 42]> self = ?", "List[int] size = [1, 192, 1344]"])

Expand Down Expand Up @@ -187,6 +191,7 @@
GUARD[torch.ops.aten._to_copy.default] = partial(guard_aten, aten__to_copy_default_blocklist)
GUARD[torch.ops.aten.unsqueeze.default] = partial(guard_aten, aten_unsqueeze_default_blocklist)
GUARD[torch.ops.aten.squeeze.dim] = partial(guard_aten, aten_squeeze_dim_blocklist)
GUARD[torch.ops.aten.masked_fill.Scalar] = partial(guard_aten, aten_masked_fill_scalar_blocklist)


def can_lowering_to_ttnn(node):
Expand Down
5 changes: 0 additions & 5 deletions torch_ttnn/passes/lowering/to_tt_guard_autogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,9 +617,6 @@
["Tensor<[1, 40, 28, 28]> self = ?", "List[int] size = [1, 40, 28, 28]", "List[int] stride = [31360, 784, 28, 1]"],
]
aten_mm_default_blocklist = [["Tensor<[1, 21843]> self = ?", "Tensor<[21843, 768]> mat2 = ?"]]
aten_masked_fill_scalar_blocklist = [
["Tensor<[2, 1, 7, 7]> self = ?", "Tensor<[2, 1, 7, 7]> mask = ?", "number value = -3.3895313892515355e+38"],
]


def get_inputs(node):
Expand Down Expand Up @@ -677,7 +674,6 @@ def guard_aten(blocklist, node):
torch.ops.aten.native_dropout.default: partial(guard_aten, aten_native_dropout_default_blocklist),
torch.ops.aten.new_empty_strided.default: partial(guard_aten, aten_new_empty_strided_default_blocklist),
torch.ops.aten.mm.default: partial(guard_aten, aten_mm_default_blocklist),
torch.ops.aten.masked_fill.Scalar: partial(guard_aten, aten_masked_fill_scalar_blocklist),
}

guard_ops = [
Expand Down Expand Up @@ -715,5 +711,4 @@ def guard_aten(blocklist, node):
"torch.ops.aten.native_dropout.default",
"torch.ops.aten.new_empty_strided.default",
"torch.ops.aten.mm.default",
"torch.ops.aten.masked_fill.Scalar",
]

0 comments on commit fdf3a57

Please sign in to comment.