From 23f92e363776ab7ce2e0f06f336cca645ddf8d87 Mon Sep 17 00:00:00 2001 From: Yx Jiang <2237303+yxjiang@users.noreply.github.com> Date: Fri, 29 Mar 2024 18:51:53 -0700 Subject: [PATCH] Update retrieve tool --- polymind/core_tools/retrieve_tool.py | 46 +++++++++++++++++++++------- pyproject.toml | 2 +- 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/polymind/core_tools/retrieve_tool.py b/polymind/core_tools/retrieve_tool.py index 7254102..0ccfbe5 100644 --- a/polymind/core_tools/retrieve_tool.py +++ b/polymind/core_tools/retrieve_tool.py @@ -15,7 +15,8 @@ class RetrieveTool(BaseTool, ABC): """The base class for the retrieval tools.""" - query_key: str = Field(default="query", description="The key to retrieve the query from the input message.") + query_key: str = Field(default="input", description="The key to retrieve the query from the input message.") + result_key: str = Field(default="results", description="The key to store the results in the output message.") embedder: Embedder = Field(description="The embedder to generate the embedding for the descriptions.") top_k: int = Field(default=3, description="The number of top results to retrieve.") @@ -54,7 +55,7 @@ def input_spec(self) -> List[Param]: def output_spec(self) -> List[Param]: output_spec = [ Param( - name="results", + name=self.result_key, type="List[str]", required=True, description="The top k results retrieved by the tool.", @@ -67,6 +68,34 @@ def output_spec(self) -> List[Param]: ] return output_spec + @abstractmethod + async def _retrieve(self, input: Message, query_embedding: List[float]) -> Message: + """Retrieve the information based on the query. + + Args: + input (Message): The input message containing the query. It should have fields defined in the input_spec. + query_embedding (List[List[float]]): The embedding of the query. + + Return: + Message: The message containing the retrieved information. + """ + pass + + async def _execute(self, input: Message) -> Message: + """Retrieve the information based on the query. + + Args: + input (Message): The input message containing the query. It should have fields defined in the input_spec. + """ + # Get the embeddings for the query. + query = input.content.get(self.query_key, "") + embed_message = Message(content={"input": [query]}) + embedding_message = await self.embedder(embed_message) + embedding_message.content["embeddings"] + # Retrieve the information based on the query. + response_message = await self._retrieve(embedding_message.content["embeddings"]) + return response_message + class IndexTool(BaseTool, ABC): """The base class for the indexing tools.""" @@ -137,18 +166,13 @@ def _set_client(self): self._client.create_collection(self.collection_name, dimension=self.embed_dim, auto_id=True) self.embedder = OpenAIEmbeddingTool(embed_dim=self.embed_dim) - async def _execute(self, input: Message) -> Message: - if self.query_key not in input.content: - raise ValueError(f"Cannot find the key {self.query_key} in the input message.") - query = input.content[self.query_key] - embed_message = Message(content={"input": [query]}) - embedding_message = await self.embedder(embed_message) - embedding_ndarray = np.array(embedding_message.content["embeddings"]) + async def _retrieve(self, input: Message, query_embedding: List[float]) -> Message: + embedding_ndarray = np.array([query_embedding]) # Convert from ndarray to list of list of float and there should be only one embedding. search_params = { "collection_name": self.collection_name, "data": embedding_ndarray.tolist(), - "limit": self.top_k, + "limit": input.content.get("top_k", self.top_k), "anns_field": "vector", "output_fields": self.keys_to_retrieve, } @@ -163,7 +187,7 @@ async def _execute(self, input: Message) -> Message: # Construct the response message. response_message = Message( content={ - "results": results, + self.result_key: results, } ) return response_message diff --git a/pyproject.toml b/pyproject.toml index bb36ff2..ccb4dd3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "polymind" -version = "0.0.32" # Update this version before publishing to PyPI +version = "0.0.33" # Update this version before publishing to PyPI description = "PolyMind is a customizable collaborative multi-agent framework for collective intelligence and distributed problem solving." authors = ["TechTao"] license = "MIT License"