Skip to content

Commit

Permalink
fix: temperature parameter in generate_text not ignored. (#887)
Browse files Browse the repository at this point in the history
Addresses #886. 

Temperature parameter is not anymore overwritten by calling to
`self.get_temperature(n=n)`. Now it will only call that method if no
parameter was given.

---------

Co-authored-by: Ivan Herreros <[email protected]>
  • Loading branch information
HerrIvan and Ivan Herreros authored Jul 30, 2024
1 parent 1c3b6d3 commit c4ed989
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions src/ragas/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ async def agenerate_text(
self,
prompt: PromptValue,
n: int = 1,
temperature: float = 1e-8,
temperature: t.Optional[float] = None,
stop: t.Optional[t.List[str]] = None,
callbacks: Callbacks = None,
) -> LLMResult:
Expand All @@ -80,11 +80,13 @@ async def generate(
self,
prompt: PromptValue,
n: int = 1,
temperature: float = 1e-8,
temperature: t.Optional[float] = None,
stop: t.Optional[t.List[str]] = None,
callbacks: Callbacks = None,
is_async: bool = True,
) -> LLMResult:
if temperature is None:
temperature = 1e-8
"""Generate text using the given event loop."""
if is_async:
agenerate_text_with_retry = add_async_retry(
Expand Down Expand Up @@ -131,11 +133,14 @@ def generate_text(
self,
prompt: PromptValue,
n: int = 1,
temperature: float = 1e-8,
temperature: t.Optional[float] = None,
stop: t.Optional[t.List[str]] = None,
callbacks: Callbacks = None,
) -> LLMResult:
temperature = self.get_temperature(n=n)
# figure out the temperature to set
if temperature is None:
temperature = self.get_temperature(n=n)

if is_multiple_completion_supported(self.langchain_llm):
return self.langchain_llm.generate_prompt(
prompts=[prompt],
Expand All @@ -161,11 +166,12 @@ async def agenerate_text(
self,
prompt: PromptValue,
n: int = 1,
temperature: float = 1e-8,
temperature: t.Optional[float] = None,
stop: t.Optional[t.List[str]] = None,
callbacks: Callbacks = None,
) -> LLMResult:
temperature = self.get_temperature(n=n)
if temperature is None:
temperature = self.get_temperature(n=n)
if is_multiple_completion_supported(self.langchain_llm):
return await self.langchain_llm.agenerate_prompt(
prompts=[prompt],
Expand Down

0 comments on commit c4ed989

Please sign in to comment.