Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added decomp for bilinear upsample that uses matmul #112

Merged
merged 1 commit into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -236,3 +236,4 @@ distributed as part of the software:
- pillow - Custom License (https://github.com/python-pillow/Pillow/blob/main/LICENSE)
- kornia - Apache v2.0 (https://github.com/kornia/kornia/blob/main/LICENSE)
- timm - MIT License (https://github.com/guigrpa/timm/blob/master/LICENSE)
- jax - Apache v2.0 (https://github.com/jax-ml/jax/blob/main/LICENSE)
43 changes: 43 additions & 0 deletions tests/torch/test_interpolation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0
import torch
from torch import nn
import pytest

import tt_torch
from tt_torch.tools.verify import verify_module
from tt_torch.tools.utils import CompilerConfig, CompileDepth
import torch.nn.functional as F


@pytest.mark.parametrize("inH", [50, 128, 224, 960])
@pytest.mark.parametrize("inW", [50, 128, 224, 540])
@pytest.mark.parametrize("inC", [3])
@pytest.mark.parametrize("scale_factor", [2, 3])
@pytest.mark.parametrize("align_corners", [False, True])
def test_bilinear_interpolation(inH, inW, inC, scale_factor, align_corners):
class Interpolate(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return F.interpolate(
x,
scale_factor=scale_factor,
mode="bilinear",
align_corners=align_corners,
)

input_shape = (1, inC, inH, inW)
small = torch.randn(input_shape, dtype=torch.bfloat16)

cc = CompilerConfig()
cc.enable_consteval = True
verify_module(
Interpolate(),
inputs=[small],
compiler_config=cc,
required_atol=3,
required_pcc=0.99 - 0.05 * scale_factor,
)
91 changes: 89 additions & 2 deletions tt_torch/dynamo/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
from torch._decomp import get_decompositions, remove_decompositions
from torch_mlir.extras.fx_decomp_util import get_decomposition_table
import numpy as np

DecompositionTable = Dict[torch._ops.OperatorBase, Callable]
DecompositionOpsList = Sequence[
Expand Down Expand Up @@ -67,6 +68,85 @@ def _extend_context_manager(
), "contextmanager unbalanced: popped different that pushed"


# This method is derived from the implementation of jax.image.resize in JAX:
# https://github.com/jax-ml/jax/blob/354bd5271077654af983965c8e01ee462ce4ce91/jax/_src/image/scale.py#L52
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What licence is this published under? Do we need to add it to our licence files?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#
# I've modified it to use numpy rather than JAX. I've also added the ability
# to generate a weight matrix that allows the matmul to be identical to to
# torch's upsample_bilinear2d when align_corners=True.
# This logic was derived from @brentyi's implementation in:
# https://github.com/jax-ml/jax/issues/11206#issuecomment-1423140760
def compute_bilinear_weight(input_size, output_size, scale, align_corners, dtype):
translation = 0
if align_corners:
scale = (output_size - 1) / (input_size - 1)
translation = 0.5 - (scale / 2)

inv_scale = 1 / scale
sample_f = (
(torch.arange(output_size, dtype=torch.float64) + 0.5) * inv_scale
- translation * inv_scale
- 0.5
)
x = torch.abs(sample_f - torch.arange(input_size, dtype=torch.float64).unsqueeze(1))

weights = torch.relu(1 - torch.abs(x))

total_weight_sum = torch.sum(weights, axis=0, keepdims=True)
weights = torch.divide(
weights,
torch.where(total_weight_sum != 0, total_weight_sum, 1),
)

weights = torch.where(
torch.logical_and(sample_f >= -0.5, sample_f <= input_size - 0.5),
weights,
0,
)
weights = weights.squeeze()
return weights.to(dtype)


def upsample_bilinear2d(
input: torch.Tensor,
output_size: List[int],
align_corners: bool,
scales_h: Optional[float] = None,
scales_w: Optional[float] = None,
):
input_size = input.shape[-2:]
res = None

if scales_h is None:
scales_h = float(output_size[0]) / float(input_size[0])

if scales_w is None:
scales_w = float(output_size[1]) / float(input_size[1])

scales = [scales_h, scales_w]
if (
scales_h == scales_w
and input_size[0] == input_size[1]
and output_size[0] == output_size[1]
):
weight_w = compute_bilinear_weight(
input_size[1], output_size[1], scales[1], False, input.dtype
)
weight_h = weight_w.transpose(-1, -2)
else:
weight_w = compute_bilinear_weight(
input_size[1], output_size[1], scales[1], align_corners, input.dtype
)
weight_h = compute_bilinear_weight(
input_size[0], output_size[0], scales[0], align_corners, input.dtype
).transpose(-1, -2)

res = (input.transpose(-1, -2) @ weight_h.transpose(-1, -2)).transpose(
-1, -2
) @ weight_w
return res


# TODO: DO we ever need this?
def _get_default_decomposition_ops() -> DecompositionOpsList:
aten = torch.ops.aten
Expand All @@ -78,7 +158,6 @@ def _get_default_decomposition_ops() -> DecompositionOpsList:
aten.select_backward,
aten.norm.ScalarOpt_dim,
aten.native_group_norm,
aten.upsample_bilinear2d.vec,
aten.split.Tensor,
aten.split_with_sizes,
aten.native_layer_norm,
Expand Down Expand Up @@ -116,15 +195,23 @@ def _get_default_decomposition_ops() -> DecompositionOpsList:
aten.unbind.int,
aten.linspace.Tensor_Tensor,
aten._scaled_dot_product_flash_attention_for_cpu.default,
aten.upsample_bilinear2d,
aten.slice_scatter,
]


def _get_custom_decopositions() -> DecompositionTable:
aten = torch.ops.aten
return {
aten.upsample_bilinear2d.default: upsample_bilinear2d,
}


# Some older APIs still use an op list instead of a table.
DEFAULT_DECOMPOSITIONS: DecompositionOpsList = _get_default_decomposition_ops()

# The table of default decompositions.
DEFAULT_DECOMPOSITION_TABLE: DecompositionTable = get_decompositions(
DEFAULT_DECOMPOSITIONS
)

CUSTOM_DECOMPOSITION_TABLE = _get_custom_decopositions()
16 changes: 10 additions & 6 deletions tt_torch/dynamo/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
from torch.func import functionalize
from typing import List, Optional, Union

from .decompositions import DEFAULT_DECOMPOSITIONS
from .decompositions import (
DecompositionTable,
DEFAULT_DECOMPOSITION_TABLE,
CUSTOM_DECOMPOSITION_TABLE,
)


def run_shape_prop(gm, example_inputs):
Expand Down Expand Up @@ -65,17 +69,16 @@ def reduce_graph(module_or_graph: Union[torch.fx.Graph, torch.fx.GraphModule]):
def apply_decompositions(
gm: torch.fx.GraphModule,
example_inputs,
decompose_ops: Optional[List[torch._ops.OpOverload]] = None,
decompositions: Optional[DecompositionTable] = None,
):
concrete_inputs = [
x.view(tuple(int(dim) for dim in x.shape)) if isinstance(x, torch.Tensor) else x
for x in example_inputs
]
if decompose_ops is None:
if decompositions is None:
return gm

with torch.no_grad():
decompositions = get_decompositions(decompose_ops)
gm = make_fx(
functionalize(gm),
decomposition_table=decompositions,
Expand Down Expand Up @@ -186,8 +189,9 @@ def order_constant_inputs(gm, parameters, constants):


def pass_pipeline(gm: torch.fx.GraphModule, example_inputs, compiler_config):
decompose_ops = DEFAULT_DECOMPOSITIONS
gm = apply_decompositions(gm, example_inputs, decompose_ops) # type: ignore
decompositions = DEFAULT_DECOMPOSITION_TABLE
decompositions.update(CUSTOM_DECOMPOSITION_TABLE)
gm = apply_decompositions(gm, example_inputs, decompositions) # type: ignore
if compiler_config.enable_consteval:
gm, constants = constant_fold(gm, example_inputs)
elif compiler_config.consteval_parameters:
Expand Down
Loading