diff --git a/WORKSPACE.bazel b/WORKSPACE.bazel index 85d606f4302..c3749f7f6cf 100644 --- a/WORKSPACE.bazel +++ b/WORKSPACE.bazel @@ -17,9 +17,9 @@ workspace(name = "stablehlo") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") -LLVM_COMMIT = "af20aff35ec37ead88903bc3e44f6a81c5c9ca4e" +LLVM_COMMIT = "e86910337f98e57f5b9253f7d80d5b916eb1d97e" -LLVM_SHA256 = "6e31682011d8c483c6a41adf5389eb09ad7db84331ca985d33a5d59efd0388f6" +LLVM_SHA256 = "4ca0eff0ca86ed6f2fdb7682354fdf4c85151d90ac9fb6e55a868e4191359e9f" http_archive( name = "llvm-raw", diff --git a/build_tools/llvm_version.txt b/build_tools/llvm_version.txt index 52c2904470c..ddf3f30d6fd 100644 --- a/build_tools/llvm_version.txt +++ b/build_tools/llvm_version.txt @@ -1 +1 @@ -af20aff35ec37ead88903bc3e44f6a81c5c9ca4e +e86910337f98e57f5b9253f7d80d5b916eb1d97e diff --git a/stablehlo/integrations/python/CheckModule.cpp b/stablehlo/integrations/python/CheckModule.cpp index c67c7911a8e..a2db7497b6d 100644 --- a/stablehlo/integrations/python/CheckModule.cpp +++ b/stablehlo/integrations/python/CheckModule.cpp @@ -11,12 +11,13 @@ limitations under the License. ==============================================================================*/ #include "mlir-c/IR.h" -#include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "nanobind/nanobind.h" #include "stablehlo/integrations/c/CheckDialect.h" -namespace py = pybind11; +namespace nb = nanobind; -PYBIND11_MODULE(_check, m) { +NB_MODULE(_check, m) { m.doc() = "check main python extension"; // @@ -32,5 +33,5 @@ PYBIND11_MODULE(_check, m) { mlirDialectHandleLoadDialect(dialect, context); } }, - py::arg("context"), py::arg("load") = true); + nb::arg("context"), nb::arg("load") = true); } diff --git a/stablehlo/integrations/python/ChloModule.cpp b/stablehlo/integrations/python/ChloModule.cpp index cbc0f47c1ab..68e6743b8b8 100644 --- a/stablehlo/integrations/python/ChloModule.cpp +++ b/stablehlo/integrations/python/ChloModule.cpp @@ -12,21 +12,23 @@ limitations under the License. ==============================================================================*/ #include "mlir-c/IR.h" -#include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string_view.h" #include "stablehlo/integrations/c/ChloAttributes.h" #include "stablehlo/integrations/c/ChloDialect.h" -namespace py = pybind11; +namespace nb = nanobind; namespace { auto toPyString(MlirStringRef mlirStringRef) { - return py::str(mlirStringRef.data, mlirStringRef.length); + return nb::str(mlirStringRef.data, mlirStringRef.length); } } // namespace -PYBIND11_MODULE(_chlo, m) { +NB_MODULE(_chlo, m) { m.doc() = "chlo main python extension"; // @@ -42,35 +44,37 @@ PYBIND11_MODULE(_chlo, m) { mlirDialectHandleLoadDialect(dialect, context); } }, - py::arg("context"), py::arg("load") = true); + nb::arg("context"), nb::arg("load") = true); // // Attributes. // - mlir::python::adaptors::mlir_attribute_subclass( + mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "ComparisonDirectionAttr", chloAttributeIsAComparisonDirectionAttr) .def_classmethod( "get", - [](py::object cls, const std::string &value, MlirContext ctx) { + [](nb::object cls, std::string_view value, MlirContext ctx) { return cls(chloComparisonDirectionAttrGet( - ctx, mlirStringRefCreate(value.c_str(), value.size()))); + ctx, mlirStringRefCreate(value.data(), value.size()))); }, - py::arg("cls"), py::arg("value"), py::arg("context") = py::none(), + nb::arg("cls"), nb::arg("value"), + nb::arg("context").none() = nb::none(), "Creates a ComparisonDirection attribute with the given value.") .def_property_readonly("value", [](MlirAttribute self) { return toPyString(chloComparisonDirectionAttrGetValue(self)); }); - mlir::python::adaptors::mlir_attribute_subclass( + mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "ComparisonTypeAttr", chloAttributeIsAComparisonTypeAttr) .def_classmethod( "get", - [](py::object cls, const std::string &value, MlirContext ctx) { + [](nb::object cls, std::string_view value, MlirContext ctx) { return cls(chloComparisonTypeAttrGet( - ctx, mlirStringRefCreate(value.c_str(), value.size()))); + ctx, mlirStringRefCreate(value.data(), value.size()))); }, - py::arg("cls"), py::arg("value"), py::arg("context") = py::none(), + nb::arg("cls"), nb::arg("value"), + nb::arg("context").none() = nb::none(), "Creates a ComparisonType attribute with the given value.") .def_property_readonly("value", [](MlirAttribute self) { return toPyString(chloComparisonTypeAttrGetValue(self)); diff --git a/stablehlo/integrations/python/StablehloApi.cpp b/stablehlo/integrations/python/StablehloApi.cpp index 68ff3fa6004..e36b2704c19 100644 --- a/stablehlo/integrations/python/StablehloApi.cpp +++ b/stablehlo/integrations/python/StablehloApi.cpp @@ -15,6 +15,7 @@ limitations under the License. #include "stablehlo/integrations/python/StablehloApi.h" +#include #include #include @@ -22,10 +23,14 @@ limitations under the License. #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/IR.h" #include "mlir-c/Support.h" -#include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" +#include "nanobind/stl/string_view.h" +#include "nanobind/stl/vector.h" #include "stablehlo/integrations/c/StablehloApi.h" -namespace py = pybind11; +namespace nb = nanobind; namespace mlir { namespace stablehlo { @@ -63,14 +68,18 @@ static MlirStringRef toMlirStringRef(std::string_view s) { return mlirStringRefCreate(s.data(), s.size()); } -void AddStablehloApi(py::module &m) { +static MlirStringRef toMlirStringRef(const nb::bytes &s) { + return mlirStringRefCreate(static_cast(s.data()), s.size()); +} + +void AddStablehloApi(nb::module_ &m) { // Portable API is a subset of StableHLO API AddPortableApi(m); // // Utility APIs. // - py::enum_( + nb::enum_( m, "StablehloCompatibilityRequirement") .value("NONE", MlirStablehloCompatibilityRequirement::NONE) .value("WEEK_4", MlirStablehloCompatibilityRequirement::WEEK_4) @@ -79,34 +88,34 @@ void AddStablehloApi(py::module &m) { m.def( "get_version_from_compatibility_requirement", - [](MlirStablehloCompatibilityRequirement requirement) -> py::str { + [](MlirStablehloCompatibilityRequirement requirement) -> std::string { StringWriterHelper accumulator; stablehloVersionFromCompatibilityRequirement( requirement, accumulator.getMlirStringCallback(), accumulator.getUserData()); return accumulator.toString(); }, - py::arg("requirement")); + nb::arg("requirement")); // // Serialization APIs. // m.def( "serialize_portable_artifact", - [](MlirModule module, std::string_view target) -> py::bytes { + [](MlirModule module, std::string_view target) -> nb::bytes { StringWriterHelper accumulator; if (mlirLogicalResultIsFailure( stablehloSerializePortableArtifactFromModule( module, toMlirStringRef(target), accumulator.getMlirStringCallback(), accumulator.getUserData()))) { - PyErr_SetString(PyExc_ValueError, "failed to serialize module"); - return ""; + throw nb::value_error("failed to serialize module"); } - return py::bytes(accumulator.toString()); + std::string serialized = accumulator.toString(); + return nb::bytes(serialized.data(), serialized.size()); }, - py::arg("module"), py::arg("target")); + nb::arg("module"), nb::arg("target")); m.def( "deserialize_portable_artifact", @@ -114,13 +123,22 @@ void AddStablehloApi(py::module &m) { auto module = stablehloDeserializePortableArtifactNoError( toMlirStringRef(artifact), context); if (mlirModuleIsNull(module)) { - PyErr_SetString(PyExc_ValueError, "failed to deserialize module"); - return {}; + throw nb::value_error("failed to deserialize module"); } return module; }, - py::arg("context"), py::arg("artifact")); - + nb::arg("context"), nb::arg("artifact")); + m.def( + "deserialize_portable_artifact", + [](MlirContext context, nb::bytes artifact) -> MlirModule { + auto module = stablehloDeserializePortableArtifactNoError( + toMlirStringRef(artifact), context); + if (mlirModuleIsNull(module)) { + throw nb::value_error("failed to deserialize module"); + } + return module; + }, + nb::arg("context"), nb::arg("artifact")); // // Reference APIs // @@ -130,9 +148,7 @@ void AddStablehloApi(py::module &m) { std::vector &args) -> std::vector { for (auto arg : args) { if (!mlirAttributeIsADenseElements(arg)) { - PyErr_SetString(PyExc_ValueError, - "input args must be DenseElementsAttr"); - return {}; + throw nb::value_error("input args must be DenseElementsAttr"); } } @@ -141,8 +157,7 @@ void AddStablehloApi(py::module &m) { stablehloEvalModule(module, args.size(), args.data(), &errorCode); if (errorCode != 0) { - PyErr_SetString(PyExc_ValueError, "interpreter failed"); - return {}; + throw nb::value_error("interpreter failed"); } std::vector pyResults; @@ -151,10 +166,10 @@ void AddStablehloApi(py::module &m) { } return pyResults; }, - py::arg("module"), py::arg("args")); + nb::arg("module"), nb::arg("args")); } -void AddPortableApi(py::module &m) { +void AddPortableApi(nb::module_ &m) { // // Utility APIs. // @@ -162,28 +177,28 @@ void AddPortableApi(py::module &m) { m.def( "get_smaller_version", - [](const std::string &version1, const std::string &version2) -> py::str { + [](const std::string &version1, + const std::string &version2) -> std::string { StringWriterHelper accumulator; if (mlirLogicalResultIsFailure(stablehloGetSmallerVersion( toMlirStringRef(version1), toMlirStringRef(version2), accumulator.getMlirStringCallback(), accumulator.getUserData()))) { - PyErr_SetString(PyExc_ValueError, - "failed to convert version to stablehlo version"); - return ""; + throw nb::value_error( + "failed to convert version to stablehlo version"); } return accumulator.toString(); }, - py::arg("version1"), py::arg("version2")); + nb::arg("version1"), nb::arg("version2")); - m.def("get_current_version", []() -> py::str { + m.def("get_current_version", []() -> std::string { StringWriterHelper accumulator; stablehloGetCurrentVersion(accumulator.getMlirStringCallback(), accumulator.getUserData()); return accumulator.toString(); }); - m.def("get_minimum_version", []() -> py::str { + m.def("get_minimum_version", []() -> std::string { StringWriterHelper accumulator; stablehloGetMinimumVersion(accumulator.getMlirStringCallback(), accumulator.getUserData()); @@ -196,7 +211,24 @@ void AddPortableApi(py::module &m) { m.def( "serialize_portable_artifact_str", [](std::string_view moduleStrOrBytecode, - std::string_view targetVersion) -> py::bytes { + std::string_view targetVersion) -> nb::bytes { + StringWriterHelper accumulator; + if (mlirLogicalResultIsFailure( + stablehloSerializePortableArtifactFromStringRef( + toMlirStringRef(moduleStrOrBytecode), + toMlirStringRef(targetVersion), + accumulator.getMlirStringCallback(), + accumulator.getUserData()))) { + throw nb::value_error("failed to serialize module"); + } + std::string serialized = accumulator.toString(); + return nb::bytes(serialized.data(), serialized.size()); + }, + nb::arg("module_str"), nb::arg("target_version")); + m.def( + "serialize_portable_artifact_str", + [](nb::bytes moduleStrOrBytecode, + std::string_view targetVersion) -> nb::bytes { StringWriterHelper accumulator; if (mlirLogicalResultIsFailure( stablehloSerializePortableArtifactFromStringRef( @@ -204,26 +236,39 @@ void AddPortableApi(py::module &m) { toMlirStringRef(targetVersion), accumulator.getMlirStringCallback(), accumulator.getUserData()))) { - PyErr_SetString(PyExc_ValueError, "failed to serialize module"); - return ""; + throw nb::value_error("failed to serialize module"); } - return py::bytes(accumulator.toString()); + std::string serialized = accumulator.toString(); + return nb::bytes(serialized.data(), serialized.size()); }, - py::arg("module_str"), py::arg("target_version")); + nb::arg("module_str"), nb::arg("target_version")); m.def( "deserialize_portable_artifact_str", - [](std::string_view artifact) -> py::bytes { + [](std::string_view artifact) -> nb::bytes { + StringWriterHelper accumulator; + if (mlirLogicalResultIsFailure(stablehloDeserializePortableArtifact( + toMlirStringRef(artifact), accumulator.getMlirStringCallback(), + accumulator.getUserData()))) { + throw nb::value_error("failed to deserialize module"); + } + std::string serialized = accumulator.toString(); + return nb::bytes(serialized.data(), serialized.size()); + }, + nb::arg("artifact_str")); + m.def( + "deserialize_portable_artifact_str", + [](const nb::bytes &artifact) -> nb::bytes { StringWriterHelper accumulator; if (mlirLogicalResultIsFailure(stablehloDeserializePortableArtifact( toMlirStringRef(artifact), accumulator.getMlirStringCallback(), accumulator.getUserData()))) { - PyErr_SetString(PyExc_ValueError, "failed to deserialize module"); - return ""; + throw nb::value_error("failed to deserialize module"); } - return py::bytes(accumulator.toString()); + std::string serialized = accumulator.toString(); + return nb::bytes(serialized.data(), serialized.size()); }, - py::arg("artifact_str")); + nb::arg("artifact_str")); } } // namespace stablehlo diff --git a/stablehlo/integrations/python/StablehloApi.h b/stablehlo/integrations/python/StablehloApi.h index e0a96a122f9..9987086ddda 100644 --- a/stablehlo/integrations/python/StablehloApi.h +++ b/stablehlo/integrations/python/StablehloApi.h @@ -16,20 +16,20 @@ limitations under the License. #ifndef STABLEHLO_INTEGRATIONS_PYTHON_API_STABLEHLOAPI_H #define STABLEHLO_INTEGRATIONS_PYTHON_API_STABLEHLOAPI_H -#include "pybind11/pybind11.h" +#include "nanobind/nanobind.h" namespace mlir { namespace stablehlo { -// Add StableHLO APIs to the pybind11 module. +// Add StableHLO APIs to the nanobind module. // Signatures of these APIs have no dependency on C++ MLIR types and all must // use C API passthrough. -void AddStablehloApi(pybind11::module& m); +void AddStablehloApi(nanobind::module_& m); // Adds a subset of the StableHLO API that doesn't use MLIR in any definitions, // and is methods only, introducing no new objects / enums to avoid potential // redefinition issues in complex build environments. -void AddPortableApi(pybind11::module& m); +void AddPortableApi(nanobind::module_& m); } // namespace stablehlo } // namespace mlir diff --git a/stablehlo/integrations/python/StablehloModule.cpp b/stablehlo/integrations/python/StablehloModule.cpp index a3f05b8a746..08c6ebd6cfe 100644 --- a/stablehlo/integrations/python/StablehloModule.cpp +++ b/stablehlo/integrations/python/StablehloModule.cpp @@ -15,14 +15,17 @@ limitations under the License. #include "mlir-c/IR.h" #include "mlir-c/Support.h" -#include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" +#include "nanobind/stl/vector.h" #include "stablehlo/integrations/c/StablehloAttributes.h" #include "stablehlo/integrations/c/StablehloDialect.h" #include "stablehlo/integrations/c/StablehloPasses.h" #include "stablehlo/integrations/c/StablehloTypes.h" #include "stablehlo/integrations/python/StablehloApi.h" -namespace py = pybind11; +namespace nb = nanobind; namespace { // Returns a vector containing integers extracted from an attribute using the @@ -40,12 +43,12 @@ std::vector attributePropertyVector( } auto toPyString(MlirStringRef mlirStringRef) { - return py::str(mlirStringRef.data, mlirStringRef.length); + return nb::str(mlirStringRef.data, mlirStringRef.length); } } // namespace -PYBIND11_MODULE(_stablehlo, m) { +NB_MODULE(_stablehlo, m) { m.doc() = "stablehlo main python extension"; // @@ -61,7 +64,7 @@ PYBIND11_MODULE(_stablehlo, m) { mlirDialectHandleLoadDialect(dialect, context); } }, - py::arg("context"), py::arg("load") = true); + nb::arg("context"), nb::arg("load") = true); // // Passes. @@ -74,14 +77,14 @@ PYBIND11_MODULE(_stablehlo, m) { // Types. // - mlir::python::adaptors::mlir_type_subclass(m, "TokenType", - stablehloTypeIsAToken) + mlir::python::nanobind_adaptors::mlir_type_subclass(m, "TokenType", + stablehloTypeIsAToken) .def_classmethod( "get", - [](py::object cls, MlirContext ctx) { + [](nb::object cls, MlirContext ctx) { return cls(stablehloTokenTypeGet(ctx)); }, - py::arg("cls"), py::arg("context") = py::none(), + nb::arg("cls"), nb::arg("context").none() = nb::none(), "Creates a Token type."); // @@ -94,12 +97,12 @@ PYBIND11_MODULE(_stablehlo, m) { stablehloScatterDimensionNumbersGetScatteredDimsToOperandDimsElem); }; - mlir::python::adaptors::mlir_attribute_subclass( + mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "ScatterDimensionNumbers", stablehloAttributeIsAScatterDimensionNumbers) .def_classmethod( "get", - [](py::object cls, const std::vector &updateWindowDims, + [](nb::object cls, const std::vector &updateWindowDims, const std::vector &insertedWindowDims, const std::vector &inputBatchingDims, const std::vector &scatterIndicesBatchingDims, @@ -114,11 +117,11 @@ PYBIND11_MODULE(_stablehlo, m) { scatteredDimsToOperandDims.size(), scatteredDimsToOperandDims.data(), indexVectorDim)); }, - py::arg("cls"), py::arg("update_window_dims"), - py::arg("inserted_window_dims"), py::arg("input_batching_dims"), - py::arg("scatter_indices_batching_dims"), - py::arg("scattered_dims_to_operand_dims"), - py::arg("index_vector_dim"), py::arg("context") = py::none(), + nb::arg("cls"), nb::arg("update_window_dims"), + nb::arg("inserted_window_dims"), nb::arg("input_batching_dims"), + nb::arg("scatter_indices_batching_dims"), + nb::arg("scattered_dims_to_operand_dims"), + nb::arg("index_vector_dim"), nb::arg("context").none() = nb::none(), "Creates a ScatterDimensionNumbers with the given dimension " "configuration.") .def_property_readonly( @@ -156,11 +159,11 @@ PYBIND11_MODULE(_stablehlo, m) { return stablehloDimensionNumbersGetIndexVectorDim(self); }); - mlir::python::adaptors::mlir_attribute_subclass( + mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "GatherDimensionNumbers", stablehloAttributeIsAGatherDimensionNumbers) .def_classmethod( "get", - [](py::object cls, const std::vector &offsetDims, + [](nb::object cls, const std::vector &offsetDims, const std::vector &collapsedSliceDims, const std::vector &operandBatchingDims, const std::vector &startIndicesBatchingDims, @@ -174,10 +177,10 @@ PYBIND11_MODULE(_stablehlo, m) { startIndicesBatchingDims.data(), startIndexMap.size(), startIndexMap.data(), indexVectorDim)); }, - py::arg("cls"), py::arg("offset_dims"), - py::arg("collapsed_slice_dims"), py::arg("operand_batching_dims"), - py::arg("start_indices_batching_dims"), py::arg("start_index_map"), - py::arg("index_vector_dim"), py::arg("context") = py::none(), + nb::arg("cls"), nb::arg("offset_dims"), + nb::arg("collapsed_slice_dims"), nb::arg("operand_batching_dims"), + nb::arg("start_indices_batching_dims"), nb::arg("start_index_map"), + nb::arg("index_vector_dim"), nb::arg("context").none() = nb::none(), "Creates a GatherDimensionNumbers attribute with the given dimension " "configuration.") .def_property_readonly( @@ -220,11 +223,11 @@ PYBIND11_MODULE(_stablehlo, m) { return stablehloGatherDimensionNumbersGetIndexVectorDim(self); }); - mlir::python::adaptors::mlir_attribute_subclass( + mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "DotAlgorithm", stablehloAttributeIsADotAlgorithm) .def_classmethod( "get", - [](py::object cls, MlirType lhsPrecisionType, + [](nb::object cls, MlirType lhsPrecisionType, MlirType rhsPrecisionType, MlirType accumulationType, int64_t lhsComponentCount, int64_t rhsComponentCount, int64_t numPrimitiveOperations, bool allowImpreciseAccumulation, @@ -234,11 +237,12 @@ PYBIND11_MODULE(_stablehlo, m) { lhsComponentCount, rhsComponentCount, numPrimitiveOperations, allowImpreciseAccumulation)); }, - py::arg("cls"), py::arg("lhs_precision_type"), - py::arg("rhs_precision_type"), py::arg("accumulation_type"), - py::arg("lhs_component_count"), py::arg("rhs_component_count"), - py::arg("num_primitive_operations"), - py::arg("allow_imprecise_accumulation"), py::arg("ctx") = py::none(), + nb::arg("cls"), nb::arg("lhs_precision_type"), + nb::arg("rhs_precision_type"), nb::arg("accumulation_type"), + nb::arg("lhs_component_count"), nb::arg("rhs_component_count"), + nb::arg("num_primitive_operations"), + nb::arg("allow_imprecise_accumulation"), + nb::arg("ctx").none() = nb::none(), "Creates a DotAlgorithm attribute with the given dimension " "configuration.") .def_property_readonly( @@ -276,11 +280,11 @@ PYBIND11_MODULE(_stablehlo, m) { return stablehloDotAlgorithmGetAllowImpreciseAccumulation(self); }); - mlir::python::adaptors::mlir_attribute_subclass( + mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "DotDimensionNumbers", stablehloAttributeIsADotDimensionNumbers) .def_classmethod( "get", - [](py::object cls, const std::vector &lhsBatchingDims, + [](nb::object cls, const std::vector &lhsBatchingDims, const std::vector &rhsBatchingDims, const std::vector &lhsContractingDims, const std::vector &rhsContractingDims, MlirContext ctx) { @@ -290,11 +294,11 @@ PYBIND11_MODULE(_stablehlo, m) { lhsContractingDims.size(), lhsContractingDims.data(), rhsContractingDims.size(), rhsContractingDims.data())); }, - py::arg("cls"), py::arg("lhs_batching_dimensions"), - py::arg("rhs_batching_dimensions"), - py::arg("lhs_contracting_dimensions"), - py::arg("rhs_contracting_dimensions"), - py::arg("context") = py::none(), + nb::arg("cls"), nb::arg("lhs_batching_dimensions"), + nb::arg("rhs_batching_dimensions"), + nb::arg("lhs_contracting_dimensions"), + nb::arg("rhs_contracting_dimensions"), + nb::arg("context").none() = nb::none(), "Creates a DotDimensionNumbers attribute with the given dimension " "configuration.") .def_property_readonly( @@ -327,11 +331,11 @@ PYBIND11_MODULE(_stablehlo, m) { stablehloDotDimensionNumbersGetRhsContractingDimensionsElem); }); - mlir::python::adaptors::mlir_attribute_subclass( + mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "ConvDimensionNumbers", stablehloAttributeIsAConvDimensionNumbers) .def_classmethod( "get", - [](py::object cls, int64_t inputBatchDimension, + [](nb::object cls, int64_t inputBatchDimension, int64_t inputFeatureDimension, const std::vector inputSpatialDimensions, int64_t kernelInputFeatureDimension, @@ -349,15 +353,16 @@ PYBIND11_MODULE(_stablehlo, m) { outputSpatialDimensions.size(), outputSpatialDimensions.data())); }, - py::arg("cls"), py::arg("input_batch_dimension"), - py::arg("input_feature_dimension"), - py::arg("input_spatial_dimensions"), - py::arg("kernel_input_feature_dimension"), - py::arg("kernel_output_feature_dimension"), - py::arg("kernel_spatial_dimensions"), - py::arg("output_batch_dimension"), - py::arg("output_feature_dimension"), - py::arg("output_spatial_dimensions"), py::arg("ctx") = py::none(), + nb::arg("cls"), nb::arg("input_batch_dimension"), + nb::arg("input_feature_dimension"), + nb::arg("input_spatial_dimensions"), + nb::arg("kernel_input_feature_dimension"), + nb::arg("kernel_output_feature_dimension"), + nb::arg("kernel_spatial_dimensions"), + nb::arg("output_batch_dimension"), + nb::arg("output_feature_dimension"), + nb::arg("output_spatial_dimensions"), + nb::arg("ctx").none() = nb::none(), "Creates a ConvDimensionNumbers attribute with the given dimension " "configuration.") .def_property_readonly( @@ -416,11 +421,11 @@ PYBIND11_MODULE(_stablehlo, m) { stablehloConvDimensionNumbersGetOutputSpatialDimensionsElem); }); - mlir::python::adaptors::mlir_attribute_subclass( + mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "OutputOperandAlias", stablehloAttributeIsAOutputOperandAlias) .def_classmethod( "get", - [](py::object cls, const std::vector outputTupleIndices, + [](nb::object cls, const std::vector outputTupleIndices, int64_t operandIndex, const std::vector operandTupleIndices, MlirContext ctx) { return cls(stablehloOutputOperandAliasGet( @@ -428,9 +433,9 @@ PYBIND11_MODULE(_stablehlo, m) { operandIndex, operandTupleIndices.size(), operandTupleIndices.data())); }, - py::arg("cls"), py::arg("output_tuple_indices"), - py::arg("operand_index"), py::arg("operand_tuple_indices"), - py::arg("ctx") = py::none(), + nb::arg("cls"), nb::arg("output_tuple_indices"), + nb::arg("operand_index"), nb::arg("operand_tuple_indices"), + nb::arg("ctx").none() = nb::none(), "Creates a OutputOperandAlias attribute with the given tuple index.") .def_property_readonly( "output_tuple_indices", @@ -450,114 +455,122 @@ PYBIND11_MODULE(_stablehlo, m) { stablehloOutputOperandAliasGetOperandTupleIndicesElem); }); - mlir::python::adaptors::mlir_attribute_subclass( + mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "ComparisonDirectionAttr", stablehloAttributeIsAComparisonDirectionAttr) .def_classmethod( "get", - [](py::object cls, const std::string &value, MlirContext ctx) { + [](nb::object cls, const std::string &value, MlirContext ctx) { return cls(stablehloComparisonDirectionAttrGet( ctx, mlirStringRefCreate(value.c_str(), value.size()))); }, - py::arg("cls"), py::arg("value"), py::arg("context") = py::none(), + nb::arg("cls"), nb::arg("value"), + nb::arg("context").none() = nb::none(), "Creates a ComparisonDirection attribute with the given value.") .def_property_readonly("value", [](MlirAttribute self) { return toPyString(stablehloComparisonDirectionAttrGetValue(self)); }); - mlir::python::adaptors::mlir_attribute_subclass( + mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "ComparisonTypeAttr", stablehloAttributeIsAComparisonTypeAttr) .def_classmethod( "get", - [](py::object cls, const std::string &value, MlirContext ctx) { + [](nb::object cls, const std::string &value, MlirContext ctx) { return cls(stablehloComparisonTypeAttrGet( ctx, mlirStringRefCreate(value.c_str(), value.size()))); }, - py::arg("cls"), py::arg("value"), py::arg("context") = py::none(), + nb::arg("cls"), nb::arg("value"), + nb::arg("context").none() = nb::none(), "Creates a ComparisonType attribute with the given value.") .def_property_readonly("value", [](MlirAttribute self) { return toPyString(stablehloComparisonTypeAttrGetValue(self)); }); - mlir::python::adaptors::mlir_attribute_subclass( + mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "PrecisionAttr", stablehloAttributeIsAPrecisionAttr) .def_classmethod( "get", - [](py::object cls, const std::string &value, MlirContext ctx) { + [](nb::object cls, const std::string &value, MlirContext ctx) { return cls(stablehloPrecisionAttrGet( ctx, mlirStringRefCreate(value.c_str(), value.size()))); }, - py::arg("cls"), py::arg("value"), py::arg("context") = py::none(), + nb::arg("cls"), nb::arg("value"), + nb::arg("context").none() = nb::none(), "Creates a Precision attribute with the given value.") .def_property_readonly("value", [](MlirAttribute self) { return toPyString(stablehloPrecisionAttrGetValue(self)); }); - mlir::python::adaptors::mlir_attribute_subclass( + mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "FftTypeAttr", stablehloAttributeIsAFftTypeAttr) .def_classmethod( "get", - [](py::object cls, const std::string &value, MlirContext ctx) { + [](nb::object cls, const std::string &value, MlirContext ctx) { return cls(stablehloFftTypeAttrGet( ctx, mlirStringRefCreate(value.c_str(), value.size()))); }, - py::arg("cls"), py::arg("value"), py::arg("context") = py::none(), + nb::arg("cls"), nb::arg("value"), + nb::arg("context").none() = nb::none(), "Creates a FftType attribute with the given value.") .def_property_readonly("value", [](MlirAttribute self) { return toPyString(stablehloFftTypeAttrGetValue(self)); }); - mlir::python::adaptors::mlir_attribute_subclass( + mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "TransposeAttr", stablehloAttributeIsATransposeAttr) .def_classmethod( "get", - [](py::object cls, const std::string &value, MlirContext ctx) { + [](nb::object cls, const std::string &value, MlirContext ctx) { return cls(stablehloTransposeAttrGet( ctx, mlirStringRefCreate(value.c_str(), value.size()))); }, - py::arg("cls"), py::arg("value"), py::arg("context") = py::none(), + nb::arg("cls"), nb::arg("value"), + nb::arg("context").none() = nb::none(), "Creates a Transpose attribute with the given value.") .def_property_readonly("value", [](MlirAttribute self) { return toPyString(stablehloTransposeAttrGetValue(self)); }); - mlir::python::adaptors::mlir_attribute_subclass( + mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "RngDistributionAttr", stablehloAttributeIsARngDistributionAttr) .def_classmethod( "get", - [](py::object cls, const std::string &value, MlirContext ctx) { + [](nb::object cls, const std::string &value, MlirContext ctx) { return cls(stablehloRngDistributionAttrGet( ctx, mlirStringRefCreate(value.c_str(), value.size()))); }, - py::arg("cls"), py::arg("value"), py::arg("context") = py::none(), + nb::arg("cls"), nb::arg("value"), + nb::arg("context").none() = nb::none(), "Creates a RngDistribution attribute with the given value.") .def_property_readonly("value", [](MlirAttribute self) { return toPyString(stablehloRngDistributionAttrGetValue(self)); }); - mlir::python::adaptors::mlir_attribute_subclass( + mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "RngAlgorithmAttr", stablehloAttributeIsARngAlgorithmAttr) .def_classmethod( "get", - [](py::object cls, const std::string &value, MlirContext ctx) { + [](nb::object cls, const std::string &value, MlirContext ctx) { return cls(stablehloRngAlgorithmAttrGet( ctx, mlirStringRefCreate(value.c_str(), value.size()))); }, - py::arg("cls"), py::arg("value"), py::arg("context") = py::none(), + nb::arg("cls"), nb::arg("value"), + nb::arg("context").none() = nb::none(), "Creates a RngAlgorithm attribute with the given value.") .def_property_readonly("value", [](MlirAttribute self) { return toPyString(stablehloRngAlgorithmAttrGetValue(self)); }); - mlir::python::adaptors::mlir_attribute_subclass( + mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "ChannelHandle", stablehloAttributeIsChannelHandle) .def_classmethod( "get", - [](py::object cls, int64_t handle, int64_t type, MlirContext ctx) { + [](nb::object cls, int64_t handle, int64_t type, MlirContext ctx) { return cls(stablehloChannelHandleGet(ctx, handle, type)); }, - py::arg("cls"), py::arg("handle"), py::arg("type"), - py::arg("context") = py::none(), "Creates a ChannelHandle attribute.") + nb::arg("cls"), nb::arg("handle"), nb::arg("type"), + nb::arg("context").none() = nb::none(), + "Creates a ChannelHandle attribute.") .def_property_readonly("handle", [](MlirAttribute self) { return stablehloChannelHandleGetHandle(self); @@ -568,16 +581,17 @@ PYBIND11_MODULE(_stablehlo, m) { return stablehloChannelHandleGetType(self); }); - mlir::python::adaptors::mlir_attribute_subclass( + mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "TypeExtensions", stablehloAttributeIsTypeExtensions) .def_classmethod( "get", - [](py::object cls, const std::vector &bounds, + [](nb::object cls, const std::vector &bounds, MlirContext ctx) { return cls( stablehloTypeExtensionsGet(ctx, bounds.size(), bounds.data())); }, - py::arg("cls"), py::arg("bounds"), py::arg("context") = py::none(), + nb::arg("cls"), nb::arg("bounds"), + nb::arg("context").none() = nb::none(), "Creates a TypeExtensions with the given bounds.") .def_property_readonly("bounds", [](MlirAttribute self) { return attributePropertyVector(self, diff --git a/stablehlo/integrations/python/VhloModule.cpp b/stablehlo/integrations/python/VhloModule.cpp index affc040ca63..433c593568e 100644 --- a/stablehlo/integrations/python/VhloModule.cpp +++ b/stablehlo/integrations/python/VhloModule.cpp @@ -11,12 +11,13 @@ limitations under the License. ==============================================================================*/ #include "mlir-c/IR.h" -#include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "nanobind/nanobind.h" #include "stablehlo/integrations/c/VhloDialect.h" -namespace py = pybind11; +namespace nb = nanobind; -PYBIND11_MODULE(_vhlo, m) { +NB_MODULE(_vhlo, m) { m.doc() = "vhlo main python extension"; // @@ -32,5 +33,5 @@ PYBIND11_MODULE(_vhlo, m) { mlirDialectHandleLoadDialect(dialect, context); } }, - py::arg("context"), py::arg("load") = true); + nb::arg("context"), nb::arg("load") = true); } diff --git a/stablehlo/tests/ops_chlo.mlir b/stablehlo/tests/ops_chlo.mlir index b5a021043b5..f3e7f77736a 100644 --- a/stablehlo/tests/ops_chlo.mlir +++ b/stablehlo/tests/ops_chlo.mlir @@ -289,6 +289,222 @@ func.func @ragged_dot_zero_rhs_group_dims_for_ragged_noncontracting(%lhs : tenso // ----- +// ragged_dot mode 1: [b,m,k], [g,b,k,n], [g] -> [b,m,n] +func.func @ragged_dot_non_contracting(%lhs : tensor<2x11x5xf32>, %rhs : tensor<3x2x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<2x11x7xf32> { + %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { + ragged_dot_dimension_numbers = #chlo.ragged_dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [1], + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [2], + lhs_ragged_dimensions = [1], + rhs_group_dimensions = [0] + >, + precision_config = [#chlo, #chlo] + } : (tensor<2x11x5xf32>, tensor<3x2x5x7xf32>, tensor<3xi64>) -> tensor<2x11x7xf32> + func.return %0 : tensor<2x11x7xf32> +} + +// ----- + +// ragged_dot mode 2: [m,k], [k,n], [g] -> [g,m,n] +func.func @ragged_dot_contracting(%lhs : tensor<2x11x5xf32>, %rhs : tensor<2x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<3x2x11x7xf32> { + %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { + ragged_dot_dimension_numbers = #chlo.ragged_dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [0], + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [1], + lhs_ragged_dimensions = [2], + rhs_group_dimensions = [] + >, + precision_config = [#chlo, #chlo] + } : (tensor<2x11x5xf32>, tensor<2x5x7xf32>, tensor<3xi64>) -> tensor<3x2x11x7xf32> + func.return %0 : tensor<3x2x11x7xf32> +} + +// ----- + +// ragged_dot mode 3: [b,m,k], [b,k,n], [g] -> [b,m,n] +func.func @ragged_dot_batch(%lhs : tensor<3x11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<3x11x7xf32> { + %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { + ragged_dot_dimension_numbers = #chlo.ragged_dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [0], + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [1], + lhs_ragged_dimensions = [0], + rhs_group_dimensions = [] + >, + precision_config = [#chlo, #chlo] + } : (tensor<3x11x5xf32>, tensor<3x5x7xf32>, tensor<3xi64>) -> tensor<3x11x7xf32> + func.return %0 : tensor<3x11x7xf32> +} + +// ----- + +func.func @ragged_dot_incompatible_contracting_dims(%lhs : tensor<11x5xf32>, %rhs : tensor<3x2x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<11x7xf32> { + // @expected-error@+1 {{contracting dimension sizes must match}} + %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { + ragged_dot_dimension_numbers = #chlo.ragged_dot< + lhs_batching_dimensions = [], + rhs_batching_dimensions = [], + lhs_contracting_dimensions = [1], + rhs_contracting_dimensions = [1], + lhs_ragged_dimensions = [0], + rhs_group_dimensions = [0] + >, + precision_config = [#chlo, #chlo] + } : (tensor<11x5xf32>, tensor<3x2x7xf32>, tensor<3xi64>) -> tensor<11x7xf32> + func.return %0 : tensor<11x7xf32> +} + +// ----- + +func.func @ragged_dot_group_sizes_incorrect_rank(%lhs : tensor<11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<3x2xi64>) -> tensor<11x7xf32> { + // @expected-error@+1 {{expected rank of group_sizes of ragged dot to be 1, got 2}} + %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { + ragged_dot_dimension_numbers = #chlo.ragged_dot< + lhs_batching_dimensions = [], + rhs_batching_dimensions = [], + lhs_contracting_dimensions = [1], + rhs_contracting_dimensions = [1], + lhs_ragged_dimensions = [0], + rhs_group_dimensions = [0] + >, + precision_config = [#chlo, #chlo] + } : (tensor<11x5xf32>, tensor<3x5x7xf32>, tensor<3x2xi64>) -> tensor<11x7xf32> + func.return %0 : tensor<11x7xf32> +} + +// ----- + +func.func @ragged_dot_group_sizes_incorrect_shape(%lhs : tensor<11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<2xi64>) -> tensor<11x7xf32> { + // @expected-error@+1 {{group_sizes is expected to have shape=[3], got [2]}} + %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { + ragged_dot_dimension_numbers = #chlo.ragged_dot< + lhs_batching_dimensions = [], + rhs_batching_dimensions = [], + lhs_contracting_dimensions = [1], + rhs_contracting_dimensions = [1], + lhs_ragged_dimensions = [0], + rhs_group_dimensions = [0] + >, + precision_config = [#chlo, #chlo] + } : (tensor<11x5xf32>, tensor<3x5x7xf32>, tensor<2xi64>) -> tensor<11x7xf32> + func.return %0 : tensor<11x7xf32> +} + +// ----- + +func.func @ragged_dot_incorrect_number_of_lhs_ragged_dimensions(%lhs : tensor<11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<11x7xf32> { + // @expected-error@+1 {{There must be exactly one ragged dimension in the lhs}} + %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { + ragged_dot_dimension_numbers = #chlo.ragged_dot< + lhs_batching_dimensions = [], + rhs_batching_dimensions = [], + lhs_contracting_dimensions = [1], + rhs_contracting_dimensions = [1], + lhs_ragged_dimensions = [0, 1], + rhs_group_dimensions = [0] + >, + precision_config = [#chlo, #chlo] + } : (tensor<11x5xf32>, tensor<3x5x7xf32>, tensor<3xi64>) -> tensor<11x7xf32> + func.return %0 : tensor<11x7xf32> +} + +// ----- + +func.func @ragged_dot_rhs_group_dim_is_batch(%lhs : tensor<3x11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<3x11x7xf32> { + // @expected-error@+1 {{has duplicated dimension from rhs_group_dimensions and rhs_batching_dimensions: 0}} + %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { + ragged_dot_dimension_numbers = #chlo.ragged_dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [0], + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [1], + lhs_ragged_dimensions = [1], + rhs_group_dimensions = [0] + >, + precision_config = [#chlo, #chlo] + } : (tensor<3x11x5xf32>, tensor<3x5x7xf32>, tensor<3xi64>) -> tensor<3x11x7xf32> + func.return %0 : tensor<3x11x7xf32> +} + +// ----- + +func.func @ragged_dot_rhs_group_dim_is_contracting(%lhs : tensor<11x3xf32>, %rhs : tensor<3x3x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<11x7xf32> { + // @expected-error@+1 {{has duplicated dimension from rhs_group_dimensions and rhs_contracting_dimensions: 1}} + %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { + ragged_dot_dimension_numbers = #chlo.ragged_dot< + lhs_batching_dimensions = [], + rhs_batching_dimensions = [], + lhs_contracting_dimensions = [1], + rhs_contracting_dimensions = [1], + lhs_ragged_dimensions = [0], + rhs_group_dimensions = [1] + >, + precision_config = [#chlo, #chlo] + } : (tensor<11x3xf32>, tensor<3x3x7xf32>, tensor<3xi64>) -> tensor<11x7xf32> + func.return %0 : tensor<11x7xf32> +} + +// ----- + +func.func @ragged_dot_nonzero_rhs_group_dims_for_ragged_batch(%lhs : tensor<2x11x5xf32>, %rhs : tensor<3x2x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<2x11x7xf32> { + // @expected-error@+1 {{There must be zero group dimensions in the rhs when the ragged dimension is batch or contracting}} + %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { + ragged_dot_dimension_numbers = #chlo.ragged_dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [1], + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [2], + lhs_ragged_dimensions = [0], + rhs_group_dimensions = [0] + >, + precision_config = [#chlo, #chlo] + } : (tensor<2x11x5xf32>, tensor<3x2x5x7xf32>, tensor<3xi64>) -> tensor<2x11x7xf32> + func.return %0 : tensor<2x11x7xf32> +} + +// ----- + +func.func @ragged_dot_nonzero_rhs_group_dims_for_ragged_contracting(%lhs : tensor<11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<11x7xf32> { + // @expected-error@+1 {{There must be zero group dimensions in the rhs when the ragged dimension is batch or contracting}} + %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { + ragged_dot_dimension_numbers = #chlo.ragged_dot< + lhs_batching_dimensions = [], + rhs_batching_dimensions = [], + lhs_contracting_dimensions = [1], + rhs_contracting_dimensions = [1], + lhs_ragged_dimensions = [1], + rhs_group_dimensions = [0] + >, + precision_config = [#chlo, #chlo] + } : (tensor<11x5xf32>, tensor<3x5x7xf32>, tensor<3xi64>) -> tensor<11x7xf32> + func.return %0 : tensor<11x7xf32> +} + +// ----- + +func.func @ragged_dot_zero_rhs_group_dims_for_ragged_noncontracting(%lhs : tensor<11x5xf32>, %rhs : tensor<5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<11x7xf32> { + // @expected-error@+1 {{There must be exactly one group dimension in the rhs when the lhs ragged dimension is non-contracting}} + %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { + ragged_dot_dimension_numbers = #chlo.ragged_dot< + lhs_batching_dimensions = [], + rhs_batching_dimensions = [], + lhs_contracting_dimensions = [1], + rhs_contracting_dimensions = [0], + lhs_ragged_dimensions = [0], + rhs_group_dimensions = [] + >, + precision_config = [#chlo, #chlo] + } : (tensor<11x5xf32>, tensor<5x7xf32>, tensor<3xi64>) -> tensor<11x7xf32> + func.return %0 : tensor<11x7xf32> +} + +// ----- + func.func @top_k(%arg0 : tensor) { // expected-error @+2 {{failed to infer returned types}} // @expected-error @+1{{operand's rank must be at least 1}}