diff --git a/betfair_parser/stream.py b/betfair_parser/stream.py index 4c659ad..79f6f10 100644 --- a/betfair_parser/stream.py +++ b/betfair_parser/stream.py @@ -1,3 +1,4 @@ +import asyncio import itertools import socket import ssl @@ -27,15 +28,25 @@ def create_ssl_socket(hostname, timeout: int = 15) -> ssl.SSLSocket: return secure_sock -class Stream: - _sock: Optional[ssl.SSLSocket] - _io: Optional[socket.SocketIO] - _connection_id: Optional[str] +class _BaseStream: + _connection_id: Optional[str] = None def __init__(self, endpoint) -> None: self._endpoint = endpoint self._id_generator = itertools.count() + @property + def connection_id(self) -> str: + return self._connection_id + + def unique_id(self) -> int: + return next(self._id_generator) + + +class Stream(_BaseStream): + _sock: Optional[ssl.SSLSocket] = None + _io: Optional[socket.SocketIO] = None + def connect(self) -> None: url = urllib.parse.urlparse(self._endpoint) self._sock = create_ssl_socket(url.hostname) @@ -44,9 +55,6 @@ def connect(self) -> None: msg: Connection = self.receive() # type: ignore self._connection_id = msg.connection_id - def unique_id(self) -> int: - return next(self._id_generator) - def send(self, request: STREAM_REQUEST) -> None: if not self._io: raise StreamError("Stream is not connected") @@ -72,10 +80,6 @@ def close(self): self._sock.close() self._sock = None - @property - def connection_id(self) -> str: - return self._connection_id - def authenticate(self, app_key: str, token: str) -> None: self.send(Authentication(id=self.unique_id(), app_key=app_key, session=token)) msg: Status = self.receive() # type: ignore @@ -93,3 +97,63 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb) -> None: self.close() + + +class AsyncStream(_BaseStream): + _reader: Optional[asyncio.StreamReader] = None + _writer: Optional[asyncio.StreamWriter] = None + + async def connect(self) -> None: + url = urllib.parse.urlparse(self._endpoint) + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + context.load_default_certs() + self._reader, self._writer = await asyncio.open_connection( + host=url.hostname, + port=url.port, + ssl=context, + server_hostname=url.hostname, + limit=1_000_000, + ) + msg: Connection = await self.receive() # type: ignore + self._connection_id = msg.connection_id + + async def send(self, request: STREAM_REQUEST) -> None: + if not self._writer: + raise StreamError("Stream is not connected") + msg = encode(request) + b"\r\n" + self._writer.write(msg) + await self._writer.drain() + + async def receive(self) -> STREAM_RESPONSE: + if not self._reader: + raise StreamError("Stream is not connected") + data = await self._reader.readline() + return stream_decode(data) + + async def close(self): + if self._writer: + try: + await self._writer.drain() + finally: + self._writer.close() + await self._writer.wait_closed() + self._writer = None + self._reader = None + + async def authenticate(self, app_key: str, token: str) -> None: + await self.send(Authentication(id=self.unique_id(), app_key=app_key, session=token)) + msg: Status = await self.receive() # type: ignore + if msg.is_error: + raise StreamAuthenticationError(f"{msg.error_code.name}: {msg.error_message}") + if msg.connection_closed: + raise StreamAuthenticationError("Connection was closed by the server unexpectedly") + + async def heartbeat(self) -> None: + await self.send(Heartbeat(id=self.unique_id())) + + async def __aenter__(self): + await self.connect() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + await self.close() diff --git a/pyproject.toml b/pyproject.toml index e9847fb..72e388e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,7 @@ +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" + [project] name = "betfair_parser" description = "A betfair parser" @@ -19,25 +23,19 @@ classifiers = [ keywords = ["parser", "betfair", "api", "json", "streaming"] dependencies = [ "msgspec>=0.16.0", - "fsspec>=2022", ] - -[project.urls] -Homepage = "https://github.com/limx0/betfair_parser" -Documentation = "https://limx0.github.io/betfair_parser/" -"Bug Tracker" = "https://github.com/limx0/betfair_parser/issues" - -[project.optional-dependencies] -dev = [ +optional-dependencies.dev = [ "pytest>=7.1", + "pytest-asyncio>=0.21", "pytest-benchmark>=4.0", "twine>=4.0.2", "requests>=2.20.0", ] -[build-system] -requires = ["poetry-core>=1.0.0"] -build-backend = "poetry.core.masonry.api" +[project.urls] +Homepage = "https://github.com/limx0/betfair_parser" +Documentation = "https://limx0.github.io/betfair_parser/" +"Bug Tracker" = "https://github.com/limx0/betfair_parser/issues" [tool.poetry] name = "betfair_parser" @@ -52,6 +50,7 @@ msgspec = ">=0.16" [tool.poetry.dev-dependencies] pytest = "^7.1" +pytest-asyncio = "^0.21" pytest-benchmark = "^4.0" requests = "^2.20" diff --git a/tests/integration/test_stream.py b/tests/integration/test_stream.py index b913261..91a5737 100644 --- a/tests/integration/test_stream.py +++ b/tests/integration/test_stream.py @@ -18,8 +18,8 @@ OrderSubscription, Status, ) -from betfair_parser.stream import Stream -from tests.integration.test_live import appconfig, skip_not_logged_in # noqa: F401 +from betfair_parser.stream import AsyncStream, Stream +from tests.integration.test_live import appconfig # noqa: F401 @pytest.fixture(scope="module") @@ -28,36 +28,38 @@ def session(appconfig) -> Session: # noqa try: client.login(s, appconfig["username"], appconfig["password"], appconfig["app_key"]) except BetfairError: - pass + pytest.skip("session could not be logged in") return s +SUBSCRIPTIONS = [ + MarketSubscription( + id=1, + heartbeat_ms=500, + market_filter=MarketFilter( + betting_types=[MarketBettingType.ODDS], + event_type_ids=[EventTypeIdCode.HORSE_RACING], + country_codes=["GB", "IE"], + market_types=[MarketTypeCode.WIN], + ), + market_data_filter=MarketDataFilter( + fields=[ + MarketDataFilterFields.EX_MARKET_DEF, + MarketDataFilterFields.EX_ALL_OFFERS, + MarketDataFilterFields.EX_LTP, + MarketDataFilterFields.EX_TRADED_VOL, + ], + ), + ), + OrderSubscription(id=2, heartbeat_ms=500, order_filter=OrderFilter()), +] + + @pytest.mark.parametrize( "subscription", - [ - MarketSubscription( - id=1, - heartbeat_ms=500, - market_filter=MarketFilter( - betting_types=[MarketBettingType.ODDS], - event_type_ids=[EventTypeIdCode.HORSE_RACING], - country_codes=["GB", "IE"], - market_types=[MarketTypeCode.WIN], - ), - market_data_filter=MarketDataFilter( - fields=[ - MarketDataFilterFields.EX_MARKET_DEF, - MarketDataFilterFields.EX_ALL_OFFERS, - MarketDataFilterFields.EX_LTP, - MarketDataFilterFields.EX_TRADED_VOL, - ], - ), - ), - OrderSubscription(id=2, heartbeat_ms=500, order_filter=OrderFilter()), - ], + SUBSCRIPTIONS, ids=lambda x: type(x).__name__, ) -@skip_not_logged_in def test_stream(session, subscription, iterations=3): token = session.headers.get("X-Authentication") app_key = session.headers.get("X-Application") @@ -79,3 +81,32 @@ def test_stream(session, subscription, iterations=3): msg = strm.receive() assert isinstance(msg, req_type) print(msg) + + +@pytest.mark.parametrize( + "subscription", + SUBSCRIPTIONS, + ids=lambda x: type(x).__name__, +) +@pytest.mark.asyncio +async def test_async_stream(session, subscription, iterations=3): + token = session.headers.get("X-Authentication") + app_key = session.headers.get("X-Application") + + async with AsyncStream(STREAM_INTEGRATION) as strm: + await strm.authenticate(app_key, token) + await strm.send(subscription) + msg: Status = await strm.receive() + assert isinstance(msg, Status) + assert not msg.is_error, f"{msg.error_code.name}: {msg.error_message}" + assert not msg.connection_closed + assert msg.id == subscription.id + + print(subscription) + print(msg) + + req_type = MCM if isinstance(subscription, MarketSubscription) else OCM + for _ in range(iterations): + msg = await strm.receive() + assert isinstance(msg, req_type) + print(msg)