Skip to content

Commit

Permalink
[json mode] implement json mode
Browse files Browse the repository at this point in the history
  • Loading branch information
pkelaita committed Jun 19, 2024
1 parent 4e2c033 commit f43da31
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 6 deletions.
42 changes: 40 additions & 2 deletions l2m2/client/async_llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional, List, Any

from l2m2.client import LLMClient
from l2m2.tools.json_mode_strategies import JsonModeStrategy


class AsyncLLMClient(LLMClient):
Expand Down Expand Up @@ -30,6 +31,8 @@ async def call_async(
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
prefer_provider: Optional[str] = None,
json_mode: bool = False,
json_mode_strategy: JsonModeStrategy = JsonModeStrategy.strip(),
) -> str:
"""Asynchronously performs inference on any active model.
Expand All @@ -45,6 +48,9 @@ async def call_async(
the provider's default value for the model is used. Defaults to None.
prefer_provider (str, optional): The preferred provider to use for the model, if the
model is available from multiple active providers. Defaults to None.
json_mode (bool, optional): Whether to return the response in JSON format. Defaults to False.
json_mode_strategy (JsonModeStrategy, optional): The strategy to use to enforce JSON outputs
when `json_mode` is True. Defaults to `JsonModeStrategy.strip()`.
Raises:
ValueError: If the provided model is not active and/or not available.
Expand All @@ -63,6 +69,8 @@ async def call_async(
temperature=temperature,
max_tokens=max_tokens,
prefer_provider=prefer_provider,
json_mode=json_mode,
json_mode_strategy=json_mode_strategy,
)

async def call_custom_async(
Expand All @@ -74,6 +82,8 @@ async def call_custom_async(
system_prompt: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
json_mode: bool = False,
json_mode_strategy: JsonModeStrategy = JsonModeStrategy.strip(),
) -> str:
"""Asynchronously Performs inference on any model from an active provider that is not
officially supported by L2M2. This method does not guarantee correctness.
Expand All @@ -90,6 +100,9 @@ async def call_custom_async(
the provider's default value for the model is used. Defaults to None.
max_tokens (int, optional): The maximum number of tokens to generate. If not specified,
the provider's default value for the model is used. Defaults to None.
json_mode (bool, optional): Whether to return the response in JSON format. Defaults to False.
json_mode_strategy (JsonModeStrategy, optional): The strategy to use to enforce JSON outputs
when `json_mode` is True. Defaults to `JsonModeStrategy.strip()`.
Raises:
ValueError: If the provided model is not active and/or not available.
Expand All @@ -104,6 +117,8 @@ async def call_custom_async(
system_prompt=system_prompt,
temperature=temperature,
max_tokens=max_tokens,
json_mode=json_mode,
json_mode_strategy=json_mode_strategy,
)

async def call_concurrent(
Expand All @@ -116,6 +131,8 @@ async def call_concurrent(
temperatures: Optional[List[float]] = None,
max_tokens: Optional[List[int]] = None,
prefer_providers: Optional[List[str]] = None,
json_modes: Optional[List[bool]] = None,
json_mode_strategies: Optional[List[JsonModeStrategy]] = None,
) -> List[str]:
"""Makes multiple concurrent calls to a given set of models, with a given set oof
parameters (e.g. model, prompt, temperature, etc). Each parameter is passed in as a list
Expand Down Expand Up @@ -160,6 +177,9 @@ async def call_concurrent(
to use for all calls. Defaults to None.
prefer_providers ([List[str]], optional): List of preferred providers to use for each
model, if the model is available from multiple active providers. Defaults to None.
json_modes ([List[bool]], optional): Whether to use JSON mode for each call. Defaults to None.
json_mode_strategies ([List[JsonModeStrategy]], optional): The strategies to use to enforce
JSON outputs. Defaults to None.
Raises:
ValueError: If `n < 1`, or if any of the parameters are not of length `1` or `n`.
Expand All @@ -173,7 +193,7 @@ async def call_concurrent(
"""
_check_concurrent_params(
n,
[models, prompts, system_prompts, temperatures, max_tokens],
[models, prompts, system_prompts, temperatures, max_tokens, json_modes],
inspect.getfullargspec(self.call_concurrent).kwonlyargs,
)

Expand All @@ -185,6 +205,8 @@ async def call_concurrent(
temperature=_get_helper(temperatures, i),
max_tokens=_get_helper(max_tokens, i),
prefer_provider=_get_helper(prefer_providers, i),
json_mode=_get_helper(json_modes, i),
json_mode_strategy=_get_helper(json_mode_strategies, i),
)
for i in range(n)
]
Expand All @@ -200,6 +222,8 @@ async def call_custom_concurrent(
system_prompts: Optional[List[str]] = None,
temperatures: Optional[List[float]] = None,
max_tokens: Optional[List[int]] = None,
json_modes: Optional[List[bool]] = None,
json_mode_strategies: Optional[List[JsonModeStrategy]] = None,
) -> List[str]:
"""Makes multiple concurrent calls to a given set of user-given models, with a given set oof
parameters (e.g. model_d, prompt, temperature, etc). Each parameter is passed in as a list
Expand Down Expand Up @@ -250,6 +274,9 @@ async def call_custom_concurrent(
temperature to use for all calls. Defaults to None.
max_tokens ([List[int]], optional): List of max_tokens to use, or a single max_tokens
to use for all calls. Defaults to None.
json_modes ([List[bool]], optional): Whether to use JSON mode for each call. Defaults to None.
json_mode_strategies ([List[JsonModeStrategy]], optional): The strategies to use to enforce
JSON outputs. Defaults to None.
Raises:
ValueError: If `n < 1`, or if any of the parameters are not of length `1` or `n`.
Expand All @@ -262,7 +289,16 @@ async def call_custom_concurrent(
"""
_check_concurrent_params(
n,
[providers, model_ids, prompts, system_prompts, temperatures, max_tokens],
[
providers,
model_ids,
prompts,
system_prompts,
temperatures,
max_tokens,
json_modes,
json_mode_strategies,
],
inspect.getfullargspec(self.call_custom_concurrent).kwonlyargs,
)

Expand All @@ -274,6 +310,8 @@ async def call_custom_concurrent(
system_prompt=_get_helper(system_prompts, i),
temperature=_get_helper(temperatures, i),
max_tokens=_get_helper(max_tokens, i),
json_mode=_get_helper(json_modes, i),
json_mode_strategy=_get_helper(json_mode_strategies, i),
)
for i in range(n)
]
Expand Down
82 changes: 80 additions & 2 deletions l2m2/client/llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
BaseMemory,
MemoryType,
)
from l2m2.tools.json_mode_strategies import (
JsonModeStrategy,
StrategyName,
get_extra_message,
run_json_strats_out,
)


