Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add annotations to Jib toplevel definitions #460

Merged
merged 1 commit into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions language/jib.ott
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,11 @@ instr :: 'I_' ::=
| reset ctyp name :: :: reset
| ctyp name = cval :: :: reinit

def_annot :: '' ::=
{{ phantom }}

cdef :: 'CDEF_' ::=
{{ aux _ def_annot }}
| register id : ctyp = {
instr0 ; ... ; instrn
} :: :: register
Expand Down
85 changes: 51 additions & 34 deletions src/lib/jib_compile.ml
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ let callgraph cdefs =
List.fold_left
(fun graph cdef ->
match cdef with
| CDEF_fundef (id, _, _, body) ->
| CDEF_aux (CDEF_fundef (id, _, _, body), _) ->
let graph = ref graph in
List.iter
(iter_instr (function
Expand Down Expand Up @@ -1375,7 +1375,7 @@ module Make (C : CONFIG) = struct

let letdef_count = ref 0

let compile_funcl ctx id pat guard exp =
let compile_funcl ctx def_annot id pat guard exp =
(* Find the function's type. *)
let quant, Typ_aux (fn_typ, _) =
try Env.get_val_spec id ctx.local_env with Type_error.Type_error _ -> Env.get_val_spec id ctx.tc_env
Expand Down Expand Up @@ -1441,7 +1441,7 @@ module Make (C : CONFIG) = struct
let instrs = fix_exception ~return:(Some ret_ctyp) ctx instrs in
let instrs = coverage_function_entry id (exp_loc exp) @ instrs in

([CDEF_fundef (id, None, List.map fst compiled_args, instrs)], orig_ctx)
([CDEF_aux (CDEF_fundef (id, None, List.map fst compiled_args, instrs), def_annot)], orig_ctx)

(** Compile a Sail toplevel definition into an IR definition **)
let rec compile_def n total ctx (DEF_aux (aux, _) as def) =
Expand Down Expand Up @@ -1476,15 +1476,16 @@ module Make (C : CONFIG) = struct
end
| _ -> compile_def' n total ctx def

and compile_def' n total ctx (DEF_aux (aux, _) as def) =
and compile_def' n total ctx (DEF_aux (aux, def_annot) as def) =
match aux with
| DEF_register (DEC_aux (DEC_reg (typ, id, None), _)) -> ([CDEF_register (id, ctyp_of_typ ctx typ, [])], ctx)
| DEF_register (DEC_aux (DEC_reg (typ, id, None), _)) ->
([CDEF_aux (CDEF_register (id, ctyp_of_typ ctx typ, []), def_annot)], ctx)
| DEF_register (DEC_aux (DEC_reg (typ, id, Some exp), _)) ->
let aexp = C.optimize_anf ctx (no_shadow ctx.letbind_ids (anf exp)) in
let setup, call, cleanup = compile_aexp ctx aexp in
let instrs = setup @ [call (CL_id (name id, ctyp_of_typ ctx typ))] @ cleanup in
let instrs = unique_names instrs in
([CDEF_register (id, ctyp_of_typ ctx typ, instrs)], ctx)
([CDEF_aux (CDEF_register (id, ctyp_of_typ ctx typ, instrs), def_annot)], ctx)
| DEF_val (VS_aux (VS_val_spec (_, id, ext), _)) ->
let quant, Typ_aux (fn_typ, _) = Env.get_val_spec id ctx.tc_env in
let extern = if Env.is_extern id ctx.tc_env "c" then Some (Env.get_extern id ctx.tc_env "c") else None in
Expand All @@ -1493,16 +1494,16 @@ module Make (C : CONFIG) = struct
in
let ctx' = { ctx with local_env = Env.add_typquant (id_loc id) quant ctx.local_env } in
let arg_ctyps, ret_ctyp = (List.map (ctyp_of_typ ctx') arg_typs, ctyp_of_typ ctx' ret_typ) in
( [CDEF_val (id, extern, arg_ctyps, ret_ctyp)],
( [CDEF_aux (CDEF_val (id, extern, arg_ctyps, ret_ctyp), def_annot)],
{ ctx with valspecs = Bindings.add id (extern, arg_ctyps, ret_ctyp) ctx.valspecs }
)
| DEF_fundef (FD_aux (FD_function (_, _, [FCL_aux (FCL_funcl (id, Pat_aux (Pat_exp (pat, exp), _)), _)]), _)) ->
Util.progress "Compiling " (string_of_id id) n total;
compile_funcl ctx id pat None exp
compile_funcl ctx def_annot id pat None exp
| DEF_fundef (FD_aux (FD_function (_, _, [FCL_aux (FCL_funcl (id, Pat_aux (Pat_when (pat, guard, exp), _)), _)]), _))
->
Util.progress "Compiling " (string_of_id id) n total;
compile_funcl ctx id pat (Some guard) exp
compile_funcl ctx def_annot id pat (Some guard) exp
| DEF_fundef (FD_aux (FD_function (_, _, []), (l, _))) ->
raise (Reporting.err_general l "Encountered function with no clauses")
| DEF_fundef (FD_aux (FD_function (_, _, _ :: _ :: _), (l, _))) ->
Expand All @@ -1512,7 +1513,7 @@ module Make (C : CONFIG) = struct
| DEF_type (TD_aux (TD_abbrev _, _)) -> ([], ctx)
| DEF_type type_def ->
let tdef, ctx = compile_type_def ctx type_def in
([CDEF_type tdef], ctx)
([CDEF_aux (CDEF_type tdef, def_annot)], ctx)
| DEF_let (LB_aux (LB_val (pat, exp), _)) ->
let ctyp = ctyp_of_typ ctx (typ_of_pat pat) in
let aexp = C.optimize_anf ctx (no_shadow ctx.letbind_ids (anf exp)) in
Expand All @@ -1532,7 +1533,7 @@ module Make (C : CONFIG) = struct
@ [ilabel end_label]
in
let instrs = unique_names instrs in
( [CDEF_let (n, bindings, instrs)],
( [CDEF_aux (CDEF_let (n, bindings, instrs), def_annot)],
{ ctx with letbinds = n :: ctx.letbinds; letbind_ids = IdSet.union (pat_ids pat) ctx.letbind_ids }
)
(* Only DEF_default that matters is default Order, but all order
Expand All @@ -1542,7 +1543,7 @@ module Make (C : CONFIG) = struct
| DEF_overload _ -> ([], ctx)
(* Only the parser and sail pretty printer care about this. *)
| DEF_fixity _ -> ([], ctx)
| DEF_pragma ("abstract", id_str, _) -> ([CDEF_pragma ("abstract", id_str)], ctx)
| DEF_pragma ("abstract", id_str, _) -> ([CDEF_aux (CDEF_pragma ("abstract", id_str), def_annot)], ctx)
(* We just ignore any pragmas we don't want to deal with. *)
| DEF_pragma _ -> ([], ctx)
(* Termination measures only needed for Coq, and other theorem prover output *)
Expand Down Expand Up @@ -1571,7 +1572,7 @@ module Make (C : CONFIG) = struct
let polymorphic_functions =
List.filter_map
(function
| CDEF_val (id, _, param_ctyps, ret_ctyp) ->
| CDEF_aux (CDEF_val (id, _, param_ctyps, ret_ctyp), _) ->
if List.exists is_polymorphic param_ctyps || is_polymorphic ret_ctyp then Some id else None
| _ -> None
)
Expand Down Expand Up @@ -1603,7 +1604,8 @@ module Make (C : CONFIG) = struct
each of the monomorphic calls we just found. *)
let spec_tyargs = ref Bindings.empty in
let rec specialize_fundefs ctx prior = function
| (CDEF_val (id, extern, param_ctyps, ret_ctyp) as orig_cdef) :: cdefs when Bindings.mem id !monomorphic_calls ->
| (CDEF_aux (CDEF_val (id, extern, param_ctyps, ret_ctyp), def_annot) as orig_cdef) :: cdefs
when Bindings.mem id !monomorphic_calls ->
let tyargs =
List.fold_left (fun set ctyp -> KidSet.union (ctyp_vars ctyp) set) KidSet.empty (ret_ctyp :: param_ctyps)
in
Expand All @@ -1620,7 +1622,7 @@ module Make (C : CONFIG) = struct
in
let param_ctyps = List.map (subst_poly substs) param_ctyps in
let ret_ctyp = subst_poly substs ret_ctyp in
Some (CDEF_val (specialized_id, extern, param_ctyps, ret_ctyp))
Some (CDEF_aux (CDEF_val (specialized_id, extern, param_ctyps, ret_ctyp), def_annot))
)
else None
)
Expand All @@ -1630,14 +1632,15 @@ module Make (C : CONFIG) = struct
List.fold_left
(fun ctx cdef ->
match cdef with
| CDEF_val (id, _, param_ctyps, ret_ctyp) ->
| CDEF_aux (CDEF_val (id, _, param_ctyps, ret_ctyp), _) ->
{ ctx with valspecs = Bindings.add id (extern, param_ctyps, ret_ctyp) ctx.valspecs }
| cdef -> ctx
)
ctx specialized_specs
in
specialize_fundefs ctx ((orig_cdef :: specialized_specs) @ prior) cdefs
| (CDEF_fundef (id, heap_return, params, body) as orig_cdef) :: cdefs when Bindings.mem id !monomorphic_calls ->
| (CDEF_aux (CDEF_fundef (id, heap_return, params, body), def_annot) as orig_cdef) :: cdefs
when Bindings.mem id !monomorphic_calls ->
let tyargs = Bindings.find id !spec_tyargs in
let specialized_fundefs =
List.filter_map
Expand All @@ -1651,7 +1654,7 @@ module Make (C : CONFIG) = struct
KBindings.empty (KidSet.elements tyargs) instantiation
in
let body = List.map (map_instr_ctyp (subst_poly substs)) body in
Some (CDEF_fundef (specialized_id, heap_return, params, body))
Some (CDEF_aux (CDEF_fundef (specialized_id, heap_return, params, body), def_annot))
)
else None
)
Expand All @@ -1670,7 +1673,7 @@ module Make (C : CONFIG) = struct
let monomorphic_roots =
List.filter_map
(function
| CDEF_val (id, _, param_ctyps, ret_ctyp) ->
| CDEF_aux (CDEF_val (id, _, param_ctyps, ret_ctyp), _) ->
if List.exists is_polymorphic param_ctyps || is_polymorphic ret_ctyp then None else Some id
| _ -> None
)
Expand All @@ -1684,8 +1687,8 @@ module Make (C : CONFIG) = struct
let cdefs =
List.filter_map
(function
| CDEF_fundef (id, _, _, _) when IdSet.mem id unreachable_polymorphic_functions -> None
| CDEF_val (id, _, _, _) when IdSet.mem id unreachable_polymorphic_functions -> None
| CDEF_aux (CDEF_fundef (id, _, _, _), _) when IdSet.mem id unreachable_polymorphic_functions -> None
| CDEF_aux (CDEF_val (id, _, _, _), _) when IdSet.mem id unreachable_polymorphic_functions -> None
| cdef -> Some cdef
)
cdefs
Expand Down Expand Up @@ -1799,12 +1802,16 @@ module Make (C : CONFIG) = struct
let specialize_field ctx struct_id = visit_cdefs (new specialize_field_visitor instantiations ctx struct_id) in

let mangled_pragma orig_id mangled_id =
CDEF_pragma
("mangled", Util.zencode_string (string_of_id orig_id) ^ " " ^ Util.zencode_string (string_of_id mangled_id))
CDEF_aux
( CDEF_pragma
("mangled", Util.zencode_string (string_of_id orig_id) ^ " " ^ Util.zencode_string (string_of_id mangled_id)),
mk_def_annot (gen_loc (id_loc orig_id))
)
in

function
| CDEF_type (CTD_variant (var_id, ctors)) :: cdefs when List.exists (fun (_, ctyp) -> is_polymorphic ctyp) ctors ->
| CDEF_aux (CDEF_type (CTD_variant (var_id, ctors)), def_annot) :: cdefs
when List.exists (fun (_, ctyp) -> is_polymorphic ctyp) ctors ->
let typ_params = List.fold_left (fun set (_, ctyp) -> KidSet.union (ctyp_vars ctyp) set) KidSet.empty ctors in

let _ = visit_cdefs (new scan_variant_visitor instantiations ctx var_id) cdefs in
Expand Down Expand Up @@ -1866,14 +1873,16 @@ module Make (C : CONFIG) = struct
specialize_variants ctx
(List.concat
(List.map
(fun (id, ctors) -> [CDEF_type (CTD_variant (id, ctors)); mangled_pragma var_id id])
(fun (id, ctors) ->
[CDEF_aux (CDEF_type (CTD_variant (id, ctors)), def_annot); mangled_pragma var_id id]
)
monomorphized_variants
)
@ mangled_ctors @ prior
)
cdefs
| CDEF_type (CTD_struct (struct_id, fields)) :: cdefs when List.exists (fun (_, ctyp) -> is_polymorphic ctyp) fields
->
| CDEF_aux (CDEF_type (CTD_struct (struct_id, fields)), def_annot) :: cdefs
when List.exists (fun (_, ctyp) -> is_polymorphic ctyp) fields ->
let typ_params = List.fold_left (fun set (_, ctyp) -> KidSet.union (ctyp_vars ctyp) set) KidSet.empty fields in

let cdefs = specialize_field ctx struct_id cdefs in
Expand Down Expand Up @@ -1932,7 +1941,9 @@ module Make (C : CONFIG) = struct
specialize_variants ctx
(List.concat
(List.map
(fun (id, fields) -> [CDEF_type (CTD_struct (id, fields)); mangled_pragma struct_id id])
(fun (id, fields) ->
[CDEF_aux (CDEF_type (CTD_struct (id, fields)), def_annot); mangled_pragma struct_id id]
)
monomorphized_structs
)
@ mangled_fields @ prior
Expand Down Expand Up @@ -2018,7 +2029,7 @@ module Make (C : CONFIG) = struct
in

let rec precise_calls prior = function
| (CDEF_type (CTD_variant (var_id, ctors)) as cdef) :: cdefs ->
| (CDEF_aux (CDEF_type (CTD_variant (var_id, ctors)), _) as cdef) :: cdefs ->
List.iter
(fun (id, ctyp) ->
constructor_types := Bindings.add id ([ctyp], CT_variant (var_id, ctors)) !constructor_types
Expand All @@ -2035,8 +2046,8 @@ module Make (C : CONFIG) = struct
to sort the type definitions in the list of cdefs. *)
let sort_ctype_defs reverse cdefs =
(* Split the cdefs into type definitions and non type definitions *)
let is_ctype_def = function CDEF_type _ -> true | _ -> false in
let unwrap = function CDEF_type ctdef -> ctdef | _ -> assert false in
let is_ctype_def = function CDEF_aux (CDEF_type _, _) -> true | _ -> false in
let unwrap = function CDEF_aux (CDEF_type ctdef, def_annot) -> (ctdef, def_annot) | _ -> assert false in
let ctype_defs = List.map unwrap (List.filter is_ctype_def cdefs) in
let cdefs = List.filter (fun cdef -> not (is_ctype_def cdef)) cdefs in

Expand All @@ -2053,7 +2064,7 @@ module Make (C : CONFIG) = struct
let module IdGraph = Graph.Make (Id) in
let graph =
List.fold_left
(fun g ctdef ->
(fun g (ctdef, _) ->
List.fold_left
(fun g id -> IdGraph.add_edge id (ctdef_id ctdef) g)
(IdGraph.add_edges (ctdef_id ctdef) [] g) (* Make sure even types with no dependencies are in graph *)
Expand All @@ -2065,7 +2076,12 @@ module Make (C : CONFIG) = struct
(* Then select the ctypes in the correct order as given by the topsort *)
let ids = IdGraph.topsort graph in
let ctype_defs =
List.map (fun id -> CDEF_type (List.find (fun ctdef -> Id.compare (ctdef_id ctdef) id = 0) ctype_defs)) ids
List.map
(fun id ->
let ctdef, def_annot = List.find (fun (ctdef, _) -> Id.compare (ctdef_id ctdef) id = 0) ctype_defs in
CDEF_aux (CDEF_type ctdef, def_annot)
)
ids
in

(if reverse then List.rev ctype_defs else ctype_defs) @ cdefs
Expand Down Expand Up @@ -2116,7 +2132,8 @@ module Make (C : CONFIG) = struct
let dummy_exn = mk_id "__dummy_exn#" in
let cdefs, ctx =
if not (Bindings.mem (mk_id "exception") ctx.variants) then
( CDEF_type (CTD_variant (mk_id "exception", [(dummy_exn, CT_unit)])) :: cdefs,
( CDEF_aux (CDEF_type (CTD_variant (mk_id "exception", [(dummy_exn, CT_unit)])), mk_def_annot Parse_ast.Unknown)
:: cdefs,
{
ctx with
variants = Bindings.add (mk_id "exception") ([], Bindings.singleton dummy_exn CT_unit) ctx.variants;
Expand Down
25 changes: 16 additions & 9 deletions src/lib/jib_optimize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ let rec flatten_instrs = function
| instr :: instrs -> instr :: flatten_instrs instrs
| [] -> []

let flatten_cdef = function
let flatten_cdef_aux = function
| CDEF_fundef (function_id, heap_return, args, body) ->
flat_counter := 0;
CDEF_fundef (function_id, heap_return, args, flatten_instrs body)
Expand All @@ -128,6 +128,8 @@ let flatten_cdef = function
CDEF_let (n, bindings, flatten_instrs instrs)
| cdef -> cdef

let flatten_cdef (CDEF_aux (aux, def_annot)) = CDEF_aux (flatten_cdef_aux aux, def_annot)

let unique_per_function_ids cdefs =
let unique_id i = function
| Name (id, ssa_num) -> Name (append_id id ("#u" ^ string_of_int i), ssa_num)
Expand All @@ -146,7 +148,7 @@ let unique_per_function_ids cdefs =
| instr :: instrs -> instr :: unique_instrs i instrs
| [] -> []
in
let unique_cdef i = function
let unique_cdef_aux i = function
| CDEF_register (id, ctyp, instrs) -> CDEF_register (id, ctyp, unique_instrs i instrs)
| CDEF_type ctd -> CDEF_type ctd
| CDEF_let (n, bindings, instrs) -> CDEF_let (n, bindings, unique_instrs i instrs)
Expand All @@ -156,6 +158,7 @@ let unique_per_function_ids cdefs =
| CDEF_finish (id, instrs) -> CDEF_finish (id, unique_instrs i instrs)
| CDEF_pragma (name, str) -> CDEF_pragma (name, str)
in
let unique_cdef i (CDEF_aux (aux, def_annot)) = CDEF_aux (unique_cdef_aux i aux, def_annot) in
List.mapi unique_cdef cdefs

let rec cval_subst id subst = function
Expand Down Expand Up @@ -257,7 +260,8 @@ let rec clexp_subst id subst = function
| CL_rmw _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "Cannot substitute into read-modify-write construct"

let rec find_function fid = function
| CDEF_fundef (fid', heap_return, args, body) :: _ when Id.compare fid fid' = 0 -> Some (heap_return, args, body)
| CDEF_aux (CDEF_fundef (fid', heap_return, args, body), _) :: _ when Id.compare fid fid' = 0 ->
Some (heap_return, args, body)
| cdef :: cdefs -> find_function fid cdefs
| [] -> None

Expand Down Expand Up @@ -524,12 +528,15 @@ let remove_tuples cdefs ctx =
let name = "tuple#" ^ Util.string_of_list "_" string_of_ctyp ctyps in
let fields = List.mapi (fun n ctyp -> (mk_id (name ^ string_of_int n), ctyp)) ctyps in
[
CDEF_type (CTD_struct (mk_id name, fields));
CDEF_pragma
( "tuplestruct",
Util.string_of_list " "
(fun x -> x)
(Util.zencode_string name :: List.map (fun (id, _) -> Util.zencode_string (string_of_id id)) fields)
CDEF_aux (CDEF_type (CTD_struct (mk_id name, fields)), mk_def_annot Parse_ast.Unknown);
CDEF_aux
( CDEF_pragma
( "tuplestruct",
Util.string_of_list " "
(fun x -> x)
(Util.zencode_string name :: List.map (fun (id, _) -> Util.zencode_string (string_of_id id)) fields)
),
mk_def_annot Parse_ast.Unknown
);
]
| _ -> assert false
Expand Down
Loading
Loading