From cb86572c761ddf570d40eece71c7814dafaaaba4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=BC=E6=AC=A3?= Date: Mon, 2 Sep 2024 19:18:26 +0800 Subject: [PATCH] refactor function calling to allow for flexible prompt modification --- examples/qwen2vl_assistant_tooluse.py | 25 +- examples/qwen2vl_function_calling.py | 19 +- qwen_agent/__init__.py | 2 +- qwen_agent/agent.py | 11 +- qwen_agent/agents/__init__.py | 3 +- qwen_agent/agents/fncall_agent.py | 1 + qwen_agent/llm/__init__.py | 20 +- qwen_agent/llm/azure.py | 87 +-- qwen_agent/llm/base.py | 24 +- .../llm/fncall_prompts/base_fncall_prompt.py | 69 +++ .../llm/fncall_prompts/nous_fncall_prompt.py | 91 +++ .../llm/fncall_prompts/qwen_fncall_prompt.py | 387 +++++++++++++ qwen_agent/llm/function_calling.py | 521 +++--------------- qwen_agent/llm/oai.py | 4 +- qwen_agent/tools/base.py | 87 ++- qwen_agent/utils/tokenization_qwen.py | 6 - qwen_agent/utils/utils.py | 67 ++- setup.py | 5 + tests/agents/test_react_chat.py | 4 - tests/llm/test_function_content.py | 8 +- tests/llm/test_oai.py | 6 +- 21 files changed, 849 insertions(+), 598 deletions(-) create mode 100644 qwen_agent/llm/fncall_prompts/base_fncall_prompt.py create mode 100644 qwen_agent/llm/fncall_prompts/nous_fncall_prompt.py create mode 100644 qwen_agent/llm/fncall_prompts/qwen_fncall_prompt.py diff --git a/examples/qwen2vl_assistant_tooluse.py b/examples/qwen2vl_assistant_tooluse.py index 5606c32..8d0f6d8 100644 --- a/examples/qwen2vl_assistant_tooluse.py +++ b/examples/qwen2vl_assistant_tooluse.py @@ -46,17 +46,12 @@ def call(self, params: Union[str, dict], files: List[str] = None, **kwargs) -> s host = 'https://wuliu.market.alicloudapi.com' path = '/kdi' - method = 'GET' appcode = os.environ['AppCode_ExpressTracking'] # 开通服务后 买家中心-查看AppCode querys = f'no={id}&type={company}' - bodys = {} url = host + path + '?' + querys header = {'Authorization': 'APPCODE ' + appcode} - try: - res = requests.get(url, headers=header) - except: - return 'URL错误' + res = requests.get(url, headers=header) httpStatusCode = res.status_code if (httpStatusCode == 200): @@ -64,8 +59,9 @@ def call(self, params: Union[str, dict], files: List[str] = None, **kwargs) -> s import json try: out = json.loads(res.text) - except: - out = eval(res.text) + except json.decoder.JSONDecodeError: + import json5 + out = json5.loads(res.text) return '```json' + json.dumps(out, ensure_ascii=False, indent=4) + '\n```' else: httpReason = res.headers['X-Ca-Error-Message'] @@ -143,10 +139,8 @@ def call(self, params: Union[str, dict], files: List[str] = None, **kwargs) -> s host = 'https://ali-weather.showapi.com' path = '/spot-to-weather' - method = 'GET' appcode = os.environ['AppCode_Area2Weather'] # 开通服务后 买家中心-查看AppCode querys = f'area={area}&needMoreDay={needMoreDay}&needIndex={needIndex}&needHourData={needHourData}&need3HourForcast={need3HourForcast}&needAlarm={needAlarm}' - bodys = {} url = host + path + '?' + querys request = urllib.request.Request(url) @@ -181,10 +175,8 @@ def call(self, params: Union[str, dict], files: List[str] = None, **kwargs) -> s host = 'https://ali-weather.showapi.com' path = '/hour24' - method = 'GET' appcode = os.environ('AppCode_weather_hour24') # 开通服务后 买家中心-查看AppCode querys = f'area={area}&areaCode=' - bodys = {} url = host + path + '?' + querys request = urllib.request.Request(url) @@ -283,18 +275,13 @@ def init_agent_service(): llm_cfg_vl = { # Using Qwen2-VL deployed at any openai-compatible service such as vLLM: # 'model_type': 'qwenvl_oai', - # 'model': 'Qwen/Qwen2-VL-72B-Instruct', + # 'model': 'Qwen2-VL-7B-Instruct', # 'model_server': 'http://localhost:8000/v1', # api_base # 'api_key': 'EMPTY', # Using Qwen2-VL provided by Alibaba Cloud DashScope: - # 'model_type': 'qwenvl_dashscope', - # 'model': 'qwen2-vl-72b-instruct', - # 'api_key': os.getenv('DASHSCOPE_API_KEY'), - - # TODO: Use qwen2-vl instead once qwen2-vl is released. 'model_type': 'qwenvl_dashscope', - 'model': 'qwen-vl-max', + 'model': 'qwen-vl-max-0809', 'api_key': os.getenv('DASHSCOPE_API_KEY'), 'generate_cfg': dict(max_retries=10,) } diff --git a/examples/qwen2vl_function_calling.py b/examples/qwen2vl_function_calling.py index 9fda76c..565f7a0 100644 --- a/examples/qwen2vl_function_calling.py +++ b/examples/qwen2vl_function_calling.py @@ -1,13 +1,16 @@ import json +import os import urllib.parse from qwen_agent.llm import get_chat_model from qwen_agent.llm.schema import ContentItem +from qwen_agent.utils.utils import save_url_to_local_work_dir def image_gen(prompt: str) -> str: prompt = urllib.parse.quote(prompt) image_url = f'https://image.pollinations.ai/prompt/{prompt}' + image_url = save_url_to_local_work_dir(image_url, save_dir='./', save_filename='pic.jpg') return image_url @@ -15,10 +18,16 @@ def test(): # Config for the model llm_cfg_oai = { # Using Qwen2-VL deployed at any openai-compatible service such as vLLM: - 'model_type': 'qwenvl_oai', - 'model': 'Qwen/Qwen2-VL-72B-Instruct', - 'model_server': 'http://localhost:8000/v1', # api_base - 'api_key': 'EMPTY', + # 'model_type': 'qwenvl_oai', + # 'model': 'Qwen2-VL-7B-Instruct', + # 'model_server': 'http://localhost:8000/v1', # api_base + # 'api_key': 'EMPTY', + + # Using Qwen2-VL provided by Alibaba Cloud DashScope: + 'model_type': 'qwenvl_dashscope', + 'model': 'qwen-vl-max-0809', + 'api_key': os.getenv('DASHSCOPE_API_KEY'), + 'generate_cfg': dict(max_retries=10,) } llm = get_chat_model(llm_cfg_oai) @@ -29,7 +38,7 @@ def test(): 'content': [{ 'image': 'https://dashscope.oss-cn-beijing.aliyuncs.com/images/dog_and_girl.jpeg' }, { - 'text': '图片中的内容是什么?请画一张内容相同,风格类似的图片。' + 'text': '图片中的内容是什么?请画一张内容相同,风格类似的图片。把女人换成男人' }] }] diff --git a/qwen_agent/__init__.py b/qwen_agent/__init__.py index 1a2eda2..d625c9a 100644 --- a/qwen_agent/__init__.py +++ b/qwen_agent/__init__.py @@ -1,4 +1,4 @@ -__version__ = '0.0.7' +__version__ = '0.0.8' from .agent import Agent from .multi_agent_hub import MultiAgentHub diff --git a/qwen_agent/agent.py b/qwen_agent/agent.py index 9cfe330..6b3deb5 100644 --- a/qwen_agent/agent.py +++ b/qwen_agent/agent.py @@ -21,7 +21,7 @@ class Agent(ABC): def __init__(self, function_list: Optional[List[Union[str, Dict, BaseTool]]] = None, - llm: Optional[Union[Dict, BaseChatModel]] = None, + llm: Optional[Union[dict, BaseChatModel]] = None, system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE, name: Optional[str] = None, description: Optional[str] = None, @@ -48,7 +48,7 @@ def __init__(self, for tool in function_list: self._init_tool(tool) - self.system_message = system_message + self.system_message = system_message or self.SYSTEM_MESSAGE self.name = name self.description = description @@ -226,3 +226,10 @@ def _detect_tool(self, message: Message) -> Tuple[bool, str, str, str]: text = '' return (func_name is not None), func_name, func_args, text + + +# The most basic form of an agent is just a LLM, not augmented with any tool or workflow. +class BasicAgent(Agent): + + def _run(self, messages: List[Message], lang: str = 'en', **kwargs) -> Iterator[List[Message]]: + return self._call_llm(messages) diff --git a/qwen_agent/agents/__init__.py b/qwen_agent/agents/__init__.py index d71c318..b09e889 100644 --- a/qwen_agent/agents/__init__.py +++ b/qwen_agent/agents/__init__.py @@ -1,4 +1,4 @@ -from qwen_agent.agent import Agent +from qwen_agent.agent import Agent, BasicAgent from qwen_agent.multi_agent_hub import MultiAgentHub from .article_agent import ArticleAgent @@ -20,6 +20,7 @@ __all__ = [ 'Agent', + 'BasicAgent', 'MultiAgentHub', 'DocQAAgent', 'ParallelDocQA', diff --git a/qwen_agent/agents/fncall_agent.py b/qwen_agent/agents/fncall_agent.py index 83035c8..50e650a 100644 --- a/qwen_agent/agents/fncall_agent.py +++ b/qwen_agent/agents/fncall_agent.py @@ -75,6 +75,7 @@ def _run(self, messages: List[Message], lang: Literal['en', 'zh'] = 'en', **kwar used_any_tool = True if not used_any_tool: break + yield response def _call_tool(self, tool_name: str, tool_args: Union[str, dict] = '{}', **kwargs) -> str: if tool_name not in self.function_map: diff --git a/qwen_agent/llm/__init__.py b/qwen_agent/llm/__init__.py index 2676cb6..c8e873f 100644 --- a/qwen_agent/llm/__init__.py +++ b/qwen_agent/llm/__init__.py @@ -1,15 +1,15 @@ -from typing import Dict, Optional +from typing import Union +from .azure import TextChatAtAzure from .base import LLM_REGISTRY, BaseChatModel, ModelServiceError from .oai import TextChatAtOAI from .openvino import OpenVINO from .qwen_dashscope import QwenChatAtDS from .qwenvl_dashscope import QwenVLChatAtDS from .qwenvl_oai import QwenVLChatAtOAI -from .azure import TextChatAtAZURE -def get_chat_model(cfg: Optional[Dict] = None) -> BaseChatModel: +def get_chat_model(cfg: Union[dict, str] = 'qwen-plus') -> BaseChatModel: """The interface of instantiating LLM objects. Args: @@ -30,7 +30,9 @@ def get_chat_model(cfg: Optional[Dict] = None) -> BaseChatModel: Returns: LLM object. """ - cfg = cfg or {} + if isinstance(cfg, str): + cfg = {'model': cfg} + if 'model_type' in cfg: model_type = cfg['model_type'] if model_type in LLM_REGISTRY: @@ -40,15 +42,15 @@ def get_chat_model(cfg: Optional[Dict] = None) -> BaseChatModel: # Deduce model_type from model and model_server if model_type is not provided: + if 'azure_endpoint' in cfg: + model_type = 'azure' + return LLM_REGISTRY[model_type](cfg) + if 'model_server' in cfg: if cfg['model_server'].strip().startswith('http'): model_type = 'oai' return LLM_REGISTRY[model_type](cfg) - if 'azure_endpoint' in cfg: - model_type = 'azure' - return LLM_REGISTRY[model_type](cfg) - model = cfg.get('model', '') if 'qwen-vl' in model: @@ -66,7 +68,7 @@ def get_chat_model(cfg: Optional[Dict] = None) -> BaseChatModel: 'BaseChatModel', 'QwenChatAtDS', 'TextChatAtOAI', - 'TextChatAtAZURE', + 'TextChatAtAzure', 'QwenVLChatAtDS', 'QwenVLChatAtOAI', 'OpenVINO', diff --git a/qwen_agent/llm/azure.py b/qwen_agent/llm/azure.py index 2e8e0bb..fc30d4f 100644 --- a/qwen_agent/llm/azure.py +++ b/qwen_agent/llm/azure.py @@ -1,32 +1,25 @@ import copy -import logging import os -from pprint import pformat -from typing import Dict, Iterator, List, Optional +from typing import Dict, Optional import openai -from openai import OpenAIError - -from qwen_agent.llm.base import ModelServiceError, register_llm -from qwen_agent.llm.function_calling import BaseFnCallModel -from qwen_agent.llm.schema import ASSISTANT, Message -from qwen_agent.log import logger +from qwen_agent.llm.base import register_llm +from qwen_agent.llm.oai import TextChatAtOAI @register_llm('azure') -class TextChatAtAZURE(BaseFnCallModel): +class TextChatAtAzure(TextChatAtOAI): def __init__(self, cfg: Optional[Dict] = None): super().__init__(cfg) - self.model = self.model or 'gpt-3.5-turbo' cfg = cfg or {} api_base = cfg.get( 'api_base', cfg.get( 'base_url', - cfg.get('model_server', cfg.get('azure_endpoint','')), + cfg.get('model_server', cfg.get('azure_endpoint', '')), ), ).strip() @@ -34,9 +27,9 @@ def __init__(self, cfg: Optional[Dict] = None): if not api_key: api_key = os.getenv('OPENAI_API_KEY', 'EMPTY') api_key = api_key.strip() - - api_version = cfg.get('api_version','2024-06-01') - + + api_version = cfg.get('api_version', '2024-06-01') + api_kwargs = {} if api_base: api_kwargs['azure_endpoint'] = api_base @@ -44,7 +37,6 @@ def __init__(self, cfg: Optional[Dict] = None): api_kwargs['api_key'] = api_key if api_version: api_kwargs['api_version'] = api_version - def _chat_complete_create(*args, **kwargs): # OpenAI API v1 does not allow the following args, must pass by extra_body @@ -56,71 +48,10 @@ def _chat_complete_create(*args, **kwargs): kwargs['extra_body'][k] = kwargs.pop(k) if 'request_timeout' in kwargs: kwargs['timeout'] = kwargs.pop('request_timeout') - + client = openai.AzureOpenAI(**api_kwargs) # client = openai.OpenAI(**api_kwargs) return client.chat.completions.create(*args, **kwargs) self._chat_complete_create = _chat_complete_create - - def _chat_stream( - self, - messages: List[Message], - delta_stream: bool, - generate_cfg: dict, - ) -> Iterator[List[Message]]: - messages = [msg.model_dump() for msg in messages] - if logger.isEnabledFor(logging.DEBUG): - logger.debug(f'LLM Input:\n{pretty_format_messages(messages, indent=2)}') - try: - response = self._chat_complete_create(model=self.model, messages=messages, stream=True, **generate_cfg) - if delta_stream: - for chunk in response: - if logger.isEnabledFor(logging.DEBUG): - logger.debug(f'Chunk received: {chunk}') - if chunk.choices and hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content: - yield [Message(ASSISTANT, chunk.choices[0].delta.content)] - else: - full_response = '' - for chunk in response: - if logger.isEnabledFor(logging.DEBUG): - logger.debug(f'Chunk received: {chunk}') - if chunk.choices and hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content: - full_response += chunk.choices[0].delta.content - yield [Message(ASSISTANT, full_response)] - except OpenAIError as ex: - raise ModelServiceError(exception=ex) - - def _chat_no_stream( - self, - messages: List[Message], - generate_cfg: dict, - ) -> List[Message]: - messages = [msg.model_dump() for msg in messages] - if logger.isEnabledFor(logging.DEBUG): - logger.debug(f'LLM Input:\n{pretty_format_messages(messages, indent=2)}') - try: - response = self._chat_complete_create(model=self.model, messages=messages, stream=False, **generate_cfg) - return [Message(ASSISTANT, response.choices[0].message.content)] - except OpenAIError as ex: - raise ModelServiceError(exception=ex) - - -def pretty_format_messages(messages: List[dict], indent: int = 2) -> str: - messages_show = [] - for msg in messages: - assert isinstance(msg, dict) - msg_show = copy.deepcopy(msg) - if isinstance(msg['content'], list): - content = [] - for item in msg['content']: - (t, v), = item.items() - if (t != 'text') and v.startswith('data:'): - v = v[:64] + ' ...' - content.append({t: v}) - else: - content = msg['content'] - msg_show['content'] = content - messages_show.append(msg_show) - return pformat(messages_show, indent=indent) \ No newline at end of file diff --git a/qwen_agent/llm/base.py b/qwen_agent/llm/base.py index 913ddd3..b3a4ae4 100644 --- a/qwen_agent/llm/base.py +++ b/qwen_agent/llm/base.py @@ -10,7 +10,7 @@ from qwen_agent.settings import DEFAULT_MAX_INPUT_TOKENS from qwen_agent.utils.tokenization_qwen import tokenizer from qwen_agent.utils.utils import (extract_text_from_message, format_as_multimodal_message, format_as_text_message, - has_chinese_messages, merge_generate_cfgs, print_traceback) + has_chinese_messages, merge_generate_cfgs) LLM_REGISTRY = {} @@ -59,6 +59,13 @@ def __init__(self, cfg: Optional[Dict] = None): self.max_retries = generate_cfg.pop('max_retries', 0) self.generate_cfg = generate_cfg + def quick_chat(self, prompt: str) -> str: + responses = self.chat(messages=[Message(role=USER, content=prompt)], stream=False) + assert len(responses) == 1 + assert not responses[0].function_call + assert isinstance(responses[0].content, str) + return responses[0].content + def chat( self, messages: List[Union[Message, Dict]], @@ -90,6 +97,8 @@ def chat( ) generate_cfg = merge_generate_cfgs(base_generate_cfg=self.generate_cfg, new_generate_cfg=extra_generate_cfg) + if 'seed' not in generate_cfg: + generate_cfg['seed'] = random.randint(a=0, b=2**30) if 'lang' in generate_cfg: lang: Literal['en', 'zh'] = generate_cfg.pop('lang') else: @@ -133,7 +142,7 @@ def chat( fncall_mode = False # Note: the preprocessor's behavior could change if it receives function_choice="none" - messages = self._preprocess_messages(messages, lang=lang, generate_cfg=generate_cfg) + messages = self._preprocess_messages(messages, lang=lang, generate_cfg=generate_cfg, functions=functions) if not self.support_multimodal_input: messages = [format_as_text_message(msg, add_upload_info=False) for msg in messages] @@ -227,8 +236,13 @@ def _chat_no_stream( ) -> List[Message]: raise NotImplementedError - def _preprocess_messages(self, messages: List[Message], lang: Literal['en', 'zh'], - generate_cfg: dict) -> List[Message]: + def _preprocess_messages( + self, + messages: List[Message], + lang: Literal['en', 'zh'], + generate_cfg: dict, + functions: Optional[List[Dict]] = None, + ) -> List[Message]: messages = [format_as_multimodal_message(msg, add_upload_info=True, lang=lang) for msg in messages] return messages @@ -457,7 +471,7 @@ def _raise_or_delay( if 'maximum context length' in str(e): raise e - print_traceback(is_error=False) + logger.warning('ModelServiceError - ' + str(e).strip('\n')) if num_retries >= max_retries: raise ModelServiceError(exception=Exception(f'Maximum number of retries ({max_retries}) exceeded.')) diff --git a/qwen_agent/llm/fncall_prompts/base_fncall_prompt.py b/qwen_agent/llm/fncall_prompts/base_fncall_prompt.py new file mode 100644 index 0000000..83d9194 --- /dev/null +++ b/qwen_agent/llm/fncall_prompts/base_fncall_prompt.py @@ -0,0 +1,69 @@ +from typing import List, Literal, Union + +from qwen_agent.llm.schema import FUNCTION, Message +from qwen_agent.utils.utils import format_as_multimodal_message, format_as_text_message, has_chinese_messages + + +class BaseFnCallPrompt(object): + + @staticmethod + def preprocess_fncall_messages( + messages: List[Message], + functions: List[dict], + lang: Literal['en', 'zh'], + parallel_function_calls: bool = True, + function_choice: Union[Literal['auto'], str] = 'auto', + ) -> List[Message]: + """ + Preprocesss the messages and add the function calling prompt, + assuming the input and output messages are in the multimodal format. + """ + assert function_choice != 'none' + raise NotImplementedError + + @staticmethod + def postprocess_fncall_messages( + messages: List[Message], + parallel_function_calls: bool = True, + function_choice: Union[Literal['auto'], str] = 'auto', + ) -> List[Message]: + """ + Transform the plaintext model output into structured function call messages, + return in the multimodal format for consistency. + """ + raise NotImplementedError + + def format_plaintext_train_samples( + self, + messages: List[Union[Message, dict]], + functions: List[dict], + lang: Literal['auto', 'en', 'zh'] = 'auto', + parallel_function_calls: bool = True, + ) -> List[Message]: + messages = [m if isinstance(m, Message) else Message(**m) for m in messages] + + if lang == 'auto': + lang = 'zh' if has_chinese_messages(messages) else 'en' + + if not parallel_function_calls: + for i in range(len(messages) - 1): + has_para = (messages[i].function_call and messages[i + 1].function_call) + has_para = has_para or ((messages[i].role == FUNCTION) and (messages[i + 1].role == FUNCTION)) + if has_para: + raise ValueError('This sample requires parallel_function_calls=True.') + + messages = [format_as_multimodal_message(msg, add_upload_info=True, lang=lang) for msg in messages] + for m in messages: + for item in m.content: + if item.type != 'text': + raise NotImplementedError('Support for multimodal samples not implemented yet.') + + messages = self.preprocess_fncall_messages( + messages=messages, + functions=functions, + lang=lang, + parallel_function_calls=parallel_function_calls, + ) + + messages = [format_as_text_message(msg, add_upload_info=False) for msg in messages] + return messages diff --git a/qwen_agent/llm/fncall_prompts/nous_fncall_prompt.py b/qwen_agent/llm/fncall_prompts/nous_fncall_prompt.py new file mode 100644 index 0000000..a48e1cf --- /dev/null +++ b/qwen_agent/llm/fncall_prompts/nous_fncall_prompt.py @@ -0,0 +1,91 @@ +import copy +import json +from typing import List, Literal, Union + +from qwen_agent.llm.fncall_prompts.base_fncall_prompt import BaseFnCallPrompt +from qwen_agent.llm.schema import ASSISTANT, FUNCTION, SYSTEM, USER, ContentItem, Message + + +class NousFnCallPrompt(BaseFnCallPrompt): + + @staticmethod + def preprocess_fncall_messages( + messages: List[Message], + functions: List[dict], + lang: Literal['en', 'zh'], + parallel_function_calls: bool = True, + function_choice: Union[Literal['auto'], str] = 'auto', + ) -> List[Message]: + del lang # ignored + del parallel_function_calls # ignored + if function_choice != 'auto': + raise NotImplementedError + + ori_messages = messages + + # Change function_call responses to plaintext responses: + messages = [] + for msg in copy.deepcopy(ori_messages): + role, content = msg.role, msg.content + if role in (SYSTEM, USER): + messages.append(msg) + elif role == ASSISTANT: + content = (content or []) + fn_call = msg.function_call + if fn_call: + fc = {'name': fn_call.name, 'arguments': json.loads(fn_call.arguments)} + fc = json.dumps(fc, ensure_ascii=False) + fc = f'\n{fc}\n' + content.append(ContentItem(text=fc)) + if messages[-1].role == ASSISTANT: + messages[-1].content.append(ContentItem(text='\n')) + messages[-1].content.extend(content) + else: + messages.append(Message(role=role, content=content)) + elif role == FUNCTION: + assert isinstance(content, list) + assert len(content) == 1 + assert content[0].text + fc = f'\n{content[0].text}\n' + content = [ContentItem(text=fc)] + if messages[-1].role == USER: + messages[-1].content.append(ContentItem(text='\n')) + messages[-1].content.extend(content) + else: + messages.append(Message(role=USER, content=content)) + else: + raise TypeError + + tool_descs = [{'type': 'function', 'function': f} for f in functions] + tool_descs = '\n'.join([json.dumps(f, ensure_ascii=False) for f in tool_descs]) + tool_system = FN_CALL_TEMPLATE.format(tool_descs=tool_descs) + if messages[0].role == SYSTEM: + messages[0].content.append(ContentItem(text='\n\n' + tool_system)) + else: + messages = [Message(role=SYSTEM, content=[ContentItem(text=tool_system)])] + messages + return messages + + @staticmethod + def postprocess_fncall_messages( + messages: List[Message], + parallel_function_calls: bool = True, + function_choice: Union[Literal['auto'], str] = 'auto', + ) -> List[Message]: + if function_choice != 'auto': + raise NotImplementedError + raise NotImplementedError + + +FN_CALL_TEMPLATE = """# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{tool_descs} + + +For each function call, return a json object with function name and arguments within XML tags: + +{{"name": , "arguments": }} +""" diff --git a/qwen_agent/llm/fncall_prompts/qwen_fncall_prompt.py b/qwen_agent/llm/fncall_prompts/qwen_fncall_prompt.py new file mode 100644 index 0000000..2d4a339 --- /dev/null +++ b/qwen_agent/llm/fncall_prompts/qwen_fncall_prompt.py @@ -0,0 +1,387 @@ +import copy +import json +from typing import Dict, List, Literal, Union + +from qwen_agent.llm.fncall_prompts.base_fncall_prompt import BaseFnCallPrompt +from qwen_agent.llm.schema import ASSISTANT, FUNCTION, SYSTEM, USER, ContentItem, FunctionCall, Message +from qwen_agent.utils.utils import extract_text_from_message + + +class QwenFnCallPrompt(BaseFnCallPrompt): + + @staticmethod + def preprocess_fncall_messages( + messages: List[Message], + functions: List[dict], + lang: Literal['en', 'zh'], + parallel_function_calls: bool = True, + function_choice: Union[Literal['auto'], str] = 'auto', + ) -> List[Message]: + ori_messages = messages + + # Change function_call responses to plaintext responses: + messages = [] + for msg in copy.deepcopy(ori_messages): + role, content = msg.role, msg.content + if role in (SYSTEM, USER): + messages.append(msg) + elif role == ASSISTANT: + content = (content or []) + fn_call = msg.function_call + if fn_call: + f_name = fn_call.name + f_args = fn_call.arguments + if f_args.startswith('```'): # if code snippet + f_args = '\n' + f_args # for markdown rendering + func_content = '\n' if messages[-1].role == ASSISTANT else '' + func_content += f'{FN_NAME}: {f_name}' + func_content += f'\n{FN_ARGS}: {f_args}' + content.append(ContentItem(text=func_content)) + if messages[-1].role == ASSISTANT: + messages[-1].content += content + else: + messages.append(Message(role=role, content=content)) + elif role == FUNCTION: + assert messages[-1].role == ASSISTANT + assert isinstance(content, list) + assert all(isinstance(item, ContentItem) for item in content) + if content: + f_result = copy.deepcopy(content) + else: + f_result = [ContentItem(text='')] + f_exit = f'\n{FN_EXIT}: ' + last_text_content = messages[-1].content[-1].text + if last_text_content.endswith(f_exit): + messages[-1].content[-1].text = last_text_content[:-len(f_exit)] + f_result = [ContentItem(text=f'\n{FN_RESULT}: ')] + f_result + [ContentItem(text=f_exit)] + messages[-1].content += f_result + else: + raise TypeError + + # Add a system prompt for function calling: + tool_desc_template = FN_CALL_TEMPLATE[lang + ('_parallel' if parallel_function_calls else '')] + tool_descs = '\n\n'.join(get_function_description(function, lang=lang) for function in functions) + tool_names = ','.join(function.get('name_for_model', function.get('name', '')) for function in functions) + tool_system = tool_desc_template.format(tool_descs=tool_descs, tool_names=tool_names) + if messages[0].role == SYSTEM: + messages[0].content.append(ContentItem(text='\n\n' + tool_system)) + else: + messages = [Message(role=SYSTEM, content=[ContentItem(text=tool_system)])] + messages + + # Remove ': ' for continued generation of function calling, + # because ': ' may form a single token with its following words: + if messages[-1].role == ASSISTANT: + last_msg = messages[-1].content + for i in range(len(last_msg) - 1, -1, -1): + item_type, item_text = last_msg[i].get_type_and_value() + if item_type == 'text': + if item_text.endswith(f'{FN_EXIT}: '): + last_msg[i].text = item_text[:-2] + break + + # Add the function_choice prefix: + if function_choice not in ('auto', 'none'): + if messages[-1].role == ASSISTANT: + last_msg = messages[-1] + if last_msg.content: + if extract_text_from_message(last_msg, add_upload_info=False).endswith(FN_EXIT): + last_msg.content.append(ContentItem(text=': \n')) + else: + last_msg.content.append(ContentItem(text='\n')) + messages = messages[:-1] + else: + last_msg = Message(role=ASSISTANT, content=[]) + last_msg.content.append(ContentItem(text=f'{FN_NAME}: {function_choice}')) + messages = messages + [last_msg] + + return messages + + @staticmethod + def postprocess_fncall_messages( + messages: List[Message], + parallel_function_calls: bool = True, + function_choice: Union[Literal['auto'], str] = 'auto', + ) -> List[Message]: + messages = copy.deepcopy(messages) + + # Prepend a prefix for function_choice: + if function_choice not in ('auto', 'none'): + output = messages[0].content[0].text + if output.lstrip().startswith(FN_ARGS): + # Prepend this prefix only if the model correctly completes it + output = f'{FN_NAME}: {function_choice}\n' + output + messages[0].content[0].text = output + + # Remove ': ' brought by continued generation of function calling + last_msg = messages[-1].content + for i in range(len(last_msg)): + item_type, item_text = last_msg[i].get_type_and_value() + if item_type == 'text': + if item_text.startswith(': '): + last_msg[i].text = item_text[2:] + elif item_text.startswith(':'): + last_msg[i].text = item_text[1:] + break + + # Convert plaintext responses to function_call responses: + new_messages = [] + for msg in messages: + role, content = msg.role, msg.content + assert isinstance(content, list) + + if role in (SYSTEM, USER): + new_messages.append(Message(role=role, content=content)) + continue + + new_content = [] + for item in content: + item_type, item_text = item.get_type_and_value() + + if item_type != 'text': # multimodal + new_content.append(item) + continue + + for stop_word in FN_STOP_WORDS: + assert stop_word not in item_text, 'Something wrong, stop words are expected to be excluded.' + + i = item_text.find(f'{FN_NAME}:') + + # If no function call: + if i < 0: + show_text = remove_incomplete_special_tokens(item_text) + if show_text: + new_content.append(ContentItem(text=show_text)) + continue + + # If it says something before function call: + if i > 0: + answer = item_text[:i].lstrip('\n').rstrip() + if answer.endswith('\n'): + answer = answer[:-1] + show_text = remove_incomplete_special_tokens(answer) + if show_text: + new_content.append(ContentItem(text=show_text)) + if new_content: + new_messages.append(Message( + role=role, + content=new_content, + )) # split thought and function call + new_content = [] + item_text = item_text[i:] + + # If has function call: + for part in item_text.split(f'{FN_NAME}:'): + if not part: + continue + if part.endswith('\n'): + part = part[:-1] + + arg_sep = f'{FN_ARGS}:' + i = part.find(arg_sep) + if i < 0: + fn_name = part.strip() + list_of_fn_args = [''] + else: + fn_name = part[:i].strip() + list_of_fn_args = [_.strip() for _ in part[i + len(arg_sep):].split(arg_sep)] + fn_name = remove_incomplete_special_tokens(fn_name) + for fn_args in list_of_fn_args: + fn_args = remove_incomplete_special_tokens(fn_args) + fn_args = remove_trailing_comment_of_fn_args(fn_args) + new_messages.append( + Message( + role=ASSISTANT, + content=[], + function_call=FunctionCall( + name=fn_name, + arguments=fn_args, + ), + )) + + # Keep only one function call if parallelism is disabled + if not parallel_function_calls: + tmp_messages = [] + for tmp_m in new_messages: + tmp_messages.append(tmp_m) + if tmp_m.function_call: + break + new_messages = tmp_messages + + # Break here and discard the text after function call + return new_messages + + if new_content: + new_messages.append(Message(role=role, content=new_content)) + return new_messages + + +FN_NAME = '✿FUNCTION✿' +FN_ARGS = '✿ARGS✿' +FN_RESULT = '✿RESULT✿' +FN_EXIT = '✿RETURN✿' +FN_STOP_WORDS = [FN_RESULT, FN_EXIT] + +FN_CALL_TEMPLATE_INFO_ZH = """# 工具 + +## 你拥有如下工具: + +{tool_descs}""" + +FN_CALL_TEMPLATE_INFO_EN = """# Tools + +## You have access to the following tools: + +{tool_descs}""" + +FN_CALL_TEMPLATE_FMT_ZH = """## 你可以在回复中插入零次、一次或多次以下命令以调用工具: + +%s: 工具名称,必须是[{tool_names}]之一。 +%s: 工具输入 +%s: 工具结果 +%s: 根据工具结果进行回复,需将图片用![](url)渲染出来""" % ( + FN_NAME, + FN_ARGS, + FN_RESULT, + FN_EXIT, +) + +FN_CALL_TEMPLATE_FMT_EN = """## When you need to call a tool, please insert the following command in your reply, which can be called zero or multiple times according to your needs: + +%s: The tool to use, should be one of [{tool_names}] +%s: The input of the tool +%s: Tool results +%s: Reply based on tool results. Images need to be rendered as ![](url)""" % ( + FN_NAME, + FN_ARGS, + FN_RESULT, + FN_EXIT, +) + +FN_CALL_TEMPLATE_FMT_PARA_ZH = """## 你可以在回复中插入以下命令以并行调用N个工具: + +%s: 工具1的名称,必须是[{tool_names}]之一 +%s: 工具1的输入 +%s: 工具2的名称 +%s: 工具2的输入 +... +%s: 工具N的名称 +%s: 工具N的输入 +%s: 工具1的结果 +%s: 工具2的结果 +... +%s: 工具N的结果 +%s: 根据工具结果进行回复,需将图片用![](url)渲染出来""" % ( + FN_NAME, + FN_ARGS, + FN_NAME, + FN_ARGS, + FN_NAME, + FN_ARGS, + FN_RESULT, + FN_RESULT, + FN_RESULT, + FN_EXIT, +) + +FN_CALL_TEMPLATE_FMT_PARA_EN = """## Insert the following command in your reply when you need to call N tools in parallel: + +%s: The name of tool 1, should be one of [{tool_names}] +%s: The input of tool 1 +%s: The name of tool 2 +%s: The input of tool 2 +... +%s: The name of tool N +%s: The input of tool N +%s: The result of tool 1 +%s: The result of tool 2 +... +%s: The result of tool N +%s: Reply based on tool results. Images need to be rendered as ![](url)""" % ( + FN_NAME, + FN_ARGS, + FN_NAME, + FN_ARGS, + FN_NAME, + FN_ARGS, + FN_RESULT, + FN_RESULT, + FN_RESULT, + FN_EXIT, +) + +FN_CALL_TEMPLATE = { + 'zh': FN_CALL_TEMPLATE_INFO_ZH + '\n\n' + FN_CALL_TEMPLATE_FMT_ZH, + 'en': FN_CALL_TEMPLATE_INFO_EN + '\n\n' + FN_CALL_TEMPLATE_FMT_EN, + 'zh_parallel': FN_CALL_TEMPLATE_INFO_ZH + '\n\n' + FN_CALL_TEMPLATE_FMT_PARA_ZH, + 'en_parallel': FN_CALL_TEMPLATE_INFO_EN + '\n\n' + FN_CALL_TEMPLATE_FMT_PARA_EN, +} + + +def get_function_description(function: Dict, lang: Literal['en', 'zh']) -> str: + """ + Text description of function + """ + tool_desc_template = { + 'zh': '### {name_for_human}\n\n{name_for_model}: {description_for_model} 输入参数:{parameters} {args_format}', + 'en': '### {name_for_human}\n\n{name_for_model}: {description_for_model} Parameters: {parameters} {args_format}' + } + tool_desc = tool_desc_template[lang] + name = function.get('name', None) + name_for_human = function.get('name_for_human', name) + name_for_model = function.get('name_for_model', name) + assert name_for_human and name_for_model + + if name_for_model == 'code_interpreter': + args_format = { + 'zh': '此工具的输入应为Markdown代码块。', + 'en': 'Enclose the code within triple backticks (`) at the beginning and end of the code.', + } + else: + args_format = { + 'zh': '此工具的输入应为JSON对象。', + 'en': 'Format the arguments as a JSON object.', + } + args_format = function.get('args_format', args_format[lang]) + + return tool_desc.format(name_for_human=name_for_human, + name_for_model=name_for_model, + description_for_model=function['description'], + parameters=json.dumps(function['parameters'], ensure_ascii=False), + args_format=args_format).rstrip() + + +# Mainly for removing incomplete trailing special tokens when streaming the output +def remove_incomplete_special_tokens(text: str) -> str: + special_tokens = (FN_NAME, FN_ARGS, FN_RESULT, FN_EXIT) + text = text.rstrip() + if text.endswith(special_tokens): + for s in special_tokens: + if text.endswith(s): + text = text[:-len(s)] + break + else: + trail_start = text.rfind('✿') + trail_token = text[trail_start:] + for s in special_tokens: + if s.startswith(trail_token): + text = text[:trail_start] + break + text = text.lstrip('\n').rstrip() + return text + + +# For hotfix badcases such as `{"arg1": "value1"} `. +def remove_trailing_comment_of_fn_args(fn_args: str): + fn_args = fn_args.strip() + + if fn_args.startswith('{'): + k = fn_args.rfind('}') + if k > 0: + fn_args = fn_args[:k + 1] + + if fn_args.startswith('```'): + k = fn_args.rfind('\n```') + if k > 0: + fn_args = fn_args[:k + 4] + + return fn_args diff --git a/qwen_agent/llm/function_calling.py b/qwen_agent/llm/function_calling.py index 4bf7d7e..9b065fe 100644 --- a/qwen_agent/llm/function_calling.py +++ b/qwen_agent/llm/function_calling.py @@ -1,49 +1,61 @@ import copy -import json from abc import ABC from typing import Dict, Iterator, List, Literal, Optional, Union from qwen_agent.llm.base import BaseChatModel -from qwen_agent.llm.schema import ASSISTANT, FUNCTION, SYSTEM, USER, ContentItem, FunctionCall, Message - - -def validate_num_fncall_results(messages: List[Message]): - fn_results = [] - i = len(messages) - 1 - while messages[i].role == FUNCTION: - fn_results = [messages[i].name] + fn_results - i -= 1 - - fn_calls = [] - while messages[i].function_call: - fn_calls = [messages[i].function_call.name] + fn_calls - i -= 1 - - if len(fn_calls) != len(fn_results): - raise ValueError(f'Expecting {len(fn_calls)} function results (i.e., messages with role="function") ' - f'but received {len(fn_results)} function results. ' - 'The number of function results must match that of the function_call messages.') - for fc_name, fr_name in zip(fn_calls, fn_results): - if fr_name and (fc_name != fr_name): - raise ValueError('The function results (i.e., the messages with role="function" ) must be ' - 'put in the same order as the function_call messages. And the function names must match.' - f'The function results are currently {fn_results}. But {fn_calls} are expected.') +from qwen_agent.llm.schema import ASSISTANT, FUNCTION, USER, ContentItem, Message class BaseFnCallModel(BaseChatModel, ABC): def __init__(self, cfg: Optional[Dict] = None): super().__init__(cfg) - stop = self.generate_cfg.get('stop', []) - self.generate_cfg['stop'] = stop + [x for x in FN_STOP_WORDS if x not in stop] + fncall_prompt_type = self.generate_cfg.get('fncall_prompt_type', 'qwen') + if fncall_prompt_type == 'qwen': + from qwen_agent.llm.fncall_prompts.qwen_fncall_prompt import FN_STOP_WORDS, QwenFnCallPrompt + self.fncall_prompt = QwenFnCallPrompt() + stop = self.generate_cfg.get('stop', []) + self.generate_cfg['stop'] = stop + [x for x in FN_STOP_WORDS if x not in stop] + else: + raise NotImplementedError - def _preprocess_messages(self, messages: List[Message], lang: Literal['en', 'zh'], - generate_cfg: dict) -> List[Message]: + def _preprocess_messages( + self, + messages: List[Message], + lang: Literal['en', 'zh'], + generate_cfg: dict, + functions: Optional[List[Dict]] = None, + ) -> List[Message]: messages = super()._preprocess_messages(messages, lang=lang, generate_cfg=generate_cfg) - if generate_cfg.get('function_choice', 'auto') == 'none': + if (not functions) or (generate_cfg.get('function_choice', 'auto') == 'none'): messages = self._remove_fncall_messages(messages, lang=lang) else: - messages = self._preprocess_fncall_messages(messages) + validate_num_fncall_results( + messages=messages, + support_multimodal_input=self.support_multimodal_input, + ) + messages = self.fncall_prompt.preprocess_fncall_messages( + messages=messages, + functions=functions, + lang=lang, + parallel_function_calls=generate_cfg.get('parallel_function_calls', False), + function_choice=generate_cfg.get('function_choice', 'auto'), + ) + return messages + + def _postprocess_messages( + self, + messages: List[Message], + fncall_mode: bool, + generate_cfg: dict, + ) -> List[Message]: + messages = super()._postprocess_messages(messages, fncall_mode=fncall_mode, generate_cfg=generate_cfg) + if fncall_mode: + messages = self.fncall_prompt.postprocess_fncall_messages( + messages=messages, + parallel_function_calls=generate_cfg.get('parallel_function_calls', False), + function_choice=generate_cfg.get('function_choice', 'auto'), + ) return messages def _remove_fncall_messages(self, messages: List[Message], lang: Literal['en', 'zh']) -> List[Message]: @@ -82,69 +94,6 @@ def _remove_fncall_messages(self, messages: List[Message], lang: Literal['en', ' new_messages.append(msg) return new_messages - def _preprocess_fncall_messages(self, messages: List[Message]) -> List[Message]: - """Convert messages with function_call key and function role to assistant's content, which is - for chat interface or text_completion interface that do not support functions. - """ - validate_num_fncall_results(messages) - new_messages = [] - for msg in copy.deepcopy(messages): - role, content = msg.role, msg.content - if role in (SYSTEM, USER): - new_messages.append(msg) - - elif role == ASSISTANT: - content = (content or []) - fn_call = msg.function_call - if fn_call: - f_name = fn_call.name - f_args = fn_call.arguments - if f_args.startswith('```'): # if code snippet - f_args = '\n' + f_args # for markdown rendering - func_content = '\n' if new_messages[-1].role == ASSISTANT else '' - func_content += f'{FN_NAME}: {f_name}' - func_content += f'\n{FN_ARGS}: {f_args}' - content.append(ContentItem(text=func_content)) - if new_messages[-1].role == ASSISTANT: - new_messages[-1].content += content - else: - new_messages.append(Message(role=role, content=content)) - - elif role == FUNCTION: - assert new_messages[-1].role == ASSISTANT - assert isinstance(content, list) - assert all(isinstance(item, ContentItem) for item in content) - if not self.support_multimodal_input: - assert all(item.type == 'text' for item in content) - - if content: - f_result = copy.deepcopy(content) - else: - f_result = [ContentItem(text='')] - - f_exit = f'\n{FN_EXIT}: ' - last_text_content = new_messages[-1].content[-1].text - if last_text_content.endswith(f_exit): - new_messages[-1].content[-1].text = last_text_content[:-len(f_exit)] - - f_result = [ContentItem(text=f'\n{FN_RESULT}: ')] + f_result + [ContentItem(text=f_exit)] - new_messages[-1].content += f_result - - else: - raise TypeError - - # Remove ': ' for continued generation of function calling, - # because ': ' may form a single token with its following words - if new_messages[-1].role == ASSISTANT: - last_msg = new_messages[-1].content - for i in range(len(last_msg) - 1, -1, -1): - item_type, item_text = last_msg[i].get_type_and_value() - if item_type == 'text': - if item_text.endswith(f'{FN_EXIT}: '): - last_msg[i].text = item_text[:-2] - break - return new_messages - def _chat_with_functions( self, messages: List[Message], @@ -157,364 +106,66 @@ def _chat_with_functions( if delta_stream: raise NotImplementedError('Please use stream=True with delta_stream=False, because delta_stream=True' ' is not implemented for function calling due to some technical reasons.') - parallel_function_calls = generate_cfg.get('parallel_function_calls', False) - messages = self._prepend_fncall_system( - messages=messages, - functions=functions, - lang=lang, - parallel_function_calls=parallel_function_calls, - ) - - fn_choice = generate_cfg.get('function_choice', 'auto') - if fn_choice not in ('auto', 'none'): - if messages[-1].role == ASSISTANT: - msg_to_cont = copy.deepcopy(messages[-1]) - if msg_to_cont.content.endswith(FN_EXIT): - msg_to_cont.content += ': ' - msg_to_cont.content += '\n' - messages = messages[:-1] - else: - msg_to_cont = Message(role=ASSISTANT, content='') - msg_to_cont.content += f'{FN_NAME}: {fn_choice}' - messages = messages + [msg_to_cont] - generate_cfg = copy.deepcopy(generate_cfg) for k in ['parallel_function_calls', 'function_choice']: if k in generate_cfg: del generate_cfg[k] - return self._continue_assistant_response(messages, generate_cfg=generate_cfg, stream=stream) - def _prepend_fncall_system( - self, - messages: List[Message], - functions: List[Dict], - lang: Literal['en', 'zh'], - parallel_function_calls: bool = False, - ) -> List[Message]: - tool_desc_template = FN_CALL_TEMPLATE[lang + ('_parallel' if parallel_function_calls else '')] - tool_descs = '\n\n'.join(get_function_description(function, lang=lang) for function in functions) - tool_names = ','.join(function.get('name', function.get('name_for_model', '')) for function in functions) - tool_system = tool_desc_template.format(tool_descs=tool_descs, tool_names=tool_names) - - assert messages[0].role == SYSTEM - messages = copy.deepcopy(messages[:1]) + messages[1:] - if isinstance(messages[0].content, str): - messages[0].content += '\n\n' + tool_system - else: - messages[0].content.append(ContentItem(text='\n\n' + tool_system)) - - return messages - def _continue_assistant_response( self, messages: List[Message], generate_cfg: dict, stream: bool, ) -> Iterator[List[Message]]: - # Simulate text completion with chat completion - if messages and messages[-1].role == ASSISTANT: - assert len(messages) > 1 and messages[-2].role == USER - assert messages[-1].function_call is None - usr = messages[-2].content - bot = messages[-1].content - sep = '\n\n' - if isinstance(usr, str) and isinstance(bot, str): - usr = usr + sep + bot - elif isinstance(usr, list) and isinstance(bot, list): - usr = usr + [ContentItem(text=sep)] + bot - else: - raise NotImplementedError - text_to_complete = copy.deepcopy(messages[-2]) - text_to_complete.content = usr - messages = messages[:-2] + [text_to_complete] + messages = simulate_response_completion_with_chat(messages) return self._chat(messages, stream=stream, delta_stream=False, generate_cfg=generate_cfg) - def _postprocess_messages( - self, - messages: List[Message], - fncall_mode: bool, - generate_cfg: dict, - ) -> List[Message]: - messages = super()._postprocess_messages(messages, fncall_mode=fncall_mode, generate_cfg=generate_cfg) - if fncall_mode: - fn_choice = generate_cfg.get('function_choice', 'auto') - if fn_choice not in ('auto', 'none'): - messages = copy.deepcopy(messages) - output = messages[0].content[0].text - if output.lstrip().startswith(FN_ARGS): - # Prepend this fn_choice prefix only if the model correctly completes it - output = f'{FN_NAME}: {fn_choice}\n' + output - messages[0].content[0].text = output - messages = self._postprocess_fncall_messages(messages) - return messages - - def _postprocess_fncall_messages(self, messages: List[Message]) -> List[Message]: - """ - If the model calls function by built-in function call template, - convert and display it in function_call format. - """ - # Remove ': ' brought by continued generation of function calling - last_msg = messages[-1].content - for i in range(len(last_msg)): - item_type, item_text = last_msg[i].get_type_and_value() - if item_type == 'text': - if item_text.startswith(': '): - last_msg[i].text = item_text[2:] - elif item_text.startswith(':'): - last_msg[i].text = item_text[1:] - break - - new_messages = [] - for msg in messages: - role, content = msg.role, msg.content - assert isinstance(content, list) +def simulate_response_completion_with_chat(messages: List[Message]) -> List[Message]: + if messages and (messages[-1].role == ASSISTANT): + assert (len(messages) > 1) and (messages[-2].role == USER) + assert messages[-1].function_call is None + usr = messages[-2].content + bot = messages[-1].content + sep = '\n\n' + if isinstance(usr, str) and isinstance(bot, str): + usr = usr + sep + bot + elif isinstance(usr, list) and isinstance(bot, list): + usr = usr + [ContentItem(text=sep)] + bot + else: + raise NotImplementedError + text_to_complete = copy.deepcopy(messages[-2]) + text_to_complete.content = usr + messages = messages[:-2] + [text_to_complete] + return messages - if role in (SYSTEM, USER): - new_messages.append(Message(role=role, content=content)) - continue - new_content = [] +def validate_num_fncall_results(messages: List[Message], support_multimodal_input: bool): + fn_results = [] + i = len(messages) - 1 + while messages[i].role == FUNCTION: + fn_results = [messages[i].name] + fn_results + content = messages[i].content + if isinstance(content, list): for item in content: - item_type, item_text = item.get_type_and_value() - - if item_type != 'text': # multimodal - new_content.append(item) - continue - - for stop_word in [FN_RESULT, FN_EXIT]: - assert stop_word in FN_STOP_WORDS - assert stop_word not in item_text, 'Something wrong, stop words are expected to be excluded.' - - i = item_text.find(f'{FN_NAME}:') - - # If no function call: - if i < 0: - show_text = remove_incomplete_special_tokens(item_text) - if show_text: - new_content.append(ContentItem(text=show_text)) - continue - - # If it says something before function call: - if i > 0: - answer = item_text[:i].lstrip('\n').rstrip() - if answer.endswith('\n'): - answer = answer[:-1] - show_text = remove_incomplete_special_tokens(answer) - if show_text: - new_content.append(ContentItem(text=show_text)) - if new_content: - new_messages.append(Message( - role=role, - content=new_content, - )) # split thought and function call - new_content = [] - item_text = item_text[i:] - - # If has function call: - for part in item_text.split(f'{FN_NAME}:'): - if not part: - continue - if part.endswith('\n'): - part = part[:-1] - - arg_sep = f'{FN_ARGS}:' - i = part.find(arg_sep) - if i < 0: - fn_name = part.strip() - list_of_fn_args = [''] - else: - fn_name = part[:i].strip() - list_of_fn_args = [_.strip() for _ in part[i + len(arg_sep):].split(arg_sep)] - fn_name = remove_incomplete_special_tokens(fn_name) - for fn_args in list_of_fn_args: - fn_args = remove_incomplete_special_tokens(fn_args) - fn_args = remove_trailing_comment_of_fn_args(fn_args) - new_messages.append( - Message( - role=ASSISTANT, - content=[], - function_call=FunctionCall( - name=fn_name, - arguments=fn_args, - ), - )) - # Break here and discard the text after function call - return new_messages - - if new_content: - new_messages.append(Message(role=role, content=new_content)) - return new_messages - - -FN_NAME = '✿FUNCTION✿' -FN_ARGS = '✿ARGS✿' -FN_RESULT = '✿RESULT✿' -FN_EXIT = '✿RETURN✿' -FN_STOP_WORDS = [FN_RESULT, FN_EXIT] - -FN_CALL_TEMPLATE_INFO_ZH = """# 工具 - -## 你拥有如下工具: - -{tool_descs}""" - -FN_CALL_TEMPLATE_INFO_EN = """# Tools - -## You have access to the following tools: - -{tool_descs}""" - -FN_CALL_TEMPLATE_FMT_ZH = """## 你可以在回复中插入零次、一次或多次以下命令以调用工具: - -%s: 工具名称,必须是[{tool_names}]之一。 -%s: 工具输入 -%s: 工具结果 -%s: 根据工具结果进行回复,需将图片用![](url)渲染出来""" % ( - FN_NAME, - FN_ARGS, - FN_RESULT, - FN_EXIT, -) - -FN_CALL_TEMPLATE_FMT_EN = """## When you need to call a tool, please insert the following command in your reply, which can be called zero or multiple times according to your needs: - -%s: The tool to use, should be one of [{tool_names}] -%s: The input of the tool -%s: Tool results -%s: Reply based on tool results. Images need to be rendered as ![](url)""" % ( - FN_NAME, - FN_ARGS, - FN_RESULT, - FN_EXIT, -) - -FN_CALL_TEMPLATE_FMT_PARA_ZH = """## 你可以在回复中插入以下命令以并行调用N个工具: - -%s: 工具1的名称,必须是[{tool_names}]之一 -%s: 工具1的输入 -%s: 工具2的名称 -%s: 工具2的输入 -... -%s: 工具N的名称 -%s: 工具N的输入 -%s: 工具1的结果 -%s: 工具2的结果 -... -%s: 工具N的结果 -%s: 根据工具结果进行回复,需将图片用![](url)渲染出来""" % ( - FN_NAME, - FN_ARGS, - FN_NAME, - FN_ARGS, - FN_NAME, - FN_ARGS, - FN_RESULT, - FN_RESULT, - FN_RESULT, - FN_EXIT, -) - -FN_CALL_TEMPLATE_FMT_PARA_EN = """## Insert the following command in your reply when you need to call N tools in parallel: - -%s: The name of tool 1, should be one of [{tool_names}] -%s: The input of tool 1 -%s: The name of tool 2 -%s: The input of tool 2 -... -%s: The name of tool N -%s: The input of tool N -%s: The result of tool 1 -%s: The result of tool 2 -... -%s: The result of tool N -%s: Reply based on tool results. Images need to be rendered as ![](url)""" % ( - FN_NAME, - FN_ARGS, - FN_NAME, - FN_ARGS, - FN_NAME, - FN_ARGS, - FN_RESULT, - FN_RESULT, - FN_RESULT, - FN_EXIT, -) - -FN_CALL_TEMPLATE = { - 'zh': FN_CALL_TEMPLATE_INFO_ZH + '\n\n' + FN_CALL_TEMPLATE_FMT_ZH, - 'en': FN_CALL_TEMPLATE_INFO_EN + '\n\n' + FN_CALL_TEMPLATE_FMT_EN, - 'zh_parallel': FN_CALL_TEMPLATE_INFO_ZH + '\n\n' + FN_CALL_TEMPLATE_FMT_PARA_ZH, - 'en_parallel': FN_CALL_TEMPLATE_INFO_EN + '\n\n' + FN_CALL_TEMPLATE_FMT_PARA_EN, -} - - -def get_function_description(function: Dict, lang: Literal['en', 'zh']) -> str: - """ - Text description of function - """ - tool_desc_template = { - 'zh': '### {name_for_human}\n\n{name_for_model}: {description_for_model} 输入参数:{parameters} {args_format}', - 'en': '### {name_for_human}\n\n{name_for_model}: {description_for_model} Parameters: {parameters} {args_format}' - } - tool_desc = tool_desc_template[lang] - name = function.get('name', None) - name_for_human = function.get('name_for_human', name) - name_for_model = function.get('name_for_model', name) - assert name_for_human and name_for_model - - if name_for_model == 'code_interpreter': - args_format = { - 'zh': '此工具的输入应为Markdown代码块。', - 'en': 'Enclose the code within triple backticks (`) at the beginning and end of the code.', - } - else: - args_format = { - 'zh': '此工具的输入应为JSON对象。', - 'en': 'Format the arguments as a JSON object.', - } - args_format = function.get('args_format', args_format[lang]) - - return tool_desc.format(name_for_human=name_for_human, - name_for_model=name_for_model, - description_for_model=function['description'], - parameters=json.dumps(function['parameters'], ensure_ascii=False), - args_format=args_format).rstrip() - - -# Mainly for removing incomplete trailing special tokens when streaming the output -def remove_incomplete_special_tokens(text: str) -> str: - special_tokens = (FN_NAME, FN_ARGS, FN_RESULT, FN_EXIT) - text = text.rstrip() - if text.endswith(special_tokens): - for s in special_tokens: - if text.endswith(s): - text = text[:-len(s)] - break - else: - trail_start = text.rfind('✿') - trail_token = text[trail_start:] - for s in special_tokens: - if s.startswith(trail_token): - text = text[:trail_start] - break - text = text.lstrip('\n').rstrip() - return text - - -# For hotfix badcases such as `{"arg1": "value1"} `. -def remove_trailing_comment_of_fn_args(fn_args: str): - fn_args = fn_args.strip() - - if fn_args.startswith('{'): - k = fn_args.rfind('}') - if k > 0: - fn_args = fn_args[:k + 1] + if item.file: + raise ValueError('Tool call results with content type="file" are not supported.') + if item.image and (not support_multimodal_input): + raise ValueError('The current model service does not accept images as tool results.') + i -= 1 - if fn_args.startswith('```'): - k = fn_args.rfind('\n```') - if k > 0: - fn_args = fn_args[:k + 4] + fn_calls = [] + while messages[i].function_call: + fn_calls = [messages[i].function_call.name] + fn_calls + i -= 1 - return fn_args + if len(fn_calls) != len(fn_results): + raise ValueError(f'Expecting {len(fn_calls)} function results (i.e., messages with role="function") ' + f'but received {len(fn_results)} function results. ' + 'The number of function results must match that of the function_call messages.') + for fc_name, fr_name in zip(fn_calls, fn_results): + if fr_name and (fc_name != fr_name): + raise ValueError('The function results (i.e., the messages with role="function" ) must be ' + 'put in the same order as the function_call messages. And the function names must match.' + f'The function results are currently {fn_results}. But {fn_calls} are expected.') diff --git a/qwen_agent/llm/oai.py b/qwen_agent/llm/oai.py index 4769fb5..9ad3f59 100644 --- a/qwen_agent/llm/oai.py +++ b/qwen_agent/llm/oai.py @@ -80,12 +80,12 @@ def _chat_stream( response = self._chat_complete_create(model=self.model, messages=messages, stream=True, **generate_cfg) if delta_stream: for chunk in response: - if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content: + if chunk.choices and hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content: yield [Message(ASSISTANT, chunk.choices[0].delta.content)] else: full_response = '' for chunk in response: - if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content: + if chunk.choices and hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content: full_response += chunk.choices[0].delta.content yield [Message(ASSISTANT, full_response)] except OpenAIError as ex: diff --git a/qwen_agent/tools/base.py b/qwen_agent/tools/base.py index 39b4b5b..fc4f704 100644 --- a/qwen_agent/tools/base.py +++ b/qwen_agent/tools/base.py @@ -1,12 +1,11 @@ +import json import os from abc import ABC, abstractmethod from typing import Dict, List, Optional, Union -import json5 - from qwen_agent.llm.schema import ContentItem from qwen_agent.settings import DEFAULT_WORKSPACE -from qwen_agent.utils.utils import has_chinese_chars, logger, print_traceback, save_url_to_local_work_dir +from qwen_agent.utils.utils import has_chinese_chars, json_loads, logger, print_traceback, save_url_to_local_work_dir TOOL_REGISTRY = {} @@ -29,17 +28,68 @@ def decorator(cls): return decorator +def is_tool_schema(obj: dict) -> bool: + """ + Check if obj is a valid JSON schema describing a tool compatible with OpenAI's tool calling. + Example valid schema: + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["location"] + } + } + """ + import jsonschema + try: + assert set(obj.keys()) == {'name', 'description', 'parameters'} + assert isinstance(obj['name'], str) + assert obj['name'].strip() + assert isinstance(obj['description'], str) + assert isinstance(obj['parameters'], dict) + + assert set(obj['parameters'].keys()) == {'type', 'properties', 'required'} + assert obj['parameters']['type'] == 'object' + assert isinstance(obj['parameters']['properties'], dict) + assert isinstance(obj['parameters']['required'], list) + assert set(obj['parameters']['required']).issubset(set(obj['parameters']['properties'].keys())) + except AssertionError: + return False + try: + jsonschema.validate(instance={}, schema=obj['parameters']) + except jsonschema.exceptions.SchemaError: + return False + except jsonschema.exceptions.ValidationError: + pass + return True + + class BaseTool(ABC): name: str = '' description: str = '' - parameters: List[Dict] = [] + parameters: Union[List[dict], dict] = [] - def __init__(self, cfg: Optional[Dict] = None): + def __init__(self, cfg: Optional[dict] = None): self.cfg = cfg or {} if not self.name: raise ValueError( f'You must set {self.__class__.__name__}.name, either by @register_tool(name=...) or explicitly setting {self.__class__.__name__}.name' ) + if isinstance(self.parameters, dict): + if not is_tool_schema({'name': self.name, 'description': self.description, 'parameters': self.parameters}): + raise ValueError( + 'The parameters, when provided as a dict, must confirm to a valid openai-compatible JSON schema.') @abstractmethod def call(self, params: Union[str, dict], **kwargs) -> Union[str, list, dict, List[ContentItem]]: @@ -56,20 +106,29 @@ def call(self, params: Union[str, dict], **kwargs) -> Union[str, list, dict, Lis """ raise NotImplementedError - def _verify_json_format_args(self, params: Union[str, dict]) -> Union[str, dict]: + def _verify_json_format_args(self, params: Union[str, dict], strict_json: bool = False) -> dict: """Verify the parameters of the function call""" - try: - if isinstance(params, str): - params_json = json5.loads(params) - else: - params_json = params + if isinstance(params, str): + try: + if strict_json: + params_json: dict = json.loads(params) + else: + params_json: dict = json_loads(params) + except json.decoder.JSONDecodeError: + raise ValueError('Parameters must be formatted as a valid JSON!') + else: + params_json: dict = params + if isinstance(self.parameters, list): for param in self.parameters: if 'required' in param and param['required']: if param['name'] not in params_json: raise ValueError('Parameters %s is required!' % param['name']) - return params_json - except Exception: - raise ValueError('Parameters cannot be converted to Json Format!') + elif isinstance(self.parameters, dict): + import jsonschema + jsonschema.validate(instance=params_json, schema=self.parameters) + else: + raise ValueError + return params_json @property def function(self) -> dict: # Bad naming. It should be `function_info`. diff --git a/qwen_agent/utils/tokenization_qwen.py b/qwen_agent/utils/tokenization_qwen.py index b633f10..5d2b6ff 100644 --- a/qwen_agent/utils/tokenization_qwen.py +++ b/qwen_agent/utils/tokenization_qwen.py @@ -50,7 +50,6 @@ def __init__( vocab_file=None, errors='replace', extra_vocab_file=None, - **kwargs, ): if not vocab_file: vocab_file = VOCAB_FILES_NAMES['vocab_file'] @@ -138,7 +137,6 @@ def tokenize( text: str, allowed_special: Union[Set, str] = 'all', disallowed_special: Union[Collection, str] = (), - **kwargs, ) -> List[Union[bytes, str]]: """ Converts a string in a sequence of tokens. @@ -153,9 +151,6 @@ def tokenize( The surface forms of the tokens that should not be in regular texts and trigger errors. Default to an empty tuple. - kwargs (additional keyword arguments, *optional*): - Will be passed to the underlying model specific encode method. - Returns: `List[bytes|str]`: The list of tokens. """ @@ -196,7 +191,6 @@ def _decode( token_ids: Union[int, List[int]], skip_special_tokens: bool = False, errors: str = None, - **kwargs, ) -> str: if isinstance(token_ids, int): token_ids = [token_ids] diff --git a/qwen_agent/utils/utils.py b/qwen_agent/utils/utils.py index 0dcec6f..c74b5fe 100644 --- a/qwen_agent/utils/utils.py +++ b/qwen_agent/utils/utils.py @@ -1,6 +1,7 @@ import base64 import copy import hashlib +import json import os import re import shutil @@ -259,6 +260,12 @@ def extract_urls(text: str) -> List[str]: return urls +def extract_markdown_urls(md_text: str) -> List[str]: + pattern = r'!?\[[^\]]*\]\(([^\)]+)\)' + urls = re.findall(pattern, md_text) + return urls + + def extract_code(text: str) -> str: # Match triple backtick blocks first triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL) @@ -273,6 +280,23 @@ def extract_code(text: str) -> str: return text +def json_loads(text: str) -> dict: + text = text.strip('\n') + if text.startswith('```') and text.endswith('\n```'): + text = '\n'.join(text.split('\n')[1:-1]) + try: + return json.loads(text) + except json.decoder.JSONDecodeError as json_err: + try: + return json5.loads(text) + except ValueError: + raise json_err + + +def json_dumps(obj: dict) -> str: + return json.dumps(obj, ensure_ascii=False, indent=2) + + def format_as_multimodal_message( msg: Message, add_upload_info: bool, @@ -360,7 +384,7 @@ def extract_text_from_message( elif isinstance(msg.content, str): text = msg.content else: - raise TypeError + raise TypeError(f'List of str or str expected, but received {type(msg.content).__name__}.') return text.strip() @@ -389,7 +413,11 @@ def merge_generate_cfgs(base_generate_cfg: Optional[dict], new_generate_cfg: Opt return generate_cfg -def build_text_completion_prompt(messages: List[Message]) -> str: +def build_text_completion_prompt( + messages: List[Message], + allow_special: bool = False, + default_system: str = DEFAULT_SYSTEM_MESSAGE, +) -> str: im_start = '<|im_start|>' im_end = '<|im_end|>' @@ -399,7 +427,7 @@ def build_text_completion_prompt(messages: List[Message]) -> str: prompt = f'{im_start}{SYSTEM}\n{sys}{im_end}' messages = messages[1:] else: - prompt = f'{im_start}{SYSTEM}\n{DEFAULT_SYSTEM_MESSAGE}{im_end}' + prompt = f'{im_start}{SYSTEM}\n{default_system}{im_end}' # Make sure we are completing the chat in the tone of the assistant if messages[-1].role != ASSISTANT: @@ -407,14 +435,24 @@ def build_text_completion_prompt(messages: List[Message]) -> str: for msg in messages: assert isinstance(msg.content, str) - if msg.role == USER: - query = msg.content.lstrip('\n').rstrip() - prompt += f'\n{im_start}{USER}\n{query}{im_end}' - elif msg.role == ASSISTANT: - response = msg.content.lstrip('\n').rstrip() - prompt += f'\n{im_start}{ASSISTANT}\n{response}{im_end}' + content = msg.content.lstrip('\n').rstrip() + if allow_special: + assert msg.role in (USER, ASSISTANT, SYSTEM, FUNCTION) + if msg.function_call: + assert msg.role == ASSISTANT + tool_call = msg.function_call.arguments + try: + tool_call = {'name': msg.function_call.name, 'arguments': json.loads(tool_call)} + tool_call = json.dumps(tool_call, ensure_ascii=False, indent=2) + except json.decoder.JSONDecodeError: + tool_call = '{"name": "' + msg.function_call.name + '", "arguments": ' + tool_call + '}' + if content: + content += '\n' + content += f'\n{tool_call}\n' else: - raise ValueError + assert msg.role in (USER, ASSISTANT) + assert msg.function_call is None + prompt += f'\n{im_start}{msg.role}\n{content}{im_end}' assert prompt.endswith(im_end) prompt = prompt[:-len(im_end)] @@ -458,3 +496,12 @@ def resize_image(img, short_side_length: int = 1080): resized_img = img.resize((new_width, new_height), resample=Image.Resampling.BILINEAR) return resized_img + + +def get_last_usr_msg_idx(messages: List[Union[dict, Message]]) -> int: + i = len(messages) - 1 + while (i >= 0) and (messages[i]['role'] != 'user'): + i -= 1 + assert i >= 0, messages + assert messages[i]['role'] == 'user' + return i diff --git a/setup.py b/setup.py index 2a67b41..7c2adb8 100644 --- a/setup.py +++ b/setup.py @@ -19,6 +19,11 @@ def read_description() -> str: return long_description +# To update the package at PyPI: +# ```bash +# python setup.py sdist bdist_wheel +# twine upload dist/* +# ``` setup( name='qwen-agent', version=get_version(), diff --git a/tests/agents/test_react_chat.py b/tests/agents/test_react_chat.py index b7bc005..62228c1 100644 --- a/tests/agents/test_react_chat.py +++ b/tests/agents/test_react_chat.py @@ -29,10 +29,6 @@ def test_react_chat_with_file(): 'model': 'qwen-max', 'model_server': 'dashscope', 'api_key': os.getenv('DASHSCOPE_API_KEY'), - - # 'model': 'Qwen/Qwen1.5-72B-Chat', - # 'model_server': 'https://api.together.xyz', - # 'api_key': os.getenv('TOGETHER_API_KEY'), } tools = ['code_interpreter'] agent = ReActChat(llm=llm_cfg, function_list=tools) diff --git a/tests/llm/test_function_content.py b/tests/llm/test_function_content.py index 18e84ad..581e196 100644 --- a/tests/llm/test_function_content.py +++ b/tests/llm/test_function_content.py @@ -16,6 +16,7 @@ @pytest.mark.parametrize('gen_cfg2', [ None, dict(function_choice='none'), + dict(function_choice='get_current_weather'), ]) def test_function_content(cfg, gen_cfg1, gen_cfg2): if cfg == 0: @@ -27,10 +28,9 @@ def test_function_content(cfg, gen_cfg1, gen_cfg2): }) else: llm = get_chat_model({ - # Use the model service provided by Together.AI: - 'model': 'Qwen/Qwen1.5-7B-Chat', - 'model_server': 'https://api.together.xyz', # api_base - 'api_key': os.getenv('TOGETHER_API_KEY'), + 'model': 'qwen2-7b-instruct', + 'model_server': 'https://dashscope.aliyuncs.com/compatible-mode/v1', + 'api_key': os.getenv('DASHSCOPE_API_KEY', 'none') }) # Step 1: send the conversation and available functions to the model diff --git a/tests/llm/test_oai.py b/tests/llm/test_oai.py index c2b378d..622f20b 100644 --- a/tests/llm/test_oai.py +++ b/tests/llm/test_oai.py @@ -33,9 +33,9 @@ def test_llm_oai(functions, stream, delta_stream): # settings llm_cfg = { - 'model': os.getenv('TEST_MODEL', 'Qwen/Qwen1.5-14B-Chat'), - 'model_server': os.getenv('TEST_MODEL_SERVER', 'https://api.together.xyz'), - 'api_key': os.getenv('TEST_MODEL_SERVER_API_KEY', 'none') + 'model': 'qwen2-7b-instruct', + 'model_server': 'https://dashscope.aliyuncs.com/compatible-mode/v1', + 'api_key': os.getenv('DASHSCOPE_API_KEY', 'none') } llm = get_chat_model(llm_cfg)