Skip to content

Commit

Permalink
Merge pull request #104 from PrefectHQ/mypy-and-imports
Browse files Browse the repository at this point in the history
mypy init + refactor llm import logic + tests
  • Loading branch information
aaazzam authored Jun 13, 2024
2 parents 0c749a3 + a067e29 commit 1b51249
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 50 deletions.
1 change: 1 addition & 0 deletions .github/workflows/static-analysis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@ jobs:
python-version: "3.9"
- name: Run pre-commit
uses: pre-commit/[email protected]

14 changes: 13 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,16 @@ repos:
- id: ruff
args: [ --fix ]
# Run the formatter.
- id: ruff-format
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.9.0
hooks:
- id: mypy
additional_dependencies:
- pydantic>=2,<3.0.0
- langchain_core
- langchain_anthropic
- langchain_openai
- langchain_google_genai
files: ^(src/controlflow/llm/models.py)$
args: [--strict]
4 changes: 4 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[mypy]
strict=true
follow_imports = skip
files = 'src/controlflow/llm/models.py', 'tests/llm/test_models.py'
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ tests = [
"pytest-sugar>=0.9,<2.0",
"pytest>=7.0",
"pytest-timeout",
"pytest_mock",
"pytest-xdist",
"pandas",
]
Expand All @@ -57,6 +58,7 @@ dev = [
"pre-commit",
"ruff>=0.3.4",
"textual-dev",
"mypy",
]

[build-system]
Expand Down
98 changes: 49 additions & 49 deletions src/controlflow/llm/models.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,61 @@
from importlib import import_module
from typing import TYPE_CHECKING, Any, Optional, Union

from langchain_core.language_models import BaseChatModel

import controlflow

if TYPE_CHECKING:
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import AzureChatOpenAI, ChatOpenAI

_model_registry: dict[str, tuple[str, str]] = {
"openai": ("langchain_openai", "ChatOpenAI"),
"azure_openai": ("langchain_openai", "AzureChatOpenAI"),
"anthropic": ("langchain_anthropic", "ChatAnthropic"),
"google": ("langchain_google_genai", "ChatGoogleGenerativeAI"),
}


def get_provider_from_string(
provider: str,
) -> Union[
type["ChatOpenAI"],
type["AzureChatOpenAI"],
type["ChatAnthropic"],
type["ChatGoogleGenerativeAI"],
]:
module_name, class_name = _model_registry.get(provider, ("openai", ""))
if not class_name:
raise ValueError(
f"Could not load provider automatically: {provider}. Please create your model manually."
)
try:
module = import_module(module_name)
except ImportError:
raise ImportError(
f"To use {provider} models, please install the `{module_name}` package."
)
return getattr(module, class_name) # type: ignore[no-any-return]


def get_model_from_string(
model: Optional[str] = None, temperature: Optional[float] = None, **kwargs: Any
) -> BaseChatModel:
provider, _, model = (model or controlflow.settings.llm_model).partition("/")
return get_provider_from_string(provider=provider)(
name=model or controlflow.settings.llm_model,
temperature=temperature or controlflow.settings.llm_temperature,
**kwargs,
)


def get_default_model() -> BaseChatModel:
if controlflow.default_model is None:
return model_from_string(controlflow.settings.llm_model)
return get_model_from_string(controlflow.settings.llm_model)
else:
return controlflow.default_model


def model_from_string(model: str, temperature: float = None, **kwargs) -> BaseChatModel:
if "/" not in model:
provider, model = "openai", model
provider, model = model.split("/")

if temperature is None:
temperature = controlflow.settings.llm_temperature

if provider == "openai":
try:
from langchain_openai import ChatOpenAI
except ImportError:
raise ImportError(
"To use OpenAI models, please install the `langchain-openai` package."
)
cls = ChatOpenAI
elif provider == "azure-openai":
try:
from langchain_openai import AzureChatOpenAI
except ImportError:
raise ImportError(
"To use Azure OpenAI models, please install the `langchain-openai` package."
)
cls = AzureChatOpenAI
elif provider == "anthropic":
try:
from langchain_anthropic import ChatAnthropic
except ImportError:
raise ImportError(
"To use Anthropic models, please install the `langchain-anthropic` package."
)
cls = ChatAnthropic
elif provider == "google":
try:
from langchain_google_genai import ChatGoogleGenerativeAI
except ImportError:
raise ImportError(
"To use Google models, please install the `langchain_google_genai` package."
)
cls = ChatGoogleGenerativeAI
else:
raise ValueError(
f"Could not load provider automatically: {provider}. Please create your model manually."
)

return cls(model=model, temperature=temperature, **kwargs)


DEFAULT_MODEL = None
79 changes: 79 additions & 0 deletions tests/llm/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import pytest
import pytest_mock
from controlflow.llm.models import (
get_default_model,
get_model_from_string,
get_provider_from_string,
)


def test_get_provider_from_string_openai() -> None:
provider = get_provider_from_string("openai")
assert provider.__name__ == "ChatOpenAI"


def test_get_provider_from_string_azure_openai() -> None:
provider = get_provider_from_string("azure_openai")
assert provider.__name__ == "AzureChatOpenAI"


def test_get_provider_from_string_anthropic() -> None:
pytest.importorskip("langchain_anthropic")
provider = get_provider_from_string("anthropic")
assert provider.__name__ == "ChatAnthropic"


def test_get_provider_from_string_google() -> None:
pytest.importorskip("langchain_google_genai")
provider = get_provider_from_string("google")
assert provider.__name__ == "ChatGoogleGenerativeAI"


def test_get_provider_from_string_invalid_provider() -> None:
with pytest.raises(ValueError):
get_provider_from_string("invalid_provider.gpt4")


def test_get_provider_from_string_missing_module() -> None:
with pytest.raises(ImportError):
get_provider_from_string("openai.missing_module")


def test_get_model_from_string(mocker: pytest_mock.MockFixture) -> None:
# Test getting a model from string
mock_provider_class = mocker.Mock()
mock_provider_instance = mocker.Mock()
mock_provider_class.return_value = mock_provider_instance
mocker.patch(
"controlflow.llm.models.get_provider_from_string",
return_value=mock_provider_class,
)
model = get_model_from_string("openai/davinci", temperature=0.5)
assert model == mock_provider_instance
mock_provider_class.assert_called_once_with(
name="davinci",
temperature=0.5,
)

# Test getting a model with default settings
mock_provider_class.reset_mock()
mocker.patch("controlflow.settings.settings.llm_model", "anthropic/claude")
mocker.patch("controlflow.settings.settings.llm_temperature", 0.7)
pytest.importorskip(
"langchain_anthropic"
) # Skip if langchain_anthropic is not installed
model = get_model_from_string()
assert model == mock_provider_instance
mock_provider_class.assert_called_once_with(
name="claude",
temperature=0.7,
)


def test_get_default_model(mocker: pytest_mock.MockFixture) -> None:
# Test getting the default model
mock_get_model_from_string = mocker.patch(
"controlflow.llm.models.get_model_from_string"
)
get_default_model()
mock_get_model_from_string.assert_called_once_with()

0 comments on commit 1b51249

Please sign in to comment.