Skip to content

Commit

Permalink
[prim, rom_ctrl] Remove S&P layer from data scrrambling
Browse files Browse the repository at this point in the history
As elaborated on #20788, the S&P layer is removed from data scrambling
in order to improve error detection guarantees, interactions with ECC
and timing.

Signed-off-by: Michael Schaffner <[email protected]>
  • Loading branch information
msfschaffner committed Jan 17, 2024
1 parent f38b01f commit 24da245
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 153 deletions.
1 change: 0 additions & 1 deletion hw/dv/sv/mem_bkdr_util/mem_bkdr_util__rom.sv
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ virtual function bit [38:0] rom_encrypt_read32(bit [bus_params_pkg::BUS_AW-1:0]
zero_key[i] = '0;
end

data_arr = sram_scrambler_pkg::sp_decrypt(data_arr, 39, zero_key);
for (int i = 0; i < 39; i++) begin
data[i] = data_arr[i] ^ keystream[i];
end
Expand Down
60 changes: 1 addition & 59 deletions hw/dv/sv/mem_bkdr_util/sram_scrambler_pkg.sv
Original file line number Diff line number Diff line change
Expand Up @@ -262,34 +262,6 @@ package sram_scrambler_pkg;
data_enc[i] = data[i] ^ keystream[i % ks_width];
end

if (data_width == sp_width) begin
// pass the entire word through the subst/perm network at once (the next cases would give the
// same results too, but this should be a bit more efficient)
data_enc = sp_encrypt(data_enc, data_width, zero_key);
end else if (sp_width == 8) begin
// pass each byte of the encoded result through the subst/perm network (special case of the
// general code below)
for (int i = 0; i < data_width / 8; i++) begin
byte_to_enc = data_enc[i*8 +: 8];
enc_byte = sp_encrypt(byte_to_enc, 8, zero_key);
data_enc[i*8 +: 8] = enc_byte;
end
end else begin
// divide the word into sp_width chunks to pass it through the subst/perm network
for (int chunk_lsb = 0; chunk_lsb < data_width; chunk_lsb += sp_width) begin
int bits_remaining = data_width - chunk_lsb;
int chunk_width = (bits_remaining < sp_width) ? bits_remaining : sp_width;
logic chunk[] = new[chunk_width];

for (int j = 0; j < chunk_width; j++) begin
chunk[j] = data_enc[chunk_lsb + j];
end
chunk = sp_encrypt(chunk, chunk_width, zero_key);
for (int j = 0; j < chunk_width; j++) begin
data_enc[chunk_lsb + j] = chunk[j];
end
end
end
return data_enc;

endfunction : encrypt_sram_data
Expand All @@ -311,39 +283,9 @@ package sram_scrambler_pkg;

// Generate the keystream
keystream = gen_keystream(addr, addr_width, key, nonce);

if (data_width == sp_width) begin
// pass the entire word through the subst/perm network at once (the next cases would give the
// same results too, but this should be a bit more efficient)
data_dec = sp_decrypt(data, data_width, zero_key);
end else if (sp_width == 8) begin
// pass each byte of the data through the subst/perm network (special case of the general code
// below)
for (int i = 0; i < data_width / 8; i++) begin
byte_to_dec = data[i*8 +: 8];
dec_byte = sp_decrypt(byte_to_dec, 8, zero_key);
data_dec[i*8 +: 8] = dec_byte;
end
end else begin
// divide the word into sp_width chunks to pass it through the subst/perm network
for (int chunk_lsb = 0; chunk_lsb < data_width; chunk_lsb += sp_width) begin
int bits_remaining = data_width - chunk_lsb;
int chunk_width = (bits_remaining < sp_width) ? bits_remaining : sp_width;
logic chunk[] = new[chunk_width];

for (int j = 0; j < chunk_width; j++) begin
chunk[j] = data[chunk_lsb + j];
end
chunk = sp_decrypt(chunk, chunk_width, zero_key);
for (int j = 0; j < chunk_width; j++) begin
data_dec[chunk_lsb + j] = chunk[j];
end
end
end

// XOR result data with the keystream
for (int i = 0; i < data_width; i++) begin
data_dec[i] = data_dec[i] ^ keystream[i % ks_width];
data_dec[i] = data[i] ^ keystream[i % ks_width];
end

return data_dec;
Expand Down
4 changes: 1 addition & 3 deletions hw/ip/otbn/rtl/otbn.sv
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,7 @@ module otbn
.Width (39),
.Depth (ImemSizeWords),
.DataBitsPerMask(39),
.EnableParity (0),
.DiffWidth (39)
.EnableParity (0)
) u_imem (
.clk_i,
.rst_ni(rst_n),
Expand Down Expand Up @@ -535,7 +534,6 @@ module otbn
.Depth (DmemSizeWords),
.DataBitsPerMask (39),
.EnableParity (0),
.DiffWidth (39),
.ReplicateKeyStream(1)
) u_dmem (
.clk_i,
Expand Down
74 changes: 12 additions & 62 deletions hw/ip/prim/rtl/prim_ram_1p_scr.sv
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,6 @@ module prim_ram_1p_scr import prim_ram_1p_pkg::*; #(
// to 2*5 + 1 effective rounds. Setting this to 2 halves this to approximately 5 effective rounds.
// Number of PRINCE half rounds, can be [1..5]
parameter int NumPrinceRoundsHalf = 2,
// Number of extra diffusion rounds. Setting this to 0 to disable diffusion.
parameter int NumDiffRounds = 2,
// This parameter governs the block-width of additional diffusion layers.
// For intra-byte diffusion, set this parameter to 8.
parameter int DiffWidth = DataBitsPerMask,
// Number of address scrambling rounds. Setting this to 0 disables address scrambling.
parameter int NumAddrScrRounds = 2,
// If set to 1, the same 64bit key stream is replicated if the data port is wider than 64bit.
Expand All @@ -55,7 +50,10 @@ module prim_ram_1p_scr import prim_ram_1p_pkg::*; #(
// use the same key, but they use a different IV
localparam int DataKeyWidth = 128,
// Each 64 bit scrambling primitive requires a 64bit IV
localparam int NonceWidth = 64 * NumParScr
localparam int NonceWidth = 64 * NumParScr,
// TO BE REMOVED
localparam int DiffWidth = 32,
localparam int NumDiffRounds = 2
) (
input clk_i,
input rst_ni,
Expand Down Expand Up @@ -89,8 +87,6 @@ module prim_ram_1p_scr import prim_ram_1p_pkg::*; #(

// The depth needs to be a power of 2 in case address scrambling is turned on
`ASSERT_INIT(DepthPow2Check_A, NumAddrScrRounds <= '0 || 2**$clog2(Depth) == Depth)
`ASSERT_INIT(DiffWidthMinimum_A, DiffWidth >= 4)
`ASSERT_INIT(DiffWidthWithParity_A, EnableParity && (DiffWidth == 8) || !EnableParity)

/////////////////////////////////////////
// Pending Write and Address Registers //
Expand Down Expand Up @@ -211,67 +207,21 @@ module prim_ram_1p_scr import prim_ram_1p_pkg::*; #(
end
end

// Replicate keystream if needed
logic [Width-1:0] keystream_repl;
assign keystream_repl = Width'({NumParKeystr{keystream}});

/////////////////////
// Data Scrambling //
/////////////////////

// Data scrambling is a two step process. First, we XOR the write data with the keystream obtained
// by operating a reduced-round PRINCE cipher in CTR-mode. Then, we diffuse data within each byte
// in order to get a limited "avalanche" behavior in case parts of the bytes are flipped as a
// result of a malicious attempt to tamper with the data in memory. We perform the diffusion only
// within bytes in order to maintain the ability to write individual bytes. Note that the
// keystream XOR is performed first for the write path such that it can be performed last for the
// read path. This allows us to hide a part of the combinational delay of the PRINCE primitive
// behind the propagation delay of the SRAM macro and the per-byte diffusion step.
// Data scrambling XOR's the write data with the keystream obtained by operating a reduced-round
// PRINCE cipher in CTR-mode.

// Replicate keystream if needed
logic [Width-1:0] keystream_repl;
assign keystream_repl = Width'({NumParKeystr{keystream}});

logic [Width-1:0] rdata_scr, rdata;
logic [Width-1:0] wdata_scr_d, wdata_scr_q, wdata_q;
for (genvar k = 0; k < (Width + DiffWidth - 1) / DiffWidth; k++) begin : gen_diffuse_data
// If the Width is not divisible by DiffWidth, we need to adjust the width of the last slice.
localparam int LocalWidth = (Width - k * DiffWidth >= DiffWidth) ? DiffWidth :
(Width - k * DiffWidth);

// Write path. Note that since this does not fan out into the interconnect, the write path is
// not as critical as the read path below in terms of timing.
// Apply the keystream first
logic [LocalWidth-1:0] wdata_xor;
assign wdata_xor = wdata_q[k*DiffWidth +: LocalWidth] ^
keystream_repl[k*DiffWidth +: LocalWidth];

// Byte aligned diffusion using a substitution / permutation network
prim_subst_perm #(
.DataWidth ( LocalWidth ),
.NumRounds ( NumDiffRounds ),
.Decrypt ( 0 )
) u_prim_subst_perm_enc (
.data_i ( wdata_xor ),
.key_i ( '0 ),
.data_o ( wdata_scr_d[k*DiffWidth +: LocalWidth] )
);

// Read path. This is timing critical. The keystream XOR operation is performed last in order to
// hide the combinational delay of the PRINCE primitive behind the propagation delay of the
// SRAM and the byte diffusion.
// Reverse diffusion first
logic [LocalWidth-1:0] rdata_xor;
prim_subst_perm #(
.DataWidth ( LocalWidth ),
.NumRounds ( NumDiffRounds ),
.Decrypt ( 1 )
) u_prim_subst_perm_dec (
.data_i ( rdata_scr[k*DiffWidth +: LocalWidth] ),
.key_i ( '0 ),
.data_o ( rdata_xor )
);

// Apply Keystream, replicate it if needed
assign rdata[k*DiffWidth +: LocalWidth] = rdata_xor ^
keystream_repl[k*DiffWidth +: LocalWidth];
end
assign wdata_scr_d = wdata_q ^ keystream_repl;
assign rdata = rdata_scr ^ keystream_repl;

////////////////////////////////////////////////
// Scrambled data register and forwarding mux //
Expand Down
16 changes: 1 addition & 15 deletions hw/ip/rom_ctrl/rtl/rom_ctrl_scrambled_rom.sv
Original file line number Diff line number Diff line change
Expand Up @@ -148,22 +148,8 @@ module rom_ctrl_scrambled_rom

assign scr_rdata_o = rdata_scr;

// Data scrambling ===========================================================

logic [Width-1:0] rdata_xor;

prim_subst_perm #(
.DataWidth (Width),
.NumRounds (2),
.Decrypt (1)
) u_sp_data (
.data_i (rdata_scr),
.key_i ('0),
.data_o (rdata_xor)
);

// XOR rdata with keystream ==================================================

assign clr_rdata_o = rdata_xor ^ keystream[Width-1:0];
assign clr_rdata_o = rdata_scr ^ keystream[Width-1:0];

endmodule
15 changes: 4 additions & 11 deletions hw/ip/rom_ctrl/util/scramble_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,21 +230,14 @@ def addr_sp_dec(self, phy_addr: int) -> int:
return subst_perm_dec(phy_addr, addr_scr_nonce, self._addr_width,
self.subst_perm_rounds)

def data_sp_enc(self, width: int, data: int) -> int:
return subst_perm_enc(data, 0, width, self.subst_perm_rounds)

def data_sp_dec(self, width: int, data: int) -> int:
return subst_perm_dec(data, 0, width, self.subst_perm_rounds)

def scramble_word(self, width: int, log_addr: int, clr_data: int) -> int:
'''Scramble clr_data at the given logical address.'''
keystream = self.get_keystream(log_addr, width)
return self.data_sp_enc(width, keystream ^ clr_data)
return keystream ^ clr_data

def unscramble_word(self, width: int, log_addr: int, scr_data: int) -> int:
keystream = self.get_keystream(log_addr, width)
sp_scr_data = self.data_sp_dec(width, scr_data)
return keystream ^ sp_scr_data
return keystream ^ scr_data

def scramble(self, mem: MemFile) -> MemFile:
assert len(mem.chunks) == 1
Expand All @@ -259,11 +252,11 @@ def scramble(self, mem: MemFile) -> MemFile:
#
# Then, for all i, we have:
#
# clr[i] = PRINCE(i) ^ data_sp_dec(scr[addr_sp_enc(i)])
# clr[i] = PRINCE(i) ^ scr[addr_sp_enc(i)]
#
# Change coordinates by evaluating at addr_sp_dec(i):
#
# clr[addr_sp_dec(i)] = PRINCE(addr_sp_dec(i)) ^ data_sp_dec(scr[i])
# clr[addr_sp_dec(i)] = PRINCE(addr_sp_dec(i)) ^ scr[i]
#
# so
#
Expand Down
3 changes: 1 addition & 2 deletions hw/ip/sram_ctrl/rtl/sram_ctrl.sv
Original file line number Diff line number Diff line change
Expand Up @@ -492,8 +492,7 @@ module sram_ctrl
.Width(DataWidth),
.Depth(Depth),
.EnableParity(0),
.DataBitsPerMask(DataWidth),
.DiffWidth(DataWidth)
.DataBitsPerMask(DataWidth)
) u_prim_ram_1p_scr (
.clk_i,
.rst_ni,
Expand Down

0 comments on commit 24da245

Please sign in to comment.