-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ce5a2a5
commit 002c3fe
Showing
2 changed files
with
123 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from funcchain.components import ChatRouter | ||
|
||
|
||
def handle_pdf_requests(user_query: str) -> None: | ||
print("Handling PDF requests with user query: ", user_query) | ||
|
||
|
||
def handle_csv_requests(user_query: str) -> None: | ||
print("Handling CSV requests with user query: ", user_query) | ||
|
||
|
||
def handle_default_requests(user_query: str) -> None: | ||
print("Handling DEFAULT requests with user query: ", user_query) | ||
|
||
|
||
router = ChatRouter( | ||
routes={ | ||
"pdf": { | ||
"handler": handle_pdf_requests, | ||
"description": "Call this for requests including PDF Files.", | ||
}, | ||
"csv": { | ||
"handler": handle_csv_requests, | ||
"description": "Call this for requests including CSV Files.", | ||
}, | ||
"default": handle_default_requests, | ||
}, | ||
) | ||
|
||
|
||
router.invoke_route("Can you summarize this csv?") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
from enum import Enum | ||
from typing import Union, Callable, TypedDict, Any, Coroutine | ||
from pydantic import BaseModel, Field, field_validator | ||
from funcchain import runnable | ||
|
||
|
||
class Route(TypedDict): | ||
handler: Union[Callable, Coroutine] | ||
description: str | ||
|
||
|
||
Routes = dict[str, Union[Route, Callable, Coroutine]] | ||
|
||
|
||
class ChatRouter(BaseModel): | ||
routes: Routes | ||
|
||
@field_validator("routes") | ||
def validate_routes(cls, v: Routes) -> Routes: | ||
if "default" not in v.keys(): | ||
raise ValueError("`default` route is missing") | ||
return v | ||
|
||
def create_route(self) -> Any: | ||
RouteChoices = Enum( # type: ignore | ||
"RouteChoices", | ||
{r: r for r in self.routes.keys()}, | ||
type=str, | ||
) | ||
|
||
class RouterModel(BaseModel): | ||
selector: RouteChoices = Field( | ||
default="default", | ||
description="Enum of the available routes.", | ||
) | ||
|
||
return runnable( | ||
instruction="Given the user query select the best query handler for it.", | ||
input_args=["user_query", "query_handlers"], | ||
output_type=RouterModel, | ||
) | ||
|
||
def show_routes(self) -> str: | ||
return "\n".join( | ||
[ | ||
f"{route_name}: {route['description']}" | ||
if isinstance(route, dict) | ||
else f"{route_name}: {route.__name__}" | ||
for route_name, route in self.routes.items() | ||
] | ||
) | ||
|
||
def invoke_route(self, user_query: str, /, **kwargs: Any) -> Any: | ||
route_query = self.create_route() | ||
|
||
selected_route = route_query.invoke( | ||
input={ | ||
"user_query": user_query, | ||
"query_handlers": self.show_routes(), | ||
} | ||
).selector | ||
assert isinstance(selected_route, str) | ||
|
||
if isinstance(self.routes[selected_route], dict): | ||
return self.routes[selected_route]["handler"](user_query, **kwargs) # type: ignore | ||
return self.routes[selected_route](user_query, **kwargs) # type: ignore | ||
|
||
async def ainvoke_route(self, user_query: str, /, **kwargs: Any) -> Any: | ||
import asyncio | ||
|
||
if not all( | ||
[ | ||
asyncio.iscoroutinefunction(route["handler"]) | ||
if isinstance(route, dict) | ||
else asyncio.iscoroutinefunction(route) | ||
for route in self.routes.values() | ||
] | ||
): | ||
raise ValueError("All routes must be awaitable when using `ainvoke_route`") | ||
|
||
route_query = self.create_route() | ||
selected_route = route_query.invoke( | ||
input={ | ||
"user_query": user_query, | ||
"query_handlers": self.show_routes(), | ||
} | ||
).selector | ||
assert isinstance(selected_route, str) | ||
|
||
if isinstance(self.routes[selected_route], dict): | ||
return await self.routes[selected_route]["handler"](user_query, **kwargs) # type: ignore | ||
return await self.routes[selected_route](user_query, **kwargs) # type: ignore |