From 867be6398927268627af37047960559bd8dbe200 Mon Sep 17 00:00:00 2001 From: Wei Ouyang Date: Fri, 6 Dec 2024 08:02:32 -0800 Subject: [PATCH] Fix artifact unload and load --- hypha/apps.py | 7 +- hypha/artifact.py | 10 +- hypha/core/workspace.py | 30 +++-- hypha/queue.py | 15 --- hypha/vectors.py | 20 ++- tests/test_vectors.py | 291 ++++++++++++++++++++++------------------ 6 files changed, 209 insertions(+), 164 deletions(-) diff --git a/hypha/apps.py b/hypha/apps.py index b3b50575..f059c9fa 100644 --- a/hypha/apps.py +++ b/hypha/apps.py @@ -79,6 +79,8 @@ async def get_runners(self): # start the browser runner server = await self.store.get_public_api() svcs = await server.list_services("public/server-app-worker") + if not svcs: + return [] runners = [await server.get_service(svc["id"]) for svc in svcs] if runners: return runners @@ -633,7 +635,10 @@ async def close_workspace(self, workspace_info: WorkspaceInfo): if app["workspace"] == workspace_info.id: await self._stop(app["id"], raise_exception=False) # Send to all runners - for runner in await self.get_runners(): + runners = await self.get_runners() + if not runners: + return + for runner in runners: try: await runner.close_workspace(workspace_info.id) except Exception as exp: diff --git a/hypha/artifact.py b/hypha/artifact.py index 2ac032d0..aadea68a 100644 --- a/hypha/artifact.py +++ b/hypha/artifact.py @@ -2776,7 +2776,9 @@ async def prepare_workspace(self, workspace_info: WorkspaceInfo): bucket=s3_config["bucket"], prefix=prefix, ) - logger.info(f"Artifacts in workspace {workspace_info.id} prepared.") + logger.info( + f"Artifacts (#{len(artifacts)}) in workspace {workspace_info.id} prepared." + ) except Exception as e: logger.error(f"Error preparing workspace: {traceback.format_exc()}") raise e @@ -2818,10 +2820,12 @@ async def close_workspace(self, workspace_info: WorkspaceInfo): f"{artifact.workspace}/{artifact.alias}" ) logger.info( - f"Artifacts in workspace {workspace_info.id} prepared for closure." + f"Artifacts (#{len(artifacts)}) in workspace {workspace_info.id} prepared for closure." ) except Exception as e: - logger.error(f"Error closing workspace: {traceback.format_exc()}") + logger.error( + f"Error closing workspace {workspace_info.id}: {traceback.format_exc()}" + ) raise e finally: await session.close() diff --git a/hypha/core/workspace.py b/hypha/core/workspace.py index fc02193d..ebe1ffe6 100644 --- a/hypha/core/workspace.py +++ b/hypha/core/workspace.py @@ -1754,6 +1754,8 @@ async def unload(self, context=None): """Unload the workspace.""" self.validate_context(context, permission=UserPermission.admin) ws = context["ws"] + if not await self._redis.hexists("workspaces", ws): + raise KeyError(f"Workspace {ws} has already been unloaded.") winfo = await self.load_workspace_info(ws) # list all the clients in the workspace and send a meesage to delete them client_keys = await self._list_client_keys(winfo.id) @@ -1770,16 +1772,13 @@ async def unload(self, context=None): # Mark the workspace as not ready winfo.status = None - if winfo.persistent and self._s3_controller: - # since the workspace will be persisted, we can remove the workspace info from the redis store - await self._redis.hdel("workspaces", ws) - elif not winfo.persistent: + if not winfo.persistent: # delete all the items in redis starting with `workspaces_name:` + # Including the queue and other associated resources keys = await self._redis.keys(f"{ws}:*") for key in keys: await self._redis.delete(key) - await self._redis.hdel("workspaces", ws) if self._s3_controller: await self._s3_controller.cleanup_workspace(winfo) @@ -1789,11 +1788,8 @@ async def unload(self, context=None): self._active_svc.remove(ws) except KeyError: pass - await self._close_workspace(winfo) - - await self._event_bus.emit("workspace_unloaded", winfo.model_dump()) - logger.info("Workspace %s unloaded.", ws) + await self._redis.hdel("workspaces", ws) async def _prepare_workspace(self, workspace_info: WorkspaceInfo): """Prepare the workspace.""" @@ -1824,10 +1820,20 @@ async def _close_workspace(self, workspace_info: WorkspaceInfo): ), "Workspace must be unloaded before archiving." if workspace_info.persistent: if self._artifact_manager: - await self._artifact_manager.close_workspace(workspace_info) + try: + await self._artifact_manager.close_workspace(workspace_info) + except Exception as e: + logger.error(f"Aritfact manager failed to close workspace: {e}") if self._server_app_controller: - await self._server_app_controller.close_workspace(workspace_info) - logger.info("Workspace %s archived.", workspace_info.id) + try: + await self._server_app_controller.close_workspace(workspace_info) + except Exception as e: + logger.error( + f"Server app controller failed to close workspace: {e}" + ) + + await self._event_bus.emit("workspace_unloaded", workspace_info.model_dump()) + logger.info("Workspace %s unloaded.", workspace_info.id) @schema_method async def wait_until_ready(self, timeout: Optional[int] = 10, context=None): diff --git a/hypha/queue.py b/hypha/queue.py index 02fecbb8..fcf3e7be 100644 --- a/hypha/queue.py +++ b/hypha/queue.py @@ -15,21 +15,6 @@ def create_queue_service(store: RedisStore): """Create a queue service for Hypha.""" redis: aioredis.FakeRedis = store.get_redis() - event_bus = store.get_event_bus() - - async def on_workspace_unloaded(workspace): - # delete all the keys that start with workspace["name"] + ":q:" - keys_pattern = workspace["name"] + ":q:*" - cursor = "0" - while cursor != 0: - cursor, keys = await redis.scan(cursor=cursor, match=keys_pattern) - if keys: - await redis.delete(*keys) - if cursor != "0": - logger.info("Removed queue keys for workspace: %s", workspace["name"]) - - event_bus.on_local("workspace_unloaded", on_workspace_unloaded) - async def push_task(queue_name, task: dict, context: dict = None): workspace = context["ws"] await redis.lpush(workspace + ":q:" + queue_name, json.dumps(task)) diff --git a/hypha/vectors.py b/hypha/vectors.py index 894d6747..701065b4 100644 --- a/hypha/vectors.py +++ b/hypha/vectors.py @@ -81,7 +81,10 @@ def _get_index_name(self, collection_name: str) -> str: async def _get_fields(self, collection_name: str): index_name = self._get_index_name(collection_name) - info = await self._redis.ft(index_name).info() + try: + info = await self._redis.ft(index_name).info() + except aioredis.ResponseError: + raise KeyError(f"Vector collection {collection_name} does not exist.") fields = parse_attributes(info["attributes"]) return fields @@ -318,10 +321,17 @@ async def load_collection( if isinstance(value, list): vector_data[key] = np.array(value, dtype=np.float32) - # Add vector to Redis - await self.add_vectors( - collection_name, [{"id": obj["name"].split(".")[0], **vector_data}] - ) + try: + # Add vector to Redis + await self.add_vectors( + collection_name, + [{"id": obj["name"].split(".")[0], **vector_data}], + ) + except Exception as e: + logger.error( + f"Failed to load vector {obj['name']} from S3 bucket {bucket} under prefix {prefix}: {e}" + ) + raise e logger.info( f"Collection {collection_name} loaded from S3 bucket {bucket} under prefix {prefix}." diff --git a/tests/test_vectors.py b/tests/test_vectors.py index 2f316805..ba64071c 100644 --- a/tests/test_vectors.py +++ b/tests/test_vectors.py @@ -186,153 +186,170 @@ async def test_artifact_vector_collection( """Test vector-related functions within a vector-collection artifact.""" # Connect to the server and set up the artifact manager - api = await connect_to_server( + async with connect_to_server( { "name": "test deploy client", "server_url": SERVER_URL_REDIS_1, "token": test_user_token, } - ) - artifact_manager = await api.get_service("public/artifact-manager") - - # Create a vector-collection artifact - vector_collection_manifest = { - "name": "vector-collection", - "description": "A test vector collection", - } - vector_collection_config = { - "vector_fields": [ + ) as api: + await api.create_workspace( { - "type": "VECTOR", - "name": "vector", - "algorithm": "FLAT", - "attributes": { - "TYPE": "FLOAT32", - "DIM": 384, - "DISTANCE_METRIC": "COSINE", - }, + "name": "my-vector-test-workspace", + "description": "This is a test workspace", + "persistent": True, }, - {"type": "TEXT", "name": "text"}, - {"type": "TAG", "name": "label"}, - {"type": "NUMERIC", "name": "rand_number"}, - ], - "embedding_models": { - "vector": "fastembed:BAAI/bge-small-en-v1.5", - }, - } - vector_collection = await artifact_manager.create( - type="vector-collection", - manifest=vector_collection_manifest, - config=vector_collection_config, - ) - # Add vectors to the collection - vectors = [ - { - "vector": [random.random() for _ in range(384)], - "text": "This is a test document.", - "label": "doc1", - "rand_number": random.randint(0, 10), - }, - { - "vector": np.random.rand(384), - "text": "Another document.", - "label": "doc2", - "rand_number": random.randint(0, 10), - }, + overwrite=True, + ) + + async with connect_to_server( { - "vector": np.random.rand(384), - "text": "Yet another document.", - "label": "doc3", - "rand_number": random.randint(0, 10), - }, - ] - await artifact_manager.add_vectors( - artifact_id=vector_collection.id, - vectors=vectors, - ) + "name": "test deploy client", + "server_url": SERVER_URL_REDIS_1, + "token": test_user_token, + "workspace": "my-vector-test-workspace", + } + ) as api: + artifact_manager = await api.get_service("public/artifact-manager") - vc = await artifact_manager.read(artifact_id=vector_collection.id) - assert vc["config"]["vector_count"] == 3 + # Create a vector-collection artifact + vector_collection_manifest = { + "name": "vector-collection", + "description": "A test vector collection", + } + vector_collection_config = { + "vector_fields": [ + { + "type": "VECTOR", + "name": "vector", + "algorithm": "FLAT", + "attributes": { + "TYPE": "FLOAT32", + "DIM": 384, + "DISTANCE_METRIC": "COSINE", + }, + }, + {"type": "TEXT", "name": "text"}, + {"type": "TAG", "name": "label"}, + {"type": "NUMERIC", "name": "rand_number"}, + ], + "embedding_models": { + "vector": "fastembed:BAAI/bge-small-en-v1.5", + }, + } + vector_collection = await artifact_manager.create( + type="vector-collection", + manifest=vector_collection_manifest, + config=vector_collection_config, + ) + # Add vectors to the collection + vectors = [ + { + "vector": [random.random() for _ in range(384)], + "text": "This is a test document.", + "label": "doc1", + "rand_number": random.randint(0, 10), + }, + { + "vector": np.random.rand(384), + "text": "Another document.", + "label": "doc2", + "rand_number": random.randint(0, 10), + }, + { + "vector": np.random.rand(384), + "text": "Yet another document.", + "label": "doc3", + "rand_number": random.randint(0, 10), + }, + ] + await artifact_manager.add_vectors( + artifact_id=vector_collection.id, + vectors=vectors, + ) - # Search for vectors by query vector - query_vector = [random.random() for _ in range(384)] - search_results = await artifact_manager.search_vectors( - artifact_id=vector_collection.id, - query={"vector": query_vector}, - limit=2, - ) - assert len(search_results) <= 2 + vc = await artifact_manager.read(artifact_id=vector_collection.id) + assert vc["config"]["vector_count"] == 3 - results = await artifact_manager.search_vectors( - artifact_id=vector_collection.id, - query={"vector": query_vector}, - limit=3, - pagination=True, - ) - assert results["total"] == 3 + # Search for vectors by query vector + query_vector = [random.random() for _ in range(384)] + search_results = await artifact_manager.search_vectors( + artifact_id=vector_collection.id, + query={"vector": query_vector}, + limit=2, + ) + assert len(search_results) <= 2 - search_results = await artifact_manager.search_vectors( - artifact_id=vector_collection.id, - filters={"rand_number": [-2, -1]}, - query={"vector": np.random.rand(384)}, - limit=2, - ) - assert len(search_results) == 0 + results = await artifact_manager.search_vectors( + artifact_id=vector_collection.id, + query={"vector": query_vector}, + limit=3, + pagination=True, + ) + assert results["total"] == 3 - search_results = await artifact_manager.search_vectors( - artifact_id=vector_collection.id, - filters={"rand_number": [0, 10]}, - query={"vector": np.random.rand(384)}, - limit=2, - ) - assert len(search_results) > 0 + search_results = await artifact_manager.search_vectors( + artifact_id=vector_collection.id, + filters={"rand_number": [-2, -1]}, + query={"vector": np.random.rand(384)}, + limit=2, + ) + assert len(search_results) == 0 - # Search for vectors by text - vectors = [ - {"vector": "This is a test document.", "label": "doc1"}, - {"vector": "Another test document.", "label": "doc2"}, - ] - await artifact_manager.add_vectors( - artifact_id=vector_collection.id, - vectors=vectors, - ) + search_results = await artifact_manager.search_vectors( + artifact_id=vector_collection.id, + filters={"rand_number": [0, 10]}, + query={"vector": np.random.rand(384)}, + limit=2, + ) + assert len(search_results) > 0 - text_search_results = await artifact_manager.search_vectors( - artifact_id=vector_collection.id, - query={"vector": "test document"}, - limit=2, - ) - assert len(text_search_results) <= 2 + # Search for vectors by text + vectors = [ + {"vector": "This is a test document.", "label": "doc1"}, + {"vector": "Another test document.", "label": "doc2"}, + ] + await artifact_manager.add_vectors( + artifact_id=vector_collection.id, + vectors=vectors, + ) - # Retrieve a specific vector - retrieved_vector = await artifact_manager.get_vector( - artifact_id=vector_collection.id, - id=text_search_results[0]["id"], - ) - assert "label" in retrieved_vector + text_search_results = await artifact_manager.search_vectors( + artifact_id=vector_collection.id, + query={"vector": "test document"}, + limit=2, + ) + assert len(text_search_results) <= 2 - # List vectors in the collection - vector_list = await artifact_manager.list_vectors( - artifact_id=vector_collection.id, - offset=0, - limit=10, - ) - assert len(vector_list) > 0 + # Retrieve a specific vector + retrieved_vector = await artifact_manager.get_vector( + artifact_id=vector_collection.id, + id=text_search_results[0]["id"], + ) + assert "label" in retrieved_vector - # Remove a vector from the collection - await artifact_manager.remove_vectors( - artifact_id=vector_collection.id, - ids=[vector_list[0]["id"]], - ) - remaining_vectors = await artifact_manager.list_vectors( - artifact_id=vector_collection.id, - offset=0, - limit=10, - ) - assert all(v["id"] != vector_list[0]["id"] for v in remaining_vectors) + # List vectors in the collection + vector_list = await artifact_manager.list_vectors( + artifact_id=vector_collection.id, + offset=0, + limit=10, + ) + assert len(vector_list) > 0 + + # Remove a vector from the collection + await artifact_manager.remove_vectors( + artifact_id=vector_collection.id, + ids=[vector_list[0]["id"]], + ) + remaining_vectors = await artifact_manager.list_vectors( + artifact_id=vector_collection.id, + offset=0, + limit=10, + ) + assert all(v["id"] != vector_list[0]["id"] for v in remaining_vectors) - # Clean up by deleting the vector collection - await artifact_manager.delete(artifact_id=vector_collection.id) + # Clean up by deleting the vector collection + await artifact_manager.delete(artifact_id=vector_collection.id) async def test_load_dump_vector_collections( @@ -346,6 +363,22 @@ async def test_load_dump_vector_collections( "server_url": SERVER_URL_REDIS_1, "token": test_user_token, } + ) as api: + await api.create_workspace( + { + "name": "my-vector-dump-workspace", + "description": "This is a test workspace for dumping vector collections", + "persistent": True, + }, + overwrite=True, + ) + async with connect_to_server( + { + "name": "test deploy client", + "server_url": SERVER_URL_REDIS_1, + "token": test_user_token, + "workspace": "my-vector-dump-workspace", + } ) as api: artifact_manager = await api.get_service("public/artifact-manager") @@ -406,6 +439,7 @@ async def test_load_dump_vector_collections( vectors=vectors, ) + await asyncio.sleep(0.2) async with connect_to_server( { "server_url": SERVER_URL_REDIS_1, @@ -423,6 +457,7 @@ async def test_load_dump_vector_collections( "name": "test deploy client", "server_url": SERVER_URL_REDIS_1, "token": test_user_token, + "workspace": "my-vector-dump-workspace", } ) as api: await api.wait_until_ready(timeout=60)