Skip to content

Commit

Permalink
fix bugs in qwen2vl's openai-compatible client
Browse files Browse the repository at this point in the history
  • Loading branch information
JianxinMa committed Sep 3, 2024
1 parent cb86572 commit 653a702
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 108 deletions.
6 changes: 6 additions & 0 deletions examples/qwen2vl_assistant_tooluse.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,12 @@ def init_agent_service():
# 'model_server': 'http://localhost:8000/v1', # api_base
# 'api_key': 'EMPTY',

# Using Qwen2-VL provided by Alibaba Cloud DashScope's openai-compatible service:
# 'model_type': 'qwenvl_oai',
# 'model': 'qwen-vl-max-0809',
# 'model_server': 'https://dashscope.aliyuncs.com/compatible-mode/v1',
# 'api_key': os.getenv('DASHSCOPE_API_KEY'),

# Using Qwen2-VL provided by Alibaba Cloud DashScope:
'model_type': 'qwenvl_dashscope',
'model': 'qwen-vl-max-0809',
Expand Down
6 changes: 6 additions & 0 deletions examples/qwen2vl_function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ def test():
# 'model_server': 'http://localhost:8000/v1', # api_base
# 'api_key': 'EMPTY',

# Using Qwen2-VL provided by Alibaba Cloud DashScope's openai-compatible service:
# 'model_type': 'qwenvl_oai',
# 'model': 'qwen-vl-max-0809',
# 'model_server': 'https://dashscope.aliyuncs.com/compatible-mode/v1',
# 'api_key': os.getenv('DASHSCOPE_API_KEY'),

# Using Qwen2-VL provided by Alibaba Cloud DashScope:
'model_type': 'qwenvl_dashscope',
'model': 'qwen-vl-max-0809',
Expand Down
31 changes: 20 additions & 11 deletions qwen_agent/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from typing import Union

