diff --git a/integrations/google_vertex/example_assets/robot1.jpg b/integrations/google_vertex/example_assets/robot1.jpg new file mode 100644 index 000000000..a3962db1b Binary files /dev/null and b/integrations/google_vertex/example_assets/robot1.jpg differ diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/captioner.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/captioner.py index 83322b33b..14102eb4b 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/captioner.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/captioner.py @@ -12,16 +12,44 @@ @component class VertexAIImageCaptioner: + """ + `VertexAIImageCaptioner` enables text generation using Google Vertex AI imagetext generative model. + + Authenticates using Google Cloud Application Default Credentials (ADCs). + For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc). + + Usage example: + ```python + import requests + + from haystack.dataclasses.byte_stream import ByteStream + from haystack_integrations.components.generators.google_vertex import VertexAIImageCaptioner + + captioner = VertexAIImageCaptioner(project_id=project_id) + + image = ByteStream( + data=requests.get( + "https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/main/integrations/google_vertex/example_assets/robot1.jpg" + ).content + ) + result = captioner.run(image=image) + + for caption in result["captions"]: + print(caption) + + >>> two gold robots are standing next to each other in the desert + ``` + """ + def __init__(self, *, model: str = "imagetext", project_id: str, location: Optional[str] = None, **kwargs): """ Generate image captions using a Google Vertex AI model. Authenticates using Google Cloud Application Default Credentials (ADCs). - For more information see the official Google documentation: - https://cloud.google.com/docs/authentication/provide-credentials-adc + For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc). :param project_id: ID of the GCP project to use. - :param model: Name of the model to use, defaults to "imagetext". + :param model: Name of the model to use. :param location: The default location to use when making API calls, if not set uses us-central-1. Defaults to None. :param kwargs: Additional keyword arguments to pass to the model. @@ -39,15 +67,35 @@ def __init__(self, *, model: str = "imagetext", project_id: str, location: Optio self._model = ImageTextModel.from_pretrained(self._model_name) def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ return default_to_dict( self, model=self._model_name, project_id=self._project_id, location=self._location, **self._kwargs ) @classmethod def from_dict(cls, data: Dict[str, Any]) -> "VertexAIImageCaptioner": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ return default_from_dict(cls, data) @component.output_types(captions=List[str]) def run(self, image: ByteStream): + """Prompts the model to generate captions for the given image. + + :param image: The image to generate captions for. + :returns: A dictionary with the following keys: + - `captions`: A list of captions generated by the model. + """ captions = self._model.get_captions(image=Image(image.data), **self._kwargs) return {"captions": captions} diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py index 5a6137765..f08a69b5f 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py @@ -22,6 +22,30 @@ @component class VertexAIGeminiChatGenerator: + """ + `VertexAIGeminiChatGenerator` enables chat completion using Google Gemini models. + + `VertexAIGeminiChatGenerator` supports both `gemini-pro` and `gemini-pro-vision` models. + Prompting with images requires `gemini-pro-vision`. Function calling, instead, requires `gemini-pro`. + + Authenticates using Google Cloud Application Default Credentials (ADCs). + For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc). + + Usage example: + ```python + from haystack.dataclasses import ChatMessage + from haystack_integrations.components.generators.google_vertex import VertexAIGeminiChatGenerator + + gemini_chat = VertexAIGeminiChatGenerator(project_id=project_id) + + messages = [ChatMessage.from_user("Tell me the name of a movie")] + res = gemini_chat.run(messages) + + print(res["replies"][0].content) + >>> The Shawshank Redemption + ``` + """ + def __init__( self, *, @@ -33,18 +57,25 @@ def __init__( tools: Optional[List[Tool]] = None, ): """ - Multi modal generator using Gemini model via Google Vertex AI. + `VertexAIGeminiChatGenerator` enables chat completion using Google Gemini models. Authenticates using Google Cloud Application Default Credentials (ADCs). - For more information see the official Google documentation: - https://cloud.google.com/docs/authentication/provide-credentials-adc + For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc). :param project_id: ID of the GCP project to use. :param model: Name of the model to use, defaults to "gemini-pro-vision". :param location: The default location to use when making API calls, if not set uses us-central-1. Defaults to None. - :param kwargs: Additional keyword arguments to pass to the model. - For a list of supported arguments see the `GenerativeModel.generate_content()` documentation. + :param generation_config: Configuration for the generation process. + See the [GenerationConfig documentation](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.preview.generative_models.GenerationConfig + for a list of supported arguments. + :param safety_settings: Safety settings to use when generating content. See the documentation + for [HarmBlockThreshold](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.preview.generative_models.HarmBlockThreshold) + and [HarmCategory](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.preview.generative_models.HarmCategory) + for more details. + :param tools: List of tools to use when generating content. See the documentation for + [Tool](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.preview.generative_models.Tool) + the list of supported arguments. """ # Login to GCP. This will fail if user has not set up their gcloud SDK @@ -84,6 +115,12 @@ def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, A } def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ data = default_to_dict( self, model=self._model_name, @@ -101,6 +138,14 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_dict(cls, data: Dict[str, Any]) -> "VertexAIGeminiChatGenerator": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ if (tools := data["init_parameters"].get("tools")) is not None: data["init_parameters"]["tools"] = [Tool.from_dict(t) for t in tools] if (generation_config := data["init_parameters"].get("generation_config")) is not None: @@ -151,6 +196,12 @@ def _message_to_content(self, message: ChatMessage) -> Content: @component.output_types(replies=List[ChatMessage]) def run(self, messages: List[ChatMessage]): + """Prompts Google Vertex AI Gemini model to generate a response to a list of messages. + + :param messages: The last message is the prompt, the rest are the history. + :returns: A dictionary with the following keys: + - `replies`: A list of ChatMessage objects representing the model's replies. + """ history = [self._message_to_content(m) for m in messages[:-1]] session = self._model.start_chat(history=history) diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/code_generator.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/code_generator.py index 1914af289..f8889373c 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/code_generator.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/code_generator.py @@ -11,18 +11,50 @@ @component class VertexAICodeGenerator: + """ + This component enables code generation using Google Vertex AI generative model. + + `VertexAICodeGenerator` supports `code-bison`, `code-bison-32k`, and `code-gecko`. + + Usage example: + ```python + from haystack_integrations.components.generators.google_vertex import VertexAICodeGenerator + + generator = VertexAICodeGenerator(project_id=project_id) + + result = generator.run(prefix="def to_json(data):") + + for answer in result["answers"]: + print(answer) + + >>> ```python + >>> import json + >>> + >>> def to_json(data): + >>> \"\"\"Converts a Python object to a JSON string. + >>> + >>> Args: + >>> data: The Python object to convert. + >>> + >>> Returns: + >>> A JSON string representing the Python object. + >>> \"\"\" + >>> + >>> return json.dumps(data) + >>> ``` + ``` + """ + def __init__(self, *, model: str = "code-bison", project_id: str, location: Optional[str] = None, **kwargs): """ Generate code using a Google Vertex AI model. Authenticates using Google Cloud Application Default Credentials (ADCs). - For more information see the official Google documentation: - https://cloud.google.com/docs/authentication/provide-credentials-adc + For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc). :param project_id: ID of the GCP project to use. - :param model: Name of the model to use, defaults to "text-bison". + :param model: Name of the model to use. :param location: The default location to use when making API calls, if not set uses us-central-1. - Defaults to None. :param kwargs: Additional keyword arguments to pass to the model. For a list of supported arguments see the `TextGenerationModel.predict()` documentation. """ @@ -38,16 +70,38 @@ def __init__(self, *, model: str = "code-bison", project_id: str, location: Opti self._model = CodeGenerationModel.from_pretrained(self._model_name) def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ return default_to_dict( self, model=self._model_name, project_id=self._project_id, location=self._location, **self._kwargs ) @classmethod def from_dict(cls, data: Dict[str, Any]) -> "VertexAICodeGenerator": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ return default_from_dict(cls, data) @component.output_types(answers=List[str]) def run(self, prefix: str, suffix: Optional[str] = None): + """ + Generate code using a Google Vertex AI model. + + :param prefix: Code before the current point. + :param suffix: Code after the current point. + :returns: A dictionary with the following keys: + - `answers`: A list of generated code snippets. + """ res = self._model.predict(prefix=prefix, suffix=suffix, **self._kwargs) # Handle the case where the model returns multiple candidates answers = [c.text for c in res.candidates] if hasattr(res, "candidates") else [res.text] diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py index 698b07b01..1383f176d 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py @@ -22,6 +22,35 @@ @component class VertexAIGeminiGenerator: + """ + `VertexAIGeminiGenerator` enables text generation using Google Gemini models. + + `VertexAIGeminiGenerator` supports both `gemini-pro` and `gemini-pro-vision` models. + Prompting with images requires `gemini-pro-vision`. Function calling, instead, requires `gemini-pro`. + + Usage example: + ```python + from haystack_integrations.components.generators.google_vertex import VertexAIGeminiGenerator + + + gemini = VertexAIGeminiGenerator(project_id=project_id) + result = gemini.run(parts = ["What is the most interesting thing you know?"]) + for answer in result["answers"]: + print(answer) + + >>> 1. **The Origin of Life:** How and where did life begin? The answers to this ... + >>> 2. **The Unseen Universe:** The vast majority of the universe is ... + >>> 3. **Quantum Entanglement:** This eerie phenomenon in quantum mechanics allows ... + >>> 4. **Time Dilation:** Einstein's theory of relativity revealed that time can ... + >>> 5. **The Fermi Paradox:** Despite the vastness of the universe and the ... + >>> 6. **Biological Evolution:** The idea that life evolves over time through natural ... + >>> 7. **Neuroplasticity:** The brain's ability to adapt and change throughout life, ... + >>> 8. **The Goldilocks Zone:** The concept of the habitable zone, or the Goldilocks zone, ... + >>> 9. **String Theory:** This theoretical framework in physics aims to unify all ... + >>> 10. **Consciousness:** The nature of human consciousness and how it arises ... + ``` + """ + def __init__( self, *, @@ -33,18 +62,17 @@ def __init__( tools: Optional[List[Tool]] = None, ): """ - Multi modal generator using Gemini model via Google Vertex AI. + Multi-modal generator using Gemini model via Google Vertex AI. Authenticates using Google Cloud Application Default Credentials (ADCs). - For more information see the official Google documentation: - https://cloud.google.com/docs/authentication/provide-credentials-adc + For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc). :param project_id: ID of the GCP project to use. - :param model: Name of the model to use, defaults to "gemini-pro-vision". + :param model: Name of the model to use. :param location: The default location to use when making API calls, if not set uses us-central-1. - Defaults to None. - :param generation_config: The generation config to use, defaults to None. - Can either be a GenerationConfig object or a dictionary of parameters. + :param generation_config: The generation config to use. + Can either be a [`GenerationConfig`](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.preview.generative_models.GenerationConfig) + object or a dictionary of parameters. Accepted fields are: - temperature - top_p @@ -52,10 +80,13 @@ def __init__( - candidate_count - max_output_tokens - stop_sequences - :param safety_settings: The safety settings to use, defaults to None. - A dictionary of HarmCategory to HarmBlockThreshold. - :param tools: The tools to use, defaults to None. - A list of Tool objects that can be used to modify the generation process. + :param safety_settings: The safety settings to use. See the documentation + for [HarmBlockThreshold](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.preview.generative_models.HarmBlockThreshold) + and [HarmCategory](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.preview.generative_models.HarmCategory) + for more details. + :param tools: List of tools to use when generating content. See the documentation for + [Tool](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.preview.generative_models.Tool) + the list of supported arguments. """ # Login to GCP. This will fail if user has not set up their gcloud SDK @@ -95,6 +126,12 @@ def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, A } def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ data = default_to_dict( self, model=self._model_name, @@ -112,6 +149,14 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_dict(cls, data: Dict[str, Any]) -> "VertexAIGeminiGenerator": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ if (tools := data["init_parameters"].get("tools")) is not None: data["init_parameters"]["tools"] = [Tool.from_dict(t) for t in tools] if (generation_config := data["init_parameters"].get("generation_config")) is not None: @@ -132,6 +177,13 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: @component.output_types(answers=List[Union[str, Dict[str, str]]]) def run(self, parts: Variadic[Union[str, ByteStream, Part]]): + """ + Generates content using the Gemini model. + + :param parts: Prompt for the model. + :returns: A dictionary with the following keys: + - `answers`: A list of generated content. + """ converted_parts = [self._convert_part(p) for p in parts] contents = [Content(parts=converted_parts, role="user")] diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/image_generator.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/image_generator.py index c81c88fe8..422e1cfe6 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/image_generator.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/image_generator.py @@ -12,18 +12,34 @@ @component class VertexAIImageGenerator: + """ + This component enables image generation using Google Vertex AI generative model. + + Authenticates using Google Cloud Application Default Credentials (ADCs). + For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc). + + Usage example: + ```python + from pathlib import Path + + from haystack_integrations.components.generators.google_vertex import VertexAIImageGenerator + + generator = VertexAIImageGenerator(project_id=project_id) + result = generator.run(prompt="Generate an image of a cute cat") + result["images"][0].to_file(Path("my_image.png")) + ``` + """ + def __init__(self, *, model: str = "imagegeneration", project_id: str, location: Optional[str] = None, **kwargs): """ Generates images using a Google Vertex AI model. Authenticates using Google Cloud Application Default Credentials (ADCs). - For more information see the official Google documentation: - https://cloud.google.com/docs/authentication/provide-credentials-adc + For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc). :param project_id: ID of the GCP project to use. - :param model: Name of the model to use, defaults to "imagegeneration". + :param model: Name of the model to use. :param location: The default location to use when making API calls, if not set uses us-central-1. - Defaults to None. :param kwargs: Additional keyword arguments to pass to the model. For a list of supported arguments see the `ImageGenerationModel.generate_images()` documentation. """ @@ -39,16 +55,38 @@ def __init__(self, *, model: str = "imagegeneration", project_id: str, location: self._model = ImageGenerationModel.from_pretrained(self._model_name) def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ return default_to_dict( self, model=self._model_name, project_id=self._project_id, location=self._location, **self._kwargs ) @classmethod def from_dict(cls, data: Dict[str, Any]) -> "VertexAIImageGenerator": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ return default_from_dict(cls, data) @component.output_types(images=List[ByteStream]) def run(self, prompt: str, negative_prompt: Optional[str] = None): + """Produces images based on the given prompt. + + :param prompt: The prompt to generate images from. + :param negative_prompt: A description of what you want to omit in + the generated images. + :returns: A dictionary with the following keys: + - images: A list of ByteStream objects, each containing an image. + """ negative_prompt = negative_prompt or self._kwargs.get("negative_prompt") res = self._model.generate_images(prompt=prompt, negative_prompt=negative_prompt, **self._kwargs) images = [ByteStream(data=i._image_bytes, meta=i.generation_parameters) for i in res.images] diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/question_answering.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/question_answering.py index 276364227..79c343b02 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/question_answering.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/question_answering.py @@ -12,18 +12,39 @@ @component class VertexAIImageQA: + """ + This component enables text generation (image captioning) using Google Vertex AI generative models. + + Authenticates using Google Cloud Application Default Credentials (ADCs). + For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc). + + Usage example: + ```python + from haystack.dataclasses.byte_stream import ByteStream + from haystack_integrations.components.generators.google_vertex import VertexAIImageQA + + qa = VertexAIImageQA(project_id=project_id) + + image = ByteStream.from_file_path("dog.jpg") + + res = qa.run(image=image, question="What color is this dog") + + print(res["answers"][0]) + + >>> white + ``` + """ + def __init__(self, *, model: str = "imagetext", project_id: str, location: Optional[str] = None, **kwargs): """ Answers questions about an image using a Google Vertex AI model. Authenticates using Google Cloud Application Default Credentials (ADCs). - For more information see the official Google documentation: - https://cloud.google.com/docs/authentication/provide-credentials-adc + For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc). :param project_id: ID of the GCP project to use. - :param model: Name of the model to use, defaults to "imagetext". + :param model: Name of the model to use. :param location: The default location to use when making API calls, if not set uses us-central-1. - Defaults to None. :param kwargs: Additional keyword arguments to pass to the model. For a list of supported arguments see the `ImageTextModel.ask_question()` documentation. """ @@ -39,15 +60,36 @@ def __init__(self, *, model: str = "imagetext", project_id: str, location: Optio self._model = ImageTextModel.from_pretrained(self._model_name) def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ return default_to_dict( self, model=self._model_name, project_id=self._project_id, location=self._location, **self._kwargs ) @classmethod def from_dict(cls, data: Dict[str, Any]) -> "VertexAIImageQA": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ return default_from_dict(cls, data) @component.output_types(answers=List[str]) def run(self, image: ByteStream, question: str): + """Prompts model to answer a question about an image. + + :param image: The image to ask the question about. + :param question: The question to ask. + :returns: A dictionary with the following keys: + - answers: A list of answers to the question. + """ answers = self._model.ask_question(image=Image(image.data), question=question, **self._kwargs) return {"answers": answers} diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/text_generator.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/text_generator.py index 6022bcf4f..e16954f8f 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/text_generator.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/text_generator.py @@ -13,18 +13,48 @@ @component class VertexAITextGenerator: + """ + This component enables text generation using Google Vertex AI generative models. + + `VertexAITextGenerator` supports `text-bison`, `text-unicorn` and `text-bison-32k` models. + + Authenticates using Google Cloud Application Default Credentials (ADCs). + For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc). + + Usage example: + ```python + from haystack_integrations.components.generators.google_vertex import VertexAITextGenerator + + generator = VertexAITextGenerator(project_id=project_id) + res = generator.run("Tell me a good interview question for a software engineer.") + + print(res["answers"][0]) + + >>> **Question:** + >>> You are given a list of integers and a target sum. + >>> Find all unique combinations of numbers in the list that add up to the target sum. + >>> + >>> **Example:** + >>> + >>> ``` + >>> Input: [1, 2, 3, 4, 5], target = 7 + >>> Output: [[1, 2, 4], [3, 4]] + >>> ``` + >>> + >>> **Follow-up:** What if the list contains duplicate numbers? + ``` + """ + def __init__(self, *, model: str = "text-bison", project_id: str, location: Optional[str] = None, **kwargs): """ Generate text using a Google Vertex AI model. Authenticates using Google Cloud Application Default Credentials (ADCs). - For more information see the official Google documentation: - https://cloud.google.com/docs/authentication/provide-credentials-adc + For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc). :param project_id: ID of the GCP project to use. - :param model: Name of the model to use, defaults to "text-bison". + :param model: Name of the model to use. :param location: The default location to use when making API calls, if not set uses us-central-1. - Defaults to None. :param kwargs: Additional keyword arguments to pass to the model. For a list of supported arguments see the `TextGenerationModel.predict()` documentation. """ @@ -40,6 +70,12 @@ def __init__(self, *, model: str = "text-bison", project_id: str, location: Opti self._model = TextGenerationModel.from_pretrained(self._model_name) def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ data = default_to_dict( self, model=self._model_name, project_id=self._project_id, location=self._location, **self._kwargs ) @@ -57,6 +93,14 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_dict(cls, data: Dict[str, Any]) -> "VertexAITextGenerator": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ if (grounding_source := data["init_parameters"].get("grounding_source")) is not None: module_name, class_name = grounding_source["type"].rsplit(".", 1) module = importlib.import_module(module_name) @@ -67,6 +111,15 @@ def from_dict(cls, data: Dict[str, Any]) -> "VertexAITextGenerator": @component.output_types(answers=List[str], safety_attributes=Dict[str, float], citations=List[Dict[str, Any]]) def run(self, prompt: str): + """Prompts the model to generate text. + + :param prompt: The prompt to use for text generation. + :returns: A dictionary with the following keys: + - answers: A list of generated answers. + - safety_attributes: A dictionary with the [safety scores](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/responsible-ai#safety_attribute_descriptions) + of each answer. + - citations: A list of citations for each answer. + """ res = self._model.predict(prompt=prompt, **self._kwargs) answers = []