Skip to content

Commit

Permalink
Refactor action signatures and add more types
Browse files Browse the repository at this point in the history
  • Loading branch information
dainnilsson committed Oct 7, 2024
1 parent 353766b commit 85f973f
Show file tree
Hide file tree
Showing 10 changed files with 314 additions and 286 deletions.
16 changes: 8 additions & 8 deletions helper/helper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from queue import Queue
from threading import Thread, Event
from typing import Callable, Dict, List
from typing import Callable

import json
import logging
Expand Down Expand Up @@ -78,14 +78,14 @@ def _handle_incoming(event, recv, error, cmd_queue):


def process(
send: Callable[[Dict], None],
recv: Callable[[], Dict],
handler: Callable[[str, List, Dict, Event, Callable[[str], None]], RpcResponse],
send: Callable[[dict], None],
recv: Callable[[], dict],
handler: Callable[[str, list, dict, Event, Callable[[str], None]], RpcResponse],
) -> None:
def error(status: str, message: str, body: Dict = {}):
def error(status: str, message: str, body: dict = {}):
send(dict(kind="error", status=status, message=message, body=body))

def signal(status: str, body: Dict = {}):
def signal(status: str, body: dict = {}):
send(dict(kind="signal", status=status, body=body))

def success(response: RpcResponse):
Expand Down Expand Up @@ -121,8 +121,8 @@ def success(response: RpcResponse):


def run_rpc(
send: Callable[[Dict], None],
recv: Callable[[], Dict],
send: Callable[[dict], None],
recv: Callable[[], dict],
) -> None:
process(send, recv, RootNode())

Expand Down
11 changes: 9 additions & 2 deletions helper/helper/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from yubikit.core import InvalidPinError
from functools import partial

