Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
LPanosTT committed Dec 13, 2024
1 parent cf9922a commit 7339c21
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 45 deletions.
30 changes: 17 additions & 13 deletions tests/torch/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
import torch.nn.functional as F


@pytest.mark.parametrize("inH", [50, 128, 224])
@pytest.mark.parametrize("inW", [50, 128, 224])
@pytest.mark.parametrize("inC", [1, 3])
@pytest.mark.parametrize("scale_factor", [2, 3, 4, 5])
@pytest.mark.parametrize("inH", [50, 128, 224, 960])
@pytest.mark.parametrize("inW", [50, 128, 224, 540])
@pytest.mark.parametrize("inC", [3])
@pytest.mark.parametrize("scale_factor", [2, 3])
@pytest.mark.parametrize("align_corners", [False, True])
def test_bilinear_interpolation(inH, inW, inC, scale_factor, align_corners):
torch.set_printoptions(linewidth=1000000, threshold=1000000)
Expand All @@ -25,17 +25,21 @@ def __init__(self):

def forward(self, x):
return F.interpolate(
x, scale_factor=2, mode="bilinear", align_corners=align_corners
x,
scale_factor=scale_factor,
mode="bilinear",
align_corners=align_corners,
)

input_shape = (1, inC, inH, inW)
out_shape = (1, inC, inH * scale_factor, inW * scale_factor)
small = (
(torch.arange(torch.prod(torch.tensor(input_shape))) + 1)
.reshape(input_shape)
.float()
)
small = torch.randn(input_shape, dtype=torch.bfloat16)

cc = CompilerConfig()
cc.compile_depth = CompileDepth.STABLEHLO
verify_module(Basic(), inputs=[small], compiler_config=cc, required_atol=3.2e-2)
cc.enable_costeval = True
verify_module(
Basic(),
inputs=[small],
compiler_config=cc,
required_atol=3,
required_pcc=0.99 - 0.15 * scale_factor,
)
47 changes: 19 additions & 28 deletions tt_torch/dynamo/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,41 +77,29 @@ def _extend_context_manager(
# This logic was derived from @brentyi's implementation in:
# https://github.com/jax-ml/jax/issues/11206#issuecomment-1423140760
def compute_bilinear_weight(input_size, output_size, scale, align_corners, dtype):
zero_tensor = torch.full([1, 1, 1, 1], 0.0)
one_tensor = torch.full([1, 1, 1, 1], 1.0)
two_tensor = torch.full([1, 1, 1, 1], 2.0)
half_tensor = torch.full([1, 1, 1, 1], 0.5)
neg_half_tensor = torch.full([1, 1, 1, 1], -0.5)
output_size_f = torch.full([1, 1, 1, 1], float(output_size))
input_size_f = torch.full([1, 1, 1, 1], float(input_size))

scale = torch.full([1, 1, 1, 1], float(scale))
translation = zero_tensor
translation = 0
if align_corners:
scale = (output_size_f - one_tensor) / (input_size_f - one_tensor)
translation = half_tensor - (scale / two_tensor)
scale = (output_size - 1) / (input_size - 1)
translation = 0.5 - (scale / 2)

inv_scale = one_tensor / scale
inv_scale = 1 / scale
sample_f = (
(torch.arange(output_size).reshape(1, 1, 1, output_size) + half_tensor)
* inv_scale
(torch.arange(output_size, dtype=torch.float64) + 0.5) * inv_scale
- translation * inv_scale
- half_tensor
- 0.5
)
x = torch.abs(sample_f - torch.arange(input_size).reshape(1, 1, input_size, 1))
x = torch.abs(sample_f - torch.arange(input_size, dtype=torch.float64).unsqueeze(1))

weights = torch.relu(one_tensor - torch.abs(x))
weights = torch.relu(1 - torch.abs(x))

total_weight_sum = torch.sum(weights, axis=2, keepdims=True)
total_weight_sum = torch.sum(weights, axis=0, keepdims=True)
weights = torch.divide(
weights,
torch.where(total_weight_sum != zero_tensor, total_weight_sum, one_tensor),
torch.where(total_weight_sum != 0, total_weight_sum, 1),
)

weights = torch.where(
torch.logical_and(
sample_f >= neg_half_tensor, sample_f <= input_size_f - half_tensor
),
torch.logical_and(sample_f >= -0.5, sample_f <= input_size - 0.5),
weights,
0,
)
Expand All @@ -138,13 +126,13 @@ def upsample_bilinear2d(
scales = [scales_h, scales_w]
if (
scales_h == scales_w
and input_size[0] == output_size[0]
and input_size[1] == output_size[1]
and input_size[0] == input_size[1]
and output_size[0] == output_size[1]
):
weight_w = compute_bilinear_weight(
input_size[1], output_size[1], scales[1], False, input.dtype
)
weigh_h = weight_w.transpose(-1, -2)
weight_h = weight_w.transpose(-1, -2)
else:
weight_w = compute_bilinear_weight(
input_size[1], output_size[1], scales[1], align_corners, input.dtype
Expand All @@ -153,7 +141,11 @@ def upsample_bilinear2d(
input_size[0], output_size[0], scales[0], align_corners, input.dtype
).transpose(-1, -2)

res = weight_h @ input @ weight_w
# breakpoint()
# res = weight_h @ input @ weight_w
res = (input.transpose(-1, -2) @ weight_h.transpose(-1, -2)).transpose(
-1, -2
) @ weight_w
return res


Expand All @@ -168,7 +160,6 @@ def _get_default_decomposition_ops() -> DecompositionOpsList:
aten.select_backward,
aten.norm.ScalarOpt_dim,
aten.native_group_norm,
aten.upsample_bilinear2d.vec,
aten.split.Tensor,
aten.split_with_sizes,
aten.native_layer_norm,
Expand Down
9 changes: 5 additions & 4 deletions tt_torch/dynamo/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,11 @@ def apply_decompositions(
return gm

with torch.no_grad():
decompositions = get_decompositions(decompose_ops)
fake_tensor_mode = torch._dynamo.utils.detect_fake_mode(example_inputs)
fake_tensor_mode.allow_non_fake_inputs = True
gm = make_fx(
gm,
tracing_mode="symbolic",
# tracing_mode="symbolic",
_allow_non_fake_inputs=True,
decomposition_table=decompositions,
)(*example_inputs)
Expand Down Expand Up @@ -131,8 +130,10 @@ def constant_fold(gm, example_inputs):


def pass_pipeline(gm: torch.fx.GraphModule, example_inputs, compiler_config):
decompose_ops = DEFAULT_DECOMPOSITIONS
gm = apply_decompositions(gm, example_inputs, decompose_ops) # type: ignore
decompositions = DEFAULT_DECOMPOSITION_TABLE
decompositions.update(CUSTOM_DECOMPOSITION_TABLE)

gm = apply_decompositions(gm, example_inputs, decompositions) # type: ignore
if compiler_config.enable_costeval:
gm, graph_constants = constant_fold(gm, example_inputs)
else:
Expand Down

0 comments on commit 7339c21

Please sign in to comment.