Skip to content

Commit

Permalink
chore: Update Cohere integration to use new generic callable (de)seri…
Browse files Browse the repository at this point in the history
…alizers for their callback handlers (#453)

* (de)serialize callback properly

* lint

* check for None
  • Loading branch information
ZanSara authored Feb 21, 2024
1 parent ad5a290 commit af346c5
Showing 1 changed file with 6 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
#
# SPDX-License-Identifier: Apache-2.0
import logging
import sys
from typing import Any, Callable, Dict, List, Optional, cast

from haystack import DeserializationError, component, default_from_dict, default_to_dict
from haystack import component, default_from_dict, default_to_dict
from haystack.components.generators.utils import deserialize_callback_handler, serialize_callback_handler
from haystack.dataclasses import StreamingChunk
from haystack.utils import Secret, deserialize_secrets_inplace

Expand Down Expand Up @@ -89,19 +89,10 @@ def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
if self.streaming_callback:
module = self.streaming_callback.__module__
if module == "builtins":
callback_name = self.streaming_callback.__name__
else:
callback_name = f"{module}.{self.streaming_callback.__name__}"
else:
callback_name = None

return default_to_dict(
self,
model=self.model,
streaming_callback=callback_name,
streaming_callback=serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None,
api_base_url=self.api_base_url,
api_key=self.api_key.to_dict(),
**self.model_parameters,
Expand All @@ -114,20 +105,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "CohereGenerator":
"""
init_params = data.get("init_parameters", {})
deserialize_secrets_inplace(init_params, ["api_key"])
streaming_callback = None
if "streaming_callback" in init_params and init_params["streaming_callback"] is not None:
parts = init_params["streaming_callback"].split(".")
module_name = ".".join(parts[:-1])
function_name = parts[-1]
module = sys.modules.get(module_name, None)
if not module:
msg = f"Could not locate the module of the streaming callback: {module_name}"
raise DeserializationError(msg)
streaming_callback = getattr(module, function_name, None)
if not streaming_callback:
msg = f"Could not locate the streaming callback: {function_name}"
raise DeserializationError(msg)
data["init_parameters"]["streaming_callback"] = streaming_callback
data["init_parameters"]["streaming_callback"] = deserialize_callback_handler(
init_params["streaming_callback"]
)
return default_from_dict(cls, data)

@component.output_types(replies=List[str], meta=List[Dict[str, Any]])
Expand Down

0 comments on commit af346c5

Please sign in to comment.