Skip to content

Commit

Permalink
speeding things up a bit
Browse files Browse the repository at this point in the history
  • Loading branch information
mbbarbosa-lectures committed Nov 30, 2024
1 parent 41c3b76 commit b4198d1
Show file tree
Hide file tree
Showing 11 changed files with 1,082 additions and 809 deletions.
141 changes: 69 additions & 72 deletions code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,10 @@ import SLH64.
require import
Array1 Array4 Array5 Array6 Array7 Array8 Array9 Array16 Array24 Array25
Array32 Array33 Array64 Array128 Array136 Array256 Array400 Array536 Array768
Array960 Array1024 Array1088 Array2048 Array2144 Array2304.

require import
WArray8 WArray16 WArray32 WArray33 WArray40 WArray64 WArray128 WArray136
WArray160 WArray192 WArray200 WArray224 WArray256 WArray288 WArray512
WArray536 WArray768 WArray800 WArray960 WArray1088 WArray1536 WArray2048
WArray2144 WArray4608.
Array960 Array1024 Array1088 Array2048 Array2144 Array2304 WArray8 WArray16
WArray32 WArray33 WArray40 WArray64 WArray128 WArray136 WArray160 WArray192
WArray200 WArray224 WArray256 WArray288 WArray512 WArray536 WArray768
WArray800 WArray960 WArray1088 WArray1536 WArray2048 WArray2144 WArray4608.

