Skip to content

Commit

Permalink
Update pydantic_prompt.py
Browse files Browse the repository at this point in the history
  • Loading branch information
kishan42 authored Dec 25, 2024
1 parent 7b32aba commit a5b2c63
Showing 1 changed file with 27 additions and 57 deletions.
84 changes: 27 additions & 57 deletions src/ragas/prompt/pydantic_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,38 +142,8 @@ async def generate_multiple(
temperature: t.Optional[float] = None,
stop: t.Optional[t.List[str]] = None,
callbacks: t.Optional[Callbacks] = None,
retries_left: int = 3,
retries_left: int = 3,
) -> t.List[OutputModel]:
"""
Generate multiple outputs using the provided language model and input data.
Parameters
----------
llm : BaseRagasLLM
The language model to use for generation.
data : InputModel
The input data for generation.
n : int, optional
The number of outputs to generate. Default is 1.
temperature : float, optional
The temperature parameter for controlling randomness in generation.
stop : List[str], optional
A list of stop sequences to end generation.
callbacks : Callbacks, optional
Callback functions to be called during the generation process.
retries_left : int, optional
Number of retry attempts for an invalid LLM response
Returns
-------
List[OutputModel]
A list of generated outputs.
Raises
------
RagasOutputParserException
If there's an error parsing the output.
"""
callbacks = callbacks or []

processed_data = self.process_input(data)
Expand All @@ -182,37 +152,37 @@ async def generate_multiple(
inputs={"data": processed_data},
callbacks=callbacks,
metadata={"type": ChainType.RAGAS_PROMPT},
)
prompt_value = PromptValue(text=self.to_string(processed_data))
)
prompt_text = self.to_string(processed_data)
resp = await llm.generate(
prompt_value,
[prompt_text], # Pass as a list of strings
n=n,
temperature=temperature,
stop=stop,
callbacks=prompt_cb,
)

output_models = []
parser = RagasOutputParser(pydantic_object=self.output_model)
for i in range(n):
output_string = resp.generations[0][i].text
try:
answer = await parser.parse_output_string(
output_string=output_string,
prompt_value=prompt_value,
llm=llm,
callbacks=prompt_cb,
retries_left=retries_left,
)
processed_output = self.process_output(answer, data) # type: ignore
output_models.append(processed_output)
except RagasOutputParserException as e:
prompt_rm.on_chain_error(error=e)
logger.error("Prompt %s failed to parse output: %s", self.name, e)
raise e

prompt_rm.on_chain_end({"output": output_models})
return output_models
)

output_models = []
parser = RagasOutputParser(pydantic_object=self.output_model)
for i in range(n):
output_string = resp.generations[0][i].text
try:
answer = await parser.parse_output_string(
output_string=output_string,
prompt_value=PromptValue(text=prompt_text),
llm=llm,
callbacks=prompt_cb,
retries_left=retries_left,
)
processed_output = self.process_output(answer, data) # type: ignore
output_models.append(processed_output)
except RagasOutputParserException as e:
prompt_rm.on_chain_error(error=e)
logger.error("Prompt %s failed to parse output: %s", self.name, e)
raise e

prompt_rm.on_chain_end({"output": output_models})
return output_models

def process_input(self, input: InputModel) -> InputModel:
return input
Expand Down

0 comments on commit a5b2c63

Please sign in to comment.