Skip to content

Commit

Permalink
Plumb session flags and a first draft ctypes based Python API. (#93)
Browse files Browse the repository at this point in the history
* Wires the flag --openxla-partitioner-gspmd-num-partitions for both CLI
and API use.
* The Python API is used to test the flags functionality.
* A followup can implement the rest of the Python API and niceties for
packaging/releasing it.
* The partitioner is now minimally functional for both API and CLI use.

Progress on #82
  • Loading branch information
Stella Laurenzo authored May 8, 2023
1 parent 9e62002 commit d773141
Show file tree
Hide file tree
Showing 9 changed files with 774 additions and 23 deletions.
17 changes: 17 additions & 0 deletions partitioner/BUILD → partitioner/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,21 @@ cc_library(
],
)

cc_library(
name = "support",
srcs = [
"src/openxla/partitioner/Support/OptionUtils.cpp",
],
hdrs = [
"src/openxla/partitioner/Support/OptionUtils.h",
],
deps = [
":defs",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Support",
],
)

cc_library(
name = "gspmd_pipeline",
srcs = [
Expand All @@ -34,6 +49,7 @@ cc_library(
],
deps = [
":defs",
":support",
"@xla//xla/hlo/transforms:hlo_constant_splitter",
"@xla//xla/mlir_hlo:hlo_legalize_to_stablehlo",
"@xla//xla/mlir_hlo:mhlo_passes",
Expand Down Expand Up @@ -78,6 +94,7 @@ cc_library(
":c_headers",
":defs",
":gspmd_pipeline",
":support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ControlFlowDialect",
"@llvm-project//mlir:FuncDialect",
Expand Down
7 changes: 7 additions & 0 deletions partitioner/bindings/python/openxla/partitioner/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright 2023 The OpenXLA Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ._dl import *
125 changes: 125 additions & 0 deletions partitioner/bindings/python/openxla/partitioner/_dl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright 2023 The OpenXLA Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Sequence

import ctypes
import os

__all__ = [
"Invocation",
"Session",
]

_dylib = None

_GET_FLAG_CALLBACK = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_size_t,
ctypes.c_void_p)


def _setsig(f, restype, argtypes):
f.restype = restype
f.argtypes = argtypes


def _init_dylib():
global _dylib
if _dylib:
return
dylib_path = os.getenv("OPENXLA_PARTITIONER_LIB_PATH")
if dylib_path is None:
# TODO: Look for a bundled dylib.
raise RuntimeError("Could not find libOpenXLAPartitioner.so: "
"set OPENXLA_PARTITIONER_LIB_PATH")
_dylib = ctypes.cdll.LoadLibrary(dylib_path)

# Setup signatures.
_setsig(_dylib.openxlaPartitionerErrorDestroy, None, [ctypes.c_void_p])
_setsig(_dylib.openxlaPartitionerErrorGetMessage, ctypes.c_char_p,
[ctypes.c_void_p])
_setsig(_dylib.openxlaPartitionerInvocationCreate, ctypes.c_void_p,
[ctypes.c_void_p])
_setsig(_dylib.openxlaPartitionerInvocationDestroy, None, [ctypes.c_void_p])
_setsig(_dylib.openxlaPartitionerSessionCreate, ctypes.c_void_p, [])
_setsig(_dylib.openxlaPartitionerSessionDestroy, None, [ctypes.c_void_p])
_setsig(_dylib.openxlaPartitionerSessionGetFlags, None,
[ctypes.c_void_p, ctypes.c_bool, ctypes.c_void_p, ctypes.c_void_p])
_setsig(_dylib.openxlaPartitionerSessionSetFlags, ctypes.c_void_p,
[ctypes.c_void_p, ctypes.c_int, ctypes.c_void_p])


def _handle_error(err_p, exc_type=ValueError):
if err_p is None:
return
message = _dylib.openxlaPartitionerErrorGetMessage(err_p).decode("UTF-8")
_dylib.openxlaPartitionerErrorDestroy(err_p)
raise exc_type(message)


def _global_initialize():
_dylib.openxlaPartitionerGlobalInitialize()


def _global_shutdown():
_dylib.openxlaPartitionerGlobalShutdown()


class _GlobalInit:

def __init__(self):
_init_dylib()
_dylib.openxlaPartitionerGlobalInitialize()

def __del__(self):
_dylib.openxlaPartitionerGlobalShutdown()


# Keep one reference for the life of the module.
_global_init = _GlobalInit()


class Session:

def __init__(self):
self._global_init = _global_init
self._session_p = _dylib.openxlaPartitionerSessionCreate()

def __del__(self):
_dylib.openxlaPartitionerSessionDestroy(self._session_p)

def invocation(self):
return Invocation(self)

def get_flags(self, non_default_only: bool = False) -> Sequence[str]:
results = []

@_GET_FLAG_CALLBACK
def callback(flag_pointer, length, user_data):
flag_bytes = ctypes.string_at(flag_pointer, length)
flag_value = flag_bytes.decode("UTF-8")
results.append(flag_value)

_dylib.openxlaPartitionerSessionGetFlags(self._session_p, non_default_only,
callback, ctypes.c_void_p(0))
return results

def set_flags(self, *flags: Sequence[str]):
argv_type = ctypes.c_char_p * len(flags)
argv = argv_type(*[flag.encode("UTF-8") for flag in flags])
_handle_error(
_dylib.openxlaPartitionerSessionSetFlags(self._session_p, len(argv),
argv))


class Invocation:

def __init__(self, session: Session):
self._session = session
self._inv_p = _dylib.openxlaPartitionerInvocationCreate(
self._session._session_p)

def __del__(self):
_dylib.openxlaPartitionerInvocationDestroy(self._inv_p)
50 changes: 50 additions & 0 deletions partitioner/bindings/python/test/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2023 The OpenXLA Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import unittest

from openxla.partitioner import *


class FlagsTest(unittest.TestCase):

def testDefaultFlags(self):
session = Session()
flags = session.get_flags()
self.assertIn("--openxla-partitioner-gspmd-num-partitions=1", flags)

def testNonDefaultFlags(self):
session = Session()
flags = session.get_flags(non_default_only=True)
self.assertEqual(flags, [])
session.set_flags("--openxla-partitioner-gspmd-num-partitions=2")
flags = session.get_flags(non_default_only=True)
self.assertIn("--openxla-partitioner-gspmd-num-partitions=2", flags)

def testFlagsAreScopedToSession(self):
session1 = Session()
session2 = Session()
session1.set_flags("--openxla-partitioner-gspmd-num-partitions=2")
session2.set_flags("--openxla-partitioner-gspmd-num-partitions=3")
self.assertIn("--openxla-partitioner-gspmd-num-partitions=2",
session1.get_flags())
self.assertIn("--openxla-partitioner-gspmd-num-partitions=3",
session2.get_flags())

def testFlagError(self):
session = Session()
with self.assertRaises(ValueError):
session.set_flags("--does-not-exist=1")

class InvocationTest(unittest.TestCase):

def testCreate(self):
session = Session()
inv = session.invocation()


if __name__ == "__main__":
unittest.main()
73 changes: 58 additions & 15 deletions partitioner/src/openxla/partitioner/EmbeddingLib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ struct GlobalInit {
// Stash the revision for the life of the instance.
// TODO: Get release revision stamp.
std::string revision;

// Our session options can optionally be bound to the global command-line
// environment. If that is not the case, then these will be nullptr, and
// they should be default initialized at the session level.
GSPMDOptions *clGSPMDOptions = nullptr;
};

GlobalInit::GlobalInit() {
Expand Down Expand Up @@ -156,18 +161,60 @@ void GlobalInit::registerCommandLineOptions() {
// Register pass manager command-line options like -mlir-print-ir-*.
mlir::registerPassManagerCLOptions();
mlir::registerDefaultTimingManagerCLOptions();

// Bind session options to the command line environment.
clGSPMDOptions = &GSPMDOptions::FromFlags::get();
}

// Sessions bring together an initialized context and set of flags.
struct Session {
Session(GlobalInit &globalInit);

Error *setFlags(int argc, const char *const *argv) {
std::string errorMessage;
auto callback = [&](llvm::StringRef message) {
if (errorMessage.empty()) {
errorMessage = "Error parsing flags:";
}
errorMessage.append("\n ");
errorMessage.append(message.data(), message.size());
};

if (failed(binder.parseArguments(argc, argv, callback))) {
return new Error(std::move(errorMessage));
}
return nullptr;
}

void getFlags(bool nonDefaultOnly,
void (*onFlag)(const char *flag, size_t length, void *),
void *userData) {
auto flagVector = binder.printArguments(nonDefaultOnly);
for (std::string &value : flagVector) {
onFlag(value.c_str(), value.size(), userData);
}
}

GlobalInit &globalInit;
support::OptionsBinder binder;
MLIRContext context;

// Options structs.
GSPMDOptions gspmdOptions;
};

Session::Session(GlobalInit &globalInit) : globalInit(globalInit) {
Session::Session(GlobalInit &globalInit)
: globalInit(globalInit), binder(support::OptionsBinder::local()) {
context.appendDialectRegistry(globalInit.getRegistry());

// Bootstrap session options from the cl environment, if enabled.
if (globalInit.usesCommandLine) {
gspmdOptions = *globalInit.clGSPMDOptions;
}

// Register each options struct with the binder so we can manipulate
// mnemonically via the API.
gspmdOptions.bindOptions(binder);
}

// A source is instantiated against a session and is used to access an llvm
Expand Down Expand Up @@ -413,7 +460,7 @@ bool Invocation::runPipeline(llvm::StringRef pipeline) {
}

bool Invocation::runGSPMDPipeline() {
buildGSPMDPipeline(passManager);
buildGSPMDPipeline(passManager, session.gspmdOptions);
passManager.enableVerifier(enableVerifier);
if (failed(passManager.run(parsedModule.get()))) {
return false;
Expand Down Expand Up @@ -583,20 +630,16 @@ void openxlaPartitionerSessionDestroy(openxla_partitioner_session_t *session) {
delete unwrap(session);
}

// TODO: Finish implementing
// openxla_partitioner_error_t *openxlaPartitionerSessionSetFlags(
// openxla_partitioner_session_t *session, int argc, const char *const
// *argv) {
// return wrap(unwrap(session)->setFlags(argc, argv));
// }
openxla_partitioner_error_t *openxlaPartitionerSessionSetFlags(
openxla_partitioner_session_t *session, int argc, const char *const *argv) {
return wrap(unwrap(session)->setFlags(argc, argv));
}

// TODO: Finish implementing
// void openxlaPartitionerSessionGetFlags(
// openxla_partitioner_session_t *session, bool nonDefaultOnly,
// void (*onFlag)(const char *flag, size_t length, void *), void *userData)
// {
// unwrap(session)->getFlags(nonDefaultOnly, onFlag, userData);
// }
void openxlaPartitionerSessionGetFlags(
openxla_partitioner_session_t *session, bool nonDefaultOnly,
void (*onFlag)(const char *flag, size_t length, void *), void *userData) {
unwrap(session)->getFlags(nonDefaultOnly, onFlag, userData);
}

openxla_partitioner_invocation_t *openxlaPartitionerInvocationCreate(
openxla_partitioner_session_t *session) {
Expand Down
Loading

0 comments on commit d773141

Please sign in to comment.