Skip to content

Commit

Permalink
Merge pull request #280 from Xilinx/bump_to_e0a5adb1
Browse files Browse the repository at this point in the history
[AutoBump] Merge with fixes of e0a5adb (May 27) (47)
  • Loading branch information
mgehre-amd authored Sep 9, 2024
2 parents 5d648ed + 483e32b commit 5b76ad7
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 27 deletions.
73 changes: 46 additions & 27 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5513,38 +5513,57 @@ class DecomposeAtenLinearOp : public OpRewritePattern<AtenLinearOp> {
Value bias = op.getBias();

BaseTensorType inputType = cast<BaseTensorType>(input.getType());
if (!inputType.hasSizes() || inputType.getSizes().size() < 2)
return rewriter.notifyMatchFailure(
op, "expected input to be rank 2 or greater");
if (!inputType.hasSizes())
return rewriter.notifyMatchFailure(op, "expected input to have sizes");

BaseTensorType weightType = cast<BaseTensorType>(weight.getType());
// `weight` must be a rank 2 matrix.
if (!weightType.hasSizes() || weightType.getSizes().size() != 2)
return rewriter.notifyMatchFailure(op, "expected weight to be a rank 2");

SmallVector<int64_t> transposeShape =
llvm::to_vector(llvm::reverse(weightType.getSizes()));
Type transposeType = weightType.getWithSizesAndDtype(
llvm::ArrayRef(transposeShape), weightType.getOptionalDtype());
Value transposeWeight =
rewriter.create<AtenTOp>(loc, transposeType, weight);

Value matmul = rewriter.create<AtenMatmulOp>(loc, op.getType(), input,
transposeWeight);
if (!weightType.hasSizes())
return rewriter.notifyMatchFailure(op, "expected weight to have sizes");

auto transposeWeight = [&]() -> Value {
SmallVector<int64_t> transposeShape =
llvm::to_vector(llvm::reverse(weightType.getSizes()));
Type transposeType = weightType.getWithSizesAndDtype(
llvm::ArrayRef(transposeShape), weightType.getOptionalDtype());
Value transposeWeight =
rewriter.create<AtenTOp>(loc, transposeType, weight);
return transposeWeight;
};

if (bias.getType().isa<Torch::NoneType>()) {
rewriter.replaceOp(op, matmul);
return success();
}
auto weightRank = weightType.getSizes().size();
if (weightRank > 2 || weightRank <= 0)
return rewriter.notifyMatchFailure(
op, "expected weight's rank <= 2 && >= 1");
if (weightRank == 1) {
rewriter.replaceOpWithNewOp<AtenMatmulOp>(op, op.getType(), input,
weight);
return success();
} else if (weightRank == 2) {
rewriter.replaceOpWithNewOp<AtenMatmulOp>(op, op.getType(), input,
transposeWeight());
return success();
}
llvm_unreachable("unsupported weightRank");
} else {
BaseTensorType biasType = cast<BaseTensorType>(bias.getType());
if (!biasType.hasSizes() || biasType.getSizes().size() != 1)
return rewriter.notifyMatchFailure(op, "expected bias to be rank 1");

BaseTensorType biasType = cast<BaseTensorType>(bias.getType());
if (!biasType.hasSizes() || biasType.getSizes().size() != 1)
return rewriter.notifyMatchFailure(op, "expected bias to be rank 1");
// `weight` must be a rank 2 matrix.
auto weightRank = weightType.getSizes().size();
if (weightRank != 2)
return rewriter.notifyMatchFailure(op,
"expected weight to be a rank 2");

Value alpha =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1));
rewriter.replaceOpWithNewOp<AtenAddTensorOp>(op, op.getType(), matmul,
op.getBias(), alpha);
return success();
Value matmul = rewriter.create<AtenMatmulOp>(loc, op.getType(), input,
transposeWeight());
Value alpha =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
rewriter.replaceOpWithNewOp<AtenAddTensorOp>(op, op.getType(), matmul,
op.getBias(), alpha);
return success();
}
}
};
} // namespace
Expand Down
11 changes: 11 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,12 @@
}

