Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: rename model_name or model_name_or_path to model in generators #6715

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion haystack/components/generators/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __init__(
self.azure_endpoint = azure_endpoint
self.azure_deployment = azure_deployment
self.organization = organization
self.model_name: str = azure_deployment or "gpt-35-turbo"
self.model: str = azure_deployment or "gpt-35-turbo"

self.client = AzureOpenAI(
api_version=api_version,
Expand Down
2 changes: 1 addition & 1 deletion haystack/components/generators/chat/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def __init__(
self.azure_endpoint = azure_endpoint
self.azure_deployment = azure_deployment
self.organization = organization
self.model_name = azure_deployment or "gpt-35-turbo"
self.model = azure_deployment or "gpt-35-turbo"

self.client = AzureOpenAI(
api_version=api_version,
Expand Down
18 changes: 9 additions & 9 deletions haystack/components/generators/chat/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,19 @@ class OpenAIChatGenerator:
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "gpt-3.5-turbo",
model: str = "gpt-3.5-turbo",
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
api_base_url: Optional[str] = None,
organization: Optional[str] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Creates an instance of OpenAIChatGenerator. Unless specified otherwise in the `model_name`, this is for OpenAI's
Creates an instance of OpenAIChatGenerator. Unless specified otherwise in the `model`, this is for OpenAI's
GPT-3.5 model.

:param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the
environment variable OPENAI_API_KEY (recommended).
:param model_name: The name of the model to use.
:param model: The name of the model to use.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function accepts StreamingChunk as an argument.
:param api_base_url: An optional base URL.
Expand All @@ -101,7 +101,7 @@ def __init__(
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens, and the
values are the bias to add to that token.
"""
self.model_name = model_name
self.model = model
self.generation_kwargs = generation_kwargs or {}
self.streaming_callback = streaming_callback
self.api_base_url = api_base_url
Expand All @@ -112,7 +112,7 @@ def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
"""
return {"model": self.model_name}
return {"model": self.model}

def to_dict(self) -> Dict[str, Any]:
"""
Expand All @@ -122,7 +122,7 @@ def to_dict(self) -> Dict[str, Any]:
callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None
return default_to_dict(
self,
model_name=self.model_name,
model=self.model,
streaming_callback=callback_name,
api_base_url=self.api_base_url,
organization=self.organization,
Expand Down Expand Up @@ -162,7 +162,7 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
openai_formatted_messages = self._convert_to_openai_format(messages)

chat_completion: Union[Stream[ChatCompletionChunk], ChatCompletion] = self.client.chat.completions.create(
model=self.model_name,
model=self.model,
messages=openai_formatted_messages, # type: ignore # openai expects list of specific message types
stream=self.streaming_callback is not None,
**generation_kwargs,
Expand Down Expand Up @@ -335,7 +335,7 @@ class GPTChatGenerator(OpenAIChatGenerator):
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "gpt-3.5-turbo",
model: str = "gpt-3.5-turbo",
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
api_base_url: Optional[str] = None,
organization: Optional[str] = None,
Expand All @@ -349,7 +349,7 @@ def __init__(
)
super().__init__(
api_key=api_key,
model_name=model_name,
model=model,
streaming_callback=streaming_callback,
api_base_url=api_base_url,
organization=organization,
Expand Down
10 changes: 5 additions & 5 deletions haystack/components/generators/hugging_face_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class HuggingFaceLocalGenerator:
```python
from haystack.components.generators import HuggingFaceLocalGenerator

generator = HuggingFaceLocalGenerator(model_name_or_path="google/flan-t5-large",
generator = HuggingFaceLocalGenerator(model="google/flan-t5-large",
task="text2text-generation",
generation_kwargs={
"max_new_tokens": 100,
Expand All @@ -81,7 +81,7 @@ class HuggingFaceLocalGenerator:

def __init__(
self,
model_name_or_path: str = "google/flan-t5-base",
model: str = "google/flan-t5-base",
task: Optional[Literal["text-generation", "text2text-generation"]] = None,
device: Optional[str] = None,
token: Optional[Union[str, bool]] = None,
Expand All @@ -90,7 +90,7 @@ def __init__(
stop_words: Optional[List[str]] = None,
):
"""
:param model_name_or_path: The name or path of a Hugging Face model for text generation,
:param model: The name or path of a Hugging Face model for text generation,
for example, "google/flan-t5-large".
If the model is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored.
:param task: The task for the Hugging Face pipeline.
Expand All @@ -114,7 +114,7 @@ def __init__(
:param huggingface_pipeline_kwargs: Dictionary containing keyword arguments used to initialize the
Hugging Face pipeline for text generation.
These keyword arguments provide fine-grained control over the Hugging Face pipeline.
In case of duplication, these kwargs override `model_name_or_path`, `task`, `device`, and `token` init parameters.
In case of duplication, these kwargs override `model`, `task`, `device`, and `token` init parameters.
See Hugging Face's [documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline.task)
for more information on the available kwargs.
In this dictionary, you can also include `model_kwargs` to specify the kwargs
Expand All @@ -132,7 +132,7 @@ def __init__(

# check if the huggingface_pipeline_kwargs contain the essential parameters
# otherwise, populate them with values from other init parameters
huggingface_pipeline_kwargs.setdefault("model", model_name_or_path)
huggingface_pipeline_kwargs.setdefault("model", model)
huggingface_pipeline_kwargs.setdefault("token", token)
if (
device is not None
Expand Down
18 changes: 9 additions & 9 deletions haystack/components/generators/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,20 @@ class OpenAIGenerator:
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "gpt-3.5-turbo",
model: str = "gpt-3.5-turbo",
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
api_base_url: Optional[str] = None,
organization: Optional[str] = None,
system_prompt: Optional[str] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Creates an instance of OpenAIGenerator. Unless specified otherwise in the `model_name`, this is for OpenAI's
Creates an instance of OpenAIGenerator. Unless specified otherwise in the `model`, this is for OpenAI's
GPT-3.5 model.

:param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the
environment variable OPENAI_API_KEY (recommended).
:param model_name: The name of the model to use.
:param model: The name of the model to use.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function accepts StreamingChunk as an argument.
:param api_base_url: An optional base URL.
Expand Down Expand Up @@ -92,7 +92,7 @@ def __init__(
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens, and the
values are the bias to add to that token.
"""
self.model_name = model_name
self.model = model
self.generation_kwargs = generation_kwargs or {}
self.system_prompt = system_prompt
self.streaming_callback = streaming_callback
Expand All @@ -105,7 +105,7 @@ def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
"""
return {"model": self.model_name}
return {"model": self.model}

def to_dict(self) -> Dict[str, Any]:
"""
Expand All @@ -115,7 +115,7 @@ def to_dict(self) -> Dict[str, Any]:
callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None
return default_to_dict(
self,
model_name=self.model_name,
model=self.model,
streaming_callback=callback_name,
api_base_url=self.api_base_url,
generation_kwargs=self.generation_kwargs,
Expand Down Expand Up @@ -161,7 +161,7 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
openai_formatted_messages = self._convert_to_openai_format(messages)

completion: Union[Stream[ChatCompletionChunk], ChatCompletion] = self.client.chat.completions.create(
model=self.model_name,
model=self.model,
messages=openai_formatted_messages, # type: ignore
stream=self.streaming_callback is not None,
**generation_kwargs,
Expand Down Expand Up @@ -280,7 +280,7 @@ class GPTGenerator(OpenAIGenerator):
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "gpt-3.5-turbo",
model: str = "gpt-3.5-turbo",
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
api_base_url: Optional[str] = None,
organization: Optional[str] = None,
Expand All @@ -295,7 +295,7 @@ def __init__(
)
super().__init__(
api_key=api_key,
model_name=model_name,
model=model,
streaming_callback=streaming_callback,
api_base_url=api_base_url,
organization=organization,
Expand Down
2 changes: 1 addition & 1 deletion haystack/pipeline_utils/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ class _OpenAIResolved(_GeneratorResolver):
def resolve(self, model_key: str, api_key: str) -> Any:
# does the model_key match the pattern OpenAI GPT pattern?
if re.match(r"^gpt-4-.*", model_key) or re.match(r"^gpt-3.5-.*", model_key):
return OpenAIGenerator(model_name=model_key, api_key=api_key)
return OpenAIGenerator(model=model_key, api_key=api_key)
return None


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
---
upgrade:
- Rename the generator parameters `model_name` and `model_name_or_path` to `model`. This change affects all Generator classes.
26 changes: 12 additions & 14 deletions test/components/generators/chat/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class TestOpenAIChatGenerator:
def test_init_default(self):
component = OpenAIChatGenerator(api_key="test-api-key")
assert component.client.api_key == "test-api-key"
assert component.model_name == "gpt-3.5-turbo"
assert component.model == "gpt-3.5-turbo"
assert component.streaming_callback is None
assert not component.generation_kwargs

Expand All @@ -32,13 +32,13 @@ def test_init_fail_wo_api_key(self, monkeypatch):
def test_init_with_parameters(self):
component = OpenAIChatGenerator(
api_key="test-api-key",
model_name="gpt-4",
model="gpt-4",
streaming_callback=default_streaming_callback,
api_base_url="test-base-url",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
)
assert component.client.api_key == "test-api-key"
assert component.model_name == "gpt-4"
assert component.model == "gpt-4"
assert component.streaming_callback is default_streaming_callback
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}

Expand All @@ -48,7 +48,7 @@ def test_to_dict_default(self):
assert data == {
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"init_parameters": {
"model_name": "gpt-3.5-turbo",
"model": "gpt-3.5-turbo",
"organization": None,
"streaming_callback": None,
"api_base_url": None,
Expand All @@ -59,7 +59,7 @@ def test_to_dict_default(self):
def test_to_dict_with_parameters(self):
component = OpenAIChatGenerator(
api_key="test-api-key",
model_name="gpt-4",
model="gpt-4",
streaming_callback=default_streaming_callback,
api_base_url="test-base-url",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
Expand All @@ -68,7 +68,7 @@ def test_to_dict_with_parameters(self):
assert data == {
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"init_parameters": {
"model_name": "gpt-4",
"model": "gpt-4",
"organization": None,
"api_base_url": "test-base-url",
"streaming_callback": "haystack.components.generators.utils.default_streaming_callback",
Expand All @@ -79,7 +79,7 @@ def test_to_dict_with_parameters(self):
def test_to_dict_with_lambda_streaming_callback(self):
component = OpenAIChatGenerator(
api_key="test-api-key",
model_name="gpt-4",
model="gpt-4",
streaming_callback=lambda x: x,
api_base_url="test-base-url",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
Expand All @@ -88,7 +88,7 @@ def test_to_dict_with_lambda_streaming_callback(self):
assert data == {
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"init_parameters": {
"model_name": "gpt-4",
"model": "gpt-4",
"organization": None,
"api_base_url": "test-base-url",
"streaming_callback": "chat.test_openai.<lambda>",
Expand All @@ -100,14 +100,14 @@ def test_from_dict(self):
data = {
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"init_parameters": {
"model_name": "gpt-4",
"model": "gpt-4",
"api_base_url": "test-base-url",
"streaming_callback": "haystack.components.generators.utils.default_streaming_callback",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
},
}
component = OpenAIChatGenerator.from_dict(data)
assert component.model_name == "gpt-4"
assert component.model == "gpt-4"
assert component.streaming_callback is default_streaming_callback
assert component.api_base_url == "test-base-url"
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
Expand All @@ -117,7 +117,7 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch):
data = {
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"init_parameters": {
"model_name": "gpt-4",
"model": "gpt-4",
"organization": None,
"api_base_url": "test-base-url",
"streaming_callback": "haystack.components.generators.utils.default_streaming_callback",
Expand Down Expand Up @@ -222,9 +222,7 @@ def test_live_run(self):
)
@pytest.mark.integration
def test_live_run_wrong_model(self, chat_messages):
component = OpenAIChatGenerator(
model_name="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY")
)
component = OpenAIChatGenerator(model="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY"))
with pytest.raises(OpenAIError):
component.run(chat_messages)

Expand Down
Loading