Skip to content

Commit

Permalink
Update the way we compute which functions should be reducible
Browse files Browse the repository at this point in the history
  • Loading branch information
sonmarcho committed Feb 25, 2025
1 parent 15a17a7 commit d4ad011
Show file tree
Hide file tree
Showing 10 changed files with 26 additions and 24 deletions.
25 changes: 5 additions & 20 deletions src/pure/PureMicroPasses.ml
Original file line number Diff line number Diff line change
Expand Up @@ -3344,33 +3344,18 @@ let compute_reducible (_ctx : ctx) (transl : pure_fun_translation list) :
match trans.f.body with
| None -> trans
| Some body -> (
(* Check if the body is exactly a call to a loop function.
Note that we check that the arguments are exactly the input
variables - otherwise we may not want the call to be reducible;
for instance when using the [progress] tactic we might want to
use a more specialized specification theorem. *)
let app, args = destruct_apps body.body in
let app, _ = destruct_apps body.body in
match app.e with
| Qualif
{
id = FunOrOp (Fun (FromLlbc (FunId fid, Some _lp_id)));
generics = _;
}
when fid = FRegular trans.f.def_id ->
if
List.length body.inputs = List.length args
&& List.for_all
(fun ((var, arg) : var * texpression) ->
match arg.e with
| Var var_id -> var_id = var.id
| _ -> false)
(List.combine body.inputs args)
then
let f =
{ trans.f with backend_attributes = { reducible = true } }
in
{ trans with f }
else trans
let f =
{ trans.f with backend_attributes = { reducible = true } }
in
{ trans with f }
| _ -> trans)
in
List.map update_one transl
Expand Down
4 changes: 2 additions & 2 deletions tests/lean/Arrays.lean
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,8 @@ divergent def sum_loop (s : Slice U32) (sum1 : U32) (i : Usize) : Result U32 :=

/- [arrays::sum]:
Source: 'tests/src/arrays.rs', lines 270:0-278:1 -/
def sum (s : Slice U32) : Result U32 :=
sum_loop s 0#u32 0#usize
@[reducible] def sum (s : Slice U32) : Result U32 :=
sum_loop s 0#u32 0#usize

/- [arrays::sum2]: loop 0:
Source: 'tests/src/arrays.rs', lines 284:4-287:5 -/
Expand Down
1 change: 1 addition & 0 deletions tests/lean/Avl/Funs.lean
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ divergent def Tree.find_loop

