Skip to content

Commit

Permalink
Adds support for 1D convolutions by rewriting them as 2D convolutions. (
Browse files Browse the repository at this point in the history
  • Loading branch information
ttjost authored Jun 7, 2023
1 parent 3130853 commit 998c9ff
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 0 deletions.
10 changes: 10 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from torch_mlir._version import torch_baseversion

LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
"Conv1dNoPaddingModule_basic",
"Conv1dNoPaddingTransposeModule_basic",
# tm_tensor.scatter' op mismatch in shape of indices and update value at dim#0
"IndexPutImpl2DNoneIndexBroadcastStaticModule_basic"
}
Expand Down Expand Up @@ -274,6 +276,10 @@
# ERROR: Unsupported: dynamic shape operator: aten.repeat_interleave.Tensor
"RepeatInterleaveModule_basic",

# failed to legalize operation 'torch.aten.convolution' that was explicitly marked illegal
"Conv1dNoPaddingModule_basic",
"Conv1dNoPaddingTransposeModule_basic",

# tm_tensor.scatter' op mismatch in shape of indices and update value at dim#0
"IndexPutImpl2DNoneIndexBroadcastStaticModule_basic"
}
Expand Down Expand Up @@ -598,6 +604,7 @@
"NumToTensorFloatModule_basic",
"AtenToDeviceModule_basic",
"AvgPool2dStaticModule_basic",
"Conv1dNoPaddingModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_basic",
"Convolution2DStaticModule_basic",
"ConvolutionModule2DTransposeStridedStatic_basic",
Expand Down Expand Up @@ -925,6 +932,7 @@
"ElementwiseCeilModule_basic",
"ElementwiseReciprocalModule_basic",
"TypePromotionAlphaWiderModule_basic",
"Conv1dNoPaddingModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_basic",
"BatchNorm1DModule_basic",
"BatchNorm1DWith2DInputModule_basic",
Expand Down Expand Up @@ -1215,6 +1223,8 @@

LTC_CRASHING_SET = {
# https://github.com/llvm/torch-mlir/issues/2186
"Conv1dNoPaddingModule_basic",
"Conv1dNoPaddingTransposeModule_basic",
"Add_Module_basic"
}

Expand Down
122 changes: 122 additions & 0 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1945,6 +1945,16 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
m_TorchListOfConstantInts(padding_2d)))
return rewriter.notifyMatchFailure(op,
"non-const padding list unsupported");

bool transposed;
if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed)))
return rewriter.notifyMatchFailure(
op, "transpose must be a bool constant");

if (transposed)
return rewriter.notifyMatchFailure(
op, "Unimplemented: only non-transposed convolutions supported");

// TOSA uses 4D padding {t, b, l, r} while Torch defines 2D padding {t, l}.
// The Torch OFM computation uses 2*pad in each spatial direction, implying
// the same t=b and l=r values for TOSA.
Expand Down Expand Up @@ -3690,6 +3700,112 @@ LogicalResult ConvertAtenOp<Aten_IndexPutImplOp>::matchAndRewrite(
return success();
}

// This defines a template to simplify legalization of certain ops.
template <typename AtenOpT>
class SimplifyAtenOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

