Skip to content

Commit

Permalink
Keep stream api more generic
Browse files Browse the repository at this point in the history
  • Loading branch information
ml31415 committed May 3, 2024
1 parent 16f9de8 commit 9f8fd43
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 17 deletions.
58 changes: 46 additions & 12 deletions betfair_parser/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import io
import itertools
import socket
import pathlib
import ssl
import urllib.parse
from collections.abc import AsyncGenerator, Iterable
Expand Down Expand Up @@ -132,15 +133,15 @@ def create_stream_io(endpoint, timeout: float = 15):
return socket.SocketIO(sock, "rwb")


def _message_changes(msg: StreamResponseType) -> Optional[list[str]]:
def changed_markets(msg: StreamResponseType) -> list[str]:
"""Return the market IDs of the markets affected by the given message."""
if isinstance(msg, (Status, Connection)):
return None
return []
if isinstance(msg, MCM) and msg.market_changes:
return [m.id for m in msg.market_changes]
if isinstance(msg, OCM) and msg.order_market_changes:
return [m.id for m in msg.order_market_changes]
return None
return []


class StreamReader:
Expand Down Expand Up @@ -171,15 +172,31 @@ def connect(self, stream: io.RawIOBase) -> None:
stream.write(self.esm.connect()) # send auth
self.esm.receive(stream)

def iter_changes(self, stream: io.RawIOBase) -> Iterable[list[str]]:
"""Iterate over the stream, yielding lists of IDs of the updated markets."""
def iter_changes(self, stream: io.RawIOBase) -> Iterable[ChangeMessageType]:
"""Iterate over the stream, yielding market and order change messages."""
if not self.esm.is_connected:
self.connect(stream)

while True:
changes = _message_changes(self.esm.receive(stream))
if changes:
yield changes
msg = self.esm.receive(stream)
if isinstance(msg, (OCM, MCM)):
yield msg

def iter_changes_and_write(
self,
stream: io.RawIOBase,
path: pathlib.Path | str,
) -> Iterable[ChangeMessageType]:
if not self.esm.is_connected:
self.connect(stream)

with open(path, "ab") as f:
while True:
raw_msg = stream.readline()
msg = self.esm.receive_bytes(raw_msg)
if isinstance(msg, (OCM, MCM)):
yield msg
f.write(raw_msg)


class AsyncStream:
Expand Down Expand Up @@ -247,11 +264,28 @@ async def connect_async(self, stream: AsyncStream) -> None:
await stream.write(self.esm.connect()) # send auth
self.esm.receive_bytes(await stream.readline())

async def iter_changes_async(self, stream: AsyncStream) -> AsyncGenerator[list[str], None]:
async def iter_changes_async(self, stream: AsyncStream) -> AsyncGenerator[ChangeMessageType, None]:
if not self.esm.is_connected:
await self.connect_async(stream)

while True:
changes = _message_changes(self.esm.receive_bytes(await stream.readline()))
if changes:
yield changes
msg = self.esm.receive_bytes(await stream.readline())
if isinstance(msg, (OCM, MCM)):
yield msg

async def iter_changes_and_write_async(
self,
stream: AsyncStream,
path: Union[pathlib.Path, str],
) -> AsyncGenerator[ChangeMessageType, None]:
if not self.esm.is_connected:
await self.connect_async(stream)

loop = asyncio.get_running_loop()
with open(path, "ab") as f:
while True:
raw_msg = await stream.readline()
msg = self.esm.receive_bytes(raw_msg)
if isinstance(msg, (OCM, MCM)):
yield msg
await loop.run_in_executor(None, f.write, raw_msg)
12 changes: 7 additions & 5 deletions tests/integration/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
OrderSubscription,
Status,
)
from betfair_parser.stream import AsyncStream, ExchangeStream, StreamReader, create_stream_io
from betfair_parser.stream import AsyncStream, ExchangeStream, StreamReader, create_stream_io, changed_markets
from tests.integration.test_live import appconfig # noqa: F401


Expand Down Expand Up @@ -132,8 +132,10 @@ def test_stream_reader(session, iterations=15):
sr.subscribe(subscription) # type: ignore[arg-type]

with create_stream_io(STREAM_INTEGRATION) as stream:
for i, changed_ids in enumerate(sr.iter_changes(stream)):
assert len(changed_ids)
for i, change_msg in enumerate(sr.iter_changes(stream)):
changed_ids = changed_markets(change_msg)
if not changed_ids:
continue
assert all(change_id.startswith("1.") for change_id in changed_ids)
assert all(change_id in sr.caches[MARKET_STREAM_ID].order_book for change_id in changed_ids) # type: ignore[union-attr]
if i >= iterations:
Expand All @@ -155,8 +157,8 @@ def test_stream_reader(session, iterations=15):
if not runner_order_book.total_volume:
# skip empty order books
continue
assert runner_order_book.available_to_back
assert runner_order_book.available_to_lay
assert runner_order_book.available_to_back or runner_order_book.available_to_lay[1.01]
assert runner_order_book.available_to_lay or runner_order_book.available_to_back[1000]
assert runner_order_book.last_traded_price

# fields must be deleted when nulled
Expand Down

0 comments on commit 9f8fd43

Please sign in to comment.