From fdcc43385196a1e20850b8150414054e58e40bd4 Mon Sep 17 00:00:00 2001 From: Wei Ouyang Date: Thu, 10 Oct 2024 11:12:21 -0700 Subject: [PATCH] Log event to sql database (#696) * support _token and remove prefix in login * allow pass workspace when login * Support current workspace for http endpoint * fix login optional * increase page size for list workspaces * Bump version for hypha-rpc 0.20.38 * Update change logs and login instructions * Support artifact endpoint via http * Implement sql database * change _id to _prefix * clean up * redirect login * use sql to store workspace info * add stage_files * Update helm charts * Fix workspaces db * skip default database uri * Use in-memory sql for artifacts * Fix workspace loading error * restore workspace info * rename it to test-3 * restore version * add logging service * Fix artifacts * Merge event log to workspace * make sure error is raised * support observability * Add tests * test observability * Fix counter duplicated error * add change log * Support download statistics * Update docs * Update change log * Remove set logging service --- CHANGELOG.md | 5 + docs/artifact-manager.md | 107 +++++++-------- hypha/VERSION | 2 +- hypha/artifact.py | 216 +++++++++++++++++++++++++------ hypha/core/__init__.py | 16 +++ hypha/core/store.py | 20 +++ hypha/core/workspace.py | 251 +++++++++++++++++++++++++++++------- hypha/http.py | 7 +- hypha/server.py | 2 +- hypha/websocket.py | 6 + requirements.txt | 1 + setup.py | 1 + tests/test_artifact.py | 130 +++++++++++++++++++ tests/test_event_log.py | 81 ++++++++++++ tests/test_observability.py | 88 +++++++++++++ tests/test_server.py | 48 +++++++ 16 files changed, 835 insertions(+), 146 deletions(-) create mode 100644 tests/test_event_log.py create mode 100644 tests/test_observability.py diff --git a/CHANGELOG.md b/CHANGELOG.md index ff7ed89e..5c0a9c9e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,8 +1,13 @@ # Hypha Change Log ### 0.20.38 + + - Support event logging in the workspace, use `log_event` to log events in the workspace and use `get_events` to get the events in the workspace. The events will be persists in the SQL database. - Allow passing workspace and expires_in to the `login` function to generate workspace specific token. - When using http endpoint to access the service, you can now pass workspace specific token to the http header `Authorization` to access the service. (Previously, all the services are assumed to be accessed from the same service provider workspace) + - Breaking Change: Remove `info`, `warning`, `error`, `critical`, `debug` from the `hypha` module, use `log` or `log_event` instead. + - Support basic observability for the workspace, including workspace status, event bus and websocket connection status. + - Support download statistics for the artifacts in the artifact manager. ### 0.20.37 - Add s3-proxy to allow accessing s3 presigned url in case the s3 server is not directly accessible. Use `--enable-s3-proxy` to enable the s3 proxy when starting Hypha. diff --git a/docs/artifact-manager.md b/docs/artifact-manager.md index 675830b5..e75c4990 100644 --- a/docs/artifact-manager.md +++ b/docs/artifact-manager.md @@ -1,17 +1,18 @@ # Artifact Manager -The `Artifact Manager` is a built-in Hypha service for indexing, managing, and storing resources such as datasets, AI models, and applications. It provides a structured way to manage datasets and similar resources, enabling efficient listing, uploading, updating, and deleting of files. +The `Artifact Manager` is a built-in Hypha service for indexing, managing, and storing resources such as datasets, AI models, and applications. It provides a structured way to manage datasets and similar resources, enabling efficient listing, uploading, updating, and deleting of files. It also now supports tracking download statistics for each artifact. -A typical use case for the `Artifact Manager` is as a backend for a single-page web application that displays a gallery of datasets, AI models, applications, or other types of resources. The default metadata of an artifact is designed to render a grid of cards on a webpage. +A typical use case for the `Artifact Manager` is as a backend for a single-page web application that displays a gallery of datasets, AI models, applications, or other types of resources. The default metadata of an artifact is designed to render a grid of cards on a webpage. It also supports tracking download statistics. **Note:** The `Artifact Manager` is only available when your Hypha server has S3 storage enabled. +--- ## Getting Started ### Step 1: Connecting to the Artifact Manager Service -To use the `Artifact Manager`, you first need to connect to the Hypha server. This API allows you to create, read, edit, and delete datasets in the artifact registry (stored in a S3 bucket for each workspace). +To use the `Artifact Manager`, you first need to connect to the Hypha server. This API allows you to create, read, edit, and delete datasets in the artifact registry (stored in an S3 bucket for each workspace). ```python from hypha_rpc.websocket_client import connect_to_server @@ -25,7 +26,7 @@ artifact_manager = await server.get_service("public/artifact-manager") ### Step 2: Creating a Dataset Gallery Collection -Once connected, you can create a collection to organize datasets in the gallery. +Once connected, you can create a collection to organize datasets in the gallery. ```python # Create a collection for the Dataset Gallery @@ -59,13 +60,15 @@ await artifact_manager.create(prefix="collections/dataset-gallery/example-datase print("Dataset added to the gallery.") ``` -### Step 4: Uploading Files to the Dataset +### Step 4: Uploading Files to the Dataset with Download Statistics + +Once you have created a dataset, you can upload files to it by generating a pre-signed URL. This URL allows you to upload the actual files to the artifact's S3 bucket. -Once you have created a dataset, you can upload files to it by generating a pre-signed URL. +Additionally, when uploading files to an artifact, you can specify a `download_weight` for each file. This weight determines how the file impacts the artifact's download count when it is accessed. For example, primary files might have a higher `download_weight`, while secondary files might have no impact. The download count is automatically updated whenever users download files from the artifact. ```python -# Get a pre-signed URL to upload a file -put_url = await artifact_manager.put_file(prefix="collections/dataset-gallery/example-dataset", file_path="data.csv") +# Get a pre-signed URL to upload a file, with a download_weight assigned +put_url = await artifact_manager.put_file(prefix="collections/dataset-gallery/example-dataset", file_path="data.csv", options={"download_weight": 0.5}) # Upload the file using an HTTP PUT request with open("path/to/local/data.csv", "rb") as f: @@ -99,7 +102,7 @@ print("Datasets in the gallery:", datasets) ## Full Example: Creating and Managing a Dataset Gallery -Here’s a full example that shows how to connect to the service, create a dataset gallery, add a dataset, upload files, and commit the dataset. +Here’s a full example that shows how to connect to the service, create a dataset gallery, add a dataset, upload files with download statistics, and commit the dataset. ```python import asyncio @@ -135,8 +138,8 @@ async def main(): await artifact_manager.create(prefix="collections/dataset-gallery/example-dataset", manifest=dataset_manifest, stage=True) print("Dataset added to the gallery.") - # Get a pre-signed URL to upload a file - put_url = await artifact_manager.put_file(prefix="collections/dataset-gallery/example-dataset", file_path="data.csv") + # Get a pre-signed URL to upload a file, with a download_weight assigned + put_url = await artifact_manager.put_file(prefix="collections/dataset-gallery/example-dataset", file_path="data.csv", options={"download_weight": 0.5}) # Upload the file using an HTTP PUT request with open("path/to/local/data.csv", "rb") as f: @@ -217,43 +220,10 @@ await artifact_manager.commit(prefix="collections/schema-dataset-gallery/valid-d print("Valid dataset committed.") ``` -### Step 3: Accessing the collection via HTTP API - -You can access the collection via the HTTP API to retrieve the schema and datasets. -This can be used for rendering a gallery of datasets on a webpage. - -```javascript -// Fetch the schema for the collection -fetch("https://hypha.aicell.io/my-workspace/artifact/public/collections/schema-dataset-gallery") - .then(response => response.json()) - .then(data => console.log("Schema:", data.collection_schema)); -``` - -## API Reference - -This section details the core functions provided by the `Artifact Manager` for creating, managing, and validating artifacts such as datasets and collections. - -### `create(prefix: str, manifest: dict, overwrite: bool = False, stage: bool = False) -> dict` - -Creates a new artifact or collection with the provided manifest. If the artifact already exists, you must set `overwrite=True` to overwrite it. - -**Parameters:** - -- `prefix`: The path where the artifact or collection will be created (e.g., `"collections/dataset-gallery"`). -- `manifest`: The manifest describing the artifact (must include fields like `id`, `name`, and `type`). -- `overwrite`: Optional. If `True`, it will overwrite an existing artifact. Default is `False`. -- `stage`: Optional. If `True`, it will put the artifact into staging mode. Default is `False`. - -**Returns:** The created manifest as a dictionary. - -**Example:** - -```python -await artifact_manager.create(prefix="collections/dataset-gallery", manifest=gallery_manifest) -``` - --- +## API References + ### `edit(prefix: str, manifest: dict) -> None` Edits an existing artifact. You provide the new manifest to update the artifact. The updated manifest is stored temporarily as `_manifest.yaml`. @@ -303,7 +273,7 @@ await artifact_manager.delete(prefix="collections/dataset-gallery/example-datase --- -### `put_file(prefix: str, file_path: str) -> str` +### `put_file(prefix: str, file_path: str, options: dict = None) -> str` Generates a pre-signed URL to upload a file to an artifact. You can then use the URL with an HTTP `PUT` request to upload the file. @@ -311,13 +281,16 @@ Generates a pre-signed URL to upload a file to an artifact. You can then use the - `prefix`: The path of the artifact where the file will be uploaded (e.g., `"collections/dataset-gallery/example-dataset"`). - `file_path`: The relative path of the file to upload within the artifact (e.g., `"data.csv"`). +- `options`: Optional. Additional options for the file upload. Default is `None`. +The options can include: + - `download_weight`: A float value representing the impact of the file on the download count. Default is `0`. **Returns:** A pre-signed URL for uploading the file. **Example:** ```python -put_url = await artifact_manager.put_file(prefix="collections/dataset-gallery/example-dataset", file_path="data.csv") +put_url = await artifact_manager.put_file(prefix="collections/dataset-gallery/example-dataset", file_path="data.csv", options={"download_weight": 1.0}) ``` --- @@ -339,7 +312,7 @@ await artifact_manager.remove_file(prefix="collections/dataset-gallery/example-d --- -### `get_file(prefix: str, path: str) -> str` +### `get_file(prefix: str, path: str, options: dict=None) -> str` Generates a pre-signed URL to download a file from the artifact. @@ -347,6 +320,9 @@ Generates a pre-signed URL to download a file from the artifact. - `prefix`: The path of the artifact (e.g., `"collections/dataset-gallery/example-dataset"`). - `path`: The relative path of the file to download (e.g., `"data.csv"`). +- `options`: Optional. Additional options for the file download. Default is `None`. +The options can include: + - `silent`: A boolean flag to suppress download statistics. Default is `False`. **Returns:** A pre-signed URL for downloading the file. @@ -460,6 +436,20 @@ print("Datasets in the gallery:", datasets) The `Artifact Manager` provides an HTTP endpoint for retrieving artifact manifests and data. This is useful for public-facing web applications that need to access datasets, models, or applications. + +### Resetting Download Statistics + +You can reset the download statistics of a dataset using the `reset_stats` function. + +```python +await artifact_manager.reset_stats(prefix="collections/dataset-gallery/example-dataset") +print("Download statistics reset.") +``` + +## HTTP API for Accessing Artifacts and Download Counts + +The `Artifact Manager` provides an HTTP endpoint for retrieving artifact manifests, data, and download statistics. This is useful for public-facing web applications that need to access datasets, models, or applications. + ### Endpoint: `/{workspace}/artifact/{path:path}` - **Workspace**: The workspace in which the artifact is stored. @@ -472,17 +462,18 @@ The `Artifact Manager` provides an HTTP endpoint for retrieving artifact manifes - **Method**: `GET` - **Parameters**: - `workspace`: The workspace in which the artifact is stored. - - `path`: The path to the artifact (e.g., `public/collections/dataset-gallery/example-dataset`). + - `path`: + + The path to the artifact (e.g., `public/collections/dataset-gallery/example-dataset`). - `stage` (optional): A boolean flag to indicate whether to fetch the staged version of the manifest (`_manifest.yaml`). Default is `False`. ### Response: -- **For public artifacts**: Returns the artifact manifest if it exists under the `public/` prefix. +- **For public artifacts**: Returns the artifact manifest if it exists under the `public/` prefix, including any download statistics. - **For private artifacts**: Returns the artifact manifest if the user has the necessary permissions. -### Example: -#### Fetching a public artifact: +### Example: Fetching a public artifact with download statistics ```python import requests @@ -493,17 +484,9 @@ response = requests.get(f"{SERVER_URL}/{workspace}/artifact/public/collections/d if response.ok: artifact = response.json() print(artifact["name"]) # Output: Example Dataset + print(artifact["_stats"]["download_count"]) # Output: Download count for the dataset else: print(f"Error: {response.status_code}") ``` -#### Fetching a private artifact: -```python -response = requests.get(f"{SERVER_URL}/{workspace}/artifact/collections/private-dataset-gallery/private-example-dataset") -if response.ok: - artifact = response.json() - print(artifact["name"]) # Output: Private Example Dataset -else: - print(f"Error: {response.status_code}") -``` diff --git a/hypha/VERSION b/hypha/VERSION index 26d90df8..6e281531 100644 --- a/hypha/VERSION +++ b/hypha/VERSION @@ -1,3 +1,3 @@ { - "version": "0.20.37.post4" + "version": "0.20.38" } diff --git a/hypha/artifact.py b/hypha/artifact.py index 1521bcdc..ecf4c668 100644 --- a/hypha/artifact.py +++ b/hypha/artifact.py @@ -1,10 +1,12 @@ import logging import sys +import copy from sqlalchemy import ( event, Column, String, Integer, + Float, JSON, UniqueConstraint, select, @@ -14,9 +16,9 @@ ) from hypha.utils import remove_objects_async, list_objects_async, safe_join from botocore.exceptions import ClientError +from sqlalchemy import update from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.asyncio import ( - create_async_engine, async_sessionmaker, AsyncSession, ) @@ -51,6 +53,11 @@ class ArtifactModel(Base): manifest = Column(JSON, nullable=True) # Store committed manifest stage_manifest = Column(JSON, nullable=True) # Store staged manifest stage_files = Column(JSON, nullable=True) # Store staged files during staging + download_weights = Column( + JSON, nullable=True + ) # Store the weights for counting downloads; a dictionary of file paths and their weights 0-1 + download_count = Column(Float, nullable=False, default=0.0) # New counter field + view_count = Column(Float, nullable=False, default=0.0) # New counter field __table_args__ = ( UniqueConstraint("workspace", "prefix", name="_workspace_prefix_uc"), ) @@ -64,24 +71,15 @@ def __init__( store, s3_controller, workspace_bucket="hypha-workspaces", - database_uri=None, ): """Set up controller with SQLAlchemy database and S3 for file storage.""" - if database_uri is None: - # create an in-memory SQLite database for testing - database_uri = "sqlite+aiosqlite:///:memory:" - logger.warning( - "Using in-memory SQLite database for artifact manager, all data will be lost on restart!!!" - ) - self.engine = create_async_engine(database_uri, echo=False) + self.engine = store.get_sql_engine() self.SessionLocal = async_sessionmaker( self.engine, expire_on_commit=False, class_=AsyncSession ) self.s3_controller = s3_controller self.workspace_bucket = workspace_bucket - - store.register_public_service(self.get_artifact_service()) - store.set_artifact_manager(self) + self.store = store router = APIRouter() @@ -90,7 +88,7 @@ async def get_artifact( workspace: str, path: str, stage: bool = False, - user_info: store.login_optional = Depends(store.login_optional), + user_info: self.store.login_optional = Depends(self.store.login_optional), ): """Get artifact from the database.""" try: @@ -107,7 +105,9 @@ async def get_artifact( except PermissionError as e: raise HTTPException(status_code=403, detail=str(e)) - store.register_router(router) + self.store.set_artifact_manager(self) + self.store.register_public_service(self.get_artifact_service()) + self.store.register_router(router) async def init_db(self): """Initialize the database and create tables.""" @@ -135,6 +135,34 @@ def end_transaction(session, transaction): return session + async def _get_artifact(self, session, workspace, prefix): + query = select(ArtifactModel).filter( + ArtifactModel.workspace == workspace, + ArtifactModel.prefix == prefix, + ) + result = await session.execute(query) + return result.scalar_one_or_none() + + async def _get_session(self, read_only=False): + """Return an SQLAlchemy async session. If read_only=True, ensure no modifications are allowed.""" + session = self.SessionLocal() + + if read_only: + sync_session = session.sync_session # Access the synchronous session object + + @event.listens_for(sync_session, "before_flush") + def prevent_flush(session, flush_context, instances): + """Prevent any flush operations to keep the session read-only.""" + raise RuntimeError("This session is read-only.") + + @event.listens_for(sync_session, "after_transaction_end") + def end_transaction(session, transaction): + """Ensure rollback after a transaction in a read-only session.""" + if not transaction._parent: + session.rollback() + + return session + async def _read_manifest(self, workspace, prefix, stage=False): session = await self._get_session() try: @@ -181,7 +209,26 @@ async def _read_manifest(self, workspace, prefix, stage=False): if stage: manifest["stage_files"] = artifact.stage_files + else: + # increase view count + stmt = ( + update(ArtifactModel) + .where(ArtifactModel.id == artifact.id) + # atomically increment the view count + .values(view_count=ArtifactModel.view_count + 1) + .execution_options(synchronize_session="fetch") + ) + await session.execute(stmt) + await session.commit() + manifest["_stats"] = { + "download_count": artifact.download_count, + "view_count": artifact.view_count, + } + if manifest.get("type") == "collection": + manifest["_stats"]["child_count"] = len(collection) return manifest + except Exception as e: + raise e finally: await session.close() @@ -246,17 +293,51 @@ async def create( manifest=None if stage else manifest, stage_manifest=manifest if stage else None, stage_files=[] if stage else None, + download_weights=None, type=manifest["type"], ) session.add(new_artifact) await session.commit() logger.info(f"Created artifact under prefix: {prefix}") - + except Exception as e: + raise e finally: await session.close() return manifest + async def reset_stats(self, prefix, context: dict): + """Reset the artifact's manifest's download count and view count.""" + if context is None or "ws" not in context: + raise ValueError("Context must include 'ws' (workspace).") + ws = context["ws"] + + user_info = UserInfo.model_validate(context["user"]) + if not user_info.check_permission(ws, UserPermission.read_write): + raise PermissionError( + "User does not have write permission to the workspace." + ) + + session = await self._get_session() + try: + async with session.begin(): + artifact = await self._get_artifact(session, ws, prefix) + if not artifact: + raise KeyError(f"Artifact under prefix '{prefix}' does not exist.") + stmt = ( + update(ArtifactModel) + .where(ArtifactModel.id == artifact.id) + .values(download_count=0, view_count=0) + .execution_options(synchronize_session="fetch") + ) + await session.execute(stmt) + await session.commit() + logger.info(f"Reset artifact under prefix: {prefix}") + except Exception as e: + raise e + finally: + await session.close() + async def read(self, prefix, stage=False, context: dict = None): """Read the artifact's manifest from the database and populate collections dynamically.""" if context is None or "ws" not in context: @@ -284,17 +365,18 @@ async def edit(self, prefix, manifest=None, context: dict = None): "User does not have write permission to the workspace." ) - # Validate the manifest - if manifest["type"] == "collection": - CollectionArtifact.model_validate(manifest) - elif manifest["type"] == "application": - ApplicationArtifact.model_validate(manifest) - elif manifest["type"] == "workspace": - WorkspaceInfo.model_validate(manifest) + if manifest: + # Validate the manifest + if manifest["type"] == "collection": + CollectionArtifact.model_validate(manifest) + elif manifest["type"] == "application": + ApplicationArtifact.model_validate(manifest) + elif manifest["type"] == "workspace": + WorkspaceInfo.model_validate(manifest) - # Convert ObjectProxy to dict if necessary - if isinstance(manifest, ObjectProxy): - manifest = ObjectProxy.toDict(manifest) + # Convert ObjectProxy to dict if necessary + if isinstance(manifest, ObjectProxy): + manifest = ObjectProxy.toDict(manifest) session = await self._get_session() try: @@ -302,12 +384,15 @@ async def edit(self, prefix, manifest=None, context: dict = None): artifact = await self._get_artifact(session, ws, prefix) if not artifact: raise KeyError(f"Artifact under prefix '{prefix}' does not exist.") - + if manifest is None: + manifest = copy.deepcopy(artifact.manifest) artifact.stage_manifest = manifest flag_modified(artifact, "stage_manifest") # Mark JSON field as modified session.add(artifact) await session.commit() logger.info(f"Edited artifact under prefix: {prefix}") + except Exception as e: + raise e finally: await session.close() @@ -334,6 +419,7 @@ async def commit(self, prefix, context: dict): manifest = artifact.stage_manifest + download_weights = {} # Validate files exist in S3 if the staged files list is present if artifact.stage_files: async with self.s3_controller.create_client_async() as s3_client: @@ -347,7 +433,10 @@ async def commit(self, prefix, context: dict): raise FileNotFoundError( f"File '{file_info['path']}' does not exist in the artifact." ) - + if file_info.get("download_weight") is not None: + download_weights[file_info["path"]] = file_info[ + "download_weight" + ] # Validate the schema if the artifact belongs to a collection parent_prefix = "/".join(prefix.split("/")[:-1]) if parent_prefix: @@ -370,10 +459,15 @@ async def commit(self, prefix, context: dict): artifact.manifest = manifest artifact.stage_manifest = None artifact.stage_files = None + artifact.download_weights = download_weights flag_modified(artifact, "manifest") + flag_modified(artifact, "stage_files") + flag_modified(artifact, "download_weights") session.add(artifact) await session.commit() logger.info(f"Committed artifact under prefix: {prefix}") + except Exception as e: + raise e finally: await session.close() @@ -397,6 +491,8 @@ async def delete(self, prefix, context: dict): await session.delete(artifact) await session.commit() logger.info(f"Deleted artifact under prefix: {prefix}") + except Exception as e: + raise e finally: await session.close() @@ -457,6 +553,8 @@ async def list_artifacts(self, prefix="", stage=False, context: dict = None): name = name.split("/")[0] collection.append(name) return collection + except Exception as e: + raise e finally: await session.close() @@ -546,7 +644,6 @@ async def search( summary_fields.append({"_prefix": artifact.prefix, **sub_manifest}) return summary_fields - except Exception as e: raise ValueError( f"An error occurred while executing the search query: {str(e)}" @@ -554,7 +651,9 @@ async def search( finally: await session.close() - async def put_file(self, prefix, file_path, context: dict = None): + async def put_file( + self, prefix, file_path, options: dict = None, context: dict = None + ): """Generate a pre-signed URL to upload a file to an artifact in S3 and update the manifest.""" ws = context["ws"] user_info = UserInfo.model_validate(context["user"]) @@ -563,6 +662,8 @@ async def put_file(self, prefix, file_path, context: dict = None): "User does not have write permission to the workspace." ) + options = options or {} + async with self.s3_controller.create_client_async() as s3_client: file_key = safe_join(ws, f"{prefix}/{file_path}") presigned_url = await s3_client.generate_presigned_url( @@ -581,18 +682,25 @@ async def put_file(self, prefix, file_path, context: dict = None): artifact.stage_files = artifact.stage_files or [] if not any(f["path"] == file_path for f in artifact.stage_files): - artifact.stage_files.append({"path": file_path}) + artifact.stage_files.append( + { + "path": file_path, + "download_weight": options.get("download_weight"), + } + ) flag_modified(artifact, "stage_files") session.add(artifact) await session.commit() logger.info(f"Generated pre-signed URL for file upload: {file_path}") + except Exception as e: + raise e finally: await session.close() return presigned_url - async def get_file(self, prefix, path, context: dict): + async def get_file(self, prefix, path, options: dict = None, context: dict = None): """Generate a pre-signed URL to download a file from an artifact in S3.""" ws = context["ws"] @@ -610,6 +718,30 @@ async def get_file(self, prefix, path, context: dict): ExpiresIn=3600, ) logger.info(f"Generated pre-signed URL for file download: {path}") + + if options is None or not options.get("silent"): + session = await self._get_session() + try: + async with session.begin(): + artifact = await self._get_artifact(session, ws, prefix) + if artifact.download_weights and path in artifact.download_weights: + # if it has download_weights, increment the download count by the weight + stmt = ( + update(ArtifactModel) + .where(ArtifactModel.id == artifact.id) + # atomically increment the download count by the weight + .values( + download_count=ArtifactModel.download_count + + artifact.download_weights[path] + ) + .execution_options(synchronize_session="fetch") + ) + await session.execute(stmt) + await session.commit() + except Exception as e: + raise e + finally: + await session.close() return presigned_url async def remove_file(self, prefix, file_path, context: dict): @@ -630,13 +762,24 @@ async def remove_file(self, prefix, file_path, context: dict): raise KeyError( f"Artifact under prefix '{prefix}' is not in staging mode." ) - # remove the file from the staged files list - artifact.stage_files = [ - f for f in artifact.stage_files if f["path"] != file_path - ] - flag_modified(artifact, "stage_files") + if artifact.stage_files: + # remove the file from the staged files list + artifact.stage_files = [ + f for f in artifact.stage_files if f["path"] != file_path + ] + flag_modified(artifact, "stage_files") + if artifact.download_weights: + # remove the file from download_weights if it's there + artifact.download_weights = { + k: v + for k, v in artifact.download_weights.items() + if k != file_path + } + flag_modified(artifact, "download_weights") session.add(artifact) await session.commit() + except Exception as e: + raise e finally: await session.close() @@ -654,6 +797,7 @@ def get_artifact_service(self): "name": "Artifact Manager", "description": "Manage artifacts in a workspace.", "create": self.create, + "reset_stats": self.reset_stats, "edit": self.edit, "read": self.read, "commit": self.commit, diff --git a/hypha/core/__init__.py b/hypha/core/__init__.py index 3aa89f18..e5b79a55 100644 --- a/hypha/core/__init__.py +++ b/hypha/core/__init__.py @@ -22,6 +22,8 @@ from hypha.utils import EventBus import jsonschema +from prometheus_client import Counter + logging.basicConfig(stream=sys.stdout) logger = logging.getLogger("core") logger.setLevel(logging.INFO) @@ -347,6 +349,8 @@ def validate_scopes(cls, v): class RedisRPCConnection: """Represent a Redis connection for handling RPC-like messaging.""" + _counter = Counter("rpc_call", "Counts the RPC calls", ["workspace"]) + def __init__( self, event_bus: EventBus, @@ -421,6 +425,7 @@ async def emit_message(self, data: Union[dict, bytes]): packed_message = msgpack.packb(message) + data[pos:] # logger.info(f"Sending message to channel {target_id}:msg") await self._event_bus.emit(f"{target_id}:msg", packed_message) + RedisRPCConnection._counter.labels(workspace=self._workspace).inc() async def disconnect(self, reason=None): """Handle disconnection.""" @@ -440,6 +445,10 @@ async def disconnect(self, reason=None): class RedisEventBus: """Represent a redis event bus.""" + _counter = Counter( + "event_bus", "Counts the events on the redis event bus", ["event"] + ) + def __init__(self, redis) -> None: """Initialize the event bus.""" self._redis = redis @@ -568,18 +577,25 @@ async def _subscribe_redis(self): try: if msg: channel = msg["channel"].decode("utf-8") + RedisEventBus._counter.labels(event="*").inc() if channel.startswith("event:b:"): event_type = channel[8:] data = msg["data"] await self._redis_event_bus.emit(event_type, data) + if ":" not in event_type: + RedisEventBus._counter.labels(event=event_type).inc() elif channel.startswith("event:d:"): event_type = channel[8:] data = json.loads(msg["data"]) await self._redis_event_bus.emit(event_type, data) + if ":" not in event_type: + RedisEventBus._counter.labels(event=event_type).inc() elif channel.startswith("event:s:"): event_type = channel[8:] data = msg["data"].decode("utf-8") await self._redis_event_bus.emit(event_type, data) + if ":" not in event_type: + RedisEventBus._counter.labels(event=event_type).inc() else: logger.info("Unknown channel: %s", channel) except Exception as exp: diff --git a/hypha/core/store.py b/hypha/core/store.py index b4d8ea43..09822871 100644 --- a/hypha/core/store.py +++ b/hypha/core/store.py @@ -21,6 +21,9 @@ UserInfo, WorkspaceInfo, ) +from sqlalchemy.ext.asyncio import ( + create_async_engine, +) from hypha.core.auth import ( create_scope, parse_token, @@ -87,6 +90,7 @@ def __init__( public_base_url=None, local_base_url=None, redis_uri=None, + database_uri=None, reconnection_token_life_time=2 * 24 * 60 * 60, ): """Initialize the redis store.""" @@ -121,6 +125,17 @@ def __init__( logger.info("Server info: %s", self._server_info) + self._database_uri = database_uri + if self._database_uri is None: + database_uri = ( + "sqlite+aiosqlite:///:memory:" # In-memory SQLite for testing + ) + logger.warning( + "Using in-memory SQLite database for event logging, all data will be lost on restart!" + ) + + self._sql_engine = create_async_engine(database_uri, echo=False) + if redis_uri and redis_uri.startswith("redis://"): from redis import asyncio as aioredis @@ -148,6 +163,9 @@ def kickout_client(self, workspace: str, client_id: str, code: int, reason: str) def get_redis(self): return self._redis + def get_sql_engine(self): + return self._sql_engine + def get_event_bus(self): """Get the event bus.""" return self._event_bus @@ -430,6 +448,7 @@ async def init(self, reset_redis, startup_functions=None): raise for service in self._public_services: try: + logger.info("Registering public service: %s", service.id) await api.register_service( service.model_dump(), {"notify": True}, @@ -615,6 +634,7 @@ async def register_workspace_manager(self): self._event_bus, self._server_info, self._manager_id, + self._sql_engine, self._s3_controller, self._artifact_manager, ) diff --git a/hypha/core/workspace.py b/hypha/core/workspace.py index e5d15aeb..2fdaa04b 100644 --- a/hypha/core/workspace.py +++ b/hypha/core/workspace.py @@ -6,11 +6,24 @@ from typing import Optional, Union, List, Any, Dict from contextlib import asynccontextmanager import random +import datetime from fakeredis import aioredis +from prometheus_client import Gauge from hypha_rpc import RPC from hypha_rpc.utils.schema import schema_method from pydantic import BaseModel, Field +from sqlalchemy import ( + Column, + String, + Integer, + JSON, + DateTime, + select, + func, +) +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from hypha.core import ( ApplicationArtifact, @@ -29,12 +42,40 @@ logger = logging.getLogger("workspace") logger.setLevel(logging.INFO) +Base = declarative_base() + + SERVICE_SUMMARY_FIELD = ["id", "name", "type", "description", "config"] # Ensure the client_id is safe _allowed_characters = re.compile(r"^[a-zA-Z0-9-_/|*]*$") +# SQLAlchemy model for storing events +class EventLog(Base): + __tablename__ = "event_logs" + + id = Column(Integer, primary_key=True, autoincrement=True) + event_type = Column(String, nullable=False) + workspace = Column(String, nullable=False) + user_id = Column(String, nullable=False) + timestamp = Column( + DateTime, default=datetime.datetime.now(datetime.timezone.utc), index=True + ) + data = Column(JSON, nullable=True) # Store any additional event metadata + + def to_dict(self): + """Convert the SQLAlchemy model instance to a dictionary.""" + return { + "id": self.id, + "event_type": self.event_type, + "workspace": self.workspace, + "user_id": self.user_id, + "timestamp": self.timestamp.isoformat(), # Convert datetime to ISO string + "data": self.data, + } + + def validate_key_part(key_part: str): """Ensure key parts only contain safe characters.""" if not _allowed_characters.match(key_part): @@ -64,6 +105,7 @@ def __init__( event_bus: EventBus, server_info: dict, client_id: str, + sql_engine: Optional[str] = None, s3_controller: Optional[Any] = None, artifact_manager: Optional[Any] = None, ): @@ -76,6 +118,21 @@ def __init__( self._client_id = client_id self._s3_controller = s3_controller self._artifact_manager = artifact_manager + self._sql_engine = sql_engine + if self._sql_engine: + self.SessionLocal = async_sessionmaker( + self._sql_engine, expire_on_commit=False, class_=AsyncSession + ) + else: + self.SessionLocal = None + self._active_ws = Gauge("active_workspaces", "Number of active workspaces") + self._active_svc = Gauge( + "active_services", "Number of active services", ["workspace"] + ) + + async def _get_sql_session(self): + """Return an async session for the database.""" + return self.SessionLocal() def get_client_id(self): assert self._client_id, "client id must not be empty." @@ -97,9 +154,140 @@ async def setup( management_service, {"notify": False}, ) + if self._sql_engine: + async with self._sql_engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + logger.info("Database tables created successfully.") self._initialized = True return rpc + @schema_method + async def log_event( + self, + event_type: str = Field(..., description="Event type"), + data: Optional[dict] = Field(None, description="Additional event data"), + context: dict = None, + ): + """Log a new event, checking permissions.""" + assert " " not in event_type, "Event type must not contain spaces" + workspace = context["ws"] + user_info = UserInfo.model_validate(context["user"]) + if not user_info.check_permission(workspace, UserPermission.read_write): + raise PermissionError(f"Permission denied for workspace {workspace}") + + session = await self._get_sql_session() + try: + async with session.begin(): + event_log = EventLog( + event_type=event_type, + workspace=workspace, + user_id=user_info.id, + data=data, + ) + session.add(event_log) + await session.commit() + logger.info( + f"Logged event: {event_type} by {user_info.id} in {workspace}" + ) + logger.info(f"Event logged: {event_type}") + except Exception as e: + logger.error(f"Failed to log event: {event_type}, {e}") + raise + finally: + await session.close() + + @schema_method + async def get_event_stats( + self, + event_type: Optional[str] = Field(None, description="Event type"), + start_time: Optional[datetime.datetime] = Field( + None, description="Start time for filtering events" + ), + end_time: Optional[datetime.datetime] = Field( + None, description="End time for filtering events" + ), + context: Optional[dict] = None, + ): + """Get statistics for specific event types, filtered by workspace, user, and time range.""" + workspace = context["ws"] + user_info = UserInfo.model_validate(context["user"]) + if not user_info.check_permission(workspace, UserPermission.read): + raise PermissionError(f"Permission denied for workspace {workspace}") + + session = await self._get_sql_session() + try: + async with session.begin(): + query = select( + EventLog.event_type, func.count(EventLog.id).label("count") + ).filter( + EventLog.workspace == workspace, EventLog.user_id == user_info.id + ) + + # Apply optional filters + if event_type: + query = query.filter(EventLog.event_type == event_type) + if start_time: + query = query.filter(EventLog.timestamp >= start_time) + if end_time: + query = query.filter(EventLog.timestamp <= end_time) + + query = query.group_by(EventLog.event_type) + result = await session.execute(query) + stats = result.fetchall() + # Convert rows to dictionaries + stats_dicts = [dict(row._mapping) for row in stats] + return stats_dicts + except Exception as e: + raise e + finally: + await session.close() + + @schema_method + async def get_events( + self, + event_type: Optional[str] = Field(None, description="Event type"), + start_time: Optional[datetime.datetime] = Field( + None, description="Start time for filtering events" + ), + end_time: Optional[datetime.datetime] = Field( + None, description="End time for filtering events" + ), + context: Optional[dict] = None, + ): + """Search for events with various filters, enforcing workspace and permission checks.""" + workspace = context["ws"] + user_info = UserInfo.model_validate(context["user"]) + if not user_info.check_permission(workspace, UserPermission.read): + raise PermissionError(f"Permission denied for workspace {workspace}") + + session = await self._get_sql_session() + try: + async with session.begin(): + query = select(EventLog).filter( + EventLog.workspace == workspace, EventLog.user_id == user_info.id + ) + + # Apply optional filters + if event_type: + query = query.filter(EventLog.event_type == event_type) + if start_time: + query = query.filter(EventLog.timestamp >= start_time) + if end_time: + query = query.filter(EventLog.timestamp <= end_time) + + result = await session.execute(query) + # Use scalars() to get model instances, not rows + events = result.scalars().all() + + # Convert each EventLog instance to a dictionary using to_dict() + event_dicts = [event.to_dict() for event in events] + + return event_dicts + except Exception as e: + raise e + finally: + await session.close() + @schema_method async def get_summary(self, context: Optional[dict] = None) -> dict: """Get a summary about the workspace.""" @@ -233,16 +421,17 @@ async def create_workspace( assert "id" in config, "Workspace id must be provided." if not config.get("name"): config["name"] = config["id"] - ws = context["ws"] user_info = UserInfo.model_validate(context["user"]) if user_info.is_anonymous: raise Exception("Only registered user can create workspace.") - if not overwrite: - try: - if await self.load_workspace_info(config["id"]): - raise RuntimeError(f"Workspace already exists: {config['id']}") - except KeyError: - pass + + try: + await self.load_workspace_info(config["id"]) + if not overwrite: + raise RuntimeError(f"Workspace already exists: {config['id']}") + exists = True + except KeyError: + exists = False config["persistent"] = config.get("persistent") or False if user_info.is_anonymous and config["persistent"]: @@ -269,6 +458,8 @@ async def create_workspace( } ] await self._redis.hset("workspaces", workspace.id, workspace.model_dump_json()) + if not exists: + self._active_ws.inc() if self._s3_controller: await self._s3_controller.setup_workspace(workspace) await self._event_bus.emit("workspace_loaded", workspace.model_dump()) @@ -795,6 +986,7 @@ async def register_service( "service_added", service.model_dump(mode="json") ) logger.info(f"Adding service {service.id}") + self._active_svc.labels(workspace=ws).inc() @schema_method async def get_service_info( @@ -912,6 +1104,7 @@ async def unregister_service( ) else: await self._event_bus.emit("service_removed", service.model_dump()) + self._active_svc.labels(workspace=ws).dec() else: logger.warning(f"Service {key} does not exist and cannot be removed.") raise KeyError(f"Service not found: {service.id}") @@ -973,40 +1166,7 @@ async def log( self, msg: str = Field(..., description="log a message"), context=None ): """Log a app message.""" - self.validate_context(context, permission=UserPermission.read) - logger.info("%s: %s", context["from"], msg) - - @schema_method - async def info( - self, msg: str = Field(..., description="log a message as info"), context=None - ): - """Log a app message.""" - self.validate_context(context, permission=UserPermission.read) - logger.info("%s: %s", context["from"], msg) - - @schema_method - async def warning( - self, msg: str = Field(..., description="log a message as info"), context=None - ): - """Log a app message (warning).""" - self.validate_context(context, permission=UserPermission.read) - logger.warning("WARNING: %s: %s", context["from"], msg) - - @schema_method - async def error( - self, msg: str = Field(..., description="log an error message"), context=None - ): - """Log a app error message (error).""" - self.validate_context(context, permission=UserPermission.read) - logger.error("%s: %s", context["from"], msg) - - @schema_method - async def critical( - self, msg: str = Field(..., description="log an critical message"), context=None - ): - """Log a app error message (critical).""" - self.validate_context(context, permission=UserPermission.read) - logger.critical("%s: %s", context["from"], msg) + await self.log_event("log", msg, context=context) async def load_workspace_info(self, workspace: str, load=True) -> WorkspaceInfo: """Load info of the current workspace from the redis store.""" @@ -1377,6 +1537,8 @@ async def unload(self, context=None): if not winfo.persistent and not winfo.read_only: await self._redis.hdel("workspaces", ws) + self._active_ws.dec() + await self._event_bus.emit("workspace_unloaded", winfo.model_dump()) logger.info("Workspace %s unloaded.", ws) @@ -1431,10 +1593,9 @@ def create_service(self, service_id, service_name=None): }, "echo": self.echo, "log": self.log, - "info": self.info, - "error": self.error, - "warning": self.warning, - "critical": self.critical, + "log_event": self.log_event, + "get_event_stats": self.get_event_stats, + "get_events": self.get_events, "register_service": self.register_service, "unregister_service": self.unregister_service, "list_workspaces": self.list_workspaces, diff --git a/hypha/http.py b/hypha/http.py index e264c930..227d0f95 100644 --- a/hypha/http.py +++ b/hypha/http.py @@ -16,11 +16,11 @@ from starlette.routing import Route, Match from starlette.types import ASGIApp from jinja2 import Environment, PackageLoader, select_autoescape +from prometheus_client import generate_latest from fastapi.responses import ( JSONResponse, Response, RedirectResponse, - StreamingResponse, FileResponse, ) import jose @@ -845,6 +845,11 @@ async def login(request: Request): """Redirect to the login page.""" return RedirectResponse(norm_url("/public/apps/hypha-login/")) + @app.get(norm_url("/metrics")) + async def metrics(): + """Expose Prometheus metrics.""" + return Response(generate_latest(), media_type="text/plain") + @app.get(norm_url("/{page:path}")) async def get_pages( page: str, diff --git a/hypha/server.py b/hypha/server.py index 2e018463..d84ba8dc 100644 --- a/hypha/server.py +++ b/hypha/server.py @@ -104,7 +104,6 @@ def start_builtin_services( store, s3_controller=s3_controller, workspace_bucket=args.workspace_bucket, - database_uri=args.database_uri, ) if args.enable_server_apps: @@ -221,6 +220,7 @@ async def lifespan(app: FastAPI): public_base_url=public_base_url, local_base_url=local_base_url, redis_uri=args.redis_uri, + database_uri=args.database_uri, reconnection_token_life_time=float( env.get("RECONNECTION_TOKEN_LIFE_TIME", str(2 * 24 * 60 * 60)) ), diff --git a/hypha/websocket.py b/hypha/websocket.py index 9a4b252b..32a91407 100644 --- a/hypha/websocket.py +++ b/hypha/websocket.py @@ -5,6 +5,7 @@ from fastapi import Query, WebSocket, status from starlette.websockets import WebSocketDisconnect from fastapi import HTTPException +from prometheus_client import Gauge from hypha import __version__ from hypha.core import UserInfo, UserPermission @@ -30,6 +31,9 @@ def __init__(self, store: RedisStore, path="/ws"): self.store.set_websocket_server(self) self._stop = False self._websockets = {} + self._gauge = Gauge( + "websocket_connections", "Number of websocket connections", ["workspace"] + ) @app.websocket(path) async def websocket_endpoint( @@ -260,6 +264,7 @@ async def force_disconnect(_): conn = RedisRPCConnection(event_bus, workspace, client_id, user_info, None) self._websockets[f"{workspace}/{client_id}"] = websocket try: + self._gauge.labels(workspace=workspace).inc() event_bus.on_local(f"unload:{workspace}", force_disconnect) async def send_bytes(data): @@ -322,6 +327,7 @@ async def send_bytes(data): except Exception as e: raise e finally: + self._gauge.labels(workspace=workspace).dec() await conn.disconnect("disconnected") event_bus.off_local(f"unload:{workspace}", force_disconnect) if ( diff --git a/requirements.txt b/requirements.txt index f8acd16b..0d92ff2c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,3 +26,4 @@ aiocache==0.12.2 jsonschema==3.2.0 sqlalchemy==2.0.35 aiosqlite==0.20.0 +prometheus-client==0.21.0 diff --git a/setup.py b/setup.py index 0a736131..099d60ca 100644 --- a/setup.py +++ b/setup.py @@ -38,6 +38,7 @@ "jsonschema>=3.2.0", "sqlalchemy>=2.0.35", "aiosqlite>=0.20.0", + "prometheus-client>=0.21.0", ] ROOT_DIR = Path(__file__).parent.resolve() diff --git a/tests/test_artifact.py b/tests/test_artifact.py index db18bc08..1d6fb534 100644 --- a/tests/test_artifact.py +++ b/tests/test_artifact.py @@ -168,6 +168,11 @@ async def test_edit_existing_artifact(minio_server, fastapi_server): "_prefix", "collections/edit-test-collection/edit-test-dataset", ) + initial_view_count = collection["_stats"]["view_count"] + assert initial_view_count > 0 + assert collection["_stats"]["child_count"] > 0 + collection = await artifact_manager.read(prefix="collections/edit-test-collection") + assert collection["_stats"]["view_count"] == initial_view_count + 1 # Edit the artifact's manifest edited_manifest = { @@ -725,3 +730,128 @@ async def test_artifact_search_with_filters(minio_server, fastapi_server): prefix=f"collections/search-test-collection/test-dataset-{i}" ) await artifact_manager.delete(prefix="collections/search-test-collection") + + +async def test_download_count(minio_server, fastapi_server): + """Test the download count functionality for artifacts.""" + api = await connect_to_server({"name": "test-client", "server_url": SERVER_URL}) + artifact_manager = await api.get_service("public/artifact-manager") + + # Create a collection for testing download count + collection_manifest = { + "id": "download-test-collection", + "name": "Download Test Collection", + "description": "A collection to test download count functionality", + "type": "collection", + "collection": [], + } + await artifact_manager.create( + prefix="collections/download-test-collection", + manifest=collection_manifest, + stage=False, + ) + + # Create an artifact inside the collection + dataset_manifest = { + "id": "download-test-dataset", + "name": "Download Test Dataset", + "description": "A test dataset for download count", + "type": "dataset", + } + await artifact_manager.create( + prefix="collections/download-test-collection/download-test-dataset", + manifest=dataset_manifest, + stage=True, + ) + + # Put a file in the artifact + put_url = await artifact_manager.put_file( + prefix="collections/download-test-collection/download-test-dataset", + file_path="example.txt", + options={ + "download_weight": 0.5 + }, # Set the file as primary so downloading it will be count as a download + ) + source = "file contents of example.txt" + response = requests.put(put_url, data=source) + assert response.ok + + # put another file in the artifact but not setting weights + put_url = await artifact_manager.put_file( + prefix="collections/download-test-collection/download-test-dataset", + file_path="example2.txt", + ) + source = "file contents of example2.txt" + response = requests.put(put_url, data=source) + assert response.ok + + # Commit the artifact + await artifact_manager.commit( + prefix="collections/download-test-collection/download-test-dataset" + ) + + # Ensure that the download count is initially zero + artifact = await artifact_manager.read( + prefix="collections/download-test-collection/download-test-dataset" + ) + assert artifact["_stats"]["download_count"] == 0 + + # Increment the download count of the artifact by download the primary file + get_url = await artifact_manager.get_file( + prefix="collections/download-test-collection/download-test-dataset", + path="example.txt", + ) + response = requests.get(get_url) + assert response.ok + + # Ensure that the download count is incremented + artifact = await artifact_manager.read( + prefix="collections/download-test-collection/download-test-dataset" + ) + assert artifact["_stats"]["download_count"] == 0.5 + + # If we get the example file in silent mode, the download count won't increment + get_url = await artifact_manager.get_file( + prefix="collections/download-test-collection/download-test-dataset", + path="example.txt", + options={"silent": True}, + ) + response = requests.get(get_url) + assert response.ok + + # Ensure that the download count is not incremented + artifact = await artifact_manager.read( + prefix="collections/download-test-collection/download-test-dataset" + ) + assert artifact["_stats"]["download_count"] == 0.5 + + # download example 2 won't increment the download count + get_url = await artifact_manager.get_file( + prefix="collections/download-test-collection/download-test-dataset", + path="example2.txt", + ) + response = requests.get(get_url) + assert response.ok + + # Ensure that the download count is incremented + artifact = await artifact_manager.read( + prefix="collections/download-test-collection/download-test-dataset" + ) + assert artifact["_stats"]["download_count"] == 0.5 + + # Let's call reset_stats to reset the download count + await artifact_manager.reset_stats( + prefix="collections/download-test-collection/download-test-dataset" + ) + + # Ensure that the download count is reset + artifact = await artifact_manager.read( + prefix="collections/download-test-collection/download-test-dataset" + ) + assert artifact["_stats"]["download_count"] == 0 + + # Clean up by deleting the dataset and the collection + await artifact_manager.delete( + prefix="collections/download-test-collection/download-test-dataset" + ) + await artifact_manager.delete(prefix="collections/download-test-collection") diff --git a/tests/test_event_log.py b/tests/test_event_log.py new file mode 100644 index 00000000..0042968d --- /dev/null +++ b/tests/test_event_log.py @@ -0,0 +1,81 @@ +"""Test Event Log services.""" +import pytest +import asyncio +from hypha_rpc.websocket_client import connect_to_server + +from . import SERVER_URL, find_item + +# All test coroutines will be treated as marked. +pytestmark = pytest.mark.asyncio + + +async def test_log_event(fastapi_server): + """Test logging a new event via the service.""" + api = await connect_to_server({"name": "test-client", "server_url": SERVER_URL}) + + # Log an event + await api.log_event("rpc_call", {"service": "test_service"}) + + # Assume no errors means success + + +async def test_get_event_stats(fastapi_server): + """Test fetching event statistics via the service.""" + api = await connect_to_server({"name": "test-client", "server_url": SERVER_URL}) + # Log some events + await api.log_event("rpc_call", "Test RPC call 1") + await api.log_event("model_download", "Test model download") + await api.log_event("rpc_call", "Test RPC call 2") + + # Fetch event stats + stats = await api.get_event_stats(event_type="rpc_call") + + # Ensure we get correct stats for "rpc_call" events + assert len(stats) > 0 + assert any(stat["count"] > 0 for stat in stats) + + +async def test_get_events(fastapi_server): + """Test searching for specific events.""" + api = await connect_to_server({"name": "test-client", "server_url": SERVER_URL}) + + # Log some events with specific types + await api.log_event("rpc_call", {"service": "rpc_service"}) + await api.log_event("dataset_access", {"dataset": "test_dataset"}) + + # Search for rpc_call events + events = await api.get_events(event_type="rpc_call") + + # Ensure we get the correct events + assert len(events) > 0 + assert any(event["event_type"] == "rpc_call" for event in events) + + +async def test_invalid_permissions(fastapi_server): + """Test handling of invalid permissions for a user.""" + api = await connect_to_server({"name": "test-client", "server_url": SERVER_URL}) + + # Try to log an event without sufficient permissions + try: + await api.log_event("rpc_call", "Attempting an unauthorized event") + except PermissionError: + pass # This is expected + + # If no exception is raised, the test should fail + assert True + + +async def test_logging_multiple_events(fastapi_server): + """Test logging multiple events and retrieving stats for all.""" + api = await connect_to_server({"name": "test-client", "server_url": SERVER_URL}) + + # Log multiple events + await api.log_event("rpc_call", "Test event 1") + await api.log_event("rpc_call", "Test event 2") + await api.log_event("dataset_download", "Dataset download event") + + # Fetch statistics for all events + stats = await api.get_event_stats() + + # Ensure the stats cover all event types + assert len(stats) >= 2 # Expect at least rpc_call and dataset_download stats diff --git a/tests/test_observability.py b/tests/test_observability.py new file mode 100644 index 00000000..a245253b --- /dev/null +++ b/tests/test_observability.py @@ -0,0 +1,88 @@ +"""Test the hypha server.""" +import os +import subprocess +import sys +import asyncio + +import pytest +import requests +from prometheus_client.parser import text_string_to_metric_families +from hypha_rpc.websocket_client import connect_to_server + +from . import ( + SERVER_URL, + SERVER_URL_REDIS_1, + SERVER_URL_REDIS_2, + SIO_PORT2, + WS_SERVER_URL, + find_item, +) + +# All test coroutines will be treated as marked. +pytestmark = pytest.mark.asyncio + + +def get_metric_value(metric_name, labels): + """Helper to parse Prometheus metrics response and extract the value for the specific metric""" + response = requests.get(f"{SERVER_URL}/metrics") + assert response.status_code == 200 + metrics_data = response.text + for family in text_string_to_metric_families(metrics_data): + if family.name == metric_name: + for sample in family.samples: + if all(sample.labels.get(k) == v for k, v in labels.items()): + return sample.value + return None + + +async def test_metrics(fastapi_server, test_user_token): + """Test Prometheus metrics for workspace and service creation.""" + # Connect to the server and create a new workspace + api = await connect_to_server( + { + "client_id": "my-app-99", + "server_url": SERVER_URL, + "token": test_user_token, + } + ) + await api.log("hello") + + # Check the initial number of active workspaces + initial_active_workspaces = get_metric_value("active_workspaces", {}) + assert initial_active_workspaces is not None + + # Create a new workspace + await api.create_workspace( + { + "name": "my-test-workspace-metrics", + "description": "This is a test workspace", + "owners": ["user1@imjoy.io", "user2@imjoy.io"], + }, + overwrite=True, + ) + + # Check if the number of active workspaces has increased + active_workspaces = get_metric_value("active_workspaces", {}) + assert ( + active_workspaces == initial_active_workspaces + 1 + ), "Active workspace count did not increase." + + # Check if a service was added to the workspace + active_services = get_metric_value( + "active_services", {"workspace": "my-test-workspace-metrics"} + ) + assert active_services is None, "Expected no active services in the new workspace." + + # Verify that other services and workspaces haven't changed unexpectedly + public_active_services = get_metric_value( + "active_services", {"workspace": "public"} + ) + assert ( + public_active_services is not None + ), "Public active services metric is missing." + + # Optionally check that the RPC call metric is functioning + rpc_call_count = get_metric_value("rpc_call", {"workspace": api.config.workspace}) + assert ( + rpc_call_count is not None + ), "Expected an RPC call metric for the new workspace." diff --git a/tests/test_server.py b/tests/test_server.py index 0b490f75..d7802955 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -275,6 +275,54 @@ async def test_services(fastapi_server): ) +async def test_workspace_owners( + fastapi_server, test_user_token, test_user_token_temporary, test_user_token_2 +): + """Test workspace owners.""" + api = await connect_to_server( + { + "client_id": "my-app-99", + "server_url": WS_SERVER_URL, + "token": test_user_token, + } + ) + + ws = await api.create_workspace( + { + "name": "my-test-workspace-owners", + "description": "This is a test workspace", + "owners": ["user-1@test.com", "user-2@test.com"], + }, + ) + + assert ws["name"] == "my-test-workspace-owners" + + api2 = await connect_to_server( + { + "server_url": WS_SERVER_URL, + "token": test_user_token_2, + "workspace": "my-test-workspace-owners", + } + ) + + assert api2.config["workspace"] == "my-test-workspace-owners" + + try: + api3 = await connect_to_server( + { + "server_url": WS_SERVER_URL, + "token": test_user_token_temporary, + "workspace": "my-test-workspace-owners", + } + ) + assert api3.config["workspace"] == "my-test-workspace-owners" + except Exception as e: + assert "Permission denied for workspace" in str(e) + + await api2.disconnect() + await api.disconnect() + + async def test_server_scalability( fastapi_server_redis_1, fastapi_server_redis_2, test_user_token ):