Skip to content

Commit

Permalink
Merge pull request #173 from simple-crypto/fix_sasca
Browse files Browse the repository at this point in the history
Fix misc sasca bugs with PUB and/or generic factors
  • Loading branch information
cassiersg authored Aug 28, 2024
2 parents 57895bb + 3b961ad commit f2ffe70
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 51 deletions.
73 changes: 41 additions & 32 deletions src/scalib_ext/scalib/src/sasca/belief_propagation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use std::sync::Arc;
use itertools::Itertools;
use thiserror::Error;

use crate::sasca::factor_graph::GenFactorOperand;

use super::factor_graph as fg;
use super::factor_graph::{
EdgeId, EdgeSlice, EdgeVec, ExprFactor, Factor, FactorGraph, FactorId, FactorKind, FactorVec,
Expand Down Expand Up @@ -298,7 +300,11 @@ impl BPState {
prop_factor!(factor_gen_and, &self.pub_reduced[factor_id])
}
ExprFactor::XOR => prop_factor!(factor_xor, &self.pub_reduced[factor_id]),
ExprFactor::NOT => prop_factor!(factor_not, (self.graph.nc - 1) as u32),
ExprFactor::NOT => prop_factor!(
factor_not,
&self.pub_reduced[factor_id],
(self.graph.nc - 1) as u32
),
ExprFactor::ADD { .. } => {
prop_factor!(factor_add, &self.pub_reduced[factor_id], &self.plans)
}
Expand Down Expand Up @@ -638,14 +644,15 @@ fn factor_not<'a>(
belief_from_var: &'a mut EdgeSlice<Distribution>,
dest: &'a [VarId],
clear_incoming: bool,
pub_reduced: &PublicValue,
inv_cst: u32,
) -> impl Iterator<Item = Distribution> + 'a {
factor_xor(
factor,
belief_from_var,
dest,
clear_incoming,
&PublicValue::Single(inv_cst),
&pub_reduced.map(|x| x ^ inv_cst),
)
}

