From 648a58c2ee970260c0429d825759915158f8d54d Mon Sep 17 00:00:00 2001 From: Ashok Kumar Kannan <160501980+ashokkumarkannan1@users.noreply.github.com> Date: Mon, 24 Feb 2025 20:01:33 +0000 Subject: [PATCH] upsample2d op support --- forge/csrc/passes/lower_to_mlir.cpp | 1 + forge/forge/op/__init__.py | 2 +- forge/forge/op/eval/forge/__init__.py | 1 + forge/forge/op/eval/forge/resize.py | 78 ++++++++++++++++++++----- forge/forge/op/resize.py | 38 ++++++++++++ forge/forge/tvm_to_python.py | 2 + forge/test/mlir/operators/nn/test_nn.py | 7 ++- 7 files changed, 111 insertions(+), 18 deletions(-) diff --git a/forge/csrc/passes/lower_to_mlir.cpp b/forge/csrc/passes/lower_to_mlir.cpp index a0dd747dd..a4fd6cf94 100644 --- a/forge/csrc/passes/lower_to_mlir.cpp +++ b/forge/csrc/passes/lower_to_mlir.cpp @@ -647,6 +647,7 @@ class MLIRGenerator lowering_handler_map["tanh"] = &MLIRGenerator::emit_mlir_ttforge_op; lowering_handler_map["transpose"] = &MLIRGenerator::emit_mlir_ttforge_op; lowering_handler_map["unsqueeze"] = &MLIRGenerator::emit_mlir_ttforge_op; + lowering_handler_map["upsample2d"] = &MLIRGenerator::emit_mlir_ttforge_op; } }; diff --git a/forge/forge/op/__init__.py b/forge/forge/op/__init__.py index 9684b518e..edcd47fd2 100644 --- a/forge/forge/op/__init__.py +++ b/forge/forge/op/__init__.py @@ -72,7 +72,7 @@ from .constant import Constant from .nn import Softmax, Layernorm, LogSoftmax, Batchnorm, MaxPool2dModule from .eltwise_nary import Concatenate, Where, IndexCopy, Stack, Interleave -from .resize import Resize2d, Resize3d +from .resize import Resize2d, Resize3d, Upsample2d from .embedding import Embedding from .dram_queue import DRAMQueue from .quantize import Quantize, Dequantize, Requantize, ForgeRequantize diff --git a/forge/forge/op/eval/forge/__init__.py b/forge/forge/op/eval/forge/__init__.py index b55536748..1a564b873 100644 --- a/forge/forge/op/eval/forge/__init__.py +++ b/forge/forge/op/eval/forge/__init__.py @@ -123,6 +123,7 @@ "constant": "constant", "resize2d": "resize", "resize3d": "resize", + "upsample2d": "resize", "dram_queue": "dram_queue", "softmax": "nn", "log_softmax": "nn", diff --git a/forge/forge/op/eval/forge/resize.py b/forge/forge/op/eval/forge/resize.py index ee1836ce1..b1a099a6f 100644 --- a/forge/forge/op/eval/forge/resize.py +++ b/forge/forge/op/eval/forge/resize.py @@ -26,28 +26,29 @@ def eval(type, attr, ops): assert len(ops) == 1 resize_method = INT_TO_RESIZE2d_METHOD[attr[-3]] acts = ops[0] + shape = acts.shape + channel_last = attr[-1] + sizes = attr + + # Determine whether to use upsample2d or downsample2d to replicate the resize2d operation. + # If the target size is larger than the input size, apply upsampling; otherwise, apply downsampling. + # Example: Given target size (14,14) and input shape (1,3,7,7), since 14 > 7, we use upsampling. + upsample = sizes[0] >= shape[-3] if channel_last else sizes[0] >= shape[-2] + scale_factor = sizes[0] // shape[-3] if channel_last else sizes[0] // shape[-2] + if attr[-1]: # channel last acts = ops[0].permute((0, 3, 1, 2)) - if resize_method == "nearest": - upsample = torch.nn.Upsample( - size=attr[0:2], - mode=resize_method, - ) + if upsample: + upsample = torch.nn.Upsample(scale_factor=scale_factor, mode=resize_method) + result = upsample(acts) else: - upsample = torch.nn.Upsample( - size=attr[0:2], - mode=resize_method, - align_corners=bool(attr[-2]), - ) - - t_ops = to_torch_operands(acts) - - result = upsample(*t_ops) + raise NotImplementedError("Downsampling of resize2d is not supported yet") if attr[-1]: result = result.permute((0, 2, 3, 1)) + return result elif type == "resize3d": assert len(attr) == 6, "Resize3d should have 6 attrs: [size, size, size, method, align_corners, channel_last]" @@ -77,6 +78,14 @@ def eval(type, attr, ops): if attr[-1]: result = result.permute((0, 2, 3, 4, 1)) return result + elif type == "upsample2d": + operandA = ops[0] + scale_factor = attr[0] + resize_method = attr[2] + + upsample = torch.nn.Upsample(scale_factor=scale_factor, mode=resize_method) + result = upsample(operandA) + return result def shape(type, attr, ops): @@ -158,6 +167,16 @@ def shape(type, attr, ops): shape[-3], shape[-2], shape[-1] = attr[0], attr[1], attr[2] return shape, [] + elif type == "upsample2d": + channel_last = attr[2] + scale_factor = attr[0] + shape = list(ops[0]) + if channel_last: + shape[-3], shape[-2] = shape[-3] * scale_factor, shape[-2] * scale_factor + else: + shape[-2], shape[-1] = shape[-2] * scale_factor, shape[-1] * scale_factor + return shape, [] + def lower(type, attr, lc, ops, outputs): raise RuntimeError("This should never be called.") @@ -277,6 +296,37 @@ def decompose_resize3d(attr, dc, inputs, resize_method): def decompose(type, attr, dc, inputs): + if type == "resize2d": + resize_method = INT_TO_RESIZE2d_METHOD[attr[2]] + acts = inputs[0] + shape = acts.shape + channel_last = attr[-1] + sizes = attr[0:2] + + if not channel_last: + upsample = sizes[0] >= shape[-3] if channel_last else sizes[0] >= shape[-2] + scale_factor = sizes[0] // shape[-3] if channel_last else sizes[0] // shape[-2] + result = inputs[0] + + # Changing the Layout from NCHW to NHWC as ttir.upsample2d supports only the NHWC layout + result = dc.op(TransposeTM.create(dim0=-3, dim1=-2), [result]) + result = dc.op(TransposeTM.create(dim0=-2, dim1=-1), [result]) + + if upsample: + result = dc.op_with_named_attrs( + "upsample2d", + [result], + {"scale_factor": scale_factor, "mode": resize_method, "channel_last": channel_last}, + (scale_factor, resize_method, True), + ) + else: + raise NotImplementedError("Downsampling of resize2d is not supported yet") + + # Changing the Layout back to NCHW from NHWC after ttir.upsample2d operation + result = dc.op(TransposeTM.create(dim0=-2, dim1=-1), [result]) + result = dc.op(TransposeTM.create(dim0=-3, dim1=-2), [result]) + dc.fuse(result) + if type == "resize3d": assert len(attr) == 6, "Resize3d should have 6 attrs: [size, size, size, method, align_corners, channel_last]" assert len(inputs) == 1 diff --git a/forge/forge/op/resize.py b/forge/forge/op/resize.py index edc7307e9..c6a12dc41 100644 --- a/forge/forge/op/resize.py +++ b/forge/forge/op/resize.py @@ -62,6 +62,44 @@ def Resize2d( return result +def Upsample2d( + name: str, operandA: Tensor, scale_factor: int, mode: str = "nearest", channel_last: bool = False +) -> Tensor: + """ + Upsample 2D operation + + Parameters + ---------- + name: str + Op name, unique to the module, or leave blank to autoset + + operandA: Tensor + Input operand A + + scale_factor: int + multiplier for spatial size. + + mode: str + the upsampling algorithm + + Returns + ------- + Tensor + Forge tensor + """ + result: Tensor = op( + "upsample2d", + name, + operandA, + attrs=(scale_factor, mode, channel_last), + scale_factor=scale_factor, + mode=mode, + channel_last=channel_last, + ).get_tensor() + + return result + + def Resize3d( name: str, operandA: Tensor, diff --git a/forge/forge/tvm_to_python.py b/forge/forge/tvm_to_python.py index c411db8c2..424c81572 100644 --- a/forge/forge/tvm_to_python.py +++ b/forge/forge/tvm_to_python.py @@ -1762,6 +1762,7 @@ def populate_requantize_args(graph, nid, compiler_cfg): "qnn.requantize": "requantize", "qnn.dense": "matmul", "atan": "atan", + "upsample2d": "upsample2d", } forge_op_to_function_name = { @@ -1847,6 +1848,7 @@ def populate_requantize_args(graph, nid, compiler_cfg): "dequantize": "forge.op.Dequantize", "requantize": "forge.op.Requantize", "atan": "forge.op.Atan", + "upsample2d": "forge.op.Upsample2d", } forge_ops_needing_arguments = { "argmax": populate_argmax_args, diff --git a/forge/test/mlir/operators/nn/test_nn.py b/forge/test/mlir/operators/nn/test_nn.py index c03d1c2e1..3206ef1a2 100644 --- a/forge/test/mlir/operators/nn/test_nn.py +++ b/forge/test/mlir/operators/nn/test_nn.py @@ -188,11 +188,12 @@ def forward(self, x): @pytest.mark.parametrize( "shape, mode", [ - ((1, 2048, 7, 7), "nearest"), - ((1, 2048, 7, 7), "bilinear"), + pytest.param((1, 2048, 7, 7), "nearest"), + pytest.param( + (1, 2048, 7, 7), "bilinear", marks=pytest.mark.xfail(reason="Runtime Error TTNN: info: Unsupported mode ") + ), ], ) -@pytest.mark.xfail(reason="Found Unsupported operations while lowering from TTForge to TTIR in forward graph") @pytest.mark.push def test_interpolate(shape, mode): class Interpolate(nn.Module):