diff --git a/LICENSE b/LICENSE index 04c28499..b239892b 100644 --- a/LICENSE +++ b/LICENSE @@ -236,3 +236,4 @@ distributed as part of the software: - pillow - Custom License (https://github.com/python-pillow/Pillow/blob/main/LICENSE) - kornia - Apache v2.0 (https://github.com/kornia/kornia/blob/main/LICENSE) - timm - MIT License (https://github.com/guigrpa/timm/blob/master/LICENSE) +- jax - Apache v2.0 (https://github.com/jax-ml/jax/blob/main/LICENSE) diff --git a/tests/torch/test_interpolation.py b/tests/torch/test_interpolation.py new file mode 100644 index 00000000..5a4edd7d --- /dev/null +++ b/tests/torch/test_interpolation.py @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 +import torch +from torch import nn +import pytest + +import tt_torch +from tt_torch.tools.verify import verify_module +from tt_torch.tools.utils import CompilerConfig, CompileDepth +import torch.nn.functional as F + + +@pytest.mark.parametrize("inH", [50, 128, 224, 960]) +@pytest.mark.parametrize("inW", [50, 128, 224, 540]) +@pytest.mark.parametrize("inC", [3]) +@pytest.mark.parametrize("scale_factor", [2, 3]) +@pytest.mark.parametrize("align_corners", [False, True]) +def test_bilinear_interpolation(inH, inW, inC, scale_factor, align_corners): + class Interpolate(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return F.interpolate( + x, + scale_factor=scale_factor, + mode="bilinear", + align_corners=align_corners, + ) + + input_shape = (1, inC, inH, inW) + small = torch.randn(input_shape, dtype=torch.bfloat16) + + cc = CompilerConfig() + cc.enable_consteval = True + verify_module( + Interpolate(), + inputs=[small], + compiler_config=cc, + required_atol=3, + required_pcc=0.99 - 0.05 * scale_factor, + ) diff --git a/tt_torch/dynamo/decompositions.py b/tt_torch/dynamo/decompositions.py index 41764383..c447b45a 100644 --- a/tt_torch/dynamo/decompositions.py +++ b/tt_torch/dynamo/decompositions.py @@ -9,6 +9,7 @@ import torch from torch._decomp import get_decompositions, remove_decompositions from torch_mlir.extras.fx_decomp_util import get_decomposition_table +import numpy as np DecompositionTable = Dict[torch._ops.OperatorBase, Callable] DecompositionOpsList = Sequence[ @@ -67,6 +68,85 @@ def _extend_context_manager( ), "contextmanager unbalanced: popped different that pushed" +# This method is derived from the implementation of jax.image.resize in JAX: +# https://github.com/jax-ml/jax/blob/354bd5271077654af983965c8e01ee462ce4ce91/jax/_src/image/scale.py#L52 +# +# I've modified it to use numpy rather than JAX. I've also added the ability +# to generate a weight matrix that allows the matmul to be identical to to +# torch's upsample_bilinear2d when align_corners=True. +# This logic was derived from @brentyi's implementation in: +# https://github.com/jax-ml/jax/issues/11206#issuecomment-1423140760 +def compute_bilinear_weight(input_size, output_size, scale, align_corners, dtype): + translation = 0 + if align_corners: + scale = (output_size - 1) / (input_size - 1) + translation = 0.5 - (scale / 2) + + inv_scale = 1 / scale + sample_f = ( + (torch.arange(output_size, dtype=torch.float64) + 0.5) * inv_scale + - translation * inv_scale + - 0.5 + ) + x = torch.abs(sample_f - torch.arange(input_size, dtype=torch.float64).unsqueeze(1)) + + weights = torch.relu(1 - torch.abs(x)) + + total_weight_sum = torch.sum(weights, axis=0, keepdims=True) + weights = torch.divide( + weights, + torch.where(total_weight_sum != 0, total_weight_sum, 1), + ) + + weights = torch.where( + torch.logical_and(sample_f >= -0.5, sample_f <= input_size - 0.5), + weights, + 0, + ) + weights = weights.squeeze() + return weights.to(dtype) + + +def upsample_bilinear2d( + input: torch.Tensor, + output_size: List[int], + align_corners: bool, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +): + input_size = input.shape[-2:] + res = None + + if scales_h is None: + scales_h = float(output_size[0]) / float(input_size[0]) + + if scales_w is None: + scales_w = float(output_size[1]) / float(input_size[1]) + + scales = [scales_h, scales_w] + if ( + scales_h == scales_w + and input_size[0] == input_size[1] + and output_size[0] == output_size[1] + ): + weight_w = compute_bilinear_weight( + input_size[1], output_size[1], scales[1], False, input.dtype + ) + weight_h = weight_w.transpose(-1, -2) + else: + weight_w = compute_bilinear_weight( + input_size[1], output_size[1], scales[1], align_corners, input.dtype + ) + weight_h = compute_bilinear_weight( + input_size[0], output_size[0], scales[0], align_corners, input.dtype + ).transpose(-1, -2) + + res = (input.transpose(-1, -2) @ weight_h.transpose(-1, -2)).transpose( + -1, -2 + ) @ weight_w + return res + + # TODO: DO we ever need this? def _get_default_decomposition_ops() -> DecompositionOpsList: aten = torch.ops.aten @@ -78,7 +158,6 @@ def _get_default_decomposition_ops() -> DecompositionOpsList: aten.select_backward, aten.norm.ScalarOpt_dim, aten.native_group_norm, - aten.upsample_bilinear2d.vec, aten.split.Tensor, aten.split_with_sizes, aten.native_layer_norm, @@ -116,11 +195,17 @@ def _get_default_decomposition_ops() -> DecompositionOpsList: aten.unbind.int, aten.linspace.Tensor_Tensor, aten._scaled_dot_product_flash_attention_for_cpu.default, - aten.upsample_bilinear2d, aten.slice_scatter, ] +def _get_custom_decopositions() -> DecompositionTable: + aten = torch.ops.aten + return { + aten.upsample_bilinear2d.default: upsample_bilinear2d, + } + + # Some older APIs still use an op list instead of a table. DEFAULT_DECOMPOSITIONS: DecompositionOpsList = _get_default_decomposition_ops() @@ -128,3 +213,5 @@ def _get_default_decomposition_ops() -> DecompositionOpsList: DEFAULT_DECOMPOSITION_TABLE: DecompositionTable = get_decompositions( DEFAULT_DECOMPOSITIONS ) + +CUSTOM_DECOMPOSITION_TABLE = _get_custom_decopositions() diff --git a/tt_torch/dynamo/passes.py b/tt_torch/dynamo/passes.py index e591957b..31f1742d 100644 --- a/tt_torch/dynamo/passes.py +++ b/tt_torch/dynamo/passes.py @@ -9,7 +9,11 @@ from torch.func import functionalize from typing import List, Optional, Union -from .decompositions import DEFAULT_DECOMPOSITIONS +from .decompositions import ( + DecompositionTable, + DEFAULT_DECOMPOSITION_TABLE, + CUSTOM_DECOMPOSITION_TABLE, +) def run_shape_prop(gm, example_inputs): @@ -65,17 +69,16 @@ def reduce_graph(module_or_graph: Union[torch.fx.Graph, torch.fx.GraphModule]): def apply_decompositions( gm: torch.fx.GraphModule, example_inputs, - decompose_ops: Optional[List[torch._ops.OpOverload]] = None, + decompositions: Optional[DecompositionTable] = None, ): concrete_inputs = [ x.view(tuple(int(dim) for dim in x.shape)) if isinstance(x, torch.Tensor) else x for x in example_inputs ] - if decompose_ops is None: + if decompositions is None: return gm with torch.no_grad(): - decompositions = get_decompositions(decompose_ops) gm = make_fx( functionalize(gm), decomposition_table=decompositions, @@ -186,8 +189,9 @@ def order_constant_inputs(gm, parameters, constants): def pass_pipeline(gm: torch.fx.GraphModule, example_inputs, compiler_config): - decompose_ops = DEFAULT_DECOMPOSITIONS - gm = apply_decompositions(gm, example_inputs, decompose_ops) # type: ignore + decompositions = DEFAULT_DECOMPOSITION_TABLE + decompositions.update(CUSTOM_DECOMPOSITION_TABLE) + gm = apply_decompositions(gm, example_inputs, decompositions) # type: ignore if compiler_config.enable_consteval: gm, constants = constant_fold(gm, example_inputs) elif compiler_config.consteval_parameters: