Skip to content

Commit

Permalink
[otbn] Allow sim to compare result to dmem
Browse files Browse the repository at this point in the history
Signed-off-by: Amin Abdulrahman <[email protected]>
  • Loading branch information
dop-amin authored and moidx committed Dec 16, 2024
1 parent 18d5f43 commit e2ae6e7
Show file tree
Hide file tree
Showing 8 changed files with 164 additions and 50 deletions.
2 changes: 2 additions & 0 deletions hw/ip/otbn/util/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ py_binary(
srcs = ["otbn_sim_test.py"],
deps = [
"//hw/ip/otbn/util/shared:check",
"//hw/ip/otbn/util/shared:mem_layout",
"//hw/ip/otbn/util/shared:reg_dump",
requirement("pyelftools"),
],
)
112 changes: 86 additions & 26 deletions hw/ip/otbn/util/otbn_sim_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,15 @@
import subprocess
import sys
from enum import IntEnum
from typing import List
from typing import Dict, List
import tempfile

from elftools.elf.elffile import ELFFile # type: ignore
from elftools.elf.sections import SymbolTableSection # type: ignore

from shared.check import CheckResult
from shared.reg_dump import parse_reg_dump
from shared.dmem_dump import parse_dmem_exp, parse_actual_dmem

# Names of special registers
ERR_BITS = 'ERR_BITS'
Expand Down Expand Up @@ -50,11 +55,24 @@ def get_err_names(err: int) -> List[str]:
return out


def _get_symbol_addr_map(elf_file: ELFFile) -> Dict[int, str]:
section = elf_file.get_section_by_name('.symtab')

if not isinstance(section, SymbolTableSection):
return {}

# Filter lables and offsets from data section
return {
sym.name: sym.entry.st_value
for sym in section.iter_symbols() if sym.entry['st_shndx'] == 2
}


