diff --git a/BUILD.bazel b/BUILD.bazel index b755374dc7..8d2b26c32a 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -276,6 +276,7 @@ cc_library( ":chlo_attrs_inc_gen", ":chlo_enums_inc_gen", ":chlo_ops_inc_gen", + ":stablehlo_assembly_format", ":stablehlo_type_inference", "@llvm-project//llvm:Support", "@llvm-project//mlir:BytecodeOpInterface", diff --git a/stablehlo/dialect/ChloEnums.td b/stablehlo/dialect/ChloEnums.td index c6d918cba6..d52b7803c5 100644 --- a/stablehlo/dialect/ChloEnums.td +++ b/stablehlo/dialect/ChloEnums.td @@ -70,4 +70,29 @@ def CHLO_ComparisonType : I32EnumAttr<"ComparisonType", def CHLO_ComparisonTypeAttr : EnumAttr; +//===----------------------------------------------------------------------===// +// Ragged dot op definitions. +//===----------------------------------------------------------------------===// + +// These mirror the XLA PrecisionConfig proto enum. +def CHLO_PRECISION_DEFAULT : I32EnumAttrCase<"DEFAULT", 0>; +def CHLO_PRECISION_HIGH : I32EnumAttrCase<"HIGH", 1>; +def CHLO_PRECISION_HIGHEST : I32EnumAttrCase<"HIGHEST", 2>; + +def CHLO_Precision : I32EnumAttr<"Precision", + "XLA precision for an operand. Has backend specific meaning.", + [ + CHLO_PRECISION_DEFAULT, + CHLO_PRECISION_HIGH, + CHLO_PRECISION_HIGHEST + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::chlo"; +} + +def CHLO_PrecisionAttr : EnumAttr; + +def CHLO_PrecisionConfigAttr: + TypedArrayAttrBase; + #endif // STABLEHLO_DIALECT_CHLO_ENUMS diff --git a/stablehlo/dialect/ChloOps.cpp b/stablehlo/dialect/ChloOps.cpp index 1868204647..b175006fcf 100644 --- a/stablehlo/dialect/ChloOps.cpp +++ b/stablehlo/dialect/ChloOps.cpp @@ -42,6 +42,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" #include "mlir/Support/TypeID.h" #include "mlir/Transforms/InliningUtils.h" +#include "stablehlo/dialect/AssemblyFormat.h" #include "stablehlo/dialect/Base.h" #include "stablehlo/dialect/BroadcastUtils.h" #include "stablehlo/dialect/ChloBytecode.h" @@ -415,6 +416,242 @@ LogicalResult BroadcastSelectOp::reifyReturnTypeShapes( return success(); } +//===----------------------------------------------------------------------===// +// RaggedDotOp +//===----------------------------------------------------------------------===// + +namespace { + +// RaggedDot has three general modes, based on the kind of the ragged dimension. +// Mode 1, where the ragged dimension is an lhs non-contracting dim (m). +// lhs : [b, m, k] +// rhs : [g, b, k, n] +// group_sizes : [g] +// result : [b, m, n] +// Mode 2, where the ragged dimension is an lhs/rhs contracting dim (k). +// lhs : [b, m, k] +// rhs : [b, k, n] +// group_sizes : [g] +// result : [g, b, m, n] +// Mode 3, where the ragged dimension is an lhs/rhs batch dim (b). +// lhs : [b, m, k] +// rhs : [b, k, n] +// group_sizes : [g] +// result : [b, m, n] +// As with dot_general, the lhs and rhs can have arbitrary batching, +// contracting and non-contracting dimensions. +// Additionally: +// - In all modes, the lhs must have exactly one ragged dimension. +// - In mode 1, the rhs must have exactly one group dimension. +LogicalResult checkRaggedDotConstraints( + std::optional location, RankedTensorType rankedLhsType, + RankedTensorType rankedRhsType, RankedTensorType rankedGroupSizesType, + ArrayRef lhsBatchingDimensions, + ArrayRef rhsBatchingDimensions, + ArrayRef lhsContractingDimensions, + ArrayRef rhsContractingDimensions, + ArrayRef lhsRaggedDimensions, + ArrayRef rhsGroupDimensions) { + // Check that the group sizes has rank=1. + if (rankedGroupSizesType.getRank() != 1) { + return emitOptionalError( + location, "expected rank of group_sizes of ragged dot to be 1, got ", + rankedGroupSizesType.getRank()); + } + auto numGroups = rankedGroupSizesType.getDimSize(0); + + // Check that there is exactly one lhs ragged dimension. + if (lhsRaggedDimensions.size() != 1) { + return emitOptionalError( + location, "There must be exactly one ragged dimension in the lhs."); + } + const int64_t lhsRaggedDim = lhsRaggedDimensions[0]; + + // Check that the lhs ragged dimension is in range. + if (failed(hlo::checkDimInBounds(location, lhsRaggedDim, + rankedLhsType.getRank(), "lhs_ragged_dim", + "lhs_rank"))) { + return failure(); + } + + // Validate basic properties of the rhs group dimension(s). + for (auto rhsGroupDim : rhsGroupDimensions) { + if (failed(hlo::checkDimInBounds(location, rhsGroupDim, + rankedRhsType.getRank(), "rhs_group_dim", + "rhs_rank"))) { + return failure(); + } + } + if (failed(hlo::checkDimsDistinct( + location, rhsGroupDimensions, rhsBatchingDimensions, + "rhs_group_dimensions", "rhs_batching_dimensions")) || + failed(hlo::checkDimsDistinct( + location, rhsGroupDimensions, rhsContractingDimensions, + "rhs_group_dimensions", "rhs_contracting_dimensions"))) { + return failure(); + } + + if (llvm::is_contained(lhsBatchingDimensions, lhsRaggedDim) || + llvm::is_contained(lhsContractingDimensions, lhsRaggedDim)) { + // Ragged batch (b): [b,m,k], [b,k,n], [g] -> [b,m,n]. + // Ragged contracting (k): [b,m,k], [b,k,n], [g] -> [g,b,m,n]. + if (!rhsGroupDimensions.empty()) { + return emitOptionalError( + location, + "There must be zero group dimensions in the rhs when the " + "ragged dimension is batch or contracting."); + } + } else { + // Ragged non-contracting (m): [b,m,k], [g,b,k,n], [g] -> [b,m,n]. + if (rhsGroupDimensions.size() != 1) { + return emitOptionalError( + location, + "There must be exactly one group dimension in the rhs when the lhs " + "ragged dimension is non-contracting."); + } + // Compare the group dimension size with the number of groups. + const int64_t rhsGroupDim = rhsGroupDimensions[0]; + if (!hlo::verifyCompatibleDims(numGroups, + rankedRhsType.getDimSize(rhsGroupDim))) { + return emitOptionalError( + location, "group_sizes is expected to have shape=[", + rankedRhsType.getDimSize(rhsGroupDim), "], got [", numGroups, "]"); + } + } + return success(); +} + +SmallVector inferRaggedDotOutputDimensions( + RankedTensorType rankedLhsType, RankedTensorType rankedRhsType, + RankedTensorType rankedGroupSizesType, + ArrayRef lhsBatchingDimensions, + ArrayRef rhsBatchingDimensions, + ArrayRef lhsContractingDimensions, + ArrayRef rhsContractingDimensions, + ArrayRef lhsRaggedDimensions, + ArrayRef rhsGroupDimensions) { + // Must have already checked that group_sizes is 1-D. + const int64_t numGroups = rankedGroupSizesType.getDimSize(0); + // Must have already checked that there is exactly one lhs ragged dim. + const int64_t lhsRaggedDim = lhsRaggedDimensions[0]; + + SmallVector dimensions; + // Add the group dimension to the result shape in case of ragged contracting. + if (llvm::is_contained(lhsContractingDimensions, lhsRaggedDim)) { + dimensions.push_back(numGroups); + } + auto lhsShape = rankedLhsType.getShape(); + auto rhsShape = rankedRhsType.getShape(); + for (const int64_t lhsBatchingDim : lhsBatchingDimensions) + dimensions.push_back(lhsShape[lhsBatchingDim]); + for (int64_t i = 0; i < rankedLhsType.getRank(); i++) + if (!llvm::is_contained(lhsBatchingDimensions, i) && + !llvm::is_contained(lhsContractingDimensions, i)) + dimensions.push_back(lhsShape[i]); + for (int64_t i = 0; i < rankedRhsType.getRank(); i++) + if (!llvm::is_contained(rhsBatchingDimensions, i) && + !llvm::is_contained(rhsContractingDimensions, i) && + !llvm::is_contained(rhsGroupDimensions, i)) + dimensions.push_back(rhsShape[i]); + return dimensions; +} + +LogicalResult inferRaggedDotOp( + std::optional location, Value lhs, Value rhs, Value groupSizes, + ArrayRef lhsBatchingDimensions, + ArrayRef rhsBatchingDimensions, + ArrayRef lhsContractingDimensions, + ArrayRef rhsContractingDimensions, + ArrayRef lhsRaggedDimensions, ArrayRef rhsGroupDimensions, + std::optional precisionConfig, + SmallVectorImpl& inferredReturnShapes) { + if (failed(hlo::verifyPrecisionConfig(location, precisionConfig))) { + return failure(); + } + + // Validate basic properties of dot dimension numbers. + if (failed(hlo::checkDotGeneralConstraints( + location, lhs.getType(), rhs.getType(), lhsBatchingDimensions, + rhsBatchingDimensions, lhsContractingDimensions, + rhsContractingDimensions, precisionConfig))) { + return failure(); + } + + // Validate ragged dot constraints. + auto rankedLhsType = cast(lhs.getType()); + auto rankedRhsType = cast(rhs.getType()); + auto rankedGroupSizesType = cast(groupSizes.getType()); + if (failed(checkRaggedDotConstraints( + location, rankedLhsType, rankedRhsType, rankedGroupSizesType, + lhsBatchingDimensions, rhsBatchingDimensions, + lhsContractingDimensions, rhsContractingDimensions, + lhsRaggedDimensions, rhsGroupDimensions))) { + return failure(); + } + + // Infer the output dimensions of the ragged dot operation. + inferredReturnShapes.emplace_back(inferRaggedDotOutputDimensions( + rankedLhsType, rankedRhsType, rankedGroupSizesType, lhsBatchingDimensions, + rhsBatchingDimensions, lhsContractingDimensions, rhsContractingDimensions, + lhsRaggedDimensions, rhsGroupDimensions)); + return success(); +} + +} // namespace + +LogicalResult RaggedDotOp::verify() { + auto location = getLoc(); + auto raggedDotDimNums = getRaggedDotDimensionNumbers(); + + SmallVector inferredReturnShapes; + if (failed(inferRaggedDotOp(location, getLhs(), getRhs(), getGroupSizes(), + raggedDotDimNums.getLhsBatchingDimensions(), + raggedDotDimNums.getRhsBatchingDimensions(), + raggedDotDimNums.getLhsContractingDimensions(), + raggedDotDimNums.getRhsContractingDimensions(), + raggedDotDimNums.getLhsRaggedDimensions(), + raggedDotDimNums.getRhsGroupDimensions(), + getPrecisionConfig(), inferredReturnShapes))) + return failure(); + auto inferredShape = inferredReturnShapes[0]; + + auto resultType = cast(getResult().getType()); + if (failed(verifyCompatibleShape(inferredShape.getDims(), + resultType.getShape()))) { + return emitOptionalError( + location, "inferred shape '", + hlo::dimSizesToString(inferredShape.getDims()), "' ", + "is incompatible with return type of operation ", resultType, ""); + } + + return success(); +} + +LogicalResult RaggedDotOp::inferReturnTypes( + MLIRContext*, std::optional, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl& inferredReturnTypes) { + RaggedDotOp::Adaptor op(operands, attributes, properties, regions); + + auto rankedLhsType = cast(op.getLhs().getType()); + auto rankedRhsType = cast(op.getRhs().getType()); + auto rankedGroupSizesType = + cast(op.getGroupSizes().getType()); + auto raggedDotDimNums = op.getRaggedDotDimensionNumbers(); + + inferredReturnTypes.push_back(RankedTensorType::get( + inferRaggedDotOutputDimensions( + rankedLhsType, rankedRhsType, rankedGroupSizesType, + raggedDotDimNums.getLhsBatchingDimensions(), + raggedDotDimNums.getRhsBatchingDimensions(), + raggedDotDimNums.getLhsContractingDimensions(), + raggedDotDimNums.getRhsContractingDimensions(), + raggedDotDimNums.getLhsRaggedDimensions(), + raggedDotDimNums.getRhsGroupDimensions()), + rankedLhsType.getElementType())); + return success(); +} + //===----------------------------------------------------------------------===// // TopKOp //===----------------------------------------------------------------------===// @@ -523,5 +760,140 @@ void ChloDialect::printAttribute(Attribute attr, DialectAsmPrinter& os) const { assert(succeeded(result)); } +/// Helpers for attributes parsing. + +static ParseResult parseDims(AsmParser& parser, + SmallVector& dimSizes) { + dimSizes.clear(); + auto failOrDims = hlo::parseDimSizes(parser); + if (failed(failOrDims)) return failure(); + dimSizes = std::move(*failOrDims); + return success(); +} + +/// Parse a custom attribute that resembles a struct of the form +/// < +/// foo = something_parsed_by_custom_parser, +/// bar = something_parsed_by_different_custom_parser, +/// baz something_parsed_by_another_custom_parser +/// > +/// The optional argument `parse_equal` array can be used to denote if +/// '=' follows the keyword (see baz in the example above) for a field. If +/// not provided, all fields must be followed by a '='. +static ParseResult parseStruct( + AsmParser& parser, ArrayRef keywords, + ArrayRef> parseFuncs, + ArrayRef parseEqual = {}) { + assert(keywords.size() == parseFuncs.size()); + assert(parseEqual.empty() || parseEqual.size() == keywords.size()); + SmallVector seen(keywords.size(), false); + while (failed(parser.parseOptionalGreater())) { + bool foundOne = false; + for (const auto& it : llvm::enumerate(keywords)) { + size_t index = it.index(); + StringRef keyword = it.value(); + if (failed(parser.parseOptionalKeyword(keyword))) continue; + if (seen[index]) + return parser.emitError(parser.getCurrentLocation()) + << "duplicated `" << keyword << "` entry"; + if (parseEqual.empty() || parseEqual[index]) { + if (failed(parser.parseEqual())) return failure(); + } + if (failed(parseFuncs[index]())) return failure(); + if (failed(parser.parseOptionalComma())) return parser.parseGreater(); + seen[index] = true; + foundOne = true; + } + if (!foundOne) { + auto parseError = parser.emitError(parser.getCurrentLocation()) + << "expected one of: "; + llvm::interleaveComma(keywords, parseError, [&](StringRef kw) { + parseError << '`' << kw << '`'; + }); + return parseError; + } + } + return success(); +} + +// Helpers to print an optional array or integer field, to simplify writing +// attribute printers. +template +static void printField(AsmPrinter& printer, StringRef name, T field, + StringRef& separator) { + if (field != 0) { + printer << separator << name << " = " << field; + separator = ", "; + } +} +template +static void printField(AsmPrinter& printer, StringRef name, ArrayRef field, + StringRef& separator) { + if (!field.empty()) { + printer << separator << name << " = ["; + llvm::interleaveComma(field, printer); + printer << "]"; + separator = ", "; + } +} +template +static void printStruct(AsmPrinter& printer, StringRef name, + Ts... printFields) { + printer << "<"; + StringRef separator = ""; + // Fold expression to print each entry in the parameter pack. + // TODO(stablehlo-team): this can be simplified when TF moves to C++17. + using unused = int[]; + (void)unused{0, (printField(printer, std::get<0>(printFields), + std::get<1>(printFields), separator), + 0)...}; + printer << ">"; +} + +// Custom printer and parser for RaggedDotDimensionNumbersAttr. +void RaggedDotDimensionNumbersAttr::print(AsmPrinter& printer) const { + printStruct( + printer, "ragged_dot", + std::make_pair("lhs_batching_dimensions", getLhsBatchingDimensions()), + std::make_pair("rhs_batching_dimensions", getRhsBatchingDimensions()), + std::make_pair("lhs_contracting_dimensions", + getLhsContractingDimensions()), + std::make_pair("rhs_contracting_dimensions", + getRhsContractingDimensions()), + std::make_pair("lhs_ragged_dimensions", getLhsRaggedDimensions()), + std::make_pair("rhs_group_dimensions", getRhsGroupDimensions())); +} + +Attribute RaggedDotDimensionNumbersAttr::parse(AsmParser& parser, Type type) { + if (failed(parser.parseLess())) return {}; + + SmallVector lhsBatchingDimensions; + SmallVector rhsBatchingDimensions; + SmallVector lhsContractingDimensions; + SmallVector rhsContractingDimensions; + SmallVector lhsRaggedDimensions; + SmallVector rhsGroupDimensions; + + if (failed(parseStruct( + parser, + {"lhs_batching_dimensions", "rhs_batching_dimensions", + "lhs_contracting_dimensions", "rhs_contracting_dimensions", + "lhs_ragged_dimensions", "rhs_group_dimensions"}, + {[&]() { return parseDims(parser, lhsBatchingDimensions); }, + [&]() { return parseDims(parser, rhsBatchingDimensions); }, + [&]() { return parseDims(parser, lhsContractingDimensions); }, + [&]() { return parseDims(parser, rhsContractingDimensions); }, + [&]() { return parseDims(parser, lhsRaggedDimensions); }, + [&]() { return parseDims(parser, rhsGroupDimensions); }}))) { + parser.emitError(parser.getCurrentLocation()) + << "failed parsing ragged dot dimension numbers attribute"; + return {}; + } + return RaggedDotDimensionNumbersAttr::get( + parser.getContext(), lhsBatchingDimensions, rhsBatchingDimensions, + lhsContractingDimensions, rhsContractingDimensions, lhsRaggedDimensions, + rhsGroupDimensions); +} + } // namespace chlo } // namespace mlir diff --git a/stablehlo/dialect/ChloOps.td b/stablehlo/dialect/ChloOps.td index 77ed7b5493..1cca13b2ac 100644 --- a/stablehlo/dialect/ChloOps.td +++ b/stablehlo/dialect/ChloOps.td @@ -833,6 +833,67 @@ def CHLO_BroadcastSelectOp : CHLO_Op<"broadcast_select", }]; } +//===----------------------------------------------------------------------===// +// Ragged dot op +//===----------------------------------------------------------------------===// + +def CHLO_Dims : ArrayRefParameter<"int64_t", "Dimension"> { + let parser = "parseDimSizes($_parser)"; + let printer = "printDimSizes($_printer, $_self)"; +} + +def CHLO_RaggedDotDimensionNumbers : AttrDef { + let mnemonic = "ragged_dot"; + let summary = "Attribute that models the dimension information for ragged dot."; + let parameters = (ins + CHLO_Dims:$lhsBatchingDimensions, + CHLO_Dims:$rhsBatchingDimensions, + CHLO_Dims:$lhsContractingDimensions, + CHLO_Dims:$rhsContractingDimensions, + CHLO_Dims:$lhsRaggedDimensions, + CHLO_Dims:$rhsGroupDimensions + ); + let hasCustomAssemblyFormat = 1; +} + +def CHLO_RaggedDotOp : CHLO_Op<"ragged_dot", + [Pure, DeclareOpInterfaceMethods]> { + string summary = "Computes a matmul over a single ragged dimension"; + + string description = [{ + + This operation takes three tensor args---lhs, rhs, and group_sizes---and + a "ragged_dot_dimension_numbers" attribute. Like dot_general, the lhs and + rhs are allowed arbitrary batch and contracting dimensions. Additionally, + the lhs is required to have one ragged dimension, and the rhs may have at + most one group dimension. The op has three modes, depending on the kind of + the lhs ragged dimension. + + In mode 1, the shape-signature is `[b,m,k], [g,b,k,n], [g] -> [b,m,n]`. + Here the ragged dimension is an lhs non-contracting dimension (`m`). The + dimensions `b` and `k` represent batch and contracting dimensions + respectively. The rhs is required to have a group dimension (`g`). + + In mode 2, the shape-signature is `[b,m,k], [b,k,n], [g] -> [g,b,m,n]`. + Here the ragged dimension is an lhs/rhs contracting dimension (`k`). + + In mode 3, the shape-signature is `[b,m,k], [b,k,n], [g] -> [b,m,n]`. Here + the ragged dimension is an lhs/rhs batch dimension (`b`). + + }]; + + let arguments = (ins + HLO_AnyTensor:$lhs, + HLO_AnyTensor:$rhs, + Arg:$group_sizes, + CHLO_RaggedDotDimensionNumbers:$ragged_dot_dimension_numbers, + OptionalAttr:$precision_config + ); + + let results = (outs HLO_AnyTensor:$result); + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // Miscellaneous ops //===----------------------------------------------------------------------===// diff --git a/stablehlo/tests/ops_chlo.mlir b/stablehlo/tests/ops_chlo.mlir index 16f399f0c3..b5a021043b 100644 --- a/stablehlo/tests/ops_chlo.mlir +++ b/stablehlo/tests/ops_chlo.mlir @@ -73,6 +73,222 @@ func.func @constant_like(%arg0: tensor<1x2xi64>) -> (tensor<1x2xi32>) { // ----- +// ragged_dot mode 1: [b,m,k], [g,b,k,n], [g] -> [b,m,n] +func.func @ragged_dot_non_contracting(%lhs : tensor<2x11x5xf32>, %rhs : tensor<3x2x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<2x11x7xf32> { + %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { + ragged_dot_dimension_numbers = #chlo.ragged_dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [1], + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [2], + lhs_ragged_dimensions = [1], + rhs_group_dimensions = [0] + >, + precision_config = [#chlo, #chlo] + } : (tensor<2x11x5xf32>, tensor<3x2x5x7xf32>, tensor<3xi64>) -> tensor<2x11x7xf32> + func.return %0 : tensor<2x11x7xf32> +} + +// ----- + +// ragged_dot mode 2: [m,k], [k,n], [g] -> [g,m,n] +func.func @ragged_dot_contracting(%lhs : tensor<2x11x5xf32>, %rhs : tensor<2x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<3x2x11x7xf32> { + %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { + ragged_dot_dimension_numbers = #chlo.ragged_dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [0], + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [1], + lhs_ragged_dimensions = [2], + rhs_group_dimensions = [] + >, + precision_config = [#chlo, #chlo] + } : (tensor<2x11x5xf32>, tensor<2x5x7xf32>, tensor<3xi64>) -> tensor<3x2x11x7xf32> + func.return %0 : tensor<3x2x11x7xf32> +} + +// ----- + +// ragged_dot mode 3: [b,m,k], [b,k,n], [g] -> [b,m,n] +func.func @ragged_dot_batch(%lhs : tensor<3x11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<3x11x7xf32> { + %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { + ragged_dot_dimension_numbers = #chlo.ragged_dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [0], + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [1], + lhs_ragged_dimensions = [0], + rhs_group_dimensions = [] + >, + precision_config = [#chlo, #chlo] + } : (tensor<3x11x5xf32>, tensor<3x5x7xf32>, tensor<3xi64>) -> tensor<3x11x7xf32> + func.return %0 : tensor<3x11x7xf32> +} + +// ----- + +func.func @ragged_dot_incompatible_contracting_dims(%lhs : tensor<11x5xf32>, %rhs : tensor<3x2x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<11x7xf32> { + // @expected-error@+1 {{contracting dimension sizes must match}} + %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { + ragged_dot_dimension_numbers = #chlo.ragged_dot< + lhs_batching_dimensions = [], + rhs_batching_dimensions = [], + lhs_contracting_dimensions = [1], + rhs_contracting_dimensions = [1], + lhs_ragged_dimensions = [0], + rhs_group_dimensions = [0] + >, + precision_config = [#chlo, #chlo] + } : (tensor<11x5xf32>, tensor<3x2x7xf32>, tensor<3xi64>) -> tensor<11x7xf32> + func.return %0 : tensor<11x7xf32> +} + +// ----- + +func.func @ragged_dot_group_sizes_incorrect_rank(%lhs : tensor<11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<3x2xi64>) -> tensor<11x7xf32> { + // @expected-error@+1 {{expected rank of group_sizes of ragged dot to be 1, got 2}} + %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { + ragged_dot_dimension_numbers = #chlo.ragged_dot< + lhs_batching_dimensions = [], + rhs_batching_dimensions = [], + lhs_contracting_dimensions = [1], + rhs_contracting_dimensions = [1], + lhs_ragged_dimensions = [0], + rhs_group_dimensions = [0] + >, + precision_config = [#chlo, #chlo] + } : (tensor<11x5xf32>, tensor<3x5x7xf32>, tensor<3x2xi64>) -> tensor<11x7xf32> + func.return %0 : tensor<11x7xf32> +} + +// ----- + +func.func @ragged_dot_group_sizes_incorrect_shape(%lhs : tensor<11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<2xi64>) -> tensor<11x7xf32> { + // @expected-error@+1 {{group_sizes is expected to have shape=[3], got [2]}} + %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { + ragged_dot_dimension_numbers = #chlo.ragged_dot< + lhs_batching_dimensions = [], + rhs_batching_dimensions = [], + lhs_contracting_dimensions = [1], + rhs_contracting_dimensions = [1], + lhs_ragged_dimensions = [0], + rhs_group_dimensions = [0] + >, + precision_config = [#chlo, #chlo] + } : (tensor<11x5xf32>, tensor<3x5x7xf32>, tensor<2xi64>) -> tensor<11x7xf32> + func.return %0 : tensor<11x7xf32> +} + +// ----- + +func.func @ragged_dot_incorrect_number_of_lhs_ragged_dimensions(%lhs : tensor<11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<11x7xf32> { + // @expected-error@+1 {{There must be exactly one ragged dimension in the lhs}} + %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { + ragged_dot_dimension_numbers = #chlo.ragged_dot< + lhs_batching_dimensions = [], + rhs_batching_dimensions = [], + lhs_contracting_dimensions = [1], + rhs_contracting_dimensions = [1], + lhs_ragged_dimensions = [0, 1], + rhs_group_dimensions = [0] + >, + precision_config = [#chlo, #chlo] + } : (tensor<11x5xf32>, tensor<3x5x7xf32>, tensor<3xi64>) -> tensor<11x7xf32> + func.return %0 : tensor<11x7xf32> +} + +// ----- + +func.func @ragged_dot_rhs_group_dim_is_batch(%lhs : tensor<3x11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<3x11x7xf32> { + // @expected-error@+1 {{has duplicated dimension from rhs_group_dimensions and rhs_batching_dimensions: 0}} + %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { + ragged_dot_dimension_numbers = #chlo.ragged_dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [0], + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [1], + lhs_ragged_dimensions = [1], + rhs_group_dimensions = [0] + >, + precision_config = [#chlo, #chlo] + } : (tensor<3x11x5xf32>, tensor<3x5x7xf32>, tensor<3xi64>) -> tensor<3x11x7xf32> + func.return %0 : tensor<3x11x7xf32> +} + +// ----- + +func.func @ragged_dot_rhs_group_dim_is_contracting(%lhs : tensor<11x3xf32>, %rhs : tensor<3x3x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<11x7xf32> { + // @expected-error@+1 {{has duplicated dimension from rhs_group_dimensions and rhs_contracting_dimensions: 1}} + %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { + ragged_dot_dimension_numbers = #chlo.ragged_dot< + lhs_batching_dimensions = [], + rhs_batching_dimensions = [], + lhs_contracting_dimensions = [1], + rhs_contracting_dimensions = [1], + lhs_ragged_dimensions = [0], + rhs_group_dimensions = [1] + >, + precision_config = [#chlo, #chlo] + } : (tensor<11x3xf32>, tensor<3x3x7xf32>, tensor<3xi64>) -> tensor<11x7xf32> + func.return %0 : tensor<11x7xf32> +} + +// ----- + +func.func @ragged_dot_nonzero_rhs_group_dims_for_ragged_batch(%lhs : tensor<2x11x5xf32>, %rhs : tensor<3x2x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<2x11x7xf32> { + // @expected-error@+1 {{There must be zero group dimensions in the rhs when the ragged dimension is batch or contracting}} + %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { + ragged_dot_dimension_numbers = #chlo.ragged_dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [1], + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [2], + lhs_ragged_dimensions = [0], + rhs_group_dimensions = [0] + >, + precision_config = [#chlo, #chlo] + } : (tensor<2x11x5xf32>, tensor<3x2x5x7xf32>, tensor<3xi64>) -> tensor<2x11x7xf32> + func.return %0 : tensor<2x11x7xf32> +} + +// ----- + +func.func @ragged_dot_nonzero_rhs_group_dims_for_ragged_contracting(%lhs : tensor<11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<11x7xf32> { + // @expected-error@+1 {{There must be zero group dimensions in the rhs when the ragged dimension is batch or contracting}} + %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { + ragged_dot_dimension_numbers = #chlo.ragged_dot< + lhs_batching_dimensions = [], + rhs_batching_dimensions = [], + lhs_contracting_dimensions = [1], + rhs_contracting_dimensions = [1], + lhs_ragged_dimensions = [1], + rhs_group_dimensions = [0] + >, + precision_config = [#chlo, #chlo] + } : (tensor<11x5xf32>, tensor<3x5x7xf32>, tensor<3xi64>) -> tensor<11x7xf32> + func.return %0 : tensor<11x7xf32> +} + +// ----- + +func.func @ragged_dot_zero_rhs_group_dims_for_ragged_noncontracting(%lhs : tensor<11x5xf32>, %rhs : tensor<5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<11x7xf32> { + // @expected-error@+1 {{There must be exactly one group dimension in the rhs when the lhs ragged dimension is non-contracting}} + %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { + ragged_dot_dimension_numbers = #chlo.ragged_dot< + lhs_batching_dimensions = [], + rhs_batching_dimensions = [], + lhs_contracting_dimensions = [1], + rhs_contracting_dimensions = [0], + lhs_ragged_dimensions = [0], + rhs_group_dimensions = [] + >, + precision_config = [#chlo, #chlo] + } : (tensor<11x5xf32>, tensor<5x7xf32>, tensor<3xi64>) -> tensor<11x7xf32> + func.return %0 : tensor<11x7xf32> +} + +// ----- + func.func @top_k(%arg0 : tensor) { // expected-error @+2 {{failed to infer returned types}} // @expected-error @+1{{operand's rank must be at least 1}}