From 6ab10690ed67836706b0d7423f4f6cd527e3cb78 Mon Sep 17 00:00:00 2001 From: Vincent Laporte Date: Fri, 14 Jun 2024 10:34:09 +0200 Subject: [PATCH] Declassify the random seed This changes what is declassified and when. Before, declassification only occurs during rejection sampling and what is declassified is whether rejection occurs or not. After, the full sampling of the matrix expects its seed to be public: what is declassified is the public key when it is read from memory in the IND-CPA encryption, when it is derived from the output of SHA3 in keygen. --- code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec | 7 ++----- code/jasmin/mlkem_avx2/gen_matrix.jazz | 1 + code/jasmin/mlkem_avx2/gen_matrix.jinc | 9 ++------- code/jasmin/mlkem_avx2/indcpa.jinc | 3 +++ code/jasmin/mlkem_ref/extraction/jkem.ec | 7 ++----- code/jasmin/mlkem_ref/gen_matrix.jazz | 2 ++ code/jasmin/mlkem_ref/gen_matrix.jinc | 9 ++------- code/jasmin/mlkem_ref/indcpa.jinc | 3 +++ proof/correctness/MLKEM_InnerPKE.ec | 6 +++--- 9 files changed, 20 insertions(+), 27 deletions(-) diff --git a/code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec b/code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec index 60b95cea..104ed9a4 100644 --- a/code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec +++ b/code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec @@ -4369,7 +4369,6 @@ module M(SC:Syscall_t) = { var val1:W16.t; var t:W16.t; var val2:W16.t; - var cond:bool; ctr <- offset; pos <- (W64.of_int 0); @@ -4387,15 +4386,13 @@ module M(SC:Syscall_t) = { t <- (t `<<` (W8.of_int 4)); val2 <- (val2 `|` t); pos <- (pos + (W64.of_int 3)); - cond <- (val1 \ult (W16.of_int 3329)); - if (cond) { + if ((val1 \ult (W16.of_int 3329))) { rp.[(W64.to_uint ctr)] <- val1; ctr <- (ctr + (W64.of_int 1)); } else { } - cond <- (val2 \ult (W16.of_int 3329)); - if (cond) { + if ((val2 \ult (W16.of_int 3329))) { if ((ctr \ult (W64.of_int 256))) { rp.[(W64.to_uint ctr)] <- val2; ctr <- (ctr + (W64.of_int 1)); diff --git a/code/jasmin/mlkem_avx2/gen_matrix.jazz b/code/jasmin/mlkem_avx2/gen_matrix.jazz index 4e4d1852..22b1ea24 100644 --- a/code/jasmin/mlkem_avx2/gen_matrix.jazz +++ b/code/jasmin/mlkem_avx2/gen_matrix.jazz @@ -43,6 +43,7 @@ export fn gen_matrix_jazz(reg u64 ap, reg u64 seedp) for i = 0 to MLKEM_SYMBYTES { + #[declassify] c = (u8)[seedp + i]; seed[i] = c; } diff --git a/code/jasmin/mlkem_avx2/gen_matrix.jinc b/code/jasmin/mlkem_avx2/gen_matrix.jinc index 59e3b518..32f35ed9 100644 --- a/code/jasmin/mlkem_avx2/gen_matrix.jinc +++ b/code/jasmin/mlkem_avx2/gen_matrix.jinc @@ -29,17 +29,12 @@ fn __rej_uniform(stack u16[MLKEM_N] rp, reg u64 offset, stack u8[SHAKE128_RATE] val2 |= t; pos += 3; - reg bool cond; - #[declassify] - cond = val1 < MLKEM_Q; - if cond { + if val1 < MLKEM_Q { rp[ctr] = val1; ctr += 1; } - #[declassify] - cond = val2 < MLKEM_Q; - if cond { + if val2 < MLKEM_Q { if(ctr < MLKEM_N) { rp[ctr] = val2; diff --git a/code/jasmin/mlkem_avx2/indcpa.jinc b/code/jasmin/mlkem_avx2/indcpa.jinc index 382c3a6f..0780285a 100644 --- a/code/jasmin/mlkem_avx2/indcpa.jinc +++ b/code/jasmin/mlkem_avx2/indcpa.jinc @@ -29,6 +29,7 @@ fn __indcpa_keypair(reg u64 pkp, reg u64 skp, reg ptr u8[MLKEM_SYMBYTES] randomn for i=0 to MLKEM_SYMBYTES/8 { + #[declassify] t64 = buf[u64 i]; publicseed[u64 i] = t64; t64 = buf[u64 i + MLKEM_SYMBYTES/8]; @@ -91,6 +92,7 @@ fn __indcpa_enc_0(stack u64 sctp, reg ptr u8[MLKEM_INDCPA_MSGBYTES] msgp, reg u6 pkp += MLKEM_POLYVECBYTES; while (i < MLKEM_SYMBYTES/8) { + #[declassify] t64 = (u64)[pkp]; publicseed.[u64 8 * (int)i] = t64; pkp += 8; @@ -155,6 +157,7 @@ fn __indcpa_enc_1(reg ptr u8[MLKEM_INDCPA_CIPHERTEXTBYTES] ctp, reg ptr u8[MLKEM pkp += MLKEM_POLYVECBYTES; while (i < MLKEM_SYMBYTES/8) { + #[declassify] t64 = (u64)[pkp]; publicseed.[u64 8*(int)i] = t64; pkp += 8; diff --git a/code/jasmin/mlkem_ref/extraction/jkem.ec b/code/jasmin/mlkem_ref/extraction/jkem.ec index 9c4716fc..707d104e 100644 --- a/code/jasmin/mlkem_ref/extraction/jkem.ec +++ b/code/jasmin/mlkem_ref/extraction/jkem.ec @@ -1634,7 +1634,6 @@ module M(SC:Syscall_t) = { var val1:W16.t; var t:W16.t; var val2:W16.t; - var cond:bool; ctr <- offset; pos <- (W64.of_int 0); @@ -1652,15 +1651,13 @@ module M(SC:Syscall_t) = { t <- (t `<<` (W8.of_int 4)); val2 <- (val2 `|` t); pos <- (pos + (W64.of_int 3)); - cond <- (val1 \ult (W16.of_int 3329)); - if (cond) { + if ((val1 \ult (W16.of_int 3329))) { rp.[(W64.to_uint ctr)] <- val1; ctr <- (ctr + (W64.of_int 1)); } else { } - cond <- (val2 \ult (W16.of_int 3329)); - if (cond) { + if ((val2 \ult (W16.of_int 3329))) { if ((ctr \ult (W64.of_int 256))) { rp.[(W64.to_uint ctr)] <- val2; ctr <- (ctr + (W64.of_int 1)); diff --git a/code/jasmin/mlkem_ref/gen_matrix.jazz b/code/jasmin/mlkem_ref/gen_matrix.jazz index 444017dc..ae7012ea 100644 --- a/code/jasmin/mlkem_ref/gen_matrix.jazz +++ b/code/jasmin/mlkem_ref/gen_matrix.jazz @@ -1,3 +1,4 @@ +require "params.jinc" require "gen_matrix.jinc" export fn gen_matrix_jazz(reg u64 ap, reg u64 seedp, reg u64 transposed) @@ -11,6 +12,7 @@ export fn gen_matrix_jazz(reg u64 ap, reg u64 seedp, reg u64 transposed) for i = 0 to MLKEM_SYMBYTES { + #[declassify] c = (u8)[seedp + i]; seed[i] = c; } diff --git a/code/jasmin/mlkem_ref/gen_matrix.jinc b/code/jasmin/mlkem_ref/gen_matrix.jinc index f261b711..13e6616c 100644 --- a/code/jasmin/mlkem_ref/gen_matrix.jinc +++ b/code/jasmin/mlkem_ref/gen_matrix.jinc @@ -27,17 +27,12 @@ fn __rej_uniform(stack u16[MLKEM_N] rp, reg u64 offset, stack u8[SHAKE128_RATE] val2 |= t; pos += 3; - reg bool cond; - #[declassify] - cond = val1 < MLKEM_Q; - if cond { + if val1 < MLKEM_Q { rp[ctr] = val1; ctr += 1; } - #[declassify] - cond = val2 < MLKEM_Q; - if cond { + if val2 < MLKEM_Q { if(ctr < MLKEM_N) { rp[ctr] = val2; diff --git a/code/jasmin/mlkem_ref/indcpa.jinc b/code/jasmin/mlkem_ref/indcpa.jinc index a0f2a463..269c2fb3 100644 --- a/code/jasmin/mlkem_ref/indcpa.jinc +++ b/code/jasmin/mlkem_ref/indcpa.jinc @@ -32,6 +32,7 @@ fn __indcpa_keypair(reg u64 pkp, reg u64 skp, reg ptr u8[MLKEM_SYMBYTES] randomn for i=0 to MLKEM_SYMBYTES/8 { + #[declassify] t64 = buf[u64 i]; publicseed[u64 i] = t64; t64 = buf[u64 i + MLKEM_SYMBYTES/8]; @@ -103,6 +104,7 @@ fn __indcpa_enc(stack u64 sctp, reg ptr u8[32] msgp, reg u64 pkp, reg ptr u8[MLK pkp += MLKEM_POLYVECBYTES; while (i < MLKEM_SYMBYTES/8) { + #[declassify] t64 = (u64)[pkp]; publicseed.[u64 8 * (int)i] = t64; pkp += 8; @@ -178,6 +180,7 @@ fn __iindcpa_enc(reg ptr u8[MLKEM_CT_LEN] ctp, reg ptr u8[32] msgp, reg u64 pkp, pkp += MLKEM_POLYVECBYTES; while (i < MLKEM_SYMBYTES/8) { + #[declassify] t64 = (u64)[pkp]; publicseed.[u64 8*(int)i] = t64; pkp += 8; diff --git a/proof/correctness/MLKEM_InnerPKE.ec b/proof/correctness/MLKEM_InnerPKE.ec index 774d3c9b..0e7de184 100644 --- a/proof/correctness/MLKEM_InnerPKE.ec +++ b/proof/correctness/MLKEM_InnerPKE.ec @@ -591,8 +591,8 @@ seq 4 2 : (to_uint ctr0{1} = j0{2} /\ to_uint pos{1} = k{2} /\ #{/~pos{1} \ult (of_int (168 - 2))%W64}post). -+ sp 1 0; if; 1: by move => &1 &2; rewrite ultE qE; smt(). - + sp 3 2; if{2}. ++ if; 1: by move => &1 &2; rewrite ultE qE; smt(). + + sp 2 2; if{2}. + rcondt{1} 1; 1: by move => *; auto => /> *; rewrite ultE; smt(). rcondt{1} 1; 1: by move => *; auto => /> *; rewrite ultE /= to_uintD_small /= /#. auto => /> &1 aar ctrl rpl 8?; rewrite ultE /= => *; do split; 2..3:smt(). @@ -638,7 +638,7 @@ seq 4 2 : (to_uint ctr0{1} = j0{2} /\ rewrite set_eqiE 1,2:/#. by rewrite to_sint_unsigned; rewrite /to_sint /smod /=; smt(W16.to_uint_cmp). - sp 1 0; if{2}. + if{2}. + rcondt{1} 1; 1: by move => *; auto => /> *; rewrite ultE /#. rcondt{1} 1; 1: by move => *; auto => /> *; rewrite ultE /#. auto => /> &1 &2 8?; rewrite ultE /= => *; do split; 2..3:smt().