/- [avl::{avl::Tree<T>}#3::find]:
Source: 'src/avl.rs', lines 342:4-354:5 -/
@[reducible]
def Tree.find
{T : Type} (OrdInst : Ord T) (self : Tree T) (value : T) : Result Bool :=
Tree.find_loop OrdInst value self.root
Expand Down
4 changes: 4 additions & 0 deletions tests/lean/Betree/Funs.lean
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ divergent def betree.List.len_loop

/- [betree::betree::{betree::betree::List<T>}#1::len]:
Source: 'src/betree.rs', lines 276:4-284:5 -/
@[reducible]
def betree.List.len {T : Type} (self : betree.List T) : Result U64 :=
betree.List.len_loop self 0#u64

Expand All @@ -110,6 +111,7 @@ divergent def betree.List.reverse_loop

/- [betree::betree::{betree::betree::List<T>}#1::reverse]:
Source: 'src/betree.rs', lines 304:4-312:5 -/
@[reducible]
def betree.List.reverse
{T : Type} (self : betree.List T) : Result (betree.List T) :=
betree.List.reverse_loop self betree.List.Nil
Expand All @@ -134,6 +136,7 @@ divergent def betree.List.split_at_loop

/- [betree::betree::{betree::betree::List<T>}#1::split_at]:
Source: 'src/betree.rs', lines 287:4-302:5 -/
@[reducible]
def betree.List.split_at
{T : Type} (self : betree.List T) (n : U64) :
Result ((betree.List T) × (betree.List T))
Expand Down Expand Up @@ -197,6 +200,7 @@ divergent def betree.ListPairU64T.partition_at_pivot_loop

/- [betree::betree::{betree::betree::List<(u64, T)>}#2::partition_at_pivot]:
Source: 'src/betree.rs', lines 355:4-370:5 -/
@[reducible]
def betree.ListPairU64T.partition_at_pivot
{T : Type} (self : betree.List (U64 × T)) (pivot : U64) :
Result ((betree.List (U64 × T)) × (betree.List (U64 × T)))
Expand Down
1 change: 1 addition & 0 deletions tests/lean/Bst/Funs.lean
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ divergent def TreeSet.find_loop

/- [bst::{bst::TreeSet<T>}::find]:
Source: 'src/bst.rs', lines 32:4-44:5 -/
@[reducible]
def TreeSet.find
{T : Type} (OrdInst : Ord T) (self : TreeSet T) (value : T) : Result Bool :=
TreeSet.find_loop OrdInst value self.root
Expand Down
1 change: 1 addition & 0 deletions tests/lean/Hashmap/Funs.lean
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ divergent def HashMap.move_elements_loop

/- [hashmap::{hashmap::HashMap<T>}::move_elements]:
Source: 'tests/src/hashmap.rs', lines 185:4-195:5 -/
@[reducible]
def HashMap.move_elements
{T : Type} (ntable : HashMap T) (slots : alloc.vec.Vec (AList T)) :
Result ((HashMap T) × (alloc.vec.Vec (AList T)))
Expand Down
8 changes: 6 additions & 2 deletions tests/lean/Loops.lean
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ divergent def sum_loop (max : U32) (i : U32) (s : U32) : Result U32 :=

/- [loops::sum]:
Source: 'tests/src/loops.rs', lines 8:0-18:1 -/
def sum (max : U32) : Result U32 :=
sum_loop max 0#u32 0#u32
@[reducible] def sum (max : U32) : Result U32 :=
sum_loop max 0#u32 0#u32

/- [loops::sum_with_mut_borrows]: loop 0:
Source: 'tests/src/loops.rs', lines 26:4-31:5 -/
Expand All @@ -37,6 +37,7 @@ divergent def sum_with_mut_borrows_loop

/- [loops::sum_with_mut_borrows]:
Source: 'tests/src/loops.rs', lines 23:0-35:1 -/
@[reducible]
def sum_with_mut_borrows (max : U32) : Result U32 :=
sum_with_mut_borrows_loop max 0#u32 0#u32

Expand All @@ -54,6 +55,7 @@ divergent def sum_with_shared_borrows_loop

/- [loops::sum_with_shared_borrows]:
Source: 'tests/src/loops.rs', lines 38:0-52:1 -/
@[reducible]
def sum_with_shared_borrows (max : U32) : Result U32 :=
sum_with_shared_borrows_loop max 0#u32 0#u32

Expand All @@ -72,6 +74,7 @@ divergent def sum_array_loop

/- [loops::sum_array]:
Source: 'tests/src/loops.rs', lines 54:0-62:1 -/
@[reducible]
def sum_array {N : Usize} (a : Array U32 N) : Result U32 :=
sum_array_loop a 0#usize 0#u32

Expand All @@ -93,6 +96,7 @@ divergent def clear_loop

/- [loops::clear]:
Source: 'tests/src/loops.rs', lines 66:0-72:1 -/
@[reducible]
def clear (v : alloc.vec.Vec U32) : Result (alloc.vec.Vec U32) :=
clear_loop v 0#usize

Expand Down
1 change: 1 addition & 0 deletions tests/lean/MiniTree.lean
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ divergent def Tree.explore_loop (current_tree : Option Node) : Result Unit :=

/- [mini_tree::{mini_tree::Tree}::explore]:
Source: 'tests/src/mini_tree.rs', lines 14:4-20:5 -/
@[reducible]
def Tree.explore (self : Tree) : Result Unit :=
Tree.explore_loop self.root

Expand Down
1 change: 1 addition & 0 deletions tests/lean/RenameAttribute.lean
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ divergent def No_borrows_sum_loop

/- [rename_attribute::sum]:
Source: 'tests/src/rename_attribute.rs', lines 65:0-75:1 -/
@[reducible]
def No_borrows_sum (max : U32) : Result U32 :=
No_borrows_sum_loop max 0#u32 0#u32

Expand Down
4 changes: 4 additions & 0 deletions tests/lean/Tutorial/Tutorial.lean
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ divergent def reverse_loop

/- [tutorial::reverse]:
Source: 'src/lib.rs', lines 146:0-154:1 -/
@[reducible]
def reverse {T : Type} (l : CList T) : Result (CList T) :=
reverse_loop l CList.CNil

Expand All @@ -239,6 +240,7 @@ divergent def zero_loop

/- [tutorial::zero]:
Source: 'src/lib.rs', lines 162:0-168:1 -/
@[reducible]
def zero (x : alloc.vec.Vec U32) : Result (alloc.vec.Vec U32) :=
zero_loop x 0#usize

Expand All @@ -265,6 +267,7 @@ divergent def add_no_overflow_loop

/- [tutorial::add_no_overflow]:
Source: 'src/lib.rs', lines 175:0-181:1 -/
@[reducible]
def add_no_overflow
(x : alloc.vec.Vec U32) (y : alloc.vec.Vec U32) :
Result (alloc.vec.Vec U32)
Expand Down Expand Up @@ -302,6 +305,7 @@ divergent def add_with_carry_loop

/- [tutorial::add_with_carry]:
Source: 'src/lib.rs', lines 186:0-199:1 -/
@[reducible]
def add_with_carry
(x : alloc.vec.Vec U32) (y : alloc.vec.Vec U32) :
Result (U8 × (alloc.vec.Vec U32))
Expand Down

0 comments on commit d4ad011

Please sign in to comment.