Skip to content

Commit

Permalink
feat: Array/Option.unattach (#5586)
Browse files Browse the repository at this point in the history
More support for automatically removing `.attach`, for `Array` and
`Option`.
  • Loading branch information
kim-em authored Oct 3, 2024
1 parent b7d6a4b commit a4fda01
Show file tree
Hide file tree
Showing 5 changed files with 387 additions and 24 deletions.
149 changes: 149 additions & 0 deletions src/Init/Data/Array/Attach.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Authors: Joachim Breitner, Mario Carneiro
-/
prelude
import Init.Data.Array.Mem
import Init.Data.Array.Lemmas
import Init.Data.List.Attach

namespace Array
Expand All @@ -26,4 +27,152 @@ Unsafe implementation of `attachWith`, taking advantage of the fact that the rep
with the same elements but in the type `{x // x ∈ xs}`. -/
@[inline] def attach (xs : Array α) : Array {x // x ∈ xs} := xs.attachWith _ fun _ => id

@[simp] theorem _root_.List.attachWith_toArray {l : List α} {P : α → Prop} {H : ∀ x ∈ l.toArray, P x} :
l.toArray.attachWith P H = (l.attachWith P (by simpa using H)).toArray := by
simp [attachWith]

@[simp] theorem _root_.List.attach_toArray {l : List α} :
l.toArray.attach = (l.attachWith (· ∈ l.toArray) (by simp)).toArray := by
simp [attach]

@[simp] theorem toList_attachWith {l : Array α} {P : α → Prop} {H : ∀ x ∈ l, P x} :
(l.attachWith P H).toList = l.toList.attachWith P (by simpa [mem_toList] using H) := by
simp [attachWith]

@[simp] theorem toList_attach {α : Type _} {l : Array α} :
l.attach.toList = l.toList.attachWith (· ∈ l) (by simp [mem_toList]) := by
simp [attach]

/-! ## unattach
`Array.unattach` is the (one-sided) inverse of `Array.attach`. It is a synonym for `Array.map Subtype.val`.
We use it by providing a simp lemma `l.attach.unattach = l`, and simp lemmas which recognize higher order
functions applied to `l : Array { x // p x }` which only depend on the value, not the predicate, and rewrite these
in terms of a simpler function applied to `l.unattach`.
Further, we provide simp lemmas that push `unattach` inwards.
-/

/--
A synonym for `l.map (·.val)`. Mostly this should not be needed by users.
It is introduced as in intermediate step by lemmas such as `map_subtype`,
and is ideally subsequently simplified away by `unattach_attach`.
If not, usually the right approach is `simp [Array.unattach, -Array.map_subtype]` to unfold.
-/
def unattach {α : Type _} {p : α → Prop} (l : Array { x // p x }) := l.map (·.val)

@[simp] theorem unattach_nil {p : α → Prop} : (#[] : Array { x // p x }).unattach = #[] := rfl
@[simp] theorem unattach_push {p : α → Prop} {a : { x // p x }} {l : Array { x // p x }} :
(l.push a).unattach = l.unattach.push a.1 := by
simp only [unattach, Array.map_push]

@[simp] theorem size_unattach {p : α → Prop} {l : Array { x // p x }} :
l.unattach.size = l.size := by
unfold unattach
simp

@[simp] theorem _root_.List.unattach_toArray {p : α → Prop} {l : List { x // p x }} :
l.toArray.unattach = l.unattach.toArray := by
simp only [unattach, List.map_toArray, List.unattach]

@[simp] theorem toList_unattach {p : α → Prop} {l : Array { x // p x }} :
l.unattach.toList = l.toList.unattach := by
simp only [unattach, toList_map, List.unattach]

@[simp] theorem unattach_attach {l : Array α} : l.attach.unattach = l := by
cases l
simp

@[simp] theorem unattach_attachWith {p : α → Prop} {l : Array α}
{H : ∀ a ∈ l, p a} :
(l.attachWith p H).unattach = l := by
cases l
simp

/-! ### Recognizing higher order functions using a function that only depends on the value. -/

/--
This lemma identifies folds over arrays of subtypes, where the function only depends on the value, not the proposition,
and simplifies these to the function directly taking the value.
-/
theorem foldl_subtype {p : α → Prop} {l : Array { x // p x }}
{f : β → { x // p x } → β} {g : β → α → β} {x : β}
{hf : ∀ b x h, f b ⟨x, h⟩ = g b x} :
l.foldl f x = l.unattach.foldl g x := by
cases l
simp only [List.foldl_toArray', List.unattach_toArray]
rw [List.foldl_subtype] -- Why can't simp do this?
simp [hf]

/-- Variant of `foldl_subtype` with side condition to check `stop = l.size`. -/
@[simp] theorem foldl_subtype' {p : α → Prop} {l : Array { x // p x }}
{f : β → { x // p x } → β} {g : β → α → β} {x : β}
{hf : ∀ b x h, f b ⟨x, h⟩ = g b x} (h : stop = l.size) :
l.foldl f x 0 stop = l.unattach.foldl g x := by
subst h
rwa [foldl_subtype]

/--
This lemma identifies folds over arrays of subtypes, where the function only depends on the value, not the proposition,
and simplifies these to the function directly taking the value.
-/
theorem foldr_subtype {p : α → Prop} {l : Array { x // p x }}
{f : { x // p x } → β → β} {g : α → β → β} {x : β}
{hf : ∀ x h b, f ⟨x, h⟩ b = g x b} :
l.foldr f x = l.unattach.foldr g x := by
cases l
simp only [List.foldr_toArray', List.unattach_toArray]
rw [List.foldr_subtype]
simp [hf]

/-- Variant of `foldr_subtype` with side condition to check `stop = l.size`. -/
@[simp] theorem foldr_subtype' {p : α → Prop} {l : Array { x // p x }}
{f : { x // p x } → β → β} {g : α → β → β} {x : β}
{hf : ∀ x h b, f ⟨x, h⟩ b = g x b} (h : start = l.size) :
l.foldr f x start 0 = l.unattach.foldr g x := by
subst h
rwa [foldr_subtype]

/--
This lemma identifies maps over arrays of subtypes, where the function only depends on the value, not the proposition,
and simplifies these to the function directly taking the value.
-/
@[simp] theorem map_subtype {p : α → Prop} {l : Array { x // p x }}
{f : { x // p x } → β} {g : α → β} {hf : ∀ x h, f ⟨x, h⟩ = g x} :
l.map f = l.unattach.map g := by
cases l
simp only [List.map_toArray, List.unattach_toArray]
rw [List.map_subtype]
simp [hf]

@[simp] theorem filterMap_subtype {p : α → Prop} {l : Array { x // p x }}
{f : { x // p x } → Option β} {g : α → Option β} {hf : ∀ x h, f ⟨x, h⟩ = g x} :
l.filterMap f = l.unattach.filterMap g := by
cases l
simp only [size_toArray, List.filterMap_toArray', List.unattach_toArray, List.length_unattach,
mk.injEq]
rw [List.filterMap_subtype]
simp [hf]

@[simp] theorem unattach_filter {p : α → Prop} {l : Array { x // p x }}
{f : { x // p x } → Bool} {g : α → Bool} {hf : ∀ x h, f ⟨x, h⟩ = g x} :
(l.filter f).unattach = l.unattach.filter g := by
cases l
simp [hf]

/-! ### Simp lemmas pushing `unattach` inwards. -/

@[simp] theorem unattach_reverse {p : α → Prop} {l : Array { x // p x }} :
l.reverse.unattach = l.unattach.reverse := by
cases l
simp

@[simp] theorem unattach_append {p : α → Prop} {l₁ l₂ : Array { x // p x }} :
(l₁ ++ l₂).unattach = l₁.unattach ++ l₂.unattach := by
cases l₁
cases l₂
simp

end Array
126 changes: 112 additions & 14 deletions src/Init/Data/Array/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -108,23 +108,52 @@ theorem toArray_concat {as : List α} {x : α} :
funext a
simp

@[simp] theorem foldrM_toArray [Monad m] (f : α → β → m β) (init : β) (l : List α) :
theorem foldrM_toArray [Monad m] (f : α → β → m β) (init : β) (l : List α) :
l.toArray.foldrM f init = l.foldrM f init := by
rw [foldrM_eq_reverse_foldlM_toList]
simp

@[simp] theorem foldlM_toArray [Monad m] (f : β → α → m β) (init : β) (l : List α) :
theorem foldlM_toArray [Monad m] (f : β → α → m β) (init : β) (l : List α) :
l.toArray.foldlM f init = l.foldlM f init := by
rw [foldlM_eq_foldlM_toList]

@[simp] theorem foldr_toArray (f : α → β → β) (init : β) (l : List α) :
theorem foldr_toArray (f : α → β → β) (init : β) (l : List α) :
l.toArray.foldr f init = l.foldr f init := by
rw [foldr_eq_foldr_toList]

@[simp] theorem foldl_toArray (f : β → α → β) (init : β) (l : List α) :
theorem foldl_toArray (f : β → α → β) (init : β) (l : List α) :
l.toArray.foldl f init = l.foldl f init := by
rw [foldl_eq_foldl_toList]

/-- Variant of `foldrM_toArray` with a side condition for the `start` argument. -/
@[simp] theorem foldrM_toArray' [Monad m] (f : α → β → m β) (init : β) (l : List α)
(h : start = l.toArray.size) :
l.toArray.foldrM f init start 0 = l.foldrM f init := by
subst h
rw [foldrM_eq_reverse_foldlM_toList]
simp

/-- Variant of `foldlM_toArray` with a side condition for the `stop` argument. -/
@[simp] theorem foldlM_toArray' [Monad m] (f : β → α → m β) (init : β) (l : List α)
(h : stop = l.toArray.size) :
l.toArray.foldlM f init 0 stop = l.foldlM f init := by
subst h
rw [foldlM_eq_foldlM_toList]

/-- Variant of `foldr_toArray` with a side condition for the `start` argument. -/
@[simp] theorem foldr_toArray' (f : α → β → β) (init : β) (l : List α)
(h : start = l.toArray.size) :
l.toArray.foldr f init start 0 = l.foldr f init := by
subst h
rw [foldr_eq_foldr_toList]

/-- Variant of `foldl_toArray` with a side condition for the `stop` argument. -/
@[simp] theorem foldl_toArray' (f : β → α → β) (init : β) (l : List α)
(h : stop = l.toArray.size) :
l.toArray.foldl f init 0 stop = l.foldl f init := by
subst h
rw [foldl_eq_foldl_toList]

@[simp] theorem append_toArray (l₁ l₂ : List α) :
l₁.toArray ++ l₂.toArray = (l₁ ++ l₂).toArray := by
apply ext'
Expand Down Expand Up @@ -730,6 +759,18 @@ theorem foldr_induction
simp [foldr, foldrM]; split; {exact go _ h0}
· next h => exact (Nat.eq_zero_of_not_pos h ▸ h0)

@[congr]
theorem foldl_congr {as bs : Array α} (h₀ : as = bs) {f g : β → α → β} (h₁ : f = g)
{a b : β} (h₂ : a = b) {start start' stop stop' : Nat} (h₃ : start = start') (h₄ : stop = stop') :
as.foldl f a start stop = bs.foldl g b start' stop' := by
congr

@[congr]
theorem foldr_congr {as bs : Array α} (h₀ : as = bs) {f g : α → β → β} (h₁ : f = g)
{a b : β} (h₂ : a = b) {start start' stop stop' : Nat} (h₃ : start = start') (h₄ : stop = stop') :
as.foldr f a start stop = bs.foldr g b start' stop' := by
congr

/-! ### map -/

@[simp] theorem mem_map {f : α → β} {l : Array α} : b ∈ l.map f ↔ ∃ a, a ∈ l ∧ f a = b := by
Expand Down Expand Up @@ -814,6 +855,13 @@ theorem map_spec (as : Array α) (f : α → β) (p : Fin as.size → β → Pro
(as.map f)[i]? = as[i]?.map f := by
simp [getElem?_def]

@[simp] theorem map_push {f : α → β} {as : Array α} {x : α} :
(as.push x).map f = (as.map f).push (f x) := by
ext
· simp
· simp only [getElem_map, get_push, size_map]
split <;> rfl

/-! ### mapIdx -/

-- This could also be proved from `SatisfiesM_mapIdxM` in Batteries.
Expand Down Expand Up @@ -920,6 +968,13 @@ abbrev filter_data := @toList_filter
theorem mem_of_mem_filter {a : α} {l} (h : a ∈ filter p l) : a ∈ l :=
(mem_filter.mp h).1

@[congr]
theorem filter_congr {as bs : Array α} (h : as = bs)
{f : α → Bool} {g : α → Bool} (h' : f = g) {start stop start' stop' : Nat}
(h₁ : start = start') (h₂ : stop = stop') :
filter f as start stop = filter g bs start' stop' := by
congr

/-! ### filterMap -/

@[simp] theorem toList_filterMap (f : α → Option β) (l : Array α) :
Expand All @@ -942,6 +997,13 @@ abbrev filterMap_data := @toList_filterMap
b ∈ filterMap f l ↔ ∃ a, a ∈ l ∧ f a = some b := by
simp only [mem_def, toList_filterMap, List.mem_filterMap]

@[congr]
theorem filterMap_congr {as bs : Array α} (h : as = bs)
{f : α → Option β} {g : α → Option β} (h' : f = g) {start stop start' stop' : Nat}
(h₁ : start = start') (h₂ : stop = stop') :
filterMap f as start stop = filterMap g bs start' stop' := by
congr

/-! ### empty -/

theorem size_empty : (#[] : Array α).size = 0 := rfl
Expand Down Expand Up @@ -1432,18 +1494,44 @@ Our goal is to have `simp` "pull `List.toArray` outwards" as much as possible.
· simp
· simp_all [List.set_eq_of_length_le]

@[simp] theorem anyM_toArray [Monad m] [LawfulMonad m] (p : α → m Bool) (l : List α) :
theorem anyM_toArray [Monad m] [LawfulMonad m] (p : α → m Bool) (l : List α) :
l.toArray.anyM p = l.anyM p := by
rw [← anyM_toList]

@[simp] theorem any_toArray (p : α → Bool) (l : List α) : l.toArray.any p = l.any p := by
theorem any_toArray (p : α → Bool) (l : List α) : l.toArray.any p = l.any p := by
rw [any_toList]

@[simp] theorem allM_toArray [Monad m] [LawfulMonad m] (p : α → m Bool) (l : List α) :
theorem allM_toArray [Monad m] [LawfulMonad m] (p : α → m Bool) (l : List α) :
l.toArray.allM p = l.allM p := by
rw [← allM_toList]

@[simp] theorem all_toArray (p : α → Bool) (l : List α) : l.toArray.all p = l.all p := by
theorem all_toArray (p : α → Bool) (l : List α) : l.toArray.all p = l.all p := by
rw [all_toList]

/-- Variant of `anyM_toArray` with a side condition on `stop`. -/
@[simp] theorem anyM_toArray' [Monad m] [LawfulMonad m] (p : α → m Bool) (l : List α)
(h : stop = l.toArray.size) :
l.toArray.anyM p 0 stop = l.anyM p := by
subst h
rw [← anyM_toList]

/-- Variant of `any_toArray` with a side condition on `stop`. -/
@[simp] theorem any_toArray' (p : α → Bool) (l : List α) (h : stop = l.toArray.size) :
l.toArray.any p 0 stop = l.any p := by
subst h
rw [any_toList]

/-- Variant of `allM_toArray` with a side condition on `stop`. -/
@[simp] theorem allM_toArray' [Monad m] [LawfulMonad m] (p : α → m Bool) (l : List α)
(h : stop = l.toArray.size) :
l.toArray.allM p 0 stop = l.allM p := by
subst h
rw [← allM_toList]

/-- Variant of `all_toArray` with a side condition on `stop`. -/
@[simp] theorem all_toArray' (p : α → Bool) (l : List α) (h : stop = l.toArray.size) :
l.toArray.all p 0 stop = l.all p := by
subst h
rw [all_toList]

@[simp] theorem swap_toArray (l : List α) (i j : Fin l.toArray.size) :
Expand All @@ -1459,15 +1547,25 @@ Our goal is to have `simp` "pull `List.toArray` outwards" as much as possible.
apply ext'
simp

@[simp] theorem filter_toArray (p : α → Bool) (l : List α) :
l.toArray.filter p = (l.filter p).toArray := by
@[simp] theorem filter_toArray' (p : α → Bool) (l : List α) (h : stop = l.toArray.size) :
l.toArray.filter p 0 stop = (l.filter p).toArray := by
subst h
apply ext'
erw [toList_filter] -- `erw` required to unify `l.length` with `l.toArray.size`.
rw [toList_filter]

@[simp] theorem filterMap_toArray (f : α → Option β) (l : List α) :
l.toArray.filterMap f = (l.filterMap f).toArray := by
@[simp] theorem filterMap_toArray' (f : α → Option β) (l : List α) (h : stop = l.toArray.size) :
l.toArray.filterMap f 0 stop = (l.filterMap f).toArray := by
subst h
apply ext'
erw [toList_filterMap] -- `erw` required to unify `l.length` with `l.toArray.size`.
rw [toList_filterMap]

theorem filter_toArray (p : α → Bool) (l : List α) :
l.toArray.filter p = (l.filter p).toArray := by
simp

theorem filterMap_toArray (f : α → Option β) (l : List α) :
l.toArray.filterMap f = (l.filterMap f).toArray := by
simp

@[simp] theorem flatten_toArray (l : List (List α)) : (l.toArray.map List.toArray).flatten = l.join.toArray := by
apply ext'
Expand Down
Loading

0 comments on commit a4fda01

Please sign in to comment.