Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Partial application of procedures #1226

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/base/Cashflow.ml
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ struct
new_map
| _ -> targ_tag_map)
| PrimType _ | PolyFun (_, _) | Unit -> targ_tag_map
| Address _ -> (* TODO *) targ_tag_map
| Address _ | ProcType _ -> (* TODO *) targ_tag_map
in
let tvar_tag_map, _ =
List.fold_left arg_typs ~init:(init_targ_to_tag_map, arg_tags)
Expand Down Expand Up @@ -519,7 +519,7 @@ struct
let ctr_arg_filter targ =
let open CFType in
match targ with
| PrimType _ | MapType _ | FunType _ -> true
| PrimType _ | MapType _ | FunType _ | ProcType _ -> true
| ADT _ (* TODO: Detect induction, and ignore only when inductive *)
| TypeVar _ (* TypeVars tagged at type level *) | PolyFun _ (* Ignore *)
| Unit ->
Expand Down
3 changes: 3 additions & 0 deletions src/base/Disambiguate.ml
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,9 @@ module ScillaDisambiguation (SR : Rep) (ER : Rep) = struct
let%bind dis_arg_t = recurse arg_t in
let%bind dis_res_t = recurse res_t in
pure @@ PostDisType.FunType (dis_arg_t, dis_res_t)
| ProcType (p, args) ->
let%bind dis_args = mapM args ~f:recurse in
pure @@ PostDisType.ProcType (p, dis_args)
| ADT (t_name, targs) ->
let%bind dis_t_name =
disambiguate_identifier typ_dict t_name (get_rep t_name)
Expand Down
4 changes: 2 additions & 2 deletions src/base/JSON.ml
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ let build_prim_lit_exn t v =
match t with
| PrimType pt -> build_prim_literal_of_type pt v
| Address _ -> build_prim_literal_of_type (Bystrx_typ Type.address_length) v
| MapType _ | FunType _ | ADT _ | TypeVar _ | PolyFun _ | Unit ->
| MapType _ | FunType _ | ADT _ | TypeVar _ | PolyFun _ | ProcType _ | Unit ->
raise (exn ())

(****************************************************************)
Expand Down Expand Up @@ -240,7 +240,7 @@ and json_to_lit_exn t v =
| PrimType _ | Address _ ->
let tv = build_prim_lit_exn t (to_string_exn v) in
tv
| FunType _ | TypeVar _ | PolyFun _ | Unit ->
| FunType _ | TypeVar _ | PolyFun _ | ProcType _ | Unit ->
let exn () =
mk_invalid_json ~kind:"Invalid type in JSON" ~inst:(pp_typ t)
in
Expand Down
2 changes: 2 additions & 0 deletions src/base/Recursion.ml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ module ScillaRecursion (SR : Rep) (ER : Rep) = struct
| MapType (t1, t2) | FunType (t1, t2) ->
let%bind () = walk t1 in
walk t2
| ProcType (_p, args) -> forallM ~f:walk args
| ADT (s, targs) ->
let%bind () = is_adt_in_scope s in
forallM ~f:walk targs
Expand Down Expand Up @@ -288,6 +289,7 @@ module ScillaRecursion (SR : Rep) (ER : Rep) = struct
| MapType (t1, t2) | FunType (t1, t2) ->
let%bind _ = walk t1 in
walk t2
| ProcType (_, args) -> forallM args ~f:walk
| ADT (s, targs) ->
(* Only allow ADTs that are already in scope. This prevents mutually inductive definitions. *)
let%bind _ = is_adt_in_scope s in
Expand Down
6 changes: 3 additions & 3 deletions src/base/SanityChecker.ml
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,7 @@ struct
| LibVar _ | LibTyp _ -> None)
in
(* Returns an array with information about matched [Option] arguments
[Some(args)] if the [fun_name] is a procedure. *)
[Some(args)] if [fun_name] is a procedure. *)
let handle_comp (cmod : cmodule) option_args_matches fun_name =
let get_comp_args comp =
match comp.comp_type with
Expand All @@ -769,8 +769,8 @@ struct
match param_ty with
| ADT (id, _targs) when is_option_name id ->
Map.set m ~key:(get_id param_id) ~data:i
| ADT _ | PrimType _ | MapType _ | FunType _ | TypeVar _
| PolyFun _ | Unit | Address _ ->
| ADT _ | PrimType _ | MapType _ | FunType _ | ProcType _
| TypeVar _ | PolyFun _ | Unit | Address _ ->
m)
in
(* Mark Option arguments that matches inside the body. *)
Expand Down
5 changes: 4 additions & 1 deletion src/base/Syntax.ml
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,10 @@ module ScillaSyntax (SR : Rep) (ER : Rep) (Lit : ScillaLiteral) = struct
ER.rep SIdentifier.t option
* SR.rep SIdentifier.t
* ER.rep SIdentifier.t list
(** [CallProc(I, P, [A1, ... An])] is a procedure call: [I = P A1 ... An] *)
(** [CallProc(I, P, [A1, ... An])] is a procedure call, when all the
arguments are specified: [I = P A1 ... An]. Otherwise, it is or
partial application of procedure that creates a new local variable
[I] that has the [ProcType] type: [I = P A1 ... An]. *)
| Throw of ER.rep SIdentifier.t option
(** [Throw(I)] represents: [throw I] *)
| GasStmt of SGasCharge.gas_charge
Expand Down
1 change: 1 addition & 0 deletions src/base/SyntaxAnnotMapper.ml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ struct
| PrimType _ | Unit | TypeVar _ -> ty
| MapType (ty1, ty2) -> MapType (walk ty1, walk ty2)
| FunType (ty1, ty2) -> FunType (walk ty1, walk ty2)
| ProcType (p, args) -> ProcType (p, List.map args ~f:walk)
| PolyFun (tv, ty) -> PolyFun (tv, walk ty)
| ADT (tid, tys) ->
ADT (map_id fl tid, List.map tys ~f:(fun ty -> walk ty))
Expand Down
28 changes: 23 additions & 5 deletions src/base/Type.ml
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,21 @@ module type ScillaType = sig
[@@deriving sexp]