def main() -> int:
parser = argparse.ArgumentParser()
parser.add_argument('simulator',
help='Path to the standalone OTBN simulator.')
parser.add_argument('expected',
parser.add_argument('--expected_regs',
metavar='FILE',
type=argparse.FileType('r'),
help=(f'File containing expected register values. '
Expand All @@ -63,23 +81,47 @@ def main() -> int:
f'{ERR_BITS} is not listed, the test will assume '
f'there are no errors expected (i.e. {ERR_BITS}'
f'= 0).'))
parser.add_argument('--expected_dmem',
metavar='FILE',
type=argparse.FileType('r'),
help=('File containing expected dmem values. '
'Addresses that are not listed are allowed to '
'have any value.'))
parser.add_argument('elf',
help='Path to the .elf file for the OTBN program.')
parser.add_argument('-v', '--verbose', action='store_true')
args = parser.parse_args()

# Parse expected values.
result = CheckResult()
expected_regs = parse_reg_dump(args.expected.read())

# Run the simulation and produce a register dump.
cmd = [args.simulator, '--dump-regs', '-', args.elf]
sim_proc = subprocess.run(cmd, check=True,
stdout=subprocess.PIPE, universal_newlines=True)
actual_regs = parse_reg_dump(sim_proc.stdout)
with tempfile.NamedTemporaryFile() as regs_file, tempfile.NamedTemporaryFile() as dmem_file:
cmd = [
args.simulator,
"--dump-regs",
regs_file.name,
"--dump-dmem",
dmem_file.name,
args.elf,
]
# Run the simulation and produce a register and dmem dump.
subprocess.run(
cmd, check=True, universal_newlines=True
)

if args.expected_dmem is not None:
dmem_file.seek(0)
actual_dmem = parse_actual_dmem(dmem_file.read())
expected_dmem = parse_dmem_exp(args.expected_dmem.read())

actual_regs = parse_reg_dump(regs_file.read().decode('utf-8'))

expected_err = 0
if args.expected_regs is not None:
expected_regs = parse_reg_dump(args.expected_regs.read())
expected_err = expected_regs.get(ERR_BITS, 0)

# Special handling for the ERR_BITS register.
expected_err = expected_regs.get(ERR_BITS, 0)
actual_err = actual_regs[ERR_BITS]
insn_cnt = actual_regs[INSN_CNT]
stop_pc = actual_regs[STOP_PC]
Expand All @@ -88,24 +130,42 @@ def main() -> int:
# case, give a special error message and exit rather than print all the
# mismatched registers.
if actual_err != 0:
err_names = ', '.join(get_err_names(actual_err))
result.err(f'OTBN encountered an unexpected error: {err_names}.\n'
f' {ERR_BITS}\t= {actual_err:#010x}\n'
f' {INSN_CNT}\t= {insn_cnt:#010x}\n'
f' {STOP_PC}\t= {stop_pc:#010x}')
err_names = ", ".join(get_err_names(actual_err))
result.err(f"OTBN encountered an unexpected error: {err_names}.\n"
f" {ERR_BITS}\t= {actual_err:#010x}\n"
f" {INSN_CNT}\t= {insn_cnt:#010x}\n"
f" {STOP_PC}\t= {stop_pc:#010x}")

else:
for reg, expected_value in expected_regs.items():
actual_value = actual_regs.get(reg, None)
if actual_value != expected_value:
if reg.startswith('w'):
expected_str = f'{expected_value:#066x}'
actual_str = f'{actual_value:#066x}'
else:
expected_str = f'{expected_value:#010x}'
actual_str = f'{actual_value:#010x}'
result.err(f'Mismatch for register {reg}:\n'
f' Expected: {expected_str}\n'
f' Actual: {actual_str}')
if args.expected_regs is not None:
for reg, expected_value in expected_regs.items():
actual_value = actual_regs.get(reg, None)
if actual_value != expected_value:
if reg.startswith("w"):
expected_str = f"{expected_value:#066x}"
actual_str = f"{actual_value:#066x}"
else:
expected_str = f"{expected_value:#010x}"
actual_str = f"{actual_value:#010x}"
result.err(f"Mismatch for register {reg}:\n"
f" Expected: {expected_str}\n"
f" Actual: {actual_str}")

if args.expected_dmem is not None:
elf_file = ELFFile(open(args.elf, 'rb'))
symbol_addr_map = _get_symbol_addr_map(elf_file)

for label, value in expected_dmem.items():
try:
offset = symbol_addr_map[label]
if actual_dmem[offset:offset + len(value)] != value:
result.err(
f"Mismatch for dmem {label}:\n"
f" Expected: {value.hex()}\n"
f" Actual: {actual_dmem[offset:offset+len(value)].hex()}"
)
except KeyError:
result.err(f'No label "{label}" found in elf-file.')

if result.has_errors() or result.has_warnings() or args.verbose:
print(result.report())
Expand Down
59 changes: 59 additions & 0 deletions hw/ip/otbn/util/shared/dmem_dump.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright lowRISC contributors (OpenTitan project).
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
# SPDX-License-Identifier: Apache-2.0

import re
import struct
from typing import Dict

from hw.ip.otbn.util.shared.mem_layout import get_memory_layout

_DMEM_RE = re.compile(
r'\s*(?P<label>[a-zA-Z0-9_]+)\s*:\s*(?P<val>(:?[0-9a-f]+))$')


def parse_dmem_exp(dump: str) -> Dict[str, int]:
'''Parse the expected dmem.
Format:
label: hex_data
Returns a dictionary mapping labels to the expected bytes.
'''

out = {}
for line in dump.split('\n'):
# Remove comments and ignore blank lines.
line = line.split('#', 1)[0].strip()
if not line:
continue
m = _DMEM_RE.match(line)
if not m:
raise ValueError(f'Failed to parse dmem dump line ({line}).')
label = m.group('label')
value = bytes.fromhex(m.group('val'))

if label in out:
raise ValueError(f'DMEM dump contains multiple values '
f'for {label}.')
out[label] = value

return out


def parse_actual_dmem(dump: bytes) -> bytes:
'''Parse the dmem dump.
Returns the dmem bytes except integrity info.
'''
dmem_bytes = []
# 8 32-bit data words + 1 byte integrity info per word = 40 bytes
bytes_w_integrity = 8 * 4 + 8
for w in struct.iter_unpack(f"<{bytes_w_integrity}s", dump):
tmp = []
# discard byte indicating integrity status
for v in struct.iter_unpack("<BI", w[0]):
tmp += [x for x in struct.unpack("4B", v[1].to_bytes(4, "big"))]
dmem_bytes += tmp
assert len(dmem_bytes) == get_memory_layout().dmem_size_bytes
return bytes(dmem_bytes)
20 changes: 15 additions & 5 deletions rules/otbn.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -147,24 +147,33 @@ def _otbn_binary(ctx, additional_srcs = []):
),
]

def _run_sim_test(ctx, exp, additional_srcs = []):
def _run_sim_test(ctx, exp, dexp, additional_srcs = []):
providers = _otbn_binary(ctx, additional_srcs)

