diff --git a/.gitignore b/.gitignore index 6936990..98e5fcf 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,3 @@ -/target +target **/*.rs.bk Cargo.lock diff --git a/Cargo.toml b/Cargo.toml index 630c352..a22e1c0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ edition = "2018" [dependencies] rustc-hash = "1.0.0" -syn = "1.0.8" +syn = { version = "1.0.8", features = ["extra-traits"] } proc-macro2 = { version = "1.0.6", features = ["span-locations"] } quote = "1.0.2" log = "0.4.8" diff --git a/datapond-derive/Cargo.toml b/datapond-derive/Cargo.toml new file mode 100644 index 0000000..216aa51 --- /dev/null +++ b/datapond-derive/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "datapond-derive" +version = "0.1.0" +authors = ["Vytautas Astrauskas "] +edition = "2018" + +[dependencies] +proc-macro-hack = "0.5" +datapond-macro = { path = "../datapond-macro" } + +[dev-dependencies] +trybuild = "1.0" +datafrog = "2" diff --git a/datapond-derive/src/lib.rs b/datapond-derive/src/lib.rs new file mode 100644 index 0000000..803c5bd --- /dev/null +++ b/datapond-derive/src/lib.rs @@ -0,0 +1,4 @@ +use proc_macro_hack::proc_macro_hack; + +#[proc_macro_hack] +pub use datapond_macro::datapond; diff --git a/datapond-derive/tests/fail/arg_mismatch.rs b/datapond-derive/tests/fail/arg_mismatch.rs new file mode 100644 index 0000000..a8f3a03 --- /dev/null +++ b/datapond-derive/tests/fail/arg_mismatch.rs @@ -0,0 +1,11 @@ +use datapond_derive::datapond; + +fn main() { + let inp = vec![(1, 2), (2, 3)]; + let out; + datapond! { + input inp(x: u32, y: u32, z: u32) + output out(x: u32, y: u32) + out(x, y) :- inp(y, x). + }; +} diff --git a/datapond-derive/tests/fail/arg_mismatch.stderr b/datapond-derive/tests/fail/arg_mismatch.stderr new file mode 100644 index 0000000..9e6f671 --- /dev/null +++ b/datapond-derive/tests/fail/arg_mismatch.stderr @@ -0,0 +1,11 @@ +error: Wrong number of arguments for inp: expected 2, found 3. + --> $DIR/arg_mismatch.rs:9:22 + | +9 | out(x, y) :- inp(y, x). + | ^^^ + +error: The predicate inp was declared here. + --> $DIR/arg_mismatch.rs:7:15 + | +7 | input inp(x: u32, y: u32, z: u32) + | ^^^ diff --git a/datapond-derive/tests/fail/kwargs.rs b/datapond-derive/tests/fail/kwargs.rs new file mode 100644 index 0000000..69836c6 --- /dev/null +++ b/datapond-derive/tests/fail/kwargs.rs @@ -0,0 +1,28 @@ +use datapond_derive::datapond; + +fn test1() { + let inp = vec![(1, 2, 0), (2, 3, 0)]; + let out; + datapond! { + input inp(x: u32, y: u32, z: u32) + output out(x: u32, y: u32) + out(x, y) :- inp(.y=y, .y=x). + }; + assert_eq!(out.len(), 2); +} + +fn test2() { + let inp = vec![(1, 2, 0), (2, 3, 0)]; + let out; + datapond! { + input inp(x: u32, y: u32, z: u32) + output out(x: u32, y: u32) + out(x, y) :- inp(.a=y, .y=x). + }; + assert_eq!(out.len(), 2); +} + +fn main() { + test1(); + test2(); +} \ No newline at end of file diff --git a/datapond-derive/tests/fail/kwargs.stderr b/datapond-derive/tests/fail/kwargs.stderr new file mode 100644 index 0000000..1937032 --- /dev/null +++ b/datapond-derive/tests/fail/kwargs.stderr @@ -0,0 +1,11 @@ +error: Parameter already bound: y + --> $DIR/kwargs.rs:9:33 + | +9 | out(x, y) :- inp(.y=y, .y=x). + | ^ + +error: Unknown parameter a in predicate inp. Available parameters are: x,y,z. + --> $DIR/kwargs.rs:20:27 + | +20 | out(x, y) :- inp(.a=y, .y=x). + | ^ diff --git a/datapond-derive/tests/pass/example1.rs b/datapond-derive/tests/pass/example1.rs new file mode 100644 index 0000000..22e1a7d --- /dev/null +++ b/datapond-derive/tests/pass/example1.rs @@ -0,0 +1,81 @@ +use datapond_derive::datapond; + +#[derive(PartialOrd, Ord, PartialEq, Eq, Clone, Copy)] +struct Origin(u64); +#[derive(PartialOrd, Ord, PartialEq, Eq, Clone, Copy)] +struct Loan(u64); +#[derive(PartialOrd, Ord, PartialEq, Eq, Clone, Copy)] +struct Point(u64); + + +fn main() { + let borrow_region = vec![]; + let cfg_edge = vec![]; + let killed = vec![]; + let outlives = vec![]; + let region_live_at = vec![]; + let invalidates = vec![]; + let errors; + datapond! { + input borrow_region(O: Origin, L: Loan, P: Point) + input cfg_edge(P: Point, Q: Point) + input killed(L: Loan, P: Point) + input outlives(O1: Origin, O2: Origin, P: Point) + input region_live_at(O: Origin, P: Point) + input invalidates(L: Loan, P: Point) + internal subset(O1: Origin, O2: Origin, P: Point) + internal requires(O: Origin, L: Loan, P: Point) + internal borrow_live_at(L: Loan, P: Point) + internal equals(O1: Origin, O2: Origin, P: Point) + output errors(L: Loan, P: Point) + + // R1 + subset(O1, O2, P) :- outlives(O1, O2, P). + + // R2 + subset(O1, O3, P) :- + subset(O1, O2, P), + outlives(O2, O3, P). + + // R3: this is the transitive relation + equals(O1, O2, P) :- + subset(O1, O2, P), + subset(O2, O1, P). + + // R4 + equals(O1, O2, Q) :- + equals(O1, O2, P), + cfg_edge(P, Q). + + // R5 + requires(O2, L, P) :- + requires(O1, L, P), + equals(O1, O2, P). + + // R6 + requires(O, L, P) :- borrow_region(O, L, P). + + // R7 + requires(O2, L, P) :- + requires(O1, L, P), + subset(O1, O2, P). + + // R8 + requires(O, L, Q) :- + requires(O, L, P), + !killed(L, P), + cfg_edge(P, Q), + region_live_at(O, Q). + + // R9 + borrow_live_at(L, P) :- + requires(O, L, P), + region_live_at(O, P). + + // R10 + errors(L, P) :- + borrow_live_at(L, P), + invalidates(L, P). + }; + assert!(errors.is_empty()); +} diff --git a/datapond-derive/tests/pass/kwargs.rs b/datapond-derive/tests/pass/kwargs.rs new file mode 100644 index 0000000..fbec88a --- /dev/null +++ b/datapond-derive/tests/pass/kwargs.rs @@ -0,0 +1,23 @@ +use datapond_derive::datapond; + +fn main() { + let inp = vec![(1, 2, 0), (2, 3, 0)]; + let out; + let out2; + datapond! { + input inp(x: u32, y: u32, z: u32) + + output out(x: u32, y: u32) + out(x, y) :- inp(.y=y, .x=x). + + output out2(x: u32, y: u32) + out2(a, b) :- inp(.y=a, .x=b). + }; + assert_eq!(out.len(), 2); + assert_eq!(out[0], (1, 2)); + assert_eq!(out[1], (2, 3)); + + assert_eq!(out2.len(), 2); + assert_eq!(out2[0], (2, 1)); + assert_eq!(out2[1], (3, 2)); +} diff --git a/datapond-derive/tests/pass/missing_args.rs b/datapond-derive/tests/pass/missing_args.rs new file mode 100644 index 0000000..925ec69 --- /dev/null +++ b/datapond-derive/tests/pass/missing_args.rs @@ -0,0 +1,31 @@ +use datapond_derive::datapond; + +fn test1() { + let inp = vec![(1, 2), (2, 3)]; + let out; + datapond! { + input inp(x: u32, y: u32) + output out(x: u32) + out(x) :- inp(x, _). + }; + assert!(out.len() == 2); + assert!(out[0] == (1,)); + assert!(out[1] == (2,)); +} + +fn test2() { + let inp = vec![(1, 2), (2, 3)]; + let out; + datapond! { + input inp(x: u32, y: u32) + output out(x: u32) + out(x) :- inp(x, _), inp(_, x). + }; + assert!(out.len() == 1); + assert!(out[0] == (2,)); +} + +fn main() { + test1(); + test2(); +} diff --git a/datapond-derive/tests/pass/negation.rs b/datapond-derive/tests/pass/negation.rs new file mode 100644 index 0000000..646643f --- /dev/null +++ b/datapond-derive/tests/pass/negation.rs @@ -0,0 +1,15 @@ +use datapond_derive::datapond; + +fn main() { + let inp = vec![(1, 2), (2, 3)]; + let kill = vec![(3,), (4,), (5,)]; + let out; + datapond! { + input inp(x: u32, y: u32) + input kill(y: u32) + output out(x: u32, y: u32) + out(x, y) :- inp(x, y), !kill(y). + }; + assert!(out.len() == 1); + assert!(out[0] == (1, 2)); +} diff --git a/datapond-derive/tests/pass/simple1.rs b/datapond-derive/tests/pass/simple1.rs new file mode 100644 index 0000000..2701ddd --- /dev/null +++ b/datapond-derive/tests/pass/simple1.rs @@ -0,0 +1,14 @@ +use datapond_derive::datapond; + +fn main() { + let inp = vec![(1, 2), (2, 3)]; + let out; + datapond! { + input inp(x: u32, y: u32) + output out(x: u32, y: u32) + out(x, y) :- inp(y, x). + }; + assert!(out.len() == 2); + assert!(out[0] == (2, 1)); + assert!(out[1] == (3, 2)); +} diff --git a/datapond-derive/tests/pass/transitive_closure.rs b/datapond-derive/tests/pass/transitive_closure.rs new file mode 100644 index 0000000..ae2b3a8 --- /dev/null +++ b/datapond-derive/tests/pass/transitive_closure.rs @@ -0,0 +1,16 @@ +use datapond_derive::datapond; + +fn main() { + let inp = vec![(1, 2), (2, 3)]; + let out; + datapond! { + input inp(x: u32, y: u32) + output out(x: u32, y: u32) + out(x, y) :- inp(x, y). + out(x, y) :- out(x, z), out(z, y). + }; + assert!(out.len() == 3); + assert!(out[0] == (1, 2)); + assert!(out[1] == (1, 3)); + assert!(out[2] == (2, 3)); +} diff --git a/datapond-derive/tests/test.rs b/datapond-derive/tests/test.rs new file mode 100644 index 0000000..35ec0e5 --- /dev/null +++ b/datapond-derive/tests/test.rs @@ -0,0 +1,6 @@ +#[test] +fn tests() { + let runner = trybuild::TestCases::new(); + runner.pass("tests/pass/*.rs"); + runner.compile_fail("tests/fail/*.rs"); +} diff --git a/datapond-macro/Cargo.toml b/datapond-macro/Cargo.toml new file mode 100644 index 0000000..0e43391 --- /dev/null +++ b/datapond-macro/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "datapond-macro" +version = "0.1.0" +authors = ["Vytautas Astrauskas "] +edition = "2018" + +[lib] +proc-macro = true + +[dependencies] +datapond = { path = ".." } +proc-macro-hack = "0.5" diff --git a/datapond-macro/src/lib.rs b/datapond-macro/src/lib.rs new file mode 100644 index 0000000..a432f5a --- /dev/null +++ b/datapond-macro/src/lib.rs @@ -0,0 +1,7 @@ +use proc_macro::TokenStream; +use proc_macro_hack::proc_macro_hack; + +#[proc_macro_hack] +pub fn datapond(input: TokenStream) -> TokenStream { + datapond::generate_datafrog(input.into()).into() +} diff --git a/examples/generate_skeleton.rs b/examples/generate_skeleton.rs index b9044b5..6fffcf2 100644 --- a/examples/generate_skeleton.rs +++ b/examples/generate_skeleton.rs @@ -1,5 +1,5 @@ -use std::env; use datapond; +use std::env; fn main() { if env::var("RUST_LOG").is_ok() { diff --git a/src/ast.rs b/src/ast.rs index 3190548..1800f24 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -1,8 +1,8 @@ //! This file contains the typed AST. +use crate::data_structures::OrderedMap; use proc_macro2::Ident; use quote::ToTokens; -use std::collections::HashMap; use std::fmt; /// The predicate kind regarding IO. @@ -98,6 +98,21 @@ pub enum Arg { Wildcard, } +impl Arg { + pub fn to_ident(&self) -> syn::Ident { + match self { + Arg::Ident(ident) => ident.clone(), + Arg::Wildcard => syn::Ident::new("_", proc_macro2::Span::call_site()), + } + } + pub fn is_wildcard(&self) -> bool { + match self { + Arg::Ident(_) => false, + Arg::Wildcard => true, + } + } +} + impl fmt::Display for Arg { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { @@ -188,6 +203,6 @@ impl fmt::Display for Rule { /// A Datalog program. #[derive(Debug, Clone)] pub struct Program { - pub decls: HashMap, + pub decls: OrderedMap, pub rules: Vec, } diff --git a/src/data_structures.rs b/src/data_structures.rs new file mode 100644 index 0000000..b25c35f --- /dev/null +++ b/src/data_structures.rs @@ -0,0 +1,54 @@ +use std::collections::HashMap; + +/// A map that tracks insertion order. +#[derive(Debug, Clone)] +pub struct OrderedMap +where + K: Eq + std::hash::Hash, +{ + key_order: Vec, + map: HashMap, +} + +impl OrderedMap { + pub fn new() -> Self { + Self { + key_order: Vec::new(), + map: HashMap::new(), + } + } + pub fn len(&self) -> usize { + self.map.len() + } + pub fn insert(&mut self, k: K, v: V) { + assert!(self.map.insert(k.clone(), v).is_none()); + self.key_order.push(k); + } + pub fn get(&self, k: &K) -> Option<&V> { + self.map.get(k) + } + pub fn values<'a>(&'a self) -> Vec<&'a V> { + self.key_order.iter().map(|k| &self.map[k]).collect() + } +} + +impl std::iter::FromIterator<(K, V)> for OrderedMap { + fn from_iter>(iter: I) -> Self { + let mut s = Self { + key_order: Vec::new(), + map: HashMap::new(), + }; + for (k, v) in iter { + s.insert(k, v); + } + s + } +} + +impl std::ops::Index<&K> for OrderedMap { + type Output = V; + + fn index(&self, key: &K) -> &Self::Output { + &self.map[key] + } +} diff --git a/src/generator.rs b/src/generator.rs index 7666a05..0ac7019 100644 --- a/src/generator.rs +++ b/src/generator.rs @@ -1,9 +1,11 @@ +use crate::data_structures::OrderedMap; use crate::{ast, parser, typechecker}; use quote::ToTokens; use rustc_hash::{FxHashMap, FxHashSet}; -use std::collections::HashMap; use std::fmt::{self, Write}; +type HashMap = OrderedMap; + /// The representation of what a datalog rule does in datafrog terms enum Operation { StaticMap(), @@ -909,26 +911,25 @@ fn generate_skeleton_code( let relation_args = join_args_as_tuple(&declared_args, &key, &args); format!("({}, _)", relation_args) } else { - let arg_names: Vec<_> = indexed_literal - .args - .iter() - .map(|v| v.to_string()) - .collect(); - + let arg_names: Vec<_> = indexed_literal.args.iter().map(|v| v.to_string()).collect(); + let canonicalized_key: Vec<_> = key .iter() .map(|v| canonicalize_arg_name(&decls, &indexed_literal.predicate, &arg_names, v)) .collect(); - + let canonicalized_args: Vec<_> = args .iter() .map(|v| canonicalize_arg_name(&decls, &indexed_literal.predicate, &arg_names, v)) .collect(); - - produced_key = join_args_as_tuple(&canonicalized_key, &canonicalized_key, &canonicalized_args); - produced_args = join_args_as_tuple(&canonicalized_args, &canonicalized_key, &canonicalized_args); - let relation_args = join_args_as_tuple(&declared_args, &canonicalized_key, &canonicalized_args); + produced_key = + join_args_as_tuple(&canonicalized_key, &canonicalized_key, &canonicalized_args); + produced_args = + join_args_as_tuple(&canonicalized_args, &canonicalized_key, &canonicalized_args); + + let relation_args = + join_args_as_tuple(&declared_args, &canonicalized_key, &canonicalized_args); relation_args }; @@ -1073,13 +1074,10 @@ fn find_arg_decl<'a>( args: &Vec, variable: &str, ) -> &'a ast::ParamDecl { - let idx = args - .iter() - .position(|arg| arg == variable) - .expect(&format!( - "Couldn't find variable {:?} in the specified args: {:?}", - variable, args - )); + let idx = args.iter().position(|arg| arg == variable).expect(&format!( + "Couldn't find variable {:?} in the specified args: {:?}", + variable, args + )); let predicate_arg_decls = &global_decls[&predicate.to_string()]; let arg_decl = &predicate_arg_decls.parameters[idx]; diff --git a/src/generator_new/ast.rs b/src/generator_new/ast.rs new file mode 100644 index 0000000..bb1ac52 --- /dev/null +++ b/src/generator_new/ast.rs @@ -0,0 +1,346 @@ +//! # Examples +//! +//! ## Example 1 +//! +//! ```Datalog +//! input in(x: u32, y: u32); +//! output r(x: u32, y: u32); +//! r(x, y) = in(y, x); +//! ``` +//! ``in`` is assumed to be a variable of type ``&Vec<(u32, u32)>``. +//! ```ignore +//! let r = in.iter().map(|(y, x)| {(x, y)}); +//! ``` +//! +//! ## Example 2 +//! +//! ```Datalog +//! input in(x: u32, y: u32); +//! output r(x: u32, y: u32); +//! r(x, y) = in(y, x); +//! r(x, y) = r(x, z), r(z, y); +//! ``` +//! ``in`` is assumed to be a variable of type ``&Vec<(u32, u32)>``. +//! ```ignore +//! let mut iteration = Iteration::new(); +//! let r = iteration.variable::<(u32, u32)>("r"); +//! let r_1 = iteration.variable::<(u32, u32)>("r_1"); +//! let r_2 = iteration.variable::<(u32, u32)>("r_2"); +//! while iteration.changed() { +//! r_1.from_map(&r, |(x, z)| {(z, x)}); +//! r_2.from_map(&r, |(z, y)| {(z, y)}); +//! r.from_join(&r_1, &r_2, |(z, x, y)| {z, x, y}); +//! } +//! let r = in.iter().map(|(y, x)| {(x, y)}); +//! ``` + +use crate::data_structures::OrderedMap; +use std::collections::HashMap; + +/// A Datalog variable. +/// +/// For example, `x` in the following: +/// ```ignore +/// r_1.from_map(&r, |(x, z)| {(z, x)}); +/// ``` +#[derive(Debug, Clone)] +pub(crate) struct DVar { + pub name: syn::Ident, +} + +impl DVar { + pub fn new(name: syn::Ident) -> Self { + Self { name: name } + } +} + +/// A flat tuple of `DVar`s. Typically used to represent the user defined types. +#[derive(Debug)] +pub(crate) struct DVarTuple { + pub vars: Vec, +} + +impl DVarTuple { + pub fn new(args: Vec) -> Self { + Self { + vars: args.into_iter().map(|ident| DVar::new(ident)).collect(), + } + } +} + +/// A (key, value) representation of `DVar`s. It is used for joins. +#[derive(Debug)] +pub(crate) struct DVarKeyVal { + pub key: Vec, + pub value: Vec, +} + +/// An ordered set of `DVar`s. +#[derive(Debug)] +pub(crate) enum DVars { + Tuple(DVarTuple), + KeyVal(DVarKeyVal), +} + +impl DVars { + pub fn new_tuple(args: Vec) -> Self { + DVars::Tuple(DVarTuple::new(args)) + } + pub fn new_key_val(key: Vec, value: Vec) -> Self { + DVars::KeyVal(DVarKeyVal { + key: key.into_iter().map(|ident| DVar::new(ident)).collect(), + value: value.into_iter().map(|ident| DVar::new(ident)).collect(), + }) + } +} + +/// A type that matches some `DVars`. +#[derive(Debug)] +pub(crate) enum DVarTypes { + Tuple(Vec), + KeyVal { + key: Vec, + value: Vec, + }, +} + +impl std::convert::From> for DVarTypes { + fn from(types: Vec) -> Self { + DVarTypes::Tuple(types) + } +} + +/// A Datafrog relation. +#[derive(Debug)] +pub(crate) struct RelationDecl { + pub var: Variable, + pub typ: Vec, +} + +/// A Datafrog variable. +/// +/// For example, `rule` in the following: +/// ```ignore +/// let rule = iteration.variable::<(u32, u32)>("rule"); +/// ``` +#[derive(Debug)] +pub(crate) struct VariableDecl { + pub var: Variable, + /// The type by shape must match `DVarKeyVal`. + pub typ: DVarTypes, + pub is_output: bool, +} + +/// A reference to a Datafrog relation or variable. +#[derive(Debug, Clone)] +pub(crate) struct Variable { + pub name: syn::Ident, +} + +impl Variable { + pub fn with_counter(&self, counter: usize) -> Self { + Self { + name: syn::Ident::new( + &format!("{}_{}", self.name, counter), + proc_macro2::Span::call_site(), + ), + } + } +} + +/// An operation that reorders and potentially drops Datalog variables. +/// +/// It is encoded as a Datafrog `from_map`. +#[derive(Debug)] +pub(crate) struct ReorderOp { + /// A variable into which we write the result. + pub output: Variable, + /// A variable from which we read the input. + pub input: Variable, + pub input_vars: DVars, + pub output_vars: DVars, +} + +/// An operation that evaluates the given expression and adds it as a last output variable. +#[derive(Debug)] +pub(crate) struct BindVarOp { + /// A variable into which we write the result. + pub output: Variable, + /// A variable from which we read the input. + pub input: Variable, + /// Input variables that are copied to output and potentially used for evaluating `expr`. + pub vars: DVarTuple, + /// The expression whose result is bound to a new variable. + pub expr: syn::Expr, +} + +/// An operation that joins two variables. +#[derive(Debug)] +pub(crate) struct JoinOp { + /// A variable into which we write the result. + pub output: Variable, + /// The first variable, which we use in join. + pub input_first: Variable, + /// The second variable, which we use in join. + pub input_second: Variable, + /// Datalog variables used for joining. + pub key: DVarTuple, + /// Datalog value variables from the first variable. + pub value_first: DVarTuple, + /// Datalog value variables from the second variable. + pub value_second: DVarTuple, +} + +/// An operation that removes facts from the variable that belong to the relation. +#[derive(Debug)] +pub(crate) struct AntiJoinOp { + /// The variable into which we write the result. + pub output: Variable, + /// The variable from which we take facts. + pub input_variable: Variable, + /// The relation in which we check facts. + pub input_relation: Variable, + /// Datalog variables used for joining. + pub key: DVarTuple, + /// Datalog value variables from the variable. + pub value: DVarTuple, +} + +/// An operation that filters out facts. +#[derive(Debug)] +pub(crate) struct FilterOp { + /// A variable which we want to filter. + pub variable: Variable, + pub vars: DVars, + /// A boolean expression used for filtering. + pub expr: syn::Expr, +} + +/// An operation that inserts the relation into a variable. +#[derive(Debug)] +pub(crate) struct InsertOp { + /// The variable into which we want to insert the relation. + pub variable: Variable, + /// The relation to be inserted. + pub relation: Variable, +} + +#[derive(Debug)] +pub(crate) enum Operation { + Reorder(ReorderOp), + // BindVar(BindVarOp), + Join(JoinOp), + AntiJoin(AntiJoinOp), + // Filter(FilterOp), + Insert(InsertOp), +} + +/// A Datafrog iteration. +#[derive(Debug)] +pub(crate) struct Iteration { + /// Variables that are converted relations. + relation_variables: HashMap, + pub relations: OrderedMap, + pub variables: OrderedMap, + /// Operations performed before entering the iteration. + pub pre_operations: Vec, + /// Operations performed in the body of the iteration. + pub body_operations: Vec, + /// Operations performed after exiting the iteration. + pub post_operations: Vec, +} + +impl Iteration { + pub fn new(relations: Vec, variables: Vec) -> Self { + Self { + relation_variables: HashMap::new(), + relations: relations + .into_iter() + .map(|decl| (decl.var.name.clone(), decl)) + .collect(), + variables: variables + .into_iter() + .map(|decl| (decl.var.name.clone(), decl)) + .collect(), + pre_operations: Vec::new(), + body_operations: Vec::new(), + post_operations: Vec::new(), + } + } + /// Convert a Datafrog relation to a Datafrog variable and return its identifier. + pub fn convert_relation_to_variable(&mut self, variable: &Variable) -> Variable { + if let Some(name) = self.relation_variables.get(&variable.name) { + return self.variables[name].var.clone(); + } + let decl = &self.relations[&variable.name]; + let variable_decl = VariableDecl { + var: decl.var.with_counter(self.variables.len()), + typ: decl.typ.clone().into(), + is_output: false, + }; + let new_variable = variable_decl.var.clone(); + self.relation_variables + .insert(variable.name.clone(), new_variable.name.clone()); + self.variables + .insert(new_variable.name.clone(), variable_decl); + self.pre_operations.push(Operation::Insert(InsertOp { + variable: new_variable.clone(), + relation: decl.var.clone(), + })); + new_variable + } + /// Get Datafrog variable that corresponds to the given predicate name. If + /// we have only a relation, then convert it into a variable. + pub fn get_or_convert_variable(&mut self, predicate: &syn::Ident) -> Variable { + if let Some(variable) = self.get_relation_var(predicate) { + // TODO: Avoid converting the same relation multiple times. + self.convert_relation_to_variable(&variable) + } else { + self.get_variable(predicate) + } + } + pub fn get_relation_var(&self, variable_name: &syn::Ident) -> Option { + self.relations + .get(variable_name) + .map(|decl| decl.var.clone()) + } + pub fn get_variable(&self, variable_name: &syn::Ident) -> Variable { + self.variables[variable_name].var.clone() + } + pub fn add_operation(&mut self, operation: Operation) { + self.body_operations.push(operation); + } + pub fn get_variable_tuple_types(&self, variable: &Variable) -> Vec { + let decl = &self.variables[&variable.name]; + match &decl.typ { + DVarTypes::Tuple(types) => types.clone(), + DVarTypes::KeyVal { .. } => unreachable!(), + } + } + pub fn create_key_val_variable( + &mut self, + variable: &Variable, + key: Vec, + value: Vec, + ) -> Variable { + self.create_variable(variable, DVarTypes::KeyVal { key, value }) + } + pub fn create_tuple_variable( + &mut self, + variable: &Variable, + types: Vec, + ) -> Variable { + self.create_variable(variable, DVarTypes::Tuple(types)) + } + pub fn create_variable(&mut self, variable: &Variable, typ: DVarTypes) -> Variable { + let variable_decl = VariableDecl { + var: variable.with_counter(self.variables.len()), + typ: typ, + is_output: false, + }; + let new_variable = variable_decl.var.clone(); + self.variables + .insert(new_variable.name.clone(), variable_decl); + new_variable + } +} diff --git a/src/generator_new/encode.rs b/src/generator_new/encode.rs new file mode 100644 index 0000000..c544494 --- /dev/null +++ b/src/generator_new/encode.rs @@ -0,0 +1,405 @@ +use crate::ast; +use crate::generator_new::ast as gen; + +/// Divide the arguments into three sets: +/// +/// 1. `key` – the arguments that are common in `first` and `second`. +/// 2. `first_remainder` – the arguments that are unique in `first`. +/// 3. `second_remainder` – the arguments that are unique in `second`. +fn common_args( + first: &Vec, + first_types: &Vec, + second: &Vec, + second_types: &Vec, +) -> ( + (Vec, Vec), + (Vec, Vec), + (Vec, Vec), +) { + assert!(first.len() == first_types.len()); + assert!(second.len() == second_types.len()); + + let mut key = Vec::new(); + let mut key_types = Vec::new(); + let mut first_remainder = Vec::new(); + let mut first_remainder_types = Vec::new(); + + for (arg1, arg1_type) in first.iter().zip(first_types) { + if arg1.is_wildcard() { + continue; + } + let mut found = false; + for arg2 in second { + if arg1 == arg2 { + key.push(arg1.clone()); + key_types.push(arg1_type.clone()); + found = true; + break; + } + } + if !found { + first_remainder.push(arg1.clone()); + first_remainder_types.push(arg1_type.clone()); + } + } + let mut second_remainder = Vec::new(); + let mut second_remainder_types = Vec::new(); + for (arg2, arg2_type) in second.iter().zip(second_types) { + if arg2.is_wildcard() { + continue; + } + if !key.contains(arg2) { + second_remainder.push(arg2.clone()); + second_remainder_types.push(arg2_type.clone()); + } + } + + ( + (key, key_types), + (first_remainder, first_remainder_types), + (second_remainder, second_remainder_types), + ) +} + +pub(crate) fn encode(program: ast::Program) -> gen::Iteration { + let mut relations = Vec::new(); + let mut variables = Vec::new(); + for decl in program.decls.values() { + let var = gen::Variable { + name: decl.name.clone(), + }; + let typ = decl + .parameters + .iter() + .map(|param| param.typ.clone()) + .collect(); + match decl.kind { + ast::PredicateKind::Input => { + relations.push(gen::RelationDecl { var: var, typ: typ }); + } + ast::PredicateKind::Internal => { + variables.push(gen::VariableDecl { + var: var, + typ: gen::DVarTypes::Tuple(typ), + is_output: false, + }); + } + ast::PredicateKind::Output => { + variables.push(gen::VariableDecl { + var: var, + typ: gen::DVarTypes::Tuple(typ), + is_output: true, + }); + } + } + } + let mut iteration = gen::Iteration::new(relations, variables); + for rule in &program.rules { + let head_variable = iteration.get_variable(&rule.head.predicate); + let mut iter = rule.body.iter(); + let literal1 = iter.next().unwrap(); + assert!(!literal1.is_negated); + let mut variable = iteration.get_or_convert_variable(&literal1.predicate); + let mut args = literal1.args.clone(); + + while let Some(literal) = iter.next() { + // TODO: Check during the typechecking phase that no literal has two + // arguments with the same name. + let (new_variable, new_args) = if literal.is_negated { + encode_antijoin(&mut iteration, &head_variable, variable, args, literal) + } else { + encode_join(&mut iteration, &head_variable, variable, args, literal) + }; + variable = new_variable; + args = new_args; + } + let reorder_op = gen::ReorderOp { + output: head_variable, + input: variable, + input_vars: args.into(), + output_vars: rule.head.args.clone().into(), + }; + iteration.add_operation(gen::Operation::Reorder(reorder_op)); + } + iteration +} + +fn encode_antijoin( + iteration: &mut gen::Iteration, + head_variable: &gen::Variable, + variable: gen::Variable, + args: Vec, + literal: &ast::Literal, +) -> (gen::Variable, Vec) { + let relation_variable = iteration + .get_relation_var(&literal.predicate) + .expect("Negations are currently supported only on relations."); + let arg_types = iteration.get_variable_tuple_types(&variable); + let literal_arg_types = iteration.relations[&relation_variable.name].typ.clone(); + + // TODO: Lift this limitation. + for arg in &literal.args { + if !args.contains(arg) { + unimplemented!("Currently all variables from the negated relation must be used."); + } + } + let mut remainder = Vec::new(); + let mut remainder_types = Vec::new(); + for (arg, arg_type) in args.iter().zip(&arg_types) { + if !literal.args.contains(arg) { + remainder.push(arg.clone()); + remainder_types.push(arg_type.clone()); + } + } + + let first_variable = iteration.create_key_val_variable( + &variable, + literal_arg_types.clone(), + remainder_types.clone(), + ); + let reorder_first_op = gen::ReorderOp { + output: first_variable.clone(), + input: variable, + input_vars: args.into(), + output_vars: (literal.args.clone(), remainder.clone()).into(), + }; + iteration.add_operation(gen::Operation::Reorder(reorder_first_op)); + + let result_types = literal_arg_types + .into_iter() + .chain(remainder_types) + .collect(); + let args = literal + .args + .clone() + .into_iter() + .chain(remainder.clone()) + .collect(); + let variable = iteration.create_tuple_variable(&head_variable, result_types); + let join_op = gen::AntiJoinOp { + output: variable.clone(), + input_variable: first_variable, + input_relation: relation_variable, + key: literal.args.clone().into(), + value: remainder.into(), + }; + iteration.add_operation(gen::Operation::AntiJoin(join_op)); + (variable, args) +} + +fn encode_join( + iteration: &mut gen::Iteration, + head_variable: &gen::Variable, + variable: gen::Variable, + args: Vec, + literal: &ast::Literal, +) -> (gen::Variable, Vec) { + let joined_variable = iteration.get_or_convert_variable(&literal.predicate); + let arg_types = iteration.get_variable_tuple_types(&variable); + let literal_arg_types = iteration.get_variable_tuple_types(&joined_variable); + let ((key, key_types), (remainder1, remainder1_types), (remainder2, remainder2_types)) = + common_args(&args, &arg_types, &literal.args, &literal_arg_types); + let first_variable = + iteration.create_key_val_variable(&variable, key_types.clone(), remainder1_types.clone()); + let reorder_first_op = gen::ReorderOp { + output: first_variable.clone(), + input: variable, + input_vars: args.into(), + output_vars: (key.clone(), remainder1.clone()).into(), + }; + iteration.add_operation(gen::Operation::Reorder(reorder_first_op)); + let second_variable = iteration.create_key_val_variable( + &joined_variable, + key_types.clone(), + remainder2_types.clone(), + ); + let reorder_second_op = gen::ReorderOp { + output: second_variable.clone(), + input: joined_variable, + input_vars: literal.args.clone().into(), + output_vars: (key.clone(), remainder2.clone()).into(), + }; + iteration.add_operation(gen::Operation::Reorder(reorder_second_op)); + let result_types = key_types + .into_iter() + .chain(remainder1_types) + .chain(remainder2_types) + .collect(); + let args = key + .clone() + .into_iter() + .chain(remainder1.clone()) + .chain(remainder2.clone()) + .collect(); + let variable = iteration.create_tuple_variable(&head_variable, result_types); + let join_op = gen::JoinOp { + output: variable.clone(), + input_first: first_variable, + input_second: second_variable, + key: key.into(), + value_first: remainder1.into(), + value_second: remainder2.into(), + }; + iteration.add_operation(gen::Operation::Join(join_op)); + (variable, args) +} + +impl std::convert::From> for gen::DVars { + fn from(args: Vec) -> Self { + gen::DVars::new_tuple(args.into_iter().map(|arg| arg.to_ident()).collect()) + } +} + +impl std::convert::From> for gen::DVarTuple { + fn from(args: Vec) -> Self { + gen::DVarTuple::new(args.into_iter().map(|arg| arg.to_ident()).collect()) + } +} + +impl std::convert::From<(Vec, Vec)> for gen::DVars { + fn from((key, value): (Vec, Vec)) -> Self { + gen::DVars::new_key_val( + key.into_iter().map(|arg| arg.to_ident()).collect(), + value.into_iter().map(|arg| arg.to_ident()).collect(), + ) + } +} + +impl std::convert::From> for gen::DVars { + fn from(args: Vec) -> Self { + gen::DVars::new_tuple(args) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::parser::parse; + use crate::typechecker::typecheck; + use proc_macro2::TokenStream; + use quote::ToTokens; + use std::str::FromStr; + + fn compare(datalog_source: &str, exptected_encoding: &str) { + let parsed_program = parse(datalog_source); + let typechecked_program = typecheck(parsed_program).unwrap(); + let iteration = encode(typechecked_program); + let tokens = iteration.to_token_stream().to_string(); + eprintln!("{}", tokens); + let expected_tokens = TokenStream::from_str(exptected_encoding).unwrap(); + assert_eq!(tokens.to_string(), expected_tokens.to_string()); + } + + #[test] + fn encode_simple1() { + compare( + " + input inp(x: u32, y: u32) + output out(x: u32, y: u32) + out(x, y) :- inp(y, x). + ", + r##" + { + let mut iteration = datafrog::Iteration::new(); + let var_inp = datafrog::Relation:: <(u32, u32,)> ::from_vec(inp); + let var_out = iteration.variable:: <(u32, u32,)>("out"); + let var_inp_1 = iteration.variable:: <(u32, u32,)>("inp_1"); + var_inp_1.insert(var_inp); + while iteration.changed() { + var_out.from_map(&var_inp_1, | &(y, x,)| (x, y,)); + } + out = var_out.complete(); + } + "##, + ); + } + #[test] + fn encode_transitive_closure() { + compare( + " + input inp(x: u32, y: u32) + output out(x: u32, y: u32) + out(x, y) :- inp(x, y). + out(x, y) :- out(x, z), out(z, y). + ", + r##" + { + let mut iteration = datafrog::Iteration::new(); + let var_inp = datafrog::Relation:: <(u32, u32,)> ::from_vec(inp); + let var_out = iteration.variable:: <(u32, u32,)>("out"); + let var_inp_1 = iteration.variable:: <(u32, u32,)>("inp_1"); + let var_out_2 = iteration.variable:: <((u32,), (u32,))>("out_2"); + let var_out_3 = iteration.variable:: <((u32,), (u32,))>("out_3"); + let var_out_4 = iteration.variable:: <(u32, u32, u32,)>("out_4"); + var_inp_1.insert(var_inp); + while iteration.changed() { + var_out.from_map(&var_inp_1, | &(x, y,)| (x, y,)); + var_out_2.from_map(&var_out, | &(x, z,)| ((z,), (x,))); + var_out_3.from_map(&var_out, | &(z, y,)| ((z,), (y,))); + var_out_4.from_join(&var_out_2, &var_out_3, | &(z,), &(x,), &(y,)| (z, x, y,)); + var_out.from_map(&var_out_4, | &(z, x, y,)| (x, y,)); + } + out = var_out.complete(); + } + "##, + ); + } + #[test] + fn encode_rule_with_wildcards() { + compare( + " + input inp(x: u32, y: u32) + output out(x: u32) + out(x) :- inp(x, _), inp(_, x). + ", + r##" + { + let mut iteration = datafrog::Iteration::new(); + let var_inp = datafrog::Relation:: <(u32, u32,)> ::from_vec(inp); + let var_out = iteration.variable:: <(u32,)>("out"); + let var_inp_1 = iteration.variable:: <(u32, u32,)>("inp_1"); + let var_inp_1_2 = iteration.variable:: <((u32,), ())>("inp_1_2"); + let var_inp_1_3 = iteration.variable:: <((u32,), ())>("inp_1_3"); + let var_out_4 = iteration.variable:: <(u32,)>("out_4"); + var_inp_1.insert(var_inp); + while iteration.changed() { + var_inp_1_2.from_map(&var_inp_1, | &(x, _,)| ((x,), ())); + var_inp_1_3.from_map(&var_inp_1, | &(_, x,)| ((x,), ())); + var_out_4.from_join(&var_inp_1_2, &var_inp_1_3, | &(x,), &(), &()| (x,)); + var_out.from_map(&var_out_4, | &(x,)| (x,)); + } + out = var_out.complete(); + } + "##, + ); + } + #[test] + fn encode_kill() { + compare( + " + input inp(x: u32, y: u32) + input kill(y: u32) + output out(x: u32, y: u32) + out(x, y) :- inp(x, y), !kill(y). + ", + r##" + { + let mut iteration = datafrog::Iteration::new(); + let var_inp = datafrog::Relation:: <(u32, u32,)> ::from_vec(inp); + let var_kill = datafrog::Relation:: <(u32,)> ::from_vec(kill); + let var_out = iteration.variable:: <(u32, u32,)>("out"); + let var_inp_1 = iteration.variable:: <(u32, u32,)>("inp_1"); + let var_inp_1_2 = iteration.variable:: <((u32,), (u32,))>("inp_1_2"); + let var_out_3 = iteration.variable:: <(u32, u32,)>("out_3"); + var_inp_1.insert(var_inp); + while iteration.changed() { + var_inp_1_2.from_map(&var_inp_1, | &(x, y,)| ((y,), (x,))); + var_out_3.from_antijoin(&var_inp_1_2, &var_kill, | &(y,), &(x,)| (y, x,)); + var_out.from_map(&var_out_3, | &(y, x,)| (x, y,)); + } + out = var_out.complete(); + } + "##, + ); + } +} diff --git a/src/generator_new/mod.rs b/src/generator_new/mod.rs new file mode 100644 index 0000000..7de57bd --- /dev/null +++ b/src/generator_new/mod.rs @@ -0,0 +1,25 @@ +use proc_macro2::TokenStream; +use quote::{quote, ToTokens}; + +mod ast; +mod encode; +mod to_tokens; + +pub fn generate_datafrog(input: TokenStream) -> TokenStream { + let parsed_program = match syn::parse2(input) { + Ok(program) => program, + Err(err) => { + let tokens = TokenStream::from(err.to_compile_error()); + return quote! { {#tokens }}; + } + }; + let typechecked_program = match crate::typechecker::typecheck(parsed_program) { + Ok(program) => program, + Err(err) => { + let tokens = TokenStream::from(err.to_syn_error().to_compile_error()); + return quote! { {#tokens }}; + } + }; + let encoded_program = encode::encode(typechecked_program); + encoded_program.to_token_stream() +} diff --git a/src/generator_new/to_tokens.rs b/src/generator_new/to_tokens.rs new file mode 100644 index 0000000..8df1e8f --- /dev/null +++ b/src/generator_new/to_tokens.rs @@ -0,0 +1,228 @@ +use crate::generator_new::ast::*; +use proc_macro2::{Span, TokenStream}; +use quote::quote; +use quote::ToTokens; + +fn type_vec_to_tokens(type_vec: &Vec) -> TokenStream { + let mut type_tokens = TokenStream::new(); + for typ in type_vec { + type_tokens.extend(quote! {#typ,}); + } + type_tokens +} + +fn var_vec_to_tokens(var_vec: &Vec) -> TokenStream { + let mut var_tokens = TokenStream::new(); + for var in var_vec { + var_tokens.extend(quote! {#var,}); + } + var_tokens +} + +impl ToTokens for DVar { + fn to_tokens(&self, tokens: &mut TokenStream) { + let name = &self.name; + tokens.extend(quote! {#name}); + } +} + +impl ToTokens for DVarTuple { + fn to_tokens(&self, tokens: &mut TokenStream) { + let vars = var_vec_to_tokens(&self.vars); + tokens.extend(quote! {(#vars)}); + } +} + +impl ToTokens for DVarKeyVal { + fn to_tokens(&self, tokens: &mut TokenStream) { + let key = var_vec_to_tokens(&self.key); + let value = var_vec_to_tokens(&self.value); + tokens.extend(quote! {((#key), (#value))}); + } +} + +impl ToTokens for DVars { + fn to_tokens(&self, tokens: &mut TokenStream) { + match self { + DVars::Tuple(tuple) => tuple.to_tokens(tokens), + DVars::KeyVal(key_val) => key_val.to_tokens(tokens), + } + } +} + +impl ToTokens for DVarTypes { + fn to_tokens(&self, tokens: &mut TokenStream) { + match self { + DVarTypes::Tuple(types) => { + let type_tokens = type_vec_to_tokens(types); + tokens.extend(quote! {(#type_tokens)}); + } + DVarTypes::KeyVal { key, value } => { + let key_tokens = type_vec_to_tokens(key); + let value_tokens = type_vec_to_tokens(value); + tokens.extend(quote! {((#key_tokens), (#value_tokens))}); + } + } + } +} + +impl ToTokens for Variable { + fn to_tokens(&self, tokens: &mut TokenStream) { + let var_name = format!("var_{}", self.name); + let ident = syn::Ident::new(&var_name, Span::call_site()); + tokens.extend(quote! {#ident}); + } +} + +impl ToTokens for ReorderOp { + fn to_tokens(&self, tokens: &mut TokenStream) { + let ReorderOp { + output, + input, + input_vars, + output_vars, + } = self; + tokens.extend(quote! { + #output.from_map(&#input, |&#input_vars| #output_vars); + }); + } +} + +impl ToTokens for BindVarOp { + fn to_tokens(&self, _tokens: &mut TokenStream) { + unimplemented!(); + } +} + +impl ToTokens for JoinOp { + fn to_tokens(&self, tokens: &mut TokenStream) { + let JoinOp { + output, + input_first, + input_second, + key, + value_first, + value_second, + } = self; + let flattened = DVarTuple { + vars: key + .vars + .iter() + .chain(&value_first.vars) + .chain(&value_second.vars) + .cloned() + .collect(), + }; + + tokens.extend(quote! { + #output.from_join( + &#input_first, + &#input_second, + |&#key, &#value_first, &#value_second| #flattened); + }); + } +} + +impl ToTokens for AntiJoinOp { + fn to_tokens(&self, tokens: &mut TokenStream) { + let AntiJoinOp { + output, + input_variable, + input_relation, + key, + value, + } = self; + let flattened = DVarTuple { + vars: key.vars.iter().chain(&value.vars).cloned().collect(), + }; + + tokens.extend(quote! { + #output.from_antijoin( + &#input_variable, + &#input_relation, + |&#key, &#value| #flattened); + }); + } +} + +impl ToTokens for FilterOp { + fn to_tokens(&self, _tokens: &mut TokenStream) { + unimplemented!(); + } +} + +impl ToTokens for InsertOp { + fn to_tokens(&self, tokens: &mut TokenStream) { + let InsertOp { variable, relation } = self; + tokens.extend(quote! { + #variable.insert(#relation); + }); + } +} + +impl ToTokens for Operation { + fn to_tokens(&self, tokens: &mut TokenStream) { + match self { + Operation::Reorder(op) => op.to_tokens(tokens), + // Operation::BindVar(op) => op.to_tokens(tokens), + Operation::Join(op) => op.to_tokens(tokens), + Operation::AntiJoin(op) => op.to_tokens(tokens), + // Operation::Filter(op) => op.to_tokens(tokens), + Operation::Insert(op) => op.to_tokens(tokens), + } + } +} + +fn operation_vec_to_tokens(operations: &Vec) -> TokenStream { + let mut tokens = TokenStream::new(); + for operation in operations { + operation.to_tokens(&mut tokens); + } + tokens +} + +impl ToTokens for Iteration { + fn to_tokens(&self, tokens: &mut TokenStream) { + let mut declare_relations = TokenStream::new(); + for relation in self.relations.values() { + let vec_name = &relation.var.name; + let var = relation.var.to_token_stream(); + let typ = type_vec_to_tokens(&relation.typ); + declare_relations.extend(quote! { + let #var = datafrog::Relation::<(#typ)>::from_vec(#vec_name); + }); + } + let mut declare_variables = TokenStream::new(); + let mut output_results = TokenStream::new(); + for variable in self.variables.values() { + let var = variable.var.to_token_stream(); + let var_name = variable.var.name.to_string(); + let typ = variable.typ.to_token_stream(); + declare_variables.extend(quote! { + let #var = iteration.variable::<#typ>(#var_name); + }); + if variable.is_output { + let new_var = &variable.var.name; + output_results.extend(quote! { + #new_var = #var.complete(); + }); + } + } + let pre_operations = operation_vec_to_tokens(&self.pre_operations); + let body_operations = operation_vec_to_tokens(&self.body_operations); + let post_operations = operation_vec_to_tokens(&self.post_operations); + tokens.extend(quote! { + { + let mut iteration = datafrog::Iteration::new(); + #declare_relations + #declare_variables + #pre_operations + while iteration.changed() { + #body_operations + } + #post_operations + #output_results + } + }); + } +} diff --git a/src/lib.rs b/src/lib.rs index 910e76e..3edb3e0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,8 +2,11 @@ extern crate log; mod ast; +mod data_structures; mod generator; +mod generator_new; mod parser; mod typechecker; pub use generator::generate_skeleton_datafrog; +pub use generator_new::generate_datafrog; diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 74f87f5..56edc91 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -136,24 +136,29 @@ impl Parse for ast::RuleHead { impl Parse for ast::Rule { fn parse(input: ParseStream) -> syn::Result { let head = input.parse()?; - input.step(|cursor| { - let rest = match cursor.token_tree() { - Some((proc_macro2::TokenTree::Punct(ref punct), next)) - if punct.as_char() == ':' && punct.spacing() == proc_macro2::Spacing::Joint => - { - next - } - _ => return Err(cursor.error(":- expected")), - }; - match rest.token_tree() { - Some((proc_macro2::TokenTree::Punct(ref punct), next)) - if punct.as_char() == '-' => - { - Ok(((), next)) - } - _ => Err(cursor.error(":- expected")), - } - })?; + // FIXME: For some reason, when getting input from a procedural macro, + // a space is always inserted between `:` and `-`. Therefore, the parser + // needs to accept the variant with a space. + input.parse::()?; + input.parse::()?; + // input.step(|cursor| { + // let rest = match cursor.token_tree() { + // Some((proc_macro2::TokenTree::Punct(ref punct), next)) + // if punct.as_char() == ':' && punct.spacing() == proc_macro2::Spacing::Joint => + // { + // next + // } + // _ => return Err(cursor.error(":- expected")), + // }; + // match rest.token_tree() { + // Some((proc_macro2::TokenTree::Punct(ref punct), next)) + // if punct.as_char() == '-' => + // { + // Ok(((), next)) + // } + // _ => Err(cursor.error(":- expected")), + // } + // })?; let body: Punctuated = Punctuated::parse_separated_nonempty(input)?; // Allow trailing punctuation. diff --git a/src/typechecker.rs b/src/typechecker.rs index 20d22c8..107eba5 100644 --- a/src/typechecker.rs +++ b/src/typechecker.rs @@ -1,4 +1,5 @@ use crate::ast; +use crate::data_structures::OrderedMap; use crate::parser::ast as past; use proc_macro2::Span; use std::collections::{HashMap, HashSet}; @@ -8,7 +9,7 @@ use std::fmt; pub struct Error { pub msg: String, pub span: Span, - pub hint_span: Option, + pub hint: Option<(String, Span)>, } impl Error { @@ -16,26 +17,34 @@ impl Error { Self { msg: msg, span: span, - hint_span: None, + hint: None, } } - fn with_hint_span(msg: String, span: Span, hint_span: Span) -> Self { + fn with_hint_span(msg: String, span: Span, hint_msg: String, hint_span: Span) -> Self { Self { msg: msg, span: span, - hint_span: Some(hint_span), + hint: Some((hint_msg, hint_span)), } } + pub fn to_syn_error(&self) -> syn::Error { + let mut error = syn::Error::new(self.span, &self.msg); + if let Some((hint_msg, hint_span)) = &self.hint { + error.combine(syn::Error::new(hint_span.clone(), hint_msg)); + } + error + } } impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - if let Some(hint_span) = self.hint_span { + if let Some((hint_msg, hint_span)) = &self.hint { write!( f, - "{} at {:?} (hint: {:?})", + "{} at {:?} ({} at {:?})", self.msg, self.span.start(), + hint_msg, hint_span.start() ) } else { @@ -46,7 +55,7 @@ impl fmt::Display for Error { fn check_head( head: &past::RuleHead, - decls: &HashMap, + decls: &OrderedMap, ) -> Result<(), Error> { let decl = decls.get(&head.predicate.to_string()).ok_or_else(|| { Error::new( @@ -56,13 +65,15 @@ fn check_head( })?; if head.args.len() != decl.parameters.len() { let msg = format!( - "Wrong number of arguments: expected {}, found {}.", + "Wrong number of arguments for {}: expected {}, found {}.", + head.predicate, + decl.parameters.len(), head.args.len(), - decl.parameters.len() ); return Err(Error::with_hint_span( msg, head.predicate.span(), + format!("The predicate {} was declared here.", head.predicate), decl.name.span(), )); } @@ -71,7 +82,7 @@ fn check_head( fn check_body( body: Vec, - decls: &HashMap, + decls: &OrderedMap, ) -> Result, Error> { let mut new_body = Vec::new(); for literal in body { @@ -85,13 +96,15 @@ fn check_body( past::ArgList::Positional(positional_args) => { if positional_args.len() != decl.parameters.len() { let msg = format!( - "Wrong number of arguments: expected {}, found {}.", + "Wrong number of arguments for {}: expected {}, found {}.", + literal.predicate, positional_args.len(), decl.parameters.len() ); return Err(Error::with_hint_span( msg, literal.predicate.span(), + format!("The predicate {} was declared here.", decl.name), decl.name.span(), )); } @@ -104,39 +117,48 @@ fn check_body( .collect() } past::ArgList::Named(named_args) => { - let kwargs: HashMap<_, _> = named_args - .into_iter() - .map(|named_arg| (named_arg.param.to_string(), named_arg.arg)) - .collect(); - let mut args = Vec::new(); + let mut kwargs = HashMap::new(); let mut used_parameters = HashSet::new(); + for named_arg in named_args { + let param_name = named_arg.param.to_string(); + if used_parameters.contains(¶m_name) { + return Err(Error::new( + format!("Parameter already bound: {}", param_name), + named_arg.param.span(), + )); + } + used_parameters.insert(param_name.clone()); + kwargs.insert(param_name, named_arg); + } + let mut args = Vec::new(); + let mut available_parameters = HashSet::new(); for parameter in &decl.parameters { let param_name = parameter.name.to_string(); let arg = match kwargs.get(¶m_name) { - Some(ident) => { + Some(past::NamedArg { arg: ident, .. }) => { let ident_str = ident.to_string(); - if used_parameters.contains(&ident_str) { - return Err(Error::new( - format!("Parameter already bound: {}", ident_str), - ident.span(), - )); - } used_parameters.insert(ident_str); ast::Arg::Ident(ident.clone()) } None => ast::Arg::Wildcard, }; + available_parameters.insert(param_name); args.push(arg); } for key in kwargs.keys() { - if !used_parameters.contains(key) { + if !available_parameters.contains(key) { + let mut available_parameters: Vec<_> = + available_parameters.into_iter().collect(); + available_parameters.sort(); + let parameter_span = kwargs[key].param.span(); return Err(Error::new( - format!("Unknown parameter {} in predicate.", key), - literal.predicate.span(), + format!("Unknown parameter {} in predicate {}. Available parameters are: {}.", + key, literal.predicate, available_parameters.join(","), + ), + parameter_span, )); } } - if kwargs.len() != used_parameters.len() {} args } }; @@ -151,7 +173,7 @@ fn check_body( } pub(crate) fn typecheck(program: past::Program) -> Result { - let mut decls = HashMap::new(); + let mut decls = OrderedMap::new(); let mut rules = Vec::new(); for item in program.items { diff --git a/tests/cspa_rules.rs b/tests/cspa_rules.rs index 34871ac..f83b4cc 100644 --- a/tests/cspa_rules.rs +++ b/tests/cspa_rules.rs @@ -186,7 +186,6 @@ fn ensure_generated_rules_build() { memory_alias.extend(assign.iter().map(|&(x, _y)| (x, x))); while iteration.changed() { - // Index maintenance value_flow_b.from_map(&value_flow, |&(a, b)| (b, a)); value_flow_a.from_map(&value_flow, |&(a, b)| (a, b)); // useless index