-
Notifications
You must be signed in to change notification settings - Fork 15.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
206 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
151 changes: 151 additions & 0 deletions
151
libs/community/langchain_community/llms/modelscope_pipeline.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
from typing import Any, Dict, Iterator, List, Mapping, Optional | ||
|
||
from langchain_core.callbacks import CallbackManagerForLLMRun | ||
from langchain_core.language_models.llms import BaseLLM | ||
from langchain_core.outputs import Generation, GenerationChunk, LLMResult | ||
|
||
DEFAULT_MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct" | ||
DEFAULT_TASK = "chat" | ||
VALID_TASKS = ( | ||
"chat", | ||
"text-generation", | ||
) | ||
DEFAULT_BATCH_SIZE = 4 | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class ModelScopePipeline(BaseLLM): | ||
"""ModelScope Pipeline API. | ||
To use, you should have the ``modelscope[framework]`` and ``ms-swift[llm]`` python package installed, | ||
you can install with ``pip install 'ms-swift[llm]' 'modelscope[framework]' -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html``. | ||
Only supports `chat` task for now. | ||
Example using from_model_id: | ||
.. code-block:: python | ||
from langchain_community.llms.modelscope_pipeline import ModelScopePipeline | ||
llm = ModelScopePipeline.from_model_id( | ||
model_id="Qwen/Qwen2.5-0.5B-Instruct", | ||
task="chat", | ||
generate_kwargs={'do_sample': True, 'max_new_tokens': 128}, | ||
) | ||
llm.invoke("Hello, how are you?") | ||
""" # noqa: E501 | ||
|
||
pipeline: Any #: :meta private: | ||
task: str = DEFAULT_TASK | ||
model_id: str = DEFAULT_MODEL_ID | ||
model_revision: Optional[str] = None | ||
generate_kwargs: Optional[Dict[Any, Any]] = None | ||
"""Keyword arguments passed to the pipeline.""" | ||
batch_size: int = DEFAULT_BATCH_SIZE | ||
"""Batch size to use when passing multiple documents to generate.""" | ||
|
||
@classmethod | ||
def from_model_id( | ||
cls, | ||
model_id: str = DEFAULT_MODEL_ID, | ||
model_revision: Optional[str] = None, | ||
task: str = DEFAULT_TASK, | ||
device_map: Optional[str] = None, | ||
generate_kwargs: Optional[Dict[Any, Any]] = None, | ||
batch_size: int = DEFAULT_BATCH_SIZE, | ||
**kwargs: Any, | ||
) -> ModelScopePipeline: | ||
"""Construct the pipeline object from model_id and task.""" | ||
try: | ||
from modelscope import pipeline # type: ignore[import] | ||
except ImportError: | ||
raise ValueError( | ||
"Could not import modelscope python package. " | ||
"Please install it with `pip install 'ms-swift[llm]' 'modelscope[framework]' -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html`." # noqa: E501 | ||
) | ||
modelscope_pipeline = pipeline( | ||
task=task, | ||
model=model_id, | ||
model_revision=model_revision, | ||
device_map="auto" if device_map is None else device_map, | ||
llm_framework="swift", | ||
**kwargs, | ||
) | ||
return cls( | ||
pipeline=modelscope_pipeline, | ||
task=task, | ||
model_id=model_id, | ||
model_revision=model_revision, | ||
generate_kwargs=generate_kwargs, | ||
batch_size=batch_size, | ||
**kwargs, | ||
) | ||
|
||
@property | ||
def _identifying_params(self) -> Mapping[str, Any]: | ||
"""Get the identifying parameters.""" | ||
return { | ||
"model_id": self.model_id, | ||
"generate_kwargs": self.generate_kwargs, | ||
} | ||
|
||
@property | ||
def _llm_type(self) -> str: | ||
return "modelscope_pipeline" | ||
|
||
def _stream( | ||
self, | ||
prompt: str, | ||
stop: Optional[List[str]] = None, | ||
run_manager: Optional[CallbackManagerForLLMRun] = None, | ||
**kwargs: Any, | ||
) -> Iterator[GenerationChunk]: | ||
if self.generate_kwargs is not None: | ||
gen_cfg = {**self.generate_kwargs, **kwargs} | ||
else: | ||
gen_cfg = {**kwargs} | ||
|
||
for stream_output in self.pipeline.stream_generate(prompt, **gen_cfg): | ||
text = stream_output["text"] | ||
chunk = GenerationChunk(text=text) | ||
if run_manager: | ||
run_manager.on_llm_new_token(chunk.text, chunk=chunk) | ||
yield chunk | ||
|
||
def _generate( | ||
self, | ||
prompts: List[str], | ||
stop: Optional[List[str]] = None, | ||
run_manager: Optional[CallbackManagerForLLMRun] = None, | ||
**kwargs: Any, | ||
) -> LLMResult: | ||
# List to hold all results | ||
text_generations: List[str] = [] | ||
if self.generate_kwargs is not None: | ||
gen_cfg = {**self.generate_kwargs, **kwargs} | ||
else: | ||
gen_cfg = {**kwargs} | ||
|
||
for i in range(0, len(prompts), self.batch_size): | ||
batch_prompts = prompts[i : i + self.batch_size] | ||
|
||
# Process batch of prompts | ||
responses = self.pipeline( | ||
batch_prompts, | ||
**gen_cfg, | ||
) | ||
# Process each response in the batch | ||
for j, response in enumerate(responses): | ||
if isinstance(response, list): | ||
# if model returns multiple generations, pick the top one | ||
response = response[0] | ||
text = response["text"] | ||
# Append the processed text to results | ||
text_generations.append(text) | ||
|
||
return LLMResult( | ||
generations=[[Generation(text=text)] for text in text_generations] | ||
) |
Empty file.
23 changes: 23 additions & 0 deletions
23
libs/langchain/langchain/chat_models/modelscope_endpoint.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from typing import TYPE_CHECKING, Any | ||
|
||
from langchain._api import create_importer | ||
|
||
if TYPE_CHECKING: | ||
from langchain_community.chat_models import ModelScopeChatEndpoint | ||
|
||
# Create a way to dynamically look up deprecated imports. | ||
# Used to consolidate logic for raising deprecation warnings and | ||
# handling optional imports. | ||
DEPRECATED_LOOKUP = {"ModelScopeChatEndpoint": "langchain_community.chat_models.modelscope_endpoint"} | ||
Check failure on line 11 in libs/langchain/langchain/chat_models/modelscope_endpoint.py GitHub Actions / cd libs/langchain / make lint #3.9Ruff (E501)
|
||
|
||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) | ||
|
||
|
||
def __getattr__(name: str) -> Any: | ||
"""Look up attributes dynamically.""" | ||
return _import_attribute(name) | ||
|
||
|
||
__all__ = [ | ||
"ModelScopeChatEndpoint", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from typing import TYPE_CHECKING, Any | ||
|
||
from langchain._api import create_importer | ||
|
||
if TYPE_CHECKING: | ||
from langchain_community.llms import ModelScopeEndpoint | ||
|
||
# Create a way to dynamically look up deprecated imports. | ||
# Used to consolidate logic for raising deprecation warnings and | ||
# handling optional imports. | ||
DEPRECATED_LOOKUP = {"ModelScopeEndpoint": "langchain_community.llms"} | ||
|
||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) | ||
|
||
|
||
def __getattr__(name: str) -> Any: | ||
"""Look up attributes dynamically.""" | ||
return _import_attribute(name) | ||
|
||
|
||
__all__ = [ | ||
"ModelScopeEndpoint", | ||
] |