Skip to content

Commit

Permalink
upsample2d op support
Browse files Browse the repository at this point in the history
  • Loading branch information
ashokkumarkannan1 committed Feb 27, 2025
1 parent 7254715 commit 6ccbc3b
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 18 deletions.
1 change: 1 addition & 0 deletions forge/csrc/passes/lower_to_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,7 @@ class MLIRGenerator
lowering_handler_map["tanh"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::TanhOp>;
lowering_handler_map["transpose"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::TransposeOp>;
lowering_handler_map["unsqueeze"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::UnsqueezeOp>;
lowering_handler_map["upsample2d"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::Upsample2dOp>;
}
};

Expand Down
2 changes: 1 addition & 1 deletion forge/forge/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
from .constant import Constant
from .nn import Softmax, Layernorm, LogSoftmax, Batchnorm, MaxPool2dModule
from .eltwise_nary import Concatenate, Where, IndexCopy, Stack, Interleave
from .resize import Resize2d, Resize3d
from .resize import Resize2d, Resize3d, Upsample2d
from .embedding import Embedding
from .dram_queue import DRAMQueue
from .quantize import Quantize, Dequantize, Requantize, ForgeRequantize
Expand Down
1 change: 1 addition & 0 deletions forge/forge/op/eval/forge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@
"constant": "constant",
"resize2d": "resize",
"resize3d": "resize",
"upsample2d": "resize",
"dram_queue": "dram_queue",
"softmax": "nn",
"log_softmax": "nn",
Expand Down
80 changes: 66 additions & 14 deletions forge/forge/op/eval/forge/resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,28 +26,29 @@ def eval(type, attr, ops):
assert len(ops) == 1
resize_method = INT_TO_RESIZE2d_METHOD[attr[-3]]
acts = ops[0]
shape = acts.shape
channel_last = attr[-1]
sizes = attr

# Determine whether to use upsample2d or downsample2d to replicate the resize2d operation.
# If the target size is larger than the input size, apply upsampling; otherwise, apply downsampling.
# Example: Given target size (14,14) and input shape (1,3,7,7), since 14 > 7, we use upsampling.
upsample = sizes[0] >= shape[-3] if channel_last else sizes[0] >= shape[-2]
scale_factor = sizes[0] // shape[-3] if channel_last else sizes[0] // shape[-2]

if attr[-1]:
# channel last
acts = ops[0].permute((0, 3, 1, 2))

if resize_method == "nearest":
upsample = torch.nn.Upsample(
size=attr[0:2],
mode=resize_method,
)
if upsample:
upsample = torch.nn.Upsample(scale_factor=scale_factor, mode=resize_method)
result = upsample(acts)
else:
upsample = torch.nn.Upsample(
size=attr[0:2],
mode=resize_method,
align_corners=bool(attr[-2]),
)

t_ops = to_torch_operands(acts)

result = upsample(*t_ops)
raise NotImplementedError("Downsampling of resize2d is not supported yet")

if attr[-1]:
result = result.permute((0, 2, 3, 1))

return result
elif type == "resize3d":
assert len(attr) == 6, "Resize3d should have 6 attrs: [size, size, size, method, align_corners, channel_last]"
Expand Down Expand Up @@ -77,6 +78,14 @@ def eval(type, attr, ops):
if attr[-1]:
result = result.permute((0, 2, 3, 4, 1))
return result
elif type == "upsample2d":
operandA = ops[0]
scale_factor = attr[0]
resize_method = attr[2]

upsample = torch.nn.Upsample(scale_factor=scale_factor, mode=resize_method)
result = upsample(operandA)
return result


def shape(type, attr, ops):
Expand Down Expand Up @@ -158,6 +167,16 @@ def shape(type, attr, ops):
shape[-3], shape[-2], shape[-1] = attr[0], attr[1], attr[2]
return shape, []

elif type == "upsample2d":
channel_last = attr[2]
scale_factor = attr[0]
shape = list(ops[0])
if channel_last:
shape[-3], shape[-2] = shape[-3] * scale_factor, shape[-2] * scale_factor
else:
shape[-2], shape[-1] = shape[-2] * scale_factor, shape[-1] * scale_factor
return shape, []


def lower(type, attr, lc, ops, outputs):
raise RuntimeError("This should never be called.")
Expand Down Expand Up @@ -277,6 +296,39 @@ def decompose_resize3d(attr, dc, inputs, resize_method):


def decompose(type, attr, dc, inputs):
if type == "resize2d":
resize_method = INT_TO_RESIZE2d_METHOD[attr[2]]
acts = inputs[0]
shape = acts.shape
channel_last = attr[-1]
sizes = attr[0:2]
result = inputs[0]

upsample = sizes[0] >= shape[-3] if channel_last else sizes[0] >= shape[-2]

if not channel_last:
# Changing the Layout from NCHW to NHWC as ttir.upsample2d supports only the NHWC layout
result = dc.op(TransposeTM.create(dim0=-3, dim1=-2), [result])
result = dc.op(TransposeTM.create(dim0=-2, dim1=-1), [result])

if upsample:
scale_factor = sizes[0] // shape[-3] if channel_last else sizes[0] // shape[-2]
result = dc.op_with_named_attrs(
"upsample2d",
[result],
{"scale_factor": scale_factor, "mode": resize_method, "channel_last": True},
(scale_factor, resize_method, True),
)
else:
raise NotImplementedError("Downsampling of resize2d is not supported yet")

if not channel_last:
# Changing the Layout back to NCHW from NHWC after ttir.upsample2d operation
result = dc.op(TransposeTM.create(dim0=-2, dim1=-1), [result])
result = dc.op(TransposeTM.create(dim0=-3, dim1=-2), [result])

dc.fuse(result)

if type == "resize3d":
assert len(attr) == 6, "Resize3d should have 6 attrs: [size, size, size, method, align_corners, channel_last]"
assert len(inputs) == 1
Expand Down
38 changes: 38 additions & 0 deletions forge/forge/op/resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,44 @@ def Resize2d(
return result


def Upsample2d(
name: str, operandA: Tensor, scale_factor: int, mode: str = "nearest", channel_last: bool = False
) -> Tensor:
"""
Upsample 2D operation
Parameters
----------
name: str
Op name, unique to the module, or leave blank to autoset
operandA: Tensor
Input operand A
scale_factor: int
multiplier for spatial size.
mode: str
the upsampling algorithm
Returns
-------
Tensor
Forge tensor
"""
result: Tensor = op(
"upsample2d",
name,
operandA,
attrs=(scale_factor, mode, channel_last),
scale_factor=scale_factor,
mode=mode,
channel_last=channel_last,
).get_tensor()

return result


def Resize3d(
name: str,
operandA: Tensor,
Expand Down
2 changes: 2 additions & 0 deletions forge/forge/tvm_to_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -1762,6 +1762,7 @@ def populate_requantize_args(graph, nid, compiler_cfg):
"qnn.requantize": "requantize",
"qnn.dense": "matmul",
"atan": "atan",
"upsample2d": "upsample2d",
}

forge_op_to_function_name = {
Expand Down Expand Up @@ -1847,6 +1848,7 @@ def populate_requantize_args(graph, nid, compiler_cfg):
"dequantize": "forge.op.Dequantize",
"requantize": "forge.op.Requantize",
"atan": "forge.op.Atan",
"upsample2d": "forge.op.Upsample2d",
}
forge_ops_needing_arguments = {
"argmax": populate_argmax_args,
Expand Down
7 changes: 4 additions & 3 deletions forge/test/mlir/operators/nn/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,12 @@ def forward(self, x):
@pytest.mark.parametrize(
"shape, mode",
[
((1, 2048, 7, 7), "nearest"),
((1, 2048, 7, 7), "bilinear"),
pytest.param((1, 2048, 7, 7), "nearest"),
pytest.param(
(1, 2048, 7, 7), "bilinear", marks=pytest.mark.xfail(reason="Runtime Error TTNN: info: Unsupported mode ")
),
],
)
@pytest.mark.xfail(reason="Found Unsupported operations while lowering from TTForge to TTIR in forward graph")
@pytest.mark.push
def test_interpolate(shape, mode):
class Interpolate(nn.Module):
Expand Down

0 comments on commit 6ccbc3b

Please sign in to comment.