Skip to content

Commit

Permalink
Lem: Fix the sum-monad for pure early returns
Browse files Browse the repository at this point in the history
This also requires no longer treating the `early_return` helper function
as effectful, so that we can distinguish between early returns and
other effects.  This used to work previously only because we never
re-ran effect inference after the early return rewrite.
  • Loading branch information
bauereiss committed Apr 23, 2024
1 parent 5e9cc7e commit dee5964
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 76 deletions.
19 changes: 19 additions & 0 deletions src/gen_lib/sail2_monadic_combinators.lem
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,18 @@ val and_boolM : forall 'abort 'barrier 'cache_op 'fault 'pa 'tlb_op 'translation
-> monad 'abort 'barrier 'cache_op 'fault 'pa 'tlb_op 'translation_summary 'arch_ak 'rv bool 'e
let and_boolM l r = l >>= (fun l -> if l then r else return false)

val and_boolE : forall 'e. either 'e bool -> either 'e bool -> either 'e bool
let and_boolE l r = l >>$= (fun l -> if l then r else Right false)

val or_boolM : forall 'abort 'barrier 'cache_op 'fault 'pa 'tlb_op 'translation_summary 'arch_ak 'rv 'e.
monad 'abort 'barrier 'cache_op 'fault 'pa 'tlb_op 'translation_summary 'arch_ak 'rv bool 'e
-> monad 'abort 'barrier 'cache_op 'fault 'pa 'tlb_op 'translation_summary 'arch_ak 'rv bool 'e
-> monad 'abort 'barrier 'cache_op 'fault 'pa 'tlb_op 'translation_summary 'arch_ak 'rv bool 'e
let or_boolM l r = l >>= (fun l -> if l then return true else r)

val or_boolE : forall 'e. either 'e bool -> either 'e bool -> either 'e bool
let or_boolE l r = l >>$= (fun l -> if l then Right true else r)

val bool_of_bitU_fail : forall 'abort 'barrier 'cache_op 'fault 'pa 'tlb_op 'translation_summary 'arch_ak 'rv 'e.
bitU -> monad 'abort 'barrier 'cache_op 'fault 'pa 'tlb_op 'translation_summary 'arch_ak 'rv bool 'e
let bool_of_bitU_fail = function
Expand Down Expand Up @@ -144,6 +150,13 @@ let rec whileM vars cond body =
body vars >>= fun vars -> whileM vars cond body
else return vars

val whileE : forall 'vars 'e. 'vars -> ('vars -> either 'e bool) -> ('vars -> either 'e 'vars) -> either 'e 'vars
let rec whileE vars cond body =
cond vars >>$= fun cond_val ->
if cond_val then
body vars >>$= fun vars -> whileE vars cond body
else Right vars

val whileMT_aux : forall 'abort 'barrier 'cache_op 'fault 'pa 'tlb_op 'translation_summary 'arch_ak 'rv 'vars 'e.
nat
-> 'vars
Expand Down Expand Up @@ -179,6 +192,12 @@ let rec untilM vars cond body =
cond vars >>= fun cond_val ->
if cond_val then return vars else untilM vars cond body

val untilE : forall 'e 'vars. 'vars -> ('vars -> either 'e bool) -> ('vars -> either 'e 'vars) -> either 'e 'vars
let rec untilE vars cond body =
body vars >>$= fun vars ->
cond vars >>$= fun cond_val ->
if cond_val then Right vars else untilE vars cond body

val untilMT_aux : forall 'abort 'barrier 'cache_op 'fault 'pa 'tlb_op 'translation_summary 'arch_ak 'rv 'vars 'e.
nat
-> 'vars
Expand Down
1 change: 0 additions & 1 deletion src/lib/effects.ml
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,6 @@ let rewrite_attach_effects effect_info =
let env = env_of_tannot tannot in
let eff =
match e_aux with
| E_app (f, _) when string_of_id f = "early_return" -> monadic_effect
| E_app (f, _) -> begin
match Bindings.find_opt f effect_info.functions with
| Some side_effects -> if pure side_effects then no_effect else monadic_effect
Expand Down
5 changes: 5 additions & 0 deletions src/lib/rewriter.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1200,3 +1200,8 @@ let default_fold_exp f x (E_aux (e, ann) as exp) =

let rec foldin_exp f x e = f (default_fold_exp (foldin_exp f)) x e
let foldin_pexp f x e = default_fold_pexp (foldin_exp f) x e

let has_early_return (e : 'a exp) =
let e_app (id, args) = string_of_id id = "early_return" || List.fold_left ( || ) false args in
let e_return _ = true in
fold_exp { (pure_exp_alg false ( || )) with e_app; e_return } e
2 changes: 2 additions & 0 deletions src/lib/rewriter.mli
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,8 @@ val add_e_typ : Env.t -> typ -> 'a exp -> 'a exp

val add_typs_let : Env.t -> typ -> typ -> 'a exp -> 'a exp

val has_early_return : 'a exp -> bool

(* In-order fold over expressions *)
val foldin_exp : (('a -> 'b exp -> 'a * 'b exp) -> 'a -> 'b exp -> 'a * 'b exp) -> 'a -> 'b exp -> 'a * 'b exp
val foldin_pexp : (('a -> 'b exp -> 'a * 'b exp) -> 'a -> 'b exp -> 'a * 'b exp) -> 'a -> 'b pexp -> 'a * 'b pexp
46 changes: 26 additions & 20 deletions src/lib/rewrites.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1749,10 +1749,9 @@ let rewrite_ast_early_return effect_info env ast =
let early_ret_spec =
fst
(Type_error.check_defs initial_env
[gen_vs ~pure:false ("early_return", "forall ('a : Type) ('b : Type). 'a -> 'b")]
[gen_vs ~pure:true ("early_return", "forall ('a : Type) ('b : Type). 'a -> 'b")]
)
in
let effect_info = Effects.add_monadic_built_in (mk_id "early_return") effect_info in

let new_ast =
rewrite_ast_base
Expand Down Expand Up @@ -2264,7 +2263,13 @@ let rewrite_ast_letbind_effects effect_info env =

let purify (E_aux (aux, (l, tannot))) = E_aux (aux, (l, add_effect_annot tannot no_effect)) in

let value (E_aux (exp_aux, _) as exp) = not (effectful exp || updates_vars exp) in
let needs_monad exp = effectful exp || has_early_return exp in
let pexp_needs_monad pexp =
let _, guard, exp, _ = destruct_pexp pexp in
let guard_needs_monad = match guard with Some g -> needs_monad g | None -> false in
needs_monad exp || guard_needs_monad
in
let value exp = not (needs_monad exp || updates_vars exp) in

let rec n_exp_name (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp =
n_exp exp (fun exp -> if value exp then k exp else monadic (letbind exp k))
Expand Down Expand Up @@ -2328,7 +2333,7 @@ let rewrite_ast_letbind_effects effect_info env =
| E_app (op_bool, [l; r]) when string_of_id op_bool = "and_bool" || string_of_id op_bool = "or_bool" ->
(* Leave effectful operands of Boolean "and"/"or" in place to allow
short-circuiting. *)
let newreturn = effectful l || effectful r in
let newreturn = needs_monad l || needs_monad r in
let l = n_exp_term ~cast:true newreturn l in
let r = n_exp_term ~cast:true newreturn r in
k (rewrap (E_app (op_bool, [l; r])))
Expand All @@ -2341,7 +2346,7 @@ let rewrite_ast_letbind_effects effect_info env =
| E_tuple exps -> n_exp_nameL exps (fun exps -> k (pure_rewrap (E_tuple exps)))
| E_if (exp1, exp2, exp3) ->
let e_if exp1 =
let newreturn = effectful exp2 || effectful exp3 in
let newreturn = needs_monad exp2 || needs_monad exp3 in
let exp2 = n_exp_term newreturn exp2 in
let exp3 = n_exp_term newreturn exp3 in
k (rewrap (E_if (exp1, exp2, exp3)))
Expand All @@ -2351,7 +2356,7 @@ let rewrite_ast_letbind_effects effect_info env =
n_exp_name start (fun start ->
n_exp_name stop (fun stop ->
n_exp_name by (fun by ->
let body = n_exp_term (effectful body) body in
let body = n_exp_term (needs_monad body) body in
k (rewrap (E_for (id, start, stop, by, dir, body)))
)
)
Expand All @@ -2362,8 +2367,8 @@ let rewrite_ast_letbind_effects effect_info env =
| Measure_aux (Measure_none, _) -> measure
| Measure_aux (Measure_some exp, l) -> Measure_aux (Measure_some (n_exp_term false exp), l)
in
let cond = n_exp_term ~cast:true (effectful cond) cond in
let body = n_exp_term (effectful body) body in
let cond = n_exp_term ~cast:true (needs_monad cond) cond in
let body = n_exp_term (needs_monad body) body in
k (rewrap (E_loop (loop, measure, cond, body)))
| E_vector exps -> n_exp_nameL exps (fun exps -> k (pure_rewrap (E_vector exps)))
| E_vector_access (exp1, exp2) ->
Expand Down Expand Up @@ -2398,25 +2403,25 @@ let rewrite_ast_letbind_effects effect_info env =
n_exp_name exp1 (fun exp1 -> n_fexpL fexps (fun fexps -> k (pure_rewrap (E_struct_update (exp1, fexps)))))
| E_field (exp1, id) -> n_exp_name exp1 (fun exp1 -> k (pure_rewrap (E_field (exp1, id))))
| E_match (exp1, pexps) ->
let newreturn = List.exists effectful_pexp pexps in
let newreturn = List.exists pexp_needs_monad pexps in
n_exp_name exp1 (fun exp1 -> n_pexpL newreturn pexps (fun pexps -> k (rewrap (E_match (exp1, pexps)))))
| E_try (exp1, pexps) ->
let newreturn = effectful exp1 || List.exists effectful_pexp pexps in
let newreturn = needs_monad exp1 || List.exists pexp_needs_monad pexps in
let exp1 = n_exp_term newreturn exp1 in
n_pexpL newreturn pexps (fun pexps -> k (rewrap (E_try (exp1, pexps))))
| E_let (lb, body) -> n_lb lb (fun lb -> rewrap (E_let (lb, n_exp body k)))
| E_sizeof nexp -> k (rewrap (E_sizeof nexp))
| E_constraint nc -> k (rewrap (E_constraint nc))
| E_assign (lexp, exp1) -> n_lexp lexp (fun lexp -> n_exp_name exp1 (fun exp1 -> k (rewrap (E_assign (lexp, exp1)))))
| E_exit exp' -> k (E_aux (E_exit (n_exp_term (effectful exp') exp'), annot))
| E_exit exp' -> k (E_aux (E_exit (n_exp_term (needs_monad exp') exp'), annot))
| E_assert (exp1, exp2) ->
n_exp_name exp1 (fun exp1 -> n_exp_name exp2 (fun exp2 -> k (rewrap (E_assert (exp1, exp2)))))
| E_var (lexp, exp1, exp2) ->
n_lexp lexp (fun lexp -> n_exp exp1 (fun exp1 -> rewrap (E_var (lexp, exp1, n_exp exp2 k))))
| E_internal_return exp1 ->
let is_early_return = function E_aux (E_app (id, _), _) -> string_of_id id = "early_return" | _ -> false in
n_exp_name exp1 (fun exp1 ->
k (if effectful exp1 || is_early_return exp1 then exp1 else rewrap (E_internal_return exp1))
k (if needs_monad exp1 || is_early_return exp1 then exp1 else rewrap (E_internal_return exp1))
)
| E_internal_value v -> k (rewrap (E_internal_value v))
| E_return exp' -> n_exp_name exp' (fun exp' -> k (pure_rewrap (E_return exp')))
Expand All @@ -2426,10 +2431,10 @@ let rewrite_ast_letbind_effects effect_info env =
in

let rewrite_fun _ (FD_aux (FD_function (recopt, tannotopt, funcls), fdannot)) =
(* TODO EFFECT *)
let effectful_vs = false in
let effectful_funcl (FCL_aux (FCL_funcl (_, pexp), _)) = effectful_pexp pexp in
let newreturn = effectful_vs || List.exists effectful_funcl funcls in
let funcl_needs_monad (FCL_aux (FCL_funcl (id, pexp), _)) =
pexp_needs_monad pexp || not (Effects.function_is_pure id effect_info)
in
let newreturn = List.exists funcl_needs_monad funcls in
let rewrite_funcl (FCL_aux (FCL_funcl (id, pexp), annot)) =
let _ = reset_fresh_name_counter () in
FCL_aux (FCL_funcl (id, n_pexp newreturn pexp (fun x -> x)), annot)
Expand All @@ -2442,7 +2447,7 @@ let rewrite_ast_letbind_effects effect_info env =
| DEF_let (LB_aux (lb, annot)) ->
let rewrap lb = DEF_let (LB_aux (lb, annot)) in
begin
match lb with LB_val (pat, exp) -> rewrap (LB_val (pat, n_exp_term (effectful exp) exp))
match lb with LB_val (pat, exp) -> rewrap (LB_val (pat, n_exp_term (needs_monad exp) exp))
end
| DEF_fundef fdef -> DEF_fundef (rewrite_fun rewriters fdef)
| DEF_internal_mutrec fdefs -> DEF_internal_mutrec (List.map (rewrite_fun rewriters) fdefs)
Expand Down Expand Up @@ -2500,7 +2505,7 @@ let rewrite_ast_internal_lets env =
let rhs = add_e_typ (env_of exp) ltyp (rhs exp) in
E_let (LB_aux (LB_val (pat_of_local_lexp lhs, rhs), annot), body)
| LB_aux (LB_val (pat, exp'), annot') ->
if effectful exp' then E_internal_plet (pat, exp', body) else E_let (lb, body)
if effectful exp' || has_early_return exp' then E_internal_plet (pat, exp', body) else E_let (lb, body)
in

let e_var (lexp, exp1, exp2) =
Expand All @@ -2511,7 +2516,7 @@ let rewrite_ast_internal_lets env =
(unaux_pat (add_p_typ (env_of_annot annot) typ (P_aux (P_id id, annot))), annot)
| _ -> failwith "E_var with unexpected lexp"
in
if effectful exp1 then E_internal_plet (P_aux (paux, annot), exp1, exp2)
if effectful exp1 || has_early_return exp1 then E_internal_plet (P_aux (paux, annot), exp1, exp2)
else E_let (LB_aux (LB_val (P_aux (paux, annot), exp1), annot), exp2)
in

Expand Down Expand Up @@ -3075,7 +3080,8 @@ let rewrite_ast_remove_superfluous_returns env =

let e_aux (exp, annot) =
match exp with
| (E_let (LB_aux (LB_val (pat, exp1), _), exp2) | E_internal_plet (pat, exp1, exp2)) when effectful exp1 -> begin
| (E_let (LB_aux (LB_val (pat, exp1), _), exp2) | E_internal_plet (pat, exp1, exp2))
when effectful exp1 || has_early_return exp1 -> begin
match (untyp_pat pat, uncast_exp exp2) with
| ( (P_aux (P_lit (L_aux (lit, _)), _), ptyp),
(E_aux (E_internal_return (E_aux (E_lit (L_aux (lit', _)), _)), a), etyp) )
Expand Down
53 changes: 22 additions & 31 deletions src/sail_coq_backend/pretty_print_coq.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1128,14 +1128,6 @@ let rec doc_pat ctxt apat_needed (P_aux (p, (l, annot))) =
| P_not _ -> unreachable l __POS__ "Coq backend doesn't support not patterns"
| P_or _ -> unreachable l __POS__ "Coq backend doesn't support or patterns yet"

let contains_early_return exp =
let e_app (f, args) =
let rets, args = List.split args in
(List.fold_left ( || ) (string_of_id f = "early_return") rets, E_app (f, args))
in
fst
(fold_exp { (Rewriter.compute_exp_alg false ( || )) with e_return = (fun (_, r) -> (true, E_return r)); e_app } exp)

let find_e_ids exp =
let e_id id = (IdSet.singleton id, E_id id) in
fst (fold_exp { (compute_exp_alg IdSet.empty IdSet.union) with e_id } exp)
Expand Down Expand Up @@ -1526,6 +1518,7 @@ let doc_exp, doc_let =
begin
match f with
| (Id_aux (Id "and_bool", _) | Id_aux (Id "or_bool", _)) when effectful (effect_of full_exp) ->
(* TODO: Pure early return? *)
let suffix = "M" in
let call = doc_id ctxt (append_id f suffix) in
debug ctxt (lazy ("Effectful boolean op: " ^ string_of_id f));
Expand Down Expand Up @@ -1576,7 +1569,9 @@ let doc_exp, doc_let =
in
let effects = effectful (effect_of body) in
let combinator =
if effects then if ctxt.is_monadic then "foreach_ZM" else "foreach_ZE" else "foreach_Z"
if ctxt.is_monadic && effectful (effect_of body) then "foreach_ZM"
else if has_early_return body then "foreach_ZE"
else "foreach_Z"
in
let combinator = combinator ^ dir in
let body_ctxt = add_single_kid_id_rename ctxt loopvar (mk_kid ("loop_" ^ string_of_id loopvar)) in
Expand Down Expand Up @@ -1614,13 +1609,16 @@ let doc_exp, doc_let =
let a' = mk_tannot (env_of_annot (l, a)) bool_typ in
E_aux (E_typ (bool_typ, exp), (l, a'))
in
let monad = if ctxt.is_monadic then "M" else "E" in
let csuffix, cond, body, body_effectful =
match (effectful (effect_of cond), effectful (effect_of body)) with
| false, false -> ("", cond, body, false)
| false, true -> (monad, return cond, body, true)
| true, false -> (monad, simple_bool cond, return body, true)
| true, true -> (monad, simple_bool cond, body, true)
let needs_monad e = effectful (effect_of e) || has_early_return e in
let csuffix =
if needs_monad full_exp then if effectful (effect_of full_exp) then "M" else "E" else ""
in
let cond, body, body_effectful =
match (needs_monad cond, needs_monad body) with
| false, false -> (cond, body, false)
| false, true -> (return cond, body, true)
| true, false -> (simple_bool cond, return body, true)
| true, true -> (simple_bool cond, body, true)
in
(* If rewrite_loops_with_escape_effect added a dummy assertion to
ensure that the loop can escape when it reaches the limit, omit
Expand Down Expand Up @@ -1984,18 +1982,18 @@ let doc_exp, doc_let =
let cast_ex, _, cast_typ' = classify_ex_type ctxt env ~rawbools:true cast_typ in
let inner_ex, _, inner_typ' = classify_ex_type ctxt env inner_typ in
let autocast_out = autocast_req ctxt env outer_typ cast_typ outer_typ' cast_typ' in
let effects = effectful (effect_of e) in
let needs_monad = effectful (effect_of e) || has_early_return e in
let () =
debug ctxt
( lazy
(" effectful: " ^ string_of_bool effects ^ " outer_ex: " ^ string_of_ex_kind outer_ex ^ " cast_ex: "
(" effectful: " ^ string_of_bool needs_monad ^ " outer_ex: " ^ string_of_ex_kind outer_ex ^ " cast_ex: "
^ string_of_ex_kind cast_ex ^ " inner_ex: " ^ string_of_ex_kind inner_ex ^ " autocast_out: "
^ string_of_auto_t autocast_out
)
)
in
let epp = epp ^/^ doc_tannot ctxt (env_of e) effects typ in
let autocast_name = if effects then "autocast_m" else "autocast" in
let epp = epp ^/^ doc_tannot ctxt (env_of e) needs_monad typ in
let autocast_name = if effectful (effect_of e) then "autocast_m" else "autocast" in
let epp =
match autocast_out with
| No -> epp
Expand Down Expand Up @@ -2254,7 +2252,9 @@ let doc_exp, doc_let =
across if expressions in complex situations, so provide an
annotation for monadic expressions. *)
let add_type_pp pp =
if effectful (effect_of t) then pp ^/^ string "return" ^/^ doc_tannot_core ctxt full_env true full_typ else pp
if effectful (effect_of t) || has_early_return t then
pp ^/^ string "return" ^/^ doc_tannot_core ctxt full_env true full_typ
else pp
in
let t_pp = top_exp ctxt false t in
let else_pp =
Expand Down Expand Up @@ -2984,7 +2984,7 @@ let doc_funcl_init types_mod avoid_target_names effect_info mutrec rec_opt ?rec_
let ctxt =
{
ctxt0 with
early_ret = (if contains_early_return exp then Some ret_typ else None);
early_ret = (if has_early_return exp then Some ret_typ else None);
ret_typ_pp = doc_typ ctxt0 env ret_typ;
}
in
Expand Down Expand Up @@ -3095,15 +3095,6 @@ let doc_funcl_init types_mod avoid_target_names effect_info mutrec rec_opt ?rec_

let doc_funcl_body ctxt (exp, is_monadic, fixupspp) =
let bodypp = doc_fun_body ctxt is_monadic exp in
let bodypp =
if is_monadic then
(* Sometimes a function is marked effectful by effect inference
when it's not (especially mappings)... TODO: this seems
bad!? *)
if not (effectful (effect_of exp)) then string "returnM" ^/^ parens bodypp else bodypp
else if Option.is_some ctxt.early_ret then bodypp
else bodypp
in
let bodypp = separate (break 1) (fixupspp @ [bodypp]) in
group bodypp

Expand Down
2 changes: 2 additions & 0 deletions src/sail_coq_backend/sail_plugin_coq.ml
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ let coq_rewrites =
("rewrite_explicit_measure", []);
("rewrite_loops_with_escape_effect", []);
("recheck_defs", []);
("infer_effects", [Bool_arg true]);
("attach_effects", []);
("remove_blocks", []);
("attach_effects", []);
("letbind_effects", []);
Expand Down
Loading

0 comments on commit dee5964

Please sign in to comment.