template <>
LogicalResult SimplifyAtenOp<AtenConvolutionOp>::matchAndRewrite(
AtenConvolutionOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// TOSA doesn't supports 1D convolutions.
// We model them through a combination of AtenViewOp and 2D Convolution.
// A Conv1D is replaced by:
// %view = AtenViewOp (%input) : (3D type) -> (4D Type)
// %conv2d = AtenConvolution (%view) : (4D type) -> (4D type)
// %view2 = AtenViewOp (%conv2d) : (4D type) -> (3D type)

auto inputTy = adaptor.getInput().getType().cast<RankedTensorType>();
auto weightTy = adaptor.getWeight().getType().cast<RankedTensorType>();
auto outputTy = getTypeConverter()
->convertType(op.getType())
.template cast<RankedTensorType>();

auto ty = op.getType().dyn_cast_or_null<BaseTensorType>();
if (!ty || !ty.hasSizes())
return rewriter.notifyMatchFailure(
op, "unimplemented: input must have known sizes");

if (!inputTy || !weightTy || !outputTy)
return rewriter.notifyMatchFailure(
op, "Input, weight and output to Convolution must be ranked tensors");

if (!weightTy.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "Unimplemented: TOSA only supports static weight");

if (inputTy.getRank() != 3)
return rewriter.notifyMatchFailure(
op, "Unimplemented: only simplify 1D convolution");

auto loc = op->getLoc();

auto getListConstructElementsPlusValue =
[&](Value listConstruct, int64_t addedValue) -> std::optional<Value> {
SmallVector<Value> values;
if (!getListConstructElements(listConstruct, values)) {
return std::nullopt;
}

Type ty = listConstruct.getType();
values.push_back(
rewriter.create<Torch::ConstantIntOp>(op->getLoc(), addedValue));
return rewriter.create<PrimListConstructOp>(op->getLoc(), ty, values);
};

auto stride = getListConstructElementsPlusValue(op.getStride(), 1);
if (!stride.has_value())
return rewriter.notifyMatchFailure(op, "non-const stride list unsupported");

auto dilation = getListConstructElementsPlusValue(op.getDilation(), 1);
if (!dilation.has_value())
return rewriter.notifyMatchFailure(op,
"non-const dilation list unsupported");

auto paddingValue = getListConstructElementsPlusValue(op.getPadding(), 0);
if (!paddingValue.has_value())
return rewriter.notifyMatchFailure(op,
"non-const padding list unsupported");

auto outputPaddingValue =
getListConstructElementsPlusValue(op.getOutputPadding(), 0);
if (!outputPaddingValue.has_value()) {
return rewriter.notifyMatchFailure(
op, "non-const output padding list unsupported");
}

auto addDimOneToSizes = [&](BaseTensorType ty) {
SmallVector<int64_t> newSizes(ty.getSizes());
newSizes.push_back(1);
return newSizes;
};

auto input = op.getInput();
auto weight = op.getWeight();

auto newSizes = addDimOneToSizes(cast<BaseTensorType>(input.getType()));
Value view1dTo2d = reshapeTo(loc, rewriter, input, newSizes);

auto newWeightSizes = addDimOneToSizes(cast<BaseTensorType>(weight.getType()));
weight = reshapeTo(loc, rewriter, weight, newWeightSizes);

auto conv2dOp = rewriter.create<AtenConvolutionOp>(
loc, view1dTo2d.getType(), view1dTo2d, weight, op.getBias(), *stride,
*paddingValue, *dilation, op.getTransposed(), *outputPaddingValue,
op.getGroups());

Value view2dTo1d = reshapeTo(loc, rewriter, conv2dOp, ty.getSizes());
rewriter.replaceOp(op, view2dTo1d);
return success();
}

template <>
LogicalResult ConvertAtenOp<AtenIndexTensorOp>::matchAndRewrite(
AtenIndexTensorOp op, OpAdaptor adaptor,
Expand Down Expand Up @@ -5064,6 +5180,12 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
patterns.add<SimplifyAten_IndexPutImplOp>(context);
patterns.add<SimplifyAten_IndexPutImplOpNone>(context);

#define INSERT_SIMPLIFY_OP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<SimplifyAtenOp<AtenOp>>(typeConverter, context);
INSERT_SIMPLIFY_OP_PATTERN(AtenConvolutionOp)
#undef INSERT_SIMPLIFY_OP_PATTERN

#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, TosaOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, TosaOp>>(typeConverter, \
Expand Down
43 changes: 43 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,49 @@

# ==============================================================================

class Conv1dNoPaddingModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([1, 768, 768], torch.float32, True),
([768, 768, 1], torch.float32, True),
([768], torch.float32, True),
])
def forward(self, x, weights, bias):
return torch.ops.aten.convolution(x, weights, bias, [1], [0], [1], False, [0], 1)


@register_test_case(module_factory=lambda: Conv1dNoPaddingModule())
def Conv1dNoPaddingModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 768, 768), tu.rand(768, 768, 1), torch.ones(768))

# ==============================================================================

class Conv1dNoPaddingTransposeModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([1, 768, 768], torch.float32, True),
([768, 768, 1], torch.float32, True),
([768], torch.float32, True),
])
def forward(self, x, weights, bias):
return torch.ops.aten.convolution(x, weights, bias, [1], [0], [1], True, [0], 1)


@register_test_case(module_factory=lambda: Conv1dNoPaddingTransposeModule())
def Conv1dNoPaddingTransposeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 768, 768), tu.rand(768, 768, 1), torch.ones(768))

# ==============================================================================

class Conv2dNoPaddingModule(torch.nn.Module):

Expand Down

0 comments on commit 998c9ff

Please sign in to comment.