Skip to content

Commit

Permalink
pretty much admit free
Browse files Browse the repository at this point in the history
  • Loading branch information
mbbarbosa-lectures committed Nov 4, 2024
1 parent d702a45 commit 2ce5124
Showing 1 changed file with 153 additions and 51 deletions.
204 changes: 153 additions & 51 deletions proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec
Original file line number Diff line number Diff line change
Expand Up @@ -1856,7 +1856,6 @@ qed.
*)
bind op [W16.t & W256.t & Array256.t] sliceget256_16_256 "asliceget".
realize bvaslicegetP.

move => /= arr offset; rewrite /sliceget256_16_256 /= => H k kb.
case (8%| offset) => /= *; last by smt(W256.get_bits2w).
rewrite /get256_direct pack32E initiE 1:/# /= initiE 1:/# /= initiE 1:/# /= bits8E initiE 1:/# /=.
Expand Down Expand Up @@ -1973,31 +1972,119 @@ rewrite (nth_map W16.zero []); 1: smt(Array256.size_to_list).
rewrite nth_mkseq 1:/# /= bits8E /= initiE /# /=.
qed.

op sliceget32_8_256 (arr: W8.t Array32.t) (i: int) : W256.t = get256 (WArray32.init8 (fun (i_0 : int) => pvc_shufbidx_s.[i_0])) (i%/256).
op sliceget32_8_256 (arr: W8.t Array32.t) (offset: int) : W256.t =
if 8 %| offset then
get256_direct (WArray32.init8 (fun (i_0 : int) => arr.[i_0])) (offset %/ 8)
else W256.bits2w (take 256 (drop offset (flatten (map W8.w2bits (to_list arr))))).

bind op [W8.t & W256.t & Array32.t] sliceget32_8_256 "asliceget".
realize bvaslicegetP by admit. (* We need a general framework for these *)
realize bvaslicegetP.
move => /= arr offset; rewrite /sliceget32_8_256 /= => H k kb.
case (8%| offset) => /= *; last by smt(W256.get_bits2w).
rewrite /get256_direct pack32E initiE 1:/# /= initiE 1:/# /= initiE 1:/# /=.
rewrite nth_take 1,2:/# nth_drop 1,2:/#.
rewrite (BitEncoding.BitChunking.nth_flatten false 8 _).
+ rewrite allP => x /=; rewrite mapP => He; elim He;smt(W8.size_w2bits).
rewrite (nth_map W8.zero []); 1: smt(Array32.size_to_list).
by rewrite nth_mkseq /#.
qed.

op sliceget768_16_256 (arr: W16.t Array768.t) (offset: int) : W256.t =
if 8 %| offset then
get256_direct (WArray1536.init16 (fun (i_0 : int) => arr.[i_0])) (offset %/ 8)
else W256.bits2w (take 256 (drop offset (flatten (map W16.w2bits (to_list arr))))).

op sliceget768_16_256 (arr: W16.t Array768.t) (i: int) : W256.t = get256 (WArray1536.init16 (fun (i_0 : int) => arr.[i_0])) (i %/ 256).

bind op [W16.t & W256.t & Array768.t] sliceget768_16_256 "asliceget".
realize bvaslicegetP by admit. (* We need a general framework for these *)
realize bvaslicegetP.
move => /= arr offset; rewrite /sliceget768_16_256 /= => H k kb.
case (8%| offset) => /= *; last by smt(W256.get_bits2w).
rewrite /get256_direct pack32E initiE 1:/# /= initiE 1:/# /= initiE 1:/# /= bits8E initiE 1:/# /=.
rewrite nth_take 1,2:/# nth_drop 1,2:/#.
rewrite (BitEncoding.BitChunking.nth_flatten false 16 _).
+ rewrite allP => x /=; rewrite mapP => He; elim He;smt(W16.size_w2bits).
rewrite (nth_map W16.zero []); 1: smt(Array768.size_to_list).
by rewrite nth_mkseq /#.
qed.