class LLMClient:
Expand Down Expand Up @@ -237,6 +243,8 @@ def call(
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
prefer_provider: Optional[str] = None,
json_mode: bool = False,
json_mode_strategy: JsonModeStrategy = JsonModeStrategy.strip(),
) -> str:
"""Performs inference on any active model.
Expand All @@ -252,6 +260,9 @@ def call(
the provider's default value for the model is used. Defaults to None.
prefer_provider (str, optional): The preferred provider to use for the model, if the
model is available from multiple active providers. Defaults to None.
json_mode (bool, optional): Whether to return the response in JSON format. Defaults to False.
json_mode_strategy (JsonModeStrategy, optional): The strategy to use to enforce JSON outputs
when `json_mode` is True. Defaults to `JsonModeStrategy.strip()`.
Raises:
ValueError: If the provided model is not active and/or not available.
Expand Down Expand Up @@ -304,6 +315,8 @@ def call(
system_prompt,
temperature,
max_tokens,
json_mode,
json_mode_strategy,
)

def call_custom(
Expand All @@ -315,6 +328,8 @@ def call_custom(
system_prompt: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
json_mode: bool = False,
json_mode_strategy: JsonModeStrategy = JsonModeStrategy.strip(),
) -> str:
"""Performs inference on any model from an active provider that is not officially supported
by L2M2. This method does not guarantee correctness.
Expand All @@ -331,6 +346,9 @@ def call_custom(
the provider's default value for the model is used. Defaults to None.
max_tokens (int, optional): The maximum number of tokens to generate. If not specified,
the provider's default value for the model is used. Defaults to None.
json_mode (bool, optional): Whether to return the response in JSON format. Defaults to False.
json_mode_strategy (JsonModeStrategy, optional): The strategy to use to enforce JSON outputs
when `json_mode` is True. Defaults to `JsonModeStrategy.strip()`.
Raises:
ValueError: If the provided model is not active and/or not available.
Expand All @@ -354,6 +372,7 @@ def call_custom(
if provider in MODEL_INFO[model].keys()
)
][provider]["params"],
"extras": {},
}

