Skip to content

Commit

Permalink
Pass manager API (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
raviqqe authored Sep 14, 2022
1 parent 8d36c17 commit 07930ad
Show file tree
Hide file tree
Showing 6 changed files with 222 additions and 10 deletions.
1 change: 1 addition & 0 deletions .cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"addf",
"addi",
"femtomc",
"indoc",
"insta",
"libm",
"linalg",
Expand Down
7 changes: 4 additions & 3 deletions src/error.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::r#type::Type;
use std::error;
use std::fmt::Display;
use std::fmt::{self, Formatter};
use std::{
error,
fmt::{self, Display, Formatter},
};

#[derive(Debug, Eq, PartialEq)]
pub enum Error<'c> {
Expand Down
37 changes: 35 additions & 2 deletions src/operation_pass_manager.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
use crate::{pass::Pass, pass_manager::PassManager, string_ref::StringRef};
use mlir_sys::{mlirOpPassManagerAddOwnedPass, mlirOpPassManagerGetNestedUnder, MlirOpPassManager};
use std::marker::PhantomData;
use mlir_sys::{
mlirOpPassManagerAddOwnedPass, mlirOpPassManagerGetNestedUnder, mlirPrintPassPipeline,
MlirOpPassManager, MlirStringRef,
};
use std::{
ffi::c_void,
fmt::{self, Display, Formatter},
marker::PhantomData,
};

/// An operation pass manager.
#[derive(Clone, Copy, Debug)]
pub struct OperationPassManager<'a> {
raw: MlirOpPassManager,
_parent: PhantomData<&'a PassManager<'a>>,
Expand All @@ -25,10 +33,35 @@ impl<'a> OperationPassManager<'a> {
unsafe { mlirOpPassManagerAddOwnedPass(self.raw, pass.to_raw()) }
}

pub(crate) unsafe fn to_raw(self) -> MlirOpPassManager {
self.raw
}

pub(crate) unsafe fn from_raw(raw: MlirOpPassManager) -> Self {
Self {
raw,
_parent: Default::default(),
}
}
}

impl<'a> Display for OperationPassManager<'a> {
fn fmt(&self, formatter: &mut Formatter) -> fmt::Result {
let mut data = (formatter, Ok(()));

unsafe extern "C" fn callback(string: MlirStringRef, data: *mut c_void) {
let data = &mut *(data as *mut (&mut Formatter, fmt::Result));
let result = write!(data.0, "{}", StringRef::from_raw(string).as_str());

if data.1.is_ok() {
data.1 = result;
}
}

unsafe {
mlirPrintPassPipeline(self.raw, Some(callback), &mut data as *mut _ as *mut c_void);
}

data.1
}
}
7 changes: 6 additions & 1 deletion src/pass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use mlir_sys::{
mlirCreateConversionConvertArithmeticToLLVM, mlirCreateConversionConvertControlFlowToLLVM,
mlirCreateConversionConvertControlFlowToSPIRV, mlirCreateConversionConvertFuncToLLVM,
mlirCreateConversionConvertMathToLLVM, mlirCreateConversionConvertMathToLibm,
mlirCreateConversionConvertMathToSPIRV, MlirPass,
mlirCreateConversionConvertMathToSPIRV, mlirCreateTransformsPrintOpStats, MlirPass,
};

/// A pass.
Expand Down Expand Up @@ -46,6 +46,11 @@ impl Pass {
Self::from_raw_fn(mlirCreateConversionConvertMathToLibm)
}

/// Creates a pass to print operation statistics.
pub fn print_operation_stats() -> Self {
Self::from_raw_fn(mlirCreateTransformsPrintOpStats)
}

// TODO Add more passes.

