From c0349d064a1924866788ab9aa04a4e4ecea8d041 Mon Sep 17 00:00:00 2001 From: "tujianhong.tjh" Date: Fri, 13 Dec 2024 14:50:24 +0800 Subject: [PATCH] support optional image prefix for vl fn-call --- examples/qwen2vl_assistant_video.py | 38 ++++++++++++ qwen_agent/__init__.py | 2 +- qwen_agent/llm/base.py | 12 +++- .../llm/fncall_prompts/base_fncall_prompt.py | 5 +- qwen_agent/llm/function_calling.py | 2 +- qwen_agent/llm/qwenvl_dashscope.py | 61 ++++++++++--------- qwen_agent/llm/schema.py | 17 ++++-- qwen_agent/utils/utils.py | 26 +++++--- tests/examples/test_examples.py | 5 ++ 9 files changed, 120 insertions(+), 48 deletions(-) create mode 100644 examples/qwen2vl_assistant_video.py diff --git a/examples/qwen2vl_assistant_video.py b/examples/qwen2vl_assistant_video.py new file mode 100644 index 0000000..b0ae88b --- /dev/null +++ b/examples/qwen2vl_assistant_video.py @@ -0,0 +1,38 @@ +from qwen_agent.agents import Assistant + + +def test(): + bot = Assistant(llm={'model': 'qwen-vl-max-latest'}) + + messages = [{ + 'role': + 'user', + 'content': [{ + 'video': [ + 'https://help-static-aliyun-doc.aliyuncs.com/file-manage-files/zh-CN/20241108/xzsgiz/football1.jpg', + 'https://help-static-aliyun-doc.aliyuncs.com/file-manage-files/zh-CN/20241108/tdescd/football2.jpg', + 'https://help-static-aliyun-doc.aliyuncs.com/file-manage-files/zh-CN/20241108/zefdja/football3.jpg', + 'https://help-static-aliyun-doc.aliyuncs.com/file-manage-files/zh-CN/20241108/aedbqh/football4.jpg' + ] + }, { + 'text': 'Describe the specific process of this video' + }] + }] + + # Uploading video files requires applying for permission on DashScope + # messages = [{ + # 'role': + # 'user', + # 'content': [{ + # 'video': 'https://www.runoob.com/try/demo_source/mov_bbb.mp4' + # }, { + # 'text': 'Describe the specific process of this video' + # }] + # }] + + for rsp in bot.run(messages): + print(rsp) + + +if __name__ == '__main__': + test() diff --git a/qwen_agent/__init__.py b/qwen_agent/__init__.py index a929830..0e43d38 100644 --- a/qwen_agent/__init__.py +++ b/qwen_agent/__init__.py @@ -1,4 +1,4 @@ -__version__ = '0.0.14' +__version__ = '0.0.15' from .agent import Agent from .multi_agent_hub import MultiAgentHub diff --git a/qwen_agent/llm/base.py b/qwen_agent/llm/base.py index f6cd4a8..dee43aa 100644 --- a/qwen_agent/llm/base.py +++ b/qwen_agent/llm/base.py @@ -298,7 +298,12 @@ def _preprocess_messages( generate_cfg: dict, functions: Optional[List[Dict]] = None, ) -> List[Message]: - messages = [format_as_multimodal_message(msg, add_upload_info=True, lang=lang) for msg in messages] + messages = [ + format_as_multimodal_message(msg, + add_upload_info=True, + add_multimodel_upload_info=(functions is not None), + lang=lang) for msg in messages + ] return messages def _postprocess_messages( @@ -307,7 +312,10 @@ def _postprocess_messages( fncall_mode: bool, generate_cfg: dict, ) -> List[Message]: - messages = [format_as_multimodal_message(msg, add_upload_info=False) for msg in messages] + messages = [ + format_as_multimodal_message(msg, add_upload_info=False, add_multimodel_upload_info=False) + for msg in messages + ] if not generate_cfg.get('skip_stopword_postproc', False): stop = generate_cfg.get('stop', []) messages = _postprocess_stop_words(messages, stop=stop) diff --git a/qwen_agent/llm/fncall_prompts/base_fncall_prompt.py b/qwen_agent/llm/fncall_prompts/base_fncall_prompt.py index 83d9194..2f073f6 100644 --- a/qwen_agent/llm/fncall_prompts/base_fncall_prompt.py +++ b/qwen_agent/llm/fncall_prompts/base_fncall_prompt.py @@ -52,7 +52,10 @@ def format_plaintext_train_samples( if has_para: raise ValueError('This sample requires parallel_function_calls=True.') - messages = [format_as_multimodal_message(msg, add_upload_info=True, lang=lang) for msg in messages] + messages = [ + format_as_multimodal_message(msg, add_upload_info=True, add_multimodel_upload_info=True, lang=lang) + for msg in messages + ] for m in messages: for item in m.content: if item.type != 'text': diff --git a/qwen_agent/llm/function_calling.py b/qwen_agent/llm/function_calling.py index 84cc082..f7644c8 100644 --- a/qwen_agent/llm/function_calling.py +++ b/qwen_agent/llm/function_calling.py @@ -29,7 +29,7 @@ def _preprocess_messages( generate_cfg: dict, functions: Optional[List[Dict]] = None, ) -> List[Message]: - messages = super()._preprocess_messages(messages, lang=lang, generate_cfg=generate_cfg) + messages = super()._preprocess_messages(messages, lang=lang, generate_cfg=generate_cfg, functions=functions) if (not functions) or (generate_cfg.get('function_choice', 'auto') == 'none'): messages = self._remove_fncall_messages(messages, lang=lang) else: diff --git a/qwen_agent/llm/qwenvl_dashscope.py b/qwen_agent/llm/qwenvl_dashscope.py index e4642a2..726680a 100644 --- a/qwen_agent/llm/qwenvl_dashscope.py +++ b/qwen_agent/llm/qwenvl_dashscope.py @@ -78,40 +78,41 @@ def _format_local_files(messages: List[Message]) -> List[Message]: if isinstance(msg.content, list): for item in msg.content: if item.image: - fname = item.image - if not fname.startswith(( - 'http://', - 'https://', - 'file://', - 'data:', # base64 such as f"data:image/jpg;base64,{image_base64}" - )): - if fname.startswith('~'): - fname = os.path.expanduser(fname) - fname = os.path.abspath(fname) - if os.path.isfile(fname): - if re.match(r'^[A-Za-z]:\\', fname): - fname = fname.replace('\\', '/') - fname = 'file://' + fname - item.image = fname + item.image = _conv_fname(item.image) if item.audio: - fname = item.audio - if not fname.startswith(( - 'http://', - 'https://', - 'file://', - 'data:', # base64 such as f"data:image/jpg;base64,{image_base64}" - )): - if fname.startswith('~'): - fname = os.path.expanduser(fname) - fname = os.path.abspath(fname) - if os.path.isfile(fname): - if re.match(r'^[A-Za-z]:\\', fname): - fname = fname.replace('\\', '/') - fname = 'file://' + fname - item.audio = fname + item.audio = _conv_fname(item.audio) + if item.video: + if isinstance(item.video, str): + item.video = _conv_fname(item.video) + else: + assert isinstance(item.video, list) + new_url = [] + for fname in item.video: + new_url.append(_conv_fname(fname)) + item.video = new_url return messages +def _conv_fname(fname: str) -> str: + ori_fname = fname + if not fname.startswith(( + 'http://', + 'https://', + 'file://', + 'data:', # base64 such as f"data:image/jpg;base64,{image_base64}" + )): + if fname.startswith('~'): + fname = os.path.expanduser(fname) + fname = os.path.abspath(fname) + if os.path.isfile(fname): + if re.match(r'^[A-Za-z]:\\', fname): + fname = fname.replace('\\', '/') + fname = 'file://' + fname + return fname + + return ori_fname + + def _extract_vl_response(response) -> List[Message]: output = response.output.choices[0].message text_content = [] diff --git a/qwen_agent/llm/schema.py b/qwen_agent/llm/schema.py index 3886669..137c51b 100644 --- a/qwen_agent/llm/schema.py +++ b/qwen_agent/llm/schema.py @@ -16,6 +16,7 @@ FILE = 'file' IMAGE = 'image' AUDIO = 'audio' +VIDEO = 'video' class BaseModelCompatibleDict(BaseModel): @@ -66,13 +67,15 @@ class ContentItem(BaseModelCompatibleDict): image: Optional[str] = None file: Optional[str] = None audio: Optional[str] = None + video: Optional[Union[str, list]] = None def __init__(self, text: Optional[str] = None, image: Optional[str] = None, file: Optional[str] = None, - audio: Optional[str] = None): - super().__init__(text=text, image=image, file=file, audio=audio) + audio: Optional[str] = None, + video: Optional[Union[str, list]] = None): + super().__init__(text=text, image=image, file=file, audio=audio, video=video) @model_validator(mode='after') def check_exclusivity(self): @@ -85,21 +88,23 @@ def check_exclusivity(self): provided_fields += 1 if self.audio: provided_fields += 1 + if self.video: + provided_fields += 1 if provided_fields != 1: - raise ValueError("Exactly one of 'text', 'image', 'file', or 'audio' must be provided.") + raise ValueError("Exactly one of 'text', 'image', 'file', 'audio', or 'video' must be provided.") return self def __repr__(self): return f'ContentItem({self.model_dump()})' - def get_type_and_value(self) -> Tuple[Literal['text', 'image', 'file', 'audio'], str]: + def get_type_and_value(self) -> Tuple[Literal['text', 'image', 'file', 'audio', 'video'], str]: (t, v), = self.model_dump().items() - assert t in ('text', 'image', 'file', 'audio') + assert t in ('text', 'image', 'file', 'audio', 'video') return t, v @property - def type(self) -> Literal['text', 'image', 'file', 'audio']: + def type(self) -> Literal['text', 'image', 'file', 'audio', 'video']: t, v = self.get_type_and_value() return t diff --git a/qwen_agent/utils/utils.py b/qwen_agent/utils/utils.py index e97e07d..1faa45d 100644 --- a/qwen_agent/utils/utils.py +++ b/qwen_agent/utils/utils.py @@ -313,6 +313,7 @@ def json_dumps_compact(obj: dict, ensure_ascii=False, indent=None, **kwargs) -> def format_as_multimodal_message( msg: Message, add_upload_info: bool, + add_multimodel_upload_info: bool, lang: Literal['auto', 'en', 'zh'] = 'auto', ) -> Message: assert msg.role in (USER, ASSISTANT, SYSTEM, FUNCTION) @@ -324,13 +325,15 @@ def format_as_multimodal_message( files = [] for item in msg.content: k, v = item.get_type_and_value() - if k == 'text': - content.append(ContentItem(text=v)) - if k in ('image', 'audio'): + if k in ('text', 'image', 'audio', 'video'): content.append(item) if k == 'file': # Move 'file' out of 'content' since it's not natively supported by models files.append(v) + if add_multimodel_upload_info and k == 'image': + # Indicate the image name + # Not considering audio and video for now + files.append(v) if add_upload_info and files and (msg.role in (SYSTEM, USER)): if lang == 'auto': has_zh = has_chinese_chars(msg) @@ -338,10 +341,16 @@ def format_as_multimodal_message( has_zh = (lang == 'zh') upload = [] for f in [get_basename_from_url(f) for f in files]: - if has_zh: - upload.append(f'[文件]({f})') + if is_image(f): + if has_zh: + upload.append(f'![图片]({f})') + else: + upload.append(f'![image]({f})') else: - upload.append(f'[file]({f})') + if has_zh: + upload.append(f'[文件]({f})') + else: + upload.append(f'[file]({f})') upload = ' '.join(upload) if has_zh: upload = f'(上传了 {upload})\n\n' @@ -372,7 +381,10 @@ def format_as_text_message( add_upload_info: bool, lang: Literal['auto', 'en', 'zh'] = 'auto', ) -> Message: - msg = format_as_multimodal_message(msg, add_upload_info=add_upload_info, lang=lang) + msg = format_as_multimodal_message(msg, + add_upload_info=add_upload_info, + add_multimodel_upload_info=add_upload_info, + lang=lang) text = '' for item in msg.content: if item.type == 'text': diff --git a/tests/examples/test_examples.py b/tests/examples/test_examples.py index 94b811e..a0d24f5 100644 --- a/tests/examples/test_examples.py +++ b/tests/examples/test_examples.py @@ -17,6 +17,7 @@ from examples.llm_vl_mix_text import test as llm_vl_mix_text # noqa from examples.multi_agent_router import test as multi_agent_router # noqa from examples.qwen2vl_assistant_tooluse import test as qwen2vl_assistant_tooluse # noqa +from examples.qwen2vl_assistant_video import test as test_video # noqa from examples.react_data_analysis import test as react_data_analysis # noqa from examples.visual_storytelling import test as visual_storytelling # noqa @@ -86,3 +87,7 @@ def test_group_chat_demo(): def test_qwen2vl_assistant_tooluse(): qwen2vl_assistant_tooluse() + + +def test_video_understanding(): + test_video()