From 653a7027480fedd0a166b9708d3b6e4fdae1d8b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=BC=E6=AC=A3?= Date: Tue, 3 Sep 2024 14:55:51 +0800 Subject: [PATCH] fix bugs in qwen2vl's openai-compatible client --- examples/qwen2vl_assistant_tooluse.py | 6 ++ examples/qwen2vl_function_calling.py | 6 ++ qwen_agent/llm/__init__.py | 31 +++++++---- qwen_agent/llm/azure.py | 32 +++-------- qwen_agent/llm/oai.py | 52 ++++++----------- qwen_agent/llm/qwenvl_oai.py | 80 +++++++++++++++------------ qwen_agent/utils/utils.py | 2 +- 7 files changed, 101 insertions(+), 108 deletions(-) diff --git a/examples/qwen2vl_assistant_tooluse.py b/examples/qwen2vl_assistant_tooluse.py index 8d0f6d8..dbfac68 100644 --- a/examples/qwen2vl_assistant_tooluse.py +++ b/examples/qwen2vl_assistant_tooluse.py @@ -279,6 +279,12 @@ def init_agent_service(): # 'model_server': 'http://localhost:8000/v1', # api_base # 'api_key': 'EMPTY', + # Using Qwen2-VL provided by Alibaba Cloud DashScope's openai-compatible service: + # 'model_type': 'qwenvl_oai', + # 'model': 'qwen-vl-max-0809', + # 'model_server': 'https://dashscope.aliyuncs.com/compatible-mode/v1', + # 'api_key': os.getenv('DASHSCOPE_API_KEY'), + # Using Qwen2-VL provided by Alibaba Cloud DashScope: 'model_type': 'qwenvl_dashscope', 'model': 'qwen-vl-max-0809', diff --git a/examples/qwen2vl_function_calling.py b/examples/qwen2vl_function_calling.py index 565f7a0..e5f88b9 100644 --- a/examples/qwen2vl_function_calling.py +++ b/examples/qwen2vl_function_calling.py @@ -23,6 +23,12 @@ def test(): # 'model_server': 'http://localhost:8000/v1', # api_base # 'api_key': 'EMPTY', + # Using Qwen2-VL provided by Alibaba Cloud DashScope's openai-compatible service: + # 'model_type': 'qwenvl_oai', + # 'model': 'qwen-vl-max-0809', + # 'model_server': 'https://dashscope.aliyuncs.com/compatible-mode/v1', + # 'api_key': os.getenv('DASHSCOPE_API_KEY'), + # Using Qwen2-VL provided by Alibaba Cloud DashScope: 'model_type': 'qwenvl_dashscope', 'model': 'qwen-vl-max-0809', diff --git a/qwen_agent/llm/__init__.py b/qwen_agent/llm/__init__.py index c8e873f..dbfe12e 100644 --- a/qwen_agent/llm/__init__.py +++ b/qwen_agent/llm/__init__.py @@ -1,3 +1,4 @@ +import copy from typing import Union from .azure import TextChatAtAzure @@ -14,17 +15,21 @@ def get_chat_model(cfg: Union[dict, str] = 'qwen-plus') -> BaseChatModel: Args: cfg: The LLM configuration, one example is: - llm_cfg = { - # Use the model service provided by DashScope: - 'model': 'qwen-max', - 'model_server': 'dashscope', - # Use your own model service compatible with OpenAI API: - # 'model': 'Qwen', - # 'model_server': 'http://127.0.0.1:7905/v1', - # (Optional) LLM hyper-parameters: - 'generate_cfg': { - 'top_p': 0.8 - } + cfg = { + # Use the model service provided by DashScope: + 'model': 'qwen-max', + 'model_server': 'dashscope', + + # Use your own model service compatible with OpenAI API: + # 'model': 'Qwen', + # 'model_server': 'http://127.0.0.1:7905/v1', + + # (Optional) LLM hyper-parameters: + 'generate_cfg': { + 'top_p': 0.8, + 'max_input_tokens': 6500, + 'max_retries': 10, + } } Returns: @@ -36,6 +41,10 @@ def get_chat_model(cfg: Union[dict, str] = 'qwen-plus') -> BaseChatModel: if 'model_type' in cfg: model_type = cfg['model_type'] if model_type in LLM_REGISTRY: + if model_type in ('oai', 'qwenvl_oai'): + if cfg.get('model_server', '').strip() == 'dashscope': + cfg = copy.deepcopy(cfg) + cfg['model_server'] = 'https://dashscope.aliyuncs.com/compatible-mode/v1' return LLM_REGISTRY[model_type](cfg) else: raise ValueError(f'Please set model_type from {str(LLM_REGISTRY.keys())}') diff --git a/qwen_agent/llm/azure.py b/qwen_agent/llm/azure.py index fc30d4f..bb6b87a 100644 --- a/qwen_agent/llm/azure.py +++ b/qwen_agent/llm/azure.py @@ -1,4 +1,3 @@ -import copy import os from typing import Dict, Optional @@ -15,18 +14,15 @@ def __init__(self, cfg: Optional[Dict] = None): super().__init__(cfg) cfg = cfg or {} - api_base = cfg.get( - 'api_base', - cfg.get( - 'base_url', - cfg.get('model_server', cfg.get('azure_endpoint', '')), - ), - ).strip() + api_base = cfg.get('api_base') + api_base = api_base or cfg.get('base_url') + api_base = api_base or cfg.get('model_server') + api_base = api_base or cfg.get('azure_endpoint') + api_base = (api_base or '').strip() - api_key = cfg.get('api_key', '') - if not api_key: - api_key = os.getenv('OPENAI_API_KEY', 'EMPTY') - api_key = api_key.strip() + api_key = cfg.get('api_key') + api_key = api_key or os.getenv('OPENAI_API_KEY') + api_key = (api_key or 'EMPTY').strip() api_version = cfg.get('api_version', '2024-06-01') @@ -39,19 +35,7 @@ def __init__(self, cfg: Optional[Dict] = None): api_kwargs['api_version'] = api_version def _chat_complete_create(*args, **kwargs): - # OpenAI API v1 does not allow the following args, must pass by extra_body - extra_params = ['top_k', 'repetition_penalty'] - if any((k in kwargs) for k in extra_params): - kwargs['extra_body'] = copy.deepcopy(kwargs.get('extra_body', {})) - for k in extra_params: - if k in kwargs: - kwargs['extra_body'][k] = kwargs.pop(k) - if 'request_timeout' in kwargs: - kwargs['timeout'] = kwargs.pop('request_timeout') - client = openai.AzureOpenAI(**api_kwargs) - - # client = openai.OpenAI(**api_kwargs) return client.chat.completions.create(*args, **kwargs) self._chat_complete_create = _chat_complete_create diff --git a/qwen_agent/llm/oai.py b/qwen_agent/llm/oai.py index 9ad3f59..d3408c6 100644 --- a/qwen_agent/llm/oai.py +++ b/qwen_agent/llm/oai.py @@ -22,21 +22,17 @@ class TextChatAtOAI(BaseFnCallModel): def __init__(self, cfg: Optional[Dict] = None): super().__init__(cfg) - self.model = self.model or 'gpt-3.5-turbo' + self.model = self.model or 'gpt-4o-mini' cfg = cfg or {} - api_base = cfg.get( - 'api_base', - cfg.get( - 'base_url', - cfg.get('model_server', ''), - ), - ).strip() + api_base = cfg.get('api_base') + api_base = api_base or cfg.get('base_url') + api_base = api_base or cfg.get('model_server') + api_base = (api_base or '').strip() - api_key = cfg.get('api_key', '') - if not api_key: - api_key = os.getenv('OPENAI_API_KEY', 'EMPTY') - api_key = api_key.strip() + api_key = cfg.get('api_key') + api_key = api_key or os.getenv('OPENAI_API_KEY') + api_key = (api_key or 'EMPTY').strip() if openai.__version__.startswith('0.'): if api_base: @@ -73,9 +69,7 @@ def _chat_stream( delta_stream: bool, generate_cfg: dict, ) -> Iterator[List[Message]]: - messages = [msg.model_dump() for msg in messages] - if logger.isEnabledFor(logging.DEBUG): - logger.debug(f'LLM Input:\n{pretty_format_messages(messages, indent=2)}') + messages = self.convert_messages_to_dicts(messages) try: response = self._chat_complete_create(model=self.model, messages=messages, stream=True, **generate_cfg) if delta_stream: @@ -96,30 +90,16 @@ def _chat_no_stream( messages: List[Message], generate_cfg: dict, ) -> List[Message]: - messages = [msg.model_dump() for msg in messages] - if logger.isEnabledFor(logging.DEBUG): - logger.debug(f'LLM Input:\n{pretty_format_messages(messages, indent=2)}') + messages = self.convert_messages_to_dicts(messages) try: response = self._chat_complete_create(model=self.model, messages=messages, stream=False, **generate_cfg) return [Message(ASSISTANT, response.choices[0].message.content)] except OpenAIError as ex: raise ModelServiceError(exception=ex) - -def pretty_format_messages(messages: List[dict], indent: int = 2) -> str: - messages_show = [] - for msg in messages: - assert isinstance(msg, dict) - msg_show = copy.deepcopy(msg) - if isinstance(msg['content'], list): - content = [] - for item in msg['content']: - (t, v), = item.items() - if (t != 'text') and v.startswith('data:'): - v = v[:64] + ' ...' - content.append({t: v}) - else: - content = msg['content'] - msg_show['content'] = content - messages_show.append(msg_show) - return pformat(messages_show, indent=indent) + @staticmethod + def convert_messages_to_dicts(messages: List[Message]) -> List[dict]: + messages = [msg.model_dump() for msg in messages] + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f'LLM Input:\n{pformat(messages, indent=2)}') + return messages diff --git a/qwen_agent/llm/qwenvl_oai.py b/qwen_agent/llm/qwenvl_oai.py index 1751d53..0d636a6 100644 --- a/qwen_agent/llm/qwenvl_oai.py +++ b/qwen_agent/llm/qwenvl_oai.py @@ -1,31 +1,17 @@ import copy +import logging import os -from typing import Iterator, List +from pprint import pformat +from typing import List +from qwen_agent.llm import ModelServiceError from qwen_agent.llm.base import register_llm from qwen_agent.llm.oai import TextChatAtOAI -from qwen_agent.llm.schema import Message +from qwen_agent.llm.schema import ContentItem, Message +from qwen_agent.log import logger from qwen_agent.utils.utils import encode_image_as_base64 -def _convert_local_images_to_base64(messages: List[Message]) -> List[Message]: - messages_new = [] - for msg in messages: - if isinstance(msg.content, list): - msg = copy.deepcopy(msg) - for item in msg.content: - t, v = item.get_type_and_value() - if t == 'image': - if v.startswith('file://'): - v = v[len('file://'):] - if (not v.startswith(('http://', 'https://', 'data:'))) and os.path.exists(v): - item.image = encode_image_as_base64(v, max_short_side_length=1080) - else: - assert isinstance(msg.content, str) - messages_new.append(msg) - return messages_new - - @register_llm('qwenvl_oai') class QwenVLChatAtOAI(TextChatAtOAI): @@ -33,19 +19,41 @@ class QwenVLChatAtOAI(TextChatAtOAI): def support_multimodal_input(self) -> bool: return True - def _chat_stream( - self, - messages: List[Message], - delta_stream: bool, - generate_cfg: dict, - ) -> Iterator[List[Message]]: - messages = _convert_local_images_to_base64(messages) - return super()._chat_stream(messages=messages, delta_stream=delta_stream, generate_cfg=generate_cfg) - - def _chat_no_stream( - self, - messages: List[Message], - generate_cfg: dict, - ) -> List[Message]: - messages = _convert_local_images_to_base64(messages) - return super()._chat_no_stream(messages=messages, generate_cfg=generate_cfg) + @staticmethod + def convert_messages_to_dicts(messages: List[Message]) -> List[dict]: + new_messages = [] + + for msg in messages: + content = msg.content + if isinstance(content, str): + content = [ContentItem(text=content)] + assert isinstance(content, list) + + new_content = [] + for item in content: + t, v = item.get_type_and_value() + if t == 'text': + new_content.append({'type': 'text', 'text': v}) + if t == 'image': + if v.startswith('file://'): + v = v[len('file://'):] + if not v.startswith(('http://', 'https://', 'data:')): + if os.path.exists(v): + v = encode_image_as_base64(v, max_short_side_length=1080) + else: + raise ModelServiceError(f'Local image "{v}" does not exist.') + new_content.append({'type': 'image_url', 'image_url': {'url': v}}) + + new_msg = msg.model_dump() + new_msg['content'] = new_content + new_messages.append(new_msg) + + if logger.isEnabledFor(logging.DEBUG): + lite_messages = copy.deepcopy(new_messages) + for msg in lite_messages: + for item in msg['content']: + if item.get('image_url', {}).get('url', '').startswith('data:'): + item['image_url']['url'] = item['image_url']['url'][:64] + '...' + logger.debug(f'LLM Input:\n{pformat(lite_messages, indent=2)}') + + return new_messages diff --git a/qwen_agent/utils/utils.py b/qwen_agent/utils/utils.py index c74b5fe..4c3b448 100644 --- a/qwen_agent/utils/utils.py +++ b/qwen_agent/utils/utils.py @@ -471,7 +471,7 @@ def encode_image_as_base64(path: str, max_short_side_length: int = -1) -> str: image = image.convert(mode='RGB') buffered = BytesIO() image.save(buffered, format='JPEG') - return 'data:image/jpg;base64,' + base64.b64encode(buffered.getvalue()).decode('utf-8') + return 'data:image/jpeg;base64,' + base64.b64encode(buffered.getvalue()).decode('utf-8') def load_image_from_base64(image_base64: Union[bytes, str]):