Expand Down Expand Up @@ -910,9 +917,11 @@ fn factor_gen_factor<'a>(
};
let res: Vec<Distribution> = dest.iter().map(|dest| {
let dest_idx = factor.edges.get_index_of(dest).unwrap();
let op_dest_idx = operands.iter().position(|op| if let GenFactorOperand::Var { factor_edge_id, .. } = op { *factor_edge_id == dest_idx } else { false }).expect("must have dest operand");
let mut distr = belief_from_var[factor.edges[dest_idx]].clone();
distr.ensure_full();
for i in 0..nmulti {
let nmulti_actual = if factor.multi { nmulti } else { 1 };
for i in 0..nmulti_actual {
let gen_factor = match gen_factor {
GenFactor::Single(x) => x,
GenFactor::Multi(x) => &x[i],
Expand All @@ -922,10 +931,10 @@ fn factor_gen_factor<'a>(
assert_eq!(gen_factor.shape().len(), operands.len());
// First slice the array with the constants.
let gen_factor = gen_factor.slice_each_axis(|ax| match operands[ax.axis.index()] {
fg::GenFactorOperand::Var(_, _) => ndarray::Slice::new(0, None, 1),
fg::GenFactorOperand::Pub(pub_idx) => {
let mut pub_val = public_values[factor.publics[pub_idx].0].get(i) as isize;
if factor.publics[pub_idx].1 {
fg::GenFactorOperand::Var { ..} => ndarray::Slice::new(0, None, 1),
fg::GenFactorOperand::Pub { pub_id } => {
let mut pub_val = public_values[factor.publics[pub_id].0].get(i) as isize;
if factor.publics[pub_id].1 {
if nc.is_power_of_two() {
pub_val = !pub_val;
} else {
Expand All @@ -938,12 +947,12 @@ fn factor_gen_factor<'a>(
});
let mut gen_factor = gen_factor.to_owned();
for (op_idx, op) in operands.iter().enumerate() {
if op_idx != dest_idx {
if let fg::GenFactorOperand::Var(var_idx, neg) = op {
if *neg {
if let fg::GenFactorOperand::Var { factor_edge_id, negated } = op {
if *factor_edge_id != dest_idx {
if *negated {
todo!("Negated operands on generalized factors not yet implemented.");
}
let distr = &belief_from_var[factor.edges[*var_idx]];
let distr = &belief_from_var[factor.edges[*factor_edge_id]];
let mut new_gen_factor: ndarray::ArrayD<f64> = ndarray::ArrayD::zeros(gen_factor.slice_axis(ndarray::Axis(op_idx), ndarray::Slice::new(0, Some(1), 1)).shape());
if let Some(distr) = distr.value() {
for (d, gf) in distr.slice(s![i,..]).iter().zip(gen_factor.axis_chunks_iter(ndarray::Axis(op_idx), 1)) {
Expand All @@ -959,10 +968,10 @@ fn factor_gen_factor<'a>(
}
}
// Drop useless axes.
for _ in 0..dest_idx {
for _ in 0..op_dest_idx {
gen_factor.index_axis_inplace(ndarray::Axis(0), 0);
}
for _ in (dest_idx+1)..operands.len() {
for _ in (op_dest_idx+1)..operands.len() {
gen_factor.index_axis_inplace(ndarray::Axis(1), 0);
}
distr.value_mut().unwrap().slice_mut(s![i,..]).assign(&gen_factor);
Expand All @@ -974,43 +983,43 @@ fn factor_gen_factor<'a>(
dest.fill(0.0);
for op_values in gen_factor.outer_iter() {
let mut res = 1.0;
for (op_idx, (op, val)) in operands.iter().zip(op_values.iter()).enumerate() {
if op_idx != dest_idx {
match op {
fg::GenFactorOperand::Var(var_idx, neg) => {
for (op, val) in operands.iter().zip(op_values.iter()) {
match op {
fg::GenFactorOperand::Var { factor_edge_id, negated} => {
if *factor_edge_id != dest_idx {
let mut val = *val;
if *neg {
if *negated {
if nc.is_power_of_two() {
val = !val & ((nc - 1) as ClassVal);
} else {
// TODO Check that we enforce this at graph creation time and return a proper error.
panic!("Cannot negate operands with non-power-of-two number of classes.");
}
}
let distr = &belief_from_var[factor.edges[*var_idx]];
let distr = &belief_from_var[factor.edges[*factor_edge_id]];
// For uniform, we implicitly multiply by 1.0
if let Some(distr) = distr.value() {
res *= distr[(i, val as usize)];
}
}
fg::GenFactorOperand::Pub(pub_idx) => {
let mut pub_val = public_values[factor.publics[*pub_idx].0].get(i);
if factor.publics[*pub_idx].1 {
if nc.is_power_of_two() {
pub_val = !pub_val & ((nc - 1) as ClassVal);
} else {
// TODO Check that we enforce this at graph creation time and return a proper error.
panic!("Cannot negate operands with non-power-of-two number of classes.");
}
}
if pub_val != *val {
res = 0.0;
}
fg::GenFactorOperand::Pub{pub_id} => {
let mut pub_val = public_values[factor.publics[*pub_id].0].get(i);
if factor.publics[*pub_id].1 {
if nc.is_power_of_two() {
pub_val = !pub_val & ((nc - 1) as ClassVal);
} else {
// TODO Check that we enforce this at graph creation time and return a proper error.
panic!("Cannot negate operands with non-power-of-two number of classes.");
}
}
if pub_val != *val {
res = 0.0;
}
}
}
}
dest[op_values[dest_idx] as usize] += res;
dest[op_values[op_dest_idx] as usize] += res;
}
}
}
Expand Down
39 changes: 23 additions & 16 deletions src/scalib_ext/scalib/src/sasca/factor_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,13 @@ pub(super) enum ExprFactor {

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub(super) enum GenFactorOperand {
Var(usize, bool),
Pub(usize),
Var {
factor_edge_id: usize,
negated: bool,
},
Pub {
pub_id: usize,
},
}

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
Expand Down Expand Up @@ -374,8 +379,10 @@ impl FactorGraph {
let ops: Vec<&PublicValue> = operands
.iter()
.map(|op| match op {
GenFactorOperand::Var(idx, ..) => &var_assignments[*idx],
GenFactorOperand::Pub(idx) => &public_values[*idx],
GenFactorOperand::Var { factor_edge_id, .. } => {
&var_assignments[*factor_edge_id]
}
GenFactorOperand::Pub { pub_id } => &public_values[*pub_id],
})
.collect();
let nmulti_ops = ops.iter().find_map(|op| {
Expand Down Expand Up @@ -451,25 +458,25 @@ impl FactorGraph {
self.factors
.values()
.map(|factor| {
let mut pubs = factor
.publics
.iter()
.map(|(pub_id, nv)| (&public_values[*pub_id], *nv));
match &factor.kind {
// Not used
FactorKind::Assign {
expr: ExprFactor::NOT,
..
}
| FactorKind::Assign {
expr: ExprFactor::LOOKUP { .. },
..
}
| FactorKind::GenFactor { .. } => PublicValue::Single(0),
FactorKind::Assign { expr, has_res } => self.merge_pubs(
expr,
!has_res,
factor
.publics
.iter()
.map(|(pub_id, nv)| (&public_values[*pub_id], *nv)),
),
FactorKind::Assign {
expr: ExprFactor::NOT,
..
} => pubs
.next()
.map(|(val, _)| val.clone())
.unwrap_or(PublicValue::Single(0)),
FactorKind::Assign { expr, has_res } => self.merge_pubs(expr, !has_res, pubs),
}
})
.collect()
Expand Down
7 changes: 5 additions & 2 deletions src/scalib_ext/scalib/src/sasca/fg_build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,13 @@ impl fg::FactorGraph {
let mut operands = Vec::new();
for (i, p) in is_pub.iter().enumerate() {
if *p {
operands.push(fg::GenFactorOperand::Pub(n_pubs));
operands.push(fg::GenFactorOperand::Pub { pub_id: n_pubs });
n_pubs += 1;
} else {
operands.push(fg::GenFactorOperand::Var(n_vars, vars[i].neg));
operands.push(fg::GenFactorOperand::Var {
factor_edge_id: n_vars,
negated: vars[i].neg,
});
n_vars += 1;
}
}
Expand Down
76 changes: 75 additions & 1 deletion tests/test_factorgraph.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import pytest
from scalib.attacks import FactorGraph, BPState
from scalib.attacks import FactorGraph, BPState, GenFactor
import numpy as np
import os
import copy
import itertools as it


def normalize_distr(x):
Expand Down Expand Up @@ -1778,3 +1779,76 @@ def test_cycle_detection_single_factor_with_multi():
fg = FactorGraph(graph_desc)
bp = BPState(fg, 2)
assert bp.is_cyclic()


def test_generic_single_multi():
nc = 2
n_exec = 2
graph_desc = f"""
NC {nc}
VAR SINGLE A
VAR SINGLE B
VAR SINGLE C
GENERIC SINGLE XOR
PROPERTY XOR(A,B,C)
"""

def xor(a, b):
return a ^ b

fg = FactorGraph(graph_desc)

XOR = np.array(
[[a, b, a ^ b] for a, b in it.product(range(nc), repeat=2)],
dtype=np.uint32,
)
gen_factors = {
"XOR": GenFactor.sparse_functional(XOR),
}

bp = BPState(fg, n_exec, gen_factors=gen_factors)
bp.bp_loopy(1, True)


def test_factor_not_pub():
nc = 4
graph_desc = f"""
NC {nc}
PUB SINGLE A
VAR SINGLE B
PROPERTY B = !A
"""
fg = FactorGraph(graph_desc)
for a in range(nc):
bp = BPState(fg, 1, public_values={"A": a})
bp.bp_acyclic("B")
result = bp.get_distribution("B")
assert np.argmax(result) == (nc - 1) ^ a


def test_factor_gen_pub():
nc = 2
graph_desc = f"""
NC {nc}
PUB SINGLE A
VAR SINGLE B
GENERIC SINGLE NOT
PROPERTY NOT(A,B)
"""
fg = FactorGraph(graph_desc)
not_factors = [
GenFactor.sparse_functional(
np.array([(a, (nc - 1) ^ a) for a in range(nc)], dtype=np.uint32)
),
GenFactor.dense(np.array([[0, 1], [1, 0]], dtype=np.float64)),
]
for nf in not_factors:
for a in range(nc):
bp = BPState(fg, 1, public_values={"A": a}, gen_factors={"NOT": nf})
bp.bp_acyclic("B")
result = bp.get_distribution("B")
assert np.argmax(result) == (nc - 1) ^ a

0 comments on commit f2ffe70

Please sign in to comment.