diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index bf2c07a4..9cf8bf03 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -26,6 +26,11 @@ jobs: ports: - 5432:5432 options: --health-cmd "pg_isready -U postgres" --health-interval 10s --health-timeout 5s --health-retries 5 + redis: + image: redis/redis-stack:7.2.0-v13 + ports: + - 6333:6379 + options: --health-cmd "redis-cli ping" --health-interval 10s --health-timeout 5s --health-retries 5 steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} @@ -36,11 +41,6 @@ jobs: run: | python -m pip install --upgrade pip pip install tox tox-gh-actions - - name: Start Redis - uses: supercharge/redis-github-action@1.8.0 - with: - redis-version: 6 - redis-port: 6333 - name: Test with tox run: tox env: diff --git a/docs/artifact-manager.md b/docs/artifact-manager.md index af11e9b5..b5499a9e 100644 --- a/docs/artifact-manager.md +++ b/docs/artifact-manager.md @@ -235,7 +235,7 @@ print("Valid dataset committed.") ## API References -### `create(parent_id: str, alias: str, type: str, manifest: dict, permissions: dict=None, config: dict=None, version: str = None, comment: str = None, publish_to: str = None) -> None` +### `create(parent_id: str, alias: str, type: str, manifest: dict, permissions: dict=None, config: dict=None, version: str = None, comment: str = None, overwrite: bool = False, publish_to: str = None) -> None` Creates a new artifact or collection with the specified manifest. The artifact is staged until committed. For collections, the `collection` field should be an empty list. @@ -270,7 +270,7 @@ Creates a new artifact or collection with the specified manifest. The artifact i - `S3_BUCKET`: The bucket name of the S3 storage for the artifact. Default to the hypha workspaces bucket. - `S3_PREFIX`: The prefix of the S3 storage for the artifact. Default: `""`. - `S3_PUBLIC_ENDPOINT_URL`: The public endpoint URL of the S3 storage for the artifact. If the S3 server is not public, you can set this to the public endpoint URL. Default: `None`. - +- `overwrite`: Optional. A boolean flag to overwrite the existing artifact with the same alias. Default is `False`. - `publish_to`: Optional. A string specifying the target platform to publish the artifact. Supported values are `zenodo` and `sandbox_zenodo`. If set, the artifact will be published to the specified platform. The artifact must have a valid Zenodo metadata schema to be published. **Note 1: If you set `version="stage"`, you must call `commit()` to finalize the artifact.** @@ -564,7 +564,7 @@ manifest = await artifact_manager.read(artifact_id="other_workspace/example-data --- -### `list(artifact_id: str=None, keywords: List[str] = None, filters: dict = None, mode: str = "AND", page: int = 0, page_size: int = 100, order_by: str = None, silent: bool = False) -> list` +### `list(artifact_id: str=None, keywords: List[str] = None, filters: dict = None, mode: str = "AND", offset: int = 0, limit: int = 100, order_by: str = None, silent: bool = False) -> list` Retrieve a list of child artifacts within a specified collection, supporting keyword-based fuzzy search, field-specific filters, and flexible ordering. This function allows detailed control over the search and pagination of artifacts in a collection, including staged artifacts if specified. @@ -584,9 +584,9 @@ Retrieve a list of child artifacts within a specified collection, supporting key - `mode` (str, optional): Defines how multiple conditions (from keywords and filters) are combined. Use `"AND"` to ensure all conditions must match, or `"OR"` to include artifacts meeting any condition. Default is `"AND"`. -- `page` (int, optional): The page number for pagination. Used in conjunction with `page_size` to limit results. Default is `0`, which returns the first page of results. +- `offset` (int, optional): The number of artifacts to skip before listing results. Default is `0`. -- `page_size` (int, optional): The maximum number of artifacts to return per page. This is capped at 1000 for performance considerations. Default is `100`. +- `limit` (int, optional): The maximum number of artifacts to return. Default is `100`. - `order_by` (str, optional): The field used to order results. Options include: - `view_count`, `download_count`, `last_modified`, `created_at`, and `id`. @@ -609,8 +609,8 @@ results = await artifact_manager.list( filters={"created_by": "user123", "stage": False}, order_by="view_count>", mode="AND", - page=1, - page_size=50 + offset=0, + limit=50 ) ``` @@ -737,8 +737,8 @@ Qury parameters are passed after the `?` in the URL and are used to control the - **keywords**: A list of search terms used for fuzzy searching across all manifest fields, separated by commas. - **filters**: A dictionary of filters to apply to the search, in the format of a JSON string. - **mode**: The mode for combining multiple conditions. Default is `AND`. -- **page**: The page number for pagination. Default is `0`. -- **page_size**: The maximum number of artifacts to return per page. Default is `100`. +- **offset**: The number of artifacts to skip before listing results. Default is `0`. +- **limit**: The maximum number of artifacts to return. Default is `100`. - **order_by**: The field used to order results. Default is ascending by id. - **silent**: A boolean flag to prevent incrementing the view count for the parent artifact when listing children, listing files, or reading the artifact. Default is `False`. diff --git a/hypha/VERSION b/hypha/VERSION index bf562dc3..f0e3ef88 100644 --- a/hypha/VERSION +++ b/hypha/VERSION @@ -1,3 +1,3 @@ { - "version": "0.20.39.post7" + "version": "0.20.39.post8" } diff --git a/hypha/artifact.py b/hypha/artifact.py index 16d07384..b927d067 100644 --- a/hypha/artifact.py +++ b/hypha/artifact.py @@ -3,8 +3,10 @@ import sys import uuid_utils as uuid import random +import numpy as np import re import json +import asyncio from sqlalchemy import ( event, Column, @@ -28,7 +30,6 @@ async_sessionmaker, AsyncSession, ) -from sqlalchemy.orm import relationship # For parent-child relationships from fastapi import APIRouter, Depends, HTTPException from hypha.core import ( UserInfo, @@ -142,6 +143,9 @@ def __init__( self.s3_controller = s3_controller self.workspace_bucket = workspace_bucket self.store = store + self._vectordb_client = self.store.get_vectordb_client() + self._openai_client = self.store.get_openai_client() + self._cache_dir = self.store.get_cache_dir() router = APIRouter() self._artifacts_dir = artifacts_dir @@ -208,8 +212,8 @@ async def get_artifact( async def list_children( workspace: str, artifact_alias: str, - page: int = 0, - page_size: int = 100, + offset: int = 0, + limit: int = 100, order_by: str = None, user_info: self.store.login_optional = Depends(self.store.login_optional), ): @@ -230,8 +234,8 @@ async def list_children( ArtifactModel.workspace == workspace, ArtifactModel.parent_id == parent_artifact.id, ) - .limit(page_size) - .offset(page * page_size) + .limit(limit) + .offset(offset) ) if order_by: @@ -872,6 +876,7 @@ async def create( publish_to=None, version: str = None, comment: str = None, + overwrite: bool = False, context: dict = None, ): """Create a new artifact and store its manifest in the database.""" @@ -898,7 +903,7 @@ async def create( "Workspace must match the alias workspace, if provided." ) workspace = ws - + created_at = int(time.time()) session = await self._get_session() try: async with session.begin(): @@ -981,16 +986,17 @@ async def create( existing_artifact = await self._get_artifact( session, f"{workspace}/{alias}" ) - if parent_id != existing_artifact.parent_id: + if parent_id != existing_artifact.parent_id and not overwrite: raise FileExistsError( f"Artifact with alias '{alias}' already exists under a different parent artifact, please choose a different alias or remove the existing artifact (ID: {existing_artifact.workspace}/{existing_artifact.alias})" ) - else: + elif not overwrite: raise FileExistsError( f"Artifact with alias '{alias}' already exists, please choose a different alias or remove the existing artifact (ID: {existing_artifact.workspace}/{existing_artifact.alias})." ) + id = existing_artifact.id except KeyError: - pass + overwrite = False parent_permissions = ( parent_artifact.config["permissions"] if parent_artifact else {} @@ -1021,7 +1027,7 @@ async def create( staging=[] if version == "stage" else None, manifest=manifest, created_by=user_info.id, - created_at=int(time.time()), + created_at=created_at, last_modified=int(time.time()), config=config, secrets=secrets, @@ -1029,7 +1035,24 @@ async def create( type=type, ) version_index = self._get_version_index(new_artifact, version) - session.add(new_artifact) + if overwrite: + await session.merge(new_artifact) + else: + session.add(new_artifact) + if new_artifact.type == "vector-collection": + assert ( + self._vectordb_client + ), "The server is not configured to use a VectorDB client." + from qdrant_client.models import Distance, VectorParams + + vectors_config = config.get("vectors_config", {}) + await self._vectordb_client.create_collection( + collection_name=f"{new_artifact.workspace}/{new_artifact.alias}", + vectors_config=VectorParams( + size=vectors_config.get("size", 128), + distance=Distance(vectors_config.get("distance", "Cosine")), + ), + ) await session.commit() await self._save_version_to_s3( version_index, @@ -1295,6 +1318,15 @@ async def commit( ), ) + if artifact.type == "vector-collection": + assert ( + self._vectordb_client + ), "The server is not configured to use a VectorDB client." + artifact.manifest["points"] = self._vectordb_client.count( + collection_name=f"{artifact.workspace}/{artifact.alias}" + ) + flag_modified(artifact, "manifest") + parent_artifact_config = ( parent_artifact.config if parent_artifact else {} ) @@ -1362,6 +1394,14 @@ async def delete( user_info, artifact_id, "delete", session ) + if artifact.type == "vector-collection": + assert ( + self._vectordb_client + ), "The server is not configured to use a VectorDB client." + self._vectordb_client.delete_collection( + collection_name=f"{artifact.workspace}/{artifact.alias}" + ) + s3_config = self._get_s3_config(artifact, parent_artifact) if version is None: # Handle recursive deletion first @@ -1412,6 +1452,310 @@ async def delete( if session: await session.close() + async def add_vectors( + self, + artifact_id: str, + vectors: list, + context: dict = None, + ): + """ + Add vectors to a vector collection. + """ + user_info = UserInfo.model_validate(context["user"]) + session = await self._get_session() + try: + async with session.begin(): + artifact, _ = await self._get_artifact_with_permission( + user_info, artifact_id, "put_file", session + ) + assert ( + artifact.type == "vector-collection" + ), "Artifact must be a vector collection." + assert ( + self._vectordb_client + ), "The server is not configured to use a VectorDB client." + assert artifact.manifest, "Artifact must be committed before upserting." + assert isinstance( + vectors, list + ), "Vectors must be a list of dictionaries." + assert all( + isinstance(v, dict) for v in vectors + ), "Vectors must be a list of dictionaries." + from qdrant_client.models import PointStruct + + _points = [] + for p in vectors: + p["id"] = p.get("id") or str(uuid.uuid4()) + _points.append(PointStruct(**p)) + await self._vectordb_client.upsert( + collection_name=f"{artifact.workspace}/{artifact.alias}", + points=_points, + ) + # TODO: Update file_count + logger.info(f"Upserted vectors to artifact with ID: {artifact_id}") + except Exception as e: + raise e + finally: + await session.close() + + async def _embed_texts(self, config, texts): + embedding_model = config.get("embedding_model") # "text-embedding-3-small" + assert ( + embedding_model + ), "Embedding model must be provided, e.g. 'fastembed', 'text-embedding-3-small' for openai or 'all-minilm' for ollama." + if embedding_model.startswith("fastembed"): + from fastembed import TextEmbedding + + if ":" in embedding_model: + model_name = embedding_model.split(":")[-1] + else: + model_name = "BAAI/bge-small-en-v1.5" + embedding_model = TextEmbedding( + model_name=model_name, cache_dir=self._cache_dir + ) + loop = asyncio.get_event_loop() + embeddings = list( + await loop.run_in_executor(None, embedding_model.embed, texts) + ) + else: + assert ( + self._openai_client + ), "The server is not configured to use an OpenAI client." + result = await self._openai_client.embeddings.create( + input=texts, model=embedding_model + ) + embeddings = [data.embedding for data in result.data] + return embeddings + + async def add_documents( + self, + artifact_id: str, + documents: str, # `id`, `text` and other fields + context: dict = None, + ): + """ + Add documents to the artifact. + """ + user_info = UserInfo.model_validate(context["user"]) + session = await self._get_session() + try: + async with session.begin(): + artifact, _ = await self._get_artifact_with_permission( + user_info, artifact_id, "put_file", session + ) + assert ( + artifact.type == "vector-collection" + ), "Artifact must be a vector collection." + texts = [doc["text"] for doc in documents] + embeddings = await self._embed_texts(artifact.config, texts) + from qdrant_client.models import PointStruct + + points = [ + PointStruct( + id=doc.get("id") or str(uuid.uuid4()), + vector=embedding, + payload=doc, + ) + for embedding, doc in zip(embeddings, documents) + ] + await self._vectordb_client.upsert( + collection_name=f"{artifact.workspace}/{artifact.alias}", + points=points, + ) + logger.info(f"Upserted documents to artifact with ID: {artifact_id}") + except Exception as e: + raise e + finally: + await session.close() + + async def search_by_vector( + self, + artifact_id: str, + query_vector, + query_filter: dict = None, + offset: int = 0, + limit: int = 10, + with_payload: bool = True, + with_vectors: bool = False, + context: dict = None, + ): + user_info = UserInfo.model_validate(context["user"]) + session = await self._get_session() + try: + async with session.begin(): + artifact, _ = await self._get_artifact_with_permission( + user_info, artifact_id, "list", session + ) + assert ( + artifact.type == "vector-collection" + ), "Artifact must be a vector collection." + # if it's a numpy array, convert it to a list + if isinstance(query_vector, np.ndarray): + query_vector = query_vector.tolist() + from qdrant_client.models import Filter + + if query_filter: + query_filter = Filter.model_validate(query_filter) + search_results = await self._vectordb_client.search( + collection_name=f"{artifact.workspace}/{artifact.alias}", + query_vector=query_vector, + query_filter=query_filter, + limit=limit, + offset=offset, + with_payload=with_payload, + with_vectors=with_vectors, + ) + return search_results + except Exception as e: + raise e + finally: + await session.close() + + async def search_by_text( + self, + artifact_id: str, + query: str, + query_filter: dict = None, + offset: int = 0, + limit: int = 10, + with_payload: bool = True, + with_vectors: bool = False, + context: dict = None, + ): + user_info = UserInfo.model_validate(context["user"]) + session = await self._get_session() + try: + async with session.begin(): + artifact, _ = await self._get_artifact_with_permission( + user_info, artifact_id, "list", session + ) + assert ( + artifact.type == "vector-collection" + ), "Artifact must be a vector collection." + (query_vector,) = await self._embed_texts(artifact.config, [query]) + from qdrant_client.models import Filter + + if query_filter: + query_filter = Filter.model_validate(query_filter) + search_results = await self._vectordb_client.search( + collection_name=f"{artifact.workspace}/{artifact.alias}", + query_vector=query_vector, + query_filter=query_filter, + limit=limit, + offset=offset, + with_payload=with_payload, + with_vectors=with_vectors, + ) + return search_results + except Exception as e: + raise e + finally: + await session.close() + + async def remove_vectors( + self, + artifact_id: str, + ids: list, + context: dict = None, + ): + user_info = UserInfo.model_validate(context["user"]) + session = await self._get_session() + try: + async with session.begin(): + artifact, _ = await self._get_artifact_with_permission( + user_info, artifact_id, "remove_file", session + ) + assert ( + artifact.type == "vector-collection" + ), "Artifact must be a vector collection." + assert ( + self._vectordb_client + ), "The server is not configured to use a VectorDB client." + await self._vectordb_client.delete( + collection_name=f"{artifact.workspace}/{artifact.alias}", + points_selector=ids, + ) + logger.info(f"Removed vectors from artifact with ID: {artifact_id}") + except Exception as e: + raise e + finally: + await session.close() + + async def get_vector( + self, + artifact_id: str, + id: int, + context: dict = None, + ): + user_info = UserInfo.model_validate(context["user"]) + session = await self._get_session() + try: + async with session.begin(): + artifact, _ = await self._get_artifact_with_permission( + user_info, artifact_id, "get_file", session + ) + assert ( + artifact.type == "vector-collection" + ), "Artifact must be a vector collection." + assert ( + self._vectordb_client + ), "The server is not configured to use a VectorDB client." + points = await self._vectordb_client.retrieve( + collection_name=f"{artifact.workspace}/{artifact.alias}", + ids=[id], + with_payload=True, + with_vectors=True, + ) + return points[0] + except Exception as e: + raise e + finally: + await session.close() + + async def list_vectors( + self, + artifact_id: str, + query_filter: dict = None, + offset: int = 0, + limit: int = 10, + order_by: str = None, + with_payload: bool = True, + with_vectors: bool = False, + context: dict = None, + ): + user_info = UserInfo.model_validate(context["user"]) + session = await self._get_session() + try: + async with session.begin(): + artifact, _ = await self._get_artifact_with_permission( + user_info, artifact_id, "list", session + ) + assert ( + artifact.type == "vector-collection" + ), "Artifact must be a vector collection." + assert ( + self._vectordb_client + ), "The server is not configured to use a VectorDB client." + from qdrant_client.models import Filter + + if query_filter: + query_filter = Filter.model_validate(query_filter) + points, _ = await self._vectordb_client.scroll( + collection_name=f"{artifact.workspace}/{artifact.alias}", + scroll_filter=query_filter, + limit=limit, + offset=offset, + order_by=order_by, + with_payload=with_payload, + with_vectors=with_vectors, + ) + return points + + except Exception as e: + raise e + finally: + await session.close() + async def put_file( self, artifact_id, file_path, download_weight: float = 0, context: dict = None ): @@ -1638,8 +1982,8 @@ async def list_children( keywords=None, filters=None, mode="AND", - page: int = 0, - page_size: int = 100, + offset: int = 0, + limit: int = 100, order_by=None, silent=False, context: dict = None, @@ -1824,7 +2168,6 @@ async def list_children( ) # Pagination and ordering - offset = page * page_size order_field_map = { "id": ArtifactModel.id, "view_count": ArtifactModel.view_count, @@ -1841,7 +2184,7 @@ async def list_children( query.order_by( order_field.asc() if ascending else order_field.desc() ) - .limit(page_size) + .limit(limit) .offset(offset) ) @@ -1959,5 +2302,12 @@ def get_artifact_service(self): "get_file": self.get_file, "list": self.list_children, "list_files": self.list_files, + "add_vectors": self.add_vectors, + "add_documents": self.add_documents, + "search_by_vector": self.search_by_vector, + "search_by_text": self.search_by_text, + "remove_vectors": self.remove_vectors, + "get_vector": self.get_vector, + "list_vectors": self.list_vectors, "publish": self.publish, } diff --git a/hypha/core/__init__.py b/hypha/core/__init__.py index 2860130e..95eed2c8 100644 --- a/hypha/core/__init__.py +++ b/hypha/core/__init__.py @@ -55,6 +55,7 @@ class ServiceConfig(BaseModel): flags: List[str] = [] singleton: Optional[bool] = False created_by: Optional[Dict] = None + service_embedding: Optional[Any] = None class ServiceInfo(BaseModel): @@ -62,50 +63,135 @@ class ServiceInfo(BaseModel): model_config = ConfigDict(extra="allow") - config: SerializeAsAny[ServiceConfig] + config: Optional[SerializeAsAny[ServiceConfig]] = None id: str name: str type: Optional[str] = "generic" - description: Optional[constr(max_length=256)] = "" # type: ignore + description: Optional[constr(max_length=256)] = None # type: ignore docs: Optional[str] = None app_id: Optional[str] = None service_schema: Optional[Dict[str, Any]] = None + score: Optional[float] = None # score generated by vector search def is_singleton(self): """Check if the service is singleton.""" return "single-instance" in self.config.flags def to_redis_dict(self): + """ + Serialize the model to a Redis-compatible dictionary based on field types. + """ data = self.model_dump() - data["config"] = self.config.model_dump_json() - return { - k: json.dumps(v) if not isinstance(v, str) else v for k, v in data.items() - } + redis_data = {} + # Note: Here we only store the fields that are in the model + # and ignore any extra fields that might be present in the data + # Iterate over fields and encode based on their type + for field_name, field_info in self.model_fields.items(): + value = data.get(field_name) + if value is None or field_name == "score": + continue + elif field_name == "config": + redis_data[field_name] = value + elif field_info.annotation in {str, Optional[str]}: + redis_data[field_name] = value + elif field_info.annotation in {list, List[str], Optional[List[str]]}: + redis_data[field_name] = ",".join(value) + else: + redis_data[field_name] = json.dumps(value) + + # Expand config fields to store as separate keys + if "config" in redis_data and redis_data["config"]: + # Iterate through the keys for ServiceConfig and store them to self + for field_name, field_info in ServiceConfig.model_fields.items(): + value = redis_data["config"].get(field_name) + if value is None: + continue + elif field_name == "service_embedding": + redis_data[field_name] = value + elif field_info.annotation in {str, Optional[str]}: + redis_data[field_name] = value + elif field_info.annotation in {list, List[str], Optional[List[str]]}: + redis_data[field_name] = ",".join(value) + else: + redis_data[field_name] = json.dumps(value) + del redis_data["config"] + return redis_data @classmethod - def from_redis_dict(cls, service_data): - converted_service_data = {} - for k, v in service_data.items(): - key_str = k.decode("utf-8") - value_str = v.decode("utf-8") - if ( - value_str.startswith("{") - and value_str.endswith("}") - or value_str.startswith("[") - and value_str.endswith("]") - ): - converted_service_data[key_str] = json.loads(value_str) + def from_redis_dict(cls, service_data: Dict[str, Any], in_bytes=True): + """ + Deserialize a Redis-compatible dictionary back to a model instance. + """ + converted_data = {} + # Note: Here we only convert the fields that are in the model + # and ignore any extra fields that might be present in the data + # Iterate over fields and decode based on their type + # Extract the fields form ServiceConfig first + config_data = {} + for field_name, field_info in ServiceConfig.model_fields.items(): + if not in_bytes: + if field_name not in service_data: + continue + value = service_data.get(field_name) + del service_data[field_name] else: - converted_service_data[key_str] = value_str - converted_service_data["config"] = ServiceConfig.model_validate( - converted_service_data["config"] - ) - return cls.model_validate(converted_service_data) + if field_name.encode("utf-8") not in service_data: + continue + value = service_data.get(field_name.encode("utf-8")) + del service_data[field_name.encode("utf-8")] + if value is None: + config_data[field_name] = None + elif field_name == "service_embedding": + config_data[field_name] = value + elif field_info.annotation in {str, Optional[str]}: + config_data[field_name] = ( + value if isinstance(value, str) else value.decode("utf-8") + ) + elif field_info.annotation in {list, List[str], Optional[List[str]]}: + config_data[field_name] = ( + value.split(",") + if isinstance(value, str) + else value.decode("utf-8").split(",") + ) + else: + value_str = value if isinstance(value, str) else value.decode("utf-8") + config_data[field_name] = json.loads(value_str) + + if config_data: + if in_bytes: + service_data[b"config"] = ServiceConfig.model_validate(config_data) + else: + service_data["config"] = ServiceConfig.model_validate(config_data) + + for field_name, field_info in cls.model_fields.items(): + if not in_bytes: + value = service_data.get(field_name) + else: + value = service_data.get(field_name.encode("utf-8")) + if value is None: + converted_data[field_name] = None + elif field_name == "config": + converted_data[field_name] = value + elif field_info.annotation in {str, Optional[str]}: + converted_data[field_name] = ( + value if isinstance(value, str) else value.decode("utf-8") + ) + elif field_info.annotation in {list, List[str], Optional[List[str]]}: + config_data[field_name] = ( + value.split(",") + if isinstance(value, str) + else value.decode("utf-8").split(",") + ) + else: + value_str = value if isinstance(value, str) else value.decode("utf-8") + converted_data[field_name] = json.loads(value_str) + return cls.model_validate(converted_data) @classmethod def model_validate(cls, data): data = data.copy() - data["config"] = ServiceConfig.model_validate(data["config"]) + if "config" in data and data["config"] is not None: + data["config"] = ServiceConfig.model_validate(data["config"]) return super().model_validate(data) diff --git a/hypha/core/store.py b/hypha/core/store.py index a5cb907a..40453174 100644 --- a/hypha/core/store.py +++ b/hypha/core/store.py @@ -25,8 +25,6 @@ from sqlalchemy.ext.asyncio import ( create_async_engine, ) -from sqlalchemy import inspect, text -from sqlalchemy.exc import NoInspectionAvailable, NoSuchTableError from hypha.core.auth import ( create_scope, @@ -95,6 +93,11 @@ def __init__( local_base_url=None, redis_uri=None, database_uri=None, + vectordb_uri=None, + ollama_host=None, + openai_config=None, + cache_dir=None, + enable_service_search=False, reconnection_token_life_time=2 * 24 * 60 * 60, ): """Initialize the redis store.""" @@ -112,9 +115,11 @@ def __init__( self._ready = False self._workspace_manager = None self._websocket_server = None + self._cache_dir = cache_dir self._server_id = server_id or random_id(readable=True) self._manager_id = "manager-" + self._server_id self.reconnection_token_life_time = reconnection_token_life_time + self._enable_service_search = enable_service_search self._server_info = { "server_id": self._server_id, "hypha_version": __version__, @@ -128,17 +133,41 @@ def __init__( } logger.info("Server info: %s", self._server_info) + self._vectordb_uri = vectordb_uri + if self._vectordb_uri is not None: + from qdrant_client import AsyncQdrantClient + + self._vectordb_client = AsyncQdrantClient(self._vectordb_uri) + else: + self._vectordb_client = None self._database_uri = database_uri if self._database_uri is None: - database_uri = ( + self._database_uri = ( "sqlite+aiosqlite:///:memory:" # In-memory SQLite for testing ) logger.warning( "Using in-memory SQLite database for event logging, all data will be lost on restart!" ) - self._sql_engine = create_async_engine(database_uri, echo=False) + self._ollama_host = ollama_host + if self._ollama_host is not None: + import ollama + + self._ollama_client = ollama.AsyncClient(host=self._ollama_host) + + self._openai_config = openai_config + if ( + self._openai_config is not None + and self._openai_config.get("api_key") is not None + ): + from openai import AsyncClient + + self._openai_client = AsyncClient(**self._openai_config) + else: + self._openai_client = None + + self._sql_engine = create_async_engine(self._database_uri, echo=False) if redis_uri and redis_uri.startswith("redis://"): from redis import asyncio as aioredis @@ -170,10 +199,22 @@ def get_redis(self): def get_sql_engine(self): return self._sql_engine + def get_vectordb_client(self): + return self._vectordb_client + + def get_openai_client(self): + return self._openai_client + + def get_ollama_client(self): + return self._ollama_client + def get_event_bus(self): """Get the event bus.""" return self._event_bus + def get_cache_dir(self): + return self._cache_dir + async def setup_root_user(self) -> UserInfo: """Setup the root user.""" self._root_user = UserInfo( @@ -634,6 +675,8 @@ async def register_workspace_manager(self): self._sql_engine, self._s3_controller, self._artifact_manager, + self._enable_service_search, + self._cache_dir, ) await manager.setup() return manager diff --git a/hypha/core/workspace.py b/hypha/core/workspace.py index cfe3a3c3..5944bf88 100644 --- a/hypha/core/workspace.py +++ b/hypha/core/workspace.py @@ -1,5 +1,6 @@ import re import json +import asyncio import logging import time import sys @@ -7,6 +8,7 @@ from contextlib import asynccontextmanager import random import datetime +import numpy as np from fakeredis import aioredis from prometheus_client import Gauge @@ -51,6 +53,21 @@ def naive_utc_now(): return datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) +def escape_redis_syntax(value: str) -> str: + """Escape Redis special characters in a query string, except '*'.""" + # Escape all special characters except '*' + return re.sub(r"([{}|@$\\\-\[\]\(\)\!&~:\"])", r"\\\1", value) + + +def sanitize_search_value(value: str) -> str: + """Sanitize a value to prevent injection attacks, allowing '*' for wildcard support.""" + # Allow alphanumeric characters, spaces, underscores, hyphens, dots, slashes, and '*' + value = re.sub( + r"[^a-zA-Z0-9 _\-./*]", "", value + ) # Remove unwanted characters except '*' + return escape_redis_syntax(value.strip()) + + # SQLModel model for storing events class EventLog(SQLModel, table=True): __tablename__ = "event_logs" @@ -108,6 +125,8 @@ def __init__( sql_engine: Optional[str] = None, s3_controller: Optional[Any] = None, artifact_manager: Optional[Any] = None, + enable_service_search: bool = False, + cache_dir: str = None, ): self._redis = redis self._initialized = False @@ -119,6 +138,7 @@ def __init__( self._s3_controller = s3_controller self._artifact_manager = artifact_manager self._sql_engine = sql_engine + self._cache_dir = cache_dir if self._sql_engine: self.SessionLocal = async_sessionmaker( self._sql_engine, expire_on_commit=False, class_=AsyncSession @@ -129,6 +149,7 @@ def __init__( self._active_svc = Gauge( "active_services", "Number of active services", ["workspace"] ) + self._enable_service_search = enable_service_search async def _get_sql_session(self): """Return an async session for the database.""" @@ -158,6 +179,50 @@ async def setup( async with self._sql_engine.begin() as conn: await conn.run_sync(EventLog.metadata.create_all) logger.info("Database tables created successfully.") + + self._embedding_model = None + self._search_fields = None + if self._enable_service_search: + from fastembed import TextEmbedding + + self._embedding_model = TextEmbedding( + model_name="BAAI/bge-small-en-v1.5", cache_dir=self._cache_dir + ) + + from redis.commands.search.field import VectorField, TextField, TagField + from redis.commands.search.indexDefinition import IndexDefinition, IndexType + + # Define vector field for RedisSearch (assuming cosine similarity) + # Manually define Redis fields for each ServiceInfo attribute + self._search_fields = [ + TagField(name="id"), # id as tag + TextField(name="name"), # name as text + TagField(name="type"), # type as tag (enum-like) + TextField(name="description"), # description as text + TextField(name="docs"), # docs as text + TagField(name="app_id"), # app_id as tag + TextField( + name="service_schema" + ), # service_schema as text (you can store a serialized JSON or string representation) + VectorField( + "service_embedding", + "FLAT", + {"TYPE": "FLOAT32", "DIM": 384, "DISTANCE_METRIC": "COSINE"}, + ), + TagField(name="visibility"), # visibility as tag + TagField(name="require_context"), # require_context as tag + TagField(name="workspace"), # workspace as tag + TagField(name="flags", separator=","), # flags as tag + TagField(name="singleton"), # singleton as tag + TextField(name="created_by"), # created_by as text + ] + # Create the index with vector field and additional fields for metadata (e.g., title) + await self._redis.ft("service_info_index").create_index( + fields=self._search_fields, + definition=IndexDefinition( + prefix=["services:"], index_type=IndexType.HASH + ), + ) self._initialized = True return rpc @@ -189,7 +254,6 @@ async def log_event( logger.info( f"Logged event: {event_type} by {user_info.id} in {workspace}" ) - logger.info(f"Event logged: {event_type}") except Exception as e: logger.error(f"Failed to log event: {event_type}, {e}") raise @@ -709,6 +773,168 @@ async def ping_client( except Exception as e: return f"Failed to ping client {client_id}: {e}" + def _convert_filters_to_hybrid_query(self, filters: dict) -> str: + """ + Convert a filter dictionary to a Redis hybrid query string. + + Args: + filters (dict): Dictionary of filters, e.g., {"type": "my-type", "year": [2011, 2012]}. + + Returns: + str: Redis hybrid query string, e.g., "(@type:{my-type} @year:[2011 2012])". + """ + from redis.commands.search.field import TextField, TagField, NumericField + + conditions = [] + + for field_name, value in filters.items(): + # Find the field type in the schema + field_type = None + for field in self._search_fields: + if field.name == field_name: + field_type = type(field) + break + + if not field_type: + raise ValueError(f"Unknown field '{field_name}' in filters.") + + # Sanitize the field name + sanitized_field_name = sanitize_search_value(field_name) + + if field_type == TagField: + # Use `{value}` for TagField + if not isinstance(value, str): + raise ValueError( + f"TagField '{field_name}' requires a string value." + ) + sanitized_value = sanitize_search_value(value) + conditions.append(f"@{sanitized_field_name}:{{{sanitized_value}}}") + + elif field_type == NumericField: + # Use `[min max]` for NumericField + if not isinstance(value, (list, tuple)) or len(value) != 2: + raise ValueError( + f"NumericField '{field_name}' requires a list or tuple with two elements." + ) + min_val, max_val = value + conditions.append(f"@{sanitized_field_name}:[{min_val} {max_val}]") + + elif field_type == TextField: + # Use `"value"` for TextField + if not isinstance(value, str): + raise ValueError( + f"TextField '{field_name}' requires a string value." + ) + if "*" in value: + assert value.endswith("*"), "Wildcard '*' must be at the end." + sanitized_value = sanitize_search_value(value) + conditions.append(f"@{sanitized_field_name}:{sanitized_value}") + else: + sanitized_value = escape_redis_syntax(value) + conditions.append(f'@{sanitized_field_name}:"{sanitized_value}"') + + else: + raise ValueError(f"Unsupported field type for '{field_name}'.") + + return " ".join(conditions) + + @schema_method + async def search_services( + self, + text_query: Optional[str] = Field( + None, description="Text query for semantic search." + ), + vector_query: Optional[Any] = Field( + None, + description="Precomputed embedding vector for vector search in numpy format.", + ), + filters: Optional[Dict[str, Any]] = Field( + None, description="Filter dictionary for hybrid search." + ), + limit: Optional[int] = Field( + 5, description="Maximum number of results to return." + ), + offset: Optional[int] = Field(0, description="Offset for pagination."), + fields: Optional[List[str]] = Field(None, description="Fields to return."), + order_by: Optional[str] = Field( + None, + description="Order by field, default is score if embedding or text_query is provided.", + ), + context: Optional[dict] = None, + ): + """ + Search services with support for hybrid queries and pure filter-based queries. + """ + if not self._enable_service_search: + raise RuntimeError("Service search is not enabled.") + from redis.commands.search.query import Query + + current_workspace = context["ws"] + # Generate embedding if text_query is provided + if text_query and not vector_query: + loop = asyncio.get_event_loop() + embeddings = list( + await loop.run_in_executor( + None, self._embedding_model.embed, [text_query] + ) + ) + vector_query = embeddings[0] + + auth_filter = f"@visibility:{{public}} | @workspace:{{{sanitize_search_value(current_workspace)}}}" + # If service_embedding is provided, prepare KNN search query + if vector_query is not None: + query_vector = vector_query.astype("float32").tobytes() + query_params = {"vector": query_vector} + knn_query = f"[KNN {limit} @service_embedding $vector AS score]" + # Combine filters into the query string + if filters: + filter_query = self._convert_filters_to_hybrid_query(filters) + query_string = f"(({filter_query}) ({auth_filter}))=>{knn_query}" + else: + query_string = f"({auth_filter})=>{knn_query}" + else: + query_params = {} + if filters: + filter_query = self._convert_filters_to_hybrid_query(filters) + query_string = f"({filter_query}) ({auth_filter})" + else: + query_string = auth_filter + + all_fields = [field.name for field in self._search_fields] + ["score"] + if fields is None: + # exclude embedding + fields = [field for field in all_fields if field != "service_embedding"] + else: + for field in fields: + if field not in all_fields: + raise ValueError(f"Invalid field: {field}") + if order_by is None: + order_by = "score" if vector_query is not None else "id" + else: + if order_by not in all_fields: + raise ValueError(f"Invalid order_by field: {order_by}") + + # Build the RedisSearch query + query = ( + Query(query_string) + .return_fields(*fields) + .sort_by(order_by, asc=True) + .paging(offset, limit) + .dialect(2) + ) + + # Perform the search using the RedisSearch index + results = await self._redis.ft("service_info_index").search( + query, query_params=query_params + ) + + # Convert results to dictionaries and return + services = [ + ServiceInfo.from_redis_dict(vars(doc), in_bytes=False) + for doc in results.docs + ] + return [service.model_dump() for service in services] + @schema_method async def list_services( self, @@ -883,6 +1109,28 @@ async def list_services( return services + async def _embed_service(self, redis_data): + if "service_embedding" in redis_data: + if isinstance(redis_data["service_embedding"], np.ndarray): + redis_data["service_embedding"] = redis_data[ + "service_embedding" + ].tobytes() + elif isinstance(redis_data["service_embedding"], bytes): + pass + else: + raise ValueError( + f"Invalid service_embedding type: {type(redis_data['service_embedding'])}, it must be a numpy array or bytes." + ) + elif redis_data.get("docs"): # Only embed the service if it has docs + assert self._embedding_model, "Embedding model is not available." + summary = f"{redis_data.get('name', '')}\n{redis_data.get('description', '')}\n{redis_data.get('docs', '')}" + loop = asyncio.get_event_loop() + embeddings = list( + await loop.run_in_executor(None, self._embedding_model.embed, [summary]) + ) + redis_data["service_embedding"] = embeddings[0].tobytes() + return redis_data + @schema_method async def register_service( self, @@ -952,7 +1200,14 @@ async def register_service( for k in service_exists: logger.info(f"Replacing existing service: {k}") await self._redis.delete(k) - await self._redis.hset(key, mapping=service.to_redis_dict()) + + if self._enable_service_search: + redis_data = await self._embed_service(service.to_redis_dict()) + else: + redis_data = service.to_redis_dict() + if "service_embedding" in redis_data: + del redis_data["service_embedding"] + await self._redis.hset(key, mapping=redis_data) if ":built-in@" in key: await self._event_bus.emit( "client_updated", {"id": client_id, "workspace": ws} @@ -962,7 +1217,13 @@ async def register_service( await self._event_bus.emit("service_updated", service.model_dump()) logger.info(f"Updating service: {service.id}") else: - await self._redis.hset(key, mapping=service.to_redis_dict()) + if self._enable_service_search: + redis_data = await self._embed_service(service.to_redis_dict()) + else: + redis_data = service.to_redis_dict() + if "service_embedding" in redis_data: + del redis_data["service_embedding"] + await self._redis.hset(key, mapping=redis_data) # Default service created by api.export({}), typically used for hypha apps if ":default@" in key: try: @@ -982,6 +1243,9 @@ async def register_service( f"services:*|*:{client_id}:built-in@*" ) else: + # Remove the service embedding from the config + if service.config and service.config.service_embedding is not None: + service.config.service_embedding = None await self._event_bus.emit( "service_added", service.model_dump(mode="json") ) @@ -1374,24 +1638,24 @@ async def get_service( async def list_workspaces( self, match: dict = Field(None, description="Match pattern for filtering workspaces"), - page: int = Field(1, description="Page number for pagination"), - page_size: int = Field(256, description="Number of items per page"), + offset: int = Field(0, description="Offset for pagination"), + limit: int = Field(256, description="Maximum number of workspaces to return"), context=None, ) -> List[Dict[str, Any]]: """Get all workspaces with pagination.""" self.validate_context(context, permission=UserPermission.read) user_info = UserInfo.model_validate(context["user"]) - # Validate page and page_size - if page < 1: - raise ValueError("Page number must be greater than 0") - if page_size < 1 or page_size > 256: - raise ValueError("Page size must be greater than 0 and less than 256") + # Validate page and limit + if offset < 0: + raise ValueError("Offset number must be greater than or equal to 0") + if limit < 1 or limit > 256: + raise ValueError("Limit must be greater than 0 and less than 256") cursor = 0 workspaces = [] - start_index = (page - 1) * page_size - end_index = page * page_size + start_index = offset + end_index = offset + limit current_index = 0 while True: @@ -1606,6 +1870,7 @@ def create_service(self, service_id, service_name=None): "unregister_service": self.unregister_service, "list_workspaces": self.list_workspaces, "list_services": self.list_services, + "search_services": self.search_services, "list_clients": self.list_clients, "register_service_type": self.register_service_type, "get_service_type": self.get_service_type, diff --git a/hypha/s3.py b/hypha/s3.py index b45bceec..3b8bba63 100644 --- a/hypha/s3.py +++ b/hypha/s3.py @@ -170,7 +170,6 @@ def __init__( s3_admin_type="generic", enable_s3_proxy=False, workspace_bucket="hypha-workspaces", - # local_log_dir="./logs", workspace_etc_dir="etc", executable_path="", ): diff --git a/hypha/server.py b/hypha/server.py index e20d3559..25d202f4 100644 --- a/hypha/server.py +++ b/hypha/server.py @@ -98,7 +98,7 @@ def start_builtin_services( s3_admin_type=args.s3_admin_type, enable_s3_proxy=args.enable_s3_proxy, workspace_bucket=args.workspace_bucket, - executable_path=args.executable_path, + executable_path=args.executable_path or args.cache_dir, ) artifact_manager = ArtifactController( store, @@ -221,6 +221,14 @@ async def lifespan(app: FastAPI): local_base_url=local_base_url, redis_uri=args.redis_uri, database_uri=args.database_uri, + vectordb_uri=args.vectordb_uri, + ollama_host=args.ollama_host, + cache_dir=args.cache_dir, + openai_config={ + "base_url": args.openai_base_url, + "api_key": args.openai_api_key, + }, + enable_service_search=args.enable_service_search, reconnection_token_life_time=float( env.get("RECONNECTION_TOKEN_LIFE_TIME", str(2 * 24 * 60 * 60)) ), @@ -373,6 +381,30 @@ def get_argparser(add_help=True): default=None, help="set SecretAccessKey for S3", ) + parser.add_argument( + "--ollama-host", + type=str, + default=None, + help="set host for the ollama server", + ) + parser.add_argument( + "--openai-base-url", + type=str, + default=None, + help="set OpenAI API type", + ) + parser.add_argument( + "--openai-api-key", + type=str, + default=None, + help="set OpenAI API key", + ) + parser.add_argument( + "--vectordb-uri", + type=str, + default=None, + help="set URI for the vector database", + ) parser.add_argument( "--database-uri", type=str, @@ -420,7 +452,17 @@ def get_argparser(add_help=True): action="store_true", help="enable S3 proxy for serving pre-signed URLs", ) - + parser.add_argument( + "--cache-dir", + type=str, + default=None, + help="set the cache directory for the server", + ) + parser.add_argument( + "--enable-service-search", + action="store_true", + help="enable semantic service search via vector database", + ) return parser diff --git a/hypha/utils/zenodo.py b/hypha/utils/zenodo.py index a556391a..c0bf7e79 100644 --- a/hypha/utils/zenodo.py +++ b/hypha/utils/zenodo.py @@ -13,7 +13,7 @@ def __init__( self.zenodo_server = zenodo_server self.headers = {"Content-Type": "application/json"} self.params = {"access_token": self.access_token} - self.client = httpx.AsyncClient() + self.client = httpx.AsyncClient(headers={"Connection": "close"}) async def create_deposition(self) -> Dict[str, Any]: """Creates a new empty deposition and returns its info.""" @@ -97,7 +97,7 @@ async def file_chunk_reader(file_path: str, chunk_size: int = 1024): async def import_file(self, deposition_info, name, target_url): bucket_url = deposition_info["links"]["bucket"] - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(headers={"Connection": "close"}) as client: async with client.stream("GET", target_url) as response: async def s3_response_chunk_reader(response, chunk_size: int = 2048): diff --git a/requirements.txt b/requirements.txt index 6743ec22..cc3bfab9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -34,3 +34,6 @@ asyncpg==0.30.0 sqlmodel==0.0.22 alembic==1.14.0 hrid==0.2.4 +qdrant-client==1.12.1 +ollama==0.3.3 +fastembed==0.4.2 diff --git a/setup.py b/setup.py index 65c997fd..5f5cbea1 100644 --- a/setup.py +++ b/setup.py @@ -79,7 +79,12 @@ "base58>=2.1.0", "pymultihash>=0.8.2", ], - "postgres": ["psycopg2-binary>=2.9.10", "asyncpg>=0.30.0"], + "db": [ + "psycopg2-binary>=2.9.10", + "asyncpg>=0.30.0", + "qdrant-client>=1.12.1", + "fastembed>=0.4.2", + ], }, zip_safe=False, entry_points={"console_scripts": ["hypha = hypha.__main__:main"]}, diff --git a/test-openai.py b/test-openai.py new file mode 100644 index 00000000..13170140 --- /dev/null +++ b/test-openai.py @@ -0,0 +1,3 @@ +import ollama + +ollama.pull("all-minilm") diff --git a/tests/conftest.py b/tests/conftest.py index 396e6d5f..8b85095d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -108,6 +108,12 @@ def generate_authenticated_user_temporary(): yield from _generate_token("test-user", ["temporary-test-user"]) +@pytest_asyncio.fixture(name="test_user_token_5", scope="session") +def generate_authenticated_user_5(): + """Generate a test user token.""" + yield from _generate_token("user-5", []) + + @pytest_asyncio.fixture(name="triton_server", scope="session") def triton_server(): """Start a triton server as test fixture and tear down after test.""" @@ -223,7 +229,14 @@ def redis_server(): except Exception: # user docker to start redis subprocess.Popen( - ["docker", "run", "-d", "-p", f"{REDIS_PORT}:6379", "redis:6.2.5"] + [ + "docker", + "run", + "-d", + "-p", + f"{REDIS_PORT}:6379", + "redis/redis-stack:7.2.0-v13", + ] ) timeout = 10 while timeout > 0: @@ -262,6 +275,8 @@ def fastapi_server_fixture(minio_server, postgres_server): "--enable-s3-proxy", f"--workspace-bucket=my-workspaces", "--s3-admin-type=minio", + "--vectordb-uri=:memory:", + "--cache-dir=./bin/cache", f"--triton-servers=http://127.0.0.1:{TRITON_PORT}", "--static-mounts=/tests:./tests", "--startup-functions", @@ -346,17 +361,18 @@ def fastapi_server_redis_1(redis_server, minio_server): "-m", "hypha.server", f"--port={SIO_PORT_REDIS_1}", - # "--enable-server-apps", - # "--enable-s3", + "--enable-server-apps", + "--enable-s3", # need to define it so the two server can communicate f"--public-base-url=http://my-public-url.com", "--server-id=server-0", f"--redis-uri=redis://127.0.0.1:{REDIS_PORT}/0", "--reset-redis", - # f"--endpoint-url={MINIO_SERVER_URL}", - # f"--access-key-id={MINIO_ROOT_USER}", - # f"--secret-access-key={MINIO_ROOT_PASSWORD}", - # f"--endpoint-url-public={MINIO_SERVER_URL_PUBLIC}", + f"--endpoint-url={MINIO_SERVER_URL}", + f"--access-key-id={MINIO_ROOT_USER}", + f"--secret-access-key={MINIO_ROOT_PASSWORD}", + f"--endpoint-url-public={MINIO_SERVER_URL_PUBLIC}", + "--enable-service-search", ], env=test_env, ) as proc: diff --git a/tests/test_artifact.py b/tests/test_artifact.py index e69c391f..777e1709 100644 --- a/tests/test_artifact.py +++ b/tests/test_artifact.py @@ -2,6 +2,8 @@ import pytest import requests import os +import numpy as np +import random from hypha_rpc import connect_to_server from . import SERVER_URL, SERVER_URL_SQLITE, find_item @@ -10,6 +12,152 @@ pytestmark = pytest.mark.asyncio +async def test_artifact_vector_collection( + minio_server, fastapi_server, test_user_token +): + """Test vector-related functions within a vector-collection artifact.""" + + # Connect to the server and set up the artifact manager + api = await connect_to_server( + { + "name": "test deploy client", + "server_url": SERVER_URL, + "token": test_user_token, + } + ) + artifact_manager = await api.get_service("public/artifact-manager") + + # Create a vector-collection artifact + vector_collection_manifest = { + "name": "vector-collection", + "description": "A test vector collection", + } + vector_collection_config = { + "vectors_config": { + "size": 384, + "distance": "Cosine", + }, + "embedding_model": "fastembed:BAAI/bge-small-en-v1.5", + } + vector_collection = await artifact_manager.create( + type="vector-collection", + manifest=vector_collection_manifest, + config=vector_collection_config, + ) + # Add vectors to the collection + vectors = [ + { + "vector": [random.random() for _ in range(384)], + "payload": { + "text": "This is a test document.", + "label": "doc1", + "rand_number": random.randint(0, 10), + }, + }, + { + "vector": [np.random.rand(384)], + "payload": { + "text": "Another document.", + "label": "doc2", + "rand_number": random.randint(0, 10), + }, + }, + { + "vector": [np.random.rand(384)], + "payload": { + "text": "Yet another document.", + "label": "doc3", + "rand_number": random.randint(0, 10), + }, + }, + ] + await artifact_manager.add_vectors( + artifact_id=vector_collection.id, + vectors=vectors, + ) + + # Search for vectors by query vector + query_vector = [random.random() for _ in range(384)] + search_results = await artifact_manager.search_by_vector( + artifact_id=vector_collection.id, + query_vector=query_vector, + limit=2, + ) + assert len(search_results) <= 2 + + query_filter = { + "should": None, + "min_should": None, + "must": [ + { + "key": "rand_number", + "match": None, + "range": {"lt": None, "gt": None, "gte": 3.0, "lte": None}, + "geo_bounding_box": None, + "geo_radius": None, + "geo_polygon": None, + "values_count": None, + } + ], + "must_not": None, + } + + search_results = await artifact_manager.search_by_vector( + artifact_id=vector_collection.id, + query_filter=query_filter, + query_vector=np.random.rand(384), + limit=2, + ) + assert len(search_results) <= 2 + + # Search for vectors by text + documents = [ + {"text": "This is a test document.", "label": "doc1"}, + {"text": "Another test document.", "label": "doc2"}, + ] + await artifact_manager.add_documents( + artifact_id=vector_collection.id, + documents=documents, + ) + text_query = "test document" + text_search_results = await artifact_manager.search_by_text( + artifact_id=vector_collection.id, + query=text_query, + limit=2, + ) + assert len(text_search_results) <= 2 + + # Retrieve a specific vector + retrieved_vector = await artifact_manager.get_vector( + artifact_id=vector_collection.id, + id=text_search_results[0]["id"], + ) + assert retrieved_vector.id == text_search_results[0]["id"] + + # List vectors in the collection + vector_list = await artifact_manager.list_vectors( + artifact_id=vector_collection.id, + offset=0, + limit=10, + ) + assert len(vector_list) > 0 + + # Remove a vector from the collection + await artifact_manager.remove_vectors( + artifact_id=vector_collection.id, + ids=[vector_list[0]["id"]], + ) + remaining_vectors = await artifact_manager.list_vectors( + artifact_id=vector_collection.id, + offset=0, + limit=10, + ) + assert all(v["id"] != vector_list[0]["id"] for v in remaining_vectors) + + # Clean up by deleting the vector collection + await artifact_manager.delete(artifact_id=vector_collection.id) + + async def test_sqlite_create_and_search_artifacts( minio_server, fastapi_server_sqlite, test_user_token ): diff --git a/tests/test_server.py b/tests/test_server.py index f1944cd7..3f8f454c 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -3,6 +3,7 @@ import subprocess import sys import asyncio +import numpy as np import pytest import requests @@ -323,6 +324,111 @@ async def test_workspace_owners( await api.disconnect() +async def test_service_search(fastapi_server_redis_1, test_user_token): + """Test service search with registered services.""" + api = await connect_to_server( + { + "client_id": "my-app-99", + "server_url": SERVER_URL_REDIS_1, + "token": test_user_token, + } + ) + + # Register sample services with unique `docs` field + await api.register_service( + { + "id": "service-1", + "name": "example service one", + "type": "my-type", + "description": "This is the first test service.", + "app_id": "my-app", + "service_schema": {"example_key": "example_value"}, + "config": {"setting": "value1"}, + "docs": "This service handles data analysis workflows for genomics.", + } + ) + await api.register_service( + { + "id": "service-2", + "name": "example service two", + "type": "another-type", + "description": "This is the second test service.", + "app_id": "another-app", + "service_schema": {"example_key": "another_value"}, + "config": {"setting": "value2"}, + "docs": "This service focuses on image processing for medical imaging.", + } + ) + + await api.register_service( + { + "id": "service-3", + "name": "example service three", + "type": "my-type", + "description": "This is the third test service.", + "app_id": "my-app", + "service_schema": {"example_key": "yet_another_value"}, + "config": {"setting": "value3"}, + "docs": "This service specializes in natural language processing and AI chatbots.", + } + ) + + # Test semantic search using `text_query` + text_query = "NLP" + services = await api.search_services(text_query=text_query, limit=3) + assert isinstance(services, list) + assert len(services) <= 3 + # The top hit should be the service with "natural language processing" in the `docs` field + assert "natural language processing" in services[0]["docs"] + assert services[0]["score"] < services[1]["score"] + + embedding = np.ones(384).astype(np.float32) + await api.register_service( + { + "id": "service-88", + "name": "example service 88", + "type": "another-type", + "description": "This is the 88-th test service.", + "app_id": "another-app", + "service_schema": {"example_key": "another_value"}, + "config": {"setting": "value2", "service_embedding": embedding}, + "docs": "This service is used for performing alphafold calculations.", + } + ) + + # Test vector query with the exact embedding + services = await api.search_services(vector_query=embedding, limit=3) + assert isinstance(services, list) + assert len(services) <= 3 + assert "service-88" in services[0]["id"] + + # Test filter-based search with fuzzy matching on the `docs` field + filters = {"docs": "calculations*"} + services = await api.search_services(filters=filters, limit=3) + assert isinstance(services, list) + assert len(services) <= 3 + assert "calculations" in services[0]["docs"] + + # Test hybrid search (text query + filters) + filters = {"type": "my-type"} + text_query = "genomics workflows" + services = await api.search_services( + text_query=text_query, filters=filters, limit=3 + ) + assert isinstance(services, list) + assert all(service["type"] == "my-type" for service in services) + # The top hit should be the service with "genomics" in the `docs` field + assert "genomics" in services[0]["docs"].lower() + + # Test hybrid search (embedding + filters) + filters = {"type": "my-type"} + services = await api.search_services( + vector_query=np.random.rand(384), filters=filters, limit=3 + ) + assert isinstance(services, list) + assert all(service["type"] == "my-type" for service in services) + + async def test_server_scalability( fastapi_server_redis_1, fastapi_server_redis_2, test_user_token ): diff --git a/tests/test_server_apps.py b/tests/test_server_apps.py index 5d3c7c91..7c9cf633 100644 --- a/tests/test_server_apps.py +++ b/tests/test_server_apps.py @@ -38,7 +38,7 @@ async def test_server_apps_unauthorized( - fastapi_server, test_user_token, root_user_token + fastapi_server, test_user_token_5, root_user_token ): """Test the server apps.""" api = await connect_to_server( @@ -46,7 +46,7 @@ async def test_server_apps_unauthorized( "name": "test client", "server_url": WS_SERVER_URL, "method_timeout": 30, - "token": test_user_token, + "token": test_user_token_5, } ) controller = await api.get_service("public/server-apps") diff --git a/tests/test_server_disconnection.py b/tests/test_server_disconnection.py index f54fe1d4..bdc95322 100644 --- a/tests/test_server_disconnection.py +++ b/tests/test_server_disconnection.py @@ -57,29 +57,3 @@ async def test_server_reconnection(fastapi_server, root_user_token): except Exception as e: assert "Connection is closed" in str(e) await api.disconnect() - - -async def test_server_reconnection_by_workspace_unload(fastapi_server): - """Test the server reconnection.""" - # connect to the server with a user - api = await connect_to_server({"server_url": WS_SERVER_URL, "client_id": "client1"}) - token = await api.generate_token() - - # connect to the server with the same user, to the same workspace - api2 = await connect_to_server( - { - "server_url": WS_SERVER_URL, - "client_id": "client2", - "workspace": api.config["workspace"], - "token": token, - } - ) - # force a server side disconnect to the second client - await api.disconnect() - try: - assert await api2.echo("hi") == "hi" - except Exception as e: - # timeout due to the server side disconnect - assert "Method call time out:" in str(e) - - await asyncio.sleep(100)