This repository has been archived by the owner on Sep 9, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
net.py
119 lines (90 loc) · 3.44 KB
/
net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import logging
import json
import trio
from utils import truncate_middle
from typings import *
log = logging.getLogger(__name__)
log.setLevel(logging.INFO)
BUFSIZE = 4096
class ConnectionClosed(Exception):
pass
class JSONStream:
""" A wrapper around a trio.Stream """
def __init__(self, stream: trio.abc.Stream):
self._stream = stream
self._write_cap = trio.CapacityLimiter(1)
# blocks reading from stream and _read_buf at the same time
self._read_cap = trio.CapacityLimiter(1)
self._read_buf = bytearray()
async def read(self) -> Message:
async with self._read_cap:
log.debug(f"Acquired reading semaphore")
i = self._read_buf.find(b"\n")
while i == -1:
try:
data = await self._stream.receive_some(BUFSIZE)
except trio.BrokenResourceError:
raise ConnectionClosed("stream closed suddenly while reading")
if not data:
raise ConnectionClosed("stream closed while reading")
self._read_buf += data
i = self._read_buf.find(b"\n")
log.debug(f"Adding to buffer {data}")
i += 1
line = str(self._read_buf[:i], encoding="utf-8")
self._read_buf[:i] = []
log.debug(f"(release read semaphore) Parsing line: {line!r}")
if line.strip() == "":
raise ValueError(f"Invalid empty value: {line!r}")
try:
obj = cast(Message, json.loads(line))
except ValueError:
log.exception(f"Invalid JSON: {line!r}")
raise
if not isinstance(obj, dict):
raise ValueError(f"should be dict, got {type(obj)} in {obj}")
log.debug(f"Read {obj!r}")
return obj
async def write(self, obj: Message) -> None:
log.debug(f"Sending {obj}")
if not isinstance(obj, dict):
raise ValueError(f"should send dict, got {obj!r}")
async with self._write_cap:
log.debug(f"Sending {obj}")
try:
await self._stream.send_all(
bytes(json.dumps(obj) + "\n", encoding="utf-8")
)
except trio.BrokenResourceError:
raise ConnectionClosed(f"stream closed while writing")
async def aclose(self) -> None:
log.info(f"closing stream {self}")
acquired = 0
with trio.move_on_after(2) as cancel_scope:
await self._write_cap.acquire()
log.debug("Got write semaphore")
acquired += 1
await self._read_cap.acquire()
log.debug("Got read semaphore")
acquired += 1
if cancel_scope.cancelled_caught:
log.warning(
"Forcefully closing stream after 2 seconds, "
f"{acquired} semaphore(s) acquired"
)
await self._stream.aclose()
log.debug("Stream closed, releasing semaphores")
if acquired >= 1:
self._write_cap.release()
if acquired >= 2:
self._read_cap.release()
def __str__(self) -> str:
return f"JSONStream({truncate_middle(repr(self._stream), 20)})"
def __repr__(self) -> str:
return str(self)
def __eq__(self, o: Any) -> bool:
return (
isinstance(o, JSONStream)
and self._stream is o._stream
and self._read_buf == o._read_buf
)