From 409ebb1a980fa81d97bba3adb777c89459bf567e Mon Sep 17 00:00:00 2001 From: Aron Zwaan Date: Wed, 13 Nov 2024 15:44:35 +0100 Subject: [PATCH] Add parser --- scopegraphs/examples/overload/ast.rs | 37 +++ scopegraphs/examples/overload/main.rs | 43 ++- scopegraphs/examples/overload/parse.rs | 344 ++++++++++++++++++++ scopegraphs/examples/overload/union_find.rs | 152 +++++++++ 4 files changed, 575 insertions(+), 1 deletion(-) create mode 100644 scopegraphs/examples/overload/ast.rs create mode 100644 scopegraphs/examples/overload/parse.rs create mode 100644 scopegraphs/examples/overload/union_find.rs diff --git a/scopegraphs/examples/overload/ast.rs b/scopegraphs/examples/overload/ast.rs new file mode 100644 index 0000000..9bc9b15 --- /dev/null +++ b/scopegraphs/examples/overload/ast.rs @@ -0,0 +1,37 @@ +#[derive(Debug, Clone)] +pub struct Program { + pub functions: Vec, + pub main: Expr, +} + +#[derive(Debug, Clone)] +pub struct Function { + pub name: String, + pub args: Vec, + pub return_type: Option, + pub body: Expr, +} + +#[derive(Debug, Clone)] +pub struct Arg { + pub name: String, + pub type_ann: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Type { + IntT, + BoolT, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Expr { + IntLit(u64), + BoolLit(bool), + Ident(String), + Plus(Box, Box), + Lt(Box, Box), + IfThenElse(Box, Box, Box), + FunCall(String, Vec), + Ascribe(Box, Type), +} diff --git a/scopegraphs/examples/overload/main.rs b/scopegraphs/examples/overload/main.rs index df1bb01..b0fe7a5 100644 --- a/scopegraphs/examples/overload/main.rs +++ b/scopegraphs/examples/overload/main.rs @@ -1,3 +1,44 @@ +#![allow(unused)] + +use crate::ast::*; +use crate::parse::parse; + +mod ast; +mod parse; +mod union_find; + pub fn main() { - println!("Hello from overload example!") + let program = " + fun tt() = true; + fun not(b) = if b { false } else { true }; + fun and(b1: bool, b2): bool = + if b1 { + if b2 { + true + } else { + false + } + } else { + false + }; + + $ and(not(false), tt()) + "; + + assert!(parse(program).is_ok()) +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PartialType { + // Types are not recursive, so no need to have separate constructors for each variant + Type(Type), + Variable(TypeVar), +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct TypeVar(usize); + +pub struct FunType { + return_type: PartialType, + arg_types: Vec<(String, PartialType)>, } diff --git a/scopegraphs/examples/overload/parse.rs b/scopegraphs/examples/overload/parse.rs new file mode 100644 index 0000000..31d7c80 --- /dev/null +++ b/scopegraphs/examples/overload/parse.rs @@ -0,0 +1,344 @@ +use winnow::ascii::multispace0; +use winnow::combinator::{ + alt, cut_err, delimited, eof, fail, opt, preceded, repeat, separated, terminated, trace, +}; +use winnow::error::{ErrMode, ParserError, StrContext}; + +use crate::ast::{Arg, Expr, Function, Program, Type}; +use winnow::prelude::*; +use winnow::seq; +use winnow::stream::AsChar; +use winnow::token::{one_of, take_while}; + +fn ws<'a, F, O, E: ParserError<&'a str>>(inner: F) -> impl Parser<&'a str, O, E> +where + F: Parser<&'a str, O, E>, +{ + delimited(multispace0, inner, multispace0) +} + +fn parse_ident(input: &mut &'_ str) -> PResult { + ws(( + one_of(|c: char| c.is_alpha() || c == '_'), + take_while(0.., |c: char| c.is_alphanum() || c == '_'), + ) + .recognize() + .verify(|i: &str| !["fun", "if", "else", "true", "false", "int", "bool"].contains(&i))) + .context(StrContext::Label("parse ident")) + .parse_next(input) + .map(|i| i.to_string()) +} + +fn parse_int(input: &mut &'_ str) -> PResult { + repeat( + 1.., + terminated(one_of('0'..='9'), repeat(0.., '_').map(|()| ())), + ) + .map(|()| ()) + .recognize() + .context(StrContext::Label("parse int")) + .parse_next(input) + .map(|i| i.parse().expect("not an integer")) +} + +fn parse_bool(input: &mut &'_ str) -> PResult { + ws(alt(("true".value(true), "false".value(false)))) + .context(StrContext::Label("parse boolean literal")) + .parse_next(input) +} + +fn parse_type(input: &mut &'_ str) -> PResult { + ws(alt(("int".value(Type::IntT), "bool".value(Type::BoolT)))) + .context(StrContext::Label("parse type literal")) + .parse_next(input) +} + +fn parse_type_anno(input: &mut &'_ str) -> PResult { + preceded(ws(":"), parse_type) + .context(StrContext::Label("parse type annotation")) + .parse_next(input) +} + +fn parse_base_expr(input: &mut &'_ str) -> PResult { + alt(( + parse_int.map(Expr::IntLit), + parse_bool.map(Expr::BoolLit), + parse_ident.map(Expr::Ident), + seq! { + _: ws("if"), + parse_expr, + _: ws("{"), + parse_expr, + _: ws("}"), + _: ws("else"), + _: ws("{"), + parse_expr, + _: ws("}"), + } + .map(|(c, if_branch, else_branch)| { + Expr::IfThenElse(Box::new(c), Box::new(if_branch), Box::new(else_branch)) + }), + seq! { + parse_ident, + _: ws("("), + parse_exprs, + _: ws(")") + } + .map(|(function_name, args)| Expr::FunCall(function_name, args)), + parse_bracketed, + )) + .context(StrContext::Label("parse base expr")) + .parse_next(input) +} + +fn parse_arith_expr(input: &mut &'_ str) -> PResult { + separated(1.., parse_base_expr, ws("+")) + .context(StrContext::Label("parse arith expression")) + .parse_next(input) + .map(|operands: Vec| { + operands + .into_iter() + .reduce(|acc, op| Expr::Plus(acc.into(), op.into())) + .expect("at least one occurrence, so unwrapping is safe!") + }) +} + +fn parse_cmp_expr(input: &mut &'_ str) -> PResult { + separated(1..3, parse_arith_expr, ws("<")) + .context(StrContext::Label("parse comparison expression")) + .parse_next(input) + .map(|operands: Vec| { + operands + .into_iter() + .reduce(|acc, op| Expr::Lt(acc.into(), op.into())) + .expect("at least one occurrence, so unwrapping is safe!") + }) +} + +fn parse_bracketed(input: &mut &'_ str) -> PResult { + delimited(ws("("), parse_expr, ws(")")) + .context(StrContext::Label("parse bracketed expression")) + .parse_next(input) +} + +fn parse_expr(input: &mut &'_ str) -> PResult { + seq! { + parse_cmp_expr, + opt(preceded(ws(":"), parse_type)), + } + .map(|(expr, type_opt)| { + let expr_clone = expr.clone(); + type_opt + .map(|t| Expr::Ascribe(expr_clone.into(), t)) + .unwrap_or(expr) + }) + .context(StrContext::Label("parse expression")) + .parse_next(input) +} + +fn parse_exprs(input: &mut &'_ str) -> PResult> { + terminated(separated(0.., ws(parse_expr), ws(",")), opt(ws(","))) + .context(StrContext::Label("parse expression list")) + .parse_next(input) +} + +fn parse_arg(input: &mut &'_ str) -> PResult { + seq! { + parse_ident, + opt(parse_type_anno) + } + .context(StrContext::Label("parse argument")) + .map(|(name, type_ann)| Arg { name, type_ann }) + .parse_next(input) +} + +fn parse_args(input: &mut &'_ str) -> PResult> { + terminated(separated(0.., ws(parse_arg), ws(",")), opt(ws(","))) + .context(StrContext::Label("parse argument list")) + .parse_next(input) +} + +fn parse_function(input: &mut &'_ str) -> PResult { + seq!( + _: ws("fun"), + parse_ident, + delimited(ws("("), parse_args, ws(")")), + opt(parse_type_anno), + _: ws("="), + parse_expr, + _: ws(";"), + ) + .map(|(name, args, return_type, body)| Function { + name, + args, + return_type, + body, + }) + .context(StrContext::Label("parse function")) + .parse_next(input) +} + +pub fn parse_program(input: &mut &'_ str) -> PResult { + seq! { + repeat(0.., parse_function), + _: ws("$"), + parse_expr, + } + .map(|(functions, main)| Program { functions, main }) + .context(StrContext::Label("parse program")) + .parse_next(input) +} + +pub fn parse(mut input: &str) -> PResult { + terminated(parse_program, eof).parse_next(&mut input) +} + +pub fn parse_trace(mut input: &str) -> PResult { + trace("trace program", parse_program).parse_next(&mut input) +} + +#[cfg(test)] +mod test { + use std::thread; + + use winnow::{ + combinator::{terminated, trace}, + Parser, + }; + + use crate::{ + parse::{parse, parse_arith_expr, parse_base_expr, parse_cmp_expr, parse_ident}, + Expr, + }; + + use super::{parse_args, parse_expr, parse_function, parse_trace}; + + #[test] + pub fn parse_expr_true() { + let mut input = "true"; + assert_eq!(Expr::BoolLit(true), parse_expr(&mut input).unwrap()); + } + + #[test] + pub fn parse_expr_false() { + let mut input = "false"; + assert_eq!(Expr::BoolLit(false), parse_expr(&mut input).unwrap()); + } + + #[test] + pub fn parse_lit_zero() { + let mut input = "0"; + assert_eq!(Expr::IntLit(0), parse_expr(&mut input).unwrap()); + } + + #[test] + pub fn parse_lit_42() { + let mut input = "42"; + assert_eq!(Expr::IntLit(42), parse_expr(&mut input).unwrap()); + } + + #[test] + pub fn parse_ident_x() { + let mut input = "x"; + assert_eq!(Expr::Ident("x".into()), parse_expr(&mut input).unwrap()); + } + + #[test] + pub fn parse_ident_x1n() { + let mut input = "x1"; + assert_eq!(Expr::Ident("x1".into()), parse_expr(&mut input).unwrap()); + } + + #[test] + pub fn parse_ident_usc_x() { + let mut input = "_x"; + assert_eq!(Expr::Ident("_x".into()), parse_expr(&mut input).unwrap()); + } + + #[test] + pub fn parse_ident_empty() { + let mut input = ""; + assert!(parse_base_expr(&mut input).is_err()); + } + + #[test] + pub fn parse_plus() { + let mut input = "x + y"; + assert_eq!( + Expr::Plus( + Expr::Ident("x".into()).into(), + Expr::Ident("y".into()).into(), + ), + trace("parse expr", parse_expr) + .parse_next(&mut input) + .unwrap(), + ); + } + + #[test] + pub fn parse_plus_assoc() { + let mut input = "x + 2 + y"; + assert_eq!( + Expr::Plus( + Box::new(Expr::Plus( + Expr::Ident("x".into()).into(), + Expr::IntLit(2).into(), + )), + Box::new(Expr::Ident("y".into()).into()), + ), + trace("parse expr", parse_expr) + .parse_next(&mut input) + .unwrap(), + ); + } + + #[test] + pub fn parse_lt() { + let mut input = "x < 42"; + assert_eq!( + Expr::Lt( + Box::new(Expr::Ident("x".into())), + Box::new(Expr::IntLit(42)), + ), + trace("parse expr", parse_expr) + .parse_next(&mut input) + .unwrap(), + ); + } + + #[test] + pub fn parse_lt_non_assoc() { + let mut input = "x < 42 < y"; + assert_eq!( + Expr::Lt( + Box::new(Expr::Ident("x".into())), + Box::new(Expr::IntLit(42)), + ), + trace("parse expr", parse_expr) + .parse_next(&mut input) + .unwrap(), + ); + } + + #[test] + pub fn parse_lt_non_assoc_bracketed() { + let mut input = "(x < 42 < y)"; + assert!(trace("parse expr", parse_expr) + .parse_next(&mut input) + .is_err()); + } + + #[test] + pub fn parse_all_constructs() { + let program = " + fun tt() = true; + fun lt(i1, i2) = (1 < 2); + fun and(b1, b2: bool): bool = + if b1 { b2: int } else { false }; + + $ tt() + "; + + assert!(parse_trace(program).is_ok()); + } +} diff --git a/scopegraphs/examples/overload/union_find.rs b/scopegraphs/examples/overload/union_find.rs new file mode 100644 index 0000000..16a56d2 --- /dev/null +++ b/scopegraphs/examples/overload/union_find.rs @@ -0,0 +1,152 @@ +use std::fmt::{Debug, Formatter}; + +use futures::Future; +use smol::channel::{bounded, Sender}; + +use crate::{ast::Type, PartialType, TypeVar}; + +#[derive(Default)] +pub struct UnionFind { + /// Records the parent of each type variable. + /// Kind of assumes type variables are assigned linearly. + /// + /// For example the "parent" of type variable 0 is stored at index 0 + parent: Vec, + /// Keep track of type variables we've given out + vars: usize, + /// A vec of signals for each type variable. + /// + /// For example, whenever type variable 0 is unified with anything, + /// we go through the list at index 0 and notify each. + callbacks: Vec>>, +} + +impl Debug for UnionFind { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{{")?; + for (idx, p) in self.parent.iter().enumerate() { + write!(f, "{idx} -> {p:?}")?; + if (idx + 1) < self.parent.len() { + write!(f, ", ")?; + } + } + write!(f, "}}") + } +} + +impl UnionFind { + /// Create a new type variable + /// (which happens to be one bigger than the previous fresh type variable) + pub(super) fn fresh(&mut self) -> TypeVar { + let old = self.vars; + self.vars += 1; + + TypeVar(old) + } + + /// Unify two partial types, asserting they are equal to each other. + /// + /// If one of left or right is a concrete type, and the other is a type variable, + /// we've essentially resolved what type the type variable is now, and we update the + /// data structure to represent that. The next [`find`](Self::find) of this type variable + /// will return the concrete type after this unification. + /// + /// Sometimes, two type variables are unified. In that case, one of the two is chosen by + /// a fair (trust me) dice roll and is made the representative of both input type variables. + /// Whenever one of the two is now unified with a concrete type, both input type variables + /// become equal to that concrete type. + pub(super) fn unify(&mut self, left: PartialType, right: PartialType) { + let left = self.find_partial_type(left); + let right = self.find_partial_type(right); + + match (left, right) { + (PartialType::Variable(left), right) | (right, PartialType::Variable(left)) => { + // FIXME: use rank heuristic in case right is a variable? + *self.get(left) = right.clone(); + if self.callbacks.len() > left.0 { + for fut in self.callbacks[left.0].drain(..) { + let _ = fut.send_blocking(right.clone()); + } + } + } + (left, right) if left != right => { + panic!("type error: cannot unify {left:?} and {right:?}"); + } + _ => {} + } + } + + /// Find the representative for a given type variable. + /// In the best case, this is a concrete type this type variable is equal to. + /// That's nice, because now we know what that type variable was supposed to be. + /// + /// However, it's possible we find another type variable instead (wrapped in a [`PartialType`]). + /// Now we know that this new type variable has the same type of the given type variable, + /// we just don't know yet which type that is. More unifications are needed. + fn find(&mut self, ty: TypeVar) -> PartialType { + let res = self.get(ty); + if let PartialType::Variable(v) = *res { + if v == ty { + return PartialType::Variable(ty); + } + + // do path compression + let root = self.find(v); + *self.get(v) = root.clone(); + root + } else { + res.clone() + } + } + + /// [find](Self::find), but for a parial type + pub(super) fn find_partial_type(&mut self, ty: PartialType) -> PartialType { + if let PartialType::Variable(v) = ty { + self.find(v) + } else { + ty + } + } + + /// Get a mutable reference to parent of a given type variable. + /// Used in the implementation of [`find`](Self::find) and [`union`](Self::union) + fn get(&mut self, tv: TypeVar) -> &mut PartialType { + let parent = &mut self.parent; + for i in parent.len()..=tv.0 { + parent.push(PartialType::Variable(TypeVar(i))); + } + + &mut parent[tv.0] + } + + #[allow(unused)] + fn type_of(&mut self, var: TypeVar) -> Option { + match self.find(var) { + PartialType::Variable(_) => None, + PartialType::Type(t) => Some(t), + } + } + + pub(super) fn type_of_partial_type(&mut self, var: PartialType) -> Option { + match self.find_partial_type(var) { + PartialType::Variable(_) => None, + PartialType::Type(t) => Some(t), + } + } + + /// Wait for when tv is unified with something. + pub(super) fn wait_for_unification( + &mut self, + tv: TypeVar, + ) -> impl Future { + let callbacks = &mut self.callbacks; + for _ in callbacks.len()..=tv.0 { + callbacks.push(vec![]); + } + + let (tx, rx) = bounded(1); + callbacks[tv.0].push(tx); + + async move { rx.recv().await.expect("sender dropped") } + } +}