-
Notifications
You must be signed in to change notification settings - Fork 63
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #104 from PrefectHQ/mypy-and-imports
mypy init + refactor llm import logic + tests
- Loading branch information
Showing
6 changed files
with
148 additions
and
50 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,3 +29,4 @@ jobs: | |
python-version: "3.9" | ||
- name: Run pre-commit | ||
uses: pre-commit/[email protected] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
[mypy] | ||
strict=true | ||
follow_imports = skip | ||
files = 'src/controlflow/llm/models.py', 'tests/llm/test_models.py' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,61 +1,61 @@ | ||
from importlib import import_module | ||
from typing import TYPE_CHECKING, Any, Optional, Union | ||
|
||
from langchain_core.language_models import BaseChatModel | ||
|
||
import controlflow | ||
|
||
if TYPE_CHECKING: | ||
from langchain_anthropic import ChatAnthropic | ||
from langchain_google_genai import ChatGoogleGenerativeAI | ||
from langchain_openai import AzureChatOpenAI, ChatOpenAI | ||
|
||
_model_registry: dict[str, tuple[str, str]] = { | ||
"openai": ("langchain_openai", "ChatOpenAI"), | ||
"azure_openai": ("langchain_openai", "AzureChatOpenAI"), | ||
"anthropic": ("langchain_anthropic", "ChatAnthropic"), | ||
"google": ("langchain_google_genai", "ChatGoogleGenerativeAI"), | ||
} | ||
|
||
|
||
def get_provider_from_string( | ||
provider: str, | ||
) -> Union[ | ||
type["ChatOpenAI"], | ||
type["AzureChatOpenAI"], | ||
type["ChatAnthropic"], | ||
type["ChatGoogleGenerativeAI"], | ||
]: | ||
module_name, class_name = _model_registry.get(provider, ("openai", "")) | ||
if not class_name: | ||
raise ValueError( | ||
f"Could not load provider automatically: {provider}. Please create your model manually." | ||
) | ||
try: | ||
module = import_module(module_name) | ||
except ImportError: | ||
raise ImportError( | ||
f"To use {provider} models, please install the `{module_name}` package." | ||
) | ||
return getattr(module, class_name) # type: ignore[no-any-return] | ||
|
||
|
||
def get_model_from_string( | ||
model: Optional[str] = None, temperature: Optional[float] = None, **kwargs: Any | ||
) -> BaseChatModel: | ||
provider, _, model = (model or controlflow.settings.llm_model).partition("/") | ||
return get_provider_from_string(provider=provider)( | ||
name=model or controlflow.settings.llm_model, | ||
temperature=temperature or controlflow.settings.llm_temperature, | ||
**kwargs, | ||
) | ||
|
||
|
||
def get_default_model() -> BaseChatModel: | ||
if controlflow.default_model is None: | ||
return model_from_string(controlflow.settings.llm_model) | ||
return get_model_from_string(controlflow.settings.llm_model) | ||
else: | ||
return controlflow.default_model | ||
|
||
|
||
def model_from_string(model: str, temperature: float = None, **kwargs) -> BaseChatModel: | ||
if "/" not in model: | ||
provider, model = "openai", model | ||
provider, model = model.split("/") | ||
|
||
if temperature is None: | ||
temperature = controlflow.settings.llm_temperature | ||
|
||
if provider == "openai": | ||
try: | ||
from langchain_openai import ChatOpenAI | ||
except ImportError: | ||
raise ImportError( | ||
"To use OpenAI models, please install the `langchain-openai` package." | ||
) | ||
cls = ChatOpenAI | ||
elif provider == "azure-openai": | ||
try: | ||
from langchain_openai import AzureChatOpenAI | ||
except ImportError: | ||
raise ImportError( | ||
"To use Azure OpenAI models, please install the `langchain-openai` package." | ||
) | ||
cls = AzureChatOpenAI | ||
elif provider == "anthropic": | ||
try: | ||
from langchain_anthropic import ChatAnthropic | ||
except ImportError: | ||
raise ImportError( | ||
"To use Anthropic models, please install the `langchain-anthropic` package." | ||
) | ||
cls = ChatAnthropic | ||
elif provider == "google": | ||
try: | ||
from langchain_google_genai import ChatGoogleGenerativeAI | ||
except ImportError: | ||
raise ImportError( | ||
"To use Google models, please install the `langchain_google_genai` package." | ||
) | ||
cls = ChatGoogleGenerativeAI | ||
else: | ||
raise ValueError( | ||
f"Could not load provider automatically: {provider}. Please create your model manually." | ||
) | ||
|
||
return cls(model=model, temperature=temperature, **kwargs) | ||
|
||
|
||
DEFAULT_MODEL = None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import pytest | ||
import pytest_mock | ||
from controlflow.llm.models import ( | ||
get_default_model, | ||
get_model_from_string, | ||
get_provider_from_string, | ||
) | ||
|
||
|
||
def test_get_provider_from_string_openai() -> None: | ||
provider = get_provider_from_string("openai") | ||
assert provider.__name__ == "ChatOpenAI" | ||
|
||
|
||
def test_get_provider_from_string_azure_openai() -> None: | ||
provider = get_provider_from_string("azure_openai") | ||
assert provider.__name__ == "AzureChatOpenAI" | ||
|
||
|
||
def test_get_provider_from_string_anthropic() -> None: | ||
pytest.importorskip("langchain_anthropic") | ||
provider = get_provider_from_string("anthropic") | ||
assert provider.__name__ == "ChatAnthropic" | ||
|
||
|
||
def test_get_provider_from_string_google() -> None: | ||
pytest.importorskip("langchain_google_genai") | ||
provider = get_provider_from_string("google") | ||
assert provider.__name__ == "ChatGoogleGenerativeAI" | ||
|
||
|
||
def test_get_provider_from_string_invalid_provider() -> None: | ||
with pytest.raises(ValueError): | ||
get_provider_from_string("invalid_provider.gpt4") | ||
|
||
|
||
def test_get_provider_from_string_missing_module() -> None: | ||
with pytest.raises(ImportError): | ||
get_provider_from_string("openai.missing_module") | ||
|
||
|
||
def test_get_model_from_string(mocker: pytest_mock.MockFixture) -> None: | ||
# Test getting a model from string | ||
mock_provider_class = mocker.Mock() | ||
mock_provider_instance = mocker.Mock() | ||
mock_provider_class.return_value = mock_provider_instance | ||
mocker.patch( | ||
"controlflow.llm.models.get_provider_from_string", | ||
return_value=mock_provider_class, | ||
) | ||
model = get_model_from_string("openai/davinci", temperature=0.5) | ||
assert model == mock_provider_instance | ||
mock_provider_class.assert_called_once_with( | ||
name="davinci", | ||
temperature=0.5, | ||
) | ||
|
||
# Test getting a model with default settings | ||
mock_provider_class.reset_mock() | ||
mocker.patch("controlflow.settings.settings.llm_model", "anthropic/claude") | ||
mocker.patch("controlflow.settings.settings.llm_temperature", 0.7) | ||
pytest.importorskip( | ||
"langchain_anthropic" | ||
) # Skip if langchain_anthropic is not installed | ||
model = get_model_from_string() | ||
assert model == mock_provider_instance | ||
mock_provider_class.assert_called_once_with( | ||
name="claude", | ||
temperature=0.7, | ||
) | ||
|
||
|
||
def test_get_default_model(mocker: pytest_mock.MockFixture) -> None: | ||
# Test getting the default model | ||
mock_get_model_from_string = mocker.patch( | ||
"controlflow.llm.models.get_model_from_string" | ||
) | ||
get_default_model() | ||
mock_get_model_from_string.assert_called_once_with() |