-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathUniform_Sampling.thy
248 lines (218 loc) · 10.4 KB
/
Uniform_Sampling.thy
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
theory Uniform_Sampling imports
Cyclic_Group_SPMF
"~~/src/HOL/Number_Theory/Cong"
begin
definition sample_uniform_units :: "nat \<Rightarrow> nat spmf"
where "sample_uniform_units q = spmf_of_set ({..< q} - {0})"
lemma lossless_sample_uniform_units:
assumes "(p::nat) > 1"
shows "lossless_spmf (sample_uniform_units p)"
unfolding sample_uniform_units_def
using assms by auto
(*General lemma for mapping using sample_uniform*)
lemma one_time_pad':
assumes inj_on: "inj_on f ({..<q} - {0})"
and sur: "f ` ({..<q} - {0}) = ({..<q} - {0})"
shows "map_spmf f (sample_uniform_units q) = (sample_uniform_units q)"
(is "?lhs = ?rhs")
proof-
have rhs: "?rhs = spmf_of_set (({..<q} - {0}))"
by(auto simp add: sample_uniform_units_def)
also have "map_spmf(\<lambda>s. f s) (spmf_of_set ({..<q} - {0})) = spmf_of_set ((\<lambda>s. f s) ` ({..<q} - {0}))"
by(simp add: inj_on)
also have "f ` ({..<q} - {0}) = ({..<q} - {0})"
apply(rule endo_inj_surj) by(simp, simp add: sur, simp add: inj_on)
ultimately show ?thesis using rhs by simp
qed
lemma one_time_pad:
assumes inj_on: "inj_on f {..<q}"
and sur: "f ` {..<q} = {..<q}"
shows "map_spmf f (sample_uniform q) = (sample_uniform q)"
(is "?lhs = ?rhs")
proof-
have rhs: "?rhs = spmf_of_set ({..< q})"
by(auto simp add: sample_uniform_def)
also have "map_spmf(\<lambda>s. f s) (spmf_of_set {..<q}) = spmf_of_set ((\<lambda>s. f s) ` {..<q})"
by(simp add: inj_on)
also have "f ` {..<q} = {..<q}"
apply(rule endo_inj_surj) by(simp, simp add: sur, simp add: inj_on)
ultimately show ?thesis using rhs by simp
qed
(*(y + b)*)
lemma plus_inj_eq:
assumes x: "x < q"
and x': "x' < q"
and map: "((y :: nat) + x) mod q = (y + x') mod q"
shows "x = x'"
proof-
have "((y :: nat) + x) mod q = (y + x') mod q \<Longrightarrow> x mod q = x' mod q"
proof-
have "((y:: nat) + x) mod q = (y + x') mod q \<Longrightarrow> [((y:: nat) + x) = (y + x')] (mod q)"
by(simp add: cong_nat_def)
moreover have "[((y:: nat) + x) = (y + x')] (mod q) \<Longrightarrow> [x = x'] (mod q)"
by (simp add: cong_add_lcancel_nat)
moreover have "[x = x'] (mod q) \<Longrightarrow> x mod q = x' mod q"
by(simp add: cong_nat_def)
ultimately show ?thesis by(simp add: map)
qed
moreover have "x mod q = x' mod q \<Longrightarrow> x = x'"
by(simp add: x x')
ultimately show ?thesis by(simp add: map)
qed
lemma inj_uni_samp_plus: "inj_on (\<lambda>(b :: nat). (y + b) mod q ) {..<q}"
by(simp add: inj_on_def)(auto simp only: plus_inj_eq)
lemma surj_uni_samp_plus:
assumes inj: "inj_on (\<lambda>(b :: nat). (y + b) mod q ) {..<q}"
shows "(\<lambda>(b :: nat). (y + b) mod q) ` {..< q} = {..< q}"
apply(rule endo_inj_surj) using inj by auto
lemma samp_uni_plus_one_time_pad:
shows "map_spmf (\<lambda>b. (y + b) mod q) (sample_uniform q) = sample_uniform q"
using inj_uni_samp_plus surj_uni_samp_plus one_time_pad by simp
(*x*b*)
lemma mult_inj_eq:
assumes coprime: "coprime x (q::nat)"
and y: "y < q"
and y': "y' < q"
and map: "x * y mod q = x * y' mod q"
shows "y = y'"
proof-
have "x*y mod q = x*y' mod q \<Longrightarrow> y mod q = y' mod q"
proof-
have "x*y mod q = x*y' mod q \<Longrightarrow> [x*y = x*y'] (mod q)"
by(simp add: cong_nat_def)
moreover have "[x*y = x*y'] (mod q) = [y = y'] (mod q)"
by(simp add: cong_mult_lcancel_nat coprime)
moreover have "[y = y'] (mod q) \<Longrightarrow> y mod q = y' mod q"
by(simp add: cong_nat_def)
ultimately show ?thesis by(simp add: map)
qed
moreover have "y mod q = y' mod q \<Longrightarrow> y = y'"
by(simp add: y y')
ultimately show ?thesis by(simp add: map)
qed
lemma inj_on_mult:
assumes coprime: "coprime x (q::nat)"
shows "inj_on (\<lambda> b. x*b mod q) {..<q}"
apply(auto simp add: inj_on_def)
using coprime by(simp only: mult_inj_eq)
lemma surj_on_mult:
assumes coprime: "coprime x (q::nat)"
and inj: "inj_on (\<lambda> b. x*b mod q) {..<q}"
shows "(\<lambda> b. x*b mod q) ` {..< q} = {..< q}"
apply(rule endo_inj_surj) using coprime inj by auto
lemma mult_one_time_pad:
assumes coprime: "coprime x q"
shows "map_spmf (\<lambda> b. x*b mod q) (sample_uniform q) = sample_uniform q"
using inj_on_mult surj_on_mult one_time_pad coprime by simp
lemma inj_on_mult':
assumes coprime: "coprime x (q::nat)"
shows "inj_on (\<lambda> b. x*b mod q) ({..<q} - {0})"
apply(auto simp add: inj_on_def)
using coprime by(simp only: mult_inj_eq)
lemma surj_on_mult':
assumes coprime: "coprime x (q::nat)"
and inj: "inj_on (\<lambda> b. x*b mod q) ({..<q} - {0})"
shows "(\<lambda> b. x*b mod q) ` ({..<q} - {0}) = ({..<q} - {0})"
proof(rule endo_inj_surj)
show " finite ({..<q} - {0})" by auto
show "(\<lambda>b. x * b mod q) ` ({..<q} - {0}) \<subseteq> {..<q} - {0}"
proof-
obtain nn :: "nat set \<Rightarrow> (nat \<Rightarrow> nat) \<Rightarrow> nat set \<Rightarrow> nat" where
"\<forall>x0 x1 x2. (\<exists>v3. v3 \<in> x2 \<and> x1 v3 \<notin> x0) = (nn x0 x1 x2 \<in> x2 \<and> x1 (nn x0 x1 x2) \<notin> x0)"
by moura
hence 1: "\<forall>N f Na. nn Na f N \<in> N \<and> f (nn Na f N) \<notin> Na \<or> f ` N \<subseteq> Na"
by (meson image_subsetI)
have 2: "x * nn ({..<q} - {0}) (\<lambda>n. x * n mod q) ({..<q} - {0}) mod q \<notin> {..<q} \<or> x * nn ({..<q} - {0}) (\<lambda>n. x * n mod q) ({..<q} - {0}) mod q \<in> insert 0 {..<q}"
by force
have 3: "(x * nn ({..<q} - {0}) (\<lambda>n. x * n mod q) ({..<q} - {0}) mod q \<in> insert 0 {..<q} - {0}) = (x * nn ({..<q} - {0}) (\<lambda>n. x * n mod q) ({..<q} - {0}) mod q \<in> {..<q} - {0})"
by simp
{ assume "x * nn ({..<q} - {0}) (\<lambda>n. x * n mod q) ({..<q} - {0}) mod q = x * 0 mod q"
hence "(0 \<le> q) = (0 = q) \<or> (nn ({..<q} - {0}) (\<lambda>n. x * n mod q) ({..<q} - {0}) \<notin> {..<q} \<or> nn ({..<q} - {0}) (\<lambda>n. x * n mod q) ({..<q} - {0}) \<in> {0}) \<or> nn ({..<q} - {0}) (\<lambda>n. x * n mod q) ({..<q} - {0}) \<notin> {..<q} - {0} \<or> x * nn ({..<q} - {0}) (\<lambda>n. x * n mod q) ({..<q} - {0}) mod q \<in> {..<q} - {0}"
by (metis antisym_conv1 insertCI lessThan_iff local.coprime mult_inj_eq) }
moreover
{ assume "0 \<noteq> x * nn ({..<q} - {0}) (\<lambda>n. x * n mod q) ({..<q} - {0}) mod q"
moreover
{ assume "x * nn ({..<q} - {0}) (\<lambda>n. x * n mod q) ({..<q} - {0}) mod q \<in> insert 0 {..<q} \<and> x * nn ({..<q} - {0}) (\<lambda>n. x * n mod q) ({..<q} - {0}) mod q \<notin> {0}"
hence "(\<lambda>n. x * n mod q) ` ({..<q} - {0}) \<subseteq> {..<q} - {0}"
using 3 1 by (meson Diff_iff) }
ultimately have "(\<lambda>n. x * n mod q) ` ({..<q} - {0}) \<subseteq> {..<q} - {0} \<or> (0 \<le> q) = (0 = q)"
using 2 by (metis antisym_conv1 lessThan_iff mod_less_divisor singletonD) }
ultimately have "(\<lambda>n. x * n mod q) ` ({..<q} - {0}) \<subseteq> {..<q} - {0} \<or> nn ({..<q} - {0}) (\<lambda>n. x * n mod q) ({..<q} - {0}) \<notin> {..<q} - {0} \<or> x * nn ({..<q} - {0}) (\<lambda>n. x * n mod q) ({..<q} - {0}) mod q \<in> {..<q} - {0}"
by force
thus "(\<lambda>n. x * n mod q) ` ({..<q} - {0}) \<subseteq> {..<q} - {0}"
using 1 by meson
qed
show "inj_on (\<lambda>b. x * b mod q) ({..<q} - {0})"
using inj by blast
qed
lemma mult_one_time_pad':
assumes coprime: "coprime x q"
shows "map_spmf (\<lambda> b. x*b mod q) (sample_uniform_units q) = sample_uniform_units q"
using inj_on_mult' surj_on_mult' one_time_pad' coprime by simp
(*y + x*b*)
lemma samp_uni_add_mult:
assumes coprime: "coprime x (q::nat)"
and x': "x' < q"
and y': "y' < q"
and map: "(y + x * x') mod q = (y + x * y') mod q"
shows "x' = y'"
proof-
have "(y + x * x') mod q = (y + x * y') mod q \<Longrightarrow> x' mod q = y' mod q"
proof-
have "(y + x * x') mod q = (y + x * y') mod q \<Longrightarrow> [y + x*x' = y + x *y'] (mod q)"
using cong_nat_def by blast
moreover have "[y + x*x' = y + x *y'] (mod q) \<Longrightarrow> [x' = y'] (mod q)"
by(simp add: cong_add_lcancel_nat)(simp add: coprime cong_mult_lcancel_nat)
ultimately show ?thesis by(simp add: cong_nat_def map)
qed
moreover have "x' mod q = y' mod q \<Longrightarrow> x' = y'"
by(simp add: x' y')
ultimately show ?thesis by(simp add: map)
qed
lemma inj_on_add_mult:
assumes coprime: "coprime x (q::nat)"
shows "inj_on (\<lambda> b. (y + x*b) mod q) {..<q}"
apply(auto simp add: inj_on_def)
using coprime by(simp only: samp_uni_add_mult)
lemma surj_on_add_mult:
assumes coprime: "coprime x (q::nat)"
and inj: "inj_on (\<lambda> b. (y + x*b) mod q) {..<q}"
shows "(\<lambda> b. (y + x*b) mod q) ` {..< q} = {..< q}"
apply(rule endo_inj_surj) using coprime inj by auto
lemma add_mult_one_time_pad:
assumes coprime: "coprime x q"
shows "map_spmf (\<lambda> b. (y + x*b) mod q) (sample_uniform q) = (sample_uniform q)"
using inj_on_add_mult surj_on_add_mult one_time_pad coprime by simp
(*(y - b) *)
lemma inj_on_minus: "inj_on (\<lambda>(b :: nat). (y + (q - b)) mod q ) {..<q}"
proof(unfold inj_on_def; auto)
fix x :: nat and y' :: nat
assume x: "x < q"
assume y': "y' < q"
assume map: "(y + q - x) mod q = (y + q - y') mod q"
have "\<forall>n na p. \<exists>nb. \<forall>nc nd pa. (\<not> (nc::nat) < nd \<or> \<not> pa (nc - nd) \<or> pa 0) \<and> (\<not> p (0::nat) \<or> p (n - na) \<or> na + nb = n)"
by (metis (no_types) nat_diff_split)
hence "\<not> y < y' - q \<and> \<not> y < x - q"
using y' x by (metis add.commute less_diff_conv not_add_less2)
hence "\<exists>n. (y' + n) mod q = (n + x) mod q"
using map by (metis add.commute add_diff_inverse_nat less_diff_conv mod_add_left_eq)
thus "x = y'"
by (metis plus_inj_eq x y' add.commute)
qed
lemma surj_on_minus:
assumes inj: "inj_on (\<lambda>(b :: nat). (y + (q - b)) mod q ) {..<q}"
shows "(\<lambda>(b :: nat). (y + (q - b)) mod q) ` {..< q} = {..< q}"
apply(rule endo_inj_surj) using inj by auto
lemma samp_uni_minus_one_time_pad:
shows "map_spmf(\<lambda> b. (y + (q - b)) mod q) (sample_uniform q) = sample_uniform q"
using inj_on_minus surj_on_minus one_time_pad by simp
lemma not_coin_spmf: "map_spmf (\<lambda> a. \<not> a) coin_spmf = coin_spmf"
proof-
have "inj_on Not {True, False}"
by simp
moreover have "Not ` {True, False} = {True, False}"
by auto
ultimately show ?thesis using one_time_pad
by (simp add: UNIV_bool)
qed
end