From 02b4e843121766dd9376a3b003099ca3e46266d4 Mon Sep 17 00:00:00 2001 From: Stephan Boyer Date: Thu, 16 Apr 2020 22:20:23 -0700 Subject: [PATCH] Implement generalization --- examples/infinite_recursion.g | 2 +- examples/infinite_type.g | 2 +- src/de_bruijn.rs | 84 ++-- src/equality.rs | 38 +- src/evaluator.rs | 12 +- src/main.rs | 17 +- src/normalizer.rs | 13 +- src/parser.rs | 46 +- src/term.rs | 51 +- src/type_checker.rs | 872 ++++++++++++++++++++++++++++------ src/unifier.rs | 180 +++---- 11 files changed, 923 insertions(+), 394 deletions(-) diff --git a/examples/infinite_recursion.g b/examples/infinite_recursion.g index a2d343ba..5d06942a 100644 --- a/examples/infinite_recursion.g +++ b/examples/infinite_recursion.g @@ -1,2 +1,2 @@ -f = x => f x +f = (x : bool) => f x f true diff --git a/examples/infinite_type.g b/examples/infinite_type.g index cea49033..08f0589d 100644 --- a/examples/infinite_type.g +++ b/examples/infinite_type.g @@ -1,3 +1,3 @@ t = int -> t -f : t = x => f +f : t = (x : int) => f f 1 2 3 diff --git a/src/de_bruijn.rs b/src/de_bruijn.rs index 4ac46b75..c4399478 100644 --- a/src/de_bruijn.rs +++ b/src/de_bruijn.rs @@ -6,11 +6,7 @@ use crate::term::{ Quotient, Sum, True, Type, Unifier, Variable, }, }; -use std::{ - cmp::{min, Ordering}, - convert::TryFrom, - rc::Rc, -}; +use std::{cmp::Ordering, convert::TryFrom, rc::Rc}; // Shifting refers to adjusting the De Bruijn indices of free variables. A cutoff determines which // variables are considered free. This function is used to raise or lower a term into a different @@ -23,25 +19,27 @@ pub fn signed_shift<'a>(term: &Term<'a>, cutoff: usize, amount: isize) -> Option // We `clone` the borrowed `subterm` to avoid holding the dynamic borrow for too long. let borrow = { subterm.borrow().clone() }; - match borrow { - Ok(subterm) => { - signed_shift( - // The `unwrap` is safe due to [ref:unifier_shifts_valid]. - &signed_shift(&subterm, 0, *subterm_shift).unwrap(), - cutoff, - amount, - ) - } - Err(min_shift) => { - let new_shift = subterm_shift + amount; - - *subterm.borrow_mut() = Err(min(min_shift, new_shift)); + if let Some(subterm) = borrow { + signed_shift(&unsigned_shift(&subterm, 0, *subterm_shift), cutoff, amount) + } else if *subterm_shift >= cutoff { + // This `unwrap` is "virtually safe", unless the conversion overflows. + let signed_subterm_shift = isize::try_from(*subterm_shift).unwrap(); + // Adjust the shift if it results in a non-negative quantity. + if signed_subterm_shift >= -amount { Some(Term { source_range: term.source_range, - variant: Unifier(subterm.clone(), new_shift), + variant: Unifier( + subterm.clone(), + // This `unwrap` is safe due to the check above. + usize::try_from(signed_subterm_shift + amount).unwrap(), + ), }) + } else { + None } + } else { + Some(term.clone()) } } Type | Integer | IntegerLiteral(_) | Boolean | True | False => Some(term.clone()), @@ -189,8 +187,8 @@ pub fn signed_shift<'a>(term: &Term<'a>, cutoff: usize, amount: isize) -> Option // is total. pub fn unsigned_shift<'a>(term: &Term<'a>, cutoff: usize, amount: usize) -> Term<'a> { // The inner `unwrap` is "essentially safe" in that it can only fail in the virtually - // impossible case of the conversion overflowing. The outer `unwrap` is safe due to - // [ref:unifier_shifts_valid]. + // impossible case of the conversion overflowing. The outer `unwrap` is safe because `amount` + // is non-negative. signed_shift(term, cutoff, isize::try_from(amount).unwrap()).unwrap() } @@ -211,16 +209,15 @@ pub fn open<'a>( // We `clone` the borrowed `subterm` to avoid holding the dynamic borrow for too long. let borrow = { subterm.borrow().clone() }; - if let Ok(subterm) = borrow { + if let Some(subterm) = borrow { open( - // The `unwrap` is safe due to [ref:unifier_shifts_valid]. - &signed_shift(&subterm, 0, *subterm_shift).unwrap(), + &unsigned_shift(&subterm, 0, *subterm_shift), index_to_replace, term_to_insert, shift_amount, ) } else { - // The `unwrap` is safe because shifting an unresolved unifier is always safe. + // The `unwrap` is NOT justified! signed_shift(term_to_open, 0, -1).unwrap() } } @@ -431,23 +428,38 @@ mod tests { use std::{cell::RefCell, rc::Rc}; #[test] - fn signed_shift_unifier_none() { + fn signed_shift_unifier_none_valid() { assert_same!( signed_shift( &Term { source_range: None, - variant: Unifier(Rc::new(RefCell::new(Err(0))), 0), + variant: Unifier(Rc::new(RefCell::new(None)), 10), }, 0, - -42, + -4, ), Some(Term { source_range: None, - variant: Unifier(Rc::new(RefCell::new(Err(-42))), -42), + variant: Unifier(Rc::new(RefCell::new(None)), 6), }), ); } + #[test] + fn signed_shift_unifier_none_invalid() { + assert_same!( + signed_shift( + &Term { + source_range: None, + variant: Unifier(Rc::new(RefCell::new(None)), 0), + }, + 0, + -42, + ), + None, + ); + } + #[test] fn signed_shift_unifier_some() { assert_same!( @@ -455,7 +467,7 @@ mod tests { &Term { source_range: None, variant: Unifier( - Rc::new(RefCell::new(Ok(Term { + Rc::new(RefCell::new(Some(Term { source_range: None, variant: Variable("x", 10), }))), @@ -529,14 +541,14 @@ mod tests { unsigned_shift( &Term { source_range: None, - variant: Unifier(Rc::new(RefCell::new(Err(0))), 0), + variant: Unifier(Rc::new(RefCell::new(None)), 0), }, 0, 42, ), Term { source_range: None, - variant: Unifier(Rc::new(RefCell::new(Err(0))), 42), + variant: Unifier(Rc::new(RefCell::new(None)), 42), }, ); } @@ -548,7 +560,7 @@ mod tests { &Term { source_range: None, variant: Unifier( - Rc::new(RefCell::new(Ok(Term { + Rc::new(RefCell::new(Some(Term { source_range: None, variant: Variable("x", 0), }))), @@ -1299,7 +1311,7 @@ mod tests { open( &Term { source_range: None, - variant: Unifier(Rc::new(RefCell::new(Err(0))), 0), + variant: Unifier(Rc::new(RefCell::new(None)), 10), }, 0, &Term { @@ -1310,7 +1322,7 @@ mod tests { ), Term { source_range: None, - variant: Unifier(Rc::new(RefCell::new(Err(-1))), -1), + variant: Unifier(Rc::new(RefCell::new(None)), 9), }, ); } @@ -1322,7 +1334,7 @@ mod tests { &Term { source_range: None, variant: Unifier( - Rc::new(RefCell::new(Ok(Term { + Rc::new(RefCell::new(Some(Term { source_range: None, variant: Variable("x", 0), }))), diff --git a/src/equality.rs b/src/equality.rs index 111fcbe2..fd61c35b 100644 --- a/src/equality.rs +++ b/src/equality.rs @@ -1,5 +1,5 @@ use crate::{ - de_bruijn::signed_shift, + de_bruijn::unsigned_shift, term::{ Term, Variant::{ @@ -20,16 +20,8 @@ pub fn syntactically_equal<'a>(term1: &Term<'a>, term2: &Term<'a>) -> bool { loop { term1 = if let Unifier(subterm, subterm_shift) = &term1.variant { // We `clone` the borrowed `subterm` to avoid holding the dynamic borrow for too long. - let borrow = { subterm.borrow().clone() }; - - if let Ok(subterm) = borrow { - if let Some(shifted_term) = signed_shift(&subterm, 0, *subterm_shift) { - shifted_term - } else { - // The `signed_shift` failed. This means the term is malformed. The error will - // be reported during type checking. - return false; - } + if let Some(subterm) = { subterm.borrow().clone() } { + unsigned_shift(&subterm, 0, *subterm_shift) } else { break; } @@ -43,16 +35,8 @@ pub fn syntactically_equal<'a>(term1: &Term<'a>, term2: &Term<'a>) -> bool { loop { term2 = if let Unifier(subterm, subterm_shift) = &term2.variant { // We `clone` the borrowed `subterm` to avoid holding the dynamic borrow for too long. - let borrow = { subterm.borrow().clone() }; - - if let Ok(subterm) = borrow { - if let Some(shifted_term) = signed_shift(&subterm, 0, *subterm_shift) { - shifted_term - } else { - // The `signed_shift` failed. This means the term is malformed. The error will - // be reported during type checking. - return false; - } + if let Some(subterm) = { subterm.borrow().clone() } { + unsigned_shift(&subterm, 0, *subterm_shift) } else { break; } @@ -177,7 +161,7 @@ mod tests { let term1 = Term { source_range: None, variant: Unifier( - Rc::new(RefCell::new(Ok(Term { + Rc::new(RefCell::new(Some(Term { source_range: None, variant: Variable("x", 0), }))), @@ -203,7 +187,7 @@ mod tests { let term2 = Term { source_range: None, variant: Unifier( - Rc::new(RefCell::new(Ok(Term { + Rc::new(RefCell::new(Some(Term { source_range: None, variant: Variable("x", 0), }))), @@ -216,7 +200,7 @@ mod tests { #[test] fn syntactically_equal_unifier_same_pointer_same_shift() { - let rc = Rc::new(RefCell::new(Err(0))); + let rc = Rc::new(RefCell::new(None)); let term1 = Term { source_range: None, @@ -233,7 +217,7 @@ mod tests { #[test] fn syntactically_inequal_unifier_same_pointer_different_shift() { - let rc = Rc::new(RefCell::new(Err(0))); + let rc = Rc::new(RefCell::new(None)); let term1 = Term { source_range: None, @@ -252,12 +236,12 @@ mod tests { fn syntactically_inequal_unifier_different_pointer_same_shift() { let term1 = Term { source_range: None, - variant: Unifier(Rc::new(RefCell::new(Err(0))), 5), + variant: Unifier(Rc::new(RefCell::new(None)), 5), }; let term2 = Term { source_range: None, - variant: Unifier(Rc::new(RefCell::new(Err(0))), 5), + variant: Unifier(Rc::new(RefCell::new(None)), 5), }; assert!(!syntactically_equal(&term1, &term2)); diff --git a/src/evaluator.rs b/src/evaluator.rs index fbbf4000..519f80c8 100644 --- a/src/evaluator.rs +++ b/src/evaluator.rs @@ -1,5 +1,5 @@ use crate::{ - de_bruijn::{open, signed_shift, unsigned_shift}, + de_bruijn::{open, unsigned_shift}, error::Error, format::CodeStr, term::{ @@ -48,13 +48,9 @@ pub fn step<'a>(term: &Term<'a>) -> Option> { | True | False => None, Unifier(subterm, subterm_shift) => { - // We `clone` the borrowed `subterm` to avoid holding the dynamic borrow for too long. - let borrow = { subterm.borrow().clone() }; - - // If the unifier points to something, step to it. Otherwise, we're stuck. - borrow - .ok() - .and_then(|subterm| signed_shift(&subterm, 0, *subterm_shift)) + // If the unifier points to something, step to it. Otherwise, we're stuck. We `clone` + // the borrowed `subterm` to avoid holding the dynamic borrow for too long. + { subterm.borrow().clone() }.map(|subterm| unsigned_shift(&subterm, 0, *subterm_shift)) } Application(applicand, argument) => { // Try to step the applicand. diff --git a/src/main.rs b/src/main.rs index 0c693e5a..2ff76099 100644 --- a/src/main.rs +++ b/src/main.rs @@ -161,7 +161,7 @@ fn run(source_path: &Path, check_only: bool) -> Result<(), Error> { // Type check the term. let mut typing_context = vec![]; let mut definitions_context = vec![]; - let _ = type_check( + let (elaborated_term, elaborated_type) = type_check( Some(source_path), &source_contents, &term, @@ -170,9 +170,18 @@ fn run(source_path: &Path, check_only: bool) -> Result<(), Error> { ) .map_err(collect_errors)?; - // Evaluate the term. - if !check_only { - let value = evaluate(&term)?; + // Evaluate the term if applicable. + if check_only { + println!( + "Elaborated term:\n\n{}", + elaborated_term.to_string().code_str(), + ); + println!( + "\nElaborated type:\n\n{}", + elaborated_type.to_string().code_str(), + ); + } else { + let value = evaluate(&elaborated_term)?; println!("{}", value.to_string().code_str()); } diff --git a/src/normalizer.rs b/src/normalizer.rs index c1641a93..868bbf3a 100644 --- a/src/normalizer.rs +++ b/src/normalizer.rs @@ -1,5 +1,5 @@ use crate::{ - de_bruijn::{open, signed_shift, unsigned_shift}, + de_bruijn::{open, unsigned_shift}, term::{ Term, Variant::{ @@ -35,12 +35,11 @@ pub fn normalize_weak_head<'a>( let borrow = { subterm.borrow().clone() }; // If the unifier points to something, normalize it. Otherwise, we're stuck. - if let Ok(subterm) = borrow { - if let Some(subterm) = signed_shift(&subterm, 0, *subterm_shift) { - normalize_weak_head(&subterm, definitions_context) - } else { - term.clone() - } + if let Some(subterm) = borrow { + normalize_weak_head( + &unsigned_shift(&subterm, 0, *subterm_shift), + definitions_context, + ) } else { term.clone() } diff --git a/src/parser.rs b/src/parser.rs index e919f3c3..0ef6cb19 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -3366,7 +3366,7 @@ fn resolve_variables<'a>( // Construct and return a unifier. term::Term { source_range: Some(term.source_range.0), - variant: term::Variant::Unifier(Rc::new(RefCell::new(Err(0))), 0), + variant: term::Variant::Unifier(Rc::new(RefCell::new(None)), 0), } } } @@ -3409,7 +3409,7 @@ fn resolve_variables<'a>( } else { Rc::new(term::Term { source_range: None, - variant: term::Variant::Unifier(Rc::new(RefCell::new(Err(0))), 0), + variant: term::Variant::Unifier(Rc::new(RefCell::new(None)), 0), }) }, Rc::new(resolve_variables( @@ -3541,7 +3541,9 @@ fn resolve_variables<'a>( // Resolve variables in the definitions and annotations. let mut resolved_definitions = vec![]; - for (inner_variable, inner_annotation, inner_definition) in &definitions { + for (i, (inner_variable, inner_annotation, inner_definition)) in + definitions.iter().enumerate() + { // Temporarily borrow from the scope guard. let mut guard = context_cell.borrow_mut(); let (borrowed_context, _) = &mut (*guard); @@ -3558,7 +3560,10 @@ fn resolve_variables<'a>( )), None => Rc::new(term::Term { source_range: None, - variant: term::Variant::Unifier(Rc::new(RefCell::new(Err(0))), 0), + variant: term::Variant::Unifier( + Rc::new(RefCell::new(None)), + definitions.len() - i, + ), }), }; @@ -3937,22 +3942,15 @@ fn check_definitions<'a>( assert_eq!(*subterm_shift, 0); // We `clone` the borrowed `subterm` to avoid holding the dynamic borrow for too long. - let borrow = { subterm.borrow().clone() }; - - match borrow { - Ok(subterm) => { - check_definitions( - source_path, - source_contents, - &subterm, - depth, - context, - errors, - ); - } - Err(min_shift) => { - assert_eq!(min_shift, 0); - } + if let Some(subterm) = { subterm.borrow().clone() } { + check_definitions( + source_path, + source_contents, + &subterm, + depth, + context, + errors, + ); } } term::Variant::Lambda(_, _, domain, body) => { @@ -4205,7 +4203,7 @@ mod tests { parse(None, source, &tokens[..], &context[..]).unwrap(), Term { source_range: Some((0, 1)), - variant: Unifier(Rc::new(RefCell::new(Err(0))), 0), + variant: Unifier(Rc::new(RefCell::new(None)), 0), }, ); } @@ -4237,7 +4235,7 @@ mod tests { false, Rc::new(Term { source_range: None, - variant: Unifier(Rc::new(RefCell::new(Err(0))), 0), + variant: Unifier(Rc::new(RefCell::new(None)), 0), }), Rc::new(Term { source_range: Some((5, 6)), @@ -4263,7 +4261,7 @@ mod tests { true, Rc::new(Term { source_range: None, - variant: Unifier(Rc::new(RefCell::new(Err(0))), 0), + variant: Unifier(Rc::new(RefCell::new(None)), 0), }), Rc::new(Term { source_range: Some((7, 8)), @@ -4645,7 +4643,7 @@ mod tests { "x", Rc::new(Term { source_range: None, - variant: Unifier(Rc::new(RefCell::new(Err(0))), 0), + variant: Unifier(Rc::new(RefCell::new(None)), 2), }), Rc::new(Term { source_range: Some((4, 26)), diff --git a/src/term.rs b/src/term.rs index 344c1f00..c4f52fc6 100644 --- a/src/term.rs +++ b/src/term.rs @@ -1,5 +1,5 @@ use crate::{ - de_bruijn::signed_shift, + de_bruijn::unsigned_shift, token::{BOOLEAN_KEYWORD, FALSE_KEYWORD, INTEGER_KEYWORD, TRUE_KEYWORD, TYPE_KEYWORD}, }; use num_bigint::BigInt; @@ -21,10 +21,7 @@ pub struct Term<'a> { // Each term has a "variant" describing what kind of term it is. #[derive(Clone, Debug)] pub enum Variant<'a> { - // For `Unifier` terms, we maintain the invariant that the `subterm`, if it exists, can be - // shifted by `subterm_shift` [tag:unifier_shifts_valid]. Note that shifting is a partial - // operation, and that's why this invariant doesn't trivially hold. - Unifier(Rc, isize>>>, isize), // (subterm/min_shift, subterm_shift) + Unifier(Rc>>>, usize), // (subterm, subterm_shift) Type, Variable(&'a str, usize), @@ -61,19 +58,22 @@ impl<'a> Display for Term<'a> { impl<'a> Display for Variant<'a> { fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { match self { - Self::Unifier(subterm, _) => { + Self::Unifier(subterm_rc, subterm_shift) => { // We `clone` the borrowed `subterm` to avoid holding the dynamic borrow for too // long. - let borrow = { subterm.borrow().clone() }; - - if let Ok(subterm) = borrow { - write!(f, "{}", subterm) + if let Some(subterm) = { subterm_rc.borrow().clone() } { + write!( + f, + "{}[{:p}]", + unsigned_shift(&subterm, 0, *subterm_shift), + *subterm_rc, + ) } else { - write!(f, "_") + write!(f, "_[{:p}^{}]", *subterm_rc, subterm_shift) } } Self::Type => write!(f, "{}", TYPE_KEYWORD), - Self::Variable(variable, _) => write!(f, "{}", variable), + Self::Variable(variable, index) => write!(f, "{}@{}", variable, index), Self::Lambda(variable, implicit, domain, body) => { if *implicit { write!(f, "{{{} : {}}} => {}", variable, domain, body) @@ -149,12 +149,10 @@ impl<'a> Display for Variant<'a> { // parsing ambiguities in any context. fn group<'a>(term: &Term<'a>) -> String { match &term.variant { - Variant::Unifier(subterm, _) => { + Variant::Unifier(subterm, subterm_shift) => { // We `clone` the borrowed `subterm` to avoid holding the dynamic borrow for too long. - let borrow = { subterm.borrow().clone() }; - - if let Ok(subterm) = borrow { - group(&subterm) + if let Some(subterm) = { subterm.borrow().clone() } { + group(&unsigned_shift(&subterm, 0, *subterm_shift)) } else { format!("{}", term) } @@ -190,15 +188,12 @@ pub fn free_variables<'a>(term: &Term<'a>, cutoff: usize, variables: &mut HashSe match &term.variant { Variant::Unifier(subterm, subterm_shift) => { // We `clone` the borrowed `subterm` to avoid holding the dynamic borrow for too long. - let borrow = { subterm.borrow().clone() }; - - if let Ok(subterm) = borrow { - if let Some(subterm) = signed_shift(&subterm, 0, *subterm_shift) { - free_variables(&subterm, cutoff, variables); - } else { - // The `signed_shift` failed. This means the term is malformed. The - // error will be reported during type checking. - } + if let Some(subterm) = { subterm.borrow().clone() } { + free_variables( + &unsigned_shift(&subterm, 0, *subterm_shift), + cutoff, + variables, + ); } } Variant::Type @@ -277,7 +272,7 @@ mod tests { free_variables( &Term { source_range: None, - variant: Unifier(Rc::new(RefCell::new(Err(0))), 0), + variant: Unifier(Rc::new(RefCell::new(None)), 0), }, 10, &mut variables, @@ -294,7 +289,7 @@ mod tests { &Term { source_range: None, variant: Unifier( - Rc::new(RefCell::new(Ok(Term { + Rc::new(RefCell::new(Some(Term { source_range: None, variant: Variable("x", 15), }))), diff --git a/src/type_checker.rs b/src/type_checker.rs index e32aa27e..cca888bc 100644 --- a/src/type_checker.rs +++ b/src/type_checker.rs @@ -11,12 +11,13 @@ use crate::{ Product, Quotient, Sum, True, Type, Unifier, Variable, }, }, - unifier::unify, + unifier::{collect_unifiers, unify}, }; use scopeguard::defer; -use std::{cell::RefCell, path::Path, rc::Rc}; +use std::{cell::RefCell, collections::HashSet, path::Path, rc::Rc}; -// This is the top-level type checking function. Invariants: +// This is the top-level type checking function. It returns the pair `(elaborated_term, type)`. +// Invariants: // - The two contexts have the same length. // - When this function is finished, the contexts are left unmodified. pub fn type_check<'a>( @@ -25,10 +26,10 @@ pub fn type_check<'a>( term: &Term<'a>, typing_context: &mut Vec<(Rc>, usize)>, definitions_context: &mut Vec>, usize)>>, -) -> Result, Vec> { +) -> Result<(Term<'a>, Term<'a>), Vec> { let mut errors = vec![]; - let result = type_check_rec( + let (elaborated_term, term_type) = type_check_rec( source_path, source_contents, term, @@ -38,13 +39,14 @@ pub fn type_check<'a>( ); if errors.is_empty() { - Ok(result) + Ok((elaborated_term, term_type)) } else { Err(errors) } } // This helper function is the workhorse of the `type_check` function above. +#[allow(clippy::cognitive_complexity)] #[allow(clippy::too_many_lines)] pub fn type_check_rec<'a>( source_path: Option<&'a Path>, @@ -53,7 +55,7 @@ pub fn type_check_rec<'a>( typing_context: &mut Vec<(Rc>, usize)>, definitions_context: &mut Vec>, usize)>>, errors: &mut Vec, -) -> Term<'a> { +) -> (Term<'a>, Term<'a>) { // Construct the type of all types once here rather than constructing it many times later. let type_term = Term { source_range: None, @@ -73,16 +75,19 @@ pub fn type_check_rec<'a>( }; // The typing rules are syntax-directed, so we pattern-match on the term. - match &term.variant { - Unifier(_, _) | Type | Integer | Boolean => type_term, + let (elaborated_term, elaborated_type) = match &term.variant { + Unifier(_, _) | Type | Integer | Boolean => (term.clone(), type_term.clone()), Variable(_, index) => { // Shift the type such that it's valid in the current context. let (variable_type, offset) = &typing_context[typing_context.len() - 1 - *index]; - unsigned_shift(variable_type, 0, *index + 1 - offset) + ( + term.clone(), + unsigned_shift(variable_type, 0, *index + 1 - offset), + ) } Lambda(variable, implicit, domain, body) => { // Infer the type of the domain. - let domain_type = type_check_rec( + let (domain, domain_type) = type_check_rec( source_path, source_contents, domain, @@ -104,11 +109,11 @@ pub fn type_check_rec<'a>( // Temporarily add the variable's type to the context for the purpose of inferring the // codomain. - typing_context.push((domain.clone(), 0)); + typing_context.push((Rc::new(domain.clone()), 0)); definitions_context.push(None); // Infer the codomain. - let codomain = type_check_rec( + let (body, codomain) = type_check_rec( source_path, source_contents, body, @@ -121,15 +126,46 @@ pub fn type_check_rec<'a>( definitions_context.pop(); typing_context.pop(); - // Construct and return the pi type. - Term { - source_range: term.source_range, - variant: Pi(variable, *implicit, domain.clone(), Rc::new(codomain)), + /* + // Replace outer implicit variables in the codomain with unifiers. These will be + // generalized below. + while let Pi(_, true, _, codomain_codomain) = &codomain.variant { + let unifier = Term { + source_range: None, + variant: Unifier(Rc::new(RefCell::new(None)), 0), + }; + + codomain = open(codomain_codomain, 0, &unifier, 0); + + // We could skip this check and always construct an application as per the + // `else` branch. However, we do this to avoid creating unnecessary "junk" + // applications. + if let Lambda(_, true, _, body_body) = &body.variant { + body = open(body_body, 0, &unifier, 0); + } else { + body = Term { + source_range: None, + variant: Application(Rc::new(body), Rc::new(unifier)), + }; + } } + */ + + // Return the new term and type. + ( + Term { + source_range: term.source_range, + variant: Lambda(variable, *implicit, Rc::new(domain.clone()), Rc::new(body)), + }, + Term { + source_range: term.source_range, + variant: Pi(variable, *implicit, Rc::new(domain), Rc::new(codomain)), + }, + ) } - Pi(_, _, domain, codomain) => { + Pi(variable, implicit, domain, codomain) => { // Infer the type of the domain. - let domain_type = type_check_rec( + let (domain, domain_type) = type_check_rec( source_path, source_contents, domain, @@ -151,11 +187,11 @@ pub fn type_check_rec<'a>( // Temporarily add the variable's type to the context for the purpose of inferring the // type of the codomain. - typing_context.push((domain.clone(), 0)); + typing_context.push((Rc::new(domain.clone()), 0)); definitions_context.push(None); // Infer the type of the codomain. - let codomain_type = type_check_rec( + let (codomain, codomain_type) = type_check_rec( source_path, source_contents, codomain, @@ -179,12 +215,18 @@ pub fn type_check_rec<'a>( definitions_context.pop(); typing_context.pop(); - // The type of a pi type is the type of all types. - type_term + // Return the new term and type. + ( + Term { + source_range: term.source_range, + variant: Pi(variable, *implicit, Rc::new(domain), Rc::new(codomain)), + }, + type_term.clone(), + ) } Application(applicand, argument) => { // Infer the type of the applicand. - let applicand_type = type_check_rec( + let (applicand, applicand_type) = type_check_rec( source_path, source_contents, applicand, @@ -196,13 +238,13 @@ pub fn type_check_rec<'a>( // Construct a unification term for the domain. let domain = Rc::new(Term { source_range: None, - variant: Unifier(Rc::new(RefCell::new(Err(0))), 0), + variant: Unifier(Rc::new(RefCell::new(None)), 0), }); // Construct a unification term for the codomain. let codomain = Rc::new(Term { source_range: None, - variant: Unifier(Rc::new(RefCell::new(Err(0))), 0), + variant: Unifier(Rc::new(RefCell::new(None)), 0), }); // Construct a pi type for unification. @@ -231,7 +273,7 @@ pub fn type_check_rec<'a>( }; // Infer the type of the argument. - let argument_type = type_check_rec( + let (argument, argument_type) = type_check_rec( source_path, source_contents, argument, @@ -255,8 +297,14 @@ pub fn type_check_rec<'a>( )); } - // Construct and return the codomain specialized to the argument. - open(&codomain, 0, argument, 0) + // Return the new term and type. + ( + Term { + source_range: term.source_range, + variant: Application(Rc::new(applicand), Rc::new(argument.clone())), + }, + open(&codomain, 0, &argument, 0), + ) } Let(definitions, body) => { // When the function returns, remove the variables from the context that we are @@ -295,43 +343,49 @@ pub fn type_check_rec<'a>( } // Infer/check the types of the definitions. - for (_, annotation, definition) in definitions.iter() { - // Temporarily borrow from the scope guard. - let mut guard = context_cell.borrow_mut(); - let ((borrowed_typing_context, borrowed_definitions_context), _) = &mut (*guard); - - // Infer the type of the definition. - let definition_type = type_check_rec( - source_path, - source_contents, - definition, - borrowed_typing_context, - borrowed_definitions_context, - errors, - ); - - // Check the type against the annotation. - if !unify(&definition_type, &annotation, borrowed_definitions_context) { - errors.push(throw( - &format!( - "This has type {}, but it was expected to have type {}:", - definition_type.to_string().code_str(), - annotation.to_string().code_str(), - ), + let definitions: Vec<_> = definitions + .iter() + .map(|(variable, annotation, definition)| { + // Temporarily borrow from the scope guard. + let mut guard = context_cell.borrow_mut(); + let ((borrowed_typing_context, borrowed_definitions_context), _) = + &mut (*guard); + + // Infer the type of the definition. + let (definition, definition_type) = type_check_rec( source_path, - definition - .source_range - .map(|source_range| (source_contents, source_range)), - )); - } - } + source_contents, + definition, + borrowed_typing_context, + borrowed_definitions_context, + errors, + ); + + // Check the type against the annotation. + if !unify(&definition_type, &annotation, borrowed_definitions_context) { + errors.push(throw( + &format!( + "This has type {}, but it was expected to have type {}:", + definition_type.to_string().code_str(), + annotation.to_string().code_str(), + ), + source_path, + definition + .source_range + .map(|source_range| (source_contents, source_range)), + )); + } + + (*variable, annotation.clone(), Rc::new(definition)) + }) + .collect(); // Temporarily borrow from the scope guard. let mut guard = context_cell.borrow_mut(); let ((borrowed_typing_context, borrowed_definitions_context), _) = &mut (*guard); // Infer the type of the body. - let body_type = type_check_rec( + let (body, body_type) = type_check_rec( source_path, source_contents, body, @@ -340,53 +394,59 @@ pub fn type_check_rec<'a>( errors, ); - // Return the opened type of the body. - (0..definitions.len()).fold(body_type, |acc, i| { - // Compute this once rather than multiple times. - let definitions_len_minus_one_minus_i = definitions.len() - 1 - i; - - // Open the body. - open( - &acc, - 0, - &Term { - source_range: None, - variant: Let( - definitions - .iter() - .map(|(variable, annotation, definition)| { - ( - *variable, - Rc::new(unsigned_shift( - annotation, - 0, - definitions_len_minus_one_minus_i, - )), - Rc::new(unsigned_shift( - definition, - 0, - definitions_len_minus_one_minus_i, - )), - ) - }) - .collect(), - Rc::new(Term { - source_range: None, - variant: Variable( - definitions[definitions_len_minus_one_minus_i].0, - i, - ), - }), - ), - }, - 0, - ) - }) + // Return the new term and type. + ( + Term { + source_range: term.source_range, + variant: Let(definitions.clone(), Rc::new(body.clone())), + }, + (0..definitions.len()).fold(body_type, |acc, i| { + // Compute this once rather than multiple times. + let definitions_len_minus_one_minus_i = definitions.len() - 1 - i; + + // Open the body. + open( + &acc, + 0, + &Term { + source_range: None, + variant: Let( + definitions + .iter() + .map(|(variable, annotation, definition)| { + ( + *variable, + Rc::new(unsigned_shift( + annotation, + 0, + definitions_len_minus_one_minus_i, + )), + Rc::new(unsigned_shift( + definition, + 0, + definitions_len_minus_one_minus_i, + )), + ) + }) + .collect(), + Rc::new(Term { + source_range: None, + variant: Variable( + definitions[definitions_len_minus_one_minus_i].0, + i, + ), + }), + ), + }, + 0, + ) + }), + ) } - IntegerLiteral(_) => integer_term, + IntegerLiteral(_) => (term.clone(), integer_term), Negation(subterm) => { // Infer the type of the subterm. - let subterm_type = type_check_rec( + let (subterm, subterm_type) = type_check_rec( source_path, source_contents, subterm, @@ -410,15 +470,378 @@ pub fn type_check_rec<'a>( )); }; - // Return the type of integers. - integer_term + // Return the new term and type. + ( + Term { + source_range: term.source_range, + variant: Negation(Rc::new(subterm)), + }, + integer_term, + ) + } + Sum(term1, term2) => { + // Infer the type of the left subterm. + let (term1, term1_type) = type_check_rec( + source_path, + source_contents, + term1, + typing_context, + definitions_context, + errors, + ); + + // Check that the type of the left subterm is the type of integers. + if !unify(&term1_type, &integer_term, definitions_context) { + errors.push(throw( + &format!( + "This has type {}, but it should have type {}:", + term1_type.to_string().code_str(), + integer_term.to_string().code_str(), + ), + source_path, + term1 + .source_range + .map(|source_range| (source_contents, source_range)), + )); + }; + + // Infer the type of the right subterm. + let (term2, term2_type) = type_check_rec( + source_path, + source_contents, + term2, + typing_context, + definitions_context, + errors, + ); + + // Check that the type of the right subterm is the type of integers. + if !unify(&term2_type, &integer_term, definitions_context) { + errors.push(throw( + &format!( + "This has type {}, but it should have type {}:", + term2_type.to_string().code_str(), + integer_term.to_string().code_str(), + ), + source_path, + term2 + .source_range + .map(|source_range| (source_contents, source_range)), + )); + }; + + // Return the new term and type. + ( + Term { + source_range: term.source_range, + variant: Sum(Rc::new(term1), Rc::new(term2)), + }, + integer_term, + ) + } + Difference(term1, term2) => { + // Infer the type of the left subterm. + let (term1, term1_type) = type_check_rec( + source_path, + source_contents, + term1, + typing_context, + definitions_context, + errors, + ); + + // Check that the type of the left subterm is the type of integers. + if !unify(&term1_type, &integer_term, definitions_context) { + errors.push(throw( + &format!( + "This has type {}, but it should have type {}:", + term1_type.to_string().code_str(), + integer_term.to_string().code_str(), + ), + source_path, + term1 + .source_range + .map(|source_range| (source_contents, source_range)), + )); + }; + + // Infer the type of the right subterm. + let (term2, term2_type) = type_check_rec( + source_path, + source_contents, + term2, + typing_context, + definitions_context, + errors, + ); + + // Check that the type of the right subterm is the type of integers. + if !unify(&term2_type, &integer_term, definitions_context) { + errors.push(throw( + &format!( + "This has type {}, but it should have type {}:", + term2_type.to_string().code_str(), + integer_term.to_string().code_str(), + ), + source_path, + term2 + .source_range + .map(|source_range| (source_contents, source_range)), + )); + }; + + // Return the new term and type. + ( + Term { + source_range: term.source_range, + variant: Difference(Rc::new(term1), Rc::new(term2)), + }, + integer_term, + ) + } + Product(term1, term2) => { + // Infer the type of the left subterm. + let (term1, term1_type) = type_check_rec( + source_path, + source_contents, + term1, + typing_context, + definitions_context, + errors, + ); + + // Check that the type of the left subterm is the type of integers. + if !unify(&term1_type, &integer_term, definitions_context) { + errors.push(throw( + &format!( + "This has type {}, but it should have type {}:", + term1_type.to_string().code_str(), + integer_term.to_string().code_str(), + ), + source_path, + term1 + .source_range + .map(|source_range| (source_contents, source_range)), + )); + }; + + // Infer the type of the right subterm. + let (term2, term2_type) = type_check_rec( + source_path, + source_contents, + term2, + typing_context, + definitions_context, + errors, + ); + + // Check that the type of the right subterm is the type of integers. + if !unify(&term2_type, &integer_term, definitions_context) { + errors.push(throw( + &format!( + "This has type {}, but it should have type {}:", + term2_type.to_string().code_str(), + integer_term.to_string().code_str(), + ), + source_path, + term2 + .source_range + .map(|source_range| (source_contents, source_range)), + )); + }; + + // Return the new term and type. + ( + Term { + source_range: term.source_range, + variant: Product(Rc::new(term1), Rc::new(term2)), + }, + integer_term, + ) + } + Quotient(term1, term2) => { + // Infer the type of the left subterm. + let (term1, term1_type) = type_check_rec( + source_path, + source_contents, + term1, + typing_context, + definitions_context, + errors, + ); + + // Check that the type of the left subterm is the type of integers. + if !unify(&term1_type, &integer_term, definitions_context) { + errors.push(throw( + &format!( + "This has type {}, but it should have type {}:", + term1_type.to_string().code_str(), + integer_term.to_string().code_str(), + ), + source_path, + term1 + .source_range + .map(|source_range| (source_contents, source_range)), + )); + }; + + // Infer the type of the right subterm. + let (term2, term2_type) = type_check_rec( + source_path, + source_contents, + term2, + typing_context, + definitions_context, + errors, + ); + + // Check that the type of the right subterm is the type of integers. + if !unify(&term2_type, &integer_term, definitions_context) { + errors.push(throw( + &format!( + "This has type {}, but it should have type {}:", + term2_type.to_string().code_str(), + integer_term.to_string().code_str(), + ), + source_path, + term2 + .source_range + .map(|source_range| (source_contents, source_range)), + )); + }; + + // Return the new term and type. + ( + Term { + source_range: term.source_range, + variant: Quotient(Rc::new(term1), Rc::new(term2)), + }, + integer_term, + ) + } + LessThan(term1, term2) => { + // Infer the type of the left subterm. + let (term1, term1_type) = type_check_rec( + source_path, + source_contents, + term1, + typing_context, + definitions_context, + errors, + ); + + // Check that the type of the left subterm is the type of integers. + if !unify(&term1_type, &integer_term, definitions_context) { + errors.push(throw( + &format!( + "This has type {}, but it should have type {}:", + term1_type.to_string().code_str(), + integer_term.to_string().code_str(), + ), + source_path, + term1 + .source_range + .map(|source_range| (source_contents, source_range)), + )); + }; + + // Infer the type of the right subterm. + let (term2, term2_type) = type_check_rec( + source_path, + source_contents, + term2, + typing_context, + definitions_context, + errors, + ); + + // Check that the type of the right subterm is the type of integers. + if !unify(&term2_type, &integer_term, definitions_context) { + errors.push(throw( + &format!( + "This has type {}, but it should have type {}:", + term2_type.to_string().code_str(), + integer_term.to_string().code_str(), + ), + source_path, + term2 + .source_range + .map(|source_range| (source_contents, source_range)), + )); + }; + + // Return the new term and type. + ( + Term { + source_range: term.source_range, + variant: LessThan(Rc::new(term1), Rc::new(term2)), + }, + boolean_term, + ) + } + LessThanOrEqualTo(term1, term2) => { + // Infer the type of the left subterm. + let (term1, term1_type) = type_check_rec( + source_path, + source_contents, + term1, + typing_context, + definitions_context, + errors, + ); + + // Check that the type of the left subterm is the type of integers. + if !unify(&term1_type, &integer_term, definitions_context) { + errors.push(throw( + &format!( + "This has type {}, but it should have type {}:", + term1_type.to_string().code_str(), + integer_term.to_string().code_str(), + ), + source_path, + term1 + .source_range + .map(|source_range| (source_contents, source_range)), + )); + }; + + // Infer the type of the right subterm. + let (term2, term2_type) = type_check_rec( + source_path, + source_contents, + term2, + typing_context, + definitions_context, + errors, + ); + + // Check that the type of the right subterm is the type of integers. + if !unify(&term2_type, &integer_term, definitions_context) { + errors.push(throw( + &format!( + "This has type {}, but it should have type {}:", + term2_type.to_string().code_str(), + integer_term.to_string().code_str(), + ), + source_path, + term2 + .source_range + .map(|source_range| (source_contents, source_range)), + )); + }; + + // Return the new term and type. + ( + Term { + source_range: term.source_range, + variant: LessThanOrEqualTo(Rc::new(term1), Rc::new(term2)), + }, + boolean_term, + ) } - Sum(term1, term2) - | Difference(term1, term2) - | Product(term1, term2) - | Quotient(term1, term2) => { + EqualTo(term1, term2) => { // Infer the type of the left subterm. - let term1_type = type_check_rec( + let (term1, term1_type) = type_check_rec( source_path, source_contents, term1, @@ -443,7 +866,7 @@ pub fn type_check_rec<'a>( }; // Infer the type of the right subterm. - let term2_type = type_check_rec( + let (term2, term2_type) = type_check_rec( source_path, source_contents, term2, @@ -467,16 +890,18 @@ pub fn type_check_rec<'a>( )); }; - // Return the type of integers. - integer_term + // Return the new term and type. + ( + Term { + source_range: term.source_range, + variant: EqualTo(Rc::new(term1), Rc::new(term2)), + }, + boolean_term, + ) } - LessThan(term1, term2) - | LessThanOrEqualTo(term1, term2) - | EqualTo(term1, term2) - | GreaterThan(term1, term2) - | GreaterThanOrEqualTo(term1, term2) => { + GreaterThan(term1, term2) => { // Infer the type of the left subterm. - let term1_type = type_check_rec( + let (term1, term1_type) = type_check_rec( source_path, source_contents, term1, @@ -501,7 +926,7 @@ pub fn type_check_rec<'a>( }; // Infer the type of the right subterm. - let term2_type = type_check_rec( + let (term2, term2_type) = type_check_rec( source_path, source_contents, term2, @@ -525,12 +950,78 @@ pub fn type_check_rec<'a>( )); }; - // Return the type of Booleans. - boolean_term + // Return the new term and type. + ( + Term { + source_range: term.source_range, + variant: GreaterThan(Rc::new(term1), Rc::new(term2)), + }, + boolean_term, + ) + } + GreaterThanOrEqualTo(term1, term2) => { + // Infer the type of the left subterm. + let (term1, term1_type) = type_check_rec( + source_path, + source_contents, + term1, + typing_context, + definitions_context, + errors, + ); + + // Check that the type of the left subterm is the type of integers. + if !unify(&term1_type, &integer_term, definitions_context) { + errors.push(throw( + &format!( + "This has type {}, but it should have type {}:", + term1_type.to_string().code_str(), + integer_term.to_string().code_str(), + ), + source_path, + term1 + .source_range + .map(|source_range| (source_contents, source_range)), + )); + }; + + // Infer the type of the right subterm. + let (term2, term2_type) = type_check_rec( + source_path, + source_contents, + term2, + typing_context, + definitions_context, + errors, + ); + + // Check that the type of the right subterm is the type of integers. + if !unify(&term2_type, &integer_term, definitions_context) { + errors.push(throw( + &format!( + "This has type {}, but it should have type {}:", + term2_type.to_string().code_str(), + integer_term.to_string().code_str(), + ), + source_path, + term2 + .source_range + .map(|source_range| (source_contents, source_range)), + )); + }; + + // Return the new term and type. + ( + Term { + source_range: term.source_range, + variant: GreaterThanOrEqualTo(Rc::new(term1), Rc::new(term2)), + }, + boolean_term, + ) } If(condition, then_branch, else_branch) => { // Infer the type of the condition. - let condition_type = type_check_rec( + let (condition, condition_type) = type_check_rec( source_path, source_contents, condition, @@ -555,7 +1046,7 @@ pub fn type_check_rec<'a>( }; // Infer the type of the then branch. - let then_branch_type = type_check_rec( + let (then_branch, then_branch_type) = type_check_rec( source_path, source_contents, then_branch, @@ -565,7 +1056,7 @@ pub fn type_check_rec<'a>( ); // Infer the type of the else branch. - let else_branch_type = type_check_rec( + let (else_branch, else_branch_type) = type_check_rec( source_path, source_contents, else_branch, @@ -589,11 +1080,86 @@ pub fn type_check_rec<'a>( )); }; - // Return the type of the branches. - then_branch_type + // Return the new term and type. + ( + Term { + source_range: term.source_range, + variant: If( + Rc::new(condition), + Rc::new(then_branch), + Rc::new(else_branch), + ), + }, + then_branch_type, + ) } - True | False => boolean_term, - } + True | False => (term.clone(), boolean_term), + }; + + // Collect the unifiers for generalization. + let mut unifiers = vec![]; + let mut visited = HashSet::new(); + collect_unifiers(&elaborated_type, 0, &mut unifiers, &mut visited); + collect_unifiers(&elaborated_term, 0, &mut unifiers, &mut visited); + + (elaborated_term, elaborated_type) + + /* + println!("Elaborated term: {}", elaborated_term); + println!("Elaborated type: {}", elaborated_type); + + // Generalize. + let (generalized_term, generalized_type) = unifiers.iter().rev().enumerate().fold( + (elaborated_term, elaborated_type), + |acc, (i, unifier)| { + let new_unifier = Rc::new(RefCell::new(Err(0))); + + { + *unifier.0.borrow_mut() = Ok(Term { + source_range: None, + variant: Unifier(new_unifier.clone(), -1), // -1 to cancel out the shifts below + }); + } + + let shifted_term = Term { + source_range: None, + variant: Lambda( + PLACEHOLDER_VARIABLE, + true, + Rc::new(type_term.clone()), + Rc::new(unsigned_shift(&acc.0, 0, 1)), + ), + }; + + let shifted_type = Term { + source_range: None, + variant: Pi( + PLACEHOLDER_VARIABLE, + true, + Rc::new(type_term.clone()), + Rc::new(unsigned_shift(&acc.1, 0, 1)), + ), + }; + + let generalization_variable = Term { + source_range: None, + variant: Variable(PLACEHOLDER_VARIABLE, unifier.1 + i), + }; + + // Resolve the unifier again to undo the shift. + { + *new_unifier.borrow_mut() = Ok(generalization_variable); + } + + (shifted_term, shifted_type) + }, + ); + + println!("Generalized term: {}", generalized_term); + println!("Generalized type: {}", generalized_type); + + (generalized_term, generalized_type) + */ } #[cfg(test)] @@ -621,7 +1187,7 @@ mod tests { let term_tokens = tokenize(None, term_source).unwrap(); let term_term = parse(None, term_source, &term_tokens[..], &parsing_context[..]).unwrap(); - let term_type_term = type_check( + let (_, term_type_term) = type_check( None, term_source, &term_term, @@ -661,7 +1227,7 @@ mod tests { let term_tokens = tokenize(None, term_source).unwrap(); let term_term = parse(None, term_source, &term_tokens[..], &parsing_context[..]).unwrap(); - let term_type_term = type_check( + let (_, term_type_term) = type_check( None, term_source, &term_term, @@ -692,7 +1258,7 @@ mod tests { let term_tokens = tokenize(None, term_source).unwrap(); let term_term = parse(None, term_source, &term_tokens[..], &parsing_context[..]).unwrap(); - let term_type_term = type_check( + let (_, term_type_term) = type_check( None, term_source, &term_term, @@ -723,7 +1289,7 @@ mod tests { let term_tokens = tokenize(None, term_source).unwrap(); let term_term = parse(None, term_source, &term_tokens[..], &parsing_context[..]).unwrap(); - let term_type_term = type_check( + let (_, term_type_term) = type_check( None, term_source, &term_term, @@ -763,7 +1329,7 @@ mod tests { let term_tokens = tokenize(None, term_source).unwrap(); let term_term = parse(None, term_source, &term_tokens[..], &parsing_context[..]).unwrap(); - let term_type_term = type_check( + let (_, term_type_term) = type_check( None, term_source, &term_term, @@ -844,7 +1410,7 @@ mod tests { let mut definitions_context = vec![None, None]; let term_source = " ( - (a : type) => (P: (x : a) -> type) => (f : (x : a) -> P x) => (x : a) => f x # , + (a : type) => (P: a -> type) => (f : (x : a) -> P x) => (x : a) => f x # , ) ( ((t : type) => t) foo # , ) ( @@ -857,7 +1423,7 @@ mod tests { let term_tokens = tokenize(None, term_source).unwrap(); let term_term = parse(None, term_source, &term_tokens[..], &parsing_context[..]).unwrap(); - let term_type_term = type_check( + let (_, term_type_term) = type_check( None, term_source, &term_term, @@ -888,7 +1454,7 @@ mod tests { let term_tokens = tokenize(None, term_source).unwrap(); let term_term = parse(None, term_source, &term_tokens[..], &parsing_context[..]).unwrap(); - let term_type_term = type_check( + let (_, term_type_term) = type_check( None, term_source, &term_term, @@ -913,7 +1479,7 @@ mod tests { let term_tokens = tokenize(None, term_source).unwrap(); let term_term = parse(None, term_source, &term_tokens[..], &parsing_context[..]).unwrap(); - let term_type_term = type_check( + let (_, term_type_term) = type_check( None, term_source, &term_term, @@ -938,7 +1504,7 @@ mod tests { let term_tokens = tokenize(None, term_source).unwrap(); let term_term = parse(None, term_source, &term_tokens[..], &parsing_context[..]).unwrap(); - let term_type_term = type_check( + let (_, term_type_term) = type_check( None, term_source, &term_term, @@ -963,7 +1529,7 @@ mod tests { let term_tokens = tokenize(None, term_source).unwrap(); let term_term = parse(None, term_source, &term_tokens[..], &parsing_context[..]).unwrap(); - let term_type_term = type_check( + let (_, term_type_term) = type_check( None, term_source, &term_term, @@ -988,7 +1554,7 @@ mod tests { let term_tokens = tokenize(None, term_source).unwrap(); let term_term = parse(None, term_source, &term_tokens[..], &parsing_context[..]).unwrap(); - let term_type_term = type_check( + let (_, term_type_term) = type_check( None, term_source, &term_term, @@ -1013,7 +1579,7 @@ mod tests { let term_tokens = tokenize(None, term_source).unwrap(); let term_term = parse(None, term_source, &term_tokens[..], &parsing_context[..]).unwrap(); - let term_type_term = type_check( + let (_, term_type_term) = type_check( None, term_source, &term_term, @@ -1038,7 +1604,7 @@ mod tests { let term_tokens = tokenize(None, term_source).unwrap(); let term_term = parse(None, term_source, &term_tokens[..], &parsing_context[..]).unwrap(); - let term_type_term = type_check( + let (_, term_type_term) = type_check( None, term_source, &term_term, @@ -1063,7 +1629,7 @@ mod tests { let term_tokens = tokenize(None, term_source).unwrap(); let term_term = parse(None, term_source, &term_tokens[..], &parsing_context[..]).unwrap(); - let term_type_term = type_check( + let (_, term_type_term) = type_check( None, term_source, &term_term, @@ -1088,7 +1654,7 @@ mod tests { let term_tokens = tokenize(None, term_source).unwrap(); let term_term = parse(None, term_source, &term_tokens[..], &parsing_context[..]).unwrap(); - let term_type_term = type_check( + let (_, term_type_term) = type_check( None, term_source, &term_term, @@ -1113,7 +1679,7 @@ mod tests { let term_tokens = tokenize(None, term_source).unwrap(); let term_term = parse(None, term_source, &term_tokens[..], &parsing_context[..]).unwrap(); - let term_type_term = type_check( + let (_, term_type_term) = type_check( None, term_source, &term_term, @@ -1138,7 +1704,7 @@ mod tests { let term_tokens = tokenize(None, term_source).unwrap(); let term_term = parse(None, term_source, &term_tokens[..], &parsing_context[..]).unwrap(); - let term_type_term = type_check( + let (_, term_type_term) = type_check( None, term_source, &term_term, @@ -1163,7 +1729,7 @@ mod tests { let term_tokens = tokenize(None, term_source).unwrap(); let term_term = parse(None, term_source, &term_tokens[..], &parsing_context[..]).unwrap(); - let term_type_term = type_check( + let (_, term_type_term) = type_check( None, term_source, &term_term, @@ -1188,7 +1754,7 @@ mod tests { let term_tokens = tokenize(None, term_source).unwrap(); let term_term = parse(None, term_source, &term_tokens[..], &parsing_context[..]).unwrap(); - let term_type_term = type_check( + let (_, term_type_term) = type_check( None, term_source, &term_term, @@ -1213,7 +1779,7 @@ mod tests { let term_tokens = tokenize(None, term_source).unwrap(); let term_term = parse(None, term_source, &term_tokens[..], &parsing_context[..]).unwrap(); - let term_type_term = type_check( + let (_, term_type_term) = type_check( None, term_source, &term_term, @@ -1238,7 +1804,7 @@ mod tests { let term_tokens = tokenize(None, term_source).unwrap(); let term_term = parse(None, term_source, &term_tokens[..], &parsing_context[..]).unwrap(); - let term_type_term = type_check( + let (_, term_type_term) = type_check( None, term_source, &term_term, @@ -1263,7 +1829,7 @@ mod tests { let term_tokens = tokenize(None, term_source).unwrap(); let term_term = parse(None, term_source, &term_tokens[..], &parsing_context[..]).unwrap(); - let term_type_term = type_check( + let (_, term_type_term) = type_check( None, term_source, &term_term, @@ -1288,7 +1854,7 @@ mod tests { let term_tokens = tokenize(None, term_source).unwrap(); let term_term = parse(None, term_source, &term_tokens[..], &parsing_context[..]).unwrap(); - let term_type_term = type_check( + let (_, term_type_term) = type_check( None, term_source, &term_term, diff --git a/src/unifier.rs b/src/unifier.rs index 107f536b..f0eeb5d7 100644 --- a/src/unifier.rs +++ b/src/unifier.rs @@ -1,7 +1,6 @@ use crate::{ de_bruijn::signed_shift, equality::syntactically_equal, - format::CodeStr, normalizer::normalize_weak_head, term::{ Term, @@ -15,6 +14,7 @@ use crate::{ use std::{ cell::RefCell, collections::HashSet, + convert::TryFrom, hash::{Hash, Hasher}, ptr, rc::Rc, @@ -22,7 +22,7 @@ use std::{ // This struct is a "newtype" for `Rc` that implements `Eq` and `Hash` based on the underlying // pointer, rather than the value being pointed to. -struct HashableRc(Rc); +pub struct HashableRc(Rc); impl Hash for HashableRc { fn hash(&self, state: &mut H) { @@ -65,75 +65,43 @@ pub fn unify<'a>( { true } - (Unifier(subterm1, subterm_shift1), _) => { - // We `clone` the borrowed `subterm1` to avoid holding the dynamic borrow for too long. - let borrow = { subterm1.borrow().clone() }; - - if let Err(min_shift) = borrow { - // Occurs check - let mut unifiers = vec![]; - let mut visited = HashSet::new(); - collect_unifiers(&whnf2, 0, &mut unifiers, &mut visited); - if visited.contains(&HashableRc(subterm1.clone())) { - return false; - } - - // Ensure the target can be shifted by `min_shift`. - if signed_shift(&whnf2, 0, min_shift).is_none() { - return false; - } - - // Unshift - if let Some(unshifted_term) = signed_shift(&whnf2, 0, -*subterm_shift1) { - // Unify - *subterm1.borrow_mut() = Ok(unshifted_term); - - // We did it! - true - } else { - false - } - } else { - panic!( - "Encountered a non-{} unifier after reduction to weak-head normal form", - "None".code_str(), - ); + // The `unwrap` is "virtually safe", unless the conversion overflows. + (Unifier(subterm1, subterm_shift1), _) + if signed_shift(&whnf2, 0, -isize::try_from(*subterm_shift1).unwrap()).is_some() => + { + // Occurs check + let mut unifiers = vec![]; + let mut visited = HashSet::new(); + collect_unifiers(&whnf2, 0, &mut unifiers, &mut visited); + if visited.contains(&HashableRc(subterm1.clone())) { + return false; } + + // Unify. The `unwrap` is "virtually safe", unless the conversion overflows. + *subterm1.borrow_mut() = + signed_shift(&whnf2, 0, -isize::try_from(*subterm_shift1).unwrap()); + + // We did it! + true } - (_, Unifier(subterm2, subterm_shift2)) => { - // We `clone` the borrowed `subterm2` to avoid holding the dynamic borrow for too long. - let borrow = { subterm2.borrow().clone() }; - - if let Err(min_shift) = borrow { - // Occurs check - let mut unifiers = vec![]; - let mut visited = HashSet::new(); - collect_unifiers(&whnf1, 0, &mut unifiers, &mut visited); - if visited.contains(&HashableRc(subterm2.clone())) { - return false; - } - - // Ensure the target can be shifted by `min_shift`. - if signed_shift(&whnf1, 0, min_shift).is_none() { - return false; - } - - // Unshift - if let Some(unshifted_term) = signed_shift(&whnf1, 0, -*subterm_shift2) { - // Unify - *subterm2.borrow_mut() = Ok(unshifted_term); - - // We did it! - true - } else { - false - } - } else { - panic!( - "Encountered a non-{} unifier after reduction to weak-head normal form", - "None".code_str(), - ); + // The `unwrap` is "virtually safe", unless the conversion overflows. + (_, Unifier(subterm2, subterm_shift2)) + if signed_shift(&whnf1, 0, -isize::try_from(*subterm_shift2).unwrap()).is_some() => + { + // Occurs check + let mut unifiers = vec![]; + let mut visited = HashSet::new(); + collect_unifiers(&whnf1, 0, &mut unifiers, &mut visited); + if visited.contains(&HashableRc(subterm2.clone())) { + return false; } + + // Unify. The `unwrap` is "virtually safe", unless the conversion overflows. + *subterm2.borrow_mut() = + signed_shift(&whnf1, 0, -isize::try_from(*subterm_shift2).unwrap()); + + // We did it! + true } (Type, Type) | (Integer, Integer) | (Boolean, Boolean) | (True, True) | (False, False) => { true @@ -194,7 +162,9 @@ pub fn unify<'a>( && unify(then_branch1, then_branch2, definitions_context) && unify(else_branch1, else_branch2, definitions_context) } - (Variable(_, _), _) + (Unifier(_, _), _) + | (_, Unifier(_, _)) + | (Variable(_, _), _) | (_, Variable(_, _)) | (Lambda(_, _, _, _), _) | (_, Lambda(_, _, _, _)) @@ -243,18 +213,20 @@ pub fn unify<'a>( // This function collects all the unresolved unifiers in a term. The unifiers are deduplicated and // returned in the order they are first encountered in the term. -fn collect_unifiers<'a>( +#[allow(clippy::type_complexity)] +pub fn collect_unifiers<'a>( term: &Term<'a>, depth: usize, - unifiers: &mut Vec, isize>>>>, - visited: &mut HashSet, isize>>>>, + unifiers: &mut Vec>>>>, + visited: &mut HashSet>>>>, ) { match &term.variant { Unifier(unifier, _) => { // We `clone` the borrowed `subterm` to avoid holding the dynamic borrow for too long. - if let Ok(subterm) = { unifier.borrow().clone() } { + if let Some(subterm) = { unifier.borrow().clone() } { collect_unifiers(&subterm, depth, unifiers, visited); } else if visited.insert(HashableRc(unifier.clone())) { + // This `unwrap` is "virtually safe", unless the conversion overflows. unifiers.push(unifier.clone()); } } @@ -304,11 +276,14 @@ fn collect_unifiers<'a>( mod tests { use crate::{ parser::parse, + term::{ + Term, + Variant::{Application, Unifier, Variable}, + }, tokenizer::tokenize, - type_checker::type_check, unifier::{collect_unifiers, unify}, }; - use std::collections::HashSet; + use std::{cell::RefCell, collections::HashSet, rc::Rc}; #[test] fn unify_unifier_left() { @@ -1047,21 +1022,21 @@ mod tests { #[test] fn collect_unifiers_unifier_deduplication() { - let parsing_context = []; - let mut typing_context = vec![]; - let mut definitions_context = vec![]; - - let source = "(x => x) _"; - let tokens = tokenize(None, source).unwrap(); - let term = parse(None, source, &tokens[..], &parsing_context[..]).unwrap(); - let _ = type_check( - None, - source, - &term, - &mut typing_context, - &mut definitions_context, - ) - .unwrap(); + let rc = Rc::new(RefCell::new(None)); + + let term = Term { + source_range: None, + variant: Application( + Rc::new(Term { + source_range: None, + variant: Unifier(rc.clone(), 0), + }), + Rc::new(Term { + source_range: None, + variant: Unifier(rc.clone(), 0), + }), + ), + }; let mut unifiers = vec![]; let mut visited = HashSet::new(); @@ -1087,21 +1062,16 @@ mod tests { #[test] fn collect_unifiers_resolved() { - let parsing_context = []; - let mut typing_context = vec![]; - let mut definitions_context = vec![]; - - let source = "(x => x) type"; - let tokens = tokenize(None, source).unwrap(); - let term = parse(None, source, &tokens[..], &parsing_context[..]).unwrap(); - let _ = type_check( - None, - source, - &term, - &mut typing_context, - &mut definitions_context, - ) - .unwrap(); + let term = Term { + source_range: None, + variant: Unifier( + Rc::new(RefCell::new(Some(Term { + source_range: None, + variant: Variable("x", 0), + }))), + 0, + ), + }; let mut unifiers = vec![]; let mut visited = HashSet::new();