diff --git a/code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec b/code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec index 60b95cea..104ed9a4 100644 --- a/code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec +++ b/code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec @@ -4369,7 +4369,6 @@ module M(SC:Syscall_t) = { var val1:W16.t; var t:W16.t; var val2:W16.t; - var cond:bool; ctr <- offset; pos <- (W64.of_int 0); @@ -4387,15 +4386,13 @@ module M(SC:Syscall_t) = { t <- (t `<<` (W8.of_int 4)); val2 <- (val2 `|` t); pos <- (pos + (W64.of_int 3)); - cond <- (val1 \ult (W16.of_int 3329)); - if (cond) { + if ((val1 \ult (W16.of_int 3329))) { rp.[(W64.to_uint ctr)] <- val1; ctr <- (ctr + (W64.of_int 1)); } else { } - cond <- (val2 \ult (W16.of_int 3329)); - if (cond) { + if ((val2 \ult (W16.of_int 3329))) { if ((ctr \ult (W64.of_int 256))) { rp.[(W64.to_uint ctr)] <- val2; ctr <- (ctr + (W64.of_int 1)); diff --git a/code/jasmin/mlkem_avx2/gen_matrix.jazz b/code/jasmin/mlkem_avx2/gen_matrix.jazz index 4e4d1852..22b1ea24 100644 --- a/code/jasmin/mlkem_avx2/gen_matrix.jazz +++ b/code/jasmin/mlkem_avx2/gen_matrix.jazz @@ -43,6 +43,7 @@ export fn gen_matrix_jazz(reg u64 ap, reg u64 seedp) for i = 0 to MLKEM_SYMBYTES { + #[declassify] c = (u8)[seedp + i]; seed[i] = c; } diff --git a/code/jasmin/mlkem_avx2/gen_matrix.jinc b/code/jasmin/mlkem_avx2/gen_matrix.jinc index 59e3b518..32f35ed9 100644 --- a/code/jasmin/mlkem_avx2/gen_matrix.jinc +++ b/code/jasmin/mlkem_avx2/gen_matrix.jinc @@ -29,17 +29,12 @@ fn __rej_uniform(stack u16[MLKEM_N] rp, reg u64 offset, stack u8[SHAKE128_RATE] val2 |= t; pos += 3; - reg bool cond; - #[declassify] - cond = val1 < MLKEM_Q; - if cond { + if val1 < MLKEM_Q { rp[ctr] = val1; ctr += 1; } - #[declassify] - cond = val2 < MLKEM_Q; - if cond { + if val2 < MLKEM_Q { if(ctr < MLKEM_N) { rp[ctr] = val2; diff --git a/code/jasmin/mlkem_avx2/indcpa.jinc b/code/jasmin/mlkem_avx2/indcpa.jinc index 382c3a6f..0780285a 100644 --- a/code/jasmin/mlkem_avx2/indcpa.jinc +++ b/code/jasmin/mlkem_avx2/indcpa.jinc @@ -29,6 +29,7 @@ fn __indcpa_keypair(reg u64 pkp, reg u64 skp, reg ptr u8[MLKEM_SYMBYTES] randomn for i=0 to MLKEM_SYMBYTES/8 { + #[declassify] t64 = buf[u64 i]; publicseed[u64 i] = t64; t64 = buf[u64 i + MLKEM_SYMBYTES/8]; @@ -91,6 +92,7 @@ fn __indcpa_enc_0(stack u64 sctp, reg ptr u8[MLKEM_INDCPA_MSGBYTES] msgp, reg u6 pkp += MLKEM_POLYVECBYTES; while (i < MLKEM_SYMBYTES/8) { + #[declassify] t64 = (u64)[pkp]; publicseed.[u64 8 * (int)i] = t64; pkp += 8; @@ -155,6 +157,7 @@ fn __indcpa_enc_1(reg ptr u8[MLKEM_INDCPA_CIPHERTEXTBYTES] ctp, reg ptr u8[MLKEM pkp += MLKEM_POLYVECBYTES; while (i < MLKEM_SYMBYTES/8) { + #[declassify] t64 = (u64)[pkp]; publicseed.[u64 8*(int)i] = t64; pkp += 8; diff --git a/code/jasmin/mlkem_ref/extraction/jkem.ec b/code/jasmin/mlkem_ref/extraction/jkem.ec index 9c4716fc..707d104e 100644 --- a/code/jasmin/mlkem_ref/extraction/jkem.ec +++ b/code/jasmin/mlkem_ref/extraction/jkem.ec @@ -1634,7 +1634,6 @@ module M(SC:Syscall_t) = { var val1:W16.t; var t:W16.t; var val2:W16.t; - var cond:bool; ctr <- offset; pos <- (W64.of_int 0); @@ -1652,15 +1651,13 @@ module M(SC:Syscall_t) = { t <- (t `<<` (W8.of_int 4)); val2 <- (val2 `|` t); pos <- (pos + (W64.of_int 3)); - cond <- (val1 \ult (W16.of_int 3329)); - if (cond) { + if ((val1 \ult (W16.of_int 3329))) { rp.[(W64.to_uint ctr)] <- val1; ctr <- (ctr + (W64.of_int 1)); } else { } - cond <- (val2 \ult (W16.of_int 3329)); - if (cond) { + if ((val2 \ult (W16.of_int 3329))) { if ((ctr \ult (W64.of_int 256))) { rp.[(W64.to_uint ctr)] <- val2; ctr <- (ctr + (W64.of_int 1)); diff --git a/code/jasmin/mlkem_ref/gen_matrix.jazz b/code/jasmin/mlkem_ref/gen_matrix.jazz index 444017dc..ae7012ea 100644 --- a/code/jasmin/mlkem_ref/gen_matrix.jazz +++ b/code/jasmin/mlkem_ref/gen_matrix.jazz @@ -1,3 +1,4 @@ +require "params.jinc" require "gen_matrix.jinc" export fn gen_matrix_jazz(reg u64 ap, reg u64 seedp, reg u64 transposed) @@ -11,6 +12,7 @@ export fn gen_matrix_jazz(reg u64 ap, reg u64 seedp, reg u64 transposed) for i = 0 to MLKEM_SYMBYTES { + #[declassify] c = (u8)[seedp + i]; seed[i] = c; } diff --git a/code/jasmin/mlkem_ref/gen_matrix.jinc b/code/jasmin/mlkem_ref/gen_matrix.jinc index f261b711..13e6616c 100644 --- a/code/jasmin/mlkem_ref/gen_matrix.jinc +++ b/code/jasmin/mlkem_ref/gen_matrix.jinc @@ -27,17 +27,12 @@ fn __rej_uniform(stack u16[MLKEM_N] rp, reg u64 offset, stack u8[SHAKE128_RATE] val2 |= t; pos += 3; - reg bool cond; - #[declassify] - cond = val1 < MLKEM_Q; - if cond { + if val1 < MLKEM_Q { rp[ctr] = val1; ctr += 1; } - #[declassify] - cond = val2 < MLKEM_Q; - if cond { + if val2 < MLKEM_Q { if(ctr < MLKEM_N) { rp[ctr] = val2; diff --git a/code/jasmin/mlkem_ref/indcpa.jinc b/code/jasmin/mlkem_ref/indcpa.jinc index a0f2a463..269c2fb3 100644 --- a/code/jasmin/mlkem_ref/indcpa.jinc +++ b/code/jasmin/mlkem_ref/indcpa.jinc @@ -32,6 +32,7 @@ fn __indcpa_keypair(reg u64 pkp, reg u64 skp, reg ptr u8[MLKEM_SYMBYTES] randomn for i=0 to MLKEM_SYMBYTES/8 { + #[declassify] t64 = buf[u64 i]; publicseed[u64 i] = t64; t64 = buf[u64 i + MLKEM_SYMBYTES/8]; @@ -103,6 +104,7 @@ fn __indcpa_enc(stack u64 sctp, reg ptr u8[32] msgp, reg u64 pkp, reg ptr u8[MLK pkp += MLKEM_POLYVECBYTES; while (i < MLKEM_SYMBYTES/8) { + #[declassify] t64 = (u64)[pkp]; publicseed.[u64 8 * (int)i] = t64; pkp += 8; @@ -178,6 +180,7 @@ fn __iindcpa_enc(reg ptr u8[MLKEM_CT_LEN] ctp, reg ptr u8[32] msgp, reg u64 pkp, pkp += MLKEM_POLYVECBYTES; while (i < MLKEM_SYMBYTES/8) { + #[declassify] t64 = (u64)[pkp]; publicseed.[u64 8*(int)i] = t64; pkp += 8; diff --git a/proof/correctness/MLKEM_InnerPKE.ec b/proof/correctness/MLKEM_InnerPKE.ec index 774d3c9b..0e7de184 100644 --- a/proof/correctness/MLKEM_InnerPKE.ec +++ b/proof/correctness/MLKEM_InnerPKE.ec @@ -591,8 +591,8 @@ seq 4 2 : (to_uint ctr0{1} = j0{2} /\ to_uint pos{1} = k{2} /\ #{/~pos{1} \ult (of_int (168 - 2))%W64}post). -+ sp 1 0; if; 1: by move => &1 &2; rewrite ultE qE; smt(). - + sp 3 2; if{2}. ++ if; 1: by move => &1 &2; rewrite ultE qE; smt(). + + sp 2 2; if{2}. + rcondt{1} 1; 1: by move => *; auto => /> *; rewrite ultE; smt(). rcondt{1} 1; 1: by move => *; auto => /> *; rewrite ultE /= to_uintD_small /= /#. auto => /> &1 aar ctrl rpl 8?; rewrite ultE /= => *; do split; 2..3:smt(). @@ -638,7 +638,7 @@ seq 4 2 : (to_uint ctr0{1} = j0{2} /\ rewrite set_eqiE 1,2:/#. by rewrite to_sint_unsigned; rewrite /to_sint /smod /=; smt(W16.to_uint_cmp). - sp 1 0; if{2}. + if{2}. + rcondt{1} 1; 1: by move => *; auto => /> *; rewrite ultE /#. rcondt{1} 1; 1: by move => *; auto => /> *; rewrite ultE /#. auto => /> &1 &2 8?; rewrite ultE /= => *; do split; 2..3:smt().