Skip to content

Commit

Permalink
support optional image prefix for vl fn-call
Browse files Browse the repository at this point in the history
  • Loading branch information
tuhahaha committed Dec 13, 2024
1 parent 5aa0070 commit c0349d0
Show file tree
Hide file tree
Showing 9 changed files with 120 additions and 48 deletions.
38 changes: 38 additions & 0 deletions examples/qwen2vl_assistant_video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from qwen_agent.agents import Assistant


def test():
bot = Assistant(llm={'model': 'qwen-vl-max-latest'})

messages = [{
'role':
'user',
'content': [{
'video': [
'https://help-static-aliyun-doc.aliyuncs.com/file-manage-files/zh-CN/20241108/xzsgiz/football1.jpg',
'https://help-static-aliyun-doc.aliyuncs.com/file-manage-files/zh-CN/20241108/tdescd/football2.jpg',
'https://help-static-aliyun-doc.aliyuncs.com/file-manage-files/zh-CN/20241108/zefdja/football3.jpg',
'https://help-static-aliyun-doc.aliyuncs.com/file-manage-files/zh-CN/20241108/aedbqh/football4.jpg'
]
}, {
'text': 'Describe the specific process of this video'
}]
}]

# Uploading video files requires applying for permission on DashScope
# messages = [{
# 'role':
# 'user',
# 'content': [{
# 'video': 'https://www.runoob.com/try/demo_source/mov_bbb.mp4'
# }, {
# 'text': 'Describe the specific process of this video'
# }]
# }]

for rsp in bot.run(messages):
print(rsp)


if __name__ == '__main__':
test()
2 changes: 1 addition & 1 deletion qwen_agent/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '0.0.14'
__version__ = '0.0.15'
from .agent import Agent
from .multi_agent_hub import MultiAgentHub

Expand Down
12 changes: 10 additions & 2 deletions qwen_agent/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,12 @@ def _preprocess_messages(
generate_cfg: dict,
functions: Optional[List[Dict]] = None,
) -> List[Message]:
messages = [format_as_multimodal_message(msg, add_upload_info=True, lang=lang) for msg in messages]
messages = [
format_as_multimodal_message(msg,
add_upload_info=True,
add_multimodel_upload_info=(functions is not None),
lang=lang) for msg in messages
]
return messages

