Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][tosa] Add profile-based operation validation #126992

Merged
merged 1 commit into from
Feb 20, 2025

Conversation

tatwaichong
Copy link
Contributor

TOSA MLIR profile-based validation is designed to identify the profile/extension requirements for each operation in TOSA MLIR graph, ensuring that TOSA operators conform to the profiles and extensions enabled by the target implementation.

The available profiles/extensions are reflected in the availability property attached to each TOSA operator in the dialect. The design of availability, the profile/extension classes, and their interface, is inspired by the SPIRV implementation.

This patch includes the following changes:

  • Introduces profile and extension knowledge within the dialect and establishes an interface to query this information.
  • Implements profile-based validation logic in the pass.
  • Adds a TargetEnv class that represents the capabilities enabled in the target implementation, such as profiles, extensions, and levels.
  • Adds a set of tests to ensure that profile and extension requirements are properly attached to the operations and that validation correctly verifies the requirements of a given operation against the target implementation.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:linalg mlir mlir:tosa bazel "Peripheral" support tier build system: utils/bazel labels Feb 13, 2025
@llvmbot
Copy link
Member

llvmbot commented Feb 13, 2025

@llvm/pr-subscribers-mlir-tosa
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: TatWai Chong (tatwaichong)

Changes

TOSA MLIR profile-based validation is designed to identify the profile/extension requirements for each operation in TOSA MLIR graph, ensuring that TOSA operators conform to the profiles and extensions enabled by the target implementation.

The available profiles/extensions are reflected in the availability property attached to each TOSA operator in the dialect. The design of availability, the profile/extension classes, and their interface, is inspired by the SPIRV implementation.

This patch includes the following changes:

  • Introduces profile and extension knowledge within the dialect and establishes an interface to query this information.
  • Implements profile-based validation logic in the pass.
  • Adds a TargetEnv class that represents the capabilities enabled in the target implementation, such as profiles, extensions, and levels.
  • Adds a set of tests to ensure that profile and extension requirements are properly attached to the operations and that validation correctly verifies the requirements of a given operation against the target implementation.

Patch is 146.02 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/126992.diff

31 Files Affected:

  • (modified) mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h (+1-1)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt (+11)
  • (added) mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h (+84)
  • (added) mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h (+403)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+200)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h (+7)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+361)
  • (added) mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h (+163)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td (+4)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td (+15)
  • (modified) mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h (+1-22)
  • (modified) mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td (+3-11)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Tosa/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+3)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp (+485)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+36-12)
  • (added) mlir/test/Dialect/Tosa/availability.mlir (+690)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+1-2)
  • (added) mlir/test/Dialect/Tosa/invalid_extension.mlir (+38)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+2-3)
  • (added) mlir/test/Dialect/Tosa/profile_all_unsupported.mlir (+83)
  • (added) mlir/test/Dialect/Tosa/profile_bi_unsupported.mlir (+26)
  • (added) mlir/test/Dialect/Tosa/profile_mi_unsupported.mlir (+62)
  • (modified) mlir/test/lib/Dialect/Tosa/CMakeLists.txt (+1)
  • (added) mlir/test/lib/Dialect/Tosa/TestAvailability.cpp (+78)
  • (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2)
  • (modified) mlir/tools/mlir-tblgen/CMakeLists.txt (+1)
  • (added) mlir/tools/mlir-tblgen/TosaUtilsGen.cpp (+226)
  • (modified) utils/bazel/llvm-project-overlay/mlir/BUILD.bazel (+20)
diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index a1eb22eba6987..195a58432737b 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -40,7 +40,7 @@ void addTosaToLinalgPasses(
     // Note: Default to 'none' level unless otherwise specified.
     std::optional<tosa::TosaValidationOptions> validationOptions =
         tosa::TosaValidationOptions{
-            {"none"}, false, tosa::TosaLevelEnum::None});
+            {"none"}, {"none"}, false, tosa::TosaLevelEnum::None});
 
 /// Populates TOSA to linalg pipelines
 /// Currently, this includes only the "tosa-to-linalg-pipeline".
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
index cc8d5ed9b0044..0a855d701d7b8 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
@@ -12,3 +12,14 @@ add_public_tablegen_target(MLIRTosaAttributesIncGen)
 set(LLVM_TARGET_DEFINITIONS TosaDialectBytecode.td)
 mlir_tablegen(TosaDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Tosa")
 add_public_tablegen_target(MLIRTosaDialectBytecodeIncGen)
+
+set(LLVM_TARGET_DEFINITIONS TosaOpBase.td)
+mlir_tablegen(TosaEnums.h.inc -gen-enum-decls)
+mlir_tablegen(TosaEnums.cpp.inc -gen-enum-defs)
+add_public_tablegen_target(MLIRTosaEnumsIncGen)
+
+set(LLVM_TARGET_DEFINITIONS TosaOps.td)
+mlir_tablegen(TosaAvailability.h.inc -gen-avail-interface-decls)
+mlir_tablegen(TosaAvailability.cpp.inc -gen-avail-interface-defs)
+mlir_tablegen(TosaOpAvailabilityImpl.inc -gen-tosa-avail-impls)
+add_public_tablegen_target(MLIRTosaAvailabilityIncGen)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
new file mode 100644
index 0000000000000..86fb4077b9207
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
@@ -0,0 +1,84 @@
+//===- TargetEnv.h - Tosa target environment utilities ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares utilities for Tosa target environment (implementation).
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TOSA_IR_TARGETENV_H
+#define MLIR_DIALECT_TOSA_IR_TARGETENV_H
+
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/SmallSet.h"
+
+namespace mlir {
+namespace tosa {
+
+/// This class represents the capability enabled in the target implementation
+/// such as profile, extension, and level.
+class TargetEnv {
+public:
+  TargetEnv() {}
+  explicit TargetEnv(const SmallVectorImpl<Profile> &profiles,
+                     const SmallVectorImpl<Extension> &extensions) {
+    for (Profile prof : profiles)
+      enabledProfiles.insert(prof);
+
+    for (Extension ext : extensions)
+      enabledExtensions.insert(ext);
+  }
+
+  void addProfile(Profile p) { enabledProfiles.insert(p); }
+  void addExtension(Extension e) { enabledExtensions.insert(e); }
+
+  // TODO implement the following utilities.
+  // Version getSpecVersion() const;
+  // TosaLevel getLevel() const;
+
+  // Returns true if the given profile is allowed.
+  bool allows(Profile prof) const { return enabledProfiles.count(prof) != 0; }
+
+  bool allowsAnyOf(ArrayRef<Profile> profs) const {
+    const auto *chosen = llvm::find_if(
+        profs, [this](tosa::Profile prof) { return allows(prof); });
+    return chosen != profs.end() ? true : false;
+  }
+
+  bool allowsAllOf(ArrayRef<Profile> profs) const {
+    bool is_allowed = true;
+    llvm::for_each(profs,
+                   [&](tosa::Profile prof) { is_allowed &= allows(prof); });
+    return is_allowed;
+  }
+
+  // Returns true if the given extension is allowed.
+  bool allows(Extension ext) const { return enabledExtensions.count(ext) != 0; }
+
+  bool allowsAnyOf(ArrayRef<Extension> exts) const {
+    const auto *chosen = llvm::find_if(
+        exts, [this](tosa::Extension ext) { return allows(ext); });
+    return chosen != exts.end() ? true : false;
+  }
+
+  bool allowsAllOf(ArrayRef<Extension> exts) const {
+    bool is_allowed = true;
+    llvm::for_each(exts,
+                   [&](tosa::Extension ext) { is_allowed &= allows(ext); });
+    return is_allowed;
+  }
+
+private:
+  llvm::SmallSet<Profile, 3> enabledProfiles;
+  llvm::SmallSet<Extension, 8> enabledExtensions;
+};
+
+} // namespace tosa
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TOSA_IR_TARGETENV_H
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h
new file mode 100644
index 0000000000000..1a10d8579962d
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h
@@ -0,0 +1,403 @@
+// The profile-based compliance content below is auto-generated by a script
+// in https://git.mlplatform.org/tosa/specification.git
+profileComplianceMap = {
+    {"tosa.argmax",
+     {{{Profile::pro_int}, {{i8T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, i32T}, {fp32T, i32T}}}}},
+    {"tosa.avg_pool2d",
+     {{{Profile::pro_int}, {{i8T, i32T, i8T}}},
+      {{Profile::pro_fp},
+       {{fp16T, fp16T, fp16T}, {fp16T, fp32T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+    {"tosa.conv2d",
+     {{{Profile::pro_int}, {{i8T, i8T, i32T, i32T, i32T}}},
+      {{Profile::pro_fp},
+       {{fp16T, fp16T, fp16T, fp16T, fp16T},
+        {fp16T, fp16T, fp16T, fp32T, fp16T},
+        {fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+    {"tosa.conv3d",
+     {{{Profile::pro_int}, {{i8T, i8T, i32T, i32T, i32T}}},
+      {{Profile::pro_fp},
+       {{fp16T, fp16T, fp16T, fp16T, fp16T},
+        {fp16T, fp16T, fp16T, fp32T, fp16T},
+        {fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+    {"tosa.depthwise_conv2d",
+     {{{Profile::pro_int}, {{i8T, i8T, i32T, i32T, i32T}}},
+      {{Profile::pro_fp},
+       {{fp16T, fp16T, fp16T, fp16T, fp16T},
+        {fp16T, fp16T, fp16T, fp32T, fp16T},
+        {fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+    {"tosa.fully_connected",
+     {{{Profile::pro_int}, {{i8T, i8T, i32T, i32T}}},
+      {{Profile::pro_fp},
+       {{fp16T, fp16T, fp16T, fp16T},
+        {fp16T, fp16T, fp32T, fp32T},
+        {fp32T, fp32T, fp32T, fp32T}}}}},
+    {"tosa.matmul",
+     {{{Profile::pro_int}, {{i8T, i8T, i32T}}},
+      {{Profile::pro_fp},
+       {{fp16T, fp16T, fp16T}, {fp16T, fp16T, fp32T}, {fp32T, fp32T, fp32T}}}}},
+    {"tosa.max_pool2d",
+     {{{Profile::pro_int}, {{i8T, i8T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.transpose_conv2d",
+     {{{Profile::pro_int}, {{i8T, i8T, i32T, i32T, i32T}}},
+      {{Profile::pro_fp},
+       {{fp16T, fp16T, fp16T, fp16T, fp16T},
+        {fp16T, fp16T, fp16T, fp32T, fp16T},
+        {fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+    {"tosa.clamp",
+     {{{Profile::pro_int}, {{i8T, i8T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.erf", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.sigmoid", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.tanh", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.add",
+     {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+    {"tosa.arithmetic_right_shift",
+     {{{Profile::pro_int},
+       {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
+    {"tosa.bitwise_and",
+     {{{Profile::pro_int},
+       {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
+    {"tosa.bitwise_or",
+     {{{Profile::pro_int},
+       {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
+    {"tosa.bitwise_xor",
+     {{{Profile::pro_int},
+       {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
+    {"tosa.intdiv",
+     {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}}}},
+    {"tosa.logical_and",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}}}},
+    {"tosa.logical_left_shift",
+     {{{Profile::pro_int, Profile::pro_fp},
+       {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
+    {"tosa.logical_right_shift",
+     {{{Profile::pro_int, Profile::pro_fp},
+       {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
+    {"tosa.logical_or",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}}}},
+    {"tosa.logical_xor",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}}}},
+    {"tosa.maximum",
+     {{{Profile::pro_int}, {{i32T, i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+    {"tosa.minimum",
+     {{{Profile::pro_int}, {{i32T, i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+    {"tosa.mul",
+     {{{Profile::pro_int}, {{i8T, i8T, i32T}, {i16T, i16T, i32T}}},
+      {{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+    {"tosa.pow",
+     {{{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+    {"tosa.sub",
+     {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+    {"tosa.table", {{{Profile::pro_int}, {{i8T, i8T, i8T}}}}},
+    {"tosa.abs",
+     {{{Profile::pro_int}, {{i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.bitwise_not",
+     {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}}},
+    {"tosa.ceil", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.clz", {{{Profile::pro_int}, {{i32T, i32T}}}}},
+    {"tosa.cos", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.exp", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.floor", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.log", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.logical_not",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}}}},
+    {"tosa.negate",
+     {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.reciprocal",
+     {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.rsqrt", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.select",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}},
+      {{Profile::pro_int},
+       {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+    {"tosa.sin", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.equal",
+     {{{Profile::pro_int}, {{i32T, i32T, boolT}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}},
+    {"tosa.greater",
+     {{{Profile::pro_int}, {{i32T, i32T, boolT}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}},
+    {"tosa.greater_equal",
+     {{{Profile::pro_int}, {{i32T, i32T, boolT}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}},
+    {"tosa.reduce_all",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}}}},
+    {"tosa.reduce_any",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}}}},
+    {"tosa.reduce_max",
+     {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.reduce_min",
+     {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.reduce_product",
+     {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.reduce_sum",
+     {{{Profile::pro_int}, {{i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.concat",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}},
+      {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.pad",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}},
+      {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.reshape",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}},
+      {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.reverse",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}},
+      {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.slice",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}},
+      {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.tile",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}},
+      {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.transpose",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}},
+      {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.gather",
+     {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.scatter",
+     {{{Profile::pro_int},
+       {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+    {"tosa.resize",
+     {{{Profile::pro_int}, {{i8T, i32T}, {i8T, i8T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.cast",
+     {{{Profile::pro_int},
+       {{boolT, i8T},
+        {boolT, i16T},
+        {boolT, i32T},
+        {i8T, boolT},
+        {i8T, i16T},
+        {i8T, i32T},
+        {i16T, boolT},
+        {i16T, i8T},
+        {i16T, i32T},
+        {i32T, boolT},
+        {i32T, i8T},
+        {i32T, i16T}}},
+      {{Profile::pro_fp},
+       {{i8T, fp16T},
+        {i8T, fp32T},
+        {i16T, fp16T},
+        {i16T, fp32T},
+        {i32T, fp16T},
+        {i32T, fp32T},
+        {fp16T, i8T},
+        {fp16T, i16T},
+        {fp16T, i32T},
+        {fp16T, fp32T},
+        {fp32T, i8T},
+        {fp32T, i16T},
+        {fp32T, i32T},
+        {fp32T, fp16T}}}}},
+    {"tosa.rescale",
+     {{{Profile::pro_int},
+       {{i8T, i8T},
+        {i8T, i16T},
+        {i8T, i32T},
+        {i16T, i8T},
+        {i16T, i16T},
+        {i16T, i32T},
+        {i32T, i8T},
+        {i32T, i16T},
+        {i32T, i32T}}}}},
+    {"tosa.const",
+     {{{Profile::pro_int}, {{boolT}, {i8T}, {i16T}, {i32T}}},
+      {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
+    {"tosa.identity",
+     {{{Profile::pro_int},
+       {{boolT, boolT}, {i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.dim",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT}}},
+      {{Profile::pro_int}, {{i8T}, {i16T}, {i32T}}},
+      {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
+};
+
+extensionComplianceMap = {
+    {"tosa.argmax",
+     {{{Extension::int16}, {{i16T, i32T}}},
+      {{Extension::fp8e4m3}, {{fp8e4m3T, i32T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T, i32T}}},
+      {{Extension::bf16}, {{bf16T, i32T}}}}},
+    {"tosa.avg_pool2d",
+     {{{Extension::int16}, {{i16T, i32T, i16T}}},
+      {{Extension::fp8e4m3}, {{fp8e4m3T, fp16T, fp8e4m3T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T, fp16T, fp8e5m2T}}},
+      {{Extension::bf16}, {{bf16T, fp32T, bf16T}}}}},
+    {"tosa.conv2d",
+     {{{Extension::int4}, {{i8T, i4T, i32T, i32T, i32T}}},
+      {{Extension::int16}, {{i16T, i8T, i48T, i48T, i48T}}},
+      {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp16T, fp16T, fp16T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp16T, fp16T, fp16T}}},
+      {{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
+    {"tosa.conv3d",
+     {{{Extension::int4}, {{i8T, i4T, i32T, i32T, i32T}}},
+      {{Extension::int16}, {{i16T, i8T, i48T, i48T, i48T}}},
+      {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp16T, fp16T, fp16T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp16T, fp16T, fp16T}}},
+      {{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
+    {"tosa.depthwise_conv2d",
+     {{{Extension::int4}, {{i8T, i4T, i32T, i32T, i32T}}},
+      {{Extension::int16}, {{i16T, i8T, i48T, i48T, i48T}}},
+      {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp16T, fp16T, fp16T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp16T, fp16T, fp16T}}},
+      {{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
+    {"tosa.fft2d", {{{Extension::fft}, {{fp32T, fp32T, fp32T, fp32T}}}}},
+    {"tosa.fully_connected",
+     {{{Extension::int4}, {{i8T, i4T, i32T, i32T}}},
+      {{Extension::int16}, {{i16T, i8T, i48T, i48T}}},
+      {{Extension::bf16}, {{bf16T, bf16T, fp32T, fp32T}}}}},
+    {"tosa.matmul",
+     {{{Extension::int16}, {{i16T, i16T, i48T}}},
+      {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp16T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp16T}}},
+      {{Extension::bf16}, {{bf16T, bf16T, fp32T}}}}},
+    {"tosa.max_pool2d",
+     {{{Extension::int16}, {{i16T, i16T}}},
+      {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
+      {{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.rfft2d", {{{Extension::fft}, {{fp32T, fp32T, fp32T}}}}},
+    {"tosa.transpose_conv2d",
+     {{{Extension::int4}, {{i8T, i4T, i32T, i32T, i32T}}},
+      {{Extension::int16}, {{i16T, i8T, i48T, i48T, i48T}}},
+      {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp16T, fp16T, fp16T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp16T, fp16T, fp16T}}},
+      {{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
+    {"tosa.clamp",
+     {{{Extension::int16}, {{i16T, i16T}}},
+      {{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.erf", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.sigmoid", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.tanh", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.add", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
+    {"tosa.maximum", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
+    {"tosa.minimum", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
+    {"tosa.mul", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
+    {"tosa.pow", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
+    {"tosa.sub", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
+    {"tosa.table", {{{Extension::int16}, {{i16T, i16T, i32T}}}}},
+    {"tosa.abs", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.ceil", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.cos", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.exp", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.floor", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.log", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.negate", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.reciprocal", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.rsqrt", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.select", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
+    {"tosa.sin", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.equal", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}},
+    {"tosa.greater", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}},
+    {"tosa.greater_equal", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}},
+    {"tosa.reduce_max", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.reduce_min", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.reduce_product", {{{Ext...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Feb 13, 2025

@llvm/pr-subscribers-mlir-linalg

Author: TatWai Chong (tatwaichong)

Changes

TOSA MLIR profile-based validation is designed to identify the profile/extension requirements for each operation in TOSA MLIR graph, ensuring that TOSA operators conform to the profiles and extensions enabled by the target implementation.

The available profiles/extensions are reflected in the availability property attached to each TOSA operator in the dialect. The design of availability, the profile/extension classes, and their interface, is inspired by the SPIRV implementation.

This patch includes the following changes:

  • Introduces profile and extension knowledge within the dialect and establishes an interface to query this information.
  • Implements profile-based validation logic in the pass.
  • Adds a TargetEnv class that represents the capabilities enabled in the target implementation, such as profiles, extensions, and levels.
  • Adds a set of tests to ensure that profile and extension requirements are properly attached to the operations and that validation correctly verifies the requirements of a given operation against the target implementation.

Patch is 146.02 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/126992.diff

31 Files Affected:

  • (modified) mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h (+1-1)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt (+11)
  • (added) mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h (+84)
  • (added) mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h (+403)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+200)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h (+7)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+361)
  • (added) mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h (+163)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td (+4)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td (+15)
  • (modified) mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h (+1-22)
  • (modified) mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td (+3-11)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Tosa/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+3)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp (+485)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+36-12)
  • (added) mlir/test/Dialect/Tosa/availability.mlir (+690)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+1-2)
  • (added) mlir/test/Dialect/Tosa/invalid_extension.mlir (+38)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+2-3)
  • (added) mlir/test/Dialect/Tosa/profile_all_unsupported.mlir (+83)
  • (added) mlir/test/Dialect/Tosa/profile_bi_unsupported.mlir (+26)
  • (added) mlir/test/Dialect/Tosa/profile_mi_unsupported.mlir (+62)
  • (modified) mlir/test/lib/Dialect/Tosa/CMakeLists.txt (+1)
  • (added) mlir/test/lib/Dialect/Tosa/TestAvailability.cpp (+78)
  • (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2)
  • (modified) mlir/tools/mlir-tblgen/CMakeLists.txt (+1)
  • (added) mlir/tools/mlir-tblgen/TosaUtilsGen.cpp (+226)
  • (modified) utils/bazel/llvm-project-overlay/mlir/BUILD.bazel (+20)
diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index a1eb22eba6987..195a58432737b 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -40,7 +40,7 @@ void addTosaToLinalgPasses(
     // Note: Default to 'none' level unless otherwise specified.
     std::optional<tosa::TosaValidationOptions> validationOptions =
         tosa::TosaValidationOptions{
-            {"none"}, false, tosa::TosaLevelEnum::None});
+            {"none"}, {"none"}, false, tosa::TosaLevelEnum::None});
 
 /// Populates TOSA to linalg pipelines
 /// Currently, this includes only the "tosa-to-linalg-pipeline".
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
index cc8d5ed9b0044..0a855d701d7b8 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
@@ -12,3 +12,14 @@ add_public_tablegen_target(MLIRTosaAttributesIncGen)
 set(LLVM_TARGET_DEFINITIONS TosaDialectBytecode.td)
 mlir_tablegen(TosaDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Tosa")
 add_public_tablegen_target(MLIRTosaDialectBytecodeIncGen)
+
+set(LLVM_TARGET_DEFINITIONS TosaOpBase.td)
+mlir_tablegen(TosaEnums.h.inc -gen-enum-decls)
+mlir_tablegen(TosaEnums.cpp.inc -gen-enum-defs)
+add_public_tablegen_target(MLIRTosaEnumsIncGen)
+
+set(LLVM_TARGET_DEFINITIONS TosaOps.td)
+mlir_tablegen(TosaAvailability.h.inc -gen-avail-interface-decls)
+mlir_tablegen(TosaAvailability.cpp.inc -gen-avail-interface-defs)
+mlir_tablegen(TosaOpAvailabilityImpl.inc -gen-tosa-avail-impls)
+add_public_tablegen_target(MLIRTosaAvailabilityIncGen)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
new file mode 100644
index 0000000000000..86fb4077b9207
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
@@ -0,0 +1,84 @@
+//===- TargetEnv.h - Tosa target environment utilities ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares utilities for Tosa target environment (implementation).
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TOSA_IR_TARGETENV_H
+#define MLIR_DIALECT_TOSA_IR_TARGETENV_H
+
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/SmallSet.h"
+
+namespace mlir {
+namespace tosa {
+
+/// This class represents the capability enabled in the target implementation
+/// such as profile, extension, and level.
+class TargetEnv {
+public:
+  TargetEnv() {}
+  explicit TargetEnv(const SmallVectorImpl<Profile> &profiles,
+                     const SmallVectorImpl<Extension> &extensions) {
+    for (Profile prof : profiles)
+      enabledProfiles.insert(prof);
+
+    for (Extension ext : extensions)
+      enabledExtensions.insert(ext);
+  }
+
+  void addProfile(Profile p) { enabledProfiles.insert(p); }
+  void addExtension(Extension e) { enabledExtensions.insert(e); }
+
+  // TODO implement the following utilities.
+  // Version getSpecVersion() const;
+  // TosaLevel getLevel() const;
+
+  // Returns true if the given profile is allowed.
+  bool allows(Profile prof) const { return enabledProfiles.count(prof) != 0; }
+
+  bool allowsAnyOf(ArrayRef<Profile> profs) const {
+    const auto *chosen = llvm::find_if(
+        profs, [this](tosa::Profile prof) { return allows(prof); });
+    return chosen != profs.end() ? true : false;
+  }
+
+  bool allowsAllOf(ArrayRef<Profile> profs) const {
+    bool is_allowed = true;
+    llvm::for_each(profs,
+                   [&](tosa::Profile prof) { is_allowed &= allows(prof); });
+    return is_allowed;
+  }
+
+  // Returns true if the given extension is allowed.
+  bool allows(Extension ext) const { return enabledExtensions.count(ext) != 0; }
+
+  bool allowsAnyOf(ArrayRef<Extension> exts) const {
+    const auto *chosen = llvm::find_if(
+        exts, [this](tosa::Extension ext) { return allows(ext); });
+    return chosen != exts.end() ? true : false;
+  }
+
+  bool allowsAllOf(ArrayRef<Extension> exts) const {
+    bool is_allowed = true;
+    llvm::for_each(exts,
+                   [&](tosa::Extension ext) { is_allowed &= allows(ext); });
+    return is_allowed;
+  }
+
+private:
+  llvm::SmallSet<Profile, 3> enabledProfiles;
+  llvm::SmallSet<Extension, 8> enabledExtensions;
+};
+
+} // namespace tosa
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TOSA_IR_TARGETENV_H
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h
new file mode 100644
index 0000000000000..1a10d8579962d
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h
@@ -0,0 +1,403 @@
+// The profile-based compliance content below is auto-generated by a script
+// in https://git.mlplatform.org/tosa/specification.git
+profileComplianceMap = {
+    {"tosa.argmax",
+     {{{Profile::pro_int}, {{i8T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, i32T}, {fp32T, i32T}}}}},
+    {"tosa.avg_pool2d",
+     {{{Profile::pro_int}, {{i8T, i32T, i8T}}},
+      {{Profile::pro_fp},
+       {{fp16T, fp16T, fp16T}, {fp16T, fp32T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+    {"tosa.conv2d",
+     {{{Profile::pro_int}, {{i8T, i8T, i32T, i32T, i32T}}},
+      {{Profile::pro_fp},
+       {{fp16T, fp16T, fp16T, fp16T, fp16T},
+        {fp16T, fp16T, fp16T, fp32T, fp16T},
+        {fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+    {"tosa.conv3d",
+     {{{Profile::pro_int}, {{i8T, i8T, i32T, i32T, i32T}}},
+      {{Profile::pro_fp},
+       {{fp16T, fp16T, fp16T, fp16T, fp16T},
+        {fp16T, fp16T, fp16T, fp32T, fp16T},
+        {fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+    {"tosa.depthwise_conv2d",
+     {{{Profile::pro_int}, {{i8T, i8T, i32T, i32T, i32T}}},
+      {{Profile::pro_fp},
+       {{fp16T, fp16T, fp16T, fp16T, fp16T},
+        {fp16T, fp16T, fp16T, fp32T, fp16T},
+        {fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+    {"tosa.fully_connected",
+     {{{Profile::pro_int}, {{i8T, i8T, i32T, i32T}}},
+      {{Profile::pro_fp},
+       {{fp16T, fp16T, fp16T, fp16T},
+        {fp16T, fp16T, fp32T, fp32T},
+        {fp32T, fp32T, fp32T, fp32T}}}}},
+    {"tosa.matmul",
+     {{{Profile::pro_int}, {{i8T, i8T, i32T}}},
+      {{Profile::pro_fp},
+       {{fp16T, fp16T, fp16T}, {fp16T, fp16T, fp32T}, {fp32T, fp32T, fp32T}}}}},
+    {"tosa.max_pool2d",
+     {{{Profile::pro_int}, {{i8T, i8T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.transpose_conv2d",
+     {{{Profile::pro_int}, {{i8T, i8T, i32T, i32T, i32T}}},
+      {{Profile::pro_fp},
+       {{fp16T, fp16T, fp16T, fp16T, fp16T},
+        {fp16T, fp16T, fp16T, fp32T, fp16T},
+        {fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+    {"tosa.clamp",
+     {{{Profile::pro_int}, {{i8T, i8T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.erf", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.sigmoid", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.tanh", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.add",
+     {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+    {"tosa.arithmetic_right_shift",
+     {{{Profile::pro_int},
+       {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
+    {"tosa.bitwise_and",
+     {{{Profile::pro_int},
+       {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
+    {"tosa.bitwise_or",
+     {{{Profile::pro_int},
+       {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
+    {"tosa.bitwise_xor",
+     {{{Profile::pro_int},
+       {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
+    {"tosa.intdiv",
+     {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}}}},
+    {"tosa.logical_and",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}}}},
+    {"tosa.logical_left_shift",
+     {{{Profile::pro_int, Profile::pro_fp},
+       {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
+    {"tosa.logical_right_shift",
+     {{{Profile::pro_int, Profile::pro_fp},
+       {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
+    {"tosa.logical_or",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}}}},
+    {"tosa.logical_xor",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}}}},
+    {"tosa.maximum",
+     {{{Profile::pro_int}, {{i32T, i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+    {"tosa.minimum",
+     {{{Profile::pro_int}, {{i32T, i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+    {"tosa.mul",
+     {{{Profile::pro_int}, {{i8T, i8T, i32T}, {i16T, i16T, i32T}}},
+      {{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+    {"tosa.pow",
+     {{{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+    {"tosa.sub",
+     {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+    {"tosa.table", {{{Profile::pro_int}, {{i8T, i8T, i8T}}}}},
+    {"tosa.abs",
+     {{{Profile::pro_int}, {{i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.bitwise_not",
+     {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}}},
+    {"tosa.ceil", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.clz", {{{Profile::pro_int}, {{i32T, i32T}}}}},
+    {"tosa.cos", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.exp", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.floor", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.log", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.logical_not",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}}}},
+    {"tosa.negate",
+     {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.reciprocal",
+     {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.rsqrt", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.select",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}},
+      {{Profile::pro_int},
+       {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+    {"tosa.sin", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.equal",
+     {{{Profile::pro_int}, {{i32T, i32T, boolT}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}},
+    {"tosa.greater",
+     {{{Profile::pro_int}, {{i32T, i32T, boolT}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}},
+    {"tosa.greater_equal",
+     {{{Profile::pro_int}, {{i32T, i32T, boolT}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}},
+    {"tosa.reduce_all",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}}}},
+    {"tosa.reduce_any",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}}}},
+    {"tosa.reduce_max",
+     {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.reduce_min",
+     {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.reduce_product",
+     {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.reduce_sum",
+     {{{Profile::pro_int}, {{i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.concat",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}},
+      {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.pad",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}},
+      {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.reshape",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}},
+      {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.reverse",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}},
+      {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.slice",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}},
+      {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.tile",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}},
+      {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.transpose",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}},
+      {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.gather",
+     {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.scatter",
+     {{{Profile::pro_int},
+       {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+    {"tosa.resize",
+     {{{Profile::pro_int}, {{i8T, i32T}, {i8T, i8T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.cast",
+     {{{Profile::pro_int},
+       {{boolT, i8T},
+        {boolT, i16T},
+        {boolT, i32T},
+        {i8T, boolT},
+        {i8T, i16T},
+        {i8T, i32T},
+        {i16T, boolT},
+        {i16T, i8T},
+        {i16T, i32T},
+        {i32T, boolT},
+        {i32T, i8T},
+        {i32T, i16T}}},
+      {{Profile::pro_fp},
+       {{i8T, fp16T},
+        {i8T, fp32T},
+        {i16T, fp16T},
+        {i16T, fp32T},
+        {i32T, fp16T},
+        {i32T, fp32T},
+        {fp16T, i8T},
+        {fp16T, i16T},
+        {fp16T, i32T},
+        {fp16T, fp32T},
+        {fp32T, i8T},
+        {fp32T, i16T},
+        {fp32T, i32T},
+        {fp32T, fp16T}}}}},
+    {"tosa.rescale",
+     {{{Profile::pro_int},
+       {{i8T, i8T},
+        {i8T, i16T},
+        {i8T, i32T},
+        {i16T, i8T},
+        {i16T, i16T},
+        {i16T, i32T},
+        {i32T, i8T},
+        {i32T, i16T},
+        {i32T, i32T}}}}},
+    {"tosa.const",
+     {{{Profile::pro_int}, {{boolT}, {i8T}, {i16T}, {i32T}}},
+      {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
+    {"tosa.identity",
+     {{{Profile::pro_int},
+       {{boolT, boolT}, {i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
+      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+    {"tosa.dim",
+     {{{Profile::pro_int, Profile::pro_fp}, {{boolT}}},
+      {{Profile::pro_int}, {{i8T}, {i16T}, {i32T}}},
+      {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
+};
+
+extensionComplianceMap = {
+    {"tosa.argmax",
+     {{{Extension::int16}, {{i16T, i32T}}},
+      {{Extension::fp8e4m3}, {{fp8e4m3T, i32T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T, i32T}}},
+      {{Extension::bf16}, {{bf16T, i32T}}}}},
+    {"tosa.avg_pool2d",
+     {{{Extension::int16}, {{i16T, i32T, i16T}}},
+      {{Extension::fp8e4m3}, {{fp8e4m3T, fp16T, fp8e4m3T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T, fp16T, fp8e5m2T}}},
+      {{Extension::bf16}, {{bf16T, fp32T, bf16T}}}}},
+    {"tosa.conv2d",
+     {{{Extension::int4}, {{i8T, i4T, i32T, i32T, i32T}}},
+      {{Extension::int16}, {{i16T, i8T, i48T, i48T, i48T}}},
+      {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp16T, fp16T, fp16T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp16T, fp16T, fp16T}}},
+      {{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
+    {"tosa.conv3d",
+     {{{Extension::int4}, {{i8T, i4T, i32T, i32T, i32T}}},
+      {{Extension::int16}, {{i16T, i8T, i48T, i48T, i48T}}},
+      {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp16T, fp16T, fp16T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp16T, fp16T, fp16T}}},
+      {{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
+    {"tosa.depthwise_conv2d",
+     {{{Extension::int4}, {{i8T, i4T, i32T, i32T, i32T}}},
+      {{Extension::int16}, {{i16T, i8T, i48T, i48T, i48T}}},
+      {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp16T, fp16T, fp16T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp16T, fp16T, fp16T}}},
+      {{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
+    {"tosa.fft2d", {{{Extension::fft}, {{fp32T, fp32T, fp32T, fp32T}}}}},
+    {"tosa.fully_connected",
+     {{{Extension::int4}, {{i8T, i4T, i32T, i32T}}},
+      {{Extension::int16}, {{i16T, i8T, i48T, i48T}}},
+      {{Extension::bf16}, {{bf16T, bf16T, fp32T, fp32T}}}}},
+    {"tosa.matmul",
+     {{{Extension::int16}, {{i16T, i16T, i48T}}},
+      {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp16T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp16T}}},
+      {{Extension::bf16}, {{bf16T, bf16T, fp32T}}}}},
+    {"tosa.max_pool2d",
+     {{{Extension::int16}, {{i16T, i16T}}},
+      {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
+      {{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.rfft2d", {{{Extension::fft}, {{fp32T, fp32T, fp32T}}}}},
+    {"tosa.transpose_conv2d",
+     {{{Extension::int4}, {{i8T, i4T, i32T, i32T, i32T}}},
+      {{Extension::int16}, {{i16T, i8T, i48T, i48T, i48T}}},
+      {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp16T, fp16T, fp16T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp16T, fp16T, fp16T}}},
+      {{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
+    {"tosa.clamp",
+     {{{Extension::int16}, {{i16T, i16T}}},
+      {{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.erf", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.sigmoid", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.tanh", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.add", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
+    {"tosa.maximum", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
+    {"tosa.minimum", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
+    {"tosa.mul", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
+    {"tosa.pow", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
+    {"tosa.sub", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
+    {"tosa.table", {{{Extension::int16}, {{i16T, i16T, i32T}}}}},
+    {"tosa.abs", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.ceil", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.cos", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.exp", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.floor", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.log", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.negate", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.reciprocal", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.rsqrt", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.select", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
+    {"tosa.sin", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.equal", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}},
+    {"tosa.greater", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}},
+    {"tosa.greater_equal", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}},
+    {"tosa.reduce_max", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.reduce_min", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.reduce_product", {{{Ext...
[truncated]

@tatwaichong tatwaichong force-pushed the prof_validation branch 4 times, most recently from 0ff7829 to e09a503 Compare February 19, 2025 07:29
@@ -99,7 +89,9 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
let options = [
ListOption<"profile", "profile", "std::string",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this allow multi-profile checking?

Copy link
Contributor

@GeorgeARM GeorgeARM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @tatwaichong, very nice patch.

@GeorgeARM GeorgeARM requested a review from lhutton1 February 19, 2025 12:18
Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just had a few nitpicks

} else {
llvm::errs() << "warning: unknown profile name passed in, supported "
"profile are bi and mi\n";
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth signalPassFailure() when an invalid option is passed in? Just thinking of the case someone mistypes a name and subsequently spammed by lots of op errors

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, fair enough. I forgot to add it back.

Copy link
Contributor

@lhutton1 lhutton1 Feb 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think warning should also be updated to Error or something similar

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. The term warning is removed.

@tatwaichong tatwaichong force-pushed the prof_validation branch 2 times, most recently from cea8e60 to b407444 Compare February 20, 2025 16:28
Copy link

github-actions bot commented Feb 20, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

TOSA MLIR profile-based validation is designed to identify the
profile/extension requirements for each operation in TOSA MLIR
graph, ensuring that TOSA operators conform to the profiles and
extensions enabled by the target implementation.

The available profiles/extensions are reflected in the availability
property attached to each TOSA operator in the dialect. The design
of availability, the profile/extension classes, and their interface,
is inspired by the SPIRV implementation.

This patch includes the following changes:
 - Introduces profile and extension knowledge within the dialect
   and establishes an interface to query this information.
 - Implements profile-based validation logic in the pass.
 - Adds a TargetEnv class that represents the capabilities enabled
   in the target implementation, such as profiles, extensions, and
   levels.
 - Adds a set of tests to ensure that profile and extension
   requirements are properly attached to the operations and that
   validation correctly verifies the requirements of a given
   operation against the target implementation.
@Jerry-Ge Jerry-Ge merged commit 11468c3 into llvm:main Feb 20, 2025
8 checks passed
@tatwaichong tatwaichong deleted the prof_validation branch February 20, 2025 21:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bazel "Peripheral" support tier build system: utils/bazel mlir:core MLIR Core Infrastructure mlir:linalg mlir:tosa mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants