Skip to content

Commit

Permalink
Make sure constructors with named fields work in scattered unions
Browse files Browse the repository at this point in the history
For example:

union clause instr = ADD : { rd : regidx, r1 : regidx, r2 : regidx }

This is a bit tricky because the way this works (for regular unions) is we
generate a separate struct type with the correct fields. However in the
union clause case that struct type must be valid in the scattered union
environment and must be de-scattered appropriately with its associated union
clause. To accomplish this, introduce SD_internal_unioncl_record for a struct
type associated with a union clause.
  • Loading branch information
Alasdair committed Sep 22, 2023
1 parent dcd00ca commit 76ef8a6
Show file tree
Hide file tree
Showing 13 changed files with 354 additions and 58 deletions.
2 changes: 2 additions & 0 deletions language/sail.ott
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,8 @@ scattered_def :: 'SD_' ::=
| union id member type_union :: :: unioncl
{{ texlong }} {{ com scattered union definition member }}

| internal_unioncl_record id1 id2 typquant { typ1 id1 , ... , typn idn } :: :: internal_unioncl_record

| scattered mapping id : tannot_opt :: :: mapping

| mapping clause id = mapcl :: :: mapcl
Expand Down
2 changes: 2 additions & 0 deletions src/lib/ast_util.ml
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,7 @@ and map_scattered_annot_aux f = function
| SD_funcl fcl -> SD_funcl (map_funcl_annot f fcl)
| SD_variant (id, typq) -> SD_variant (id, typq)
| SD_unioncl (id, tu) -> SD_unioncl (id, tu)
| SD_internal_unioncl_record (id, record_id, typq, fields) -> SD_internal_unioncl_record (id, record_id, typq, fields)
| SD_mapping (id, tannot_opt) -> SD_mapping (id, tannot_opt)
| SD_mapcl (id, mcl) -> SD_mapcl (id, map_mapcl_annot f mcl)
| SD_end id -> SD_end id
Expand Down Expand Up @@ -1122,6 +1123,7 @@ let id_of_scattered (SD_aux (sdef, _)) =
| SD_end id
| SD_variant (id, _)
| SD_unioncl (id, _)
| SD_internal_unioncl_record (_, id, _, _)
| SD_mapping (id, _)
| SD_mapcl (id, _)
| SD_enum id
Expand Down
129 changes: 87 additions & 42 deletions src/lib/initial_check.ml
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ type type_constructor = kind_aux option list
type ctx = {
kinds : kind_aux KBindings.t;
type_constructors : type_constructor Bindings.t;
scattereds : ctx Bindings.t;
scattereds : (P.typquant * ctx) Bindings.t;
fixities : (prec * int) Bindings.t;
internal_files : StringSet.t;
target_sets : string list StringMap.t;
Expand All @@ -101,7 +101,9 @@ type ctx = {
let rec equal_ctx ctx1 ctx2 =
KBindings.equal ( = ) ctx1.kinds ctx2.kinds
&& Bindings.equal ( = ) ctx1.type_constructors ctx2.type_constructors
&& Bindings.equal equal_ctx ctx1.scattereds ctx2.scattereds
&& Bindings.equal
(fun (typq1, ctx1) (typq2, ctx2) -> typq1 = typq2 && equal_ctx ctx1 ctx2)
ctx1.scattereds ctx2.scattereds
&& Bindings.equal ( = ) ctx1.fixities ctx2.fixities
&& StringSet.equal ctx1.internal_files ctx2.internal_files
&& StringMap.equal ( = ) ctx1.target_sets ctx2.target_sets
Expand All @@ -125,7 +127,10 @@ let merge_ctx l ctx1 ctx2 =
ctx1.type_constructors ctx2.type_constructors;
scattereds =
Bindings.merge
(compatible equal_ctx (fun id -> "Scattered definition " ^ string_of_id id ^ " found with mismatching context"))
(compatible
(fun (typq1, ctx1) (typq2, ctx2) -> typq1 = typq2 && equal_ctx ctx1 ctx2)
(fun id -> "Scattered definition " ^ string_of_id id ^ " found with mismatching context")
)
ctx1.scattereds ctx2.scattereds;
fixities =
Bindings.merge
Expand Down Expand Up @@ -872,22 +877,28 @@ let anon_rec_constructor_typ record_id = function
| args -> P.ATyp_aux (P.ATyp_app (record_id, args), Generated l)
)

let rec realise_union_anon_rec_types orig_union arms =
let realize_union_anon_rec_arm union_id typq = function
| P.Tu_aux (P.Tu_ty_id _, _) as arm -> (None, arm)
| P.Tu_aux (P.Tu_ty_anon_rec (fields, id), l) ->
let open Parse_ast in
let record_str = "_" ^ string_of_parse_id union_id ^ "_" ^ string_of_parse_id id ^ "_record" in
let record_id = Id_aux (Id record_str, Generated l) in
let new_arm = Tu_aux (Tu_ty_id (anon_rec_constructor_typ record_id typq, id), Generated l) in
(Some (record_id, fields, l), new_arm)

let rec realize_union_anon_rec_types orig_union arms =
match orig_union with
| P.TD_variant (union_id, typq, _, flag) -> begin
match arms with
| [] -> []
| arm :: arms -> (
match arm with
| P.Tu_aux (P.Tu_ty_id _, _) -> (None, arm) :: realise_union_anon_rec_types orig_union arms
| P.Tu_aux (P.Tu_ty_anon_rec (fields, id), l) ->
let open Parse_ast in
let record_str = "_" ^ string_of_parse_id union_id ^ "_" ^ string_of_parse_id id ^ "_record" in
let record_id = Id_aux (Id record_str, Generated l) in
let new_arm = Tu_aux (Tu_ty_id (anon_rec_constructor_typ record_id typq, id), Generated l) in
let new_rec_def = TD_aux (TD_record (record_id, typq, fields, flag), Generated l) in
(Some new_rec_def, new_arm) :: realise_union_anon_rec_types orig_union arms
)
| arm :: arms ->
let realized =
match realize_union_anon_rec_arm union_id typq arm with
| Some (record_id, fields, l), new_arm ->
(Some (P.TD_aux (P.TD_record (record_id, typq, fields, flag), Generated l)), new_arm)
| None, arm -> (None, arm)
in
realized :: realize_union_anon_rec_types orig_union arms
end
| _ ->
raise
Expand Down Expand Up @@ -960,6 +971,12 @@ let to_ast_reserved_type_id ctx id =
end
else id

let to_ast_record ctx id typq fields =
let id = to_ast_reserved_type_id ctx id in
let typq, typq_ctx = to_ast_typquant ctx typq in
let fields = List.map (fun (atyp, id) -> (to_ast_typ typq_ctx atyp, to_ast_id ctx id)) fields in
(id, typq, fields, add_constructor id typq ctx)

let rec to_ast_typedef ctx def_annot (P.TD_aux (aux, l) : P.type_def) : uannot def list ctx_out =
match aux with
| P.TD_abbrev (id, typq, kind, typ_arg) ->
Expand All @@ -975,15 +992,11 @@ let rec to_ast_typedef ctx def_annot (P.TD_aux (aux, l) : P.type_def) : uannot d
| None -> ([], ctx)
end
| P.TD_record (id, typq, fields, _) ->
let id = to_ast_reserved_type_id ctx id in
let typq, typq_ctx = to_ast_typquant ctx typq in
let fields = List.map (fun (atyp, id) -> (to_ast_typ typq_ctx atyp, to_ast_id ctx id)) fields in
( [DEF_aux (DEF_type (TD_aux (TD_record (id, typq, fields, false), (l, empty_uannot))), def_annot)],
add_constructor id typq ctx
)
let id, typq, fields, ctx = to_ast_record ctx id typq fields in
([DEF_aux (DEF_type (TD_aux (TD_record (id, typq, fields, false), (l, empty_uannot))), def_annot)], ctx)
| P.TD_variant (id, typq, arms, _) as union ->
(* First generate auxilliary record types for anonymous records in constructors *)
let records_and_arms = realise_union_anon_rec_types union arms in
let records_and_arms = realize_union_anon_rec_types union arms in
let rec filter_records = function
| [] -> []
| Some x :: xs -> x :: filter_records xs
Expand Down Expand Up @@ -1124,45 +1137,63 @@ let to_ast_dec ctx (P.DEC_aux (regdec, l)) =
)

let to_ast_scattered ctx (P.SD_aux (aux, l)) =
let aux, ctx =
let extra_def, aux, ctx =
match aux with
| P.SD_function (rec_opt, tannot_opt, _, id) ->
let tannot_opt, _ = to_ast_tannot_opt ctx tannot_opt in
(SD_function (to_ast_rec ctx rec_opt, tannot_opt, to_ast_id ctx id), ctx)
| P.SD_funcl funcl -> (SD_funcl (to_ast_funcl ctx funcl), ctx)
| P.SD_variant (id, typq) ->
(None, SD_function (to_ast_rec ctx rec_opt, tannot_opt, to_ast_id ctx id), ctx)
| P.SD_funcl funcl -> (None, SD_funcl (to_ast_funcl ctx funcl), ctx)
| P.SD_variant (id, parse_typq) ->
let id = to_ast_id ctx id in
let typq, typq_ctx = to_ast_typquant ctx typq in
( SD_variant (id, typq),
add_constructor id typq { ctx with scattereds = Bindings.add id typq_ctx ctx.scattereds }
let typq, typq_ctx = to_ast_typquant ctx parse_typq in
( None,
SD_variant (id, typq),
add_constructor id typq { ctx with scattereds = Bindings.add id (parse_typq, typq_ctx) ctx.scattereds }
)
| P.SD_unioncl (id, tu) ->
let id = to_ast_id ctx id in
| P.SD_unioncl (union_id, tu) ->
let id = to_ast_id ctx union_id in
begin
match Bindings.find_opt id ctx.scattereds with
| Some typq_ctx ->
let tu = to_ast_type_union typq_ctx tu in
(SD_unioncl (id, tu), ctx)
| Some (typq, scattered_ctx) ->
let anon_rec_opt, tu = realize_union_anon_rec_arm union_id typq tu in
let extra_def, scattered_ctx =
match anon_rec_opt with
| Some (record_id, fields, l) ->
let l = gen_loc l in
let record_id, typq, fields, scattered_ctx = to_ast_record scattered_ctx record_id typq fields in
( Some
(DEF_aux
( DEF_scattered
(SD_aux (SD_internal_unioncl_record (id, record_id, typq, fields), (l, empty_uannot))),
mk_def_annot l
)
),
scattered_ctx
)
| None -> (None, scattered_ctx)
in
let tu = to_ast_type_union scattered_ctx tu in
(extra_def, SD_unioncl (id, tu), ctx)
| None -> raise (Reporting.err_typ l ("No scattered union declaration found for " ^ string_of_id id))
end
| P.SD_end id -> (SD_end (to_ast_id ctx id), ctx)
| P.SD_end id -> (None, SD_end (to_ast_id ctx id), ctx)
| P.SD_mapping (id, tannot_opt) ->
let id = to_ast_id ctx id in
let tannot_opt, _ = to_ast_tannot_opt ctx tannot_opt in
(SD_mapping (id, tannot_opt), ctx)
(None, SD_mapping (id, tannot_opt), ctx)
| P.SD_mapcl (id, mapcl) ->
let id = to_ast_id ctx id in
let mapcl = to_ast_mapcl ctx mapcl in
(SD_mapcl (id, mapcl), ctx)
(None, SD_mapcl (id, mapcl), ctx)
| P.SD_enum id ->
let id = to_ast_id ctx id in
(SD_enum id, ctx)
(None, SD_enum id, ctx)
| P.SD_enumcl (id, member) ->
let id = to_ast_id ctx id in
let member = to_ast_id ctx member in
(SD_enumcl (id, member), ctx)
(None, SD_enumcl (id, member), ctx)
in
(SD_aux (aux, (l, empty_uannot)), ctx)
(extra_def, SD_aux (aux, (l, empty_uannot)), ctx)

let to_ast_prec = function P.Infix -> Infix | P.InfixL -> InfixL | P.InfixR -> InfixR

Expand Down Expand Up @@ -1255,8 +1286,8 @@ let rec to_ast_def doc attrs ctx (P.DEF_aux (def, l)) : uannot def list ctx_out
(* Should never occur because of remove_mutrec *)
raise (Reporting.err_unreachable l __POS__ "Internal mutual block found when processing scattered defs")
| P.DEF_scattered sdef ->
let sdef, ctx = to_ast_scattered ctx sdef in
([DEF_aux (DEF_scattered sdef, annot)], ctx)
let extra_def, sdef, ctx = to_ast_scattered ctx sdef in
([DEF_aux (DEF_scattered sdef, annot)] @ Option.to_list extra_def, ctx)
| P.DEF_measure (id, pat, exp) ->
([DEF_aux (DEF_measure (to_ast_id ctx id, to_ast_pat ctx pat, to_ast_exp ctx exp), annot)], ctx)
| P.DEF_loop_measures (id, measures) ->
Expand Down Expand Up @@ -1522,6 +1553,20 @@ let generate_undefineds vs_ids defs =
| (DEF_aux (DEF_scattered (SD_aux (SD_variant (id, typq), _)), _) as def) :: defs ->
let vs, fn = undefined_scattered id typq in
(def :: vs :: undefined_defs defs) @ [fn]
| (DEF_aux (DEF_scattered (SD_aux (SD_internal_unioncl_record (_, id, typq, fields), _)), _) as def) :: defs
when not (IdSet.mem (prepend_id "undefined_" id) vs_ids) ->
let pat =
p_tup (quant_items typq |> List.map quant_item_param |> List.concat |> List.map (fun id -> mk_pat (P_id id)))
in
let vs = mk_val_spec (VS_val_spec (undefined_typschm id typq, prepend_id "undefined_" id, None)) in
let fn =
mk_fundef
[
mk_funcl (prepend_id "undefined_" id) pat
(mk_exp (E_struct (List.map (fun (_, id) -> mk_fexp id (mk_lit_exp L_undef)) fields)));
]
in
def :: vs :: fn :: undefined_defs defs
| def :: defs -> def :: undefined_defs defs
| [] -> []
in
Expand Down
9 changes: 9 additions & 0 deletions src/lib/pretty_print_sail.ml
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,15 @@ let doc_scattered (SD_aux (sd_aux, _)) =
| SD_mapping (id, Typ_annot_opt_aux (Typ_annot_opt_some (typq, typ), _)) ->
separate space [string "scattered mapping"; doc_id id; colon; doc_binding (typq, typ)]
| SD_unioncl (id, tu) -> separate space [string "union clause"; doc_id id; equals; doc_union tu]
| SD_internal_unioncl_record (id, record_id, typq, fields) ->
let prefix = separate space [string "internal_union_record clause"; doc_id id; doc_id record_id] in
let params =
match typq with
| TypQ_aux (TypQ_no_forall, _) | TypQ_aux (TypQ_tq [], _) -> empty
| TypQ_aux (TypQ_tq qs, _) -> doc_param_quants qs
in
separate space
[prefix ^^ params; equals; surround 2 0 lbrace (separate_map (comma ^^ break 1) doc_field fields) rbrace]
| SD_enum id -> separate space [string "scattered enum"; doc_id id]
| SD_enumcl (id, member) -> separate space [string "enum clause"; doc_id id; equals; doc_id member]

Expand Down
4 changes: 3 additions & 1 deletion src/lib/rewriter.ml
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,9 @@ let rewrite_scattered rewriters (SD_aux (sd, (l, annot))) =
match sd with
| SD_funcl funcl -> SD_funcl (rewrite_funcl rewriters funcl)
| SD_mapcl (id, mapcl) -> SD_mapcl (id, rewrite_mapcl rewriters mapcl)
| SD_variant _ | SD_unioncl _ | SD_mapping _ | SD_function _ | SD_end _ | SD_enum _ | SD_enumcl _ -> sd
| SD_variant _ | SD_unioncl _ | SD_mapping _ | SD_function _ | SD_end _ | SD_enum _ | SD_enumcl _
| SD_internal_unioncl_record _ ->
sd
in
SD_aux (sd, (l, annot))

Expand Down
16 changes: 15 additions & 1 deletion src/lib/scattered.ml
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,21 @@ let fake_rec_opt l = Rec_aux (Rec_nonrec, gen_loc l)

let no_tannot_opt l = Typ_annot_opt_aux (Typ_annot_opt_none, gen_loc l)

let rec get_union_records id acc = function
| DEF_aux (DEF_scattered (SD_aux (SD_internal_unioncl_record (uid, record_id, typq, fields), annot)), def_annot)
:: defs
when Id.compare uid id = 0 ->
let def = DEF_aux (DEF_type (TD_aux (TD_record (record_id, typq, fields, false), annot)), def_annot) in
get_union_records id (def :: acc) defs
| def :: defs -> get_union_records id acc defs
| [] -> acc

let rec filter_union_clauses id = function
| DEF_aux (DEF_scattered (SD_aux (SD_unioncl (uid, _), _)), _) :: defs when Id.compare id uid = 0 ->
filter_union_clauses id defs
| DEF_aux (DEF_scattered (SD_aux (SD_internal_unioncl_record (uid, _, _, _), _)), _) :: defs
when Id.compare id uid = 0 ->
filter_union_clauses id defs
| def :: defs -> def :: filter_union_clauses id defs
| [] -> []

Expand Down Expand Up @@ -168,14 +180,16 @@ let rec descatter' accumulator funcls mapcls = function
regular union declaration. *)
| DEF_aux (DEF_scattered (SD_aux (SD_variant (id, typq), (l, _))), def_annot) :: defs ->
let tus = get_scattered_union_clauses id defs in
let records = get_union_records id [] defs in
begin
match tus with
| [] -> raise (Reporting.err_general l "No clauses found for scattered union type")
| _ ->
let accumulator =
DEF_aux
(DEF_type (TD_aux (TD_variant (id, typq, tus, false), (gen_loc l, Type_check.empty_tannot))), def_annot)
:: accumulator
:: records
@ accumulator
in
descatter' accumulator funcls mapcls (filter_union_clauses id defs)
end
Expand Down
43 changes: 30 additions & 13 deletions src/lib/type_check.ml
Original file line number Diff line number Diff line change
Expand Up @@ -3692,7 +3692,7 @@ and bind_mpat allow_unknown other_env env (MP_aux (mpat_aux, (l, uannot)) as mpa
match mpat_aux with
| MP_lit lit ->
let var = fresh_var () in
let guard = mk_exp (E_app_infix (mk_exp (E_id var), mk_id "==", mk_exp (E_lit lit))) in
let guard = mk_exp ~loc:l (E_app_infix (mk_exp (E_id var), mk_id "==", mk_exp (E_lit lit))) in
let typed_mpat, env, guards = bind_mpat allow_unknown other_env env (mk_mpat (MP_id var)) typ in
(typed_mpat, env, guard :: guards)
| _ -> raise typ_exn
Expand Down Expand Up @@ -4152,6 +4152,20 @@ let check_type_union u_l non_rec_env env variant typq (Tu_aux (Tu_ty_id (arg_typ
wf_binding l env (typq, typ);
env |> Env.add_union_id v (typq, typ) |> Env.add_val_spec v (typq, typ)

let check_record l env def_annot id typq fields =
forbid_recursive_types l (fun () ->
List.iter (fun ((Typ_aux (_, l) as field), _) -> wf_binding l env (typq, field)) fields
);
let env =
try
match get_def_attribute "bitfield" def_annot with
| Some (_, size) when not (Env.is_bitfield id env) ->
Env.add_bitfield id (bits_typ env (nconstant (Big_int.of_string size))) Bindings.empty env
| _ -> env
with _ -> env
in
Env.add_record id typq fields env

let rec check_typedef : Env.t -> def_annot -> uannot type_def -> tannot def list * Env.t =
fun env def_annot (TD_aux (tdef, (l, _))) ->
match tdef with
Expand All @@ -4163,18 +4177,8 @@ let rec check_typedef : Env.t -> def_annot -> uannot type_def -> tannot def list
end;
([DEF_aux (DEF_type (TD_aux (tdef, (l, empty_tannot))), def_annot)], Env.add_typ_synonym id typq typ_arg env)
| TD_record (id, typq, fields, _) ->
forbid_recursive_types l (fun () ->
List.iter (fun ((Typ_aux (_, l) as field), _) -> wf_binding l env (typq, field)) fields
);
let env =
try
match get_def_attribute "bitfield" def_annot with
| Some (_, size) when not (Env.is_bitfield id env) ->
Env.add_bitfield id (bits_typ env (nconstant (Big_int.of_string size))) Bindings.empty env
| _ -> env
with _ -> env
in
([DEF_aux (DEF_type (TD_aux (tdef, (l, empty_tannot))), def_annot)], Env.add_record id typq fields env)
let env = check_record l env def_annot id typq fields in
([DEF_aux (DEF_type (TD_aux (tdef, (l, empty_tannot))), def_annot)], env)
| TD_variant (id, typq, arms, _) ->
let rec_env = Env.add_variant id (typq, arms) env in
(* register_value is a special type used by theorem prover
Expand Down Expand Up @@ -4248,6 +4252,19 @@ and check_scattered : Env.t -> def_annot -> uannot scattered_def -> tannot def l
in
raise (Type_error (l', err_because (err, id_loc id, Err_other msg)))
)
| SD_internal_unioncl_record (id, record_id, typq, fields) ->
let definition_env = Env.get_scattered_variant_env id env in
let definition_env = check_record l definition_env def_annot record_id typq fields in
let env = Env.set_scattered_variant_env ~variant_env:definition_env id env in
let env = Env.add_record record_id typq fields env in
( [
DEF_aux
( DEF_scattered (SD_aux (SD_internal_unioncl_record (id, record_id, typq, fields), (l, empty_tannot))),
def_annot
);
],
env
)
| SD_funcl (FCL_aux (FCL_funcl (id, _), (fcl_def_annot, _)) as funcl) ->
let typq, typ = Env.get_val_spec id env in
let funcl_env = Env.add_typquant fcl_def_annot.loc typq env in
Expand Down
Loading

0 comments on commit 76ef8a6

Please sign in to comment.