Skip to content

Commit

Permalink
path for slices
Browse files Browse the repository at this point in the history
  • Loading branch information
mbbarbosa-lectures committed Nov 4, 2024
1 parent e21494d commit 6068dcb
Showing 1 changed file with 56 additions and 45 deletions.
101 changes: 56 additions & 45 deletions proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec
Original file line number Diff line number Diff line change
Expand Up @@ -1815,8 +1815,12 @@ rewrite BVA_Top_Array1088_Array1088_t.tolistP.
apply eq_in_mkseq => i i_bnd; smt(Array1088.initE).
qed.

op sliceget256_16_256 (arr: W16.t Array256.t) (offset: int) : W256.t = W256.bits2w (take 256 (drop offset (flatten (map W16.w2bits (Array256.to_list arr))))).
op sliceget256_16_256 (arr: W16.t Array256.t) (offset: int) : W256.t =
if 8 %| offset then
get256_direct ((init16 (fun (i_0 : int) => arr.[i_0])))%WArray512 (offset %/ 8)
else W256.bits2w (take 256 (drop offset (flatten (map W16.w2bits (to_list arr))))).

(*
lemma flatten_take_drop_16 (l : W16.t list) (csize offset bit : int) :
0 <= offset =>
offset + csize <= 16 * size l =>
Expand All @@ -1833,13 +1837,6 @@ rewrite -get_w2bits;congr.
by rewrite (nth_map witness) 1:/#.
qed.

lemma size_flatten_W16_w2bits (a : W16.t list) :
(size (flatten (map W16.w2bits (a)))) = 16 * size a.
proof.
rewrite size_flatten -map_comp /(\o) /=.
rewrite StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /=.
rewrite StdBigop.Bigint.big_constz count_predT /#.
qed.

lemma aligned_get256_16_256 arr offset :
0 <= offset <= 16*256 - 256 =>
Expand All @@ -1856,20 +1853,29 @@ rewrite get_bits8 1:/#.
smt(@IntDiv).
qed.


*)
bind op [W16.t & W256.t & Array256.t] sliceget256_16_256 "asliceget".
realize bvaslicegetP.
move => *; rewrite /sliceget256_16_256 bits2wK // size_take //= size_drop //=.
admit. (* bounds are incomplete! 0 <= offset <= 16 * 256 - 256 *)
by smt(Array256.size_to_list size_flatten_W16_w2bits).

move => /= arr offset; rewrite /sliceget256_16_256 /= => H k kb.
case (8%| offset) => /= *; last by smt(W256.get_bits2w).
rewrite /get256_direct pack32E initiE 1:/# /= initiE 1:/# /= initiE 1:/# /= bits8E initiE 1:/# /=.
rewrite nth_take 1,2:/# nth_drop 1,2:/#.
rewrite (BitEncoding.BitChunking.nth_flatten false 16 _).
+ rewrite allP => x /=; rewrite mapP => He; elim He;smt(W16.size_w2bits).
rewrite (nth_map W16.zero []); 1: smt(Array256.size_to_list).
by rewrite nth_mkseq /#.
qed.

import BitEncoding BS2Int BitChunking.

