From 7339c2152287807bf5fd001b2a793ddb75599512 Mon Sep 17 00:00:00 2001 From: Lewis Panos Date: Fri, 13 Dec 2024 20:11:06 +0000 Subject: [PATCH] cleanup --- tests/torch/test_interpolation.py | 30 +++++++++++--------- tt_torch/dynamo/decompositions.py | 47 +++++++++++++------------------ tt_torch/dynamo/passes.py | 9 +++--- 3 files changed, 41 insertions(+), 45 deletions(-) diff --git a/tests/torch/test_interpolation.py b/tests/torch/test_interpolation.py index 3d255523..fb9654af 100644 --- a/tests/torch/test_interpolation.py +++ b/tests/torch/test_interpolation.py @@ -11,10 +11,10 @@ import torch.nn.functional as F -@pytest.mark.parametrize("inH", [50, 128, 224]) -@pytest.mark.parametrize("inW", [50, 128, 224]) -@pytest.mark.parametrize("inC", [1, 3]) -@pytest.mark.parametrize("scale_factor", [2, 3, 4, 5]) +@pytest.mark.parametrize("inH", [50, 128, 224, 960]) +@pytest.mark.parametrize("inW", [50, 128, 224, 540]) +@pytest.mark.parametrize("inC", [3]) +@pytest.mark.parametrize("scale_factor", [2, 3]) @pytest.mark.parametrize("align_corners", [False, True]) def test_bilinear_interpolation(inH, inW, inC, scale_factor, align_corners): torch.set_printoptions(linewidth=1000000, threshold=1000000) @@ -25,17 +25,21 @@ def __init__(self): def forward(self, x): return F.interpolate( - x, scale_factor=2, mode="bilinear", align_corners=align_corners + x, + scale_factor=scale_factor, + mode="bilinear", + align_corners=align_corners, ) input_shape = (1, inC, inH, inW) - out_shape = (1, inC, inH * scale_factor, inW * scale_factor) - small = ( - (torch.arange(torch.prod(torch.tensor(input_shape))) + 1) - .reshape(input_shape) - .float() - ) + small = torch.randn(input_shape, dtype=torch.bfloat16) cc = CompilerConfig() - cc.compile_depth = CompileDepth.STABLEHLO - verify_module(Basic(), inputs=[small], compiler_config=cc, required_atol=3.2e-2) + cc.enable_costeval = True + verify_module( + Basic(), + inputs=[small], + compiler_config=cc, + required_atol=3, + required_pcc=0.99 - 0.15 * scale_factor, + ) diff --git a/tt_torch/dynamo/decompositions.py b/tt_torch/dynamo/decompositions.py index e09a758b..561c5dc6 100644 --- a/tt_torch/dynamo/decompositions.py +++ b/tt_torch/dynamo/decompositions.py @@ -77,41 +77,29 @@ def _extend_context_manager( # This logic was derived from @brentyi's implementation in: # https://github.com/jax-ml/jax/issues/11206#issuecomment-1423140760 def compute_bilinear_weight(input_size, output_size, scale, align_corners, dtype): - zero_tensor = torch.full([1, 1, 1, 1], 0.0) - one_tensor = torch.full([1, 1, 1, 1], 1.0) - two_tensor = torch.full([1, 1, 1, 1], 2.0) - half_tensor = torch.full([1, 1, 1, 1], 0.5) - neg_half_tensor = torch.full([1, 1, 1, 1], -0.5) - output_size_f = torch.full([1, 1, 1, 1], float(output_size)) - input_size_f = torch.full([1, 1, 1, 1], float(input_size)) - - scale = torch.full([1, 1, 1, 1], float(scale)) - translation = zero_tensor + translation = 0 if align_corners: - scale = (output_size_f - one_tensor) / (input_size_f - one_tensor) - translation = half_tensor - (scale / two_tensor) + scale = (output_size - 1) / (input_size - 1) + translation = 0.5 - (scale / 2) - inv_scale = one_tensor / scale + inv_scale = 1 / scale sample_f = ( - (torch.arange(output_size).reshape(1, 1, 1, output_size) + half_tensor) - * inv_scale + (torch.arange(output_size, dtype=torch.float64) + 0.5) * inv_scale - translation * inv_scale - - half_tensor + - 0.5 ) - x = torch.abs(sample_f - torch.arange(input_size).reshape(1, 1, input_size, 1)) + x = torch.abs(sample_f - torch.arange(input_size, dtype=torch.float64).unsqueeze(1)) - weights = torch.relu(one_tensor - torch.abs(x)) + weights = torch.relu(1 - torch.abs(x)) - total_weight_sum = torch.sum(weights, axis=2, keepdims=True) + total_weight_sum = torch.sum(weights, axis=0, keepdims=True) weights = torch.divide( weights, - torch.where(total_weight_sum != zero_tensor, total_weight_sum, one_tensor), + torch.where(total_weight_sum != 0, total_weight_sum, 1), ) weights = torch.where( - torch.logical_and( - sample_f >= neg_half_tensor, sample_f <= input_size_f - half_tensor - ), + torch.logical_and(sample_f >= -0.5, sample_f <= input_size - 0.5), weights, 0, ) @@ -138,13 +126,13 @@ def upsample_bilinear2d( scales = [scales_h, scales_w] if ( scales_h == scales_w - and input_size[0] == output_size[0] - and input_size[1] == output_size[1] + and input_size[0] == input_size[1] + and output_size[0] == output_size[1] ): weight_w = compute_bilinear_weight( input_size[1], output_size[1], scales[1], False, input.dtype ) - weigh_h = weight_w.transpose(-1, -2) + weight_h = weight_w.transpose(-1, -2) else: weight_w = compute_bilinear_weight( input_size[1], output_size[1], scales[1], align_corners, input.dtype @@ -153,7 +141,11 @@ def upsample_bilinear2d( input_size[0], output_size[0], scales[0], align_corners, input.dtype ).transpose(-1, -2) - res = weight_h @ input @ weight_w + # breakpoint() + # res = weight_h @ input @ weight_w + res = (input.transpose(-1, -2) @ weight_h.transpose(-1, -2)).transpose( + -1, -2 + ) @ weight_w return res @@ -168,7 +160,6 @@ def _get_default_decomposition_ops() -> DecompositionOpsList: aten.select_backward, aten.norm.ScalarOpt_dim, aten.native_group_norm, - aten.upsample_bilinear2d.vec, aten.split.Tensor, aten.split_with_sizes, aten.native_layer_norm, diff --git a/tt_torch/dynamo/passes.py b/tt_torch/dynamo/passes.py index 3fc345f2..dced3dd3 100644 --- a/tt_torch/dynamo/passes.py +++ b/tt_torch/dynamo/passes.py @@ -78,12 +78,11 @@ def apply_decompositions( return gm with torch.no_grad(): - decompositions = get_decompositions(decompose_ops) fake_tensor_mode = torch._dynamo.utils.detect_fake_mode(example_inputs) fake_tensor_mode.allow_non_fake_inputs = True gm = make_fx( gm, - tracing_mode="symbolic", + # tracing_mode="symbolic", _allow_non_fake_inputs=True, decomposition_table=decompositions, )(*example_inputs) @@ -131,8 +130,10 @@ def constant_fold(gm, example_inputs): def pass_pipeline(gm: torch.fx.GraphModule, example_inputs, compiler_config): - decompose_ops = DEFAULT_DECOMPOSITIONS - gm = apply_decompositions(gm, example_inputs, decompose_ops) # type: ignore + decompositions = DEFAULT_DECOMPOSITION_TABLE + decompositions.update(CUSTOM_DECOMPOSITION_TABLE) + + gm = apply_decompositions(gm, example_inputs, decompositions) # type: ignore if compiler_config.enable_costeval: gm, graph_constants = constant_fold(gm, example_inputs) else: