From 224f143caa8d12181c8ccbb0c2be141867f5cfb4 Mon Sep 17 00:00:00 2001 From: Shroominic Date: Tue, 26 Dec 2023 20:49:22 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A7=20improve=20chain=20creation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/funcchain/chain/creation.py | 161 +++++++++++++++++--------------- 1 file changed, 86 insertions(+), 75 deletions(-) diff --git a/src/funcchain/chain/creation.py b/src/funcchain/chain/creation.py index afdefb6..31dcc0a 100644 --- a/src/funcchain/chain/creation.py +++ b/src/funcchain/chain/creation.py @@ -1,29 +1,25 @@ from types import UnionType -from typing import TypeVar, Type +from typing import Type, TypeVar +from langchain_core.callbacks import Callbacks +from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.language_models import BaseChatModel -from langchain_core.prompts import ChatPromptTemplate from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from langchain_core.output_parsers import BaseOutputParser -from langchain_core.chat_history import BaseChatMessageHistory -from langchain_core.runnables import ( - RunnableSerializable, - RunnableWithFallbacks, -) +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.runnables import RunnableSerializable from PIL import Image from pydantic import BaseModel -from funcchain._llms import ChatLlamaCpp - from ..parser import MultiToolParser, ParserBaseModel, PydanticFuncParser from ..settings import FuncchainSettings from ..streaming import stream_handler from ..utils import ( - parser_for, count_tokens, is_function_model, is_vision_model, multi_pydantic_to_functions, + parser_for, pydantic_to_functions, pydantic_to_grammar, univeral_model_selector, @@ -44,7 +40,7 @@ def create_union_chain( system: str, memory: BaseChatMessageHistory, context: list[BaseMessage], - llm: BaseChatModel | RunnableWithFallbacks, + llm: BaseChatModel, input_kwargs: dict[str, str], ) -> RunnableSerializable[dict[str, str], BaseModel]: """ @@ -64,16 +60,7 @@ def create_union_chain( functions = multi_pydantic_to_functions(output_types) - if isinstance(llm, RunnableWithFallbacks): - llm = llm.runnable.bind(**functions).with_fallbacks( - [ - fallback.bind(**functions) - for fallback in llm.fallbacks - if hasattr(llm, "fallbacks") - ] - ) - else: - llm = llm.bind(**functions) # type: ignore + llm = llm.bind(**functions) # type: ignore prompt = create_chat_prompt( system, @@ -93,24 +80,15 @@ def create_union_chain( def create_pydanctic_chain( output_type: type[BaseModel], prompt: ChatPromptTemplate, - llm: BaseChatModel | RunnableWithFallbacks, + llm: BaseChatModel, input_kwargs: dict[str, str], ) -> RunnableSerializable[dict[str, str], BaseModel]: # TODO: check these format_instructions input_kwargs["format_instructions"] = f"Extract to {output_type.__name__}." functions = pydantic_to_functions(output_type) - llm = ( - llm.runnable.bind(**functions).with_fallbacks( # type: ignore - [ - fallback.bind(**functions) - for fallback in llm.fallbacks - if hasattr(llm, "fallbacks") - ] - ) - if isinstance(llm, RunnableWithFallbacks) - else llm.bind(**functions) - ) + llm = llm.bind(**functions) # type: ignore + return prompt | llm | PydanticFuncParser(pydantic_schema=output_type) @@ -127,13 +105,17 @@ def create_chain( Compile a langchain runnable chain from the funcchain syntax. """ # large language model - llm = _gather_llm(settings) + _llm = _gather_llm(settings) + llm = _add_custom_callbacks(_llm, settings) parser = parser_for(output_type) # add format instructions for parser - if parser and not is_function_model(llm): - instruction = _add_format_instructions( + f_instructions = None + if parser and (settings.streaming or not is_function_model(llm)): + # streaming behavior is not supported for function models + # but for normal function models we do not need to add format instructions + instruction, f_instructions = _add_format_instructions( parser, instruction, input_kwargs, @@ -151,29 +133,18 @@ def create_chain( images = _handle_images(llm, input_kwargs) # create prompts - instruction_prompt = create_instruction_prompt(instruction, images, input_kwargs) + instruction_prompt = create_instruction_prompt( + instruction, + images, + input_kwargs, + format_instructions=f_instructions, + ) chat_prompt = create_chat_prompt(system, instruction_prompt, context, memory) # add formatted instruction to chat history memory.add_message(instruction_prompt.format(**input_kwargs)) - if isinstance(llm, ChatLlamaCpp): - if isinstance(output_type, UnionType): - # TODO: implement Union Type grammar - raise NotImplementedError( - "Union types are not yet supported for LlamaCpp models." - ) - if issubclass(output_type, BaseModel) and not issubclass( - output_type, ParserBaseModel - ): - from llama_cpp import LlamaGrammar - - grammar = pydantic_to_grammar(output_type) - setattr( - llm, - "grammar", - LlamaGrammar.from_string(grammar, verbose=False), - ) + _inject_grammar_for_local_models(llm, output_type) # function model patches if is_function_model(llm): @@ -191,13 +162,16 @@ def create_chain( if issubclass(output_type, BaseModel) and not issubclass( output_type, ParserBaseModel ): - return create_pydanctic_chain( # type: ignore - output_type, - chat_prompt, - llm, - input_kwargs, - ) - + if settings.streaming and hasattr(llm, "model_kwargs"): + llm.model_kwargs = {"response_format": {"type": "json_object"}} + else: + return create_pydanctic_chain( # type: ignore + output_type, + chat_prompt, + llm, + input_kwargs, + ) + assert parser is not None return chat_prompt | llm | parser @@ -205,7 +179,7 @@ def _add_format_instructions( parser: BaseOutputParser, instruction: str, input_kwargs: dict[str, str], -) -> str: +) -> tuple[str, str | None]: """ Add parsing format instructions to the instruction message and input_kwargs @@ -215,9 +189,9 @@ def _add_format_instructions( if format_instructions := parser.get_format_instructions(): instruction += "\n{format_instructions}" input_kwargs["format_instructions"] = format_instructions - return instruction + return instruction, format_instructions except NotImplementedError: - return instruction + return instruction, None def _crop_large_inputs( @@ -239,7 +213,7 @@ def _crop_large_inputs( def _handle_images( - llm: BaseChatModel | RunnableWithFallbacks, + llm: BaseChatModel, input_kwargs: dict[str, str], ) -> list[Image.Image]: """ @@ -256,12 +230,33 @@ def _handle_images( return images +def _inject_grammar_for_local_models(llm: BaseChatModel, output_type: type) -> None: + """ + Inject GBNF grammar into local models. + """ + try: + from funcchain._llms import ChatOllama + except: # noqa + pass + else: + if isinstance(llm, ChatOllama): + if isinstance(output_type, UnionType): + raise NotImplementedError( + "Union types are not yet supported for LlamaCpp models." + ) # TODO: implement + + if issubclass(output_type, BaseModel) and not issubclass( + output_type, ParserBaseModel + ): + llm.grammar = pydantic_to_grammar(output_type) + if issubclass(output_type, ParserBaseModel): + llm.grammar = output_type.custom_grammar() + + def _gather_llm( settings: FuncchainSettings, -) -> BaseChatModel | RunnableWithFallbacks: - if isinstance(settings.llm, RunnableWithFallbacks) or isinstance( - settings.llm, BaseChatModel - ): +) -> BaseChatModel: + if isinstance(settings.llm, BaseChatModel): llm = settings.llm else: llm = univeral_model_selector(settings) @@ -271,12 +266,28 @@ def _gather_llm( "No language model provided. Either set the llm environment variable or " "pass a model to the `chain` function." ) + return llm + + +def _add_custom_callbacks( + llm: BaseChatModel, settings: FuncchainSettings +) -> BaseChatModel: + callbacks: Callbacks = [] + if handler := stream_handler.get(): + callbacks = [handler] + + if settings.console_stream: + from ..streaming import AsyncStreamHandler + + callbacks = [ + AsyncStreamHandler(print, {"end": "", "flush": True}), + ] + + if callbacks: settings.streaming = True - if isinstance(llm, RunnableWithFallbacks) and isinstance( - llm.runnable, BaseChatModel - ): - llm.runnable.callbacks = [handler] - elif isinstance(llm, BaseChatModel): - llm.callbacks = [handler] + if hasattr(llm, "streaming"): + llm.streaming = True + llm.callbacks = callbacks + return llm