diff --git a/programs/db/chromadb/__init__.py b/programs/db/chromadb/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/programs/db/chromadb/main.py b/programs/db/chromadb/main.py new file mode 100644 index 0000000..7fbf76c --- /dev/null +++ b/programs/db/chromadb/main.py @@ -0,0 +1,48 @@ +from chromadb import PersistentClient, Documents, Embeddings + +from utca.core import AddData, ReplacingScope +from utca.implementation.tasks import TransformersTextEmbedding +from utca.implementation.datasources.db import ( + ChromaDBGetOrCreateCollection, + ChromaDBCollectionAdd, + ChromaDBEmbeddingFunctionComponent, + ChromaDBCollectionQuery, +) + +# Sentences for dataset +sentences = [ + "People who test positive for Covid-19 no longer need to routinely stay away from others for at least five days, according to new guidelines from the US Centers for Disease Control and Prevention issued Friday.", + "The change ends a strategy from earlier in the pandemic that experts said has been important to controlling the spread of the infection.", + "Whether it be the latest prized Stanley cup or that 10-year-old plastic spout bottle you don’t go anywhere without, “emotional support water bottles” seem to be stuck to our sides and not going anywhere.", + "Health officials in Alaska recently reported the first known human death from a virus called Alaskapox.", + "Blizzard conditions continued to slam Northern California over the weekend with damaging winds and heavy snow dumping on mountain ridges down to the valleys.", +] + +class EmbeddingFunction(ChromaDBEmbeddingFunctionComponent[Documents]): + def __call__(self, documents: Documents) -> Embeddings: + return embedding_pipe.run({"texts": documents})["embeddings"].tolist() + +if __name__ == "__main__": + embedding_pipe = TransformersTextEmbedding() + + pipe = ( + AddData({ + "collection_name": "test", + }) + | ChromaDBGetOrCreateCollection( + client=PersistentClient(), embedding_function=EmbeddingFunction(embedding_pipe) # type: ignore + ).use(get_key="collection_name") + | AddData({ + "documents": sentences, + "ids": [f"id_{i}" for i in range(1, len(sentences)+1)] + }) + | ChromaDBCollectionAdd() + | AddData({ + "query_texts": ["Bad weather"], + "n_results": 1, + "include": ["documents", "distances"] + }) + | ChromaDBCollectionQuery().use(set_key="results", replace=ReplacingScope.GLOBAL) + ) + + print(pipe.run()["results"]) \ No newline at end of file diff --git a/programs/db/graph/README.md b/programs/db/graph/README.md new file mode 100644 index 0000000..b8b79d4 --- /dev/null +++ b/programs/db/graph/README.md @@ -0,0 +1,13 @@ +# Neo4j + +To run this example, you can run Docker locally. Execute the following command to start the container before running the program: + +``` console +sh run_neo4j_docker.sh +``` + +Then, run the program: + +``` console +python main.py +``` \ No newline at end of file diff --git a/programs/db/graph/__init__.py b/programs/db/graph/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/programs/db/graph/main.py b/programs/db/graph/main.py new file mode 100644 index 0000000..61d3dbd --- /dev/null +++ b/programs/db/graph/main.py @@ -0,0 +1,86 @@ +from __future__ import annotations +from typing import Any, Dict, cast + +from neo4j import ManagedTransaction + +from utca.core import While, ExecuteFunction +from utca.implementation.datasources.db import ( + Neo4jClient, Neo4jWriteAction +) + +employee_threshold=10 + +def employ_person_tx(tx: ManagedTransaction, name: str) -> str: + # Create new Person node with given name, if not exists already + result = tx.run(""" + MERGE (p:Person {name: $name}) + RETURN p.name AS name + """, name=name + ) + + # Obtain most recent organization ID and the number of people linked to it + result = tx.run(""" + MATCH (o:Organization) + RETURN o.id AS id, COUNT{(p:Person)-[r:WORKS_FOR]->(o)} AS employees_n + ORDER BY o.created_date DESC + LIMIT 1 + """) + org = result.single() + + if org is not None and org["employees_n"] == 0: + raise Exception("Most recent organization is empty.") + # Transaction will roll back -> not even Person is created! + + # If org does not have too many employees, add this Person to that + if org is not None and org.get("employees_n") < employee_threshold: + result = tx.run(""" + MATCH (o:Organization {id: $org_id}) + MATCH (p:Person {name: $name}) + MERGE (p)-[r:WORKS_FOR]->(o) + RETURN $org_id AS id + """, org_id=org["id"], name=name + ) + + # Otherwise, create a new Organization and link Person to it + else: + result = tx.run(""" + MATCH (p:Person {name: $name}) + CREATE (o:Organization {id: randomuuid(), created_date: datetime()}) + MERGE (p)-[r:WORKS_FOR]->(o) + RETURN o.id AS id + """, name=name + ) + + # Return the Organization ID to which the new Person ends up in + return cast(str, result.single(strict=True)["id"]) + +if __name__ == "__main__": + # See shell script for docker + client = Neo4jClient( + url="neo4j://localhost:7687", + user="neo4j", + password="password", + ) + + employee_id = 0 + def employee_name(_: Any) -> Dict[str, Any]: + global employee_id + employee_id += 1 + return {"kwargs": {"name": f"Thor{employee_id}"}} + + def print_message(input_data: Dict[str, Any]) -> None: + print(f'User {input_data["kwargs"]["name"]} added to organization {input_data["org_id"]}') + + p = While( + ExecuteFunction(employee_name) + | Neo4jWriteAction( + database="neo4j", + transaction_function=employ_person_tx, + client=client, + ).use(set_key="org_id") + | ExecuteFunction(print_message), + max_iterations=100, + ) + p.run() + + client.close() \ No newline at end of file diff --git a/programs/db/graph/run_neo4j_docker.sh b/programs/db/graph/run_neo4j_docker.sh new file mode 100644 index 0000000..fe3cd52 --- /dev/null +++ b/programs/db/graph/run_neo4j_docker.sh @@ -0,0 +1,4 @@ +docker run \ + --publish=7474:7474 --publish=7687:7687 \ + --env NEO4J_AUTH=neo4j/password \ + neo4j \ No newline at end of file diff --git a/programs/db/sql/__init__.py b/programs/db/sql/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/programs/db/main.py b/programs/db/sql/main.py similarity index 96% rename from programs/db/main.py rename to programs/db/sql/main.py index 817dee8..28f92d5 100644 --- a/programs/db/main.py +++ b/programs/db/sql/main.py @@ -1,7 +1,6 @@ from __future__ import annotations +from typing import List, Optional -from typing import List -from typing import Optional from sqlalchemy import ( ForeignKey, String, diff --git a/pyproject.toml b/pyproject.toml index 3d923d7..010ada1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "utca" -version = "0.1.0" +version = "0.1.1" description = "" authors = ["knowledgator.com"] readme = "README.md" @@ -27,6 +27,8 @@ requests-html = "^0.10.0" reportlab = "^4.2.0" requests = "^2.32.2" gliner = "^0.2.2" +neo4j = "^5.20.0" +chromadb = "^0.5.0" [build-system] diff --git a/src/utca/implementation/datasources/db/__init__.py b/src/utca/implementation/datasources/db/__init__.py index 000f1e8..e82bb84 100644 --- a/src/utca/implementation/datasources/db/__init__.py +++ b/src/utca/implementation/datasources/db/__init__.py @@ -1,10 +1,40 @@ from utca.implementation.datasources.db.sqlalchemy.main import ( SQLSessionFactory, BaseModel, SQLAction, SQLActionWithReturns ) +from utca.implementation.datasources.db.neo4j.main import ( + Neo4jClient, Neo4jReadAction, Neo4jWriteAction +) +from utca.implementation.datasources.db.chroma.main import ( + ChromaDBCollectionAdd, + ChromaDBCollectionUpdate, + ChromaDBCollectionUpsert, + ChromaDBGetCollection, + ChromaDBCreateCollection, + ChromaDBGetOrCreateCollection, + ChromaDBDeleteCollection, + ChromaDBCollectionGet, + ChromaDBCollectionQuery, +) +from utca.implementation.datasources.db.chroma.schema import ( + ChromaDBEmbeddingFunctionComponent, +) __all__ = [ "SQLSessionFactory", "BaseModel", "SQLAction", "SQLActionWithReturns", + "Neo4jClient", + "Neo4jReadAction", + "Neo4jWriteAction", + "ChromaDBGetCollection", + "ChromaDBCreateCollection", + "ChromaDBGetOrCreateCollection", + "ChromaDBDeleteCollection", + "ChromaDBCollectionAdd", + "ChromaDBCollectionUpdate", + "ChromaDBCollectionUpsert", + "ChromaDBCollectionGet", + "ChromaDBCollectionQuery", + "ChromaDBEmbeddingFunctionComponent", ] \ No newline at end of file diff --git a/src/utca/implementation/datasources/db/chroma/__init__.py b/src/utca/implementation/datasources/db/chroma/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utca/implementation/datasources/db/chroma/main.py b/src/utca/implementation/datasources/db/chroma/main.py new file mode 100644 index 0000000..8a0a707 --- /dev/null +++ b/src/utca/implementation/datasources/db/chroma/main.py @@ -0,0 +1,344 @@ +from typing import Any, Dict, Optional, cast + +from chromadb import Collection, EmbeddingFunction +from chromadb.api import ClientAPI +from chromadb.api.types import Embeddable, QueryResult, GetResult +from chromadb.utils.embedding_functions import DefaultEmbeddingFunction + +from utca.core.executable_level_1.actions import Action, ActionInput, ActionOutput + +class ChromaDBAction(Action[ActionInput, ActionOutput]): + def __init__( + self, + client: ClientAPI, + name: Optional[str]=None, + default_key: str="output", + ) -> None: + """ + Args: + client (ClientAPI): ChromaDB client to use + + name (Optional[str], optional): Name for identification. + If equals to None, class name will be used. Defaults to None. + + default_key (str, optional): Default key used for results that is not of type Dict. + Defaults to "output". + """ + super().__init__(name, default_key=default_key) + self.client = client + + +class ChromaDBRetrieveCollectionAction(ChromaDBAction[str, Collection]): + def __init__( + self, + client: ClientAPI, + embedding_function: Optional[EmbeddingFunction[Embeddable]]=None, + metadata: Optional[Dict[str, Any]] = None, + name: Optional[str]=None, + ) -> None: + """ + Args: + client (ClientAPI): ChromaDB client to use + + embedding_function (Optional[EmbeddingFunction[Embeddable]], optional): Embedding function to use. + If equals to None, default will be used. Defaults to None. + + metadata (Optional[Dict[str, Any]], optional): Collection metadata. Defaults to None. + + name (Optional[str], optional): Name for identification. + If equals to None, class name will be used. Defaults to None. + """ + super().__init__(client, name, "collection") + self.embedding_function = cast( + EmbeddingFunction[Embeddable], embedding_function or DefaultEmbeddingFunction() + ) + self.metadata = metadata + + +class ChromaDBCreateCollection(ChromaDBRetrieveCollectionAction): + """ + Create collection + + Args: + input_data: The name of the collection to create. + Returns: + Collection: The newly created collection. + """ + def execute(self, input_data: str) -> Collection: + """ + Args: + input_data: The name of the collection to create. + Returns: + Collection: The newly created collection. + """ + return self.client.create_collection( + name=input_data, + metadata=self.metadata, + embedding_function=self.embedding_function, + ) + + +class ChromaDBGetCollection(ChromaDBRetrieveCollectionAction): + """ + Get collection + + Args: + input_data: The name of the collection to get. + Returns: + Collection: The newly created collection. + """ + def execute(self, input_data: str) -> Collection: + """ + Args: + input_data: The name of the collection to get. + Returns: + Collection: The newly created collection. + """ + return self.client.get_collection( + name=input_data, + embedding_function=self.embedding_function + ) + + +class ChromaDBGetOrCreateCollection(ChromaDBRetrieveCollectionAction): + """ + Get or Create collection + """ + def execute(self, input_data: str) -> Collection: + """ + Args: + input_data: The name of the collection to get or create. + Returns: + Collection: The newly created collection. + """ + return self.client.get_or_create_collection( + name=input_data, + embedding_function=self.embedding_function + ) + + +class ChromaDBDeleteCollection(ChromaDBAction[str, None]): + """ + Delete collection + + Args: + input_data (str): Collection name + """ + def execute(self, input_data: str) -> None: + """ + Args: + input_data (str): Collection name + """ + self.client.delete_collection( + name=input_data + ) + + +class ChromaDBCollectionAdd(Action[Dict[str, Any], None]): + """ + Add embeddings to the data store. + + Args: + collection (Collection): Collection to use. + + ids: The ids of the embeddings you wish to add + + embeddings: The embeddings to add. If None, embeddings will be computed based on the documents or images using the embedding_function set for the Collection. Optional. + + metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional. + + documents: The documents to associate with the embeddings. Optional. + + images: The images to associate with the embeddings. Optional. + + uris: The uris of the images to associate with the embeddings. Optional. + """ + def execute(self, input_data: Dict[str, Any]) -> None: + """ + Args: + collection (Collection): Collection to use. + + ids: The ids of the embeddings you wish to add + + embeddings: The embeddings to add. If None, embeddings will be computed based on the documents or images using the embedding_function set for the Collection. Optional. + + metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional. + + documents: The documents to associate with the embeddings. Optional. + + images: The images to associate with the embeddings. Optional. + + uris: The uris of the images to associate with the embeddings. Optional. + """ + collection: Collection = input_data["collection"] + collection.add( # type: ignore + ids=input_data["ids"], + embeddings=input_data.get("embeddings"), + metadatas=input_data.get("metadatas"), + documents=input_data.get("documents"), + images=input_data.get("images"), + uris=input_data.get("uris"), + ) + + +class ChromaDBCollectionUpdate(Action[Dict[str, Any], None]): + """ + Update the embeddings, metadatas or documents for provided ids. + + Args: + collection (Collection): Collection to use. + + ids: The ids of the embeddings you wish to add + + embeddings: The embeddings to add. If None, embeddings will be computed based on the documents or images using the embedding_function set for the Collection. Optional. + + metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional. + + documents: The documents to associate with the embeddings. Optional. + + images: The images to associate with the embeddings. Optional. + + uris: The uris of the images to associate with the embeddings. Optional. + """ + def execute(self, input_data: Dict[str, Any]) -> None: + """ + Args: + collection (Collection): Collection to use. + + ids: The ids of the embeddings you wish to add + + embeddings: The embeddings to add. If None, embeddings will be computed based on the documents or images using the embedding_function set for the Collection. Optional. + + metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional. + + documents: The documents to associate with the embeddings. Optional. + + images: The images to associate with the embeddings. Optional. + + uris: The uris of the images to associate with the embeddings. Optional. + """ + collection: Collection = input_data["collection"] + collection.update( # type: ignore + ids=input_data["ids"], + embeddings=input_data.get("embeddings"), + metadatas=input_data.get("metadatas"), + documents=input_data.get("documents"), + images=input_data.get("images"), + uris=input_data.get("uris"), + ) + + +class ChromaDBCollectionUpsert(Action[Dict[str, Any], None]): + """ + Update the embeddings, metadatas or documents for provided ids, or create them if they don't exist. + + Args: + collection (Collection): Collection to use. + + ids: The ids of the embeddings you wish to add + + embeddings: The embeddings to add. If None, embeddings will be computed based on the documents or images using the embedding_function set for the Collection. Optional. + + metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional. + + documents: The documents to associate with the embeddings. Optional. + + images: The images to associate with the embeddings. Optional. + + uris: The uris of the images to associate with the embeddings. Optional. + """ + def execute(self, input_data: Dict[str, Any]) -> None: + collection: Collection = input_data["collection"] + collection.upsert( # type: ignore + ids=input_data["ids"], + embeddings=input_data.get("embeddings"), + metadatas=input_data.get("metadatas"), + documents=input_data.get("documents"), + images=input_data.get("images"), + uris=input_data.get("uris"), + ) + + +class ChromaDBCollectionQuery(Action[Dict[str, Any], QueryResult]): + """ + Get the n_results nearest neighbor embeddings for provided query_embeddings or query_texts. + + Args: + query_embeddings: The embeddings to get the closes neighbors of. Optional. + + query_texts: The document texts to get the closes neighbors of. Optional. + + query_images: The images to get the closes neighbors of. Optional. + + n_results: The number of neighbors to return for each query_embedding or query_texts. Optional. + + where: A Where type dict used to filter results by. E.g. {"$and": ["color" : "red", "price": {"$gte": 4.20}]}. Optional. + + where_document: A WhereDocument type dict used to filter by the documents. E.g. {$contains: {"text": "hello"}}. Optional. + + include: A list of what to include in the results. Can contain "embeddings", "metadatas", "documents", "distances". Ids are always included. Defaults to ["metadatas", "documents", "distances"]. Optional. + + Returns: + QueryResult: A QueryResult object containing the results. + """ + def execute(self, input_data: Dict[str, Any]) -> QueryResult: + collection: Collection = input_data["collection"] + return collection.query( # type: ignore + query_embeddings=input_data.get("query_embeddings"), + query_texts=input_data.get("query_texts"), + query_uris=input_data.get("query_uris"), + n_results=input_data.get("n_results", 10), + where=input_data.get("where"), + where_document=input_data.get("where_document"), + include=input_data.get("include", ["metadatas", "documents", "distances"]), + ) + + +class ChromaDBCollectionGet(Action[Dict[str, Any], GetResult]): + """ + Get embeddings and their associate data from the data store. If no ids or where filter is provided returns all embeddings up to limit starting at offset. + + Args: + ids: The ids of the embeddings to get. Optional. + + where: A Where type dict used to filter results by. E.g. {"$and": ["color" : "red", "price": {"$gte": 4.20}]}. Optional. + + limit: The number of documents to return. Optional. + + offset: The offset to start returning results from. Useful for paging results with limit. Optional. + + where_document: A WhereDocument type dict used to filter by the documents. E.g. {$contains: {"text": "hello"}}. Optional. + + include: A list of what to include in the results. Can contain "embeddings", "metadatas", "documents". Ids are always included. Defaults to ["metadatas", "documents"]. Optional. + + Returns: + GetResult: A GetResult object containing the results. + """ + def execute(self, input_data: Dict[str, Any]) -> GetResult: + """ + Args: + ids: The ids of the embeddings to get. Optional. + + where: A Where type dict used to filter results by. E.g. {"$and": ["color" : "red", "price": {"$gte": 4.20}]}. Optional. + + limit: The number of documents to return. Optional. + + offset: The offset to start returning results from. Useful for paging results with limit. Optional. + + where_document: A WhereDocument type dict used to filter by the documents. E.g. {$contains: {"text": "hello"}}. Optional. + + include: A list of what to include in the results. Can contain "embeddings", "metadatas", "documents". Ids are always included. Defaults to ["metadatas", "documents"]. Optional. + + Returns: + GetResult: A GetResult object containing the results. + """ + collection: Collection = input_data["collection"] + return collection.get( # type: ignore + ids=input_data.get("ids"), + limit=input_data.get("limit"), + offset=input_data.get("offset"), + where=input_data.get("where"), + where_document=input_data.get("where_document"), + include=input_data.get("include", ["metadatas", "documents"]), + ) + diff --git a/src/utca/implementation/datasources/db/chroma/schema.py b/src/utca/implementation/datasources/db/chroma/schema.py new file mode 100644 index 0000000..8996450 --- /dev/null +++ b/src/utca/implementation/datasources/db/chroma/schema.py @@ -0,0 +1,20 @@ +from abc import abstractmethod + +from chromadb import EmbeddingFunction, Embeddings +from chromadb.api.types import D + +from utca.core.executable_level_1.component import Component + +class ChromaDBEmbeddingFunctionComponent(EmbeddingFunction[D]): + """ + Embedding function wrapper for components + """ + def __init__(self, component: Component) -> None: + super().__init__() + self.component = component + + + + @abstractmethod + def __call__(self, input: D) -> Embeddings: + ... \ No newline at end of file diff --git a/src/utca/implementation/datasources/db/neo4j/__init__.py b/src/utca/implementation/datasources/db/neo4j/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utca/implementation/datasources/db/neo4j/main.py b/src/utca/implementation/datasources/db/neo4j/main.py new file mode 100644 index 0000000..ba750d6 --- /dev/null +++ b/src/utca/implementation/datasources/db/neo4j/main.py @@ -0,0 +1,136 @@ +from typing import Any, Dict, Callable, Optional + +from neo4j import GraphDatabase, Session, ManagedTransaction + +from utca.core.executable_level_1.actions import Action + +class Neo4jClient: + """ + Neo4j client + """ + def __init__( + self, + url: str, + user: str, + password: str, + ) -> None: + """ + Args: + url (str): Connetcion URL. + + user (str): Authentication user. + + password (str): Authentication password. + """ + self.driver = GraphDatabase.driver(url, auth=(user, password)) + + + def close(self) -> None: + self.driver.close() + + + def session(self, database: str) -> Session: + return self.driver.session(database=database) + + +class Neo4jWriteAction(Action[Dict[str, Any], Any]): + """ + Neo4j write transaction + + Args: + input_data (Dict[str, Any]): Expected keys: + "args" (List[Any], optional): Positional arguments for transaction function. + + "kwargs" (Dict[str, Any], optional): Keyword arguments for transaction function. + + Returns: + Any: result of the executed transaction + """ + def __init__( + self, + database: str, + transaction_function: Callable[[ManagedTransaction, Any], Any], + client: Neo4jClient, + name: Optional[str]=None + ): + """ + Args: + database (str): Database name. + + transaction_function(Callable[[ManagedTransaction, Any], Any]): Transaction that will be executed. + + client (Neo4jClient): Client that will be used. + + name (Optional[str], optional): Name for identification. + If equals to None, class name will be used. Defaults to None. + """ + super().__init__(name) + self.database = database + self.transaction_function = transaction_function + self.client = client + + + def execute(self, input_data: Dict[str, Any]) -> Any: + """ + Args: + input_data (Dict[str, Any]): Expected keys: + "args" (List[Any], optional): Positional arguments for transaction function. + + "kwargs" (Dict[str, Any], optional): Keyword arguments for transaction function. + + Returns: + Any: result of the executed transaction + """ + with self.client.session(self.database) as session: + return session.execute_write( + transaction_function=self.transaction_function, + *(input_data.get("args", [])), + **(input_data.get("kwargs", {})), + ) + + +class Neo4jReadAction(Action[Dict[str, Any], Any]): + """ + Neo4j read transaction + """ + def __init__( + self, + database: str, + transaction_function: Callable[[ManagedTransaction, Any], Any], + client: Neo4jClient, + name: Optional[str]=None + ): + """ + Args: + database (str): Database name. + + transaction_function(Callable[[ManagedTransaction, Any], Any]): Transaction that will be executed. + + client (Neo4jClient): Client that will be used. + + name (Optional[str], optional): Name for identification. + If equals to None, class name will be used. Defaults to None. + """ + super().__init__(name) + self.database = database + self.transaction_function = transaction_function + self.client = client + + + def execute(self, input_data: Dict[str, Any]) -> Any: + """ + Args: + input_data (Dict[str, Any]): Expected keys: + "args" (List[Any], optional): Positional arguments for transaction function. + + "kwargs" (Dict[str, Any], optional): Keyword arguments for transaction function. + + Returns: + Any: result of the executed transaction + """ + with self.client.session(self.database) as session: + return session.execute_read( + transaction_function=self.transaction_function, + *(input_data.get("args", [])), + **(input_data.get("kwargs", {})), + ) \ No newline at end of file diff --git a/src/utca/implementation/predictors/__init__.py b/src/utca/implementation/predictors/__init__.py index 2a564ca..33ed6fc 100644 --- a/src/utca/implementation/predictors/__init__.py +++ b/src/utca/implementation/predictors/__init__.py @@ -1,5 +1,5 @@ from utca.implementation.predictors.transformers_predictor.transformers_model import ( - TransformersModel, + TransformersModel, TransformersGenerativeModel ) from utca.implementation.predictors.transformers_predictor.schema import ( TransformersModelConfig, @@ -65,6 +65,7 @@ __all__ = [ "TransformersModel", + "TransformersGenerativeModel", "TransformersModelConfig", "TransformersPipelineConfig", "TransformersImageClassificationModelInput", diff --git a/src/utca/implementation/schemas/semantic_search/semantic_search_schema.py b/src/utca/implementation/schemas/semantic_search/semantic_search_schema.py index 6ba56c9..fbb1833 100644 --- a/src/utca/implementation/schemas/semantic_search/semantic_search_schema.py +++ b/src/utca/implementation/schemas/semantic_search/semantic_search_schema.py @@ -5,9 +5,7 @@ from utca.core.executable_level_1.interpreter import Evaluator from utca.core.executable_level_1.executable import Executable -from utca.core.executable_level_1.schema import ( - IOModel, Transformable -) +from utca.core.executable_level_1.schema import IOModel from utca.implementation.tasks.text_processing.embedding.transformers_task.transformers_embedding import ( TransformersTextEmbedding ) @@ -83,12 +81,9 @@ def __init__( def get_embeddings(self, texts: List[str]) -> npt.NDArray[Any]: - return getattr( - self.encoder(Transformable({ + return self.encoder.run({ "texts": texts - })), - "embeddings" - ) + })["embeddings"] def add(self, dataset: List[str]) -> SemanticSearchSchema: diff --git a/src/utca/implementation/tasks/__init__.py b/src/utca/implementation/tasks/__init__.py index fc3f67f..ce9eb7a 100644 --- a/src/utca/implementation/tasks/__init__.py +++ b/src/utca/implementation/tasks/__init__.py @@ -194,6 +194,14 @@ GLiNERRelationExtractionPostprocessor, ) +from utca.implementation.tasks.text_processing.textual_q_and_a.gliner_task.q_and_a import ( + GLiNERQandA, +) +from utca.implementation.tasks.text_processing.textual_q_and_a.gliner_task.actions import ( + GLiNERQandAPreprocessor, + GLiNERQandAPostprocessor, +) + __all__ = [ # Audio processing "TransformersTextToSpeech", @@ -307,4 +315,8 @@ "GLiNERRelationExtraction", "GLiNERRelationExtractionPreprocessor", "GLiNERRelationExtractionPostprocessor", + + "GLiNERQandA", + "GLiNERQandAPreprocessor", + "GLiNERQandAPostprocessor", ] \ No newline at end of file diff --git a/src/utca/implementation/tasks/text_processing/entity_linking/transformers_task/transformers_entity_linking.py b/src/utca/implementation/tasks/text_processing/entity_linking/transformers_task/transformers_entity_linking.py index 2aef8af..6a00ee8 100644 --- a/src/utca/implementation/tasks/text_processing/entity_linking/transformers_task/transformers_entity_linking.py +++ b/src/utca/implementation/tasks/text_processing/entity_linking/transformers_task/transformers_entity_linking.py @@ -204,21 +204,32 @@ class name will be used. Defaults to None. self.pad_token_id: int = cast(int, self.tokenizer.pad_token_id) + expected_kwargs = { + "max_new_tokens": 512, + "pad_token_id": self.pad_token_id, + } + required_kwargs = { + "return_dict_in_generate": True, + "output_scores": True, + } if not predictor: model = self.initialize_model(self.default_model) predictor = TransformersGenerativeModel( TransformersModelConfig( model=model, # type: ignore - kwargs={ - "max_new_tokens": 512, - "pad_token_id": self.pad_token_id, - "return_dict_in_generate": True, - "output_scores": True, - } + kwargs={**expected_kwargs, **required_kwargs} ), input_class=TransformersEntityLinkingInput, output_class=TransformersEntityLinkingOutput, ) + else: + if not predictor.cfg.kwargs: # type: ignore + predictor.cfg.kwargs = {**expected_kwargs, **required_kwargs} # type: ignore + else: + for k, v in expected_kwargs.items(): + if not k in predictor.cfg.kwargs: # type: ignore + predictor.cfg.kwargs[k] = v # type: ignore + predictor.cfg.kwargs.update(required_kwargs) # type: ignore self.encoder_decoder: bool = predictor.config.is_encoder_decoder self.initialize_labels_trie(labels) diff --git a/src/utca/implementation/tasks/text_processing/ner/gliner_task/actions.py b/src/utca/implementation/tasks/text_processing/ner/gliner_task/actions.py index addc92d..94f55cb 100644 --- a/src/utca/implementation/tasks/text_processing/ner/gliner_task/actions.py +++ b/src/utca/implementation/tasks/text_processing/ner/gliner_task/actions.py @@ -19,8 +19,8 @@ class GLiNERPreprocessor(Action[Dict[str, Any], Dict[str, Any]]): "inputs" (List[str]): Model inputs; "chunks_starts" (List[int]): Chunks start positions. Used by postprocessor; - - "prompt_lengths" (List[int]): Prompt lenghts. Used by postprocessor; + + "threshold" (float): Minimal score for an entity to put into output; """ def __init__( diff --git a/src/utca/implementation/tasks/text_processing/textual_q_and_a/gliner_task/__init__.py b/src/utca/implementation/tasks/text_processing/textual_q_and_a/gliner_task/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utca/implementation/tasks/text_processing/textual_q_and_a/gliner_task/actions.py b/src/utca/implementation/tasks/text_processing/textual_q_and_a/gliner_task/actions.py new file mode 100644 index 0000000..628dda3 --- /dev/null +++ b/src/utca/implementation/tasks/text_processing/textual_q_and_a/gliner_task/actions.py @@ -0,0 +1,172 @@ +from typing import Any, Dict, List, Generator, Optional, Tuple + +from utca.core.executable_level_1.actions import Action +from utca.core.task_level_3.objects.objects import ( + Entity +) +from utca.implementation.tasks.text_processing.utils import sent_tokenizer + +class GLiNERQandAPreprocessor(Action[Dict[str, Any], Dict[str, Any]]): + """ + Preprocess inputs + + Args: + input_data (Dict[str, Any]): Expected keys: + "text" (str): Text to process; + + "question" (str): Question to answer; + + Returns: + Dict[str, Any]: Expected keys: + "texts" (List[str]): Model inputs; + + "labels" (List[str]): Labels model inputs; + + "chunks_starts" (List[int]): Chunks start positions. Used by postprocessor; + + "threshold" (float): Minimal score for an entity to put into output; + """ + def __init__( + self, + sents_batch: int=10, + threshold: float=0.5, + name: Optional[str]=None, + ) -> None: + """ + Args: + sents_batch (int): Chunks size in sentences. Defaults to 10. + + threshold (float): Minimial score to put entities into the output. + + name (Optional[str], optional): Name for identification. If equals to None, + class name will be used. Defaults to None. + """ + super().__init__(name) + self.threshold = threshold + self.sents_batch = sents_batch + + + def get_last_sentence_id(self, i: int, sentences_len: int) -> int: + return min(i + self.sents_batch, sentences_len) - 1 + + + def chunkanize(self, text: str) -> Tuple[List[str], List[int]]: + chunks: List[str] = [] + starts: List[int] = [] + + sentences: List[Tuple[int, int]] = list(sent_tokenizer(text)) + + for i in range(0, len(sentences), self.sents_batch): + start = sentences[i][0] + starts.append(start) + + last_sentence = self.get_last_sentence_id(i, len(sentences)) + end = sentences[last_sentence][-1] + + chunks.append(text[start:end]) + return chunks, starts + + + def execute( + self, input_data: Dict[str, Any] + ) -> Dict[str, Any]: + """ + Args: + input_data (Dict[str, Any]): Expected keys: + "text" (str): Text to process; + + "question" (str): Question to answer; + + Returns: + Dict[str, Any]: Expected keys: + "texts" (List[str]): Model inputs; + + "labels" (List[str]): Labels model inputs; + + "chunks_starts" (List[int]): Chunks start positions. Used by postprocessor; + + "threshold" (float): Minimal score for an entity to put into output; + """ + chunks, chunks_starts = ( + self.chunkanize(input_data["text"]) + ) + return { + "texts": [f'{input_data["question"]} {c}' for c in chunks], + "labels": ["answer"], + "chunks_starts": chunks_starts, + "threshold": self.threshold, + } + + +class GLiNERQandAPostprocessor(Action[Dict[str, Any], Dict[str, Any]]): + """ + Format output + + Args: + input_data (Dict[str, Any]): Expected keys: + "output" (List[List[Dict[str, Any]]]): Model output; + + "chunks_starts" (List[int]): Chunks starts; + + "text" (str): Processed text; + + "question" (str): Answered question. + + Returns: + Dict[str, Any]: Expected keys: + "text" (str): Processed text; + + "question" (str): Answered question. + + "output" (List[Entity]): Answers; + """ + def process_entities( + self, + raw_entities: List[List[Dict[str, Any]]], + chunks_starts: List[int], + question_length: int, + ) -> Generator[Entity, None, None]: + for id, output in enumerate(raw_entities): + shift = chunks_starts[id] - question_length - 1 + for ent in output: + start = ent['start'] + shift + end = ent['end'] + shift + yield Entity( + start=start, + end=end, + span=ent['text'], + score=ent['score'], + ) + + + def execute( + self, input_data: Dict[str, Any] + ) -> Dict[str, Any]: + """ + Args: + input_data (Dict[str, Any]): Expected keys: + "output" (List[List[Dict[str, Any]]]): Model output; + + "chunks_starts" (List[int]): Chunks starts; + + "text" (str): Processed text; + + "question" (str): Answered question. + + Returns: + Dict[str, Any]: Expected keys: + "text" (str): Processed text; + + "question" (str): Answered question. + + "output" (List[Entity]): Answers; + """ + return { + 'text': input_data["text"], + 'question': input_data["question"], + 'output': list(self.process_entities( + input_data["output"], + input_data["chunks_starts"], + len(input_data["question"]) + )) + } \ No newline at end of file diff --git a/src/utca/implementation/tasks/text_processing/textual_q_and_a/gliner_task/q_and_a.py b/src/utca/implementation/tasks/text_processing/textual_q_and_a/gliner_task/q_and_a.py new file mode 100644 index 0000000..e66e48c --- /dev/null +++ b/src/utca/implementation/tasks/text_processing/textual_q_and_a/gliner_task/q_and_a.py @@ -0,0 +1,95 @@ +from typing import Any, Type, Optional + +from utca.core.executable_level_1.component import Component +from utca.core.executable_level_1.schema import IOModel, Input +from utca.core.predictor_level_2.predictor import Predictor +from utca.core.task_level_3.task import NERTask +from utca.core.task_level_3.schema import ( + NEROutput, NEROutputType +) +from utca.core.task_level_3.objects.objects import ( + Entity +) +from utca.implementation.predictors.gliner_predictor.predictor import ( + GLiNERPredictor, GLiNERPredictorConfig +) +from utca.implementation.tasks.text_processing.textual_q_and_a.gliner_task.actions import ( + GLiNERQandAPreprocessor, + GLiNERQandAPostprocessor +) + +class GLiNERQandAInput(IOModel): + """ + Arguments: + text (str): Text to use. + + question (str): Question to answer. + """ + text: str + question: str + + +class GLiNERQandAOutput(NEROutput[Entity]): + """ + Arguments: + text (str): Input text. + + question (str): Answered question. + + output (List[Entity]): Answers. + """ + text: str + question: str + + +class GLiNERQandA( + NERTask[Input, NEROutputType] +): + """ + Textual Q&A task + """ + default_model: str = "knowledgator/gliner-multitask-large-v0.5" + + def __init__( + self, + predictor: Optional[Predictor[Any, Any]]=None, + preprocess: Optional[Component]=None, + postprocess: Optional[Component]=None, + input_class: Type[Input]=GLiNERQandAInput, + output_class: Type[NEROutputType]=GLiNERQandAOutput, + name: Optional[str]=None, + ) -> None: + """ + Args: + predictor (Optional[Predictor[Any, Any]], optional): Predictor that will be used in task. + If equals to None, default predictor will be used. Defaults to None. + + preprocess (Optional[Component], optional): Component executed + before predictor. If equals to None, default component will be used. Defaults to None. + + Default component: + GLiNERQandAPreprocessor + + postprocess (Optional[Component], optional): Component executed + after predictor. If equals to None, default component will be used. Defaults to None. + + Default component: + GLiNERQandAPostprocessor + + input_class (Type[Input], optional): Class for input validation. + Defaults to GLiNERQandAInput. + + output_class (Type[NEROutputType], optional): Class for output validation. + Defaults to GLiNERQandAOutput. + + name (Optional[str], optional): Name for identification. If equals to None, + class name will be used. Defaults to None. + """ + super().__init__( + predictor=predictor or GLiNERPredictor(GLiNERPredictorConfig(model_name=self.default_model)), + preprocess=preprocess or GLiNERQandAPreprocessor(), + postprocess=postprocess or GLiNERQandAPostprocessor(), + input_class=input_class, + output_class=output_class, + name=name, + ) \ No newline at end of file diff --git a/src/utca/implementation/tasks/text_processing/textual_q_and_a/token_searcher/token_searcher.py b/src/utca/implementation/tasks/text_processing/textual_q_and_a/token_searcher/token_searcher.py index e173fbb..62f93b1 100644 --- a/src/utca/implementation/tasks/text_processing/textual_q_and_a/token_searcher/token_searcher.py +++ b/src/utca/implementation/tasks/text_processing/textual_q_and_a/token_searcher/token_searcher.py @@ -21,7 +21,7 @@ class TokenSearcherQandAInput(IOModel): """ Arguments: - text (str): Text to clean. + text (str): Text to use. question (str): Question to answer. """ diff --git a/tests/implementations/tasks/test_gliner_q_and_a.py b/tests/implementations/tasks/test_gliner_q_and_a.py new file mode 100644 index 0000000..4a5730c --- /dev/null +++ b/tests/implementations/tasks/test_gliner_q_and_a.py @@ -0,0 +1,12 @@ +from utca.implementation.tasks import ( + GLiNERQandA +) + +def test_default(): + text = "Microsoft was founded by Bill Gates and Paul Allen on April 4, 1975, to develop and sell BASIC interpreters for the Altair 8800. During his career at Microsoft, Gates held the positions of chairman, chief executive officer, president and chief software architect, while also being the largest individual shareholder until May 2014." + pipe = GLiNERQandA() + answer = pipe.run({ + "text": text, + "question": "Who was the CEO of Microsoft?" + })["output"][0] + assert answer["span"] == text[answer["start"]:answer["end"]] \ No newline at end of file