Skip to content

Commit

Permalink
Integrate LLVM at llvm/llvm-project@e86910337f98
Browse files Browse the repository at this point in the history
  • Loading branch information
sdasgup3 committed Dec 18, 2024
1 parent 33f19a4 commit 5db82a1
Show file tree
Hide file tree
Showing 9 changed files with 428 additions and 147 deletions.
4 changes: 2 additions & 2 deletions WORKSPACE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ workspace(name = "stablehlo")

load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

LLVM_COMMIT = "af20aff35ec37ead88903bc3e44f6a81c5c9ca4e"
LLVM_COMMIT = "e86910337f98e57f5b9253f7d80d5b916eb1d97e"

LLVM_SHA256 = "6e31682011d8c483c6a41adf5389eb09ad7db84331ca985d33a5d59efd0388f6"
LLVM_SHA256 = "4ca0eff0ca86ed6f2fdb7682354fdf4c85151d90ac9fb6e55a868e4191359e9f"

http_archive(
name = "llvm-raw",
Expand Down
2 changes: 1 addition & 1 deletion build_tools/llvm_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
af20aff35ec37ead88903bc3e44f6a81c5c9ca4e
e86910337f98e57f5b9253f7d80d5b916eb1d97e
9 changes: 5 additions & 4 deletions stablehlo/integrations/python/CheckModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@ limitations under the License.
==============================================================================*/

#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "nanobind/nanobind.h"
#include "stablehlo/integrations/c/CheckDialect.h"

namespace py = pybind11;
namespace nb = nanobind;

PYBIND11_MODULE(_check, m) {
NB_MODULE(_check, m) {
m.doc() = "check main python extension";

//
Expand All @@ -32,5 +33,5 @@ PYBIND11_MODULE(_check, m) {
mlirDialectHandleLoadDialect(dialect, context);
}
},
py::arg("context"), py::arg("load") = true);
nb::arg("context"), nb::arg("load") = true);
}
30 changes: 17 additions & 13 deletions stablehlo/integrations/python/ChloModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,23 @@ limitations under the License.
==============================================================================*/

#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "nanobind/nanobind.h"
#include "nanobind/stl/string_view.h"
#include "stablehlo/integrations/c/ChloAttributes.h"
#include "stablehlo/integrations/c/ChloDialect.h"

namespace py = pybind11;
namespace nb = nanobind;

namespace {

auto toPyString(MlirStringRef mlirStringRef) {
return py::str(mlirStringRef.data, mlirStringRef.length);
return nb::str(mlirStringRef.data, mlirStringRef.length);
}

} // namespace

PYBIND11_MODULE(_chlo, m) {
NB_MODULE(_chlo, m) {
m.doc() = "chlo main python extension";

//
Expand All @@ -42,35 +44,37 @@ PYBIND11_MODULE(_chlo, m) {
mlirDialectHandleLoadDialect(dialect, context);
}
},
py::arg("context"), py::arg("load") = true);
nb::arg("context"), nb::arg("load") = true);

//
// Attributes.
//

mlir::python::adaptors::mlir_attribute_subclass(
mlir::python::nanobind_adaptors::mlir_attribute_subclass(
m, "ComparisonDirectionAttr", chloAttributeIsAComparisonDirectionAttr)
.def_classmethod(
"get",
[](py::object cls, const std::string &value, MlirContext ctx) {
[](nb::object cls, std::string_view value, MlirContext ctx) {
return cls(chloComparisonDirectionAttrGet(
ctx, mlirStringRefCreate(value.c_str(), value.size())));
ctx, mlirStringRefCreate(value.data(), value.size())));
},
py::arg("cls"), py::arg("value"), py::arg("context") = py::none(),
nb::arg("cls"), nb::arg("value"),
nb::arg("context").none() = nb::none(),
"Creates a ComparisonDirection attribute with the given value.")
.def_property_readonly("value", [](MlirAttribute self) {
return toPyString(chloComparisonDirectionAttrGetValue(self));
});