op sliceset960_8_128 (arr: W8.t Array960.t) (offset: int) (bv: W128.t) : W8.t Array960.t =
if 8 %| offset
then Array960.init (fun (i3 : int) => get8 (set128_direct ((init8 (fun (i_0 : int) => arr.[i_0])))%WArray960 (offset %/ 8) bv) i3)
else Array960.of_list witness (map W8.bits2w (chunk 8 (take offset (flatten (map W8.w2bits (to_list arr))) ++ w2bits bv ++
drop (offset + 128) (flatten (map W8.w2bits (to_list arr)))))).

op sliceset960_8_128 (arr: W8.t Array960.t) (i: int) (bv: W128.t) : W8.t Array960.t = Array960.init (get8 (set128_direct ((init8 (fun (i_0 : int) => arr.[i_0])))%WArray960 (i %/ 8) bv)).
lemma size_flatten_W8_w2bits (a : W8.t list) :
(size (flatten (map W8.w2bits (a)))) = 8 * size a.
proof.
rewrite size_flatten -map_comp /(\o) /=.
rewrite StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /=.
rewrite StdBigop.Bigint.big_constz count_predT /#.
qed.

bind op [W8.t & W128.t & Array960.t] sliceset960_8_128 "asliceset".
realize bvaslicesetP by admit. (* We need a general framework for these *)
realize bvaslicesetP.
move => arr offset bv H /= k kb; rewrite /sliceset960_8_128 /=.
case (8 %| offset) => /= *; last first.
+ rewrite of_listK; 1: by rewrite size_map size_chunk // !size_cat size_take;
by smt(size_take size_drop W8.size_w2bits size_cat Array960.size_to_list size_flatten_W8_w2bits size_ge0).
rewrite -(map_comp W8.w2bits W8.bits2w) /(\o).
have := eq_in_map ((fun (x : bool list) => w2bits ((bits2w x))%W8)) idfun (chunk 8
(take offset (flatten (map W8.w2bits (to_list arr))) ++ w2bits bv ++
drop (offset + 128) (flatten (map W8.w2bits (to_list arr))))).
rewrite iffE => [#] -> * /=; 1: by smt(in_chunk_size W8.bits2wK).
rewrite map_id /= chunkK //;1: by rewrite !size_cat size_take;
by smt(size_take size_drop W8.size_w2bits size_cat Array960.size_to_list size_flatten_W8_w2bits size_ge0).
by rewrite !nth_cat !size_cat /=;
smt(nth_take nth_drop size_take size_drop W8.size_w2bits size_cat Array960.size_to_list size_flatten_W8_w2bits size_ge0).
rewrite (nth_flatten _ 8); 1: by rewrite allP => i;rewrite mapP => He; elim He;smt(W8.size_w2bits).
rewrite (nth_map W8.zero []); 1: smt(Array960.size_to_list).
rewrite nth_mkseq 1:/# /= initiE 1:/# /= /get8 /set128_direct.
rewrite initiE 1:/# /=.
case (offset <= k && k < offset + 128) => *; 1: by
rewrite ifT 1:/# get_bits8 /= 1,2:/# initiE // initiE //.
rewrite ifF 1:/# initiE 1:/# /=.
rewrite (nth_flatten _ 8); 1: by rewrite allP => i;rewrite mapP => He; elim He;smt(W8.size_w2bits).
rewrite (nth_map W8.zero []); 1: smt(Array960.size_to_list).
rewrite nth_mkseq /#.
qed.


op sliceset960_8_32 (arr: W8.t Array960.t) (i: int) (bv: W32.t) : W8.t Array960.t = Array960.init
op sliceset960_8_32 (arr: W8.t Array960.t) (offset: int) (bv: W32.t) : W8.t Array960.t =
if 8 %| offset
then Array960.init
(WArray960.get8
(set32_direct (WArray960.init8 (fun (i_0 : int) => arr.[i_0])) (
i %/ 8) bv)).
offset %/ 8) bv))
else Array960.of_list witness (map W8.bits2w (chunk 8 (take offset (flatten (map W8.w2bits (to_list arr))) ++ w2bits bv ++
drop (offset + 32) (flatten (map W8.w2bits (to_list arr)))))).


bind op [W8.t & W32.t & Array960.t] sliceset960_8_32 "asliceset".
realize bvaslicesetP by admit. (* We need a general framework for these *)

