Skip to content

Commit

Permalink
add type hints around crypto module
Browse files Browse the repository at this point in the history
  • Loading branch information
jbeckerwsi committed Oct 2, 2024
1 parent 1343fa7 commit 7de3e1d
Show file tree
Hide file tree
Showing 9 changed files with 141 additions and 98 deletions.
3 changes: 2 additions & 1 deletion asyncua/client/ha/reconciliator.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ async def reconciliate(self) -> None:
for url in valid_urls:
digest_ideal = get_digest(ideal_map[url])
digest_real = get_digest(real_map[url])
if url not in real_map or digest_ideal != digest_real:
#if url not in real_map or digest_ideal != digest_real:
if url not in real_map or ideal_map[url] != real_map[url]:
targets.add(url)
if not targets:
_logger.info(
Expand Down
4 changes: 2 additions & 2 deletions asyncua/client/ua_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _process_received_data(self, data: bytes) -> None:
return
msg = self._connection.receive_from_header_and_body(header, buf)
self._process_received_message(msg)
if header.MessageType == ua.MessageType.SecureOpen:
if header.MessageType == ua.MessageType.SecureOpen and isinstance(msg,ua.Message):
params: ua.OpenSecureChannelParameters = self._open_secure_channel_exchange
response: ua.OpenSecureChannelResponse = struct_from_binary(ua.OpenSecureChannelResponse, msg.body())
response.ResponseHeader.ServiceResult.check()
Expand All @@ -107,7 +107,7 @@ def _process_received_data(self, data: bytes) -> None:
self.disconnect_socket()
return

def _process_received_message(self, msg: Union[ua.Message, ua.Acknowledge, ua.ErrorMessage]):
def _process_received_message(self, msg: Union[None,ua.Message, ua.Acknowledge, ua.ErrorMessage]):
if msg is None:
pass
elif isinstance(msg, ua.Message):
Expand Down
53 changes: 29 additions & 24 deletions asyncua/common/connection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
import hashlib
from datetime import datetime, timedelta, timezone
from typing import Optional, List, TYPE_CHECKING, Union
import logging
import copy

Expand All @@ -14,6 +15,10 @@
class InvalidSignature(Exception): # type: ignore
pass

if TYPE_CHECKING:
from asyncua.common.utils import Buffer
from asyncua.ua.uaprotocol_hand import SecurityPolicy, SecurityPolicyFactory

_logger = logging.getLogger('asyncua.uaprotocol')


Expand Down Expand Up @@ -105,7 +110,7 @@ def from_binary(security_policy, data):
return MessageChunk.from_header_and_body(security_policy, h, data, use_prev_key=True)

@staticmethod
def from_header_and_body(security_policy, header, buf, use_prev_key=False):
def from_header_and_body(security_policy: "SecurityPolicy", header, buf, use_prev_key=False):
if not len(buf) >= header.body_size:
raise ValueError('Full body expected here')
data = buf.copy(header.body_size)
Expand Down Expand Up @@ -156,7 +161,7 @@ def max_body_size(crypto, max_chunk_size):
return max_plain_size - ua.SequenceHeader.max_size() - crypto.signature_size() - crypto.min_padding_size()

@staticmethod
def message_to_chunks(security_policy, body, max_chunk_size, message_type=ua.MessageType.SecureMessage, channel_id=1, request_id=1, token_id=1):
def message_to_chunks(security_policy: "SecurityPolicy", body, max_chunk_size, message_type=ua.MessageType.SecureMessage, channel_id=1, request_id=1, token_id=1):
"""
Pack message body (as binary string) into one or more chunks.
Size of each chunk will not exceed max_chunk_size.
Expand All @@ -179,7 +184,7 @@ def message_to_chunks(security_policy, body, max_chunk_size, message_type=ua.Mes
crypto = security_policy.symmetric_cryptography
max_size = MessageChunk.max_body_size(crypto, max_chunk_size)

chunks = []
chunks: List[MessageChunk] = []
for i in range(0, len(body), max_size):
part = body[i:i + max_size]
if i + max_size >= len(body):
Expand All @@ -204,22 +209,22 @@ class SecureConnection:
"""
Common logic for client and server
"""
def __init__(self, security_policy, limits: TransportLimits):
self._sequence_number = 0
self._peer_sequence_number = None
self._incoming_parts = []
self.security_policy = security_policy
self._policies = []
self._open = False
def __init__(self, security_policy: "SecurityPolicy", limits: TransportLimits) -> None:
self._sequence_number: int = 0
self._peer_sequence_number: Optional[int] = None
self._incoming_parts: List[MessageChunk] = []
self.security_policy: "SecurityPolicy" = security_policy
self._policies: "List[SecurityPolicyFactory]" = []
self._open: bool = False
self.security_token = ua.ChannelSecurityToken()
self.next_security_token = ua.ChannelSecurityToken()
self.prev_security_token = ua.ChannelSecurityToken()
self.local_nonce = 0
self.remote_nonce = 0
self._allow_prev_token = False
self._limits = limits
self.local_nonce: int = 0
self.remote_nonce:int = 0
self._allow_prev_token: bool = False
self._limits: TransportLimits = limits

def set_channel(self, params, request_type, client_nonce):
def set_channel(self, params, request_type, client_nonce) -> None:
"""
Called on client side when getting secure channel data from server.
"""
Expand All @@ -241,7 +246,7 @@ def set_channel(self, params, request_type, client_nonce):

self._allow_prev_token = True

def open(self, params, server):
def open(self, params, server) -> ua.OpenSecureChannelResult:
"""
Called on server side to open secure channel.
"""
Expand Down Expand Up @@ -276,32 +281,32 @@ def open(self, params, server):

return response

def close(self):
def close(self) -> None:
self._open = False

def is_open(self):
def is_open(self) -> bool:
return self._open

def set_policy_factories(self, policies):
def set_policy_factories(self, policies: "List[SecurityPolicyFactory]") -> None:
"""
Set a list of available security policies.
Use this in servers with multiple endpoints with different security.
"""
self._policies = policies

@staticmethod
def _policy_matches(policy, uri, mode=None):
def _policy_matches(policy: "SecurityPolicy", uri, mode=None) -> bool:
return policy.URI == uri and (mode is None or policy.Mode == mode)

def select_policy(self, uri, peer_certificate, mode=None):
def select_policy(self, uri: str, peer_certificate, mode=None):
for policy in self._policies:
if policy.matches(uri, mode):
self.security_policy = policy.create(peer_certificate)
return
if self.security_policy.URI != uri or (mode is not None and self.security_policy.Mode != mode):
raise ua.UaError(f"No matching policy: {uri}, {mode}")

def revolve_tokens(self):
def revolve_tokens(self) -> None:
"""
Revolve security tokens of the security channel. Start using the
next security token negotiated during the renewal of the channel and
Expand Down Expand Up @@ -389,7 +394,7 @@ def _check_incoming_chunk(self, chunk):
raise ua.UaError(f"Received chunk: {chunk} with wrong sequence expecting:" f" {self._peer_sequence_number}, received: {seq_num}," f" spec says to close connection")
self._peer_sequence_number = seq_num

def receive_from_header_and_body(self, header, body):
def receive_from_header_and_body(self, header: ua.Header, body: "Buffer") -> Union[None,ua.Message,ua.Hello,ua.Acknowledge,ua.ErrorMessage]:
"""
Convert MessageHeader and binary body to OPC UA TCP message (see OPC UA
specs Part 6, 7.1: Hello, Acknowledge or ErrorMessage), or a Message
Expand Down Expand Up @@ -430,7 +435,7 @@ def receive_from_header_and_body(self, header, body):
return msg
raise ua.UaError(f"Unsupported message type {header.MessageType}")

def _receive(self, msg):
def _receive(self, msg: MessageChunk) -> Optional[ua.Message]:
if msg.MessageHeader.packet_size > self._limits.max_recv_buffer:
self._incoming_parts = []
_logger.error("Message size: %s is > chunk max size: %s", msg.MessageHeader.packet_size, self._limits.max_recv_buffer)
Expand Down
29 changes: 19 additions & 10 deletions asyncua/crypto/permission_rules.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from abc import abstractmethod
from asyncua import ua
from asyncua.server.users import UserRole
from abc import ABC

WRITE_TYPES = [
from typing import TYPE_CHECKING, Tuple, Dict, Set

if TYPE_CHECKING:
from asyncua.server.users import User
from asyncua.common.utils import Buffer

WRITE_TYPES: Tuple[int,...] = (
ua.ObjectIds.WriteRequest_Encoding_DefaultBinary,
ua.ObjectIds.RegisterServerRequest_Encoding_DefaultBinary,
ua.ObjectIds.RegisterServer2Request_Encoding_DefaultBinary,
Expand All @@ -11,9 +19,9 @@
ua.ObjectIds.DeleteReferencesRequest_Encoding_DefaultBinary,
ua.ObjectIds.RegisterNodesRequest_Encoding_DefaultBinary,
ua.ObjectIds.UnregisterNodesRequest_Encoding_DefaultBinary
]
)

READ_TYPES = [
READ_TYPES: Tuple[int,...] = (
ua.ObjectIds.CreateSessionRequest_Encoding_DefaultBinary,
ua.ObjectIds.CloseSessionRequest_Encoding_DefaultBinary,
ua.ObjectIds.ActivateSessionRequest_Encoding_DefaultBinary,
Expand All @@ -33,16 +41,17 @@
ua.ObjectIds.CloseSecureChannelRequest_Encoding_DefaultBinary,
ua.ObjectIds.CallRequest_Encoding_DefaultBinary,
ua.ObjectIds.SetMonitoringModeRequest_Encoding_DefaultBinary,
ua.ObjectIds.SetPublishingModeRequest_Encoding_DefaultBinary
]
ua.ObjectIds.SetPublishingModeRequest_Encoding_DefaultBinary,
)


class PermissionRuleset:
class PermissionRuleset(ABC):
"""
Base class for permission ruleset
"""

def check_validity(self, user, action_type, body):
@abstractmethod
def check_validity(self, user: "User", action_type_id: ua.NodeId, body: "Buffer") -> bool:
raise NotImplementedError


Expand All @@ -52,16 +61,16 @@ class SimpleRoleRuleset(PermissionRuleset):
Admins alone can write, admins and users can read, and anonymous users can't do anything.
"""

def __init__(self):
def __init__(self) -> None:
write_ids = list(map(ua.NodeId, WRITE_TYPES))
read_ids = list(map(ua.NodeId, READ_TYPES))
self._permission_dict = {
self._permission_dict: Dict[UserRole, Set[ua.NodeId]] = {
UserRole.Admin: set().union(write_ids, read_ids),
UserRole.User: set().union(read_ids),
UserRole.Anonymous: set()
}

def check_validity(self, user, action_type_id, body):
def check_validity(self, user: "User", action_type_id: ua.NodeId, body: "Buffer") -> bool:
if action_type_id in self._permission_dict[user.role]:
return True
else:
Expand Down
16 changes: 8 additions & 8 deletions asyncua/crypto/security_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@ class Verifier:
__metaclass__ = ABCMeta

@abstractmethod
def signature_size(self):
def signature_size(self) -> None:
pass

@abstractmethod
def verify(self, data, signature):
def verify(self, data, signature) -> None:
pass

def reset(self):
def reset(self) -> None:
attrs = self.__dict__
for k in attrs:
attrs[k] = None
Expand All @@ -70,11 +70,11 @@ class Encryptor:
__metaclass__ = ABCMeta

@abstractmethod
def plain_block_size(self):
def plain_block_size(self) -> int:
pass

@abstractmethod
def encrypted_block_size(self):
def encrypted_block_size(self) -> int:
pass

@abstractmethod
Expand All @@ -90,18 +90,18 @@ class Decryptor:
__metaclass__ = ABCMeta

@abstractmethod
def plain_block_size(self):
def plain_block_size(self) -> int:
pass

@abstractmethod
def encrypted_block_size(self):
def encrypted_block_size(self) -> int:
pass

@abstractmethod
def decrypt(self, data):
pass

def reset(self):
def reset(self) -> None:
attrs = self.__dict__
for k in attrs:
attrs[k] = None
Expand Down
Loading

0 comments on commit 7de3e1d

Please sign in to comment.