mlir::python::adaptors::mlir_attribute_subclass(
mlir::python::nanobind_adaptors::mlir_attribute_subclass(
m, "ComparisonTypeAttr", chloAttributeIsAComparisonTypeAttr)
.def_classmethod(
"get",
[](py::object cls, const std::string &value, MlirContext ctx) {
[](nb::object cls, std::string_view value, MlirContext ctx) {
return cls(chloComparisonTypeAttrGet(
ctx, mlirStringRefCreate(value.c_str(), value.size())));
ctx, mlirStringRefCreate(value.data(), value.size())));
},
py::arg("cls"), py::arg("value"), py::arg("context") = py::none(),
nb::arg("cls"), nb::arg("value"),
nb::arg("context").none() = nb::none(),
"Creates a ComparisonType attribute with the given value.")
.def_property_readonly("value", [](MlirAttribute self) {
return toPyString(chloComparisonTypeAttrGetValue(self));
Expand Down
123 changes: 84 additions & 39 deletions stablehlo/integrations/python/StablehloApi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,22 @@ limitations under the License.

#include "stablehlo/integrations/python/StablehloApi.h"

#include <stdexcept>
#include <string>
#include <string_view>

#include "llvm/Support/raw_ostream.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "nanobind/nanobind.h"
#include "nanobind/stl/string.h"
#include "nanobind/stl/string_view.h"
#include "nanobind/stl/vector.h"
#include "stablehlo/integrations/c/StablehloApi.h"

namespace py = pybind11;
namespace nb = nanobind;

namespace mlir {
namespace stablehlo {
Expand Down Expand Up @@ -63,14 +68,18 @@ static MlirStringRef toMlirStringRef(std::string_view s) {
return mlirStringRefCreate(s.data(), s.size());
}

void AddStablehloApi(py::module &m) {
static MlirStringRef toMlirStringRef(const nb::bytes &s) {
return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size());
}

void AddStablehloApi(nb::module_ &m) {
// Portable API is a subset of StableHLO API
AddPortableApi(m);

//
// Utility APIs.
//
py::enum_<MlirStablehloCompatibilityRequirement>(
nb::enum_<MlirStablehloCompatibilityRequirement>(
m, "StablehloCompatibilityRequirement")
.value("NONE", MlirStablehloCompatibilityRequirement::NONE)
.value("WEEK_4", MlirStablehloCompatibilityRequirement::WEEK_4)
Expand All @@ -79,48 +88,57 @@ void AddStablehloApi(py::module &m) {

m.def(
"get_version_from_compatibility_requirement",
[](MlirStablehloCompatibilityRequirement requirement) -> py::str {
[](MlirStablehloCompatibilityRequirement requirement) -> std::string {
StringWriterHelper accumulator;
stablehloVersionFromCompatibilityRequirement(
requirement, accumulator.getMlirStringCallback(),
accumulator.getUserData());
return accumulator.toString();
},
py::arg("requirement"));
nb::arg("requirement"));

//
// Serialization APIs.
//
m.def(
"serialize_portable_artifact",
[](MlirModule module, std::string_view target) -> py::bytes {
[](MlirModule module, std::string_view target) -> nb::bytes {
StringWriterHelper accumulator;
if (mlirLogicalResultIsFailure(
stablehloSerializePortableArtifactFromModule(
module, toMlirStringRef(target),
accumulator.getMlirStringCallback(),
accumulator.getUserData()))) {
PyErr_SetString(PyExc_ValueError, "failed to serialize module");
return "";
throw nb::value_error("failed to serialize module");
}

return py::bytes(accumulator.toString());
std::string serialized = accumulator.toString();
return nb::bytes(serialized.data(), serialized.size());
},
py::arg("module"), py::arg("target"));
nb::arg("module"), nb::arg("target"));

m.def(
"deserialize_portable_artifact",
[](MlirContext context, std::string_view artifact) -> MlirModule {
auto module = stablehloDeserializePortableArtifactNoError(
toMlirStringRef(artifact), context);
if (mlirModuleIsNull(module)) {
PyErr_SetString(PyExc_ValueError, "failed to deserialize module");
return {};
throw nb::value_error("failed to deserialize module");
}
return module;
},
py::arg("context"), py::arg("artifact"));

nb::arg("context"), nb::arg("artifact"));
m.def(
"deserialize_portable_artifact",
[](MlirContext context, nb::bytes artifact) -> MlirModule {
auto module = stablehloDeserializePortableArtifactNoError(
toMlirStringRef(artifact), context);
if (mlirModuleIsNull(module)) {
throw nb::value_error("failed to deserialize module");
}
return module;
},
nb::arg("context"), nb::arg("artifact"));
//
// Reference APIs
//
Expand All @@ -130,9 +148,7 @@ void AddStablehloApi(py::module &m) {
std::vector<MlirAttribute> &args) -> std::vector<MlirAttribute> {
for (auto arg : args) {
if (!mlirAttributeIsADenseElements(arg)) {
PyErr_SetString(PyExc_ValueError,
"input args must be DenseElementsAttr");
return {};
throw nb::value_error("input args must be DenseElementsAttr");
}
}

Expand All @@ -141,8 +157,7 @@ void AddStablehloApi(py::module &m) {
stablehloEvalModule(module, args.size(), args.data(), &errorCode);

if (errorCode != 0) {
PyErr_SetString(PyExc_ValueError, "interpreter failed");
return {};
throw nb::value_error("interpreter failed");
}

std::vector<MlirAttribute> pyResults;
Expand All @@ -151,39 +166,39 @@ void AddStablehloApi(py::module &m) {
}
return pyResults;
},
py::arg("module"), py::arg("args"));
nb::arg("module"), nb::arg("args"));
}

void AddPortableApi(py::module &m) {
void AddPortableApi(nb::module_ &m) {
//
// Utility APIs.
//
m.def("get_api_version", []() { return stablehloGetApiVersion(); });

m.def(
"get_smaller_version",
[](const std::string &version1, const std::string &version2) -> py::str {
[](const std::string &version1,
const std::string &version2) -> std::string {
StringWriterHelper accumulator;
if (mlirLogicalResultIsFailure(stablehloGetSmallerVersion(
toMlirStringRef(version1), toMlirStringRef(version2),
accumulator.getMlirStringCallback(),
accumulator.getUserData()))) {
PyErr_SetString(PyExc_ValueError,
"failed to convert version to stablehlo version");
return "";
throw nb::value_error(
"failed to convert version to stablehlo version");
}
return accumulator.toString();
},
py::arg("version1"), py::arg("version2"));
nb::arg("version1"), nb::arg("version2"));

m.def("get_current_version", []() -> py::str {
m.def("get_current_version", []() -> std::string {
StringWriterHelper accumulator;
stablehloGetCurrentVersion(accumulator.getMlirStringCallback(),
accumulator.getUserData());
return accumulator.toString();
});

m.def("get_minimum_version", []() -> py::str {
m.def("get_minimum_version", []() -> std::string {
StringWriterHelper accumulator;
stablehloGetMinimumVersion(accumulator.getMlirStringCallback(),
accumulator.getUserData());
Expand All @@ -196,34 +211,64 @@ void AddPortableApi(py::module &m) {
m.def(
"serialize_portable_artifact_str",
[](std::string_view moduleStrOrBytecode,
std::string_view targetVersion) -> py::bytes {
std::string_view targetVersion) -> nb::bytes {
StringWriterHelper accumulator;
if (mlirLogicalResultIsFailure(
stablehloSerializePortableArtifactFromStringRef(
toMlirStringRef(moduleStrOrBytecode),
toMlirStringRef(targetVersion),
accumulator.getMlirStringCallback(),
accumulator.getUserData()))) {
throw nb::value_error("failed to serialize module");
}
std::string serialized = accumulator.toString();
return nb::bytes(serialized.data(), serialized.size());
},
nb::arg("module_str"), nb::arg("target_version"));
m.def(
"serialize_portable_artifact_str",
[](nb::bytes moduleStrOrBytecode,
std::string_view targetVersion) -> nb::bytes {
StringWriterHelper accumulator;
if (mlirLogicalResultIsFailure(
stablehloSerializePortableArtifactFromStringRef(
toMlirStringRef(moduleStrOrBytecode),
toMlirStringRef(targetVersion),
accumulator.getMlirStringCallback(),
accumulator.getUserData()))) {
PyErr_SetString(PyExc_ValueError, "failed to serialize module");
return "";
throw nb::value_error("failed to serialize module");
}
return py::bytes(accumulator.toString());
std::string serialized = accumulator.toString();
return nb::bytes(serialized.data(), serialized.size());
},
py::arg("module_str"), py::arg("target_version"));
nb::arg("module_str"), nb::arg("target_version"));

m.def(
"deserialize_portable_artifact_str",
[](std::string_view artifact) -> py::bytes {
[](std::string_view artifact) -> nb::bytes {
StringWriterHelper accumulator;
if (mlirLogicalResultIsFailure(stablehloDeserializePortableArtifact(
toMlirStringRef(artifact), accumulator.getMlirStringCallback(),
accumulator.getUserData()))) {
throw nb::value_error("failed to deserialize module");
}
std::string serialized = accumulator.toString();
return nb::bytes(serialized.data(), serialized.size());
},
nb::arg("artifact_str"));
m.def(
"deserialize_portable_artifact_str",
[](const nb::bytes &artifact) -> nb::bytes {
StringWriterHelper accumulator;
if (mlirLogicalResultIsFailure(stablehloDeserializePortableArtifact(
toMlirStringRef(artifact), accumulator.getMlirStringCallback(),
accumulator.getUserData()))) {
PyErr_SetString(PyExc_ValueError, "failed to deserialize module");
return "";
throw nb::value_error("failed to deserialize module");
}
return py::bytes(accumulator.toString());
std::string serialized = accumulator.toString();
return nb::bytes(serialized.data(), serialized.size());
},
py::arg("artifact_str"));
nb::arg("artifact_str"));
}

} // namespace stablehlo
Expand Down
8 changes: 4 additions & 4 deletions stablehlo/integrations/python/StablehloApi.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,20 @@ limitations under the License.
#ifndef STABLEHLO_INTEGRATIONS_PYTHON_API_STABLEHLOAPI_H
#define STABLEHLO_INTEGRATIONS_PYTHON_API_STABLEHLOAPI_H

#include "pybind11/pybind11.h"
#include "nanobind/nanobind.h"

namespace mlir {
namespace stablehlo {

// Add StableHLO APIs to the pybind11 module.
// Add StableHLO APIs to the nanobind module.
// Signatures of these APIs have no dependency on C++ MLIR types and all must
// use C API passthrough.
void AddStablehloApi(pybind11::module& m);
void AddStablehloApi(nanobind::module_& m);

// Adds a subset of the StableHLO API that doesn't use MLIR in any definitions,
// and is methods only, introducing no new objects / enums to avoid potential
// redefinition issues in complex build environments.
void AddPortableApi(pybind11::module& m);
void AddPortableApi(nanobind::module_& m);

} // namespace stablehlo
} // namespace mlir
Expand Down
Loading

0 comments on commit 5db82a1

Please sign in to comment.