return self._call_impl(
Expand All @@ -363,6 +382,8 @@ def call_custom(
system_prompt,
temperature,
max_tokens,
json_mode,
json_mode_strategy,
)

def _call_impl(
Expand All @@ -373,6 +394,8 @@ def _call_impl(
system_prompt: Optional[str],
temperature: Optional[float],
max_tokens: Optional[int],
json_mode: bool = False,
json_mode_strategy: JsonModeStrategy = JsonModeStrategy.strip(),
) -> str:
param_info = model_info["params"]
params = {}
Expand All @@ -396,15 +419,34 @@ def add_param(name: ParamName, value: Any) -> None:
add_param("temperature", temperature)
add_param("max_tokens", max_tokens)

# Handle native JSON mode
has_native_json_mode = "json_mode_arg" in model_info["extras"]
if json_mode and has_native_json_mode:
arg = model_info["extras"]["json_mode_arg"]
key, value = next(iter(arg.items()))
params[key] = value

# Update prompts if we're using external memory
if isinstance(self.memory, ExternalMemory):
system_prompt, prompt = self._get_external_memory_prompts(
system_prompt, prompt
)

# Run the LLM
result = getattr(self, f"_call_{provider}")(
model_info["model_id"], prompt, system_prompt, params
model_info["model_id"],
prompt,
system_prompt,
params,
json_mode,
json_mode_strategy,
)

# Handle JSON mode strategies for the output (but only if we don't have native support)
if json_mode and not has_native_json_mode:
result = run_json_strats_out(json_mode_strategy, result)

# Lastly, update chat memory if applicable
if isinstance(self.memory, ChatMemory):
self.memory.add_user_message(prompt)
self.memory.add_agent_message(result)
Expand All @@ -417,6 +459,7 @@ def _call_openai(
prompt: str,
system_prompt: Optional[str],
params: Dict[str, Any],
*_: Any, # json_mode and json_mode_strategy are not used here
) -> str:
oai = OpenAI(api_key=self.api_keys["openai"])
messages = []
Expand All @@ -438,6 +481,8 @@ def _call_anthropic(
prompt: str,
system_prompt: Optional[str],
params: Dict[str, Any],
json_mode: bool,
json_mode_strategy: JsonModeStrategy,
) -> str:
anthr = Anthropic(api_key=self.api_keys["anthropic"])
if system_prompt is not None:
Expand All @@ -446,6 +491,12 @@ def _call_anthropic(
if isinstance(self.memory, ChatMemory):
messages.extend(self.memory.unpack("role", "content", "user", "assistant"))
messages.append({"role": "user", "content": prompt})

if json_mode:
append_msg = get_extra_message(json_mode_strategy)
if append_msg:
messages.append({"role": "assistant", "content": append_msg})

result = anthr.messages.create(
model=model_id,
messages=messages, # type: ignore
Expand All @@ -459,6 +510,8 @@ def _call_cohere(
prompt: str,
system_prompt: Optional[str],
params: Dict[str, Any],
json_mode: bool,
json_mode_strategy: JsonModeStrategy,
) -> str:
cohere = CohereClient(api_key=self.api_keys["cohere"])
if system_prompt is not None:
Expand All @@ -467,6 +520,13 @@ def _call_cohere(
params["chat_history"] = self.memory.unpack(
"role", "message", "USER", "CHATBOT"
)

if json_mode:
append_msg = get_extra_message(json_mode_strategy)
if append_msg:
entry = {"role": "CHATBOT", "message": append_msg}
params.setdefault("chat_history", []).append(entry)

result = cohere.chat(
model=model_id,
message=prompt,
Expand All @@ -480,6 +540,8 @@ def _call_groq(
prompt: str,
system_prompt: Optional[str],
params: Dict[str, Any],
json_mode: bool,
json_mode_strategy: JsonModeStrategy,
) -> str:
groq = Groq(api_key=self.api_keys["groq"])
messages = []
Expand All @@ -488,6 +550,12 @@ def _call_groq(
if isinstance(self.memory, ChatMemory):
messages.extend(self.memory.unpack("role", "content", "user", "assistant"))
messages.append({"role": "user", "content": prompt})

if json_mode:
append_msg = get_extra_message(json_mode_strategy)
if append_msg:
messages.append({"role": "assistant", "content": append_msg})

result = groq.chat.completions.create(
model=model_id,
messages=messages, # type: ignore
Expand All @@ -501,6 +569,7 @@ def _call_google(
prompt: str,
system_prompt: Optional[str],
params: Dict[str, Any],
*_: Any, # json_mode and json_mode_strategy are not used here
) -> str:
google.configure(api_key=self.api_keys["google"])

Expand Down Expand Up @@ -533,9 +602,18 @@ def _call_replicate(
prompt: str,
system_prompt: Optional[str],
params: Dict[str, Any],
_: bool, # json_mode is not used here
json_mode_strategy: JsonModeStrategy,
) -> str:
if isinstance(self.memory, ChatMemory):
raise ValueError("Chat memory is not supported with Replicate models.")
raise ValueError(
"Chat memory is not supported with Replicate. Try using Groq."
)
if json_mode_strategy.strategy_name == StrategyName.PREPEND:
raise ValueError(
"JsonModeStrategy.prepend() is not supported with Replicate."
+ "Try using Groq, or using JsonModeStrategy.strip() instead."
)

client = replicate.Client(api_token=self.api_keys["replicate"])
if system_prompt is not None:
Expand Down
Loading

0 comments on commit f43da31

Please sign in to comment.