-
Notifications
You must be signed in to change notification settings - Fork 367
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add type hints around crypto module #1714
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
from dataclasses import dataclass | ||
import hashlib | ||
from datetime import datetime, timedelta, timezone | ||
from typing import Optional, List, TYPE_CHECKING, Union | ||
import logging | ||
import copy | ||
|
||
|
@@ -14,6 +15,10 @@ | |
class InvalidSignature(Exception): # type: ignore | ||
pass | ||
|
||
if TYPE_CHECKING: | ||
from asyncua.common.utils import Buffer | ||
from asyncua.ua.uaprotocol_hand import SecurityPolicy, SecurityPolicyFactory | ||
|
||
_logger = logging.getLogger('asyncua.uaprotocol') | ||
|
||
|
||
|
@@ -51,7 +56,7 @@ def is_chunk_count_within_limit(self, sz: int) -> bool: | |
_logger.error("Number of message chunks: %s is > configured max chunk count: %s", sz, self.max_chunk_count) | ||
return within_limit | ||
|
||
def create_acknowledge_and_set_limits(self, msg: ua.Hello) -> ua.Acknowledge: | ||
def create_acknowledge_and_set_limits(self, msg: "ua.Hello") -> "ua.Acknowledge": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that too cannot be correct. should be possible to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there should be no reason to use " for typing |
||
ack = ua.Acknowledge() | ||
ack.ReceiveBufferSize = min(msg.ReceiveBufferSize, self.max_send_buffer) | ||
ack.SendBufferSize = min(msg.SendBufferSize, self.max_recv_buffer) | ||
|
@@ -64,14 +69,14 @@ def create_acknowledge_and_set_limits(self, msg: ua.Hello) -> ua.Acknowledge: | |
_logger.info("updating server limits to: %s", self) | ||
return ack | ||
|
||
def create_hello_limits(self, msg: ua.Hello) -> ua.Hello: | ||
def create_hello_limits(self, msg: "ua.Hello") -> "ua.Hello": | ||
msg.ReceiveBufferSize = self.max_recv_buffer | ||
msg.SendBufferSize = self.max_send_buffer | ||
msg.MaxChunkCount = self.max_chunk_count | ||
msg.MaxMessageSize = self.max_chunk_count | ||
return msg | ||
|
||
def update_client_limits(self, msg: ua.Acknowledge) -> None: | ||
def update_client_limits(self, msg: "ua.Acknowledge") -> None: | ||
self.max_chunk_count = msg.MaxChunkCount | ||
self.max_recv_buffer = msg.ReceiveBufferSize | ||
self.max_send_buffer = msg.SendBufferSize | ||
|
@@ -105,7 +110,7 @@ def from_binary(security_policy, data): | |
return MessageChunk.from_header_and_body(security_policy, h, data, use_prev_key=True) | ||
|
||
@staticmethod | ||
def from_header_and_body(security_policy, header, buf, use_prev_key=False): | ||
def from_header_and_body(security_policy: "SecurityPolicy", header, buf, use_prev_key=False): | ||
if not len(buf) >= header.body_size: | ||
raise ValueError('Full body expected here') | ||
data = buf.copy(header.body_size) | ||
|
@@ -156,7 +161,7 @@ def max_body_size(crypto, max_chunk_size): | |
return max_plain_size - ua.SequenceHeader.max_size() - crypto.signature_size() - crypto.min_padding_size() | ||
|
||
@staticmethod | ||
def message_to_chunks(security_policy, body, max_chunk_size, message_type=ua.MessageType.SecureMessage, channel_id=1, request_id=1, token_id=1): | ||
def message_to_chunks(security_policy: "SecurityPolicy", body, max_chunk_size, message_type=ua.MessageType.SecureMessage, channel_id=1, request_id=1, token_id=1): | ||
""" | ||
Pack message body (as binary string) into one or more chunks. | ||
Size of each chunk will not exceed max_chunk_size. | ||
|
@@ -179,7 +184,7 @@ def message_to_chunks(security_policy, body, max_chunk_size, message_type=ua.Mes | |
crypto = security_policy.symmetric_cryptography | ||
max_size = MessageChunk.max_body_size(crypto, max_chunk_size) | ||
|
||
chunks = [] | ||
chunks: List[MessageChunk] = [] | ||
for i in range(0, len(body), max_size): | ||
part = body[i:i + max_size] | ||
if i + max_size >= len(body): | ||
|
@@ -204,22 +209,22 @@ class SecureConnection: | |
""" | ||
Common logic for client and server | ||
""" | ||
def __init__(self, security_policy, limits: TransportLimits): | ||
self._sequence_number = 0 | ||
self._peer_sequence_number = None | ||
self._incoming_parts = [] | ||
self.security_policy = security_policy | ||
self._policies = [] | ||
self._open = False | ||
def __init__(self, security_policy: "SecurityPolicy", limits: TransportLimits) -> None: | ||
self._sequence_number: int = 0 | ||
self._peer_sequence_number: Optional[int] = None | ||
self._incoming_parts: List[MessageChunk] = [] | ||
self.security_policy: SecurityPolicy = security_policy | ||
self._policies: List[SecurityPolicyFactory] = [] | ||
self._open: bool = False | ||
self.security_token = ua.ChannelSecurityToken() | ||
self.next_security_token = ua.ChannelSecurityToken() | ||
self.prev_security_token = ua.ChannelSecurityToken() | ||
self.local_nonce = 0 | ||
self.remote_nonce = 0 | ||
self._allow_prev_token = False | ||
self._limits = limits | ||
self.local_nonce: int = 0 | ||
self.remote_nonce:int = 0 | ||
self._allow_prev_token: bool = False | ||
self._limits: TransportLimits = limits | ||
|
||
def set_channel(self, params, request_type, client_nonce): | ||
def set_channel(self, params, request_type, client_nonce) -> None: | ||
""" | ||
Called on client side when getting secure channel data from server. | ||
""" | ||
|
@@ -241,7 +246,7 @@ def set_channel(self, params, request_type, client_nonce): | |
|
||
self._allow_prev_token = True | ||
|
||
def open(self, params, server): | ||
def open(self, params, server) -> ua.OpenSecureChannelResult: | ||
""" | ||
Called on server side to open secure channel. | ||
""" | ||
|
@@ -276,32 +281,32 @@ def open(self, params, server): | |
|
||
return response | ||
|
||
def close(self): | ||
def close(self) -> None: | ||
self._open = False | ||
|
||
def is_open(self): | ||
def is_open(self) -> bool: | ||
return self._open | ||
|
||
def set_policy_factories(self, policies): | ||
def set_policy_factories(self, policies: "List[SecurityPolicyFactory]") -> None: | ||
""" | ||
Set a list of available security policies. | ||
Use this in servers with multiple endpoints with different security. | ||
""" | ||
self._policies = policies | ||
|
||
@staticmethod | ||
def _policy_matches(policy, uri, mode=None): | ||
def _policy_matches(policy: "SecurityPolicy", uri, mode=None) -> bool: | ||
return policy.URI == uri and (mode is None or policy.Mode == mode) | ||
|
||
def select_policy(self, uri, peer_certificate, mode=None): | ||
def select_policy(self, uri: str, peer_certificate, mode=None): | ||
for policy in self._policies: | ||
if policy.matches(uri, mode): | ||
self.security_policy = policy.create(peer_certificate) | ||
return | ||
if self.security_policy.URI != uri or (mode is not None and self.security_policy.Mode != mode): | ||
raise ua.UaError(f"No matching policy: {uri}, {mode}") | ||
|
||
def revolve_tokens(self): | ||
def revolve_tokens(self) -> None: | ||
""" | ||
Revolve security tokens of the security channel. Start using the | ||
next security token negotiated during the renewal of the channel and | ||
|
@@ -389,7 +394,7 @@ def _check_incoming_chunk(self, chunk): | |
raise ua.UaError(f"Received chunk: {chunk} with wrong sequence expecting:" f" {self._peer_sequence_number}, received: {seq_num}," f" spec says to close connection") | ||
self._peer_sequence_number = seq_num | ||
|
||
def receive_from_header_and_body(self, header, body): | ||
def receive_from_header_and_body(self, header: ua.Header, body: "Buffer") -> Union[None,ua.Message,"ua.Hello",ua.Acknowledge,ua.ErrorMessage]: | ||
""" | ||
Convert MessageHeader and binary body to OPC UA TCP message (see OPC UA | ||
specs Part 6, 7.1: Hello, Acknowledge or ErrorMessage), or a Message | ||
|
@@ -430,7 +435,7 @@ def receive_from_header_and_body(self, header, body): | |
return msg | ||
raise ua.UaError(f"Unsupported message type {header.MessageType}") | ||
|
||
def _receive(self, msg): | ||
def _receive(self, msg: MessageChunk) -> Optional[ua.Message]: | ||
if msg.MessageHeader.packet_size > self._limits.max_recv_buffer: | ||
self._incoming_parts = [] | ||
_logger.error("Message size: %s is > chunk max size: %s", msg.MessageHeader.packet_size, self._limits.max_recv_buffer) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that check for isinstance cannot be correct. if that is because of mypy then annotate the "recevie_from_header_xx" over