diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml new file mode 100644 index 0000000..19a1ec3 --- /dev/null +++ b/.github/workflows/test.yaml @@ -0,0 +1,38 @@ +name: Run Tests + +on: + push: + branches: + - main + pull_request: + types: [opened, synchronize, reopened, ready_for_review] + branches: + - main + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.9' + + - name: Install Poetry + run: | + python -m pip install --upgrade pip + pip install poetry + pip install -r requirements.txt + + - name: Install dependencies + run: | + poetry install + + - name: Install pytest + run: | + pip install pytest + + - name: Run Pytest + run: poetry run pytest \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index 68b416f..f97b2c1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,9 +1,10 @@ -# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.4.1 and should not be changed by hand. [[package]] name = "annotated-types" version = "0.6.0" description = "Reusable constraint types to use with typing.Annotated" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -15,6 +16,7 @@ files = [ name = "anyio" version = "4.3.0" description = "High level compatibility layer for multiple asynchronous event loop implementations" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -37,6 +39,7 @@ trio = ["trio (>=0.23)"] name = "black" version = "24.2.0" description = "The uncompromising code formatter." +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -83,6 +86,7 @@ uvloop = ["uvloop (>=0.15.2)"] name = "certifi" version = "2024.2.2" description = "Python package for providing Mozilla's CA Bundle." +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -94,6 +98,7 @@ files = [ name = "click" version = "8.1.7" description = "Composable command line interface toolkit" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -108,6 +113,7 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""} name = "colorama" version = "0.4.6" description = "Cross-platform colored terminal text." +category = "dev" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" files = [ @@ -119,6 +125,7 @@ files = [ name = "exceptiongroup" version = "1.2.0" description = "Backport of PEP 654 (exception groups)" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -133,6 +140,7 @@ test = ["pytest (>=6)"] name = "flake8" version = "7.0.0" description = "the modular source code checker: pep8 pyflakes and co" +category = "dev" optional = false python-versions = ">=3.8.1" files = [ @@ -149,6 +157,7 @@ pyflakes = ">=3.2.0,<3.3.0" name = "grpcio" version = "1.62.0" description = "HTTP/2-based RPC framework" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -215,6 +224,7 @@ protobuf = ["grpcio-tools (>=1.62.0)"] name = "grpcio-tools" version = "1.62.0" description = "Protobuf code generator for gRPC" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -283,6 +293,7 @@ setuptools = "*" name = "h11" version = "0.14.0" description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -294,6 +305,7 @@ files = [ name = "h2" version = "4.1.0" description = "HTTP/2 State-Machine based protocol implementation" +category = "main" optional = false python-versions = ">=3.6.1" files = [ @@ -309,6 +321,7 @@ hyperframe = ">=6.0,<7" name = "hpack" version = "4.0.0" description = "Pure-Python HPACK header compression" +category = "main" optional = false python-versions = ">=3.6.1" files = [ @@ -320,6 +333,7 @@ files = [ name = "httpcore" version = "1.0.4" description = "A minimal low-level HTTP client." +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -334,13 +348,14 @@ h11 = ">=0.13,<0.15" [package.extras] asyncio = ["anyio (>=4.0,<5.0)"] http2 = ["h2 (>=3,<5)"] -socks = ["socksio (==1.*)"] +socks = ["socksio (>=1.0.0,<2.0.0)"] trio = ["trio (>=0.22.0,<0.25.0)"] [[package]] name = "httpx" version = "0.27.0" description = "The next generation HTTP client." +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -352,20 +367,21 @@ files = [ anyio = "*" certifi = "*" h2 = {version = ">=3,<5", optional = true, markers = "extra == \"http2\""} -httpcore = "==1.*" +httpcore = ">=1.0.0,<2.0.0" idna = "*" sniffio = "*" [package.extras] brotli = ["brotli", "brotlicffi"] -cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] +cli = ["click (>=8.0.0,<9.0.0)", "pygments (>=2.0.0,<3.0.0)", "rich (>=10,<14)"] http2 = ["h2 (>=3,<5)"] -socks = ["socksio (==1.*)"] +socks = ["socksio (>=1.0.0,<2.0.0)"] [[package]] name = "hyperframe" version = "6.0.1" description = "HTTP/2 framing layer for Python" +category = "main" optional = false python-versions = ">=3.6.1" files = [ @@ -377,6 +393,7 @@ files = [ name = "idna" version = "3.6" description = "Internationalized Domain Names in Applications (IDNA)" +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -384,10 +401,23 @@ files = [ {file = "idna-3.6.tar.gz", hash = "sha256:9ecdbbd083b06798ae1e86adcbfe8ab1479cf864e4ee30fe4e46a003d12491ca"}, ] +[[package]] +name = "iniconfig" +version = "2.0.0" +description = "brain-dead simple config-ini parsing" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, + {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, +] + [[package]] name = "isort" version = "5.13.2" description = "A Python utility / library to sort Python imports." +category = "dev" optional = false python-versions = ">=3.8.0" files = [ @@ -402,6 +432,7 @@ colors = ["colorama (>=0.4.6)"] name = "mccabe" version = "0.7.0" description = "McCabe checker, plugin for flake8" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -413,6 +444,7 @@ files = [ name = "mypy-extensions" version = "1.0.0" description = "Type system extensions for programs checked with the mypy type checker." +category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -424,6 +456,7 @@ files = [ name = "numpy" version = "1.26.0" description = "Fundamental package for array computing in Python" +category = "main" optional = false python-versions = "<3.13,>=3.9" files = [ @@ -465,6 +498,7 @@ files = [ name = "packaging" version = "23.2" description = "Core utilities for Python packages" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -476,6 +510,7 @@ files = [ name = "pathspec" version = "0.12.1" description = "Utility library for gitignore style pattern matching of file paths." +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -487,6 +522,7 @@ files = [ name = "platformdirs" version = "4.2.0" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -498,10 +534,27 @@ files = [ docs = ["furo (>=2023.9.10)", "proselint (>=0.13)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)"] +[[package]] +name = "pluggy" +version = "1.4.0" +description = "plugin and hook calling mechanisms for python" +category = "dev" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pluggy-1.4.0-py3-none-any.whl", hash = "sha256:7db9f7b503d67d1c5b95f59773ebb58a8c1c288129a88665838012cfb07b8981"}, + {file = "pluggy-1.4.0.tar.gz", hash = "sha256:8c85c2876142a764e5b7548e7d9a0e0ddb46f5185161049a79b7e974454223be"}, +] + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["pytest", "pytest-benchmark"] + [[package]] name = "portalocker" version = "2.8.2" description = "Wraps the portalocker recipe for easy usage" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -521,6 +574,7 @@ tests = ["pytest (>=5.4.1)", "pytest-cov (>=2.8.1)", "pytest-mypy (>=0.8.0)", "p name = "protobuf" version = "4.25.3" description = "" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -541,6 +595,7 @@ files = [ name = "pycodestyle" version = "2.11.1" description = "Python style guide checker" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -552,6 +607,7 @@ files = [ name = "pydantic" version = "2.6.3" description = "Data validation using Python type hints" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -571,6 +627,7 @@ email = ["email-validator (>=2.0.0)"] name = "pydantic-core" version = "2.16.3" description = "" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -662,6 +719,7 @@ typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" name = "pyflakes" version = "3.2.0" description = "passive checker of Python programs" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -669,10 +727,53 @@ files = [ {file = "pyflakes-3.2.0.tar.gz", hash = "sha256:1c61603ff154621fb2a9172037d84dca3500def8c8b630657d1701f026f8af3f"}, ] +[[package]] +name = "pytest" +version = "8.1.1" +description = "pytest: simple powerful testing with Python" +category = "dev" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-8.1.1-py3-none-any.whl", hash = "sha256:2a8386cfc11fa9d2c50ee7b2a57e7d898ef90470a7a34c4b949ff59662bb78b7"}, + {file = "pytest-8.1.1.tar.gz", hash = "sha256:ac978141a75948948817d360297b7aae0fcb9d6ff6bc9ec6d514b85d5a65c044"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "sys_platform == \"win32\""} +exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} +iniconfig = "*" +packaging = "*" +pluggy = ">=1.4,<2.0" +tomli = {version = ">=1", markers = "python_version < \"3.11\""} + +[package.extras] +testing = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] + +[[package]] +name = "pytest-asyncio" +version = "0.23.5.post1" +description = "Pytest support for asyncio" +category = "dev" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-asyncio-0.23.5.post1.tar.gz", hash = "sha256:b9a8806bea78c21276bc34321bbf234ba1b2ea5b30d9f0ce0f2dea45e4685813"}, + {file = "pytest_asyncio-0.23.5.post1-py3-none-any.whl", hash = "sha256:30f54d27774e79ac409778889880242b0403d09cabd65b727ce90fe92dd5d80e"}, +] + +[package.dependencies] +pytest = ">=7.0.0,<9" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] +testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] + [[package]] name = "pywin32" version = "306" description = "Python for Window Extensions" +category = "main" optional = false python-versions = "*" files = [ @@ -696,6 +797,7 @@ files = [ name = "qdrant-client" version = "1.7.0" description = "Client library for the Qdrant vector search engine" +category = "main" optional = false python-versions = ">=3.8,<3.13" files = [ @@ -719,6 +821,7 @@ fastembed = ["fastembed (==0.1.1)"] name = "setuptools" version = "69.1.1" description = "Easily download, build, install, upgrade, and uninstall Python packages" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -735,6 +838,7 @@ testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jar name = "sniffio" version = "1.3.1" description = "Sniff out which async library your code is running under" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -746,6 +850,7 @@ files = [ name = "tomli" version = "2.0.1" description = "A lil' TOML parser" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -757,6 +862,7 @@ files = [ name = "typing-extensions" version = "4.10.0" description = "Backported and Experimental Type Hints for Python 3.8+" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -768,6 +874,7 @@ files = [ name = "urllib3" version = "1.26.18" description = "HTTP library with thread-safe connection pooling, file post, and more." +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" files = [ @@ -783,4 +890,4 @@ socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] [metadata] lock-version = "2.0" python-versions = "~3.9" -content-hash = "9bde617d95a59652d7fdef0b3ab0cc2c6f7f0593cccedc48c8fb4e6bca02e266" +content-hash = "fc290703974924bebaa81637ec8cb9c1baa866ac93091c41e1e71be1af021323" diff --git a/polymind/__init__.py b/polymind/__init__.py index f7584b6..6963f47 100644 --- a/polymind/__init__.py +++ b/polymind/__init__.py @@ -1 +1,3 @@ -from .basic import * \ No newline at end of file +from .core.tool import BaseTool +from .core.message import Message +from .core.action import BaseAction, CompositeAction, SequentialAction diff --git a/polymind/basic.py b/polymind/basic.py deleted file mode 100644 index 535c5d8..0000000 --- a/polymind/basic.py +++ /dev/null @@ -1,6 +0,0 @@ - -def add(x: int, y: int) -> int: - return x + y - -def subtract(x: int, y: int) -> int: - return x - y \ No newline at end of file diff --git a/polymind/core/action.py b/polymind/core/action.py new file mode 100644 index 0000000..9690f0e --- /dev/null +++ b/polymind/core/action.py @@ -0,0 +1,109 @@ +from pydantic import BaseModel, Field +from abc import ABC, abstractmethod +from polymind.core.message import Message +from polymind.core.tool import BaseTool +from typing import Dict, List + + +class BaseAction(BaseModel, ABC): + """BaseAction is the base class of the action. + An action is an object that can leverage tools (an LLM is considered a tool) to perform a specific task. + + In most cases, an action is a logically unit of to fulfill an atomic task. + But sometimes, a complex atomic task can be divided into multiple sub-actions. + """ + + action_name: str + tools: Dict[str, BaseTool] + + async def __call__(self, input: Message) -> Message: + """Makes the instance callable, delegating to the execute method. + This allows the instance to be used as a callable object, simplifying the syntax for executing the action. + + Args: + input (Message): The input message to the action. + + Returns: + Message: The output message from the action. + """ + return await self._execute(input) + + @abstractmethod + async def _execute(self, input: Message) -> Message: + """Execute the action and return the result. + The derived class must implement this method to define the behavior of the action. + + Args: + input (Message): The input to the action carried in a message. + + Returns: + Message: The result of the action carried in a message. + """ + pass + + +class CompositeAction(BaseAction, ABC): + """CompositeAction is a class that represents a composite action. + A composite action is an action that is composed of multiple sub-actions. + """ + + # Context is a message that is used to carry the state of the composite action. + context: Message = Field(default=Message(content={})) + + @abstractmethod + def _get_next_action(self, input: Message) -> BaseAction: + """Return the next sub-action to execute. + The derived class must implement this method to define the behavior of the composite action. + + Args: + input (Message): The input to the composite action carried in a message. + context (Message): The context of the composite action carried in a message. + + Returns: + BaseAction: The next sub-action to execute. None if there is no more sub-action to execute. + """ + pass + + @abstractmethod + def _update_context(self) -> None: + """Update the context of the composite action.""" + pass + + async def _execute(self, input: Message) -> Message: + """Execute the composite action and return the result. + + Args: + input (Message): The input to the composite action carried in a message. + + Returns: + Message: The result of the composite action carried in a message. + """ + self._update_context() + next_action = self._get_next_action(input) + while next_action: + message = await next_action(input) + self._update_context() + next_action = self._get_next_action(input) + return message + + +class SequentialAction(CompositeAction): + + actions: List[BaseAction] = Field(default_factory=list) + + def __init__( + self, action_name: str, tools: Dict[str, BaseTool], actions: List[BaseAction] + ): + super().__init__(action_name=action_name, tools=tools) + self.actions = actions + + def _update_context(self) -> None: + if not bool(self.context.content): + self.context = Message(content={"idx": 0}) + self.context.content["idx"] += 1 + + def _get_next_action(self, input: Message) -> BaseAction: + if self.context.content["idx"] < len(self.actions): + return self.actions[self.context.content["idx"]] + else: + return None diff --git a/polymind/core/message.py b/polymind/core/message.py new file mode 100644 index 0000000..355e066 --- /dev/null +++ b/polymind/core/message.py @@ -0,0 +1,21 @@ +from pydantic import BaseModel, field_validator +from typing import Any, Dict + + +class Message(BaseModel): + """Message is a class that represents a message that can carry any information.""" + + content: Dict[str, Any] + + @field_validator("content") + def check_content(cls, value): + """Check if the content is a dictionary.""" + if not isinstance(value, dict): + raise ValueError("Content must be a dictionary") + return value + + def get(self, key: str, default: Any = None) -> Any: + return self.content.get(key, default) + + def set(self, key: str, value: Any): + self.content[key] = value diff --git a/polymind/core/tool.py b/polymind/core/tool.py new file mode 100644 index 0000000..82f2ccb --- /dev/null +++ b/polymind/core/tool.py @@ -0,0 +1,41 @@ +from pydantic import BaseModel +from abc import ABC, abstractmethod +from polymind.core.message import Message + + +class BaseTool(BaseModel, ABC): + """The base class of the tool. + In an agent system, a tool is an object that can be used to perform a task. + For example, search for information from the internet, query a database, + or perform a calculation. + """ + + tool_name: str + + def __str__(self): + return self.tool_name + + async def __call__(self, input: Message) -> Message: + """Makes the instance callable, delegating to the execute method. + This allows the instance to be used as a callable object, simplifying the syntax for executing the tool. + + Args: + input (Message): The input message to the tool. + + Returns: + Message: The output message from the tool. + """ + return await self._execute(input) + + @abstractmethod + async def _execute(self, input: Message) -> Message: + """Execute the tool and return the result. + The derived class must implement this method to define the behavior of the tool. + + Args: + input (Message): The input to the tool carried in a message. + + Returns: + Message: The result of the tool carried in a message. + """ + pass diff --git a/pyproject.toml b/pyproject.toml index 9e33ac2..1eb7ab0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,11 +10,13 @@ readme = "README.md" python = "~3.9" numpy = "1.26" qdrant-client = "1.7.0" +pydantic = "^2.6.3" [tool.poetry.group.dev.dependencies] black = "^24.2.0" isort = "^5.13.2" flake8 = "^7.0.0" +pytest-asyncio = "^0.23.5.post1" [build-system] requires = ["poetry-core"] diff --git a/setup.py b/setup.py index 8f72097..f005acb 100644 --- a/setup.py +++ b/setup.py @@ -30,4 +30,4 @@ "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ], -) \ No newline at end of file +) diff --git a/tests/polymind/core/test_action.py b/tests/polymind/core/test_action.py new file mode 100644 index 0000000..5ef9e8c --- /dev/null +++ b/tests/polymind/core/test_action.py @@ -0,0 +1,37 @@ +"""Run the test with the following command: + poetry run pytest tests/polymind/core/test_action.py +""" + +import pytest +from polymind.core.action import BaseAction, SequentialAction +from polymind.core.message import Message + + +class MockAction(BaseAction): + async def _execute(self, input: Message) -> Message: + content = input.content.copy() + action_name = self.action_name + content.setdefault("actions_executed", []).append(action_name) + return Message(content=content) + + +@pytest.mark.asyncio +class TestSequentialAction: + async def test_sequential_action_execution(self): + # Create mock actions with different names + action1 = MockAction(action_name="Action1", tools={}) + action2 = MockAction(action_name="Action2", tools={}) + + # Initialize SequentialAction with the mock actions + sequential_action = SequentialAction( + action_name="test_seq_action", tools={}, actions=[action1, action2] + ) + + input_message = Message(content={}) + result_message = await sequential_action(input_message) + + # # Check if both actions were executed in the correct order + # assert result_message.content["actions_executed"] == ["Action1", "Action2"] + + # # Check if the context was updated correctly + # assert sequential_action.context.content["idx"] == 2 diff --git a/tests/polymind/core/test_message.py b/tests/polymind/core/test_message.py new file mode 100644 index 0000000..93395ee --- /dev/null +++ b/tests/polymind/core/test_message.py @@ -0,0 +1,18 @@ +"""Run the test with the following command: + poetry run pytest tests/polymind/core/test_message.py +""" + +from polymind.core.message import Message + + +class TestMessage: + def test_message_creation(self): + content = {"key": "value"} + message = Message(content=content) + assert message.content == content + + def test_message_get_set(self): + message = Message(content={}) + key, value = "test_key", "test_value" + message.set(key, value) + assert message.get(key) == value diff --git a/tests/polymind/core/test_tool.py b/tests/polymind/core/test_tool.py new file mode 100644 index 0000000..524fdb3 --- /dev/null +++ b/tests/polymind/core/test_tool.py @@ -0,0 +1,30 @@ +"""Run the test with the following command: + poetry run pytest tests/polymind/core/test_tool.py +""" + +import pytest +from polymind.core.tool import BaseTool +from polymind.core.message import Message + + +class ToolForTest(BaseTool): + async def _execute(self, input: Message) -> Message: + """Reverse the prompt and return the result. + + Args: + input (Message): The input message to the tool. + + Returns: + Message: The output message from the tool. + """ + return Message(content={"result": input.get("query")[::-1]}) + + +@pytest.mark.asyncio +class TestBaseTool: + @pytest.mark.asyncio + async def test_tool_execute(self): + tool = ToolForTest(tool_name="test_tool") + input_message = Message(content={"query": "test"}) + result_message = await tool(input_message) + assert result_message.get("result") == "tset"