op sliceset256_16_256 (arr: W16.t Array256.t) (offset: int) (bv: W256.t) : W16.t Array256.t = Array256.of_list witness (map W16.bits2w (chunk 16 (take offset (flatten (map W16.w2bits (to_list arr))) ++ w2bits bv ++
op sliceset256_16_256 (arr: W16.t Array256.t) (offset: int) (bv: W256.t) : W16.t Array256.t =
if 8 %| offset
then (init (fun (i3 : int) => get16 (set256_direct ((init16 (fun (i_0 : int) => arr.[i_0])))%WArray512 (offset %/ 8) bv) i3))%Array256
else Array256.of_list witness (map W16.bits2w (chunk 16 (take offset (flatten (map W16.w2bits (to_list arr))) ++ w2bits bv ++
drop (offset + 256) (flatten (map W16.w2bits (to_list arr)))))).


(*
lemma aligned_set256_16_256 arr offset bv :
0 <= offset <= 16*256 - 256 =>
256 %| offset =>
Expand Down Expand Up @@ -1930,24 +1936,41 @@ case (2 * i < 32 * (offset %/ 256 + 1));last first.
rewrite nth_drop; 1,2: by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
rewrite !get_w2bits get_bits16;by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
qed.
*)


lemma size_flatten_W16_w2bits (a : W16.t list) :
(size (flatten (map W16.w2bits (a)))) = 16 * size a.
proof.
rewrite size_flatten -map_comp /(\o) /=.
rewrite StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /=.
rewrite StdBigop.Bigint.big_constz count_predT /#.
qed.

bind op [W16.t & W256.t & Array256.t] sliceset256_16_256 "asliceset".
realize bvaslicesetP. (* bounds are incomplete! 0 <= offset <= 16 * 256 - 256 *)
move => arr offset bv *. have ? : 0 <= offset by admit.
rewrite /sliceset256_16_256 of_listK.
+ rewrite size_map size_chunk // !size_cat size_take 1:/#.
by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
rewrite -(map_comp W16.w2bits W16.bits2w) /(\o).
have := eq_in_map ((fun (x : bool list) => w2bits ((bits2w x))%W16)) idfun (chunk 16
realize bvaslicesetP.
move => arr offset bv H /= k kb; rewrite /sliceset256_16_256 /=.
case (8 %| offset) => /= *; last first.
+ rewrite of_listK; 1: by rewrite size_map size_chunk // !size_cat size_take;
by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
rewrite -(map_comp W16.w2bits W16.bits2w) /(\o).
have := eq_in_map ((fun (x : bool list) => w2bits ((bits2w x))%W16)) idfun (chunk 16
(take offset (flatten (map W16.w2bits (to_list arr))) ++ w2bits bv ++
drop (offset + 256) (flatten (map W16.w2bits (to_list arr))))).
rewrite iffE => [#] -> * /=.
+ by smt(in_chunk_size W16.bits2wK).
rewrite map_id /= chunkK //.
+ rewrite !size_cat size_take 1:/#.
by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
rewrite iffE => [#] -> * /=; 1: by smt(in_chunk_size W16.bits2wK).
rewrite map_id /= chunkK //;1: by rewrite !size_cat size_take;
by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
by rewrite !nth_cat !size_cat /=;
smt(nth_take nth_drop size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
rewrite (nth_flatten _ 16); 1: by rewrite allP => i;rewrite mapP => He; elim He;smt(W16.size_w2bits).
rewrite (nth_map W16.zero []); 1: smt(Array256.size_to_list).
rewrite nth_mkseq 1:/# /= initiE 1:/# /= get16E pack2E initiE 1:/# /= initiE 1:/# /= /set256_direct.
rewrite initiE 1:/# /=.
case (offset <= k && k < offset + 256) => *; 1: by
rewrite ifT 1:/# get_bits8 /= 1,2:/# initiE // initiE //.
rewrite ifF 1:/# initiE 1:/# /=.
rewrite (nth_flatten _ 16); 1: by rewrite allP => i;rewrite mapP => He; elim He;smt(W16.size_w2bits).
rewrite (nth_map W16.zero []); 1: smt(Array256.size_to_list).
rewrite nth_mkseq 1:/# /= bits8E /= initiE /# /=.
qed.

op sliceget32_8_256 (arr: W8.t Array32.t) (i: int) : W256.t = get256 (WArray32.init8 (fun (i_0 : int) => pvc_shufbidx_s.[i_0])) (i%/256).
Expand Down Expand Up @@ -2386,36 +2409,24 @@ proc.
inline *.
proc change 1 : (init_1088_8 (fun i => W8.zero));1: by auto.
proc change 3 : (init_256_16 (fun i => r.[i]));1: by auto.
proc change ^while.1 : (sliceget256_16_256 ap (i0*256)).
move => &hr; rewrite aligned_get256_16_256; 2..: by smt(). admit. (* We need a loop invariant *)
proc change ^while.11 : (sliceset256_16_256 ap (i0*256) a1).
move => &hr; rewrite aligned_set256_16_256;2..: by smt(). admit. (* We need a loop invariant *)
proc change ^while.1 : (sliceget256_16_256 ap (i0*256)); 1: by smt().
proc change ^while.11 : (sliceset256_16_256 ap (i0*256) a1);1 : by smt().

proc change 10 : (init_768_16 (fun i => if 0 <= i && i < 256 then aux.[i] else r.[i]));1: by auto.
proc change 11 : (init_256_16 (fun i => r.[256 + i]));1: by auto.

proc change ^while{2}.1 : (sliceget256_16_256 ap0 (i1*256)).
move => &hr; rewrite aligned_get256_16_256; 2..: by smt(). admit. (* We need a loop invariant *)

proc change ^while{2}.11 : (sliceset256_16_256 ap0 (i1*256) a2).
move => &hr; rewrite aligned_set256_16_256;2..: by smt(). admit. (* We need a loop invariant *)

proc change ^while{2}.1 : (sliceget256_16_256 ap0 (i1*256)); 1: by smt().
proc change ^while{2}.11 : (sliceset256_16_256 ap0 (i1*256) a2);1: by smt().

proc change 18 : (init_768_16 (fun i => if 256 <= i && i < 256 + 256 then aux.[i - 256] else r.[i]));1: by auto.
proc change 19 : (init_256_16 (fun i => r.[2*256 + i]));1: by auto.
proc change ^while{3}.1 : (sliceget256_16_256 ap1 (i2*256)).
move => &hr; rewrite aligned_get256_16_256;2..: by smt(). admit. (* We need a loop invariant *)

proc change ^while{3}.11 : (sliceset256_16_256 ap1 (i2*256) a3).
move => &hr; rewrite aligned_set256_16_256;2..: by smt(). admit. (* We need a loop invariant *)
proc change ^while{3}.1 : (sliceget256_16_256 ap1 (i2*256)); 1: by smt().
proc change ^while{3}.11 : (sliceset256_16_256 ap1 (i2*256) a3); 1: by smt().

proc change 26 : (init_768_16 (fun i => if 2 * 256 <= i < 3 * 256 then aux.[i - 2 * 256] else r.[i]));1: by auto.

proc change 30 : (init_960_8 (fun i_0 => ctp0.[i_0 + 0]));1: by done.
proc change 37 : (sliceget32_8_256 pvc_shufbidx_s 0);1: by auto.

proc change ^while{4}.1 : (sliceget768_16_256 a (i*256));1: by smt().

proc change ^while{4}.25 : (sliceset960_8_128 rp (i * 160) lo); 1: by smt().
proc change ^while{4}.26 : (sliceset960_8_32 rp (i * 160 + 128) (VPEXTR_32 hi W8.zero));1: by smt().
cfold 38.
Expand Down

0 comments on commit 6068dcb

Please sign in to comment.