from .azure import TextChatAtAzure
Expand All @@ -14,17 +15,21 @@ def get_chat_model(cfg: Union[dict, str] = 'qwen-plus') -> BaseChatModel:
Args:
cfg: The LLM configuration, one example is:
llm_cfg = {
# Use the model service provided by DashScope:
'model': 'qwen-max',
'model_server': 'dashscope',
# Use your own model service compatible with OpenAI API:
# 'model': 'Qwen',
# 'model_server': 'http://127.0.0.1:7905/v1',
# (Optional) LLM hyper-parameters:
'generate_cfg': {
'top_p': 0.8
}
cfg = {
# Use the model service provided by DashScope:
'model': 'qwen-max',
'model_server': 'dashscope',
# Use your own model service compatible with OpenAI API:
# 'model': 'Qwen',
# 'model_server': 'http://127.0.0.1:7905/v1',
# (Optional) LLM hyper-parameters:
'generate_cfg': {
'top_p': 0.8,
'max_input_tokens': 6500,
'max_retries': 10,
}
}
Returns:
Expand All @@ -36,6 +41,10 @@ def get_chat_model(cfg: Union[dict, str] = 'qwen-plus') -> BaseChatModel:
if 'model_type' in cfg:
model_type = cfg['model_type']
if model_type in LLM_REGISTRY:
if model_type in ('oai', 'qwenvl_oai'):
if cfg.get('model_server', '').strip() == 'dashscope':
cfg = copy.deepcopy(cfg)
cfg['model_server'] = 'https://dashscope.aliyuncs.com/compatible-mode/v1'
return LLM_REGISTRY[model_type](cfg)
else:
raise ValueError(f'Please set model_type from {str(LLM_REGISTRY.keys())}')
Expand Down
32 changes: 8 additions & 24 deletions qwen_agent/llm/azure.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
import os
from typing import Dict, Optional

Expand All @@ -15,18 +14,15 @@ def __init__(self, cfg: Optional[Dict] = None):
super().__init__(cfg)
cfg = cfg or {}

api_base = cfg.get(
'api_base',
cfg.get(
'base_url',
cfg.get('model_server', cfg.get('azure_endpoint', '')),
),
).strip()
api_base = cfg.get('api_base')
api_base = api_base or cfg.get('base_url')
api_base = api_base or cfg.get('model_server')
api_base = api_base or cfg.get('azure_endpoint')
api_base = (api_base or '').strip()

api_key = cfg.get('api_key', '')
if not api_key:
api_key = os.getenv('OPENAI_API_KEY', 'EMPTY')
api_key = api_key.strip()
api_key = cfg.get('api_key')
api_key = api_key or os.getenv('OPENAI_API_KEY')
api_key = (api_key or 'EMPTY').strip()

api_version = cfg.get('api_version', '2024-06-01')

Expand All @@ -39,19 +35,7 @@ def __init__(self, cfg: Optional[Dict] = None):
api_kwargs['api_version'] = api_version

def _chat_complete_create(*args, **kwargs):
# OpenAI API v1 does not allow the following args, must pass by extra_body
extra_params = ['top_k', 'repetition_penalty']
if any((k in kwargs) for k in extra_params):
kwargs['extra_body'] = copy.deepcopy(kwargs.get('extra_body', {}))
for k in extra_params:
if k in kwargs:
kwargs['extra_body'][k] = kwargs.pop(k)
if 'request_timeout' in kwargs:
kwargs['timeout'] = kwargs.pop('request_timeout')

client = openai.AzureOpenAI(**api_kwargs)

# client = openai.OpenAI(**api_kwargs)
return client.chat.completions.create(*args, **kwargs)

self._chat_complete_create = _chat_complete_create
52 changes: 16 additions & 36 deletions qwen_agent/llm/oai.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,17 @@ class TextChatAtOAI(BaseFnCallModel):

def __init__(self, cfg: Optional[Dict] = None):
super().__init__(cfg)
self.model = self.model or 'gpt-3.5-turbo'
self.model = self.model or 'gpt-4o-mini'
cfg = cfg or {}

api_base = cfg.get(
'api_base',
cfg.get(
'base_url',
cfg.get('model_server', ''),
),
).strip()
api_base = cfg.get('api_base')
api_base = api_base or cfg.get('base_url')
api_base = api_base or cfg.get('model_server')
api_base = (api_base or '').strip()

api_key = cfg.get('api_key', '')
if not api_key:
api_key = os.getenv('OPENAI_API_KEY', 'EMPTY')
api_key = api_key.strip()
api_key = cfg.get('api_key')
api_key = api_key or os.getenv('OPENAI_API_KEY')
api_key = (api_key or 'EMPTY').strip()

if openai.__version__.startswith('0.'):
if api_base:
Expand Down Expand Up @@ -73,9 +69,7 @@ def _chat_stream(
delta_stream: bool,
generate_cfg: dict,
) -> Iterator[List[Message]]:
messages = [msg.model_dump() for msg in messages]
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f'LLM Input:\n{pretty_format_messages(messages, indent=2)}')
messages = self.convert_messages_to_dicts(messages)
try:
response = self._chat_complete_create(model=self.model, messages=messages, stream=True, **generate_cfg)
if delta_stream:
Expand All @@ -96,30 +90,16 @@ def _chat_no_stream(
messages: List[Message],
generate_cfg: dict,
) -> List[Message]:
messages = [msg.model_dump() for msg in messages]
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f'LLM Input:\n{pretty_format_messages(messages, indent=2)}')
messages = self.convert_messages_to_dicts(messages)
try:
response = self._chat_complete_create(model=self.model, messages=messages, stream=False, **generate_cfg)
return [Message(ASSISTANT, response.choices[0].message.content)]
except OpenAIError as ex:
raise ModelServiceError(exception=ex)