abbrev gen_matrix_indexes =
(Array16.of_list witness
Expand Down Expand Up @@ -1313,7 +1310,7 @@ module M(SC:Syscall_t) = {
var t1:W256.t;
t0 <- (VMOVSLDUP_256 b);
t0 <- (VPBLEND_8u32 a t0 (W8.of_int 170));
a <- (VPSRL_4u64 a (W8.of_int 32));
a <- (VPSRL_4u64 a (W128.of_int 32));
t1 <- (VPBLEND_8u32 a b (W8.of_int 170));
return (t0, t1);
}
Expand All @@ -1322,9 +1319,9 @@ module M(SC:Syscall_t) = {
var r1:W256.t;
var t0:W256.t;
var t1:W256.t;
t0 <- (VPSLL_8u32 b (W8.of_int 16));
t0 <- (VPSLL_8u32 b (W128.of_int 16));
r0 <- (VPBLEND_16u16 a t0 (W8.of_int 170));
t1 <- (VPSRL_8u32 a (W8.of_int 16));
t1 <- (VPSRL_8u32 a (W128.of_int 16));
r1 <- (VPBLEND_16u16 t1 b (W8.of_int 170));
return (r0, r1);
}
Expand Down Expand Up @@ -1466,15 +1463,15 @@ module M(SC:Syscall_t) = {
proc __csubq (r:W256.t, qx16:W256.t) : W256.t = {
var t:W256.t;
r <- (VPSUB_16u16 r qx16);
t <- (VPSRA_16u16 r (W8.of_int 15));
t <- (VPSRA_16u16 r (W128.of_int 15));
t <- (VPAND_256 t qx16);
r <- (VPADD_16u16 t r);
return r;
}
proc __red16x (r:W256.t, qx16:W256.t, vx16:W256.t) : W256.t = {
var x:W256.t;
x <- (VPMULH_16u16 r vx16);
x <- (VPSRA_16u16 x (W8.of_int 10));
x <- (VPSRA_16u16 x (W128.of_int 10));
x <- (VPMULL_16u16 x qx16);
r <- (VPSUB_16u16 r x);
return r;
Expand Down Expand Up @@ -1939,8 +1936,8 @@ module M(SC:Syscall_t) = {
proc __rol_4u64 (a:W256.t, o:int) : W256.t = {
var r:W256.t;
var t256:W256.t;
r <- (VPSLL_4u64 a (W8.of_int o));
t256 <- (VPSRL_4u64 a (W8.of_int (64 - o)));
r <- (VPSLL_4u64 a (W128.of_int o));
t256 <- (VPSRL_4u64 a (W128.of_int (64 - o)));
r <- (r `|` t256);
return r;
}
Expand Down Expand Up @@ -3101,8 +3098,8 @@ module M(SC:Syscall_t) = {
al <- (VPBLEND_16u16 a0 _zero (W8.of_int 170));
ah <- (VPBLEND_16u16 a1 _zero (W8.of_int 170));
al <- (VPACKUS_8u32 al ah);
a0 <- (VPSRL_8u32 a0 (W8.of_int 16));
a1 <- (VPSRL_8u32 a1 (W8.of_int 16));
a0 <- (VPSRL_8u32 a0 (W128.of_int 16));
a1 <- (VPSRL_8u32 a1 (W128.of_int 16));
ah <- (VPACKUS_8u32 a0 a1);
return (al, ah);
}
Expand Down Expand Up @@ -3515,27 +3512,27 @@ module M(SC:Syscall_t) = {
(t6, t3) <@ __shuffle1 (t0, t3);
(t0, t4) <@ __shuffle1 (t1, t4);
(t1, t5) <@ __shuffle1 (t2, t5);
t7 <- (VPSRL_16u16 t6 (W8.of_int 12));
t8 <- (VPSLL_16u16 t3 (W8.of_int 4));
t7 <- (VPSRL_16u16 t6 (W128.of_int 12));
t8 <- (VPSLL_16u16 t3 (W128.of_int 4));
t7 <- (VPOR_256 t7 t8);
t6 <- (VPAND_256 mask t6);
t7 <- (VPAND_256 mask t7);
t8 <- (VPSRL_16u16 t3 (W8.of_int 8));
t9 <- (VPSLL_16u16 t0 (W8.of_int 8));
t8 <- (VPSRL_16u16 t3 (W128.of_int 8));
t9 <- (VPSLL_16u16 t0 (W128.of_int 8));
t8 <- (VPOR_256 t8 t9);
t8 <- (VPAND_256 mask t8);
t9 <- (VPSRL_16u16 t0 (W8.of_int 4));
t9 <- (VPSRL_16u16 t0 (W128.of_int 4));
t9 <- (VPAND_256 mask t9);
t10 <- (VPSRL_16u16 t4 (W8.of_int 12));
t11 <- (VPSLL_16u16 t1 (W8.of_int 4));
t10 <- (VPSRL_16u16 t4 (W128.of_int 12));
t11 <- (VPSLL_16u16 t1 (W128.of_int 4));
t10 <- (VPOR_256 t10 t11);
t4 <- (VPAND_256 mask t4);
t10 <- (VPAND_256 mask t10);
t11 <- (VPSRL_16u16 t1 (W8.of_int 8));
tt <- (VPSLL_16u16 t5 (W8.of_int 8));
t11 <- (VPSRL_16u16 t1 (W128.of_int 8));
tt <- (VPSLL_16u16 t5 (W128.of_int 8));
t11 <- (VPOR_256 t11 tt);
t11 <- (VPAND_256 mask t11);
tt <- (VPSRL_16u16 t5 (W8.of_int 4));
tt <- (VPSRL_16u16 t5 (W128.of_int 4));
tt <- (VPAND_256 mask tt);
rp <-
(Array256.init
Expand Down Expand Up @@ -3640,13 +3637,13 @@ module M(SC:Syscall_t) = {
g3 <- (VPSHUFD_256 f (W8.of_int (85 * i)));
g3 <- (VPSLLV_8u32 g3 shift);
g3 <- (VPSHUFB_256 g3 idx);
g0 <- (VPSLL_16u16 g3 (W8.of_int 12));
g1 <- (VPSLL_16u16 g3 (W8.of_int 8));
g2 <- (VPSLL_16u16 g3 (W8.of_int 4));
g0 <- (VPSRA_16u16 g0 (W8.of_int 15));
g1 <- (VPSRA_16u16 g1 (W8.of_int 15));
g2 <- (VPSRA_16u16 g2 (W8.of_int 15));
g3 <- (VPSRA_16u16 g3 (W8.of_int 15));
g0 <- (VPSLL_16u16 g3 (W128.of_int 12));
g1 <- (VPSLL_16u16 g3 (W128.of_int 8));
g2 <- (VPSLL_16u16 g3 (W128.of_int 4));
g0 <- (VPSRA_16u16 g0 (W128.of_int 15));
g1 <- (VPSRA_16u16 g1 (W128.of_int 15));
g2 <- (VPSRA_16u16 g2 (W128.of_int 15));
g3 <- (VPSRA_16u16 g3 (W128.of_int 15));
g0 <- (VPAND_256 g0 hqs);
g1 <- (VPAND_256 g1 hqs);
g2 <- (VPAND_256 g2 hqs);
Expand Down Expand Up @@ -3719,19 +3716,19 @@ module M(SC:Syscall_t) = {
(get256_direct (WArray128.init8 (fun i_0 => buf.[i_0])) (24 * i));
f0 <- (VPERMQ f0 (W8.of_int 148));
f0 <- (VPSHUFB_256 f0 shufbidx);
f1 <- (VPSRL_8u32 f0 (W8.of_int 1));
f2 <- (VPSRL_8u32 f0 (W8.of_int 2));
f1 <- (VPSRL_8u32 f0 (W128.of_int 1));
f2 <- (VPSRL_8u32 f0 (W128.of_int 2));
f0 <- (VPAND_256 mask249 f0);
f1 <- (VPAND_256 mask249 f1);
f2 <- (VPAND_256 mask249 f2);
f0 <- (VPADD_8u32 f0 f1);
f0 <- (VPADD_8u32 f0 f2);
f1 <- (VPSRL_8u32 f0 (W8.of_int 3));
f1 <- (VPSRL_8u32 f0 (W128.of_int 3));
f0 <- (VPADD_8u32 f0 mask6DB);
f0 <- (VPSUB_8u32 f0 f1);
f1 <- (VPSLL_8u32 f0 (W8.of_int 10));
f2 <- (VPSRL_8u32 f0 (W8.of_int 12));
f3 <- (VPSRL_8u32 f0 (W8.of_int 2));
f1 <- (VPSLL_8u32 f0 (W128.of_int 10));
f2 <- (VPSRL_8u32 f0 (W128.of_int 12));
f3 <- (VPSRL_8u32 f0 (W128.of_int 2));
f0 <- (VPAND_256 f0 mask07);
f1 <- (VPAND_256 f1 mask70);
f2 <- (VPAND_256 f2 mask07);
Expand Down Expand Up @@ -3786,16 +3783,16 @@ module M(SC:Syscall_t) = {
i <- 0;
while ((i < aux)) {
f0 <- (get256 (WArray128.init8 (fun i_0 => buf.[i_0])) i);
f1 <- (VPSRL_16u16 f0 (W8.of_int 1));
f1 <- (VPSRL_16u16 f0 (W128.of_int 1));
f0 <- (VPAND_256 mask55 f0);
f1 <- (VPAND_256 mask55 f1);
f0 <- (VPADD_32u8 f0 f1);
f1 <- (VPSRL_16u16 f0 (W8.of_int 2));
f1 <- (VPSRL_16u16 f0 (W128.of_int 2));
f0 <- (VPAND_256 mask33 f0);
f1 <- (VPAND_256 mask33 f1);
f0 <- (VPADD_32u8 f0 mask33);
f0 <- (VPSUB_32u8 f0 f1);
f1 <- (VPSRL_16u16 f0 (W8.of_int 4));
f1 <- (VPSRL_16u16 f0 (W128.of_int 4));
f0 <- (VPAND_256 mask0F f0);
f1 <- (VPAND_256 mask0F f1);
f0 <- (VPSUB_32u8 f0 mask03);
Expand Down Expand Up @@ -4647,21 +4644,21 @@ module M(SC:Syscall_t) = {
t5 <- (get256 (WArray512.init16 (fun i_0 => a.[i_0])) ((8 * i) + 5));
t6 <- (get256 (WArray512.init16 (fun i_0 => a.[i_0])) ((8 * i) + 6));
t7 <- (get256 (WArray512.init16 (fun i_0 => a.[i_0])) ((8 * i) + 7));
tt <- (VPSLL_16u16 t1 (W8.of_int 12));
tt <- (VPSLL_16u16 t1 (W128.of_int 12));
tt <- (tt `|` t0);
t0 <- (VPSRL_16u16 t1 (W8.of_int 4));
t1 <- (VPSLL_16u16 t2 (W8.of_int 8));
t0 <- (VPSRL_16u16 t1 (W128.of_int 4));
t1 <- (VPSLL_16u16 t2 (W128.of_int 8));
t0 <- (t0 `|` t1);
t1 <- (VPSRL_16u16 t2 (W8.of_int 8));
t2 <- (VPSLL_16u16 t3 (W8.of_int 4));
t1 <- (VPSRL_16u16 t2 (W128.of_int 8));
t2 <- (VPSLL_16u16 t3 (W128.of_int 4));
t1 <- (t1 `|` t2);
t2 <- (VPSLL_16u16 t5 (W8.of_int 12));
t2 <- (VPSLL_16u16 t5 (W128.of_int 12));
t2 <- (t2 `|` t4);
t3 <- (VPSRL_16u16 t5 (W8.of_int 4));
t4 <- (VPSLL_16u16 t6 (W8.of_int 8));
t3 <- (VPSRL_16u16 t5 (W128.of_int 4));
t4 <- (VPSLL_16u16 t6 (W128.of_int 8));
t3 <- (t3 `|` t4);
t4 <- (VPSRL_16u16 t6 (W8.of_int 8));
t5 <- (VPSLL_16u16 t7 (W8.of_int 4));
t4 <- (VPSRL_16u16 t6 (W128.of_int 8));
t5 <- (VPSLL_16u16 t7 (W128.of_int 4));
t4 <- (t4 `|` t5);
(ttt, t0) <@ __shuffle1 (tt, t0);
(tt, t2) <@ __shuffle1 (t1, t2);
Expand Down Expand Up @@ -4721,8 +4718,8 @@ module M(SC:Syscall_t) = {
f1 <- (get256 (WArray512.init16 (fun i_0 => a.[i_0])) ((2 * i) + 1));
f0 <- (VPSUB_16u16 hq f0);
f1 <- (VPSUB_16u16 hq f1);
g0 <- (VPSRA_16u16 f0 (W8.of_int 15));
g1 <- (VPSRA_16u16 f1 (W8.of_int 15));
g0 <- (VPSRA_16u16 f0 (W128.of_int 15));
g1 <- (VPSRA_16u16 f1 (W128.of_int 15));
f0 <- (VPXOR_256 f0 g0);
f1 <- (VPXOR_256 f1 g1);
f0 <- (VPSUB_16u16 f0 hhq);
Expand Down Expand Up @@ -4806,7 +4803,7 @@ module M(SC:Syscall_t) = {
f <- (VPERMQ f (W8.of_int 148));
f <- (VPSHUFB_256 f shufbidx);
f <- (VPSLLV_8u32 f sllvdidx);
f <- (VPSRL_16u16 f (W8.of_int 1));
f <- (VPSRL_16u16 f (W128.of_int 1));
f <- (VPAND_256 f mask);
f <- (VPMULHRS_16u16 f q);
r <-
Expand Down Expand Up @@ -4841,7 +4838,7 @@ module M(SC:Syscall_t) = {
a <@ __polyvec_csubq (a);
x16p <- jvx16;
v <- (get256 (WArray32.init16 (fun i_0 => x16p.[i_0])) 0);
v8 <- (VPSLL_16u16 v (W8.of_int 3));
v8 <- (VPSLL_16u16 v (W128.of_int 3));
off <- (VPBROADCAST_16u16 pvc_off_s);
shift1 <- (VPBROADCAST_16u16 pvc_shift1_s);
mask <- (VPBROADCAST_16u16 pvc_mask_s);
Expand All @@ -4855,17 +4852,17 @@ module M(SC:Syscall_t) = {
f0 <- (get256 (WArray1536.init16 (fun i_0 => a.[i_0])) i);
f1 <- (VPMULL_16u16 f0 v8);
f2 <- (VPADD_16u16 f0 off);
f0 <- (VPSLL_16u16 f0 (W8.of_int 3));
f0 <- (VPSLL_16u16 f0 (W128.of_int 3));
f0 <- (VPMULH_16u16 f0 v);
f2 <- (VPSUB_16u16 f1 f2);
f1 <- (VPANDN_256 f1 f2);
f1 <- (VPSRL_16u16 f1 (W8.of_int 15));
f1 <- (VPSRL_16u16 f1 (W128.of_int 15));
f0 <- (VPSUB_16u16 f0 f1);
f0 <- (VPMULHRS_16u16 f0 shift1);
f0 <- (VPAND_256 f0 mask);
f0 <- (VPMADDWD_256 f0 shift2);
f0 <- (VPSLLV_8u32 f0 sllvdidx);
f0 <- (VPSRL_4u64 f0 (W8.of_int 12));
f0 <- (VPSRL_4u64 f0 (W128.of_int 12));
f0 <- (VPSHUFB_256 f0 shufbidx);
t0 <- (truncateu128 f0);
t1 <- (VEXTRACTI128 f0 (W8.of_int 1));
Expand Down Expand Up @@ -4901,7 +4898,7 @@ module M(SC:Syscall_t) = {
a <@ __polyvec_csubq (a);
x16p <- jvx16;
v <- (get256 (WArray32.init16 (fun i_0 => x16p.[i_0])) 0);
v8 <- (VPSLL_16u16 v (W8.of_int 3));
v8 <- (VPSLL_16u16 v (W128.of_int 3));
off <- (VPBROADCAST_16u16 pvc_off_s);
shift1 <- (VPBROADCAST_16u16 pvc_shift1_s);
mask <- (VPBROADCAST_16u16 pvc_mask_s);
Expand All @@ -4915,17 +4912,17 @@ module M(SC:Syscall_t) = {
f0 <- (get256 (WArray1536.init16 (fun i_0 => a.[i_0])) i);
f1 <- (VPMULL_16u16 f0 v8);
f2 <- (VPADD_16u16 f0 off);
f0 <- (VPSLL_16u16 f0 (W8.of_int 3));
f0 <- (VPSLL_16u16 f0 (W128.of_int 3));
f0 <- (VPMULH_16u16 f0 v);
f2 <- (VPSUB_16u16 f1 f2);
f1 <- (VPANDN_256 f1 f2);
f1 <- (VPSRL_16u16 f1 (W8.of_int 15));
f1 <- (VPSRL_16u16 f1 (W128.of_int 15));
f0 <- (VPSUB_16u16 f0 f1);
f0 <- (VPMULHRS_16u16 f0 shift1);
f0 <- (VPAND_256 f0 mask);
f0 <- (VPMADDWD_256 f0 shift2);
f0 <- (VPSLLV_8u32 f0 sllvdidx);
f0 <- (VPSRL_4u64 f0 (W8.of_int 12));
f0 <- (VPSRL_4u64 f0 (W128.of_int 12));
f0 <- (VPSHUFB_256 f0 shufbidx);
t0 <- (truncateu128 f0);
t1 <- (VEXTRACTI128 f0 (W8.of_int 1));
Expand Down Expand Up @@ -5096,8 +5093,8 @@ module M(SC:Syscall_t) = {
if ((r = 56)) {
a.[x] <- (VPSHUFB_256 a.[x] r56);
} else {
t <- (VPSLL_4u64 a.[x] (W8.of_int r));
a.[x] <- (VPSRL_4u64 a.[x] (W8.of_int (64 - r)));
t <- (VPSLL_4u64 a.[x] (W128.of_int r));
a.[x] <- (VPSRL_4u64 a.[x] (W128.of_int (64 - r)));
a.[x] <- (a.[x] `|` t);
}
}
Expand Down Expand Up @@ -5283,7 +5280,7 @@ module M(SC:Syscall_t) = {
((2 %% (2 ^ 2)) +
((2 ^ 2) *
((3 %% (2 ^ 2)) + ((2 ^ 2) * ((0 %% (2 ^ 2)) + ((2 ^ 2) * 1))))))));
t.[1] <- (c14 \vshr64u256 (W8.of_int 63));
t.[1] <- (c14 \vshr64u256 (W128.of_int 63));
t.[2] <- (c14 \vadd64u256 c14);
t.[1] <- (t.[1] `|` t.[2]);
d14 <-
Expand All @@ -5301,7 +5298,7 @@ module M(SC:Syscall_t) = {
((0 %% (2 ^ 2)) + ((2 ^ 2) * ((0 %% (2 ^ 2)) + ((2 ^ 2) * 0))))))));
c00 <- (c00 `^` state.[0]);
c00 <- (c00 `^` t.[0]);
t.[0] <- (c00 \vshr64u256 (W8.of_int 63));
t.[0] <- (c00 \vshr64u256 (W128.of_int 63));
t.[1] <- (c00 \vadd64u256 c00);
t.[1] <- (t.[1] `|` t.[0]);
state.[2] <- (state.[2] `^` d00);
Expand Down Expand Up @@ -5945,7 +5942,7 @@ module M(SC:Syscall_t) = {
((2 %% (2 ^ 2)) +
((2 ^ 2) *
((3 %% (2 ^ 2)) + ((2 ^ 2) * ((0 %% (2 ^ 2)) + ((2 ^ 2) * 1))))))));
t.[1] <- (c14 \vshr64u256 (W8.of_int 63));
t.[1] <- (c14 \vshr64u256 (W128.of_int 63));
t.[2] <- (c14 \vadd64u256 c14);
t.[1] <- (t.[1] `|` t.[2]);
d14 <-
Expand All @@ -5963,7 +5960,7 @@ module M(SC:Syscall_t) = {
((0 %% (2 ^ 2)) + ((2 ^ 2) * ((0 %% (2 ^ 2)) + ((2 ^ 2) * 0))))))));
c00 <- (c00 `^` state.[0]);
c00 <- (c00 `^` t.[0]);
t.[0] <- (c00 \vshr64u256 (W8.of_int 63));
t.[0] <- (c00 \vshr64u256 (W128.of_int 63));
t.[1] <- (c00 \vadd64u256 c00);
t.[1] <- (t.[1] `|` t.[0]);
state.[2] <- (state.[2] `^` d00);
Expand Down Expand Up @@ -7195,8 +7192,8 @@ module M(SC:Syscall_t) = {
((1 %% (2 ^ 2)) + ((2 ^ 2) * ((1 %% (2 ^ 2)) + ((2 ^ 2) * 2))))))));
f0 <- (VPSHUFB_256 f0 load_shuffle);
f1 <- (VPSHUFB_256 f1 load_shuffle);
g0 <- (VPSRL_16u16 f0 (W8.of_int 4));
g1 <- (VPSRL_16u16 f1 (W8.of_int 4));
g0 <- (VPSRL_16u16 f0 (W128.of_int 4));
g1 <- (VPSRL_16u16 f1 (W128.of_int 4));
f0 <- (VPBLEND_16u16 f0 g0 (W8.of_int 170));
f1 <- (VPBLEND_16u16 f1 g1 (W8.of_int 170));
f0 <- (VPAND_256 f0 mask);
Expand Down Expand Up @@ -7364,7 +7361,7 @@ module M(SC:Syscall_t) = {
((2 ^ 2) *
((1 %% (2 ^ 2)) + ((2 ^ 2) * ((1 %% (2 ^ 2)) + ((2 ^ 2) * 2))))))));
f0 <- (VPSHUFB_256 f0 load_shuffle);
g0 <- (VPSRL_16u16 f0 (W8.of_int 4));
g0 <- (VPSRL_16u16 f0 (W128.of_int 4));
f0 <- (VPBLEND_16u16 f0 g0 (W8.of_int 170));
f0 <- (VPAND_256 f0 mask);
g0 <- (VPCMPGT_16u16 bounds f0);
Expand Down
9 changes: 3 additions & 6 deletions code/jasmin/mlkem_ref/extraction/jkem.ec
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,9 @@ import SLH64.

require import
Array4 Array5 Array24 Array25 Array32 Array33 Array34 Array64 Array128
Array168 Array256 Array768 Array960 Array1088 Array2304.

require import
WArray20 WArray32 WArray33 WArray34 WArray40 WArray64 WArray128 WArray168
WArray192 WArray200 WArray256 WArray512 WArray960 WArray1088 WArray1536
WArray4608.
Array168 Array256 Array768 Array960 Array1088 Array2304 WArray20 WArray32
WArray33 WArray34 WArray40 WArray64 WArray128 WArray168 WArray192 WArray200
WArray256 WArray512 WArray960 WArray1088 WArray1536 WArray4608.

abbrev jzetas_inv =
(Array128.of_list witness
Expand Down
2 changes: 1 addition & 1 deletion jasmin
Submodule jasmin updated 121 files
Loading

0 comments on commit b4198d1

Please sign in to comment.