Skip to content

Commit

Permalink
Merge pull request #60 from small-thinking/update-retriever
Browse files Browse the repository at this point in the history
Update retrieve tool
  • Loading branch information
yxjiang authored Mar 30, 2024
2 parents 0fd741a + 23f92e3 commit d37e10b
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 12 deletions.
46 changes: 35 additions & 11 deletions polymind/core_tools/retrieve_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
class RetrieveTool(BaseTool, ABC):
"""The base class for the retrieval tools."""

query_key: str = Field(default="query", description="The key to retrieve the query from the input message.")
query_key: str = Field(default="input", description="The key to retrieve the query from the input message.")
result_key: str = Field(default="results", description="The key to store the results in the output message.")
embedder: Embedder = Field(description="The embedder to generate the embedding for the descriptions.")
top_k: int = Field(default=3, description="The number of top results to retrieve.")

Expand Down Expand Up @@ -54,7 +55,7 @@ def input_spec(self) -> List[Param]:
def output_spec(self) -> List[Param]:
output_spec = [
Param(
name="results",
name=self.result_key,
type="List[str]",
required=True,
description="The top k results retrieved by the tool.",
Expand All @@ -67,6 +68,34 @@ def output_spec(self) -> List[Param]:
]
return output_spec

@abstractmethod
async def _retrieve(self, input: Message, query_embedding: List[float]) -> Message:
"""Retrieve the information based on the query.
Args:
input (Message): The input message containing the query. It should have fields defined in the input_spec.
query_embedding (List[List[float]]): The embedding of the query.
Return:
Message: The message containing the retrieved information.
"""
pass

async def _execute(self, input: Message) -> Message:
"""Retrieve the information based on the query.
Args:
input (Message): The input message containing the query. It should have fields defined in the input_spec.
"""
# Get the embeddings for the query.
query = input.content.get(self.query_key, "")
embed_message = Message(content={"input": [query]})
embedding_message = await self.embedder(embed_message)
embedding_message.content["embeddings"]
# Retrieve the information based on the query.
response_message = await self._retrieve(embedding_message.content["embeddings"])
return response_message


class IndexTool(BaseTool, ABC):
"""The base class for the indexing tools."""
Expand Down Expand Up @@ -137,18 +166,13 @@ def _set_client(self):
self._client.create_collection(self.collection_name, dimension=self.embed_dim, auto_id=True)
self.embedder = OpenAIEmbeddingTool(embed_dim=self.embed_dim)

async def _execute(self, input: Message) -> Message:
if self.query_key not in input.content:
raise ValueError(f"Cannot find the key {self.query_key} in the input message.")
query = input.content[self.query_key]
embed_message = Message(content={"input": [query]})
embedding_message = await self.embedder(embed_message)
embedding_ndarray = np.array(embedding_message.content["embeddings"])
async def _retrieve(self, input: Message, query_embedding: List[float]) -> Message:
embedding_ndarray = np.array([query_embedding])
# Convert from ndarray to list of list of float and there should be only one embedding.
search_params = {
"collection_name": self.collection_name,
"data": embedding_ndarray.tolist(),
"limit": self.top_k,
"limit": input.content.get("top_k", self.top_k),
"anns_field": "vector",
"output_fields": self.keys_to_retrieve,
}
Expand All @@ -163,7 +187,7 @@ async def _execute(self, input: Message) -> Message:
# Construct the response message.
response_message = Message(
content={
"results": results,
self.result_key: results,
}
)
return response_message
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "polymind"
version = "0.0.32" # Update this version before publishing to PyPI
version = "0.0.33" # Update this version before publishing to PyPI
description = "PolyMind is a customizable collaborative multi-agent framework for collective intelligence and distributed problem solving."
authors = ["TechTao"]
license = "MIT License"
Expand Down

0 comments on commit d37e10b

Please sign in to comment.