Skip to content

Commit

Permalink
Improve task scheduling
Browse files Browse the repository at this point in the history
  • Loading branch information
oeway committed Dec 4, 2024
1 parent 7b93e43 commit 2bacdf9
Show file tree
Hide file tree
Showing 9 changed files with 321 additions and 27 deletions.
15 changes: 14 additions & 1 deletion hypha/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,20 @@ def shutdown(_) -> None:

self.event_bus.on_local("shutdown", shutdown)

async def client_disconnected(info: dict) -> None:
"""Handle client disconnected event."""
# {"id": client_id, "workspace": ws}
client_id = info["id"]
full_client_id = info["workspace"] + "/" + client_id
if full_client_id in self._sessions:
app_info = self._sessions.pop(full_client_id, None)
try:
await app_info["_runner"].stop(full_client_id)
except Exception as exp:
logger.warning(f"Failed to stop browser tab: {exp}")

self.event_bus.on_local("client_disconnected", client_disconnected)

async def get_runners(self):
# start the browser runner
server = await self.store.get_public_api()
Expand Down Expand Up @@ -424,7 +438,6 @@ async def _close_after_time_limit():
await asyncio.sleep(time_limit)
if full_client_id in self._sessions:
await runner.stop(full_client_id)
del self._sessions[full_client_id]
logger.info(
f"App {full_client_id} stopped after time limit {time_limit} reached."
)
Expand Down
2 changes: 1 addition & 1 deletion hypha/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ async def disconnect(self, reason=None):
self._event_bus.off(f"{self._workspace}/*:msg", self._handle_message)

self._handle_message = None
logger.info(f"Redis Connection Disconnected: {reason}")
logger.debug(f"Redis Connection Disconnected: {reason}")
if self._handle_disconnected:
await self._handle_disconnected(reason)

Expand Down
104 changes: 82 additions & 22 deletions hypha/core/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@
from hypha_rpc import RPC
from hypha_rpc.utils.schema import schema_method
from starlette.routing import Mount
from pydantic.fields import Field
from aiocache.backends.redis import RedisCache
from aiocache.serializers import PickleSerializer
from taskiq import TaskiqScheduler
from taskiq.api import run_receiver_task, run_scheduler_task
from hypha.taskiq_utils.redis_broker import ListQueueBroker
from hypha.taskiq_utils.schedule_source import RedisScheduleSource


from hypha import __version__
from hypha.core import (
Expand Down Expand Up @@ -168,14 +172,24 @@ def __init__(
from redis import asyncio as aioredis

self._redis = aioredis.from_url(redis_uri)

else: # Create a redis server with fakeredis
from fakeredis import aioredis

self._redis = aioredis.FakeRedis.from_url("redis://localhost:9997/11")

self._broker = ListQueueBroker(self._redis)
self._broker_task = None
self._source = RedisScheduleSource(self._redis)
self._scheduler = TaskiqScheduler(
broker=self._broker,
sources=[self._source],
)
self._house_keeping_schedule = None
self._first_run = True

self._redis_cache = RedisCache(serializer=PickleSerializer())
self._redis_cache.client = self._redis

self._root_user = None
self._event_bus = RedisEventBus(self._redis)

Expand Down Expand Up @@ -278,23 +292,25 @@ async def _run_startup_functions(self, startup_functions):

async def housekeeping(self):
"""Perform housekeeping tasks."""
# Perform housekeeping tasks
# Start the housekeeping task after 2 minutes
logger.info("Starting housekeeping task in 2 minutes...")
await asyncio.sleep(120)
while True:
try:
logger.info("Running housekeeping task...")
async with self.get_workspace_interface(
self._root_user, "ws-user-root", client_id="housekeeping"
) as api:
# admin = await api.get_service("admin-utils")
workspaces = await api.list_workspaces()
for workspace in workspaces:
await api.cleanup(workspace.id)
await asyncio.sleep(3600)
except Exception as e:
logger.exception(f"Error in housekeeping: {e}")
if self._first_run:
logger.info("Skipping housekeeping on first run")
self._first_run = False
return
try:
logger.info(f"Running housekeeping task at {datetime.datetime.now()}")
async with self.get_workspace_interface(
self._root_user, "ws-user-root", client_id="housekeeping"
) as api:
# admin = await api.get_service("admin-utils")
workspaces = await api.list_workspaces()
for workspace in workspaces:
summary = await api.cleanup(workspace.id)
if "removed_clients" in summary:
logger.info(
f"Removed {len(summary['removed_clients'])} clients from workspace {workspace.id}"
)
except Exception as e:
logger.exception(f"Error in housekeeping: {e}")

async def upgrade(self):
"""Upgrade the store."""
Expand Down Expand Up @@ -521,7 +537,36 @@ async def init(self, reset_redis, startup_functions=None):
logger.info("Server initialized with server id: %s", self._server_id)
logger.info("Currently connected hypha servers: %s", servers)

asyncio.create_task(self.housekeeping())
# Setup broker and scheduler
await self._broker.startup()
self._broker_task = asyncio.create_task(run_receiver_task(self._broker))
await self._scheduler.startup()
self._scheduler_task = asyncio.create_task(run_scheduler_task(self._scheduler))

# Do house keeping every 10 minutes
self._house_keeping_schedule = await self.schedule_task(
self.housekeeping, task_name="housekeeping", corn="*/1 * * * *"
)
self._first_run = True

async def schedule_task(
self,
task,
*args,
corn: str = None,
time: datetime.datetime = None,
task_name: str = None,
**kwargs,
):
"""Schedule a task."""
assert not (corn and time), "Only one of corn or time can be provided"
if corn:
my_task = self._broker.register_task(task, task_name=task_name)
return await my_task.schedule_by_cron(self._source, corn, *args, **kwargs)
if time:
my_task = self._broker.register_task(my_task, task_name=task_name)
return await my_task.schedule_by_time(self._source, time, *args, **kwargs)
return await my_task.kiq(*args, **kwargs)

async def _register_root_services(self):
"""Register root services."""
Expand Down Expand Up @@ -749,7 +794,7 @@ def create_rpc(
"""Create a rpc object for a workspace."""
client_id = client_id or "anonymous-client-" + random_id(readable=False)
assert "/" not in client_id
logger.info("Creating RPC for client %s", client_id)
logger.debug("Creating RPC for client %s", client_id)
assert user_info is not None, "User info is required"
connection = RedisRPCConnection(
self._event_bus,
Expand Down Expand Up @@ -865,7 +910,22 @@ def unmount_app(self, path):
async def teardown(self):
"""Teardown the server."""
self._ready = False
logger.info("Tearing down the public workspace...")
logger.info("Tearing down the redis store...")
if self._house_keeping_schedule:
await self._house_keeping_schedule.unschedule()
self._broker_task.cancel()
try:
await self._broker_task
except asyncio.CancelledError:
print("Broker successfully exited.")
await self._broker.shutdown()
self._scheduler_task.cancel()
try:
await self._scheduler_task
except asyncio.CancelledError:
print("Scheduler successfully exited.")
await self._scheduler.shutdown()

client_id = self._public_workspace_interface.rpc.get_client_info()["id"]
await self.remove_client(client_id, "public", self._root_user, unload=True)
client_id = self._root_workspace_interface.rpc.get_client_info()["id"]
Expand Down
16 changes: 13 additions & 3 deletions hypha/core/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ def __init__(
self._active_svc = Gauge(
"active_services", "Number of active services", ["workspace"]
)
self._active_clients = Gauge(
"active_clients", "Number of active clients", ["workspace"]
)
self._enable_service_search = enable_service_search

async def _get_sql_session(self):
Expand Down Expand Up @@ -1196,9 +1199,7 @@ async def register_service(
"client_connected", {"id": client_id, "workspace": ws}
)
logger.info(f"Adding built-in service: {service.id}")
builtins = await self._redis.keys(
f"services:*|*:{client_id}:built-in@*"
)
self._active_clients.labels(workspace=ws).inc()
else:
# Remove the service embedding from the config
if service.config and service.config.service_embedding is not None:
Expand Down Expand Up @@ -1323,6 +1324,7 @@ async def unregister_service(
await self._event_bus.emit(
"client_disconnected", {"id": client_id, "workspace": ws}
)
self._active_clients.labels(workspace=ws).dec()
else:
await self._event_bus.emit("service_removed", service.model_dump())
self._active_svc.labels(workspace=ws).dec()
Expand Down Expand Up @@ -1710,6 +1712,12 @@ async def delete_client(
for key in keys:
await self._redis.delete(key)

await self._event_bus.emit(
"client_disconnected", {"id": client_id, "workspace": cws}
)
self._active_clients.labels(workspace=cws).dec()
self._active_svc.labels(workspace=cws).dec(len(keys) - 1)

if unload:
if await self._redis.hexists("workspaces", cws):
if user_info.is_anonymous and cws == user_info.get_workspace():
Expand Down Expand Up @@ -1765,6 +1773,8 @@ async def unload(self, context=None):
await self._s3_controller.cleanup_workspace(winfo)

self._active_ws.dec()
self._active_clients.remove(ws)
self._active_svc.remove(ws)

await self._event_bus.emit("workspace_unloaded", winfo.model_dump())
logger.info("Workspace %s unloaded.", ws)
Expand Down
Empty file added hypha/taskiq_utils/__init__.py
Empty file.
117 changes: 117 additions & 0 deletions hypha/taskiq_utils/redis_broker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""This file contains modules from the taskiq-redis package
We need to patch it so we can use fakeredis when redis is not available.
The library was created by Pavel Kirilin, released under the MIT license.
"""
from fakeredis import aioredis
from logging import getLogger
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Optional, TypeVar

from taskiq.abc.broker import AsyncBroker
from taskiq.abc.result_backend import AsyncResultBackend
from taskiq.message import BrokerMessage

_T = TypeVar("_T")

logger = getLogger("taskiq.redis_broker")


class BaseRedisBroker(AsyncBroker):
"""Base broker that works with Redis."""

def __init__(
self,
redis: aioredis.FakeRedis,
task_id_generator: Optional[Callable[[], str]] = None,
result_backend: Optional[AsyncResultBackend[_T]] = None,
queue_name: str = "taskiq",
) -> None:
"""
Constructs a new broker.
:param url: url to redis.
:param task_id_generator: custom task_id generator.
:param result_backend: custom result backend.
:param queue_name: name for a list in redis.
:param max_connection_pool_size: maximum number of connections in pool.
Each worker opens its own connection. Therefore this value has to be
at least number of workers + 1.
:param connection_kwargs: additional arguments for redis BlockingConnectionPool.
"""
super().__init__(
result_backend=result_backend,
task_id_generator=task_id_generator,
)

self.redis = redis
self.queue_name = queue_name

async def shutdown(self) -> None:
"""Closes redis connection pool."""
await super().shutdown()


class PubSubBroker(BaseRedisBroker):
"""Broker that works with Redis and broadcasts tasks to all workers."""

async def kick(self, message: BrokerMessage) -> None:
"""
Publish message over PUBSUB channel.
:param message: message to send.
"""
queue_name = message.labels.get("queue_name") or self.queue_name
await self.redis.publish(queue_name, message.message)

async def listen(self) -> AsyncGenerator[bytes, None]:
"""
Listen redis queue for new messages.
This function listens to the pubsub channel
and yields all messages with proper types.
:yields: broker messages.
"""

redis_pubsub_channel = self.redis.pubsub()
await redis_pubsub_channel.subscribe(self.queue_name)
async for message in redis_pubsub_channel.listen():
if not message:
continue
if message["type"] != "message":
logger.debug("Received non-message from redis: %s", message)
continue
yield message["data"]


class ListQueueBroker(BaseRedisBroker):
"""Broker that works with Redis and distributes tasks between workers."""

async def kick(self, message: BrokerMessage) -> None:
"""
Put a message in a list.
This method appends a message to the list of all messages.
:param message: message to append.
"""
queue_name = message.labels.get("queue_name") or self.queue_name
await self.redis.lpush(queue_name, message.message)

async def listen(self) -> AsyncGenerator[bytes, None]:
"""
Listen redis queue for new messages.
This function listens to the queue
and yields new messages if they have BrokerMessage type.
:yields: broker messages.
"""
redis_brpop_data_position = 1
while True:
try:
yield (await self.redis.brpop(self.queue_name))[
redis_brpop_data_position
]
except ConnectionError as exc:
logger.warning("Redis connection error: %s", exc)
continue
Loading

0 comments on commit 2bacdf9

Please sign in to comment.