import inspect
import logging

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -127,7 +128,13 @@ def __call__(self, action, target, params, event, signal, traversed=None):
action, target[1:], params, event, signal, traversed
)
elif action in self.list_actions():
response = self.get_action(action)(params, event, signal)
action_f = self.get_action(action)
args = inspect.signature(action_f).parameters
if "event" in args:
params["event"] = event
if "signal" in args:
params["signal"] = signal
response = action_f(**params)
elif action in self.list_children():
traversed += [action]
response = self.get_child(action)(
Expand Down Expand Up @@ -224,7 +231,7 @@ def get_child(self, name):
return self._child

@action
def get(self, params, event, signal):
def get(self):
return dict(
data=self.get_data(),
actions=self.list_actions(),
Expand Down
26 changes: 13 additions & 13 deletions helper/helper/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicKey
from hashlib import sha256
from dataclasses import asdict
from typing import Mapping, Tuple
from typing import Mapping

import os
import sys
Expand Down Expand Up @@ -113,19 +113,19 @@ def nfc(self):
return self._readers

@action
def diagnose(self, *ignored):
def diagnose(self):
return dict(diagnostics=get_diagnostics())

@action(closes_child=False)
def logging(self, params, event, signal):
level = LOG_LEVEL[params["level"].upper()]
set_log_level(level)
logger.info(f"Log level set to: {level.name}")
def logging(self, level: str):
lvl = LOG_LEVEL[level.upper()]
set_log_level(lvl)
logger.info(f"Log level set to: {lvl.name}")
return dict()

@action(closes_child=False)
def qr(self, params, event, signal):
return dict(result=scan_qr(params.get("image")))
def qr(self, image: str | None = None):
return dict(result=scan_qr(image))


def _id_from_fingerprint(fp):
Expand All @@ -142,7 +142,7 @@ def __init__(self):
self._reader_mapping = {}

@action(closes_child=False)
def scan(self, *ignored):
def scan(self):
return self.list_children()

def list_children(self):
Expand Down Expand Up @@ -173,7 +173,7 @@ def create_child(self, name):

class _ScanDevices:
def __init__(self):
self._state: Tuple[Mapping[PID, int], int] = ({}, 0)
self._state: tuple[Mapping[PID, int], int] = ({}, 0)
self._caching = False

def __call__(self):
Expand Down Expand Up @@ -225,7 +225,7 @@ def close(self):
super().close()

@action(closes_child=False)
def scan(self, *ignored):
def scan(self):
return self.get_data()

def get_data(self):
Expand Down Expand Up @@ -460,8 +460,8 @@ def _refresh_data(self):
return dict(present=False, status="no-card")

@action(closes_child=False)
def get(self, params, event, signal):
return super().get(params, event, signal)
def get(self):
return super().get()

@child
def ccid(self):
Expand Down
29 changes: 11 additions & 18 deletions helper/helper/fido.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def _prepare_reset_usb(device, event, signal):
raise TimeoutException()

@action
def reset(self, params, event, signal):
def reset(self, event, signal):
target = _ctap_id(self.ctap)
device = self.ctap.device
if isinstance(device, CtapPcscDevice):
Expand Down Expand Up @@ -206,8 +206,7 @@ def reset(self, params, event, signal):
return RpcResponse(dict(), ["device_info", "device_closed"])

@action(condition=lambda self: self._info.options["clientPin"])
def unlock(self, params, event, signal):
pin = params.pop("pin")
def unlock(self, pin: str):
permissions = ClientPin.PERMISSION(0)
if CredentialManagement.is_supported(self._info):
permissions |= ClientPin.PERMISSION.CREDENTIAL_MGMT
Expand All @@ -227,25 +226,21 @@ def unlock(self, params, event, signal):
return _handle_pin_error(e, self.client_pin)

@action
def set_pin(self, params, event, signal):
def set_pin(self, new_pin: str, pin: str | None = None):
has_pin = self.ctap.get_info().options["clientPin"]
try:
if has_pin:
self.client_pin.change_pin(
params.pop("pin"),
params.pop("new_pin"),
)
assert pin # nosec
self.client_pin.change_pin(pin, new_pin)
else:
self.client_pin.set_pin(
params.pop("new_pin"),
)
self.client_pin.set_pin(new_pin)
self._info = self.ctap.get_info()
return RpcResponse(dict(), ["device_info"])
except CtapError as e:
return _handle_pin_error(e, self.client_pin)

@action(condition=lambda self: Config.is_supported(self._info))
def enable_ep_attestation(self, params, event, signal):
def enable_ep_attestation(self):
if self._info.options["clientPin"] and not self._token:
raise AuthRequiredException()
config = Config(self.ctap, self.client_pin.protocol, self._token)
Expand Down Expand Up @@ -343,7 +338,7 @@ def get_data(self):
return self.data

@action
def delete(self, params, event, signal):
def delete(self):
self.credman.delete_cred(self.data["credential_id"])
self.refresh_rps()
return dict()
Expand Down Expand Up @@ -378,8 +373,7 @@ def create_child(self, name):
return super().create_child(name)

@action
def add(self, params, event, signal):
name = params.get("name", None)
def add(self, event, signal, name: str | None = None):
enroller = self.bio.enroll()
template_id = None
while template_id is None:
Expand Down Expand Up @@ -410,15 +404,14 @@ def get_data(self):
return dict(template_id=self.template_id, name=self.name)

@action
def rename(self, params, event, signal):
name = params.pop("name")
def rename(self, name: str):
self.bio.set_name(self.template_id, name)
self.name = name
self.refresh()
return dict()

@action
def delete(self, params, event, signal):
def delete(self):
self.bio.remove_enrollment(self.template_id)
self.refresh()
return dict()
57 changes: 41 additions & 16 deletions helper/helper/management.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,23 @@
# limitations under the License.

from .base import RpcResponse, RpcNode, action
from yubikit.core import require_version, NotSupportedError, TRANSPORT, Connection
from yubikit.core import (
require_version,
NotSupportedError,
TRANSPORT,
USB_INTERFACE,
Connection,
)
from yubikit.core.smartcard import SmartCardConnection
from yubikit.core.otp import OtpConnection
from yubikit.core.fido import FidoConnection
from yubikit.management import ManagementSession, DeviceConfig, Mode, CAPABILITY
from yubikit.management import (
ManagementSession,
DeviceConfig,
Mode,
CAPABILITY,
DEVICE_FLAG,
)
from ykman.device import list_all_devices
from dataclasses import asdict
from time import sleep
Expand Down Expand Up @@ -75,18 +87,26 @@ def _await_reboot(self, serial, usb_enabled):
logger.warning("Timed out waiting for device")

@action
def configure(self, params, event, signal):
reboot = params.pop("reboot", False)
cur_lock_code = bytes.fromhex(params.pop("cur_lock_code", "")) or None
new_lock_code = bytes.fromhex(params.pop("new_lock_code", "")) or None
def configure(
self,
reboot: bool = False,
cur_lock_code: str = "",
new_lock_code: str = "",
enabled_capabilities: dict = {},
auto_eject_timeout: int | None = None,
challenge_response_timeout: int | None = None,
device_flags: int | None = None,
):
cur_code = bytes.fromhex(cur_lock_code) or None
new_code = bytes.fromhex(new_lock_code) or None
config = DeviceConfig(
params.pop("enabled_capabilities", {}),
params.pop("auto_eject_timeout", None),
params.pop("challenge_response_timeout", None),
params.pop("device_flags", None),
enabled_capabilities,
auto_eject_timeout,
challenge_response_timeout,
DEVICE_FLAG(device_flags) if device_flags else None,
)
serial = self.session.read_device_info().serial
self.session.write_device_config(config, reboot, cur_lock_code, new_lock_code)
self.session.write_device_config(config, reboot, cur_code, new_code)
flags = ["device_info"]
if reboot:
enabled = config.enabled_capabilities.get(TRANSPORT.USB)
Expand All @@ -95,17 +115,22 @@ def configure(self, params, event, signal):
return RpcResponse(dict(), flags)

@action
def set_mode(self, params, event, signal):
def set_mode(
self,
interfaces: int,
challenge_response_timeout: int = 0,
auto_eject_timeout: int | None = None,
):
self.session.set_mode(
Mode(params["interfaces"]),
params.pop("challenge_response_timeout", 0),
params.pop("auto_eject_timeout"),
Mode(USB_INTERFACE(interfaces)),
challenge_response_timeout,
auto_eject_timeout,
)
return dict()

@action(
condition=lambda self: issubclass(self._connection_type, SmartCardConnection)
)
def device_reset(self, params, event, signal):
def device_reset(self):
self.session.device_reset()
return RpcResponse(dict(), ["device_info"])
Loading

0 comments on commit 85f973f

Please sign in to comment.