diff --git a/proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec b/proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec index b2e3c7da..aab3e17a 100644 --- a/proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec +++ b/proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec @@ -1815,8 +1815,12 @@ rewrite BVA_Top_Array1088_Array1088_t.tolistP. apply eq_in_mkseq => i i_bnd; smt(Array1088.initE). qed. -op sliceget256_16_256 (arr: W16.t Array256.t) (offset: int) : W256.t = W256.bits2w (take 256 (drop offset (flatten (map W16.w2bits (Array256.to_list arr))))). +op sliceget256_16_256 (arr: W16.t Array256.t) (offset: int) : W256.t = + if 8 %| offset then + get256_direct ((init16 (fun (i_0 : int) => arr.[i_0])))%WArray512 (offset %/ 8) + else W256.bits2w (take 256 (drop offset (flatten (map W16.w2bits (to_list arr))))). +(* lemma flatten_take_drop_16 (l : W16.t list) (csize offset bit : int) : 0 <= offset => offset + csize <= 16 * size l => @@ -1833,13 +1837,6 @@ rewrite -get_w2bits;congr. by rewrite (nth_map witness) 1:/#. qed. -lemma size_flatten_W16_w2bits (a : W16.t list) : - (size (flatten (map W16.w2bits (a)))) = 16 * size a. -proof. - rewrite size_flatten -map_comp /(\o) /=. - rewrite StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /=. - rewrite StdBigop.Bigint.big_constz count_predT /#. -qed. lemma aligned_get256_16_256 arr offset : 0 <= offset <= 16*256 - 256 => @@ -1856,20 +1853,29 @@ rewrite get_bits8 1:/#. smt(@IntDiv). qed. - +*) bind op [W16.t & W256.t & Array256.t] sliceget256_16_256 "asliceget". realize bvaslicegetP. -move => *; rewrite /sliceget256_16_256 bits2wK // size_take //= size_drop //=. -admit. (* bounds are incomplete! 0 <= offset <= 16 * 256 - 256 *) -by smt(Array256.size_to_list size_flatten_W16_w2bits). + +move => /= arr offset; rewrite /sliceget256_16_256 /= => H k kb. +case (8%| offset) => /= *; last by smt(W256.get_bits2w). +rewrite /get256_direct pack32E initiE 1:/# /= initiE 1:/# /= initiE 1:/# /= bits8E initiE 1:/# /=. +rewrite nth_take 1,2:/# nth_drop 1,2:/#. +rewrite (BitEncoding.BitChunking.nth_flatten false 16 _). ++ rewrite allP => x /=; rewrite mapP => He; elim He;smt(W16.size_w2bits). +rewrite (nth_map W16.zero []); 1: smt(Array256.size_to_list). +by rewrite nth_mkseq /#. qed. import BitEncoding BS2Int BitChunking. -op sliceset256_16_256 (arr: W16.t Array256.t) (offset: int) (bv: W256.t) : W16.t Array256.t = Array256.of_list witness (map W16.bits2w (chunk 16 (take offset (flatten (map W16.w2bits (to_list arr))) ++ w2bits bv ++ +op sliceset256_16_256 (arr: W16.t Array256.t) (offset: int) (bv: W256.t) : W16.t Array256.t = + if 8 %| offset + then (init (fun (i3 : int) => get16 (set256_direct ((init16 (fun (i_0 : int) => arr.[i_0])))%WArray512 (offset %/ 8) bv) i3))%Array256 + else Array256.of_list witness (map W16.bits2w (chunk 16 (take offset (flatten (map W16.w2bits (to_list arr))) ++ w2bits bv ++ drop (offset + 256) (flatten (map W16.w2bits (to_list arr)))))). - +(* lemma aligned_set256_16_256 arr offset bv : 0 <= offset <= 16*256 - 256 => 256 %| offset => @@ -1930,24 +1936,41 @@ case (2 * i < 32 * (offset %/ 256 + 1));last first. rewrite nth_drop; 1,2: by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). rewrite !get_w2bits get_bits16;by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). qed. +*) - +lemma size_flatten_W16_w2bits (a : W16.t list) : + (size (flatten (map W16.w2bits (a)))) = 16 * size a. +proof. + rewrite size_flatten -map_comp /(\o) /=. + rewrite StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /=. + rewrite StdBigop.Bigint.big_constz count_predT /#. +qed. bind op [W16.t & W256.t & Array256.t] sliceset256_16_256 "asliceset". -realize bvaslicesetP. (* bounds are incomplete! 0 <= offset <= 16 * 256 - 256 *) -move => arr offset bv *. have ? : 0 <= offset by admit. -rewrite /sliceset256_16_256 of_listK. -+ rewrite size_map size_chunk // !size_cat size_take 1:/#. - by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). -rewrite -(map_comp W16.w2bits W16.bits2w) /(\o). -have := eq_in_map ((fun (x : bool list) => w2bits ((bits2w x))%W16)) idfun (chunk 16 +realize bvaslicesetP. +move => arr offset bv H /= k kb; rewrite /sliceset256_16_256 /=. +case (8 %| offset) => /= *; last first. ++ rewrite of_listK; 1: by rewrite size_map size_chunk // !size_cat size_take; + by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). + rewrite -(map_comp W16.w2bits W16.bits2w) /(\o). + have := eq_in_map ((fun (x : bool list) => w2bits ((bits2w x))%W16)) idfun (chunk 16 (take offset (flatten (map W16.w2bits (to_list arr))) ++ w2bits bv ++ drop (offset + 256) (flatten (map W16.w2bits (to_list arr))))). -rewrite iffE => [#] -> * /=. -+ by smt(in_chunk_size W16.bits2wK). -rewrite map_id /= chunkK //. -+ rewrite !size_cat size_take 1:/#. - by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). + rewrite iffE => [#] -> * /=; 1: by smt(in_chunk_size W16.bits2wK). + rewrite map_id /= chunkK //;1: by rewrite !size_cat size_take; + by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). + by rewrite !nth_cat !size_cat /=; + smt(nth_take nth_drop size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). +rewrite (nth_flatten _ 16); 1: by rewrite allP => i;rewrite mapP => He; elim He;smt(W16.size_w2bits). +rewrite (nth_map W16.zero []); 1: smt(Array256.size_to_list). +rewrite nth_mkseq 1:/# /= initiE 1:/# /= get16E pack2E initiE 1:/# /= initiE 1:/# /= /set256_direct. +rewrite initiE 1:/# /=. +case (offset <= k && k < offset + 256) => *; 1: by + rewrite ifT 1:/# get_bits8 /= 1,2:/# initiE // initiE //. +rewrite ifF 1:/# initiE 1:/# /=. +rewrite (nth_flatten _ 16); 1: by rewrite allP => i;rewrite mapP => He; elim He;smt(W16.size_w2bits). +rewrite (nth_map W16.zero []); 1: smt(Array256.size_to_list). +rewrite nth_mkseq 1:/# /= bits8E /= initiE /# /=. qed. op sliceget32_8_256 (arr: W8.t Array32.t) (i: int) : W256.t = get256 (WArray32.init8 (fun (i_0 : int) => pvc_shufbidx_s.[i_0])) (i%/256). @@ -2386,36 +2409,24 @@ proc. inline *. proc change 1 : (init_1088_8 (fun i => W8.zero));1: by auto. proc change 3 : (init_256_16 (fun i => r.[i]));1: by auto. -proc change ^while.1 : (sliceget256_16_256 ap (i0*256)). - move => &hr; rewrite aligned_get256_16_256; 2..: by smt(). admit. (* We need a loop invariant *) -proc change ^while.11 : (sliceset256_16_256 ap (i0*256) a1). - move => &hr; rewrite aligned_set256_16_256;2..: by smt(). admit. (* We need a loop invariant *) +proc change ^while.1 : (sliceget256_16_256 ap (i0*256)); 1: by smt(). +proc change ^while.11 : (sliceset256_16_256 ap (i0*256) a1);1 : by smt(). proc change 10 : (init_768_16 (fun i => if 0 <= i && i < 256 then aux.[i] else r.[i]));1: by auto. proc change 11 : (init_256_16 (fun i => r.[256 + i]));1: by auto. - -proc change ^while{2}.1 : (sliceget256_16_256 ap0 (i1*256)). - move => &hr; rewrite aligned_get256_16_256; 2..: by smt(). admit. (* We need a loop invariant *) - -proc change ^while{2}.11 : (sliceset256_16_256 ap0 (i1*256) a2). - move => &hr; rewrite aligned_set256_16_256;2..: by smt(). admit. (* We need a loop invariant *) - +proc change ^while{2}.1 : (sliceget256_16_256 ap0 (i1*256)); 1: by smt(). +proc change ^while{2}.11 : (sliceset256_16_256 ap0 (i1*256) a2);1: by smt(). proc change 18 : (init_768_16 (fun i => if 256 <= i && i < 256 + 256 then aux.[i - 256] else r.[i]));1: by auto. proc change 19 : (init_256_16 (fun i => r.[2*256 + i]));1: by auto. -proc change ^while{3}.1 : (sliceget256_16_256 ap1 (i2*256)). - move => &hr; rewrite aligned_get256_16_256;2..: by smt(). admit. (* We need a loop invariant *) - -proc change ^while{3}.11 : (sliceset256_16_256 ap1 (i2*256) a3). - move => &hr; rewrite aligned_set256_16_256;2..: by smt(). admit. (* We need a loop invariant *) +proc change ^while{3}.1 : (sliceget256_16_256 ap1 (i2*256)); 1: by smt(). +proc change ^while{3}.11 : (sliceset256_16_256 ap1 (i2*256) a3); 1: by smt(). proc change 26 : (init_768_16 (fun i => if 2 * 256 <= i < 3 * 256 then aux.[i - 2 * 256] else r.[i]));1: by auto. - proc change 30 : (init_960_8 (fun i_0 => ctp0.[i_0 + 0]));1: by done. proc change 37 : (sliceget32_8_256 pvc_shufbidx_s 0);1: by auto. proc change ^while{4}.1 : (sliceget768_16_256 a (i*256));1: by smt(). - proc change ^while{4}.25 : (sliceset960_8_128 rp (i * 160) lo); 1: by smt(). proc change ^while{4}.26 : (sliceset960_8_32 rp (i * 160 + 128) (VPEXTR_32 hi W8.zero));1: by smt(). cfold 38.