From f0812f3453fc19e8842a73dc8fabcefe3d10abca Mon Sep 17 00:00:00 2001 From: Shroominic Date: Wed, 20 Dec 2023 21:35:12 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=A7=AA=20router=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/funcchain/components.py | 3 +++ tests/router_test.py | 40 +++++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+) create mode 100644 tests/router_test.py diff --git a/src/funcchain/components.py b/src/funcchain/components.py index 3e73866..4654dcf 100644 --- a/src/funcchain/components.py +++ b/src/funcchain/components.py @@ -15,6 +15,9 @@ class Route(TypedDict): class ChatRouter(BaseModel): routes: Routes + class Config: + arbitrary_types_allowed = True + @field_validator("routes") def validate_routes(cls, v: Routes) -> Routes: if "default" not in v.keys(): diff --git a/tests/router_test.py b/tests/router_test.py new file mode 100644 index 0000000..67b38a6 --- /dev/null +++ b/tests/router_test.py @@ -0,0 +1,40 @@ +from funcchain.components import ChatRouter + + +def handle_pdf_requests(user_query: str) -> str: + return f"Handling PDF requests with user query: {user_query}" + + +def handle_csv_requests(user_query: str) -> str: + return f"Handling CSV requests with user query: {user_query}" + + +def handle_default_requests(user_query: str) -> str: + return f"Handling DEFAULT requests with user query: {user_query}" + + +router = ChatRouter( + routes={ + "pdf": { + "handler": handle_pdf_requests, + "description": "Call this for requests including PDF Files.", + }, + "csv": { + "handler": handle_csv_requests, + "description": "Call this for requests including CSV Files.", + }, + "default": handle_default_requests, + }, +) + + +def test_router() -> None: + assert "Handling CSV" in router.invoke_route("Can you summarize this csv?") + + assert "Handling PDF" in router.invoke_route("Can you summarize this pdf?") + + assert "Handling DEFAULT" in router.invoke_route("Hey, whatsup?") + + +if __name__ == "__main__": + test_router()