STABLEHLO_PASS_SET = {
"AtenLinear1D_basic",
"AtenLinear2D_basic",
"AtenLinear3DBias_basic",
"AtenLinearMatVec_basic",
"AtenLinearVecMatBias_basic",
"AtenLinearVecMat_basic",
"ReduceAminSingleDim_basic",
"AtenDotModule_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
Expand Down Expand Up @@ -1506,6 +1512,8 @@
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"AtenLinear2D_basic",
"AtenLinear3DBias_basic",
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
"ElementwiseDivTensorFloatModule_basic",
"ElementwiseMulTensorFloatModule_basic",
Expand Down Expand Up @@ -2092,6 +2100,9 @@
"CumsumStaticNegativeDimModule_basic",
"CumsumInputDtypeInt32Module_basic",
"EyeStaticModule_basic",
"AtenLinear1D_basic",
"AtenLinearMatVec_basic",
"AtenLinearVecMatBias_basic",
"MaxPool1dEmptyStrideStaticModule_basic",
"MaxPool1dStaticCeilModeTrueModule_basic",
"MaxPool1dStaticModule_basic",
Expand Down
125 changes: 125 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,131 @@ def AtenMatmulQMixedSigni8Transpose_basic(module, tu: TestUtils):
# ==============================================================================


class AtenLinear1D(torch.nn.Module):
@export
@annotate_args(
[
None,
([3], torch.float32, True),
([3], torch.float32, True),
]
)
def forward(self, a, b):
return torch.ops.aten.linear(a, b)


@register_test_case(module_factory=lambda: AtenLinear1D())
def AtenLinear1D_basic(module, tu: TestUtils):
module.forward(tu.rand(3), tu.rand(3))


# ==============================================================================


class AtenLinearMatVec(torch.nn.Module):
@export
@annotate_args(
[
None,
([3, 4], torch.float32, True),
([4], torch.float32, True),
]
)
def forward(self, a, b):
return torch.ops.aten.linear(a, b)


@register_test_case(module_factory=lambda: AtenLinearMatVec())
def AtenLinearMatVec_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4), tu.rand(4))


# ==============================================================================


class AtenLinearVecMat(torch.nn.Module):
@export
@annotate_args(
[
None,
([4], torch.float32, True),
([3, 4], torch.float32, True),
]
)
def forward(self, a, b):
return torch.ops.aten.linear(a, b)


@register_test_case(module_factory=lambda: AtenLinearVecMat())
def AtenLinearVecMat_basic(module, tu: TestUtils):
module.forward(tu.rand(4), tu.rand(3, 4))


class AtenLinearVecMatBias(torch.nn.Module):
@export
@annotate_args(
[
None,
([4], torch.float32, True),
([3, 4], torch.float32, True),
([3], torch.float32, True),
]
)
def forward(self, a, b, c):
return torch.ops.aten.linear(a, b, c)


@register_test_case(module_factory=lambda: AtenLinearVecMatBias())
def AtenLinearVecMatBias_basic(module, tu: TestUtils):
module.forward(tu.rand(4), tu.rand(3, 4), tu.rand(3))


# ==============================================================================


class AtenLinear2D(torch.nn.Module):
@export
@annotate_args(
[
None,
([3, 4], torch.float32, True),
([5, 4], torch.float32, True),
]
)
def forward(self, a, b):
return torch.ops.aten.linear(a, b)


@register_test_case(module_factory=lambda: AtenLinear2D())
def AtenLinear2D_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4), tu.rand(5, 4))


# ==============================================================================


class AtenLinear3DBias(torch.nn.Module):
@export
@annotate_args(
[
None,
([3, 6, 4], torch.float32, True),
([5, 4], torch.float32, True),
([5], torch.float32, True),
]
)
def forward(self, a, b, c):
return torch.ops.aten.linear(a, b, c)


@register_test_case(module_factory=lambda: AtenLinear3DBias())
def AtenLinear3DBias_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 6, 4), tu.rand(5, 4), tu.rand(5))


# ==============================================================================


class AtenLinalgCrossInt(torch.nn.Module):
@export
@annotate_args(
Expand Down

0 comments on commit 5b76ad7

Please sign in to comment.