diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py index 4927839d2..4811d3581 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py @@ -2,10 +2,10 @@ # # SPDX-License-Identifier: Apache-2.0 import logging -import sys from typing import Any, Callable, Dict, List, Optional, cast -from haystack import DeserializationError, component, default_from_dict, default_to_dict +from haystack import component, default_from_dict, default_to_dict +from haystack.components.generators.utils import deserialize_callback_handler, serialize_callback_handler from haystack.dataclasses import StreamingChunk from haystack.utils import Secret, deserialize_secrets_inplace @@ -89,19 +89,10 @@ def to_dict(self) -> Dict[str, Any]: """ Serialize this component to a dictionary. """ - if self.streaming_callback: - module = self.streaming_callback.__module__ - if module == "builtins": - callback_name = self.streaming_callback.__name__ - else: - callback_name = f"{module}.{self.streaming_callback.__name__}" - else: - callback_name = None - return default_to_dict( self, model=self.model, - streaming_callback=callback_name, + streaming_callback=serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None, api_base_url=self.api_base_url, api_key=self.api_key.to_dict(), **self.model_parameters, @@ -114,20 +105,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "CohereGenerator": """ init_params = data.get("init_parameters", {}) deserialize_secrets_inplace(init_params, ["api_key"]) - streaming_callback = None if "streaming_callback" in init_params and init_params["streaming_callback"] is not None: - parts = init_params["streaming_callback"].split(".") - module_name = ".".join(parts[:-1]) - function_name = parts[-1] - module = sys.modules.get(module_name, None) - if not module: - msg = f"Could not locate the module of the streaming callback: {module_name}" - raise DeserializationError(msg) - streaming_callback = getattr(module, function_name, None) - if not streaming_callback: - msg = f"Could not locate the streaming callback: {function_name}" - raise DeserializationError(msg) - data["init_parameters"]["streaming_callback"] = streaming_callback + data["init_parameters"]["streaming_callback"] = deserialize_callback_handler( + init_params["streaming_callback"] + ) return default_from_dict(cls, data) @component.output_types(replies=List[str], meta=List[Dict[str, Any]])