-
Notifications
You must be signed in to change notification settings - Fork 108
/
l2_opt.ML
286 lines (263 loc) · 9.9 KB
/
l2_opt.ML
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
(*
* Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
*
* SPDX-License-Identifier: BSD-2-Clause
*)
(*
* Optimise L2 fragments of code by using facts learnt earlier in the fragments
* to simplify code afterwards.
*)
structure L2Opt =
struct
(*
* Map the given simpset to tweak it for L2Opt.
*
* If "use_ugly_rules" is enabled, we will use rules that are useful for
* discharging proofs, but make the output ugly.
*)
fun map_opt_simpset use_ugly_rules =
Simplifier.add_cong @{thm if_cong}
#> Simplifier.add_cong @{thm split_cong}
#> Simplifier.add_cong @{thm HOL.conj_cong}
#> (if use_ugly_rules then
(fn ctxt => ctxt addsimps [@{thm split_def}])
else
I)
(*
* Solve a goal of the form:
*
* simp_expr P A ?X
*
* This is done by simplifying "A" while assuming "P", and unifying the result
* (usually instantiating "X") in the process.
*)
val simp_expr_thm =
@{lemma "(simp_expr P G G == simp_expr P G G') ==> simp_expr P G G'" by (clarsimp simp: simp_expr_def)}
fun solve_simp_expr_tac ctxt =
Subgoal.FOCUS_PARAMS (fn {context = ctxt, ...} =>
(fn thm =>
case Drule.cprems_of thm of
[] => (no_tac thm)
| (goal::_) =>
(case Thm.term_of goal of
(_ $ (Const (@{const_name "simp_expr"}, _) $ P $ L $ _)) =>
let
val goal = @{mk_term "simp_expr ?P ?L ?L" (P, L)} (P, L)
|> Thm.cterm_of ctxt
val simplified = Simplifier.asm_full_rewrite (map_opt_simpset false ctxt) goal
(* Ensure that all schematics have been resolved. *)
val schematic_remains = Term.exists_subterm Term.is_Var (Thm.prop_of simplified)
in
if schematic_remains then
(resolve_tac ctxt @{thms simp_expr_triv} 1) thm
else
((resolve_tac ctxt [simp_expr_thm] 1) THEN (resolve_tac ctxt [simplified] 1)) thm
end
| _ => no_tac thm)
)) ctxt
(*
* Solve a goal of the forms:
*
* simp_expr P A B
*
* where both "A" and "B" are constants (i.e., not schematics).
*)
fun solve_simp_expr_const_tac ctxt thm =
if (Term.exists_subterm Term.is_Var (Thm.term_of (Thm.cprem_of thm 1))) then
no_tac thm
else
SOLVES (
(resolve_tac ctxt @{thms simp_expr_solve_constant} 1)
THEN (Clasimp.clarsimp_tac (map_opt_simpset true ctxt) 1)) thm
(*
* Given a theorem of the form:
*
* monad_equiv P L R Q E
*
* simplify "P", possibly trimming parts of it that are too large.
*
* The idea here is to avoid exponential blow-up by trimming off terms that get
* too large.
*)
fun simp_monad_equiv_pre_tac ctxt =
Subgoal.FOCUS_PARAMS (fn {context = ctxt, ...} =>
(fn thm =>
case Thm.term_of (Thm.cprem_of thm 1) of
Const (@{const_name Trueprop}, _) $
(Const (@{const_name monad_equiv}, _) $ P $ _ $ _ $ _ $ _) =>
let
(* If P is schematic, we could end up with flex-flex pairs that Isabelle refuses to solve.
* Our monad_equiv rules should never allow this to happen. *)
val _ = if not (exists_subterm is_Var P) then () else
raise CTERM ("autocorres: bad schematic in monad_equiv_pre", [Thm.cprem_of thm 1])
(* Perform basic simplification of the term. *)
val simp_thm = Simplifier.asm_full_rewrite (map_opt_simpset false ctxt) (Thm.cterm_of ctxt P)
in
(resolve_tac ctxt [@{thm monad_equiv_weaken_pre''} OF [simp_thm]] 1
ORELSE (fn t => raise (CTERM ("autocorres: monad_equiv_pre failed to prove goal", [Thm.cprem_of t 1])))) thm
end
| _ =>
all_tac thm
)) ctxt
(*
* Recursively simplify a monadic expression, using information gleaned from
* earlier in the program to simplify parts of the program further down.
*)
fun monad_equiv ctxt ct =
let
(* Mark context as being "invisible" to reduce warnings being printed. *)
val ctxt = Context_Position.set_visible false ctxt
(* Generate our top-level "monad_equiv" goal. *)
val goal = @{mk_term "?L == ?R" (L)} (Thm.term_of ct)
|> Thm.cterm_of ctxt
|> Goal.init
|> Utils.apply_tac "Creating object-level equality." (resolve_tac ctxt @{thms eq_reflection} 1)
|> Utils.apply_tac "Creating 'monad_equiv' goal." (resolve_tac ctxt @{thms monad_equiv_eq} 1)
(* Print a diagnostic if this branch fails. *)
val num_failures = ref 0
fun print_failure_tac t =
if (false andalso !num_failures < 5) then
(num_failures := !num_failures + 1; (print_tac ctxt "Branch failed" THEN no_tac) t)
else
(no_tac t)
(* Fetch theorms used in the simplification process. *)
val thms = Utils.get_rules ctxt @{named_theorems L2flow}
(* Tactic to blindly apply simplification rules. *)
fun solve_goal_tac _ =
(simp_monad_equiv_pre_tac ctxt 1)
THEN DETERM (
SOLVES
((solve_simp_expr_const_tac ctxt)
ORELSE
((solve_simp_expr_tac ctxt 1)
ORELSE
((resolve_tac ctxt thms THEN_ALL_NEW solve_goal_tac) 1
ORELSE
((print_failure_tac))))))
(* Apply the rules. *)
val thm =
Utils.apply_tac "Simplifying L2" (solve_goal_tac 1) goal
|> Goal.finish ctxt
in
thm
end
(*
* A simproc implementing the "L2_gets_bind" rule. The rule, unfortunately, has
* the ability to cause exponential growth in the spec size in some cases;
* thus, we can only selectively apply it in cases where this doesn't happen.
*
* In particular, we propagate a "gets" into its usage if it is used at most once.
*
* Or, if the user asks for "no_opt", we only erase the "gets" if it is never used.
* (Even with "no optimisation", we still want to get rid of control flow variables
* emitted by c-parser. Hopefully the user won't mind if their own unused variables
* also disappear.)
*)
val l2_gets_bind_thm = mk_meta_eq @{thm L2_gets_bind}
fun l2_gets_bind_simproc' ctxt cterm =
let
fun is_simple (_ $ Abs (_, _, Bound _)) = true
| is_simple (_ $ Abs (_, _, Free _)) = true
| is_simple (_ $ Abs (_, _, Const _)) = true
| is_simple _ = false
in
case Thm.term_of cterm of
(Const (@{const_name "L2_seq"}, _) $ lhs $ Abs (_, _, rhs)) =>
let
fun count_var_usage (a $ b) = count_var_usage a + count_var_usage b
| count_var_usage (Abs (_, _, x)) = count_var_usage x
| count_var_usage (Free ("_dummy", _)) = 1
| count_var_usage _ = 0
val count = count_var_usage (subst_bounds ([Free ("_dummy", dummyT)], rhs))
in
if count <= 1 orelse is_simple lhs then
SOME l2_gets_bind_thm
else
NONE
end
| _ => NONE
end
val l2_gets_bind_simproc =
Utils.mk_simproc' @{context}
("L2_gets_bind_simproc", ["L2_seq (L2_gets (%_. ?A) ?n) ?B"], l2_gets_bind_simproc')
(* Simproc to clean up guards. *)
fun l2_guard_simproc' ss ctxt cterm =
let
val simp_thm = Simplifier.asm_full_rewrite
(Simplifier.add_cong @{thm HOL.conj_cong} (put_simpset ss ctxt)) cterm
val [lhs, rhs] = Thm.prop_of (Drule.eta_contraction_rule simp_thm) |> Term.strip_comb |> snd
in
if Term_Ord.fast_term_ord (lhs, rhs) = EQUAL then
NONE
else
SOME simp_thm
end
fun l2_guard_simproc ss =
Utils.mk_simproc' @{context} ("L2_guard_simproc", ["L2_guard ?G"], l2_guard_simproc' ss)
(*
* Adjust "case_prod commands so that constructs such as:
*
* while C (%x. gets (case x of (a, b) => %s. P a b)) ...
*
* are transformed into:
*
* while C (%(a, b). gets (%s. P a b)) ...
*)
fun fix_L2_while_loop_splits_conv ctxt =
Simplifier.asm_full_rewrite (
put_simpset HOL_ss ctxt
addsimps @{thms L2_split_fixups}
|> fold Simplifier.add_cong @{thms L2_split_fixups_congs})
(*
* Carry out flow-sensitive optimisations on the given 'thm'.
*
* "n" is the argument number to cleanup, counting from 1. So for example, if
* our input theorem was "corres P A B", an "n" of 2 would simplify "A".
* If n < 0, then the cleanup is applied to the -n-th argument from the end.
*
* If "fast_mode" is 0, perform flow-sensitive optimisations (which tend to be
* time-consuming). If 1, only apply L2Peephole and L2Opt simplification rules.
* If 2, do not use AutoCorres' simplification rules at all.
*)
fun cleanup_thm ctxt thm fast_mode n do_trace =
let
(* Don't print out warning messages. *)
val ctxt = Context_Position.set_visible false ctxt
(* Setup basic simplifier. *)
fun basic_ss ctxt =
put_simpset AUTOCORRES_SIMPSET ctxt
|> (fn ctxt => if fast_mode < 2 then ctxt addsimps (Utils.get_rules ctxt @{named_theorems L2opt}) else ctxt)
|> (fn ctxt => if fast_mode < 2 then ctxt addsimprocs [l2_gets_bind_simproc] else ctxt)
|> (fn ctxt => ctxt addsimprocs [l2_guard_simproc (simpset_of ctxt)])
|> map_opt_simpset false
fun simp_conv ctxt =
Drule.beta_eta_conversion
then_conv (fix_L2_while_loop_splits_conv ctxt)
then_conv (Simplifier.rewrite (basic_ss ctxt))
fun l2conv conv =
Utils.remove_meta_conv (fn ctxt => Utils.nth_arg_conv n (conv ctxt)) ctxt
(* Apply peephole optimisations to the theorem. *)
val (new_thm, peephole_trace) =
AutoCorresTrace.fconv_rule_maybe_traced ctxt (l2conv simp_conv) thm do_trace
|> apfst Drule.eta_contraction_rule
(* Apply flow-sensitive optimisations, and then re-apply simple simplifications. *)
(* TODO: trace monad_equiv using trace_solve_tac rather than fconv_rule_traced *)
val (new_thm, flow_trace) =
if fast_mode = 0 then
AutoCorresTrace.fconv_rule_maybe_traced ctxt (
l2conv (fn ctxt =>
monad_equiv ctxt
then_conv (simp_conv (put_simpset AUTOCORRES_SIMPSET ctxt))
)) new_thm do_trace
else
(new_thm, NONE)
(* Beta/Eta normalise. *)
val new_thm = Conv.fconv_rule (l2conv (K Drule.beta_eta_conversion)) new_thm
in
(new_thm, List.mapPartial I [peephole_trace, flow_trace])
end
(* Also tag the traces in a suitable format to be stored in AutoCorresData. *)
fun cleanup_thm_tagged ctxt thm fast_mode n do_trace phase =
cleanup_thm ctxt thm fast_mode n do_trace
|> apsnd (map AutoCorresData.SimpTrace #> Utils.zip [phase ^ " peephole opt", phase ^ " flow opt"])
end