Skip to content

Commit

Permalink
Update Worker to support concurrent predictions
Browse files Browse the repository at this point in the history
This adds a max_concurrency parameter to Worker that allows it to accept
predictions while others are already in-flight.

I have removed the PROCESSING WorkerState; in concurrent mode, there is no
distinction between READY and PROCESSING because we might be able to accept a
prediction in either case.

Worker now keeps track of multiple in-flight predictions in a dictionary, keyed
on tag.  Tags are required if max_concurrency > 1.  Otherwise tags are
optional (and, if omitted, we store the prediction with tag=None).

There is one awkward place which is _prepare_payload().  As I understand it,
this synchronously downloads URLFiles, which will block us from processing any
other updates from the child while we download the URL.
  • Loading branch information
philandstuff committed Nov 18, 2024
1 parent f8a5f8c commit 9c51063
Show file tree
Hide file tree
Showing 3 changed files with 214 additions and 91 deletions.
217 changes: 133 additions & 84 deletions python/cog/server/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import inspect
import multiprocessing
import os
import select
import signal
import sys
import threading
Expand All @@ -25,6 +26,7 @@
)

import structlog
from attrs import define

from ..base_predictor import BasePredictor
from ..json import make_encodeable
Expand Down Expand Up @@ -69,53 +71,90 @@ class WorkerState(Enum):
NEW = auto()
STARTING = auto()
READY = auto()
PROCESSING = auto()
DEFUNCT = auto()


@define
class PredictionRequest:
tag: Optional[str]


@define
class CancelRequest:
tag: Optional[str]


class PredictionState:
def __init__(
self, tag: Optional[str], payload: Dict[str, Any], result: "Future[Done]"
) -> None:
self.tag = tag
self.payload = payload
self.result = result

self.cancel_sent = False


class Worker:
def __init__(self, child: "_ChildWorker", events: Connection) -> None:
def __init__(
self, child: "_ChildWorker", events: Connection, max_concurrency: int = 1
) -> None:
self._child = child
self._events = events

self._allow_cancel = False
self._sent_shutdown_event = False
self._state = WorkerState.NEW
self._terminating = False

self._result: Optional["Future[Done]"] = None
self._setup_result: "Future[Done]" = Future()
self._subscribers_lock = threading.Lock()
self._subscribers: Dict[
int, Tuple[Callable[[_PublicEventType], None], Optional[str]]
] = {}

self._predict_tag: Optional[str] = None
self._predict_payload: Optional[Dict[str, Any]] = None
self._predict_start = threading.Event() # set when a prediction is started
self._max_concurrency = max_concurrency

self._predictions_lock = threading.Lock()
self._predictions_in_flight: Dict[Optional[str], PredictionState] = {}

recv_conn, send_conn = _spawn.Pipe(duplex=False)
self._request_send_conn = send_conn
self._request_recv_conn = recv_conn

self._pool = ThreadPoolExecutor(max_workers=1)
self._event_consumer = None

def setup(self) -> "Future[Done]":
self._assert_state(WorkerState.NEW)
self._state = WorkerState.STARTING
result = Future()
self._result = result
self._child.start()
self._event_consumer = self._pool.submit(self._consume_events)
return result
return self._setup_result

def predict(
self, payload: Dict[str, Any], tag: Optional[str] = None
) -> "Future[Done]":
self._assert_state(WorkerState.READY)
self._state = WorkerState.PROCESSING
self._allow_cancel = True
result = Future()
self._result = result
self._predict_tag = tag
self._predict_payload = payload
self._predict_start.set()
# TODO: tag is Optional, but it's required when in concurrent mode and
# basically unnecesary in sequential mode. Should we have a separate
# ConcurrentWorker?
if self._max_concurrency > 1 and tag is None:
raise TypeError(
"Invalid operation: tag is required when Worker has max_concurrency > 1"
)

