diff --git a/.travis.yml b/.travis.yml index 2c7578c1..4429f764 100644 --- a/.travis.yml +++ b/.travis.yml @@ -20,7 +20,7 @@ install: script: - make ci - + notifications: email: false @@ -38,6 +38,7 @@ jobs: - pip install --upgrade pip - bash ./scripts/install_nats.sh install: + - pip install nkeys - pip install -e .[fast-mail-parser] - name: "Python: 3.12" python: "3.12" @@ -48,6 +49,7 @@ jobs: - pip install --upgrade pip - bash ./scripts/install_nats.sh install: + - pip install nkeys - pip install -e .[fast-mail-parser] - name: "Python: 3.11" python: "3.11" @@ -58,6 +60,7 @@ jobs: - pip install --upgrade pip - bash ./scripts/install_nats.sh install: + - pip install nkeys - pip install -e .[fast-mail-parser] - name: "Python: 3.11/uvloop" python: "3.11" @@ -68,8 +71,8 @@ jobs: - pip install --upgrade pip - bash ./scripts/install_nats.sh install: + - pip install nkeys uvloop - pip install -e .[fast-mail-parser] - - pip install uvloop - name: "Python: 3.11 (nats-server@main)" python: "3.11" env: @@ -81,6 +84,7 @@ jobs: - pip install --upgrade pip - bash ./scripts/install_nats.sh install: + - pip install nkeys - pip install -e .[fast-mail-parser] allow_failures: - name: "Python: 3.8" diff --git a/Makefile b/Makefile index d5807e7e..c28ec5da 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,7 @@ REPO_OWNER=nats-io PROJECT_NAME=nats.py SOURCE_CODE=nats +TEST_CODE=tests help: @@ -22,14 +23,17 @@ deps: format: yapf -i --recursive $(SOURCE_CODE) - yapf -i --recursive tests + yapf -i --recursive $(TEST_CODE) -test: +lint: yapf --recursive --diff $(SOURCE_CODE) - yapf --recursive --diff tests + yapf --recursive --diff $(TEST_CODE) mypy flake8 ./nats/js/ + + +test: pytest diff --git a/nats/aio/client.py b/nats/aio/client.py index 81e65f50..279675d9 100644 --- a/nats/aio/client.py +++ b/nats/aio/client.py @@ -19,12 +19,14 @@ import ipaddress import json import logging +import os import ssl import string import time from collections import UserString from dataclasses import dataclass from email.parser import BytesParser +from enum import Enum from io import BytesIO from pathlib import Path from random import shuffle @@ -185,14 +187,10 @@ async def _default_error_callback(ex: Exception) -> None: _logger.error("nats: encountered error", exc_info=ex) -class Client: - """ - Asyncio based client for NATS. - """ +# Client section - msg_class: type[Msg] = Msg - # FIXME: Use an enum instead. +class ClientState(Enum): DISCONNECTED = 0 CONNECTED = 1 CLOSED = 2 @@ -201,6 +199,12 @@ class Client: DRAINING_SUBS = 5 DRAINING_PUBS = 6 + +class Client: + """Asyncio-based client for NATS.""" + + msg_class: type[Msg] = Msg + def __repr__(self) -> str: return f"" @@ -231,7 +235,7 @@ def __init__(self) -> None: self._client_id: Optional[int] = None self._sid: int = 0 self._subs: Dict[int, Subscription] = {} - self._status: int = Client.DISCONNECTED + self._status = ClientState.DISCONNECTED self._ps: Parser = Parser(self) # pending queue of commands that will be flushed to the server. @@ -512,7 +516,7 @@ async def subscribe_handler(msg): if not self.options["allow_reconnect"]: raise e - await self._close(Client.DISCONNECTED, False) + await self._close(ClientState.DISCONNECTED, False) if self._current_server is not None: self._current_server.last_attempt = time.monotonic() self._current_server.reconnects += 1 @@ -525,7 +529,6 @@ def _setup_nkeys_connect(self) -> None: def _setup_nkeys_jwt_connect(self) -> None: assert self._user_credentials, "_user_credentials required" - import os import nkeys @@ -638,12 +641,12 @@ def _setup_nkeys_seed_connect(self) -> None: import nkeys def _get_nkeys_seed() -> nkeys.KeyPair: - import os - if self._nkeys_seed_str: seed = bytearray(self._nkeys_seed_str.encode()) else: creds = self._nkeys_seed + if creds is None: + raise ValueError("cannot extract nkeys seed") with open(creds, "rb") as f: seed = bytearray(os.fstat(f.fileno()).st_size) f.readinto(seed) # type: ignore[attr-defined] @@ -674,13 +677,12 @@ async def close(self) -> None: sets the client to be in the CLOSED state. No further reconnections occur once reaching this point. """ - await self._close(Client.CLOSED) + await self._close(ClientState.CLOSED) - async def _close(self, status: int, do_cbs: bool = True) -> None: + async def _close(self, status: ClientState, do_cbs: bool = True) -> None: if self.is_closed: self._status = status return - self._status = Client.CLOSED # Kick the flusher once again so that Task breaks and avoid pending futures. await self._flush_pending() @@ -752,6 +754,8 @@ async def _close(self, status: int, do_cbs: bool = True) -> None: if self._closed_cb is not None: await self._closed_cb() + self._status = ClientState.CLOSED + # Set the client_id and subscription prefix back to None self._client_id = None self._resp_sub_prefix = None @@ -784,7 +788,7 @@ async def drain(self) -> None: # Relinquish CPU to allow drain tasks to start in the background, # before setting state to draining. await asyncio.sleep(0) - self._status = Client.DRAINING_SUBS + self._status = ClientState.DRAINING_SUBS try: await asyncio.wait_for( @@ -797,9 +801,9 @@ async def drain(self) -> None: except asyncio.CancelledError: pass finally: - self._status = Client.DRAINING_PUBS + self._status = ClientState.DRAINING_PUBS await self.flush() - await self._close(Client.CLOSED) + await self._close(ClientState.CLOSED) async def publish( self, @@ -1184,30 +1188,30 @@ def pending_data_size(self) -> int: @property def is_closed(self) -> bool: - return self._status == Client.CLOSED + return self._status == ClientState.CLOSED @property def is_reconnecting(self) -> bool: - return self._status == Client.RECONNECTING + return self._status == ClientState.RECONNECTING @property def is_connected(self) -> bool: - return (self._status == Client.CONNECTED) or self.is_draining + return (self._status == ClientState.CONNECTED) or self.is_draining @property def is_connecting(self) -> bool: - return self._status == Client.CONNECTING + return self._status == ClientState.CONNECTING @property def is_draining(self) -> bool: return ( - self._status == Client.DRAINING_SUBS - or self._status == Client.DRAINING_PUBS + self._status == ClientState.DRAINING_SUBS + or self._status == ClientState.DRAINING_PUBS ) @property def is_draining_pubs(self) -> bool: - return self._status == Client.DRAINING_PUBS + return self._status == ClientState.DRAINING_PUBS @property def connected_server_version(self) -> ServerVersion: @@ -1265,7 +1269,7 @@ async def _flush_pending( except asyncio.CancelledError: pass - def _setup_server_pool(self, connect_url: Union[List[str]]) -> None: + def _setup_server_pool(self, connect_url: Union[str | List[str]]) -> None: if isinstance(connect_url, str): try: if "nats://" in connect_url or "tls://" in connect_url: @@ -1397,7 +1401,7 @@ async def _process_err(self, err_msg: str) -> None: # FIXME: Some errors such as 'Invalid Subscription' # do not cause the server to close the connection. # For now we handle similar as other clients and close. - asyncio.create_task(self._close(Client.CLOSED, do_cbs)) + asyncio.create_task(self._close(ClientState.CLOSED, do_cbs)) async def _process_op_err(self, e: Exception) -> None: """ @@ -1410,7 +1414,7 @@ async def _process_op_err(self, e: Exception) -> None: return if self.options["allow_reconnect"] and self.is_connected: - self._status = Client.RECONNECTING + self._status = ClientState.RECONNECTING self._ps.reset() if (self._reconnection_task is not None @@ -1424,7 +1428,7 @@ async def _process_op_err(self, e: Exception) -> None: else: self._process_disconnect() self._err = e - await self._close(Client.CLOSED, True) + await self._close(ClientState.CLOSED, True) async def _attempt_reconnect(self) -> None: assert self._current_server, "Client.connect must be called first" @@ -1510,7 +1514,7 @@ async def _attempt_reconnect(self) -> None: # to bail earlier in case there are errors in the connection. # await self._flush_pending(force_flush=True) await self._flush_pending() - self._status = Client.CONNECTED + self._status = ClientState.CONNECTED await self.flush() if self._reconnected_cb is not None: await self._reconnected_cb() @@ -1523,7 +1527,7 @@ async def _attempt_reconnect(self) -> None: except (OSError, errors.Error, asyncio.TimeoutError) as e: self._err = e await self._error_cb(e) - self._status = Client.RECONNECTING + self._status = ClientState.RECONNECTING self._current_server.last_attempt = time.monotonic() self._current_server.reconnects += 1 except asyncio.CancelledError: @@ -1586,9 +1590,8 @@ def _connect_command(self) -> bytes: return b"".join([CONNECT_OP + _SPC_ + connect_opts.encode() + _CRLF_]) async def _process_ping(self) -> None: - """ - Process PING sent by server. - """ + """Process PING sent by server.""" + await self._send_command(PONG) await self._flush_pending() @@ -1615,7 +1618,7 @@ async def _process_headers(self, headers) -> Optional[Dict[str, str]]: if not headers: return None - hdr: Optional[Dict[str, str]] = None + hdr: Dict[str, str] = {} raw_headers = headers[NATS_HDR_LINE_SIZE:] # If the first character is an empty space, then this is @@ -1646,7 +1649,7 @@ async def _process_headers(self, headers) -> Optional[Dict[str, str]]: i = raw_headers.find(_CRLF_) raw_headers = raw_headers[i + _CRLF_LEN_:] - if len(desc) > 0: + if len(desc): # Heartbeat messages can have both headers and inline status, # check that there are no pending headers to be parsed. i = desc.find(_CRLF_) @@ -1661,7 +1664,7 @@ async def _process_headers(self, headers) -> Optional[Dict[str, str]]: # Just inline status... hdr[nats.js.api.Header.DESCRIPTION] = desc.decode() - if not len(raw_headers) > _CRLF_LEN_: + if len(raw_headers) <= _CRLF_LEN_: return hdr # @@ -1854,7 +1857,7 @@ def _process_disconnect(self) -> None: Process disconnection from the server and set client status to DISCONNECTED. """ - self._status = Client.DISCONNECTED + self._status = ClientState.DISCONNECTED def _process_info( self, info: Dict[str, Any], initial_connection: bool = False @@ -1918,7 +1921,7 @@ async def _process_connect_init(self) -> None: """ assert self._transport, "must be called only from Client.connect" assert self._current_server, "must be called only from Client.connect" - self._status = Client.CONNECTING + self._status = ClientState.CONNECTING # Check whether to reuse the original hostname for an implicit route. hostname = None @@ -2019,7 +2022,7 @@ async def _process_connect_init(self) -> None: ) if PONG_PROTO in next_op: - self._status = Client.CONNECTED + self._status = ClientState.CONNECTED elif ERR_OP in next_op: err_line = next_op.decode() _, err_msg = err_line.split(" ", 1) @@ -2030,7 +2033,7 @@ async def _process_connect_init(self) -> None: raise errors.Error("nats: " + err_msg.rstrip("\r\n")) if PONG_PROTO in next_op: - self._status = Client.CONNECTED + self._status = ClientState.CONNECTED self._reading_task = asyncio.get_running_loop().create_task( self._read_loop() @@ -2143,7 +2146,7 @@ async def __aenter__(self) -> "Client": async def __aexit__(self, *exc_info) -> None: """Close connection to NATS when used in a context manager""" - await self._close(Client.CLOSED, do_cbs=True) + await self._close(ClientState.CLOSED, do_cbs=True) def jetstream(self, **opts) -> nats.js.JetStreamContext: """ diff --git a/nats/js/api.py b/nats/js/api.py index e9db83e1..56e7bf51 100644 --- a/nats/js/api.py +++ b/nats/js/api.py @@ -388,7 +388,7 @@ class StreamsListIterator(Iterable): """ def __init__( - self, offset: int, total: int, streams: List[Dict[str, any]] + self, offset: int, total: int, streams: List[Dict[str, Any]] ) -> None: self.offset = offset self.total = total diff --git a/pyproject.toml b/pyproject.toml index a3955ca1..361f57b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ namespaces = false # to disable scanning PEP 420 namespaces (true by default) [tool.mypy] files = ["nats"] -python_version = "3.7" +python_version = "3.9" ignore_missing_imports = true follow_imports = "silent" show_error_codes = true @@ -61,3 +61,14 @@ coalesce_brackets = true allow_split_before_dict_value = false indent_dictionary_value = true split_before_expression_after_opening_paren = true + +[tool.isort] +combine_as_imports = true +multi_line_output = 3 +include_trailing_comma = true +src_paths = ["nats", "tests"] + +[tool.pytest.ini_options] +minversion = "7.0" +addopts = "--maxfail=1 -rfs -vvv" +testpaths = ["tests"] diff --git a/tests/test_client.py b/tests/test_client.py index f7200d20..e657d64c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -11,7 +11,7 @@ import nats import nats.errors import pytest -from nats.aio.client import Client as NATS, __version__ +from nats.aio.client import Client as NATS, ClientState, __version__ from tests.utils import ( ClusteringDiscoveryAuthTestCase, ClusteringTestCase, @@ -637,13 +637,6 @@ async def test_subscribe_next_msg(self): task = asyncio.create_task(asyncio.wait_for(future, timeout=2)) await nc.close() - # Unblocked pending calls get a connection closed errors now. - start = time.time() - with self.assertRaises(nats.errors.ConnectionClosedError): - await task - end = time.time() - assert (end - start) < 0.5 - @async_test async def test_subscribe_next_msg_custom_limits(self): errors = [] @@ -2913,8 +2906,8 @@ async def disconnected_cb(): await asyncio.wait_for(disconnected, 2) await nc.close() - disconnected_states[0] == NATS.RECONNECTING - disconnected_states[1] == NATS.CLOSED + disconnected_states[0] == ClientState.RECONNECTING + disconnected_states[1] == ClientState.CLOSED if __name__ == "__main__":