# Extract the output .elf file from the output group.
elf = providers[1].elf.to_list()[0]

exp_content = ""
exp_files = []
if exp != None:
exp_content += "--expected_regs {} ".format(exp.short_path)
exp_files.append(exp)
if dexp != None:
exp_content += "--expected_dmem {}".format(dexp.short_path)
exp_files.append(dexp)

# Create a simple script that runs the OTBN test wrapper on the .elf file
# using the provided simulator path.
sim_test_wrapper = ctx.executable._sim_test_wrapper
simulator = ctx.executable._simulator
ctx.actions.write(
output = ctx.outputs.executable,
content = "{} {} {} {}".format(sim_test_wrapper.short_path, simulator.short_path, exp.short_path, elf.short_path),
content = "{} {} {} {}".format(sim_test_wrapper.short_path, exp_content, simulator.short_path, elf.short_path),
)

# Runfiles include sources, the .elf file, the simulator and test wrapper
# themselves, and all the simulator and test wrapper runfiles.
runfiles = ctx.runfiles(files = (ctx.files.srcs + additional_srcs + [elf, exp, ctx.executable._simulator, ctx.executable._sim_test_wrapper]))
runfiles = ctx.runfiles(files = (ctx.files.srcs + additional_srcs + exp_files + [elf, ctx.executable._simulator, ctx.executable._sim_test_wrapper]))
runfiles = runfiles.merge(ctx.attr._simulator[DefaultInfo].default_runfiles)
runfiles = runfiles.merge(ctx.attr._sim_test_wrapper[DefaultInfo].default_runfiles)
return [
Expand All @@ -180,7 +189,7 @@ def _otbn_sim_test(ctx):
them on the simulator. Tests are expected to count failures in the w0
register; the test checks that w0=0 to determine if the test passed.
"""
return _run_sim_test(ctx, ctx.file.exp)
return _run_sim_test(ctx, ctx.file.exp, ctx.file.dexp)

def _otbn_autogen_sim_test_impl(ctx):
"""
Expand Down Expand Up @@ -209,7 +218,7 @@ def _otbn_autogen_sim_test_impl(ctx):
executable = ctx.executable.testgen,
)

return _run_sim_test(ctx, exp, additional_srcs = [data])
return _run_sim_test(ctx, exp, None, additional_srcs = [data])

def _otbn_consttime_test_impl(ctx):
"""This rule checks if a program or subroutine is constant-time.
Expand Down Expand Up @@ -325,6 +334,7 @@ otbn_sim_test = rv_rule(
"srcs": attr.label_list(allow_files = True),
"deps": attr.label_list(providers = [DefaultInfo]),
"exp": attr.label(allow_single_file = True),
"dexp": attr.label(allow_single_file = True),
"_cc_toolchain": attr.label(default = Label("@bazel_tools//tools/cpp:current_cc_toolchain")),
"_otbn_as": attr.label(
default = "//hw/ip/otbn/util:otbn_as",
Expand Down
2 changes: 1 addition & 1 deletion sw/otbn/crypto/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ otbn_sim_test(
srcs = [
"mul_test.s",
],
exp = "mul_test.exp",
dexp = "mul_test.dexp",
deps = [
"//sw/otbn/crypto:mul",
],
Expand Down
1 change: 1 addition & 0 deletions sw/otbn/crypto/tests/mul_test.dexp
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
result: 6ff7b440e4486d8c1a343ef7de3f5755ff3b910ae028fecfda01f020f2a033736e328be51f4f5577051a9b76882e22e038083011956960faf5c89c3b5553facfad4cbd8311cf8bca78c79f34b479815005968c5d6f20fd62716dbfa334951f41436378bfb1669647382cf4af4eaeb12f7479d0a1dcfd33179d2bd863
6 changes: 0 additions & 6 deletions sw/otbn/crypto/tests/mul_test.exp

This file was deleted.

12 changes: 0 additions & 12 deletions sw/otbn/crypto/tests/mul_test.s
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,6 @@ main:
la x12, result
jal x1, bignum_mul

/* Load result into w0 through w3.
[w0..w3] <= dmem[result] */
la x2, result
li x3, 0
bn.lid x3, 0(x2++)
addi x3, x3, 1
bn.lid x3, 0(x2++)
addi x3, x3, 1
bn.lid x3, 0(x2++)
addi x3, x3, 1
bn.lid x3, 0(x2)

ecall


Expand Down

0 comments on commit e2ae6e7

Please sign in to comment.