def pretty_format_messages(messages: List[dict], indent: int = 2) -> str:
messages_show = []
for msg in messages:
assert isinstance(msg, dict)
msg_show = copy.deepcopy(msg)
if isinstance(msg['content'], list):
content = []
for item in msg['content']:
(t, v), = item.items()
if (t != 'text') and v.startswith('data:'):
v = v[:64] + ' ...'
content.append({t: v})
else:
content = msg['content']
msg_show['content'] = content
messages_show.append(msg_show)
return pformat(messages_show, indent=indent)
@staticmethod
def convert_messages_to_dicts(messages: List[Message]) -> List[dict]:
messages = [msg.model_dump() for msg in messages]
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f'LLM Input:\n{pformat(messages, indent=2)}')
return messages
80 changes: 44 additions & 36 deletions qwen_agent/llm/qwenvl_oai.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,59 @@
import copy
import logging
import os
from typing import Iterator, List
from pprint import pformat
from typing import List

from qwen_agent.llm import ModelServiceError
from qwen_agent.llm.base import register_llm
from qwen_agent.llm.oai import TextChatAtOAI
from qwen_agent.llm.schema import Message
from qwen_agent.llm.schema import ContentItem, Message
from qwen_agent.log import logger
from qwen_agent.utils.utils import encode_image_as_base64


def _convert_local_images_to_base64(messages: List[Message]) -> List[Message]:
messages_new = []
for msg in messages:
if isinstance(msg.content, list):
msg = copy.deepcopy(msg)
for item in msg.content:
t, v = item.get_type_and_value()
if t == 'image':
if v.startswith('file://'):
v = v[len('file://'):]
if (not v.startswith(('http://', 'https://', 'data:'))) and os.path.exists(v):
item.image = encode_image_as_base64(v, max_short_side_length=1080)
else:
assert isinstance(msg.content, str)
messages_new.append(msg)
return messages_new


@register_llm('qwenvl_oai')
class QwenVLChatAtOAI(TextChatAtOAI):

@property
def support_multimodal_input(self) -> bool:
return True

def _chat_stream(
self,
messages: List[Message],
delta_stream: bool,
generate_cfg: dict,
) -> Iterator[List[Message]]:
messages = _convert_local_images_to_base64(messages)
return super()._chat_stream(messages=messages, delta_stream=delta_stream, generate_cfg=generate_cfg)

def _chat_no_stream(
self,
messages: List[Message],
generate_cfg: dict,
) -> List[Message]:
messages = _convert_local_images_to_base64(messages)
return super()._chat_no_stream(messages=messages, generate_cfg=generate_cfg)
@staticmethod
def convert_messages_to_dicts(messages: List[Message]) -> List[dict]:
new_messages = []

for msg in messages:
content = msg.content
if isinstance(content, str):
content = [ContentItem(text=content)]
assert isinstance(content, list)

new_content = []
for item in content:
t, v = item.get_type_and_value()
if t == 'text':
new_content.append({'type': 'text', 'text': v})
if t == 'image':
if v.startswith('file://'):
v = v[len('file://'):]
if not v.startswith(('http://', 'https://', 'data:')):
if os.path.exists(v):
v = encode_image_as_base64(v, max_short_side_length=1080)
else:
raise ModelServiceError(f'Local image "{v}" does not exist.')
new_content.append({'type': 'image_url', 'image_url': {'url': v}})

new_msg = msg.model_dump()
new_msg['content'] = new_content
new_messages.append(new_msg)

if logger.isEnabledFor(logging.DEBUG):
lite_messages = copy.deepcopy(new_messages)
for msg in lite_messages:
for item in msg['content']:
if item.get('image_url', {}).get('url', '').startswith('data:'):
item['image_url']['url'] = item['image_url']['url'][:64] + '...'
logger.debug(f'LLM Input:\n{pformat(lite_messages, indent=2)}')

return new_messages
2 changes: 1 addition & 1 deletion qwen_agent/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def encode_image_as_base64(path: str, max_short_side_length: int = -1) -> str:
image = image.convert(mode='RGB')
buffered = BytesIO()
image.save(buffered, format='JPEG')
return 'data:image/jpg;base64,' + base64.b64encode(buffered.getvalue()).decode('utf-8')
return 'data:image/jpeg;base64,' + base64.b64encode(buffered.getvalue()).decode('utf-8')


def load_image_from_base64(image_base64: Union[bytes, str]):
Expand Down

0 comments on commit 653a702

Please sign in to comment.