with self._predictions_lock:
if len(self._predictions_in_flight) >= self._max_concurrency:
raise InvalidStateException(
"Invalid operation: maximum predictions in flight reached"
)
if tag in self._predictions_in_flight:
raise InvalidStateException(
f"Invalid operation: prediction with tag {tag} already running"
)
self._assert_state(WorkerState.READY)
result = Future()
self._predictions_in_flight[tag] = PredictionState(tag, payload, result)
self._request_send_conn.send(PredictionRequest(tag))
return result

def subscribe(
Expand Down Expand Up @@ -164,10 +203,7 @@ def terminate(self) -> None:
self._pool.shutdown(wait=False)

def cancel(self, tag: Optional[str] = None) -> None:
if self._allow_cancel:
self._child.send_cancel()
self._events.send(Envelope(event=Cancel(), tag=tag))
self._allow_cancel = False
self._request_send_conn.send(CancelRequest(tag))

def _assert_state(self, state: WorkerState) -> None:
if self._state != state:
Expand Down Expand Up @@ -200,90 +236,101 @@ def _consume_events_inner(self) -> None:
# If we didn't get a done event, the child process died.
if not done:
exitcode = self._child.exitcode
assert self._result
self._result.set_exception(
assert self._setup_result
self._setup_result.set_exception(
FatalWorkerException(
f"Predictor setup failed for an unknown reason. It might have run out of memory? (exitcode {exitcode})"
)
)
self._result = None
self._state = WorkerState.DEFUNCT
return
if done.error:
assert self._result
self._result.set_exception(
assert self._setup_result
self._setup_result.set_exception(
FatalWorkerException(
"Predictor errored during setup: " + done.error_detail
)
)
self._result = None
self._state = WorkerState.DEFUNCT
return

assert self._result

# We capture the setup future and then set state to READY before
# completing it, so that we can immediately accept work.
result = self._result
self._result = None
self._state = WorkerState.READY
result.set_result(done)
self._setup_result.set_result(done)

# Predictions
# Main event loop
while self._child.is_alive():
start = self._predict_start.wait(timeout=0.1)
if not start:
continue
# see if we have any new prediction requests

assert self._predict_payload is not None
assert self._result

# Prepare payload (download URLPath objects)
try:
_prepare_payload(self._predict_payload)
except Exception as e:
done = Envelope(
event=Done(error=True, error_detail=str(e)),
tag=self._predict_tag,
)
self._publish(done)
else:
# Start the prediction
self._events.send(
Envelope(
event=PredictionInput(payload=self._predict_payload),
tag=self._predict_tag,
)
)

# Consume and publish prediction events
done = self._consume_events_until_done()
if not done:
break

# We capture the predict future and then reset state before
# completing it, so that we can immediately accept work.
result = self._result
self._predict_tag = None
self._predict_payload = None
self._predict_start.clear()
self._result = None
self._state = WorkerState.READY
self._allow_cancel = False
result.set_result(done)
read_socks, _, _ = select.select(
[self._request_recv_conn, self._events], [], [], 0.1
)
for sock in read_socks:
if sock == self._request_recv_conn:
ev = self._request_recv_conn.recv()
if isinstance(ev, PredictionRequest):
with self._predictions_lock:
state = self._predictions_in_flight[ev.tag]

# Prepare payload (download URLPath objects)
# FIXME this blocks the event loop, which is bad in concurrent mode
try:
_prepare_payload(state.payload)
except Exception as e:
done = Done(error=True, error_detail=str(e))
self._publish(Envelope(done, state.tag))
self._complete_prediction(done, state.tag)
else:
# Start the prediction
self._events.send(
Envelope(
event=PredictionInput(payload=state.payload),
tag=state.tag,
)
)
if isinstance(ev, CancelRequest):
with self._predictions_lock:
predict_state = self._predictions_in_flight.get(ev.tag)
if predict_state and not predict_state.cancel_sent:
self._child.send_cancel()
self._events.send(Envelope(event=Cancel(), tag=ev.tag))
predict_state.cancel_sent = True

else: # sock == self._events
ev = self._events.recv()
self._publish(ev)
if isinstance(ev.event, Done):
self._complete_prediction(ev.event, ev.tag)

# If we dropped off the end off the end of the loop, it's because the
# child process died.
# child process died. First, process any remaining messages on the connection
while self._events.poll():
ev = self._events.recv()
self._publish(ev)
if isinstance(ev.event, Done):
self._complete_prediction(ev.event, ev.tag)

if not self._terminating:
if self._result:
exitcode = self._child.exitcode
self._result.set_exception(
FatalWorkerException(
f"Prediction failed for an unknown reason. It might have run out of memory? (exitcode {exitcode})"
)
)
self._result = None
self._state = WorkerState.DEFUNCT
with self._predictions_lock:
for tag in list(self._predictions_in_flight.keys()):
exitcode = self._child.exitcode
self._predictions_in_flight[tag].result.set_exception(
FatalWorkerException(
f"Prediction failed for an unknown reason. It might have run out of memory? (exitcode {exitcode})"
)
)
del self._predictions_in_flight[tag]

def _complete_prediction(self, done: Done, tag: Optional[str]) -> None:
# We update the in-flight dictionary before completing the prediction
# future, so that we can immediately accept work.
with self._predictions_lock:
predict_state = self._predictions_in_flight.pop(tag)
if len(self._predictions_in_flight) == 0:
self._state = WorkerState.READY
predict_state.result.set_result(done)

def _publish(self, e: Envelope) -> None:
with self._subscribers_lock:
Expand Down Expand Up @@ -637,10 +684,12 @@ def _stream_write_hook(self, stream_name: str, data: str) -> None:
)


def make_worker(predictor_ref: str, tee_output: bool = True) -> Worker:
def make_worker(
predictor_ref: str, tee_output: bool = True, max_concurrency: int = 1
) -> Worker:
parent_conn, child_conn = _spawn.Pipe()
child = _ChildWorker(predictor_ref, events=child_conn, tee_output=tee_output)
parent = Worker(child=child, events=parent_conn)
parent = Worker(child=child, events=parent_conn, max_concurrency=max_concurrency)
return parent


Expand Down
20 changes: 16 additions & 4 deletions python/tests/server/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class AppConfig:
class WorkerConfig:
fixture_name: str
setup: bool = True
max_concurrency: int = 1


def pytest_make_parametrize_id(config, val):
Expand Down Expand Up @@ -70,7 +71,7 @@ def uses_predictor_with_client_options(name, **options):
)


def uses_worker(name_or_names, setup=True):
def uses_worker(name_or_names, setup=True, max_concurrency=1):
"""
Decorator for tests that require a Worker instance. `name_or_names` can be
a single fixture name, or a sequence (list, tuple) of fixture names. If
Expand All @@ -79,9 +80,16 @@ def uses_worker(name_or_names, setup=True):
If `setup` is True (the default) setup will be run before the test runs.
"""
if isinstance(name_or_names, (tuple, list)):
values = (WorkerConfig(fixture_name=n, setup=setup) for n in name_or_names)
values = (
WorkerConfig(fixture_name=n, setup=setup, max_concurrency=max_concurrency)
for n in name_or_names
)
else:
values = (WorkerConfig(fixture_name=name_or_names, setup=setup),)
values = (
WorkerConfig(
fixture_name=name_or_names, setup=setup, max_concurrency=max_concurrency
),
)
return pytest.mark.parametrize("worker", values, indirect=True)


Expand Down Expand Up @@ -143,7 +151,11 @@ def static_schema(client) -> dict:
@pytest.fixture
def worker(request):
ref = _fixture_path(request.param.fixture_name)
w = make_worker(predictor_ref=ref, tee_output=False)
w = make_worker(
predictor_ref=ref,
tee_output=False,
max_concurrency=request.param.max_concurrency,
)
if request.param.setup:
assert not w.setup().result().error
try:
Expand Down
Loading

0 comments on commit 9c51063

Please sign in to comment.