def _postprocess_messages(
Expand All @@ -307,7 +312,10 @@ def _postprocess_messages(
fncall_mode: bool,
generate_cfg: dict,
) -> List[Message]:
messages = [format_as_multimodal_message(msg, add_upload_info=False) for msg in messages]
messages = [
format_as_multimodal_message(msg, add_upload_info=False, add_multimodel_upload_info=False)
for msg in messages
]
if not generate_cfg.get('skip_stopword_postproc', False):
stop = generate_cfg.get('stop', [])
messages = _postprocess_stop_words(messages, stop=stop)
Expand Down
5 changes: 4 additions & 1 deletion qwen_agent/llm/fncall_prompts/base_fncall_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ def format_plaintext_train_samples(
if has_para:
raise ValueError('This sample requires parallel_function_calls=True.')

messages = [format_as_multimodal_message(msg, add_upload_info=True, lang=lang) for msg in messages]
messages = [
format_as_multimodal_message(msg, add_upload_info=True, add_multimodel_upload_info=True, lang=lang)
for msg in messages
]
for m in messages:
for item in m.content:
if item.type != 'text':
Expand Down
2 changes: 1 addition & 1 deletion qwen_agent/llm/function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _preprocess_messages(
generate_cfg: dict,
functions: Optional[List[Dict]] = None,
) -> List[Message]:
messages = super()._preprocess_messages(messages, lang=lang, generate_cfg=generate_cfg)
messages = super()._preprocess_messages(messages, lang=lang, generate_cfg=generate_cfg, functions=functions)
if (not functions) or (generate_cfg.get('function_choice', 'auto') == 'none'):
messages = self._remove_fncall_messages(messages, lang=lang)
else:
Expand Down
61 changes: 31 additions & 30 deletions qwen_agent/llm/qwenvl_dashscope.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,40 +78,41 @@ def _format_local_files(messages: List[Message]) -> List[Message]:
if isinstance(msg.content, list):
for item in msg.content:
if item.image:
fname = item.image
if not fname.startswith((
'http://',
'https://',
'file://',
'data:', # base64 such as f"data:image/jpg;base64,{image_base64}"
)):
if fname.startswith('~'):
fname = os.path.expanduser(fname)
fname = os.path.abspath(fname)
if os.path.isfile(fname):
if re.match(r'^[A-Za-z]:\\', fname):
fname = fname.replace('\\', '/')
fname = 'file://' + fname
item.image = fname
item.image = _conv_fname(item.image)
if item.audio:
fname = item.audio
if not fname.startswith((
'http://',
'https://',
'file://',
'data:', # base64 such as f"data:image/jpg;base64,{image_base64}"
)):
if fname.startswith('~'):
fname = os.path.expanduser(fname)
fname = os.path.abspath(fname)
if os.path.isfile(fname):
if re.match(r'^[A-Za-z]:\\', fname):
fname = fname.replace('\\', '/')
fname = 'file://' + fname
item.audio = fname
item.audio = _conv_fname(item.audio)
if item.video:
if isinstance(item.video, str):
item.video = _conv_fname(item.video)
else:
assert isinstance(item.video, list)
new_url = []
for fname in item.video:
new_url.append(_conv_fname(fname))
item.video = new_url
return messages


def _conv_fname(fname: str) -> str:
ori_fname = fname
if not fname.startswith((
'http://',
'https://',
'file://',
'data:', # base64 such as f"data:image/jpg;base64,{image_base64}"
)):
if fname.startswith('~'):
fname = os.path.expanduser(fname)
fname = os.path.abspath(fname)
if os.path.isfile(fname):
if re.match(r'^[A-Za-z]:\\', fname):
fname = fname.replace('\\', '/')
fname = 'file://' + fname
return fname

return ori_fname


def _extract_vl_response(response) -> List[Message]:
output = response.output.choices[0].message
text_content = []
Expand Down
17 changes: 11 additions & 6 deletions qwen_agent/llm/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
FILE = 'file'
IMAGE = 'image'
AUDIO = 'audio'
VIDEO = 'video'


class BaseModelCompatibleDict(BaseModel):
Expand Down Expand Up @@ -66,13 +67,15 @@ class ContentItem(BaseModelCompatibleDict):
image: Optional[str] = None
file: Optional[str] = None
audio: Optional[str] = None
video: Optional[Union[str, list]] = None

def __init__(self,
text: Optional[str] = None,
image: Optional[str] = None,
file: Optional[str] = None,
audio: Optional[str] = None):
super().__init__(text=text, image=image, file=file, audio=audio)
audio: Optional[str] = None,
video: Optional[Union[str, list]] = None):
super().__init__(text=text, image=image, file=file, audio=audio, video=video)

@model_validator(mode='after')
def check_exclusivity(self):
Expand All @@ -85,21 +88,23 @@ def check_exclusivity(self):
provided_fields += 1
if self.audio:
provided_fields += 1
if self.video:
provided_fields += 1

if provided_fields != 1:
raise ValueError("Exactly one of 'text', 'image', 'file', or 'audio' must be provided.")
raise ValueError("Exactly one of 'text', 'image', 'file', 'audio', or 'video' must be provided.")
return self

def __repr__(self):
return f'ContentItem({self.model_dump()})'

def get_type_and_value(self) -> Tuple[Literal['text', 'image', 'file', 'audio'], str]:
def get_type_and_value(self) -> Tuple[Literal['text', 'image', 'file', 'audio', 'video'], str]:
(t, v), = self.model_dump().items()
assert t in ('text', 'image', 'file', 'audio')
assert t in ('text', 'image', 'file', 'audio', 'video')
return t, v

@property
def type(self) -> Literal['text', 'image', 'file', 'audio']:
def type(self) -> Literal['text', 'image', 'file', 'audio', 'video']:
t, v = self.get_type_and_value()
return t

Expand Down
26 changes: 19 additions & 7 deletions qwen_agent/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ def json_dumps_compact(obj: dict, ensure_ascii=False, indent=None, **kwargs) ->
def format_as_multimodal_message(
msg: Message,
add_upload_info: bool,
add_multimodel_upload_info: bool,
lang: Literal['auto', 'en', 'zh'] = 'auto',
) -> Message:
assert msg.role in (USER, ASSISTANT, SYSTEM, FUNCTION)
Expand All @@ -324,24 +325,32 @@ def format_as_multimodal_message(
files = []
for item in msg.content:
k, v = item.get_type_and_value()
if k == 'text':
content.append(ContentItem(text=v))
if k in ('image', 'audio'):
if k in ('text', 'image', 'audio', 'video'):
content.append(item)
if k == 'file':
# Move 'file' out of 'content' since it's not natively supported by models
files.append(v)
if add_multimodel_upload_info and k == 'image':
# Indicate the image name
# Not considering audio and video for now
files.append(v)
if add_upload_info and files and (msg.role in (SYSTEM, USER)):
if lang == 'auto':
has_zh = has_chinese_chars(msg)
else:
has_zh = (lang == 'zh')
upload = []
for f in [get_basename_from_url(f) for f in files]:
if has_zh:
upload.append(f'[文件]({f})')
if is_image(f):
if has_zh:
upload.append(f'![图片]({f})')
else:
upload.append(f'![image]({f})')
else:
upload.append(f'[file]({f})')
if has_zh:
upload.append(f'[文件]({f})')
else:
upload.append(f'[file]({f})')
upload = ' '.join(upload)
if has_zh:
upload = f'(上传了 {upload}\n\n'
Expand Down Expand Up @@ -372,7 +381,10 @@ def format_as_text_message(
add_upload_info: bool,
lang: Literal['auto', 'en', 'zh'] = 'auto',
) -> Message:
msg = format_as_multimodal_message(msg, add_upload_info=add_upload_info, lang=lang)
msg = format_as_multimodal_message(msg,
add_upload_info=add_upload_info,
add_multimodel_upload_info=add_upload_info,
lang=lang)
text = ''
for item in msg.content:
if item.type == 'text':
Expand Down
5 changes: 5 additions & 0 deletions tests/examples/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from examples.llm_vl_mix_text import test as llm_vl_mix_text # noqa
from examples.multi_agent_router import test as multi_agent_router # noqa
from examples.qwen2vl_assistant_tooluse import test as qwen2vl_assistant_tooluse # noqa
from examples.qwen2vl_assistant_video import test as test_video # noqa
from examples.react_data_analysis import test as react_data_analysis # noqa
from examples.visual_storytelling import test as visual_storytelling # noqa

Expand Down Expand Up @@ -86,3 +87,7 @@ def test_group_chat_demo():

def test_qwen2vl_assistant_tooluse():
qwen2vl_assistant_tooluse()


def test_video_understanding():
test_video()

0 comments on commit c0349d0

Please sign in to comment.