fn from_raw_fn(create_raw: unsafe extern "C" fn() -> MlirPass) -> Self {
Expand Down
153 changes: 151 additions & 2 deletions src/pass_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::{
};
use mlir_sys::{
mlirPassManagerAddOwnedPass, mlirPassManagerCreate, mlirPassManagerDestroy,
mlirPassManagerEnableIRPrinting, mlirPassManagerEnableVerifier,
mlirPassManagerGetAsOpPassManager, mlirPassManagerGetNestedUnder, mlirPassManagerRun,
MlirPassManager,
};
Expand Down Expand Up @@ -40,13 +41,23 @@ impl<'c> PassManager<'c> {
unsafe { mlirPassManagerAddOwnedPass(self.raw, pass.to_raw()) }
}

/// Enables a verifier.
pub fn enable_verifier(&self, enabled: bool) {
unsafe { mlirPassManagerEnableVerifier(self.raw, enabled) }
}

/// Enables IR printing.
pub fn enable_ir_printing(&self) {
unsafe { mlirPassManagerEnableIRPrinting(self.raw) }
}

/// Runs passes added to a pass manager against a module.
pub fn run(&self, module: &Module) -> LogicalResult {
LogicalResult::from_raw(unsafe { mlirPassManagerRun(self.raw, module.to_raw()) })
}

/// Converts a pass manager to an operation pass manager.
pub fn as_operation_pass_manager(&mut self) -> OperationPassManager {
pub fn as_operation_pass_manager(&self) -> OperationPassManager {
unsafe { OperationPassManager::from_raw(mlirPassManagerGetAsOpPassManager(self.raw)) }
}
}
Expand All @@ -60,7 +71,19 @@ impl<'c> Drop for PassManager<'c> {
#[cfg(test)]
mod tests {
use super::*;
use crate::location::Location;
use crate::{
dialect_registry::DialectRegistry,
location::Location,
utility::{parse_pass_pipeline, register_all_dialects, register_print_operation_stats},
};
use indoc::indoc;
use pretty_assertions::assert_eq;

fn register_all_upstream_dialects(context: &Context) {
let registry = DialectRegistry::new();
register_all_dialects(&registry);
context.append_dialect_registry(&registry);
}

#[test]
fn new() {
Expand All @@ -76,6 +99,21 @@ mod tests {
PassManager::new(&context).add_pass(Pass::convert_func_to_llvm());
}

#[test]
fn enable_verifier() {
let context = Context::new();

PassManager::new(&context).enable_verifier(true);
}

// TODO Enable this test.
// #[test]
// fn enable_ir_printing() {
// let context = Context::new();

// PassManager::new(&context).enable_ir_printing();
// }

#[test]
fn run() {
let context = Context::new();
Expand All @@ -84,4 +122,115 @@ mod tests {
manager.add_pass(Pass::convert_func_to_llvm());
manager.run(&Module::new(Location::unknown(&context)));
}

#[test]
fn run_on_function() {
let context = Context::new();
register_all_upstream_dialects(&context);

let module = Module::parse(
&context,
indoc!(
"
func.func @foo(%arg0 : i32) -> i32 {
%res = arith.addi %arg0, %arg0 : i32
return %res : i32
}
"
),
);

let manager = PassManager::new(&context);
manager.add_pass(Pass::print_operation_stats());

assert!(manager.run(&module).is_success());
}

#[test]
fn run_on_function_in_nested_module() {
let context = Context::new();
register_all_upstream_dialects(&context);

let module = Module::parse(
&context,
indoc!(
"
func.func @foo(%arg0 : i32) -> i32 {
%res = arith.addi %arg0, %arg0 : i32
return %res : i32
}
module {
func.func @bar(%arg0 : f32) -> f32 {
%res = arith.addf %arg0, %arg0 : f32
return %res : f32
}
}
"
),
);

let manager = PassManager::new(&context);
manager
.nested_under("func.func")
.add_pass(Pass::print_operation_stats());

assert!(manager.run(&module).is_success());

let manager = PassManager::new(&context);
manager
.nested_under("builtin.module")
.nested_under("func.func")
.add_pass(Pass::print_operation_stats());

assert!(manager.run(&module).is_success());
}

#[test]
fn print_pass_pipeline() {
let context = Context::new();
let manager = PassManager::new(&context);
let module_manager = manager.nested_under("builtin.module");
let function_manager = module_manager.nested_under("func.func");

function_manager.add_pass(Pass::print_operation_stats());

assert_eq!(
manager.as_operation_pass_manager().to_string(),
"builtin.module(func.func(print-op-stats{json=false}))"
);
assert_eq!(
module_manager.to_string(),
"func.func(print-op-stats{json=false})"
);
assert_eq!(function_manager.to_string(), "print-op-stats{json=false}");
}

#[test]
fn parse_pass_pipeline_() {
let context = Context::new();
let manager = PassManager::new(&context);

assert!(parse_pass_pipeline(
manager.as_operation_pass_manager(),
"builtin.module(func.func(print-op-stats{json=false}),\
func.func(print-op-stats{json=false}))"
)
.is_failure());

register_print_operation_stats();

assert!(parse_pass_pipeline(
manager.as_operation_pass_manager(),
"builtin.module(func.func(print-op-stats{json=false}),\
func.func(print-op-stats{json=false}))"
)
.is_success());

assert_eq!(
manager.as_operation_pass_manager().to_string(),
"builtin.module(func.func(print-op-stats{json=false}),\
func.func(print-op-stats{json=false}))"
);
}
}
27 changes: 25 additions & 2 deletions src/utility.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
use crate::{context::Context, dialect_registry::DialectRegistry};
use mlir_sys::{mlirRegisterAllDialects, mlirRegisterAllLLVMTranslations, mlirRegisterAllPasses};
use crate::{
context::Context, dialect_registry::DialectRegistry, logical_result::LogicalResult,
operation_pass_manager::OperationPassManager, string_ref::StringRef,
};
use mlir_sys::{
mlirParsePassPipeline, mlirRegisterAllDialects, mlirRegisterAllLLVMTranslations,
mlirRegisterAllPasses, mlirRegisterTransformsCSE, mlirRegisterTransformsPrintOpStats,
};
use std::sync::Once;

/// Registers all dialects to a dialect registry.
Expand All @@ -20,6 +26,23 @@ pub fn register_all_passes() {
ONCE.call_once(|| unsafe { mlirRegisterAllPasses() });
}

/// Parses a pass pipeline.
pub fn parse_pass_pipeline(manager: OperationPassManager, source: &str) -> LogicalResult {
LogicalResult::from_raw(unsafe {
mlirParsePassPipeline(manager.to_raw(), StringRef::from(source).to_raw())
})
}

/// Registers a pass to print operation stats.
pub fn register_print_operation_stats() {
unsafe { mlirRegisterTransformsPrintOpStats() }
}

/// Registers a pass to print operation stats.
pub fn register_cse() {
unsafe { mlirRegisterTransformsCSE() }
}

// TODO Use into_raw_parts.
pub(crate) unsafe fn into_raw_array<T>(xs: Vec<T>) -> *mut T {
xs.leak().as_mut_ptr()
Expand Down

0 comments on commit 07930ad

Please sign in to comment.