Skip to content

Commit

Permalink
change answers to replies
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 committed Mar 27, 2024
1 parent 66d0304 commit 89f68be
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class GoogleAIGeminiGenerator:
gemini = GoogleAIGeminiGenerator(model="gemini-pro", api_key=Secret.from_token("<MY_API_KEY>"))
res = gemini.run(parts = ["What is the most interesting thing you know?"])
for answer in res["answers"]:
for answer in res["replies"]:
print(answer)
```
Expand Down Expand Up @@ -55,7 +55,7 @@ class GoogleAIGeminiGenerator:
gemini = GoogleAIGeminiGenerator(model="gemini-pro-vision", api_key=Secret.from_token("<MY_API_KEY>"))
result = gemini.run(parts = ["What can you tell me about this robots?", *images])
for answer in result["answers"]:
for answer in result["replies"]:
print(answer)
```
"""
Expand Down Expand Up @@ -173,7 +173,7 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part:
msg = f"Unsupported type {type(part)} for part {part}"
raise ValueError(msg)

@component.output_types(answers=List[Union[str, Dict[str, str]]])
@component.output_types(replies=List[Union[str, Dict[str, str]]])
def run(self, parts: Variadic[Union[str, ByteStream, Part]]):
"""
Generates text based on the given input parts.
Expand All @@ -182,7 +182,7 @@ def run(self, parts: Variadic[Union[str, ByteStream, Part]]):
A heterogeneous list of strings, `ByteStream` or `Part` objects.
:returns:
A dictionary containing the following key:
- `answers`: A list of strings or dictionaries with function calls.
- `replies`: A list of strings or dictionaries with function calls.
"""

converted_parts = [self._convert_part(p) for p in parts]
Expand All @@ -194,16 +194,16 @@ def run(self, parts: Variadic[Union[str, ByteStream, Part]]):
safety_settings=self._safety_settings,
)
self._model.start_chat()
answers = []
replies = []
for candidate in res.candidates:
for part in candidate.content.parts:
if part.text != "":
answers.append(part.text)
replies.append(part.text)
elif part.function_call is not None:
function_call = {
"name": part.function_call.name,
"args": dict(part.function_call.args.items()),
}
answers.append(function_call)
replies.append(function_call)

return {"answers": answers}
return {"replies": replies}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class VertexAICodeGenerator:
result = generator.run(prefix="def to_json(data):")
for answer in result["answers"]:
for answer in result["replies"]:
print(answer)
>>> ```python
Expand Down Expand Up @@ -92,17 +92,17 @@ def from_dict(cls, data: Dict[str, Any]) -> "VertexAICodeGenerator":
"""
return default_from_dict(cls, data)

@component.output_types(answers=List[str])
@component.output_types(replies=List[str])
def run(self, prefix: str, suffix: Optional[str] = None):
"""
Generate code using a Google Vertex AI model.
:param prefix: Code before the current point.
:param suffix: Code after the current point.
:returns: A dictionary with the following keys:
- `answers`: A list of generated code snippets.
- `replies`: A list of generated code snippets.
"""
res = self._model.predict(prefix=prefix, suffix=suffix, **self._kwargs)
# Handle the case where the model returns multiple candidates
answers = [c.text for c in res.candidates] if hasattr(res, "candidates") else [res.text]
return {"answers": answers}
replies = [c.text for c in res.candidates] if hasattr(res, "candidates") else [res.text]
return {"replies": replies}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class VertexAIGeminiGenerator:
gemini = VertexAIGeminiGenerator(project_id=project_id)
result = gemini.run(parts = ["What is the most interesting thing you know?"])
for answer in result["answers"]:
for answer in result["replies"]:
print(answer)
>>> 1. **The Origin of Life:** How and where did life begin? The answers to this ...
Expand Down Expand Up @@ -175,14 +175,14 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part:
msg = f"Unsupported type {type(part)} for part {part}"
raise ValueError(msg)

@component.output_types(answers=List[Union[str, Dict[str, str]]])
@component.output_types(replies=List[Union[str, Dict[str, str]]])
def run(self, parts: Variadic[Union[str, ByteStream, Part]]):
"""
Generates content using the Gemini model.
:param parts: Prompt for the model.
:returns: A dictionary with the following keys:
- `answers`: A list of generated content.
- `replies`: A list of generated content.
"""
converted_parts = [self._convert_part(p) for p in parts]

Expand All @@ -194,16 +194,16 @@ def run(self, parts: Variadic[Union[str, ByteStream, Part]]):
tools=self._tools,
)
self._model.start_chat()
answers = []
replies = []
for candidate in res.candidates:
for part in candidate.content.parts:
if part._raw_part.text != "":
answers.append(part.text)
replies.append(part.text)
elif part.function_call is not None:
function_call = {
"name": part.function_call.name,
"args": dict(part.function_call.args.items()),
}
answers.append(function_call)
replies.append(function_call)

return {"answers": answers}
return {"replies": replies}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class VertexAITextGenerator:
generator = VertexAITextGenerator(project_id=project_id)
res = generator.run("Tell me a good interview question for a software engineer.")
print(res["answers"][0])
print(res["replies"][0])
>>> **Question:**
>>> You are given a list of integers and a target sum.
Expand Down Expand Up @@ -109,26 +109,26 @@ def from_dict(cls, data: Dict[str, Any]) -> "VertexAITextGenerator":
)
return default_from_dict(cls, data)

@component.output_types(answers=List[str], safety_attributes=Dict[str, float], citations=List[Dict[str, Any]])
@component.output_types(replies=List[str], safety_attributes=Dict[str, float], citations=List[Dict[str, Any]])
def run(self, prompt: str):
"""Prompts the model to generate text.
:param prompt: The prompt to use for text generation.
:returns: A dictionary with the following keys:
- `answers`: A list of generated answers.
- `replies`: A list of generated replies.
- `safety_attributes`: A dictionary with the [safety scores](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/responsible-ai#safety_attribute_descriptions)
of each answer.
- `citations`: A list of citations for each answer.
"""
res = self._model.predict(prompt=prompt, **self._kwargs)

answers = []
replies = []
safety_attributes = []
citations = []

for prediction in res.raw_prediction_response.predictions:
answers.append(prediction["content"])
replies.append(prediction["content"])
safety_attributes.append(prediction["safetyAttributes"])
citations.append(prediction["citationMetadata"]["citations"])

return {"answers": answers, "safety_attributes": safety_attributes, "citations": citations}
return {"replies": replies, "safety_attributes": safety_attributes, "citations": citations}

0 comments on commit 89f68be

Please sign in to comment.