From 002c3fe8f9e61f9f9e2d7b8f4e0beb7fa901ea7f Mon Sep 17 00:00:00 2001 From: Shroominic Date: Wed, 20 Dec 2023 21:27:36 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=80=20router=20component?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/router_component.py | 31 ++++++++++++ src/funcchain/components.py | 92 ++++++++++++++++++++++++++++++++++++ 2 files changed, 123 insertions(+) create mode 100644 examples/router_component.py create mode 100644 src/funcchain/components.py diff --git a/examples/router_component.py b/examples/router_component.py new file mode 100644 index 0000000..1e3b618 --- /dev/null +++ b/examples/router_component.py @@ -0,0 +1,31 @@ +from funcchain.components import ChatRouter + + +def handle_pdf_requests(user_query: str) -> None: + print("Handling PDF requests with user query: ", user_query) + + +def handle_csv_requests(user_query: str) -> None: + print("Handling CSV requests with user query: ", user_query) + + +def handle_default_requests(user_query: str) -> None: + print("Handling DEFAULT requests with user query: ", user_query) + + +router = ChatRouter( + routes={ + "pdf": { + "handler": handle_pdf_requests, + "description": "Call this for requests including PDF Files.", + }, + "csv": { + "handler": handle_csv_requests, + "description": "Call this for requests including CSV Files.", + }, + "default": handle_default_requests, + }, +) + + +router.invoke_route("Can you summarize this csv?") diff --git a/src/funcchain/components.py b/src/funcchain/components.py new file mode 100644 index 0000000..3e73866 --- /dev/null +++ b/src/funcchain/components.py @@ -0,0 +1,92 @@ +from enum import Enum +from typing import Union, Callable, TypedDict, Any, Coroutine +from pydantic import BaseModel, Field, field_validator +from funcchain import runnable + + +class Route(TypedDict): + handler: Union[Callable, Coroutine] + description: str + + +Routes = dict[str, Union[Route, Callable, Coroutine]] + + +class ChatRouter(BaseModel): + routes: Routes + + @field_validator("routes") + def validate_routes(cls, v: Routes) -> Routes: + if "default" not in v.keys(): + raise ValueError("`default` route is missing") + return v + + def create_route(self) -> Any: + RouteChoices = Enum( # type: ignore + "RouteChoices", + {r: r for r in self.routes.keys()}, + type=str, + ) + + class RouterModel(BaseModel): + selector: RouteChoices = Field( + default="default", + description="Enum of the available routes.", + ) + + return runnable( + instruction="Given the user query select the best query handler for it.", + input_args=["user_query", "query_handlers"], + output_type=RouterModel, + ) + + def show_routes(self) -> str: + return "\n".join( + [ + f"{route_name}: {route['description']}" + if isinstance(route, dict) + else f"{route_name}: {route.__name__}" + for route_name, route in self.routes.items() + ] + ) + + def invoke_route(self, user_query: str, /, **kwargs: Any) -> Any: + route_query = self.create_route() + + selected_route = route_query.invoke( + input={ + "user_query": user_query, + "query_handlers": self.show_routes(), + } + ).selector + assert isinstance(selected_route, str) + + if isinstance(self.routes[selected_route], dict): + return self.routes[selected_route]["handler"](user_query, **kwargs) # type: ignore + return self.routes[selected_route](user_query, **kwargs) # type: ignore + + async def ainvoke_route(self, user_query: str, /, **kwargs: Any) -> Any: + import asyncio + + if not all( + [ + asyncio.iscoroutinefunction(route["handler"]) + if isinstance(route, dict) + else asyncio.iscoroutinefunction(route) + for route in self.routes.values() + ] + ): + raise ValueError("All routes must be awaitable when using `ainvoke_route`") + + route_query = self.create_route() + selected_route = route_query.invoke( + input={ + "user_query": user_query, + "query_handlers": self.show_routes(), + } + ).selector + assert isinstance(selected_route, str) + + if isinstance(self.routes[selected_route], dict): + return await self.routes[selected_route]["handler"](user_query, **kwargs) # type: ignore + return await self.routes[selected_route](user_query, **kwargs) # type: ignore