realize bvaslicesetP.
move => arr offset bv H /= k kb; rewrite /sliceset960_8_32 /=.
case (8 %| offset) => /= *; last first.
+ rewrite of_listK; 1: by rewrite size_map size_chunk // !size_cat size_take;
by smt(size_take size_drop W8.size_w2bits size_cat Array960.size_to_list size_flatten_W8_w2bits size_ge0).
rewrite -(map_comp W8.w2bits W8.bits2w) /(\o).
have := eq_in_map ((fun (x : bool list) => w2bits ((bits2w x))%W8)) idfun (chunk 8
(take offset (flatten (map W8.w2bits (to_list arr))) ++ w2bits bv ++
drop (offset + 32) (flatten (map W8.w2bits (to_list arr))))).
rewrite iffE => [#] -> * /=; 1: by smt(in_chunk_size W8.bits2wK).
rewrite map_id /= chunkK //;1: by rewrite !size_cat size_take;
by smt(size_take size_drop W8.size_w2bits size_cat Array960.size_to_list size_flatten_W8_w2bits size_ge0).
by rewrite !nth_cat !size_cat /=;
smt(nth_take nth_drop size_take size_drop W8.size_w2bits size_cat Array960.size_to_list size_flatten_W8_w2bits size_ge0).
rewrite (nth_flatten _ 8); 1: by rewrite allP => i;rewrite mapP => He; elim He;smt(W8.size_w2bits).
rewrite (nth_map W8.zero []); 1: smt(Array960.size_to_list).
rewrite nth_mkseq 1:/# /= initiE 1:/# /= /get8 /set32_direct.
rewrite initiE 1:/# /=.
case (offset <= k && k < offset + 32) => *; 1: by
rewrite ifT 1:/# get_bits8 /= 1,2:/# initiE // initiE //.
rewrite ifF 1:/# initiE 1:/# /=.
rewrite (nth_flatten _ 8); 1: by rewrite allP => i;rewrite mapP => He; elim He;smt(W8.size_w2bits).
rewrite (nth_map W8.zero []); 1: smt(Array960.size_to_list).
rewrite nth_mkseq /#.
qed.

theory W10.
abbrev [-printing] size = 10.
Expand Down Expand Up @@ -2099,17 +2186,19 @@ qed.


op sll_64 (w1 w2 : W64.t) : W64.t =
if (to_uint w2 < 64) then w1 `<<` (truncateu8 w2) else W64.zero.
if (64 <= to_uint w2) then W64.zero else w1 `<<` (truncateu8 w2).

bind op [W64.t] sll_64 "shl".
realize bvshlP.
proof.
rewrite /sll_64 => bv1 bv2.
case : (to_uint bv2 < 64).
case : (64 <= to_uint bv2); last first.
+ rewrite /(`<<`) W64.to_uint_shl; 1: by smt(W8.to_uint_cmp).
rewrite /truncateu8 => bv2bnd />.
do 2! (rewrite (pmod_small (to_uint bv2) _);smt(W64.to_uint_cmp)).
admit. (* What is the circuit behavior? Does it give zero? Yes. *)
move => *.
have -> : to_uint bv2 = (to_uint bv2 - 64) + 64 by ring.
by rewrite exprD_nneg 1,2:/# /= /#.
qed.

bind op [W32.t & W16.t] W2u16.truncateu16 "truncate".
Expand Down Expand Up @@ -2199,65 +2288,78 @@ qed.


op sra_32 (w1 w2 : W32.t) : W32.t =
w1 `|>>` (truncateu8 w2).
if (32 <= to_uint w2) then W32.zero else w1 `|>>` (truncateu8 w2).

bind op [W32.t] sra_32 "ashr".
realize bvashrP by admit.
realize bvashrP.
rewrite /sra_32 => bv1 bv2.
case : (32 <= to_uint bv2 < 32); last by admit.
move => *.
have -> : to_uint bv2 = (to_uint bv2 - 32) + 32 by ring.
rewrite exprD_nneg 1,2:/# /= mulrC. by admit.
qed.

op sra_16 (w1 w2 : W16.t) : W16.t =
w1 `|>>` (truncateu8 w2).
if (16 <= to_uint w2) then W16.zero else w1 `|>>` (truncateu8 w2).

bind op [W16.t] sra_16 "ashr".
realize bvashrP by admit.
realize bvashrP.
rewrite /sra_16 => bv1 bv2.
case : (16 <= to_uint bv2); last by admit.
move => *.
have -> : to_uint bv2 = (to_uint bv2 - 16) + 16 by ring.
rewrite exprD_nneg 1,2:/# /= mulrC. by admit.
qed.


op srl_16 (w1 w2 : W16.t) : W16.t =
if 16 <= (to_uint w2) then W16.zero else
w1 `>>` (truncateu8 w2).

bind op [W16.t] srl_16 "shr".
realize bvshrP.
move => w1 w2.
rewrite /srl_16.
case : (16 <= to_uint w2).
move => gt.
simplify.
rewrite eq_sym.
apply (divz_eq0 (to_uint w1) (2 ^ to_uint w2)).
smt(StdOrder.IntOrder.expr_gt0).
split.
smt(W16.to_uint_cmp).
have : (2 ^ 16 <= 2 ^ (to_uint w2)).
apply StdOrder.IntOrder.ler_weexpn2l => //.

move => bnd2.
smt(W16.to_uint_cmp).
move => bnd.
have : (to_uint w2 < 16).
smt().
move => bnd2.

rewrite /srl_16 /(`>>`) to_uint_shr.
smt(W16.to_uint_cmp).
rewrite /truncateu8.
rewrite to_uint_small.
smt(W16.to_uint_cmp).
congr. congr.
rewrite pmod_small.
smt(W16.to_uint_cmp).
trivial.
rewrite /srl_16 => bv1 bv2.
case : (16 <= to_uint bv2); last first.
+ rewrite /(`>>`) W16.to_uint_shr; 1: by smt(W8.to_uint_cmp).
rewrite /truncateu8 => bv2bnd />.
do 2! (rewrite (pmod_small (to_uint bv2) _);smt(W16.to_uint_cmp)).
move => *.
have -> : to_uint bv2 = (to_uint bv2 - 16) + 16 by ring.
rewrite exprD_nneg 1,2:/# /=.
smt(StdOrder.IntOrder.expr_gt0 W16.to_uint_cmp pow2_16).
qed.


op sll_16 (w1 w2 : W16.t) : W16.t =
w1 `<<` (truncateu8 w2).
if (16 <= to_uint w2) then W16.zero else w1 `<<` (truncateu8 w2).

bind op [W16.t] sll_16 "shl".
realize bvshlP by admit. (* not provable. missing %% 65536? *)
realize bvshlP.
rewrite /sll_16 => bv1 bv2.
case : (16 <= to_uint bv2); last first.
+ rewrite /(`<<`) W16.to_uint_shl; 1: by smt(W8.to_uint_cmp).
rewrite /truncateu8 => bv2bnd />.
do 2! (rewrite (pmod_small (to_uint bv2) _);smt(W16.to_uint_cmp pow2_16)).
move => *.
have -> : to_uint bv2 = (to_uint bv2 - 16) + 16 by ring.
by rewrite exprD_nneg 1,2:/# /= /#.
qed.

op srl_64 (w1 w2 : W64.t) : W64.t =
w1 `>>` (truncateu8 w2).
if (64 <= to_uint w2) then W64.zero else w1 `>>` (truncateu8 w2).

bind op [W64.t] srl_64 "shr".
realize bvshrP by admit. (* not provable. missing %% 256? *)
realize bvshrP.
rewrite /srl_64 => bv1 bv2.
case : (64 <= to_uint bv2); last first.
+ rewrite /(`>>`) W64.to_uint_shr; 1: by smt(W8.to_uint_cmp).
rewrite /truncateu8 => bv2bnd />.
do 2! (rewrite (pmod_small (to_uint bv2) _);smt(W64.to_uint_cmp)).
move => *.
have -> : to_uint bv2 = (to_uint bv2 - 64) + 64 by ring.
rewrite exprD_nneg 1,2:/# /=.
smt(StdOrder.IntOrder.expr_gt0 W64.to_uint_cmp pow2_64).
qed.

op lane_func_reduce(c : W16.t) : W16.t =
let t = (sigextu32 c) * (W32.of_int 20159) in
Expand Down

0 comments on commit 2ce5124

Please sign in to comment.