Skip to content

Commit

Permalink
Expose count for vectors and artifacts
Browse files Browse the repository at this point in the history
  • Loading branch information
oeway committed Nov 21, 2024
1 parent b7e615b commit 64c4c97
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 17 deletions.
2 changes: 1 addition & 1 deletion hypha/VERSION
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{
"version": "0.20.39.post10"
"version": "0.20.39.post11"
}
43 changes: 27 additions & 16 deletions hypha/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,23 @@
text,
and_,
or_,
update,
)
from hrid import HRID
from sqlalchemy.sql import func
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm.attributes import flag_modified
from sqlalchemy.ext.asyncio import (
async_sessionmaker,
AsyncSession,
)

from hrid import HRID
from hypha.utils import remove_objects_async, list_objects_async, safe_join
from hypha.utils.zenodo import ZenodoClient
from botocore.exceptions import ClientError
from hypha.s3 import FSFileResponse
from aiobotocore.session import get_session
from sqlalchemy import update
from sqlalchemy.ext.asyncio import (
async_sessionmaker,
AsyncSession,
)

from fastapi import APIRouter, Depends, HTTPException
from hypha.core import (
UserInfo,
Expand Down Expand Up @@ -1257,6 +1260,23 @@ async def read(
artifact, version_index, s3_config
)

if artifact.type == "collection":
# Use with_only_columns to optimize the count query
count_q = select(func.count()).where(
ArtifactModel.parent_id == artifact.id
)
result = await session.execute(count_q)
child_count = result.scalar()
artifact_data["config"] = artifact_data.get("config", {})
artifact_data["config"]["child_count"] = child_count
elif artifact.type == "vector-collection" and self._vectordb_client:
artifact_data["config"] = artifact_data.get("config", {})
artifact_data["config"]["vector_count"] = (
await self._vectordb_client.count(
collection_name=f"{artifact.workspace}/{artifact.alias}"
)
).count

if not silent:
await session.commit()

Expand Down Expand Up @@ -1337,15 +1357,6 @@ async def commit(
),
)

if artifact.type == "vector-collection":
assert (
self._vectordb_client
), "The server is not configured to use a VectorDB client."
artifact.manifest["points"] = self._vectordb_client.count(
collection_name=f"{artifact.workspace}/{artifact.alias}"
)
flag_modified(artifact, "manifest")

parent_artifact_config = (
parent_artifact.config if parent_artifact else {}
)
Expand Down Expand Up @@ -1417,7 +1428,7 @@ async def delete(
assert (
self._vectordb_client
), "The server is not configured to use a VectorDB client."
self._vectordb_client.delete_collection(
await self._vectordb_client.delete_collection(
collection_name=f"{artifact.workspace}/{artifact.alias}"
)

Expand Down
6 changes: 6 additions & 0 deletions tests/test_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ async def test_artifact_vector_collection(
vectors=vectors,
)

vc = await artifact_manager.read(artifact_id=vector_collection.id)
assert vc["config"]["vector_count"] == 3

# Search for vectors by query vector
query_vector = [random.random() for _ in range(384)]
search_results = await artifact_manager.search_by_vector(
Expand Down Expand Up @@ -929,6 +932,9 @@ async def test_edit_existing_artifact(minio_server, fastapi_server, test_user_to
version="stage",
)

collection = await artifact_manager.read(artifact_id=collection.id)
assert collection["config"]["child_count"] == 1

# Commit the artifact
dataset = await artifact_manager.commit(artifact_id=dataset.id)
versions = dataset["versions"]
Expand Down

0 comments on commit 64c4c97

Please sign in to comment.