From 1f4da4be7c727abeb4c187c41a1885a614678709 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Thu, 26 Dec 2024 17:25:37 +0800 Subject: [PATCH] add langchain import --- .../chat_models/modelscope_endpoint.py | 7 +- .../llms/modelscope_endpoint.py | 11 +- .../llms/modelscope_pipeline.py | 151 ++++++++++++++++++ .../llms/test_modelscope_pipeline.py | 0 .../chat_models/modelscope_endpoint.py | 23 +++ .../langchain/llms/modelscope_endpoint.py | 23 +++ 6 files changed, 206 insertions(+), 9 deletions(-) create mode 100644 libs/community/langchain_community/llms/modelscope_pipeline.py create mode 100644 libs/community/tests/integration_tests/llms/test_modelscope_pipeline.py create mode 100644 libs/langchain/langchain/chat_models/modelscope_endpoint.py create mode 100644 libs/langchain/langchain/llms/modelscope_endpoint.py diff --git a/libs/community/langchain_community/chat_models/modelscope_endpoint.py b/libs/community/langchain_community/chat_models/modelscope_endpoint.py index f48b290d38692..2dd41cc150bc2 100644 --- a/libs/community/langchain_community/chat_models/modelscope_endpoint.py +++ b/libs/community/langchain_community/chat_models/modelscope_endpoint.py @@ -2,17 +2,14 @@ from typing import Dict -from langchain_core.utils import ( - convert_to_secret_str, - get_from_dict_or_env -) +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from pydantic import model_validator from langchain_community.chat_models import ChatOpenAI from langchain_community.llms.modelscope_endpoint import ( MODELSCOPE_SERVICE_URL_BASE, ModelScopeCommon, ) -from pydantic import model_validator class ModelScopeChatEndpoint(ModelScopeCommon, ChatOpenAI): # type: ignore[misc, override, override] diff --git a/libs/community/langchain_community/llms/modelscope_endpoint.py b/libs/community/langchain_community/llms/modelscope_endpoint.py index 9930400c00698..d93793a02dbe9 100644 --- a/libs/community/langchain_community/llms/modelscope_endpoint.py +++ b/libs/community/langchain_community/llms/modelscope_endpoint.py @@ -10,7 +10,7 @@ ) from langchain_core.language_models import LLM from langchain_core.outputs.generation import GenerationChunk -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env from pydantic import ( BaseModel, ConfigDict, @@ -109,7 +109,8 @@ class ModelScopeCommon(BaseModel): base_url: str = MODELSCOPE_SERVICE_URL_BASE modelscope_sdk_token: Optional[SecretStr] = Field(default=None, alias="api_key") model_name: str = Field(default="Qwen/Qwen2.5-Coder-32B-Instruct", alias="model") - """Model name. Available models listed here: https://modelscope.cn/docs/model-service/API-Inference/intro """ + """Model name. Available models listed here: + https://modelscope.cn/docs/model-service/API-Inference/intro """ max_tokens: int = 1024 """Maximum number of tokens to generate.""" temperature: float = 0.3 @@ -161,7 +162,9 @@ def validate_environment(cls, values: Dict) -> Dict: values["client"] = ModelScopeClient( api_key=values["modelscope_sdk_token"], - base_url=values["base_url"] if "base_url" in values else MODELSCOPE_SERVICE_URL_BASE, + base_url=values["base_url"] + if "base_url" in values + else MODELSCOPE_SERVICE_URL_BASE, # noqa: E501 timeout=values["timeout"] if "timeout" in values else 60, ) return values @@ -191,7 +194,7 @@ class ModelScopeEndpoint(ModelScopeCommon, LLM): async for chunk in llm.astream("write a quick sort in python"): print(chunk, end='', flush=True) - """ + """ # noqa: E501 model_config = ConfigDict( populate_by_name=True, diff --git a/libs/community/langchain_community/llms/modelscope_pipeline.py b/libs/community/langchain_community/llms/modelscope_pipeline.py new file mode 100644 index 0000000000000..b2e992036a3e1 --- /dev/null +++ b/libs/community/langchain_community/llms/modelscope_pipeline.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +import logging +from typing import Any, Dict, Iterator, List, Mapping, Optional + +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models.llms import BaseLLM +from langchain_core.outputs import Generation, GenerationChunk, LLMResult + +DEFAULT_MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct" +DEFAULT_TASK = "chat" +VALID_TASKS = ( + "chat", + "text-generation", +) +DEFAULT_BATCH_SIZE = 4 + +logger = logging.getLogger(__name__) + + +class ModelScopePipeline(BaseLLM): + """ModelScope Pipeline API. + + To use, you should have the ``modelscope[framework]`` and ``ms-swift[llm]`` python package installed, + you can install with ``pip install 'ms-swift[llm]' 'modelscope[framework]' -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html``. + + Only supports `chat` task for now. + + Example using from_model_id: + .. code-block:: python + + from langchain_community.llms.modelscope_pipeline import ModelScopePipeline + llm = ModelScopePipeline.from_model_id( + model_id="Qwen/Qwen2.5-0.5B-Instruct", + task="chat", + generate_kwargs={'do_sample': True, 'max_new_tokens': 128}, + ) + llm.invoke("Hello, how are you?") + """ # noqa: E501 + + pipeline: Any #: :meta private: + task: str = DEFAULT_TASK + model_id: str = DEFAULT_MODEL_ID + model_revision: Optional[str] = None + generate_kwargs: Optional[Dict[Any, Any]] = None + """Keyword arguments passed to the pipeline.""" + batch_size: int = DEFAULT_BATCH_SIZE + """Batch size to use when passing multiple documents to generate.""" + + @classmethod + def from_model_id( + cls, + model_id: str = DEFAULT_MODEL_ID, + model_revision: Optional[str] = None, + task: str = DEFAULT_TASK, + device_map: Optional[str] = None, + generate_kwargs: Optional[Dict[Any, Any]] = None, + batch_size: int = DEFAULT_BATCH_SIZE, + **kwargs: Any, + ) -> ModelScopePipeline: + """Construct the pipeline object from model_id and task.""" + try: + from modelscope import pipeline # type: ignore[import] + except ImportError: + raise ValueError( + "Could not import modelscope python package. " + "Please install it with `pip install 'ms-swift[llm]' 'modelscope[framework]' -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html`." # noqa: E501 + ) + modelscope_pipeline = pipeline( + task=task, + model=model_id, + model_revision=model_revision, + device_map="auto" if device_map is None else device_map, + llm_framework="swift", + **kwargs, + ) + return cls( + pipeline=modelscope_pipeline, + task=task, + model_id=model_id, + model_revision=model_revision, + generate_kwargs=generate_kwargs, + batch_size=batch_size, + **kwargs, + ) + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + return { + "model_id": self.model_id, + "generate_kwargs": self.generate_kwargs, + } + + @property + def _llm_type(self) -> str: + return "modelscope_pipeline" + + def _stream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: + if self.generate_kwargs is not None: + gen_cfg = {**self.generate_kwargs, **kwargs} + else: + gen_cfg = {**kwargs} + + for stream_output in self.pipeline.stream_generate(prompt, **gen_cfg): + text = stream_output["text"] + chunk = GenerationChunk(text=text) + if run_manager: + run_manager.on_llm_new_token(chunk.text, chunk=chunk) + yield chunk + + def _generate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> LLMResult: + # List to hold all results + text_generations: List[str] = [] + if self.generate_kwargs is not None: + gen_cfg = {**self.generate_kwargs, **kwargs} + else: + gen_cfg = {**kwargs} + + for i in range(0, len(prompts), self.batch_size): + batch_prompts = prompts[i : i + self.batch_size] + + # Process batch of prompts + responses = self.pipeline( + batch_prompts, + **gen_cfg, + ) + # Process each response in the batch + for j, response in enumerate(responses): + if isinstance(response, list): + # if model returns multiple generations, pick the top one + response = response[0] + text = response["text"] + # Append the processed text to results + text_generations.append(text) + + return LLMResult( + generations=[[Generation(text=text)] for text in text_generations] + ) diff --git a/libs/community/tests/integration_tests/llms/test_modelscope_pipeline.py b/libs/community/tests/integration_tests/llms/test_modelscope_pipeline.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/libs/langchain/langchain/chat_models/modelscope_endpoint.py b/libs/langchain/langchain/chat_models/modelscope_endpoint.py new file mode 100644 index 0000000000000..df5e3a865ada2 --- /dev/null +++ b/libs/langchain/langchain/chat_models/modelscope_endpoint.py @@ -0,0 +1,23 @@ +from typing import TYPE_CHECKING, Any + +from langchain._api import create_importer + +if TYPE_CHECKING: + from langchain_community.chat_models import ModelScopeChatEndpoint + +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = {"ModelScopeChatEndpoint": "langchain_community.chat_models.modelscope_endpoint"} + +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) + + +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) + + +__all__ = [ + "ModelScopeChatEndpoint", +] diff --git a/libs/langchain/langchain/llms/modelscope_endpoint.py b/libs/langchain/langchain/llms/modelscope_endpoint.py new file mode 100644 index 0000000000000..bdc9f134f7e8d --- /dev/null +++ b/libs/langchain/langchain/llms/modelscope_endpoint.py @@ -0,0 +1,23 @@ +from typing import TYPE_CHECKING, Any + +from langchain._api import create_importer + +if TYPE_CHECKING: + from langchain_community.llms import ModelScopeEndpoint + +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = {"ModelScopeEndpoint": "langchain_community.llms"} + +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) + + +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) + + +__all__ = [ + "ModelScopeEndpoint", +] \ No newline at end of file