Skip to content

Commit

Permalink
add langchain import
Browse files Browse the repository at this point in the history
  • Loading branch information
Yunnglin committed Dec 26, 2024
1 parent c1c0d75 commit 1f4da4b
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,14 @@

from typing import Dict

from langchain_core.utils import (
convert_to_secret_str,
get_from_dict_or_env
)
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from pydantic import model_validator

from langchain_community.chat_models import ChatOpenAI
from langchain_community.llms.modelscope_endpoint import (
MODELSCOPE_SERVICE_URL_BASE,
ModelScopeCommon,
)
from pydantic import model_validator


class ModelScopeChatEndpoint(ModelScopeCommon, ChatOpenAI): # type: ignore[misc, override, override]
Expand Down
11 changes: 7 additions & 4 deletions libs/community/langchain_community/llms/modelscope_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
)
from langchain_core.language_models import LLM
from langchain_core.outputs.generation import GenerationChunk
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from pydantic import (
BaseModel,
ConfigDict,
Expand Down Expand Up @@ -109,7 +109,8 @@ class ModelScopeCommon(BaseModel):
base_url: str = MODELSCOPE_SERVICE_URL_BASE
modelscope_sdk_token: Optional[SecretStr] = Field(default=None, alias="api_key")
model_name: str = Field(default="Qwen/Qwen2.5-Coder-32B-Instruct", alias="model")
"""Model name. Available models listed here: https://modelscope.cn/docs/model-service/API-Inference/intro """
"""Model name. Available models listed here:
https://modelscope.cn/docs/model-service/API-Inference/intro """
max_tokens: int = 1024
"""Maximum number of tokens to generate."""
temperature: float = 0.3
Expand Down Expand Up @@ -161,7 +162,9 @@ def validate_environment(cls, values: Dict) -> Dict:

values["client"] = ModelScopeClient(
api_key=values["modelscope_sdk_token"],
base_url=values["base_url"] if "base_url" in values else MODELSCOPE_SERVICE_URL_BASE,
base_url=values["base_url"]
if "base_url" in values
else MODELSCOPE_SERVICE_URL_BASE, # noqa: E501
timeout=values["timeout"] if "timeout" in values else 60,
)
return values
Expand Down Expand Up @@ -191,7 +194,7 @@ class ModelScopeEndpoint(ModelScopeCommon, LLM):
async for chunk in llm.astream("write a quick sort in python"):
print(chunk, end='', flush=True)
"""
""" # noqa: E501

model_config = ConfigDict(
populate_by_name=True,
Expand Down
151 changes: 151 additions & 0 deletions libs/community/langchain_community/llms/modelscope_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from __future__ import annotations

import logging
from typing import Any, Dict, Iterator, List, Mapping, Optional

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import Generation, GenerationChunk, LLMResult

DEFAULT_MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
DEFAULT_TASK = "chat"
VALID_TASKS = (
"chat",
"text-generation",
)
DEFAULT_BATCH_SIZE = 4

logger = logging.getLogger(__name__)


class ModelScopePipeline(BaseLLM):
"""ModelScope Pipeline API.
To use, you should have the ``modelscope[framework]`` and ``ms-swift[llm]`` python package installed,
you can install with ``pip install 'ms-swift[llm]' 'modelscope[framework]' -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html``.
Only supports `chat` task for now.
Example using from_model_id:
.. code-block:: python
from langchain_community.llms.modelscope_pipeline import ModelScopePipeline
llm = ModelScopePipeline.from_model_id(
model_id="Qwen/Qwen2.5-0.5B-Instruct",
task="chat",
generate_kwargs={'do_sample': True, 'max_new_tokens': 128},
)
llm.invoke("Hello, how are you?")
""" # noqa: E501

pipeline: Any #: :meta private:
task: str = DEFAULT_TASK
model_id: str = DEFAULT_MODEL_ID
model_revision: Optional[str] = None
generate_kwargs: Optional[Dict[Any, Any]] = None
"""Keyword arguments passed to the pipeline."""
batch_size: int = DEFAULT_BATCH_SIZE
"""Batch size to use when passing multiple documents to generate."""

@classmethod
def from_model_id(
cls,
model_id: str = DEFAULT_MODEL_ID,
model_revision: Optional[str] = None,
task: str = DEFAULT_TASK,
device_map: Optional[str] = None,
generate_kwargs: Optional[Dict[Any, Any]] = None,
batch_size: int = DEFAULT_BATCH_SIZE,
**kwargs: Any,
) -> ModelScopePipeline:
"""Construct the pipeline object from model_id and task."""
try:
from modelscope import pipeline # type: ignore[import]
except ImportError:
raise ValueError(
"Could not import modelscope python package. "
"Please install it with `pip install 'ms-swift[llm]' 'modelscope[framework]' -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html`." # noqa: E501
)
modelscope_pipeline = pipeline(
task=task,
model=model_id,
model_revision=model_revision,
device_map="auto" if device_map is None else device_map,
llm_framework="swift",
**kwargs,
)
return cls(
pipeline=modelscope_pipeline,
task=task,
model_id=model_id,
model_revision=model_revision,
generate_kwargs=generate_kwargs,
batch_size=batch_size,
**kwargs,
)

@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
"model_id": self.model_id,
"generate_kwargs": self.generate_kwargs,
}

@property
def _llm_type(self) -> str:
return "modelscope_pipeline"

def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
if self.generate_kwargs is not None:
gen_cfg = {**self.generate_kwargs, **kwargs}
else:
gen_cfg = {**kwargs}

for stream_output in self.pipeline.stream_generate(prompt, **gen_cfg):
text = stream_output["text"]
chunk = GenerationChunk(text=text)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk

def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
# List to hold all results
text_generations: List[str] = []
if self.generate_kwargs is not None:
gen_cfg = {**self.generate_kwargs, **kwargs}
else:
gen_cfg = {**kwargs}

for i in range(0, len(prompts), self.batch_size):
batch_prompts = prompts[i : i + self.batch_size]

# Process batch of prompts
responses = self.pipeline(
batch_prompts,
**gen_cfg,
)
# Process each response in the batch
for j, response in enumerate(responses):
if isinstance(response, list):
# if model returns multiple generations, pick the top one
response = response[0]
text = response["text"]
# Append the processed text to results
text_generations.append(text)

return LLMResult(
generations=[[Generation(text=text)] for text in text_generations]
)
Empty file.
23 changes: 23 additions & 0 deletions libs/langchain/langchain/chat_models/modelscope_endpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import TYPE_CHECKING, Any

from langchain._api import create_importer

if TYPE_CHECKING:
from langchain_community.chat_models import ModelScopeChatEndpoint

# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and
# handling optional imports.
DEPRECATED_LOOKUP = {"ModelScopeChatEndpoint": "langchain_community.chat_models.modelscope_endpoint"}

Check failure on line 11 in libs/langchain/langchain/chat_models/modelscope_endpoint.py

View workflow job for this annotation

GitHub Actions / cd libs/langchain / make lint #3.9

Ruff (E501)

langchain/chat_models/modelscope_endpoint.py:11:89: E501 Line too long (101 > 88)

Check failure on line 11 in libs/langchain/langchain/chat_models/modelscope_endpoint.py

View workflow job for this annotation

GitHub Actions / cd libs/langchain / make lint #3.13

Ruff (E501)

langchain/chat_models/modelscope_endpoint.py:11:89: E501 Line too long (101 > 88)

_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)


def __getattr__(name: str) -> Any:
"""Look up attributes dynamically."""
return _import_attribute(name)


__all__ = [
"ModelScopeChatEndpoint",
]
23 changes: 23 additions & 0 deletions libs/langchain/langchain/llms/modelscope_endpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import TYPE_CHECKING, Any

from langchain._api import create_importer

if TYPE_CHECKING:
from langchain_community.llms import ModelScopeEndpoint

# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and
# handling optional imports.
DEPRECATED_LOOKUP = {"ModelScopeEndpoint": "langchain_community.llms"}

_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)


def __getattr__(name: str) -> Any:
"""Look up attributes dynamically."""
return _import_attribute(name)


__all__ = [
"ModelScopeEndpoint",
]

0 comments on commit 1f4da4b

Please sign in to comment.