Skip to content

Commit

Permalink
AsyncStream and tests added
Browse files Browse the repository at this point in the history
  • Loading branch information
ml31415 committed Sep 25, 2023
1 parent ed6d131 commit 540b370
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 48 deletions.
86 changes: 75 additions & 11 deletions betfair_parser/stream.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import itertools
import socket
import ssl
Expand Down Expand Up @@ -27,15 +28,25 @@ def create_ssl_socket(hostname, timeout: int = 15) -> ssl.SSLSocket:
return secure_sock


class Stream:
_sock: Optional[ssl.SSLSocket]
_io: Optional[socket.SocketIO]
_connection_id: Optional[str]
class _BaseStream:
_connection_id: Optional[str] = None

def __init__(self, endpoint) -> None:
self._endpoint = endpoint
self._id_generator = itertools.count()

@property
def connection_id(self) -> str:
return self._connection_id

def unique_id(self) -> int:
return next(self._id_generator)


class Stream(_BaseStream):
_sock: Optional[ssl.SSLSocket] = None
_io: Optional[socket.SocketIO] = None

def connect(self) -> None:
url = urllib.parse.urlparse(self._endpoint)
self._sock = create_ssl_socket(url.hostname)
Expand All @@ -44,9 +55,6 @@ def connect(self) -> None:
msg: Connection = self.receive() # type: ignore
self._connection_id = msg.connection_id

def unique_id(self) -> int:
return next(self._id_generator)

def send(self, request: STREAM_REQUEST) -> None:
if not self._io:
raise StreamError("Stream is not connected")
Expand All @@ -72,10 +80,6 @@ def close(self):
self._sock.close()
self._sock = None

@property
def connection_id(self) -> str:
return self._connection_id

def authenticate(self, app_key: str, token: str) -> None:
self.send(Authentication(id=self.unique_id(), app_key=app_key, session=token))
msg: Status = self.receive() # type: ignore
Expand All @@ -93,3 +97,63 @@ def __enter__(self):

def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.close()


class AsyncStream(_BaseStream):
_reader: Optional[asyncio.StreamReader] = None
_writer: Optional[asyncio.StreamWriter] = None

async def connect(self) -> None:
url = urllib.parse.urlparse(self._endpoint)
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
context.load_default_certs()
self._reader, self._writer = await asyncio.open_connection(
host=url.hostname,
port=url.port,
ssl=context,
server_hostname=url.hostname,
limit=1_000_000,
)
msg: Connection = await self.receive() # type: ignore
self._connection_id = msg.connection_id

async def send(self, request: STREAM_REQUEST) -> None:
if not self._writer:
raise StreamError("Stream is not connected")
msg = encode(request) + b"\r\n"
self._writer.write(msg)
await self._writer.drain()

async def receive(self) -> STREAM_RESPONSE:
if not self._reader:
raise StreamError("Stream is not connected")
data = await self._reader.readline()
return stream_decode(data)

async def close(self):
if self._writer:
try:
await self._writer.drain()
finally:
self._writer.close()
await self._writer.wait_closed()
self._writer = None
self._reader = None

async def authenticate(self, app_key: str, token: str) -> None:
await self.send(Authentication(id=self.unique_id(), app_key=app_key, session=token))
msg: Status = await self.receive() # type: ignore
if msg.is_error:
raise StreamAuthenticationError(f"{msg.error_code.name}: {msg.error_message}")
if msg.connection_closed:
raise StreamAuthenticationError("Connection was closed by the server unexpectedly")

async def heartbeat(self) -> None:
await self.send(Heartbeat(id=self.unique_id()))

async def __aenter__(self):
await self.connect()
return self

async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
await self.close()
23 changes: 11 additions & 12 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

