Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Block API #34

Merged
merged 11 commits into from
Sep 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "melior"
description = "The rustic MLIR bindings in Rust"
version = "0.1.0"
edition = "2021"
license-file = "LICENSE"
license = "Apache-2.0"
repository = "https://github.com/raviqqe/melior"

[dependencies]
Expand Down
268 changes: 234 additions & 34 deletions src/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,25 @@ use crate::{
operation::{Operation, OperationRef},
r#type::Type,
region::RegionRef,
string_ref::StringRef,
utility::into_raw_array,
value::Value,
};
use mlir_sys::{
mlirBlockAddArgument, mlirBlockAppendOwnedOperation, mlirBlockCreate, mlirBlockDestroy,
mlirBlockEqual, mlirBlockGetArgument, mlirBlockGetFirstOperation, mlirBlockGetNumArguments,
mlirBlockGetParentRegion, mlirBlockInsertOwnedOperation, MlirBlock,
mlirBlockDetach, mlirBlockEqual, mlirBlockGetArgument, mlirBlockGetFirstOperation,
mlirBlockGetNextInRegion, mlirBlockGetNumArguments, mlirBlockGetParentOperation,
mlirBlockGetParentRegion, mlirBlockGetTerminator, mlirBlockInsertOwnedOperation,
mlirBlockInsertOwnedOperationAfter, mlirBlockInsertOwnedOperationBefore, mlirBlockPrint,
MlirBlock, MlirStringRef,
};
use std::{
ffi::c_void,
fmt::{self, Display, Formatter},
marker::PhantomData,
mem::forget,
ops::Deref,
};
use std::{marker::PhantomData, mem::forget, ops::Deref};

/// A block
#[derive(Debug)]
Expand All @@ -25,24 +35,28 @@ impl<'c> Block<'c> {
/// Creates a block.
pub fn new(arguments: &[(Type<'c>, Location<'c>)]) -> Self {
unsafe {
Self {
r#ref: BlockRef::from_raw(mlirBlockCreate(
arguments.len() as isize,
into_raw_array(
arguments
.iter()
.map(|(argument, _)| argument.to_raw())
.collect(),
),
into_raw_array(
arguments
.iter()
.map(|(_, location)| location.to_raw())
.collect(),
),
)),
_context: Default::default(),
}
Self::from_raw(mlirBlockCreate(
arguments.len() as isize,
into_raw_array(
arguments
.iter()
.map(|(argument, _)| argument.to_raw())
.collect(),
),
into_raw_array(
arguments
.iter()
.map(|(_, location)| location.to_raw())
.collect(),
),
))
}
}

pub(crate) unsafe fn from_raw(raw: MlirBlock) -> Self {
Self {
r#ref: BlockRef::from_raw(raw),
_context: Default::default(),
}
}

Expand Down Expand Up @@ -78,8 +92,6 @@ impl<'c> Deref for Block<'c> {
}

/// A reference of a block.
// TODO Should we split context lifetimes? Or, is it transitively proven that 'c
// > 'a?
#[derive(Clone, Copy, Debug)]
pub struct BlockRef<'a> {
raw: MlirBlock,
Expand All @@ -106,11 +118,6 @@ impl<'c> BlockRef<'c> {
unsafe { mlirBlockGetNumArguments(self.raw) as usize }
}

/// Gets a parent region.
pub fn parent_region(&self) -> Option<RegionRef> {
unsafe { RegionRef::from_option_raw(mlirBlockGetParentRegion(self.raw)) }
}

