Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for SPV_INTEL_subgroup_matrix_multiply_accumulate #5928

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/spirv-tools/libspirv.h
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,9 @@ typedef enum spv_operand_type_t {
SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_REDUCE,
// Enum type from SPV_NV_cooperative_matrix2
SPV_OPERAND_TYPE_TENSOR_ADDRESSING_OPERANDS,
// Optional types from SPV_INTEL_subgroup_matrix_multiply_accumulate
SPV_OPERAND_TYPE_MATRIX_MULTIPLY_ACCUMULATE_OPERANDS,
SPV_OPERAND_TYPE_OPTIONAL_MATRIX_MULTIPLY_ACCUMULATE_OPERANDS,

// This is a sentinel value, and does not represent an operand type.
// It should come last.
Expand Down
7 changes: 6 additions & 1 deletion source/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,9 @@ spv_result_t Parser::parseOperand(size_t inst_offset,
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS:
case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS:
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_REDUCE:
case SPV_OPERAND_TYPE_TENSOR_ADDRESSING_OPERANDS: {
case SPV_OPERAND_TYPE_TENSOR_ADDRESSING_OPERANDS:
case SPV_OPERAND_TYPE_MATRIX_MULTIPLY_ACCUMULATE_OPERANDS:
case SPV_OPERAND_TYPE_OPTIONAL_MATRIX_MULTIPLY_ACCUMULATE_OPERANDS: {
// This operand is a mask.

// Map an optional operand type to its corresponding concrete type.
Expand All @@ -738,6 +740,9 @@ spv_result_t Parser::parseOperand(size_t inst_offset,
parsed_operand.type = SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS;
if (type == SPV_OPERAND_TYPE_OPTIONAL_RAW_ACCESS_CHAIN_OPERANDS)
parsed_operand.type = SPV_OPERAND_TYPE_RAW_ACCESS_CHAIN_OPERANDS;
if (type == SPV_OPERAND_TYPE_OPTIONAL_MATRIX_MULTIPLY_ACCUMULATE_OPERANDS)
parsed_operand.type =
SPV_OPERAND_TYPE_MATRIX_MULTIPLY_ACCUMULATE_OPERANDS;

// Check validity of set mask bits. Also prepare for operands for those
// masks if they have any. To get operand order correct, scan from
Expand Down
5 changes: 5 additions & 0 deletions source/operand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,9 @@ const char* spvOperandTypeStr(spv_operand_type_t type) {
return "cooperative matrix reduce";
case SPV_OPERAND_TYPE_TENSOR_ADDRESSING_OPERANDS:
return "tensor addressing operands";
case SPV_OPERAND_TYPE_MATRIX_MULTIPLY_ACCUMULATE_OPERANDS:
case SPV_OPERAND_TYPE_OPTIONAL_MATRIX_MULTIPLY_ACCUMULATE_OPERANDS:
return "matrix multiply accumulate operands";
case SPV_OPERAND_TYPE_INITIALIZATION_MODE_QUALIFIER:
return "initialization mode qualifier";
case SPV_OPERAND_TYPE_HOST_ACCESS_QUALIFIER:
Expand Down Expand Up @@ -415,6 +418,7 @@ bool spvOperandIsConcreteMask(spv_operand_type_t type) {
case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS:
case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS:
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS:
case SPV_OPERAND_TYPE_MATRIX_MULTIPLY_ACCUMULATE_OPERANDS:
case SPV_OPERAND_TYPE_RAW_ACCESS_CHAIN_OPERANDS:
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_REDUCE:
case SPV_OPERAND_TYPE_TENSOR_ADDRESSING_OPERANDS:
Expand All @@ -437,6 +441,7 @@ bool spvOperandIsOptional(spv_operand_type_t type) {
case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER:
case SPV_OPERAND_TYPE_OPTIONAL_PACKED_VECTOR_FORMAT:
case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS:
case SPV_OPERAND_TYPE_OPTIONAL_MATRIX_MULTIPLY_ACCUMULATE_OPERANDS:
case SPV_OPERAND_TYPE_OPTIONAL_CIV:
case SPV_OPERAND_TYPE_OPTIONAL_RAW_ACCESS_CHAIN_OPERANDS:
case SPV_OPERAND_TYPE_OPTIONAL_FPENCODING:
Expand Down
3 changes: 2 additions & 1 deletion source/text.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,8 @@ spv_result_t spvTextEncodeOperand(const spvtools::AssemblyGrammar& grammar,
case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS:
case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS:
case SPV_OPERAND_TYPE_TENSOR_ADDRESSING_OPERANDS:
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_REDUCE: {
case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_REDUCE:
case SPV_OPERAND_TYPE_OPTIONAL_MATRIX_MULTIPLY_ACCUMULATE_OPERANDS: {
uint32_t value;
if (auto error = grammar.parseMaskOperand(type, textValue, &value)) {
return context->diagnostic(error)
Expand Down
27 changes: 27 additions & 0 deletions test/binary_to_text_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,33 @@ INSTANTIATE_TEST_SUITE_P(
"OpDecorate %1 HostAccessINTEL ReadWriteINTEL \"readwrite\"\n",
})));

// clang-format off
INSTANTIATE_TEST_SUITE_P(
MatrixMultiplyAccumulateOperands, RoundTripInstructionsTest,
Combine(::testing::Values(SPV_ENV_UNIVERSAL_1_0),
::testing::ValuesIn(std::vector<std::string>{
"%2 = OpSubgroupMatrixMultiplyAccumulateINTEL %1 %3 %4 %5 %6\n",
"%2 = OpSubgroupMatrixMultiplyAccumulateINTEL %1 %3 %4 %5 %6 None\n",
"%2 = OpSubgroupMatrixMultiplyAccumulateINTEL %1 %3 %4 %5 %6 MatrixASignedComponentsINTEL\n",
"%2 = OpSubgroupMatrixMultiplyAccumulateINTEL %1 %3 %4 %5 %6 MatrixBSignedComponentsINTEL\n",
"%2 = OpSubgroupMatrixMultiplyAccumulateINTEL %1 %3 %4 %5 %6 MatrixCBFloat16INTEL\n",
"%2 = OpSubgroupMatrixMultiplyAccumulateINTEL %1 %3 %4 %5 %6 MatrixResultBFloat16INTEL\n",
"%2 = OpSubgroupMatrixMultiplyAccumulateINTEL %1 %3 %4 %5 %6 MatrixAPackedInt8INTEL\n",
"%2 = OpSubgroupMatrixMultiplyAccumulateINTEL %1 %3 %4 %5 %6 MatrixBPackedInt8INTEL\n",
"%2 = OpSubgroupMatrixMultiplyAccumulateINTEL %1 %3 %4 %5 %6 MatrixAPackedInt4INTEL\n",
"%2 = OpSubgroupMatrixMultiplyAccumulateINTEL %1 %3 %4 %5 %6 MatrixBPackedInt4INTEL\n",
"%2 = OpSubgroupMatrixMultiplyAccumulateINTEL %1 %3 %4 %5 %6 MatrixATF32INTEL\n",
"%2 = OpSubgroupMatrixMultiplyAccumulateINTEL %1 %3 %4 %5 %6 MatrixBTF32INTEL\n",
"%2 = OpSubgroupMatrixMultiplyAccumulateINTEL %1 %3 %4 %5 %6 MatrixCBFloat16INTEL\n",
"%2 = OpSubgroupMatrixMultiplyAccumulateINTEL %1 %3 %4 %5 %6 MatrixAPackedFloat16INTEL\n",
"%2 = OpSubgroupMatrixMultiplyAccumulateINTEL %1 %3 %4 %5 %6 MatrixBPackedFloat16INTEL\n",
"%2 = OpSubgroupMatrixMultiplyAccumulateINTEL %1 %3 %4 %5 %6 MatrixAPackedBFloat16INTEL\n",
"%2 = OpSubgroupMatrixMultiplyAccumulateINTEL %1 %3 %4 %5 %6 MatrixBPackedBFloat16INTEL\n",
"%2 = OpSubgroupMatrixMultiplyAccumulateINTEL %1 %3 %4 %5 %6 "
"MatrixASignedComponentsINTEL|MatrixBSignedComponentsINTEL|MatrixAPackedInt8INTEL|MatrixBPackedInt8INTEL\n",
})));
// clang-format on

using MaskSorting = TextToBinaryTest;

TEST_F(MaskSorting, MasksAreSortedFromLSBToMSB) {
Expand Down
2 changes: 1 addition & 1 deletion utils/generate_grammar_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,7 @@ def generate_operand_kind_table(enums):

# We have a few operand kinds that require their optional counterpart to
# exist in the operand info table.
optional_enums = ['ImageOperands', 'AccessQualifier', 'MemoryAccess', 'PackedVectorFormat', 'CooperativeMatrixOperands', 'RawAccessChainOperands', 'FPEncoding']
optional_enums = ['ImageOperands', 'AccessQualifier', 'MemoryAccess', 'PackedVectorFormat', 'CooperativeMatrixOperands', 'MatrixMultiplyAccumulateOperands', 'RawAccessChainOperands', 'FPEncoding']
optional_enums = [e for e in enums if e[0] in optional_enums]
enums.extend(optional_enums)

Expand Down
Loading