From 41ee3be95f51d18b51f5f05874e8bcef0f673e47 Mon Sep 17 00:00:00 2001 From: Harutaka Kawamura Date: Sat, 2 Dec 2023 12:27:18 +0900 Subject: [PATCH] langchain[patch]: Support passing parameters to `llms.Databricks` and `llms.Mlflow` (#14100) Before, we need to use `params` to pass extra parameters: ```python from langchain.llms import Databricks Databricks(..., params={"temperature": 0.0}) ``` Now, we can directly specify extra params: ```python from langchain.llms import Databricks Databricks(..., temperature=0.0) ``` --- .../langchain/langchain/chat_models/mlflow.py | 4 +- libs/langchain/langchain/llms/databricks.py | 77 +++++++++++++------ libs/langchain/langchain/llms/mlflow.py | 34 +++++--- 3 files changed, 82 insertions(+), 33 deletions(-) diff --git a/libs/langchain/langchain/chat_models/mlflow.py b/libs/langchain/langchain/chat_models/mlflow.py index 4b40286c7ebc4..e1c1ad1542bd7 100644 --- a/libs/langchain/langchain/chat_models/mlflow.py +++ b/libs/langchain/langchain/chat_models/mlflow.py @@ -118,8 +118,10 @@ def _generate( "stop": stop or self.stop, "max_tokens": self.max_tokens, **self.extra_params, + **kwargs, } - + if stop := self.stop or stop: + data["stop"] = stop resp = self._client.predict(endpoint=self.endpoint, inputs=data) return ChatMlflow._create_chat_result(resp) diff --git a/libs/langchain/langchain/llms/databricks.py b/libs/langchain/langchain/llms/databricks.py index 55dff84391718..d83e67a6cc8bf 100644 --- a/libs/langchain/langchain/llms/databricks.py +++ b/libs/langchain/langchain/llms/databricks.py @@ -46,6 +46,10 @@ def post( ) -> Any: ... + @property + def llm(self) -> bool: + return False + def _transform_completions(response: Dict[str, Any]) -> str: return response["choices"][0]["text"] @@ -85,6 +89,10 @@ def __init__(self, **data: Any): ) self.task = endpoint.get("task") + @property + def llm(self) -> bool: + return self.task in ("llm/v1/chat", "llm/v1/completions") + @root_validator(pre=True) def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]: if "api_url" not in values: @@ -137,8 +145,11 @@ def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]: values["api_url"] = api_url return values - def post(self, request: Any, transform: Optional[Callable[..., str]] = None) -> Any: - return self._post(self.api_url, request) + def post( + self, request: Any, transform_output_fn: Optional[Callable[..., str]] = None + ) -> Any: + resp = self._post(self.api_url, request) + return transform_output_fn(resp) if transform_output_fn else resp def get_repl_context() -> Any: @@ -285,12 +296,10 @@ class Databricks(LLM): We recommend the server using a port number between ``[3000, 8000]``. """ - params: Optional[Dict[str, Any]] = None - """Extra parameters to pass to the endpoint.""" - model_kwargs: Optional[Dict[str, Any]] = None """ - Deprecated. Please use ``params`` instead. Extra parameters to pass to the endpoint. + Deprecated. Please use ``extra_params`` instead. Extra parameters to pass to + the endpoint. """ transform_input_fn: Optional[Callable] = None @@ -306,12 +315,34 @@ class Databricks(LLM): databricks_uri: str = "databricks" """The databricks URI. Only used when using a serving endpoint.""" + temperature: float = 0.0 + """The sampling temperature.""" + n: int = 1 + """The number of completion choices to generate.""" + stop: Optional[List[str]] = None + """The stop sequence.""" + max_tokens: Optional[int] = None + """The maximum number of tokens to generate.""" + extra_params: Dict[str, Any] = Field(default_factory=dict) + """Any extra parameters to pass to the endpoint.""" + _client: _DatabricksClientBase = PrivateAttr() class Config: extra = Extra.forbid underscore_attrs_are_private = True + @property + def _llm_params(self) -> Dict[str, Any]: + params = { + "temperature": self.temperature, + "n": self.n, + "stop": self.stop, + "max_tokens": self.max_tokens, + **(self.model_kwargs or self.extra_params), + } + return params + @validator("cluster_id", always=True) def set_cluster_id(cls, v: Any, values: Dict[str, Any]) -> Optional[str]: if v and values["endpoint_name"]: @@ -356,11 +387,11 @@ def set_model_kwargs(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any def __init__(self, **data: Any): super().__init__(**data) - if self.model_kwargs is not None and self.params is not None: - raise ValueError("Cannot set both model_kwargs and params.") + if self.model_kwargs is not None and self.extra_params is not None: + raise ValueError("Cannot set both extra_params and extra_params.") elif self.model_kwargs is not None: warnings.warn( - "model_kwargs is deprecated. Please use params instead.", + "model_kwargs is deprecated. Please use extra_params instead.", DeprecationWarning, ) if self.endpoint_name: @@ -382,10 +413,6 @@ def __init__(self, **data: Any): "Must specify either endpoint_name or cluster_id/cluster_driver_port." ) - @property - def _params(self) -> Optional[Dict[str, Any]]: - return self.model_kwargs or self.params - @property def _default_params(self) -> Dict[str, Any]: """Return default params.""" @@ -397,7 +424,11 @@ def _default_params(self) -> Dict[str, Any]: "cluster_driver_port": self.cluster_driver_port, "databricks_uri": self.databricks_uri, "model_kwargs": self.model_kwargs, - "params": self.params, + "temperature": self.temperature, + "n": self.n, + "stop": self.stop, + "max_tokens": self.max_tokens, + "extra_params": self.extra_params, # TODO: Support saving transform_input_fn and transform_output_fn # "transform_input_fn": self.transform_input_fn, # "transform_output_fn": self.transform_output_fn, @@ -423,17 +454,17 @@ def _call( # TODO: support callbacks - request = {"prompt": prompt, "stop": stop} + request: Dict[str, Any] = {"prompt": prompt} + if self._client.llm: + request.update(self._llm_params) + request.update(self.model_kwargs or self.extra_params) + else: + request.update(self.model_kwargs or self.extra_params) request.update(kwargs) - if self._params: - request.update(self._params) + if stop := self.stop or stop: + request["stop"] = stop if self.transform_input_fn: request = self.transform_input_fn(**request) - response = self._client.post(request) - - if self.transform_output_fn: - response = self.transform_output_fn(response) - - return response + return self._client.post(request, transform_output_fn=self.transform_output_fn) diff --git a/libs/langchain/langchain/llms/mlflow.py b/libs/langchain/langchain/llms/mlflow.py index 7f77fe12efff5..565a4b3a363f0 100644 --- a/libs/langchain/langchain/llms/mlflow.py +++ b/libs/langchain/langchain/llms/mlflow.py @@ -5,7 +5,7 @@ from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models import LLM -from langchain_core.pydantic_v1 import BaseModel, Extra, PrivateAttr +from langchain_core.pydantic_v1 import BaseModel, Extra, Field, PrivateAttr # Ignoring type because below is valid pydantic code @@ -41,7 +41,17 @@ class Mlflow(LLM): """The endpoint to use.""" target_uri: str """The target URI to use.""" - params: Optional[Params] = None + temperature: float = 0.0 + """The sampling temperature.""" + n: int = 1 + """The number of completion choices to generate.""" + stop: Optional[List[str]] = None + """The stop sequence.""" + max_tokens: Optional[int] = None + """The maximum number of tokens to generate.""" + extra_params: Dict[str, Any] = Field(default_factory=dict) + """Any extra parameters to pass to the endpoint.""" + """Extra parameters such as `temperature`.""" _client: Any = PrivateAttr() @@ -71,13 +81,15 @@ def _validate_uri(self) -> None: @property def _default_params(self) -> Dict[str, Any]: - params: Dict[str, Any] = { + return { "target_uri": self.target_uri, "endpoint": self.endpoint, + "temperature": self.temperature, + "n": self.n, + "stop": self.stop, + "max_tokens": self.max_tokens, + "extra_params": self.extra_params, } - if self.params: - params["params"] = self.params.dict() - return params @property def _identifying_params(self) -> Mapping[str, Any]: @@ -92,10 +104,14 @@ def _call( ) -> str: data: Dict[str, Any] = { "prompt": prompt, - **(self.params.dict() if self.params else {}), + "temperature": self.temperature, + "n": self.n, + "max_tokens": self.max_tokens, + **self.extra_params, + **kwargs, } - if s := (stop or (self.params.stop if self.params else None)): - data["stop"] = s + if stop := self.stop or stop: + data["stop"] = stop resp = self._client.predict(endpoint=self.endpoint, inputs=data) return resp["choices"][0]["text"]