Skip to content

Commit

Permalink
Add extra_params
Browse files Browse the repository at this point in the history
  • Loading branch information
pkelaita committed Dec 17, 2024
1 parent df9aa25 commit b80f485
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 10 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ _Current version: 0.0.39_

- Support for [Llama 3.3 70b](https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_3/) via [Groq](https://console.groq.com/docs/models) and [Cerebras](https://inference-docs.cerebras.ai/introduction).
- Support for OpenAI's [o1 series](https://openai.com/o1/): `o1`, `o1-preview`, and `o1-mini`.
- The `extra_params` parameter to `call` and `call_custom`.

> [!NOTE]
> At the time of this release, you must be on OpenAI's [usage tier](https://platform.openai.com/docs/guides/rate-limits) 5 to use `o1` and tier 1+ to use `o1-preview` and `o1-mini`.
Expand Down
38 changes: 31 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# L2M2: A Simple Python LLM Manager 💬👍

[![Tests](https://github.com/pkelaita/l2m2/actions/workflows/tests.yml/badge.svg?timestamp=1734470191)](https://github.com/pkelaita/l2m2/actions/workflows/tests.yml) [![codecov](https://codecov.io/github/pkelaita/l2m2/graph/badge.svg?token=UWIB0L9PR8)](https://codecov.io/github/pkelaita/l2m2) [![PyPI version](https://badge.fury.io/py/l2m2.svg?timestamp=1734470191)](https://badge.fury.io/py/l2m2)
[![Tests](https://github.com/pkelaita/l2m2/actions/workflows/tests.yml/badge.svg?timestamp=1734476927)](https://github.com/pkelaita/l2m2/actions/workflows/tests.yml) [![codecov](https://codecov.io/github/pkelaita/l2m2/graph/badge.svg?token=UWIB0L9PR8)](https://codecov.io/github/pkelaita/l2m2) [![PyPI version](https://badge.fury.io/py/l2m2.svg?timestamp=1734476927)](https://badge.fury.io/py/l2m2)

**L2M2** ("LLM Manager" → "LLMM" → "L2M2") is a tiny and very simple LLM manager for Python that exposes lots of models through a unified API. This is useful for evaluation, demos, production applications etc. that need to easily be model-agnostic.

Expand Down Expand Up @@ -59,9 +59,6 @@ L2M2 currently supports the following models:
| `llama-3.2-3b` | [Groq](https://wow.groq.com/) | `llama-3.2-3b-preview` |
| `llama-3.3-70b` | [Groq](https://wow.groq.com/), [Cerebras](https://cerebras.ai/) | `llama-3.3-70b-versatile`, `llama3.3-70b` |

> [!NOTE]
> Currently, you must be on OpenAI's [usage tier](https://platform.openai.com/docs/guides/rate-limits) 5 to use `o1` and tier 1+ to use `o1-preview`, `o1-mini`, and `gpt-4o`.
<!--end-model-table-->

## Table of Contents
Expand All @@ -78,6 +75,7 @@ L2M2 currently supports the following models:
- **Tools**
- [JSON Mode](#tools-json-mode)
- [Prompt Loader](#tools-prompt-loader)
- [Other Capabilities](#other-capabilities)
- [Planned Features](#planned-features)
- [Contributing](#contributing)
- [Contact](#contact)
Expand Down Expand Up @@ -154,9 +152,7 @@ response = client.call(
)
```

If you'd like to call a language model from one of the supported providers that isn't officially supported by L2M2 (for example, older models such as `gpt-4-0125-preview`), you can similarly `call_custom` with the additional required parameter `provider`, and pass in the model name expected by the provider's API. Unlike `call`, `call_custom` doesn't guarantee correctness or well-defined behavior.

### Example
#### Example

```python
# example.py
Expand Down Expand Up @@ -654,6 +650,34 @@ print(prompt)
Your name is Pierce and you are a software engineer.
```
## Other Capabilities
#### Call Custom
If you'd like to call a language model from one of the supported providers that isn't officially supported by L2M2 (for example, older models such as `gpt-4-0125-preview`), you can similarly `call_custom` with the additional required parameter `provider`, and pass in the model name expected by the provider's API. Unlike `call`, `call_custom` doesn't guarantee correctness or well-defined behavior.
```python
response = client.call_custom(
provider="<provider name>",
model_id="<model id for given provider>",
prompt="<prompt>",
...
)
```
#### Extra Parameters
You can pass in extra parameters to the provider's API (For example, [reasoning_effort](https://platform.openai.com/docs/api-reference/chat/create#chat-create-reasoning_effort) on OpenAI's o1 series) by passing in the `extra_params` parameter to `call` or `call_custom`. These parameters are passed in as a dictionary of key-value pairs, where the values are of type `str`, `int`, or `float`. Similarly, using `extra_params` does not guarantee correctness or well-defined behavior, and you should refer to the provider's documentation for correct usage.
```python
response = client.call(
model="<model name>",
prompt="<prompt>",
extra_params={"foo": "bar", "baz": 123},
...
)
```
## Planned Features
- Support for structured outputs where available (Just OpenAI as far as I know)
Expand Down
7 changes: 6 additions & 1 deletion l2m2/_internal/http.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Dict, Any
from typing import Optional, Dict, Any, Union
import httpx

from l2m2.exceptions import LLMTimeoutError, LLMRateLimitError
Expand Down Expand Up @@ -43,12 +43,17 @@ async def llm_post(
api_key: str,
data: Dict[str, Any],
timeout: Optional[int],
extra_params: Optional[Dict[str, Union[str, int, float]]],
) -> Any:
endpoint = PROVIDER_INFO[provider]["endpoint"]
if API_KEY in endpoint:
endpoint = endpoint.replace(API_KEY, api_key)
if MODEL_ID in endpoint and model_id is not None:
endpoint = endpoint.replace(MODEL_ID, model_id)

if extra_params:
data.update(extra_params)

try:
response = await client.post(
endpoint,
Expand Down
20 changes: 19 additions & 1 deletion l2m2/client/base_llm_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, Set, Dict, Optional, Tuple
from typing import Any, List, Set, Dict, Optional, Tuple, Union
import httpx
import os

Expand Down Expand Up @@ -262,6 +262,7 @@ async def call(
timeout: Optional[int] = DEFAULT_TIMEOUT_SECONDS,
bypass_memory: bool = False,
alt_memory: Optional[BaseMemory] = None,
extra_params: Optional[Dict[str, Union[str, int, float]]] = None,
) -> str:
"""Performs inference on any active model.
Expand Down Expand Up @@ -290,6 +291,8 @@ async def call(
alt_memory (BaseMemory, optional): An alternative memory object to use for this call only. This
is very useful for asynchronous workflows where you want to keep track of multiple memory
streams in parallel without risking race conditions. Defaults to `None`.
extra_params (Dict[str, Union[str, int, float]], optional): Extra parameters to pass to the model.
Defaults to `None`.
Raises:
ValueError: If the provided model is not active and/or not available.
Expand Down Expand Up @@ -347,6 +350,7 @@ async def call(
timeout,
bypass_memory,
alt_memory,
extra_params,
)

async def call_custom(
Expand All @@ -363,6 +367,7 @@ async def call_custom(
timeout: Optional[int] = DEFAULT_TIMEOUT_SECONDS,
bypass_memory: bool = False,
alt_memory: Optional[BaseMemory] = None,
extra_params: Optional[Dict[str, Union[str, int, float]]] = None,
) -> str:
"""Performs inference on any model from an active provider that is not officially supported
by L2M2. This method does not guarantee correctness.
Expand Down Expand Up @@ -430,6 +435,7 @@ async def call_custom(
timeout,
bypass_memory,
alt_memory,
extra_params,
)

async def _call_impl(
Expand All @@ -445,6 +451,7 @@ async def _call_impl(
timeout: Optional[int],
bypass_memory: bool,
alt_memory: Optional[BaseMemory],
extra_params: Optional[Dict[str, Union[str, int, float]]],
) -> str:
# Prepare memory
memory = alt_memory if alt_memory is not None else self.memory
Expand Down Expand Up @@ -486,6 +493,7 @@ async def _call_impl(
params,
timeout,
memory,
extra_params,
json_mode,
json_mode_strategy,
model_info["extras"],
Expand Down Expand Up @@ -516,6 +524,7 @@ async def _call_google(
params: Dict[str, Any],
timeout: Optional[int],
memory: Optional[BaseMemory],
extra_params: Optional[Dict[str, Union[str, int, float]]],
*_: Any, # json_mode and json_mode_strategy, and extras are not used here
) -> str:
data: Dict[str, Any] = {}
Expand All @@ -541,6 +550,7 @@ async def _call_google(
api_key=self.api_keys["google"],
data=data,
timeout=timeout,
extra_params=extra_params,
)
result = result["candidates"][0]

Expand All @@ -558,6 +568,7 @@ async def _call_anthropic(
params: Dict[str, Any],
timeout: Optional[int],
memory: Optional[BaseMemory],
extra_params: Optional[Dict[str, Union[str, int, float]]],
json_mode: bool,
json_mode_strategy: JsonModeStrategy,
_: Dict[str, Any], # extras is not used here
Expand All @@ -581,6 +592,7 @@ async def _call_anthropic(
api_key=self.api_keys["anthropic"],
data={"model": model_id, "messages": messages, **params},
timeout=timeout,
extra_params=extra_params,
)
return str(result["content"][0]["text"])

Expand All @@ -592,6 +604,7 @@ async def _call_cohere(
params: Dict[str, Any],
timeout: Optional[int],
memory: Optional[BaseMemory],
extra_params: Optional[Dict[str, Union[str, int, float]]],
json_mode: bool,
json_mode_strategy: JsonModeStrategy,
_: Dict[str, Any], # extras is not used here
Expand All @@ -614,6 +627,7 @@ async def _call_cohere(
api_key=self.api_keys["cohere"],
data={"model": model_id, "message": prompt, **params},
timeout=timeout,
extra_params=extra_params,
)
return str(result["text"])

Expand All @@ -637,6 +651,7 @@ async def _call_replicate(
params: Dict[str, Any],
timeout: Optional[int],
memory: Optional[BaseMemory],
extra_params: Optional[Dict[str, Union[str, int, float]]],
_: bool, # json_mode is not used here
json_mode_strategy: JsonModeStrategy,
__: Dict[str, Any], # extras is not used here
Expand All @@ -662,6 +677,7 @@ async def _call_replicate(
api_key=self.api_keys["replicate"],
data={"input": {"prompt": prompt, **params}},
timeout=timeout,
extra_params=extra_params,
)
return "".join(result["output"])

Expand All @@ -680,6 +696,7 @@ async def _generic_openai_spec_call(
params: Dict[str, Any],
timeout: Optional[int],
memory: Optional[BaseMemory],
extra_params: Optional[Dict[str, Union[str, int, float]]],
json_mode: bool,
json_mode_strategy: JsonModeStrategy,
extras: Dict[str, Any],
Expand Down Expand Up @@ -711,6 +728,7 @@ async def _generic_openai_spec_call(
api_key=self.api_keys[provider],
data={"model": model_id, "messages": messages, **params},
timeout=timeout,
extra_params=extra_params,
)
return str(result["choices"][0]["message"]["content"])

Expand Down
6 changes: 5 additions & 1 deletion l2m2/client/llm_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Any, Optional
from typing import Dict, Any, Optional, Union
import asyncio
import httpx

Expand Down Expand Up @@ -38,6 +38,7 @@ def call( # type: ignore
timeout: Optional[int] = DEFAULT_TIMEOUT_SECONDS,
bypass_memory: bool = False,
alt_memory: Optional[BaseMemory] = None,
extra_params: Optional[Dict[str, Union[str, int, float]]] = None,
) -> str:
result = asyncio.run(
self._sync_fn_wrapper(
Expand All @@ -53,6 +54,7 @@ def call( # type: ignore
timeout=timeout,
bypass_memory=bypass_memory,
alt_memory=alt_memory,
extra_params=extra_params,
)
)
return str(result)
Expand All @@ -71,6 +73,7 @@ def call_custom( # type: ignore
timeout: Optional[int] = DEFAULT_TIMEOUT_SECONDS,
bypass_memory: bool = False,
alt_memory: Optional[BaseMemory] = None,
extra_params: Optional[Dict[str, Union[str, int, float]]] = None,
) -> str:
result = asyncio.run(
self._sync_fn_wrapper(
Expand All @@ -86,6 +89,7 @@ def call_custom( # type: ignore
timeout=timeout,
bypass_memory=bypass_memory,
alt_memory=alt_memory,
extra_params=extra_params,
)
)
return str(result)
Expand Down
24 changes: 24 additions & 0 deletions tests/l2m2/_internal/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,25 @@ async def test_llm_post_success():
"test_key",
{"prompt": "test"},
timeout=10,
extra_params={},
)
assert result == {"result": "success"}


@pytest.mark.asyncio
@pytest.mark.parametrize("extra_param_value", ["bar", 123, 0.0])
async def test_llm_post_success_with_extra_params(extra_param_value):
responses = [httpx.Response(200, json={"result": "success"})]

async with httpx.AsyncClient(transport=MockTransport(responses)) as client:
result = await llm_post(
client,
"openai",
"gpt-4",
"test_key",
{"prompt": "test"},
timeout=10,
extra_params={"foo": extra_param_value},
)
assert result == {"result": "success"}

Expand All @@ -143,6 +162,7 @@ async def test_llm_post_timeout():
"test_key",
{"prompt": "test"},
timeout=10,
extra_params={},
)


Expand All @@ -161,6 +181,7 @@ async def test_llm_post_rate_limit():
"test_key",
{"prompt": "test"},
timeout=10,
extra_params={},
)


Expand All @@ -179,6 +200,7 @@ async def test_llm_post_error():
"test_key",
{"prompt": "test"},
timeout=10,
extra_params={},
)
assert str(exc_info.value) == "Bad request"

Expand Down Expand Up @@ -210,6 +232,7 @@ async def test_llm_post_replicate_success():
"test_key",
{"prompt": "test"},
timeout=10,
extra_params={},
)
assert result["status"] == "succeeded"
assert result["output"] == "test output"
Expand Down Expand Up @@ -256,6 +279,7 @@ async def test_llm_post_with_api_key_in_endpoint():
"test_key_123",
{"prompt": "test"},
timeout=10,
extra_params={},
)
assert result == {"result": "success"}
finally:
Expand Down

0 comments on commit b80f485

Please sign in to comment.