type t =
| PrimType of PrimType.t
| MapType of t * t
| PrimType of PrimType.t (** [PrimType(T)] represents a primary type *)
| MapType of t * t (** [MapType(K, V)] is a mapping type: [K |-> V] *)
| FunType of t * t
(** [FunType(Ty1, Ty2)] is a function type: [Ty1 => Ty2] *)
| ADT of loc TIdentifier.t * t list
| TypeVar of string
(** [ADT(N, Args)] represents ADT type [N] with type parameters [Args] *)
| TypeVar of string (** [TypeVar(T)] is a type variable *)
| PolyFun of string * t
| Unit
| Address of t addr_kind
(** [PolyFun('A, T)] represents a polymorphic function type where
['A] is a type parameter. For example: [forall 'A. List 'a -> List 'A] *)
| ProcType of string * t list
(** [ProcType(P, Args)] is a type of partial application of procedure
[P] which has formal arguments with types [Args] *)
| Unit (** [Unit] is a unit type *)
| Address of t addr_kind (** [Address(A)] represents address *)
[@@deriving sexp, to_yojson]

val pp_typ : t -> string
Expand Down Expand Up @@ -223,6 +230,7 @@ module MkType (I : ScillaIdentifier) = struct
| ADT of loc TIdentifier.t * t list
| TypeVar of string
| PolyFun of string * t
| ProcType of string * t list
| Unit
| Address of (t addr_kind[@to_yojson fun _ -> `String "Address"])
[@@deriving sexp, to_yojson]
Expand All @@ -239,6 +247,9 @@ module MkType (I : ScillaIdentifier) = struct
in
String.concat ~sep:" " elems
| FunType (at, vt) -> sprintf "%s -> %s" (with_paren at) (recurser vt)
| ProcType (p, args_tys) ->
sprintf "%s (%s)" p
(List.map args_tys ~f:recurser |> String.concat ~sep:", ")
| TypeVar tv -> tv
| PolyFun (tv, bt) -> sprintf "forall %s. %s" tv (recurser bt)
| Unit -> sprintf "()"
Expand Down Expand Up @@ -280,6 +291,8 @@ module MkType (I : ScillaIdentifier) = struct
| PrimType _ | Unit -> acc
| MapType (kt, vt) -> go kt acc |> go vt
| FunType (at, rt) -> go at acc |> go rt
| ProcType (_, args) ->
List.fold_left args ~init:[] ~f:(fun acc arg -> go arg acc)
| TypeVar n -> add acc n
| ADT (_, ts) -> List.fold_left ts ~init:acc ~f:(Fn.flip go)
| PolyFun (arg, bt) ->
Expand Down Expand Up @@ -314,6 +327,9 @@ module MkType (I : ScillaIdentifier) = struct
let ats = subst_type_in_type tvar tp at in
let rts = subst_type_in_type tvar tp rt in
FunType (ats, rts)
| ProcType (p, args_tys) ->
let args_tyss = List.map args_tys ~f:(subst_type_in_type tvar tp) in
ProcType (p, args_tyss)
| TypeVar n -> if String.(tvar = n) then tp else tm
| ADT (s, ts) ->
let ts' = List.map ts ~f:(subst_type_in_type tvar tp) in
Expand All @@ -337,6 +353,8 @@ module MkType (I : ScillaIdentifier) = struct
match t with
| MapType (kt, vt) -> MapType (kt, recursor vt taken)
| FunType (at, rt) -> FunType (recursor at taken, recursor rt taken)
| ProcType (p, args_tys) ->
ProcType (p, List.map args_tys ~f:(fun ty -> recursor ty taken))
| ADT (n, ts) ->
let ts' = List.map ts ~f:(fun w -> recursor w taken) in
ADT (n, ts')
Expand Down
136 changes: 99 additions & 37 deletions src/base/TypeChecker.ml
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ module ScillaTypechecker (SR : Rep) (ER : Rep) = struct
match t with
| PrimType _ | Unit | TypeVar _ -> 1
| PolyFun (_, t) -> 1 + type_size t
| ProcType (_, args) -> 1 + List.length args
| MapType (t1, t2) | FunType (t1, t2) -> 1 + type_size t1 + type_size t2
| ADT (_, ts) ->
List.fold_left ts ~init:1 ~f:(fun acc t -> acc + type_size t)
Expand All @@ -163,7 +164,8 @@ module ScillaTypechecker (SR : Rep) (ER : Rep) = struct
| MapType (_, _)
| FunType (_, _)
| ADT (_, _)
| PolyFun (_, _) ->
| PolyFun (_, _)
| ProcType (_, _) ->
1
| TypeVar n -> if String.(n = tvar) then tp_size else 1
| Address AnyAddr | Address LibAddr | Address CodeAddr -> 1
Expand Down Expand Up @@ -201,6 +203,9 @@ module ScillaTypechecker (SR : Rep) (ER : Rep) = struct
else
let%bind res = recurser t' in
pure (PolyFun (arg, res))
| ProcType (pname, args) ->
let%bind args' = mapM args ~f:recurser in
pure (ProcType (pname, args'))
| Address AnyAddr | Address LibAddr | Address CodeAddr -> pure t
| Address (ContrAddr fts) ->
let%bind fts_res =
Expand Down Expand Up @@ -551,15 +556,37 @@ module ScillaTypechecker (SR : Rep) (ER : Rep) = struct
(* Typing statements *)
(**************************************************************)

(* Auxiliary structure for types of fields and BC components *)
type stmt_tenv = {
pure : TEnv.t;
fields : TEnv.t;
procedures : (TCName.t * (TCType.t list * TCType.t option)) list;
}
(** Auxiliary structure for types of fields and BC components *)

(** Looks up a procedure with name [pname] or a local binding to a partial
application of procedure. *)
let lookup_proc env pname =
List.Assoc.find env.procedures ~equal:[%equal: TCName.t] (get_id pname)
let proc =
List.Assoc.find env.procedures ~equal:[%equal: TCName.t] (get_id pname)
in
match proc with
| None -> (
(* Lookup local bind to a partial application *)
let get_proc_type_args (rr : resolve_result) =
match (rr_typ rr).tp with
| ProcType (_proc_name, arg_tys) -> Some arg_tys
| _ -> None
in
match TEnv.resolveT env.pure (get_id pname) ~lopt:None with
| Ok bind ->
get_proc_type_args bind
|> Option.value_map
~f:(fun arg_tys ->
(* Partially applied procedures never have a return type *)
Some (arg_tys, None))
~default:None
| _ -> None)
| res -> res

let type_map_access_helper env maptype keys =
let rec helper maptype keys =
Expand Down Expand Up @@ -985,8 +1012,8 @@ module ScillaTypechecker (SR : Rep) (ER : Rep) = struct
@@ add_stmt_to_stmts_env_gas
(TypedSyntax.CreateEvnt typed_i, rep)
checked_stmts
| CallProc (id_opt, p, args) ->
let%bind arg_typs, ret_ty_opt =
| CallProc (id_opt, p, actual_args) ->
let%bind formal_args, ret_ty_opt =
match lookup_proc env p with
| Some (arg_typs, ret_ty_opt) -> pure (arg_typs, ret_ty_opt)
| None ->
Expand All @@ -995,41 +1022,76 @@ module ScillaTypechecker (SR : Rep) (ER : Rep) = struct
~inst:(as_error_string p)
(SR.get_loc (get_rep p)))
in
let%bind typed_args =
let%bind targs, typed_actuals = type_actuals env.pure args in
let%bind targs, typed_actuals = type_actuals env.pure actual_args in
let is_partial_application =
Option.is_some id_opt
&& List.length formal_args > List.length targs
in
if is_partial_application then
let%bind _ =
fromR_TE
@@ proc_type_applies arg_typs targs ~lc:(SR.get_loc rep)
@@ partial_proc_type_applies formal_args targs
~lc:(SR.get_loc rep)
in
pure typed_actuals
in
let%bind typed_id_opt, checked_stmts =
match id_opt with
| None ->
let%bind checked_stmts = type_stmts comp sts get_loc env in
pure @@ (None, checked_stmts)
| Some id -> (
match ret_ty_opt with
| Some ret_ty ->
let typed_id = add_type_to_ident id (mk_qual_tp ret_ty) in
let%bind checked_stmts =
with_extended_env env get_tenv_pure
[ (id, ret_ty) ]
[]
(type_stmts comp sts get_loc)
in
pure @@ (Some typed_id, checked_stmts)
| None ->
fail
(mk_type_error1
~kind:"Procedure does not return a value"
~inst:(as_error_string p)
(SR.get_loc (get_rep p))))
in
pure
@@ add_stmt_to_stmts_env_gas
(TypedSyntax.CallProc (typed_id_opt, p, typed_args), rep)
checked_stmts
let id = Option.value_exn id_opt in
let proc_name =
SIdentifier.Name.as_string (SIdentifier.get_id p)
in
let unapplied_formal_args =
List.sub formal_args
~pos:(List.length targs - 1)
~len:(List.length formal_args - List.length targs)
in
let partial_applied_type =
ProcType (proc_name, unapplied_formal_args)
in
let typed_id =
add_type_to_ident id (mk_qual_tp partial_applied_type)
in
let%bind checked_stmts =
with_extended_env env get_tenv_pure
[ (id, partial_applied_type) ]
[]
(type_stmts comp sts get_loc)
in
pure
@@ add_stmt_to_stmts_env_gas
(TypedSyntax.CallProc (Some typed_id, p, typed_actuals), rep)
checked_stmts
else
let%bind _ =
fromR_TE
@@ proc_type_applies formal_args targs ~lc:(SR.get_loc rep)
in
let%bind typed_id_opt, checked_stmts =
match id_opt with
| None ->
let%bind checked_stmts = type_stmts comp sts get_loc env in
pure @@ (None, checked_stmts)
| Some id -> (
match ret_ty_opt with
| Some ret_ty ->
let typed_id =
add_type_to_ident id (mk_qual_tp ret_ty)
in
let%bind checked_stmts =
with_extended_env env get_tenv_pure
[ (id, ret_ty) ]
[]
(type_stmts comp sts get_loc)
in
pure @@ (Some typed_id, checked_stmts)
| None ->
fail
(mk_type_error1
~kind:"Procedure does not return a value"
~inst:(as_error_string p)
(SR.get_loc (get_rep p))))
in
pure
@@ add_stmt_to_stmts_env_gas
(TypedSyntax.CallProc (typed_id_opt, p, typed_actuals), rep)
checked_stmts
| Iterate (l, p) -> (
let%bind lt =
fromR_TE
Expand Down
Loading