From f43da31b87df7de5b37e9c3a75925877c2dd7cb2 Mon Sep 17 00:00:00 2001 From: Pierce Kelaita Date: Wed, 19 Jun 2024 16:41:08 -0700 Subject: [PATCH] [json mode] implement json mode --- l2m2/client/async_llm_client.py | 42 ++++++++++++++- l2m2/client/llm_client.py | 82 +++++++++++++++++++++++++++++- l2m2/model_info.py | 19 ++++++- l2m2/tools/__init__.py | 3 +- l2m2/tools/json_mode_strategies.py | 50 ++++++++++++++++++ 5 files changed, 190 insertions(+), 6 deletions(-) create mode 100644 l2m2/tools/json_mode_strategies.py diff --git a/l2m2/client/async_llm_client.py b/l2m2/client/async_llm_client.py index b481f52..f1a739f 100644 --- a/l2m2/client/async_llm_client.py +++ b/l2m2/client/async_llm_client.py @@ -3,6 +3,7 @@ from typing import Optional, List, Any from l2m2.client import LLMClient +from l2m2.tools.json_mode_strategies import JsonModeStrategy class AsyncLLMClient(LLMClient): @@ -30,6 +31,8 @@ async def call_async( temperature: Optional[float] = None, max_tokens: Optional[int] = None, prefer_provider: Optional[str] = None, + json_mode: bool = False, + json_mode_strategy: JsonModeStrategy = JsonModeStrategy.strip(), ) -> str: """Asynchronously performs inference on any active model. @@ -45,6 +48,9 @@ async def call_async( the provider's default value for the model is used. Defaults to None. prefer_provider (str, optional): The preferred provider to use for the model, if the model is available from multiple active providers. Defaults to None. + json_mode (bool, optional): Whether to return the response in JSON format. Defaults to False. + json_mode_strategy (JsonModeStrategy, optional): The strategy to use to enforce JSON outputs + when `json_mode` is True. Defaults to `JsonModeStrategy.strip()`. Raises: ValueError: If the provided model is not active and/or not available. @@ -63,6 +69,8 @@ async def call_async( temperature=temperature, max_tokens=max_tokens, prefer_provider=prefer_provider, + json_mode=json_mode, + json_mode_strategy=json_mode_strategy, ) async def call_custom_async( @@ -74,6 +82,8 @@ async def call_custom_async( system_prompt: Optional[str] = None, temperature: Optional[float] = None, max_tokens: Optional[int] = None, + json_mode: bool = False, + json_mode_strategy: JsonModeStrategy = JsonModeStrategy.strip(), ) -> str: """Asynchronously Performs inference on any model from an active provider that is not officially supported by L2M2. This method does not guarantee correctness. @@ -90,6 +100,9 @@ async def call_custom_async( the provider's default value for the model is used. Defaults to None. max_tokens (int, optional): The maximum number of tokens to generate. If not specified, the provider's default value for the model is used. Defaults to None. + json_mode (bool, optional): Whether to return the response in JSON format. Defaults to False. + json_mode_strategy (JsonModeStrategy, optional): The strategy to use to enforce JSON outputs + when `json_mode` is True. Defaults to `JsonModeStrategy.strip()`. Raises: ValueError: If the provided model is not active and/or not available. @@ -104,6 +117,8 @@ async def call_custom_async( system_prompt=system_prompt, temperature=temperature, max_tokens=max_tokens, + json_mode=json_mode, + json_mode_strategy=json_mode_strategy, ) async def call_concurrent( @@ -116,6 +131,8 @@ async def call_concurrent( temperatures: Optional[List[float]] = None, max_tokens: Optional[List[int]] = None, prefer_providers: Optional[List[str]] = None, + json_modes: Optional[List[bool]] = None, + json_mode_strategies: Optional[List[JsonModeStrategy]] = None, ) -> List[str]: """Makes multiple concurrent calls to a given set of models, with a given set oof parameters (e.g. model, prompt, temperature, etc). Each parameter is passed in as a list @@ -160,6 +177,9 @@ async def call_concurrent( to use for all calls. Defaults to None. prefer_providers ([List[str]], optional): List of preferred providers to use for each model, if the model is available from multiple active providers. Defaults to None. + json_modes ([List[bool]], optional): Whether to use JSON mode for each call. Defaults to None. + json_mode_strategies ([List[JsonModeStrategy]], optional): The strategies to use to enforce + JSON outputs. Defaults to None. Raises: ValueError: If `n < 1`, or if any of the parameters are not of length `1` or `n`. @@ -173,7 +193,7 @@ async def call_concurrent( """ _check_concurrent_params( n, - [models, prompts, system_prompts, temperatures, max_tokens], + [models, prompts, system_prompts, temperatures, max_tokens, json_modes], inspect.getfullargspec(self.call_concurrent).kwonlyargs, ) @@ -185,6 +205,8 @@ async def call_concurrent( temperature=_get_helper(temperatures, i), max_tokens=_get_helper(max_tokens, i), prefer_provider=_get_helper(prefer_providers, i), + json_mode=_get_helper(json_modes, i), + json_mode_strategy=_get_helper(json_mode_strategies, i), ) for i in range(n) ] @@ -200,6 +222,8 @@ async def call_custom_concurrent( system_prompts: Optional[List[str]] = None, temperatures: Optional[List[float]] = None, max_tokens: Optional[List[int]] = None, + json_modes: Optional[List[bool]] = None, + json_mode_strategies: Optional[List[JsonModeStrategy]] = None, ) -> List[str]: """Makes multiple concurrent calls to a given set of user-given models, with a given set oof parameters (e.g. model_d, prompt, temperature, etc). Each parameter is passed in as a list @@ -250,6 +274,9 @@ async def call_custom_concurrent( temperature to use for all calls. Defaults to None. max_tokens ([List[int]], optional): List of max_tokens to use, or a single max_tokens to use for all calls. Defaults to None. + json_modes ([List[bool]], optional): Whether to use JSON mode for each call. Defaults to None. + json_mode_strategies ([List[JsonModeStrategy]], optional): The strategies to use to enforce + JSON outputs. Defaults to None. Raises: ValueError: If `n < 1`, or if any of the parameters are not of length `1` or `n`. @@ -262,7 +289,16 @@ async def call_custom_concurrent( """ _check_concurrent_params( n, - [providers, model_ids, prompts, system_prompts, temperatures, max_tokens], + [ + providers, + model_ids, + prompts, + system_prompts, + temperatures, + max_tokens, + json_modes, + json_mode_strategies, + ], inspect.getfullargspec(self.call_custom_concurrent).kwonlyargs, ) @@ -274,6 +310,8 @@ async def call_custom_concurrent( system_prompt=_get_helper(system_prompts, i), temperature=_get_helper(temperatures, i), max_tokens=_get_helper(max_tokens, i), + json_mode=_get_helper(json_modes, i), + json_mode_strategy=_get_helper(json_mode_strategies, i), ) for i in range(n) ] diff --git a/l2m2/client/llm_client.py b/l2m2/client/llm_client.py index ceb265f..93c2056 100644 --- a/l2m2/client/llm_client.py +++ b/l2m2/client/llm_client.py @@ -22,6 +22,12 @@ BaseMemory, MemoryType, ) +from l2m2.tools.json_mode_strategies import ( + JsonModeStrategy, + StrategyName, + get_extra_message, + run_json_strats_out, +) class LLMClient: @@ -237,6 +243,8 @@ def call( temperature: Optional[float] = None, max_tokens: Optional[int] = None, prefer_provider: Optional[str] = None, + json_mode: bool = False, + json_mode_strategy: JsonModeStrategy = JsonModeStrategy.strip(), ) -> str: """Performs inference on any active model. @@ -252,6 +260,9 @@ def call( the provider's default value for the model is used. Defaults to None. prefer_provider (str, optional): The preferred provider to use for the model, if the model is available from multiple active providers. Defaults to None. + json_mode (bool, optional): Whether to return the response in JSON format. Defaults to False. + json_mode_strategy (JsonModeStrategy, optional): The strategy to use to enforce JSON outputs + when `json_mode` is True. Defaults to `JsonModeStrategy.strip()`. Raises: ValueError: If the provided model is not active and/or not available. @@ -304,6 +315,8 @@ def call( system_prompt, temperature, max_tokens, + json_mode, + json_mode_strategy, ) def call_custom( @@ -315,6 +328,8 @@ def call_custom( system_prompt: Optional[str] = None, temperature: Optional[float] = None, max_tokens: Optional[int] = None, + json_mode: bool = False, + json_mode_strategy: JsonModeStrategy = JsonModeStrategy.strip(), ) -> str: """Performs inference on any model from an active provider that is not officially supported by L2M2. This method does not guarantee correctness. @@ -331,6 +346,9 @@ def call_custom( the provider's default value for the model is used. Defaults to None. max_tokens (int, optional): The maximum number of tokens to generate. If not specified, the provider's default value for the model is used. Defaults to None. + json_mode (bool, optional): Whether to return the response in JSON format. Defaults to False. + json_mode_strategy (JsonModeStrategy, optional): The strategy to use to enforce JSON outputs + when `json_mode` is True. Defaults to `JsonModeStrategy.strip()`. Raises: ValueError: If the provided model is not active and/or not available. @@ -354,6 +372,7 @@ def call_custom( if provider in MODEL_INFO[model].keys() ) ][provider]["params"], + "extras": {}, } return self._call_impl( @@ -363,6 +382,8 @@ def call_custom( system_prompt, temperature, max_tokens, + json_mode, + json_mode_strategy, ) def _call_impl( @@ -373,6 +394,8 @@ def _call_impl( system_prompt: Optional[str], temperature: Optional[float], max_tokens: Optional[int], + json_mode: bool = False, + json_mode_strategy: JsonModeStrategy = JsonModeStrategy.strip(), ) -> str: param_info = model_info["params"] params = {} @@ -396,15 +419,34 @@ def add_param(name: ParamName, value: Any) -> None: add_param("temperature", temperature) add_param("max_tokens", max_tokens) + # Handle native JSON mode + has_native_json_mode = "json_mode_arg" in model_info["extras"] + if json_mode and has_native_json_mode: + arg = model_info["extras"]["json_mode_arg"] + key, value = next(iter(arg.items())) + params[key] = value + + # Update prompts if we're using external memory if isinstance(self.memory, ExternalMemory): system_prompt, prompt = self._get_external_memory_prompts( system_prompt, prompt ) + # Run the LLM result = getattr(self, f"_call_{provider}")( - model_info["model_id"], prompt, system_prompt, params + model_info["model_id"], + prompt, + system_prompt, + params, + json_mode, + json_mode_strategy, ) + # Handle JSON mode strategies for the output (but only if we don't have native support) + if json_mode and not has_native_json_mode: + result = run_json_strats_out(json_mode_strategy, result) + + # Lastly, update chat memory if applicable if isinstance(self.memory, ChatMemory): self.memory.add_user_message(prompt) self.memory.add_agent_message(result) @@ -417,6 +459,7 @@ def _call_openai( prompt: str, system_prompt: Optional[str], params: Dict[str, Any], + *_: Any, # json_mode and json_mode_strategy are not used here ) -> str: oai = OpenAI(api_key=self.api_keys["openai"]) messages = [] @@ -438,6 +481,8 @@ def _call_anthropic( prompt: str, system_prompt: Optional[str], params: Dict[str, Any], + json_mode: bool, + json_mode_strategy: JsonModeStrategy, ) -> str: anthr = Anthropic(api_key=self.api_keys["anthropic"]) if system_prompt is not None: @@ -446,6 +491,12 @@ def _call_anthropic( if isinstance(self.memory, ChatMemory): messages.extend(self.memory.unpack("role", "content", "user", "assistant")) messages.append({"role": "user", "content": prompt}) + + if json_mode: + append_msg = get_extra_message(json_mode_strategy) + if append_msg: + messages.append({"role": "assistant", "content": append_msg}) + result = anthr.messages.create( model=model_id, messages=messages, # type: ignore @@ -459,6 +510,8 @@ def _call_cohere( prompt: str, system_prompt: Optional[str], params: Dict[str, Any], + json_mode: bool, + json_mode_strategy: JsonModeStrategy, ) -> str: cohere = CohereClient(api_key=self.api_keys["cohere"]) if system_prompt is not None: @@ -467,6 +520,13 @@ def _call_cohere( params["chat_history"] = self.memory.unpack( "role", "message", "USER", "CHATBOT" ) + + if json_mode: + append_msg = get_extra_message(json_mode_strategy) + if append_msg: + entry = {"role": "CHATBOT", "message": append_msg} + params.setdefault("chat_history", []).append(entry) + result = cohere.chat( model=model_id, message=prompt, @@ -480,6 +540,8 @@ def _call_groq( prompt: str, system_prompt: Optional[str], params: Dict[str, Any], + json_mode: bool, + json_mode_strategy: JsonModeStrategy, ) -> str: groq = Groq(api_key=self.api_keys["groq"]) messages = [] @@ -488,6 +550,12 @@ def _call_groq( if isinstance(self.memory, ChatMemory): messages.extend(self.memory.unpack("role", "content", "user", "assistant")) messages.append({"role": "user", "content": prompt}) + + if json_mode: + append_msg = get_extra_message(json_mode_strategy) + if append_msg: + messages.append({"role": "assistant", "content": append_msg}) + result = groq.chat.completions.create( model=model_id, messages=messages, # type: ignore @@ -501,6 +569,7 @@ def _call_google( prompt: str, system_prompt: Optional[str], params: Dict[str, Any], + *_: Any, # json_mode and json_mode_strategy are not used here ) -> str: google.configure(api_key=self.api_keys["google"]) @@ -533,9 +602,18 @@ def _call_replicate( prompt: str, system_prompt: Optional[str], params: Dict[str, Any], + _: bool, # json_mode is not used here + json_mode_strategy: JsonModeStrategy, ) -> str: if isinstance(self.memory, ChatMemory): - raise ValueError("Chat memory is not supported with Replicate models.") + raise ValueError( + "Chat memory is not supported with Replicate. Try using Groq." + ) + if json_mode_strategy.strategy_name == StrategyName.PREPEND: + raise ValueError( + "JsonModeStrategy.prepend() is not supported with Replicate." + + "Try using Groq, or using JsonModeStrategy.strip() instead." + ) client = replicate.Client(api_token=self.api_keys["replicate"]) if system_prompt is not None: diff --git a/l2m2/model_info.py b/l2m2/model_info.py index 2e48413..27e96aa 100644 --- a/l2m2/model_info.py +++ b/l2m2/model_info.py @@ -1,6 +1,6 @@ """Information about models and providers supported by L2M2.""" -from typing import Dict, Union +from typing import Dict, Union, Any from typing_extensions import TypedDict, NotRequired, TypeVar, Generic, Literal from enum import Enum import sys @@ -37,6 +37,7 @@ class ModelParams(TypedDict): class ModelEntry(TypedDict): model_id: str params: ModelParams + extras: Dict[str, Any] INF: int = sys.maxsize @@ -83,6 +84,7 @@ class ModelEntry(TypedDict): "max": 4096, }, }, + "extras": {"json_mode_arg": {"response_format": {"type": "json_object"}}}, }, }, "gpt-4-turbo": { @@ -98,6 +100,7 @@ class ModelEntry(TypedDict): "max": 4096, }, }, + "extras": {"json_mode_arg": {"response_format": {"type": "json_object"}}}, }, }, "gpt-3.5-turbo": { @@ -113,6 +116,7 @@ class ModelEntry(TypedDict): "max": 4096, }, }, + "extras": {"json_mode_arg": {"response_format": {"type": "json_object"}}}, }, }, "gemini-1.5-pro": { @@ -130,6 +134,7 @@ class ModelEntry(TypedDict): "max": 8192, }, }, + "extras": {"json_mode_arg": {"response_mime_type": "application/json"}}, }, }, "gemini-1.0-pro": { @@ -147,6 +152,7 @@ class ModelEntry(TypedDict): "max": 8192, }, }, + "extras": {}, }, }, "claude-3-opus": { @@ -162,6 +168,7 @@ class ModelEntry(TypedDict): "max": 4096, }, }, + "extras": {}, }, }, "claude-3-sonnet": { @@ -177,6 +184,7 @@ class ModelEntry(TypedDict): "max": 4096, }, }, + "extras": {}, }, }, "claude-3-haiku": { @@ -192,6 +200,7 @@ class ModelEntry(TypedDict): "max": 4096, }, }, + "extras": {}, }, }, "command-r": { @@ -207,6 +216,7 @@ class ModelEntry(TypedDict): "max": 4000, }, }, + "extras": {}, }, }, "command-r-plus": { @@ -222,6 +232,7 @@ class ModelEntry(TypedDict): "max": 4000, }, }, + "extras": {}, }, }, "mixtral-8x7b": { @@ -237,6 +248,7 @@ class ModelEntry(TypedDict): "max": 2**16 - 1, }, }, + "extras": {}, }, }, "gemma-7b": { @@ -252,6 +264,7 @@ class ModelEntry(TypedDict): "max": 2**16 - 1, }, }, + "extras": {}, }, }, "llama3-8b": { @@ -267,6 +280,7 @@ class ModelEntry(TypedDict): "max": 2**16 - 1, }, }, + "extras": {}, }, "replicate": { "model_id": "meta/meta-llama-3-8b-instruct", @@ -281,6 +295,7 @@ class ModelEntry(TypedDict): "max": INF, }, }, + "extras": {}, }, }, "llama3-70b": { @@ -296,6 +311,7 @@ class ModelEntry(TypedDict): "max": 2**16 - 1, }, }, + "extras": {}, }, "replicate": { "model_id": "meta/meta-llama-3-70b-instruct", @@ -310,6 +326,7 @@ class ModelEntry(TypedDict): "max": INF, }, }, + "extras": {}, }, }, } diff --git a/l2m2/tools/__init__.py b/l2m2/tools/__init__.py index 8366866..4e89a42 100644 --- a/l2m2/tools/__init__.py +++ b/l2m2/tools/__init__.py @@ -1,3 +1,4 @@ from .prompt_loader import PromptLoader +from .json_mode_strategies import JsonModeStrategy -__all__ = ["PromptLoader"] +__all__ = ["PromptLoader", "JsonModeStrategy"] diff --git a/l2m2/tools/json_mode_strategies.py b/l2m2/tools/json_mode_strategies.py new file mode 100644 index 0000000..678dd84 --- /dev/null +++ b/l2m2/tools/json_mode_strategies.py @@ -0,0 +1,50 @@ +from typing import Optional +from enum import Enum + +DEFAULT_PREFIX = "Here is the JSON response: " + + +class StrategyName(Enum): + STRIP = "strip" + PREPEND = "prepend" + + +class JsonModeStrategy: + def __init__( + self, + strategy_name: StrategyName, + prefix: Optional[str] = None, + ) -> None: + self.strategy_name = strategy_name + self.prefix = prefix + + @classmethod + def strip(cls) -> "JsonModeStrategy": + return cls(StrategyName.STRIP) + + @classmethod + def prepend(cls, custom_prefix: str = DEFAULT_PREFIX) -> "JsonModeStrategy": + return cls(StrategyName.PREPEND, custom_prefix) + + +def get_extra_message(strategy: JsonModeStrategy) -> Optional[str]: + if strategy.strategy_name == StrategyName.PREPEND: + assert strategy.prefix is not None + return strategy.prefix + "{" + + return None + + +def run_json_strats_out( + strategy: JsonModeStrategy, + output: str, +) -> str: + if strategy.strategy_name == StrategyName.PREPEND: + return "{" + output + + if strategy.strategy_name == StrategyName.STRIP: + start = output.find("{") + end = output.rfind("}") + if start == -1 or end == -1 or start >= end: + return output + return output[start : end + 1] # noqa: E203