[project]
name = "betfair_parser"
description = "A betfair parser"
Expand All @@ -19,25 +23,19 @@ classifiers = [
keywords = ["parser", "betfair", "api", "json", "streaming"]
dependencies = [
"msgspec>=0.16.0",
"fsspec>=2022",
]

[project.urls]
Homepage = "https://github.com/limx0/betfair_parser"
Documentation = "https://limx0.github.io/betfair_parser/"
"Bug Tracker" = "https://github.com/limx0/betfair_parser/issues"

[project.optional-dependencies]
dev = [
optional-dependencies.dev = [
"pytest>=7.1",
"pytest-asyncio>=0.21",
"pytest-benchmark>=4.0",
"twine>=4.0.2",
"requests>=2.20.0",
]

[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[project.urls]
Homepage = "https://github.com/limx0/betfair_parser"
Documentation = "https://limx0.github.io/betfair_parser/"
"Bug Tracker" = "https://github.com/limx0/betfair_parser/issues"

[tool.poetry]
name = "betfair_parser"
Expand All @@ -52,6 +50,7 @@ msgspec = ">=0.16"

[tool.poetry.dev-dependencies]
pytest = "^7.1"
pytest-asyncio = "^0.21"
pytest-benchmark = "^4.0"
requests = "^2.20"

Expand Down
81 changes: 56 additions & 25 deletions tests/integration/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
OrderSubscription,
Status,
)
from betfair_parser.stream import Stream
from tests.integration.test_live import appconfig, skip_not_logged_in # noqa: F401
from betfair_parser.stream import AsyncStream, Stream
from tests.integration.test_live import appconfig # noqa: F401


@pytest.fixture(scope="module")
Expand All @@ -28,36 +28,38 @@ def session(appconfig) -> Session: # noqa
try:
client.login(s, appconfig["username"], appconfig["password"], appconfig["app_key"])
except BetfairError:
pass
pytest.skip("session could not be logged in")
return s


SUBSCRIPTIONS = [
MarketSubscription(
id=1,
heartbeat_ms=500,
market_filter=MarketFilter(
betting_types=[MarketBettingType.ODDS],
event_type_ids=[EventTypeIdCode.HORSE_RACING],
country_codes=["GB", "IE"],
market_types=[MarketTypeCode.WIN],
),
market_data_filter=MarketDataFilter(
fields=[
MarketDataFilterFields.EX_MARKET_DEF,
MarketDataFilterFields.EX_ALL_OFFERS,
MarketDataFilterFields.EX_LTP,
MarketDataFilterFields.EX_TRADED_VOL,
],
),
),
OrderSubscription(id=2, heartbeat_ms=500, order_filter=OrderFilter()),
]


@pytest.mark.parametrize(
"subscription",
[
MarketSubscription(
id=1,
heartbeat_ms=500,
market_filter=MarketFilter(
betting_types=[MarketBettingType.ODDS],
event_type_ids=[EventTypeIdCode.HORSE_RACING],
country_codes=["GB", "IE"],
market_types=[MarketTypeCode.WIN],
),
market_data_filter=MarketDataFilter(
fields=[
MarketDataFilterFields.EX_MARKET_DEF,
MarketDataFilterFields.EX_ALL_OFFERS,
MarketDataFilterFields.EX_LTP,
MarketDataFilterFields.EX_TRADED_VOL,
],
),
),
OrderSubscription(id=2, heartbeat_ms=500, order_filter=OrderFilter()),
],
SUBSCRIPTIONS,
ids=lambda x: type(x).__name__,
)
@skip_not_logged_in
def test_stream(session, subscription, iterations=3):
token = session.headers.get("X-Authentication")
app_key = session.headers.get("X-Application")
Expand All @@ -79,3 +81,32 @@ def test_stream(session, subscription, iterations=3):
msg = strm.receive()
assert isinstance(msg, req_type)
print(msg)


@pytest.mark.parametrize(
"subscription",
SUBSCRIPTIONS,
ids=lambda x: type(x).__name__,
)
@pytest.mark.asyncio
async def test_async_stream(session, subscription, iterations=3):
token = session.headers.get("X-Authentication")
app_key = session.headers.get("X-Application")

async with AsyncStream(STREAM_INTEGRATION) as strm:
await strm.authenticate(app_key, token)
await strm.send(subscription)
msg: Status = await strm.receive()
assert isinstance(msg, Status)
assert not msg.is_error, f"{msg.error_code.name}: {msg.error_message}"
assert not msg.connection_closed
assert msg.id == subscription.id

print(subscription)
print(msg)

req_type = MCM if isinstance(subscription, MarketSubscription) else OCM
for _ in range(iterations):
msg = await strm.receive()
assert isinstance(msg, req_type)
print(msg)

0 comments on commit 540b370

Please sign in to comment.