/// Gets the first operation.
pub fn first_operation(&self) -> Option<OperationRef> {
unsafe {
Expand All @@ -124,6 +131,21 @@ impl<'c> BlockRef<'c> {
}
}

/// Gets a terminator operation.
pub fn terminator(&self) -> Option<OperationRef> {
unsafe { OperationRef::from_option_raw(mlirBlockGetTerminator(self.raw)) }
}

/// Gets a parent region.
pub fn parent_region(&self) -> Option<RegionRef> {
unsafe { RegionRef::from_option_raw(mlirBlockGetParentRegion(self.raw)) }
}

/// Gets a parent operation.
pub fn parent_operation(&self) -> Option<OperationRef> {
unsafe { OperationRef::from_option_raw(mlirBlockGetParentOperation(self.raw)) }
}

/// Adds an argument.
pub fn add_argument(&self, r#type: Type<'c>, location: Location<'c>) -> Value {
unsafe {
Expand All @@ -135,6 +157,17 @@ impl<'c> BlockRef<'c> {
}
}

/// Appends an operation.
pub fn append_operation(&self, operation: Operation) -> OperationRef {
unsafe {
let operation = operation.into_raw();

mlirBlockAppendOwnedOperation(self.raw, operation);

OperationRef::from_raw(operation)
}
}

/// Inserts an operation.
// TODO How can we make those update functions take `&mut self`?
// TODO Use cells?
Expand All @@ -148,17 +181,46 @@ impl<'c> BlockRef<'c> {
}
}

/// Appends an operation.
pub fn append_operation(&self, operation: Operation) -> OperationRef {
/// Inserts an operation after another.
pub fn insert_operation_after(&self, one: OperationRef, other: Operation) -> OperationRef {
unsafe {
let operation = operation.into_raw();
let other = other.into_raw();

mlirBlockAppendOwnedOperation(self.raw, operation);
mlirBlockInsertOwnedOperationAfter(self.raw, one.to_raw(), other);

OperationRef::from_raw(operation)
OperationRef::from_raw(other)
}
}

/// Inserts an operation before another.
pub fn insert_operation_before(&self, one: OperationRef, other: Operation) -> OperationRef {
unsafe {
let other = other.into_raw();

mlirBlockInsertOwnedOperationBefore(self.raw, one.to_raw(), other);

OperationRef::from_raw(other)
}
}

/// Detaches a block from a region and assumes its ownership.
pub fn detach(&self) -> Option<Block> {
if self.parent_region().is_some() {
unsafe {
mlirBlockDetach(self.raw);

Some(Block::from_raw(self.raw))
}
} else {
None
}
}

/// Gets a next block in a region.
pub fn next_in_region(&self) -> Option<BlockRef> {
unsafe { BlockRef::from_option_raw(mlirBlockGetNextInRegion(self.raw)) }
}

pub(crate) unsafe fn from_raw(raw: MlirBlock) -> Self {
Self {
raw,
Expand Down Expand Up @@ -187,10 +249,34 @@ impl<'a> PartialEq for BlockRef<'a> {

impl<'a> Eq for BlockRef<'a> {}

impl<'a> Display for BlockRef<'a> {
fn fmt(&self, formatter: &mut Formatter) -> fmt::Result {
let mut data = (formatter, Ok(()));

unsafe extern "C" fn callback(string: MlirStringRef, data: *mut c_void) {
let data = &mut *(data as *mut (&mut Formatter, fmt::Result));
let result = write!(data.0, "{}", StringRef::from_raw(string).as_str());

if data.1.is_ok() {
data.1 = result;
}
}

unsafe {
mlirBlockPrint(self.raw, Some(callback), &mut data as *mut _ as *mut c_void);
}

data.1
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{operation_state::OperationState, region::Region};
use crate::{
dialect_registry::DialectRegistry, module::Module, operation_state::OperationState,
region::Region, utility::register_all_dialects,
};

#[test]
fn new() {
Expand Down Expand Up @@ -236,6 +322,48 @@ mod tests {
assert_eq!(block.parent_region(), None);
}

#[test]
fn parent_operation() {
let context = Context::new();
let module = Module::new(Location::unknown(&context));

assert_eq!(
module.body().parent_operation(),
Some(module.as_operation())
);
}

#[test]
fn parent_operation_none() {
let block = Block::new(&[]);

assert_eq!(block.parent_operation(), None);
}

#[test]
fn terminator() {
let registry = DialectRegistry::new();
register_all_dialects(&registry);

let context = Context::new();
context.append_dialect_registry(&registry);
context.load_all_available_dialects();

let block = Block::new(&[]);

let operation = block.append_operation(Operation::new(OperationState::new(
"func.return",
Location::unknown(&context),
)));

assert_eq!(block.terminator(), Some(operation));
}

#[test]
fn terminator_none() {
assert_eq!(Block::new(&[]).terminator(), None);
}

#[test]
fn first_operation() {
let context = Context::new();
Expand Down Expand Up @@ -277,4 +405,76 @@ mod tests {
Operation::new(OperationState::new("foo", Location::unknown(&context))),
);
}

#[test]
fn insert_operation_after() {
let context = Context::new();
let block = Block::new(&[]);

let first_operation = block.append_operation(Operation::new(OperationState::new(
"foo",
Location::unknown(&context),
)));
let second_operation = block.insert_operation_after(
first_operation,
Operation::new(OperationState::new("foo", Location::unknown(&context))),
);

assert_eq!(block.first_operation(), Some(first_operation));
assert_eq!(
block.first_operation().unwrap().next_in_block(),
Some(second_operation)
);
}

#[test]
fn insert_operation_before() {
let context = Context::new();
let block = Block::new(&[]);

let second_operation = block.append_operation(Operation::new(OperationState::new(
"foo",
Location::unknown(&context),
)));
let first_operation = block.insert_operation_before(
second_operation,
Operation::new(OperationState::new("foo", Location::unknown(&context))),
);

assert_eq!(block.first_operation(), Some(first_operation));
assert_eq!(
block.first_operation().unwrap().next_in_block(),
Some(second_operation)
);
}

#[test]
fn next_in_region() {
let region = Region::new();

let first_block = region.append_block(Block::new(&[]));
let second_block = region.append_block(Block::new(&[]));

assert_eq!(first_block.next_in_region(), Some(second_block));
}

#[test]
fn detach() {
let region = Region::new();
let block = region.append_block(Block::new(&[]));

assert_eq!(block.detach().unwrap().to_string(), "<<UNLINKED BLOCK>>\n");
}

#[test]
fn detach_detached() {
let block = Block::new(&[]);

assert!(block.detach().is_none());
}

#[test]
fn display() {
assert_eq!(Block::new(&[]).to_string(), "<<UNLINKED BLOCK>>\n");
}
}
Loading