Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for upsample2d in lower to mlir #1315

Merged
merged 1 commit into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions forge/csrc/passes/lower_to_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,7 @@ class MLIRGenerator
lowering_handler_map["tanh"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::TanhOp>;
lowering_handler_map["transpose"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::TransposeOp>;
lowering_handler_map["unsqueeze"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::UnsqueezeOp>;
lowering_handler_map["upsample2d"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::Upsample2dOp>;
}
};

Expand Down
2 changes: 1 addition & 1 deletion forge/forge/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
from .constant import Constant
from .nn import Softmax, Layernorm, LogSoftmax, Batchnorm, MaxPool2dModule
from .eltwise_nary import Concatenate, Where, IndexCopy, Stack, Interleave
from .resize import Resize2d, Resize3d
from .resize import Resize2d, Resize3d, Upsample2d
from .embedding import Embedding
from .dram_queue import DRAMQueue
from .quantize import Quantize, Dequantize, Requantize, ForgeRequantize
Expand Down
1 change: 1 addition & 0 deletions forge/forge/op/eval/forge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@
"constant": "constant",
"resize2d": "resize",
"resize3d": "resize",
"upsample2d": "resize",
"dram_queue": "dram_queue",
"softmax": "nn",
"log_softmax": "nn",
Expand Down
80 changes: 66 additions & 14 deletions forge/forge/op/eval/forge/resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,28 +26,29 @@ def eval(type, attr, ops):
assert len(ops) == 1
resize_method = INT_TO_RESIZE2d_METHOD[attr[-3]]
acts = ops[0]
shape = acts.shape
channel_last = attr[-1]
sizes = attr

# Determine whether to use upsample2d or downsample2d to replicate the resize2d operation.
# If the target size is larger than the input size, apply upsampling; otherwise, apply downsampling.
# Example: Given target size (14,14) and input shape (1,3,7,7), since 14 > 7, we use upsampling.
upsample = sizes[0] >= shape[-3] if channel_last else sizes[0] >= shape[-2]
scale_factor = sizes[0] // shape[-3] if channel_last else sizes[0] // shape[-2]

if attr[-1]:
# channel last
acts = ops[0].permute((0, 3, 1, 2))

if resize_method == "nearest":
upsample = torch.nn.Upsample(
size=attr[0:2],
mode=resize_method,
)
if upsample:
upsample = torch.nn.Upsample(scale_factor=scale_factor, mode=resize_method)
result = upsample(acts)
else:
upsample = torch.nn.Upsample(
size=attr[0:2],
mode=resize_method,
align_corners=bool(attr[-2]),
)

t_ops = to_torch_operands(acts)

result = upsample(*t_ops)
raise NotImplementedError("Downsampling of resize2d is not supported yet")

if attr[-1]:
result = result.permute((0, 2, 3, 1))

return result
elif type == "resize3d":
assert len(attr) == 6, "Resize3d should have 6 attrs: [size, size, size, method, align_corners, channel_last]"
Expand Down Expand Up @@ -77,6 +78,14 @@ def eval(type, attr, ops):
if attr[-1]:
result = result.permute((0, 2, 3, 4, 1))
return result
elif type == "upsample2d":
operandA = ops[0]
scale_factor = attr[0]
resize_method = attr[2]

upsample = torch.nn.Upsample(scale_factor=scale_factor, mode=resize_method)
result = upsample(operandA)
return result


def shape(type, attr, ops):
Expand Down Expand Up @@ -158,6 +167,16 @@ def shape(type, attr, ops):
shape[-3], shape[-2], shape[-1] = attr[0], attr[1], attr[2]
return shape, []

elif type == "upsample2d":
channel_last = attr[2]
scale_factor = attr[0]
shape = list(ops[0])
if channel_last:
shape[-3], shape[-2] = shape[-3] * scale_factor, shape[-2] * scale_factor
else:
shape[-2], shape[-1] = shape[-2] * scale_factor, shape[-1] * scale_factor
return shape, []


def lower(type, attr, lc, ops, outputs):
raise RuntimeError("This should never be called.")
Expand Down Expand Up @@ -277,6 +296,39 @@ def decompose_resize3d(attr, dc, inputs, resize_method):


def decompose(type, attr, dc, inputs):
if type == "resize2d":
resize_method = INT_TO_RESIZE2d_METHOD[attr[2]]
acts = inputs[0]
shape = acts.shape
channel_last = attr[-1]
sizes = attr[0:2]
result = inputs[0]

upsample = sizes[0] >= shape[-3] if channel_last else sizes[0] >= shape[-2]

if not channel_last:
# Changing the Layout from NCHW to NHWC as ttir.upsample2d supports only the NHWC layout
result = dc.op(TransposeTM.create(dim0=-3, dim1=-2), [result])
result = dc.op(TransposeTM.create(dim0=-2, dim1=-1), [result])

if upsample:
scale_factor = sizes[0] // shape[-3] if channel_last else sizes[0] // shape[-2]
result = dc.op_with_named_attrs(
"upsample2d",
[result],
{"scale_factor": scale_factor, "mode": resize_method, "channel_last": True},
(scale_factor, resize_method, True),
)
else:
raise NotImplementedError("Downsampling of resize2d is not supported yet")

if not channel_last:
# Changing the Layout back to NCHW from NHWC after ttir.upsample2d operation
result = dc.op(TransposeTM.create(dim0=-2, dim1=-1), [result])
result = dc.op(TransposeTM.create(dim0=-3, dim1=-2), [result])

dc.fuse(result)

if type == "resize3d":
assert len(attr) == 6, "Resize3d should have 6 attrs: [size, size, size, method, align_corners, channel_last]"
assert len(inputs) == 1
Expand Down
38 changes: 38 additions & 0 deletions forge/forge/op/resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,44 @@ def Resize2d(
return result


def Upsample2d(
name: str, operandA: Tensor, scale_factor: int, mode: str = "nearest", channel_last: bool = False
) -> Tensor:
"""
Upsample 2D operation

Parameters
----------
name: str
Op name, unique to the module, or leave blank to autoset

operandA: Tensor
Input operand A

scale_factor: int
multiplier for spatial size.

mode: str
the upsampling algorithm

Returns
-------
Tensor
Forge tensor
"""
result: Tensor = op(
"upsample2d",
name,
operandA,
attrs=(scale_factor, mode, channel_last),
scale_factor=scale_factor,
mode=mode,
channel_last=channel_last,
).get_tensor()

return result


def Resize3d(
name: str,
operandA: Tensor,
Expand Down
2 changes: 2 additions & 0 deletions forge/forge/tvm_to_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -1762,6 +1762,7 @@ def populate_requantize_args(graph, nid, compiler_cfg):
"qnn.requantize": "requantize",
"qnn.dense": "matmul",
"atan": "atan",
"upsample2d": "upsample2d",
}

forge_op_to_function_name = {
Expand Down Expand Up @@ -1847,6 +1848,7 @@ def populate_requantize_args(graph, nid, compiler_cfg):
"dequantize": "forge.op.Dequantize",
"requantize": "forge.op.Requantize",
"atan": "forge.op.Atan",
"upsample2d": "forge.op.Upsample2d",
}
forge_ops_needing_arguments = {
"argmax": populate_argmax_args,
Expand Down
7 changes: 4 additions & 3 deletions forge/test/mlir/operators/nn/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,12 @@ def forward(self, x):
@pytest.mark.parametrize(
"shape, mode",
[
((1, 2048, 7, 7), "nearest"),
((1, 2048, 7, 7), "bilinear"),
pytest.param((1, 2048, 7, 7), "nearest"),
pytest.param(
(1, 2048, 7, 7), "bilinear", marks=pytest.mark.xfail(reason="Runtime Error TTNN: info: Unsupported mode ")
),
],
)
@pytest.mark.xfail(reason="Found Unsupported operations while lowering from TTForge to TTIR in forward graph")
@pytest.mark.push
def test_interpolate(shape, mode):
class Interpolate(nn.Module):
Expand Down
Loading