diff --git a/.gitignore b/.gitignore index 1befe1a..86e7cb6 100644 --- a/.gitignore +++ b/.gitignore @@ -24,7 +24,8 @@ log.jsonl ai_agent/debug.json ai_agent/local_prompts/* -*/debug.json +**/debug.json +**/debug.log* debug.json ai_agent/log.jsonl qwen_agent.egg-info/* diff --git a/docs/agent.md b/docs/agent.md index 365347f..bdb8018 100644 --- a/docs/agent.md +++ b/docs/agent.md @@ -8,7 +8,7 @@ The Agent receives a list of messages as input and produces a generator that yie Different Agent classes have various workflows. In the [agents](../qwen_agent/agents) directory, we provide several different fundamental Agent subclasses. For instance, the [ArticleAgent](../qwen_agent/agents/article_agent.py) returns a message that includes an article; -the [DocQAAgent](../qwen_agent/agents/docqa_agent.py) returns a message that contains the results of a document Q&A Results. +the [BasicDocQA](../qwen_agent/agents/doc_qa/basic_doc_qa.py) returns a message that contains the results of a document Q&A Results. These types of Agents have relatively fixed response patterns and are suited for fairly specific use cases. diff --git a/docs/agent_cn.md b/docs/agent_cn.md index 224619c..75a34cb 100644 --- a/docs/agent_cn.md +++ b/docs/agent_cn.md @@ -8,7 +8,7 @@ Agent接收一个消息列表输入,并返回一个消息列表的生成器, 不同Agent类具有不同的工作流程,我们在[agents](../qwen_agent/agents)目录提供了多个不同的基础的Agent子类, 例如[ArticleAgent](../qwen_agent/agents/article_agent.py)接收消息后,返回消息包含一篇文章; -[DocQAAgent](../qwen_agent/agents/docqa_agent.py)返回消息包含文档问答的结果。 +[BasicDocQA](../qwen_agent/agents/doc_qa/basic_doc_qa.py)返回消息包含文档问答的结果。 可以看出,这类Agent回复模式相对固定,使用场景也比较固定。 ### 1.1. Assistant 类 diff --git a/examples/assistant_add_custom_tool.py b/examples/assistant_add_custom_tool.py index 8951b06..ae4386b 100644 --- a/examples/assistant_add_custom_tool.py +++ b/examples/assistant_add_custom_tool.py @@ -19,26 +19,34 @@ class MyImageGen(BaseTool): 'name': 'prompt', 'type': 'string', 'description': 'Detailed description of the desired image content, in English', - 'required': True + 'required': True, }] def call(self, params: str, **kwargs) -> str: prompt = json5.loads(params)['prompt'] prompt = urllib.parse.quote(prompt) - return json.dumps({'image_url': f'https://image.pollinations.ai/prompt/{prompt}'}, ensure_ascii=False) + return json.dumps( + {'image_url': f'https://image.pollinations.ai/prompt/{prompt}'}, + ensure_ascii=False, + ) def init_agent_service(): llm_cfg = {'model': 'qwen-max'} - system = ('According to the user\'s request, you first draw a picture and then automatically ' + system = ("According to the user's request, you first draw a picture and then automatically " 'run code to download the picture and select an image operation from the given document ' 'to process the image') - tools = ['my_image_gen', 'code_interpreter'] # code_interpreter is a built-in tool in Qwen-Agent - bot = Assistant(llm=llm_cfg, - system_message=system, - function_list=tools, - files=[os.path.join(ROOT_RESOURCE, 'doc.pdf')]) + tools = [ + 'my_image_gen', + 'code_interpreter', + ] # code_interpreter is a built-in tool in Qwen-Agent + bot = Assistant( + llm=llm_cfg, + system_message=system, + function_list=tools, + files=[os.path.join(ROOT_RESOURCE, 'doc.pdf')], + ) return bot diff --git a/examples/assistant_growing_girl.py b/examples/assistant_growing_girl.py index 8c47fb5..7fb9643 100644 --- a/examples/assistant_growing_girl.py +++ b/examples/assistant_growing_girl.py @@ -10,10 +10,12 @@ def init_agent_service(): llm_cfg = {'model': 'qwen-max'} tools = ['image_gen'] - bot = Assistant(llm=llm_cfg, - function_list=tools, - system_message='你扮演一个漫画家,根据我给你的女孩的不同阶段,使用工具画出每个阶段女孩的的图片,' - '并串成一个故事讲述出来。要求图片背景丰富') + bot = Assistant( + llm=llm_cfg, + function_list=tools, + system_message='你扮演一个漫画家,根据我给你的女孩的不同阶段,使用工具画出每个阶段女孩的的图片,' + '并串成一个故事讲述出来。要求图片背景丰富', + ) return bot @@ -42,7 +44,10 @@ def app(): messages.extend(response) -def test(query='请用image_gen开始创作!', file: Optional[str] = os.path.join(ROOT_RESOURCE, 'growing_girl.pdf')): +def test( + query='请用image_gen开始创作!', + file: Optional[str] = os.path.join(ROOT_RESOURCE, 'growing_girl.pdf'), +): # Define the agent bot = init_agent_service() diff --git a/examples/assistant_weather_bot.py b/examples/assistant_weather_bot.py index 7ba3360..86cff21 100644 --- a/examples/assistant_weather_bot.py +++ b/examples/assistant_weather_bot.py @@ -13,7 +13,11 @@ def init_agent_service(): '你需要查询相应地区的天气,然后调用给你的画图工具绘制一张城市的图,并从给定的诗词文档中选一首相关的诗词来描述天气,不要说文档以外的诗词。') tools = ['image_gen', 'amap_weather'] - bot = Assistant(llm=llm_cfg, system_message=system, function_list=tools) + bot = Assistant( + llm=llm_cfg, + system_message=system, + function_list=tools, + ) return bot diff --git a/examples/gpt_mentions.py b/examples/gpt_mentions.py deleted file mode 100644 index 018e12a..0000000 --- a/examples/gpt_mentions.py +++ /dev/null @@ -1,143 +0,0 @@ -"""A gpt @mentions gradio demo""" -import gradio as gr - -from qwen_agent.agents import Assistant, DocQAAgent, ReActChat -from qwen_server import output_beautify - - -def init_agent_service(messages): - llm_cfg = {'model': 'qwen-max'} - - agent_list = { - 'code_interpreter': { - 'object': ReActChat, - 'params': { - 'system_message': - 'you are a programming expert, skilled in writing code to solve mathematical problems and data analysis problems.', - 'function_list': ['code_interpreter'], - 'llm': - llm_cfg - } - }, - 'doc_qa': { - 'object': DocQAAgent, - 'params': { - 'llm': llm_cfg - } - }, - 'assistant': { - 'object': Assistant, - 'params': { - 'llm': llm_cfg - } - } - } - - agent = messages[-1]['content'][0]['text'].split('@')[-1].strip() - selected_agent = agent_list[agent]['object'](**agent_list[agent]['params']) - - return selected_agent - - -# ========================================================= -# Below is the gradio service: front-end and back-end logic -# ========================================================= - -app_global_para = {'messages': [], 'is_first_upload': True, 'uploaded_file': ''} - -AGENT_LIST_NAME = ['code_interpreter', 'doc_qa', 'assistant'] - - -def app(history, chosen_plug): - if not history: - yield history - else: - if '@' not in history[-1][0]: - history[-1][0] += ('@' + chosen_plug) - content = [{'text': history[-1][0]}] - if app_global_para['uploaded_file'] and app_global_para['is_first_upload']: - app_global_para['is_first_upload'] = False # only send file when first upload - content.append({'file': app_global_para['uploaded_file']}) - app_global_para['messages'].append({'role': 'user', 'content': content}) - - # Define the agent - selected_agent = init_agent_service(messages=app_global_para['messages']) - - # Chat - history[-1][1] = '' - response = [] - try: - for response in selected_agent.run(messages=app_global_para['messages']): - if response: - display_response = output_beautify.convert_fncall_to_text(response) - history[-1][1] = display_response[-1]['content'] - yield history - except Exception as ex: - raise ValueError(ex) - - app_global_para['messages'].extend(response) - - -def test(history: list = [('你好', None)], chosen_plug: str = 'assistant'): - app(history=history, chosen_plug=chosen_plug) - - -def add_text(history, text): - history = history + [(text, None)] - return history, gr.update(value='', interactive=False) - - -def chat_clear(): - app_global_para['messages'] = [] - return None, None - - -def add_file(file): - app_global_para['uploaded_file'] = file.name - app_global_para['is_first_upload'] = True - return file.name - - -with gr.Blocks(theme='soft') as demo: - with gr.Tab('Chat', elem_id='chat-tab'): - with gr.Column(): - chatbot = gr.Chatbot( - [], - elem_id='chatbot', - height=750, - show_copy_button=True, - ) - with gr.Row(): - with gr.Column(scale=1, min_width=0): - file_btn = gr.UploadButton('Upload', file_types=['file']) - - with gr.Column(scale=13): - chat_txt = gr.Textbox( - show_label=False, - placeholder='Chat with Qwen...', - container=False, - ) - with gr.Column(scale=1, min_width=0): - chat_clr_bt = gr.Button('Clear') - - with gr.Row(): - with gr.Column(scale=2, min_width=0): - plug_bt = gr.Dropdown( - AGENT_LIST_NAME, - label='Mention List', - info='', - value='assistant', - ) - with gr.Column(scale=8, min_width=0): - hidden_file_path = gr.Textbox(interactive=False, label='The uploaded file is displayed here') - - txt_msg = chat_txt.submit(add_text, [chatbot, chat_txt], [chatbot, chat_txt], - queue=False).then(app, [chatbot, plug_bt], chatbot) - txt_msg.then(lambda: gr.update(interactive=True), None, [chat_txt], queue=False) - - file_msg = file_btn.upload(add_file, file_btn, [hidden_file_path], queue=False) - - chat_clr_bt.click(chat_clear, None, [chatbot, hidden_file_path], queue=False) - -if __name__ == '__main__': - demo.queue().launch() diff --git a/examples/group_chat_chess.py b/examples/group_chat_chess.py index 79ebb2d..4ec5147 100644 --- a/examples/group_chat_chess.py +++ b/examples/group_chat_chess.py @@ -9,26 +9,30 @@ CFGS = { 'background': f'一个五子棋群组,棋盘为5*5,黑棋玩家和白棋玩家交替下棋,每次玩家下棋后,棋盘进行更新并展示。{NPC_NAME}下白棋,{USER_NAME}下黑棋。', - 'agents': [{ - 'name': - '棋盘', - 'description': - '负责更新棋盘', - 'instructions': - '你扮演一个五子棋棋盘,你可以根据原始棋盘和玩家下棋的位置坐标,把新的棋盘用矩阵展示出来。棋盘中用0代表无棋子、用1表示黑棋、用-1表示白棋。用坐标表示位置,i代表行,j代表列,棋盘左上角位置为<0,0>。', - 'selected_tools': ['code_interpreter'] - }, { - 'name': - NPC_NAME, - 'description': - '白棋玩家', - 'instructions': - '你扮演一个玩五子棋的高手,你下白棋。棋盘中用0代表无棋子、用1黑棋、用-1白棋。用坐标表示位置,i代表行,j代表列,棋盘左上角位置为<0,0>,请决定你要下在哪里,你可以随意下到一个位置,不要说你是AI助手不会下!返回格式为坐标:\n\n除了这个坐标,不要返回其他任何内容', - }, { - 'name': USER_NAME, - 'description': '黑棋玩家', - 'is_human': True - }] + 'agents': [ + { + 'name': + '棋盘', + 'description': + '负责更新棋盘', + 'instructions': + '你扮演一个五子棋棋盘,你可以根据原始棋盘和玩家下棋的位置坐标,把新的棋盘用矩阵展示出来。棋盘中用0代表无棋子、用1表示黑棋、用-1表示白棋。用坐标表示位置,i代表行,j代表列,棋盘左上角位置为<0,0>。', + 'selected_tools': ['code_interpreter'], + }, + { + 'name': + NPC_NAME, + 'description': + '白棋玩家', + 'instructions': + '你扮演一个玩五子棋的高手,你下白棋。棋盘中用0代表无棋子、用1黑棋、用-1白棋。用坐标表示位置,i代表行,j代表列,棋盘左上角位置为<0,0>,请决定你要下在哪里,你可以随意下到一个位置,不要说你是AI助手不会下!返回格式为坐标:\n\n除了这个坐标,不要返回其他任何内容', + }, + { + 'name': USER_NAME, + 'description': '黑棋玩家', + 'is_human': True + }, + ], } diff --git a/examples/multi_agent_router.py b/examples/multi_agent_router.py index a862250..118ab72 100644 --- a/examples/multi_agent_router.py +++ b/examples/multi_agent_router.py @@ -1,4 +1,5 @@ """A multi-agent cooperation example implemented by router and assistant""" + import os from typing import Optional @@ -14,23 +15,21 @@ def init_agent_service(): tools = ['image_gen', 'code_interpreter'] # Define a vl agent - bot_vl = Assistant(llm=llm_cfg_vl) + bot_vl = Assistant(llm=llm_cfg_vl, name='多模态助手', description='可以理解图像内容。') # Define a tool agent - bot_tool = ReActChat(llm=llm_cfg, function_list=tools) + bot_tool = ReActChat( + llm=llm_cfg, + name='工具助手', + description='可以使用画图工具和运行代码来解决问题', + function_list=tools, + ) # Define a router (simultaneously serving as a text agent) - bot = Router(llm=llm_cfg, - agents={ - 'vl': { - 'obj': bot_vl, - 'desc': '多模态助手,可以理解图像内容。' - }, - 'tool': { - 'obj': bot_tool, - 'desc': '工具助手,可以使用画图工具和运行代码来解决问题' - } - }) + bot = Router( + llm=llm_cfg, + agents=[bot_vl, bot_tool], + ) return bot @@ -65,10 +64,10 @@ def app(): def test( - query: str = 'hello', - image: # noqa - str = 'https://img01.sc115.com/uploads/sc/jpgs/1505/apic11540_sc115.com.jpg', # noqa - file: Optional[str] = os.path.join(ROOT_RESOURCE, 'poem.pdf')): # noqa + query: str = 'hello', + image: str = 'https://img01.sc115.com/uploads/sc/jpgs/1505/apic11540_sc115.com.jpg', + file: Optional[str] = os.path.join(ROOT_RESOURCE, 'poem.pdf'), +): # Define the agent bot = init_agent_service() diff --git a/examples/react_data_analysis.py b/examples/react_data_analysis.py index a73f9b8..8cf50de 100644 --- a/examples/react_data_analysis.py +++ b/examples/react_data_analysis.py @@ -46,10 +46,8 @@ def app(): messages.extend(response) -def test( - query: # noqa - str = 'pd.head the file first and then help me draw a line chart to show the changes in stock prices', - file: Optional[str] = os.path.join(ROOT_RESOURCE, 'stock_prices.csv')): # noqa +def test(query: str = 'pd.head the file first and then help me draw a line chart to show the changes in stock prices', + file: Optional[str] = os.path.join(ROOT_RESOURCE, 'stock_prices.csv')): # Define the agent bot = init_agent_service() diff --git a/examples/visual_storytelling.py b/examples/visual_storytelling.py index d74bbcb..22c380a 100644 --- a/examples/visual_storytelling.py +++ b/examples/visual_storytelling.py @@ -34,8 +34,8 @@ def _run(self, **kwargs) -> Iterator[List[Message]]: """Define the workflow""" - assert (isinstance(messages[-1]['content'], list) and - any([item.image for item in messages[-1]['content']])), 'This agent requires input of images' + assert isinstance(messages[-1]['content'], list) + assert any([item.image for item in messages[-1]['content']]), 'This agent requires input of images' # Image understanding new_messages = copy.deepcopy(messages) diff --git a/qwen_agent/__init__.py b/qwen_agent/__init__.py index ef06910..a12f93c 100644 --- a/qwen_agent/__init__.py +++ b/qwen_agent/__init__.py @@ -1,4 +1,4 @@ -__version__ = '0.0.2' +__version__ = '0.0.3' from .agent import Agent __all__ = ['Agent'] diff --git a/qwen_agent/agent.py b/qwen_agent/agent.py index 2633853..a067921 100644 --- a/qwen_agent/agent.py +++ b/qwen_agent/agent.py @@ -9,7 +9,7 @@ from qwen_agent.llm.schema import CONTENT, DEFAULT_SYSTEM_MESSAGE, ROLE, SYSTEM, ContentItem, Message from qwen_agent.log import logger from qwen_agent.tools import TOOL_REGISTRY, BaseTool -from qwen_agent.utils.utils import has_chinese_chars +from qwen_agent.utils.utils import has_chinese_chars, merge_generate_cfgs class Agent(ABC): @@ -41,6 +41,7 @@ def __init__(self, self.llm = get_chat_model(llm) else: self.llm = llm + self.extra_generate_cfg: dict = {} self.function_map = {} if function_list: @@ -114,6 +115,7 @@ def _call_llm( messages: List[Message], functions: Optional[List[Dict]] = None, stream: bool = True, + extra_generate_cfg: Optional[dict] = None, ) -> Iterator[List[Message]]: """The interface of calling LLM for the agent. @@ -136,7 +138,13 @@ def _call_llm( else: assert isinstance(messages[0][CONTENT], list) messages[0][CONTENT] = [ContentItem(text=self.system_message)] + messages[0][CONTENT] - return self.llm.chat(messages=messages, functions=functions, stream=stream) + return self.llm.chat(messages=messages, + functions=functions, + stream=stream, + extra_generate_cfg=merge_generate_cfgs( + base_generate_cfg=self.extra_generate_cfg, + new_generate_cfg=extra_generate_cfg, + )) def _call_tool(self, tool_name: str, tool_args: Union[str, dict] = '{}', **kwargs) -> str: """The interface of calling tools for the agent. @@ -160,6 +168,7 @@ def _call_tool(self, tool_name: str, tool_args: Union[str, dict] = '{}', **kwarg error_message = f'An error occurred when calling tool `{tool_name}`:\n' \ f'{exception_type}: {exception_message}\n' \ f'Traceback:\n{traceback_info}' + logger.warning(error_message) return error_message if isinstance(tool_result, str): diff --git a/qwen_agent/agents/__init__.py b/qwen_agent/agents/__init__.py index 0b5f107..c149b39 100644 --- a/qwen_agent/agents/__init__.py +++ b/qwen_agent/agents/__init__.py @@ -1,6 +1,10 @@ +from qwen_agent import Agent + from .article_agent import ArticleAgent from .assistant import Assistant -from .docqa_agent import DocQAAgent +# DocQAAgent is the default solution for long document question answering. +# The actual implementation of DocQAAgent may change with every release. +from .doc_qa.basic_doc_qa import BasicDocQA as DocQAAgent from .fncall_agent import FnCallAgent from .group_chat import GroupChat from .group_chat_auto_router import GroupChatAutoRouter @@ -11,6 +15,16 @@ from .write_from_scratch import WriteFromScratch __all__ = [ - 'DocQAAgent', 'Assistant', 'ArticleAgent', 'ReActChat', 'Router', 'UserAgent', 'GroupChat', 'WriteFromScratch', - 'GroupChatCreator', 'GroupChatAutoRouter', 'FnCallAgent' + 'Agent', + 'DocQAAgent', + 'Assistant', + 'ArticleAgent', + 'ReActChat', + 'Router', + 'UserAgent', + 'GroupChat', + 'WriteFromScratch', + 'GroupChatCreator', + 'GroupChatAutoRouter', + 'FnCallAgent', ] diff --git a/qwen_agent/agents/article_agent.py b/qwen_agent/agents/article_agent.py index be83fac..3a19fca 100644 --- a/qwen_agent/agents/article_agent.py +++ b/qwen_agent/agents/article_agent.py @@ -4,6 +4,7 @@ from qwen_agent.agents.write_from_scratch import WriteFromScratch from qwen_agent.llm.schema import ASSISTANT, CONTENT, Message from qwen_agent.prompts import ContinueWriting +from qwen_agent.settings import DEFAULT_MAX_REF_TOKEN class ArticleAgent(Assistant): @@ -15,7 +16,7 @@ class ArticleAgent(Assistant): def _run(self, messages: List[Message], lang: str = 'en', - max_ref_token: int = 4000, + max_ref_token: int = DEFAULT_MAX_REF_TOKEN, full_article: bool = False, **kwargs) -> Iterator[List[Message]]: @@ -35,6 +36,6 @@ def _run(self, response.append(Message(ASSISTANT, '>\n> Writing Text: \n')) yield response - for trunk in writing_agent.run(messages=messages, lang=lang, knowledge=_ref): - if trunk: - yield response + trunk + for rsp in writing_agent.run(messages=messages, lang=lang, knowledge=_ref): + if rsp: + yield response + rsp diff --git a/qwen_agent/agents/assistant.py b/qwen_agent/agents/assistant.py index 0f0b5ff..30ac7f7 100644 --- a/qwen_agent/agents/assistant.py +++ b/qwen_agent/agents/assistant.py @@ -1,36 +1,69 @@ import copy -from typing import Iterator, List +import datetime +from typing import Iterator, List, Literal, Optional, Union + +import json5 from qwen_agent.llm.schema import CONTENT, ROLE, SYSTEM, Message from qwen_agent.log import logger -from qwen_agent.utils.utils import format_knowledge_to_source_and_content +from qwen_agent.settings import DEFAULT_MAX_REF_TOKEN +from qwen_agent.utils.utils import print_traceback from .fncall_agent import FnCallAgent -KNOWLEDGE_SNIPPET_ZH = """## 来自 {source} 的内容: - -``` -{content} -```""" KNOWLEDGE_TEMPLATE_ZH = """ # 知识库 {knowledge}""" -KNOWLEDGE_SNIPPET_EN = """## The content from {source}: +KNOWLEDGE_TEMPLATE_EN = """ + +# Knowledge Base + +{knowledge}""" + +KNOWLEDGE_TEMPLATE = {'zh': KNOWLEDGE_TEMPLATE_ZH, 'en': KNOWLEDGE_TEMPLATE_EN} + +KNOWLEDGE_SNIPPET_ZH = """## 来自 {source} 的内容: ``` {content} ```""" -KNOWLEDGE_TEMPLATE_EN = """ -# Knowledge Base +KNOWLEDGE_SNIPPET_EN = """## The content from {source}: -{knowledge}""" +``` +{content} +```""" KNOWLEDGE_SNIPPET = {'zh': KNOWLEDGE_SNIPPET_ZH, 'en': KNOWLEDGE_SNIPPET_EN} -KNOWLEDGE_TEMPLATE = {'zh': KNOWLEDGE_TEMPLATE_ZH, 'en': KNOWLEDGE_TEMPLATE_EN} + + +def format_knowledge_to_source_and_content(result: Union[str, List[dict]]) -> List[dict]: + knowledge = [] + if isinstance(result, str): + result = f'{result}'.strip() + try: + docs = json5.loads(result) + except Exception: + print_traceback() + knowledge.append({'source': '上传的文档', 'content': result}) + return knowledge + else: + docs = result + try: + _tmp_knowledge = [] + assert isinstance(docs, list) + for doc in docs: + url, snippets = doc['url'], doc['text'] + assert isinstance(snippets, list) + _tmp_knowledge.append({'source': f'[文件]({url})', 'content': '\n\n...\n\n'.join(snippets)}) + knowledge.extend(_tmp_knowledge) + except Exception: + print_traceback() + knowledge.append({'source': '上传的文档', 'content': result}) + return knowledge class Assistant(FnCallAgent): @@ -39,21 +72,35 @@ class Assistant(FnCallAgent): def _run(self, messages: List[Message], lang: str = 'en', - max_ref_token: int = 4000, + max_ref_token: int = DEFAULT_MAX_REF_TOKEN, + knowledge: str = '', **kwargs) -> Iterator[List[Message]]: + """Q&A with RAG and tool use abilities. + + Args: + knowledge: If an external knowledge string is provided, + it will be used directly without retrieving information from files in messages. - new_messages = self._prepend_knowledge_prompt(messages, lang, max_ref_token, **kwargs) + """ + + new_messages = self._prepend_knowledge_prompt(messages=messages, + lang=lang, + max_ref_token=max_ref_token, + knowledge=knowledge, + **kwargs) return super()._run(messages=new_messages, lang=lang, max_ref_token=max_ref_token, **kwargs) def _prepend_knowledge_prompt(self, messages: List[Message], lang: str = 'en', - max_ref_token: int = 4000, + max_ref_token: int = DEFAULT_MAX_REF_TOKEN, + knowledge: str = '', **kwargs) -> List[Message]: messages = copy.deepcopy(messages) - # Retrieval knowledge from files - *_, last = self.mem.run(messages=messages, max_ref_token=max_ref_token, lang=lang, **kwargs) - knowledge = last[-1][CONTENT] + if not knowledge: + # Retrieval knowledge from files + *_, last = self.mem.run(messages=messages, lang=lang, max_ref_token=max_ref_token, **kwargs) + knowledge = last[-1][CONTENT] logger.debug(f'Retrieved knowledge of type `{type(knowledge).__name__}`:\n{knowledge}') if knowledge: @@ -74,3 +121,23 @@ def _prepend_knowledge_prompt(self, else: messages = [Message(role=SYSTEM, content=knowledge_prompt)] + messages return messages + + +def get_current_date_str( + lang: Literal['en', 'zh'] = 'en', + hours_from_utc: Optional[int] = None, +) -> str: + if hours_from_utc is None: + cur_time = datetime.datetime.now() + else: + cur_time = datetime.datetime.utcnow() + datetime.timedelta(hours=hours_from_utc) + if lang == 'en': + date_str = 'Current date: ' + cur_time.strftime('%A, %B %d, %Y') + elif lang == 'zh': + cur_time = cur_time.timetuple() + date_str = f'当前时间:{cur_time.tm_year}年{cur_time.tm_mon}月{cur_time.tm_mday}日,星期' + date_str += ['一', '二', '三', '四', '五', '六', '日'][cur_time.tm_wday] + date_str += '。' + else: + raise NotImplementedError + return date_str diff --git a/qwen_agent/agents/doc_qa/__init__.py b/qwen_agent/agents/doc_qa/__init__.py new file mode 100644 index 0000000..d07ca25 --- /dev/null +++ b/qwen_agent/agents/doc_qa/__init__.py @@ -0,0 +1,5 @@ +from .basic_doc_qa import BasicDocQA + +__all__ = [ + 'BasicDocQA', +] diff --git a/qwen_agent/agents/docqa_agent.py b/qwen_agent/agents/doc_qa/basic_doc_qa.py similarity index 71% rename from qwen_agent/agents/docqa_agent.py rename to qwen_agent/agents/doc_qa/basic_doc_qa.py index d7a1899..b2a26ae 100644 --- a/qwen_agent/agents/docqa_agent.py +++ b/qwen_agent/agents/doc_qa/basic_doc_qa.py @@ -4,18 +4,22 @@ from qwen_agent.llm.base import BaseChatModel from qwen_agent.llm.schema import CONTENT, DEFAULT_SYSTEM_MESSAGE, Message from qwen_agent.prompts import DocQA +from qwen_agent.settings import DEFAULT_MAX_REF_TOKEN from qwen_agent.tools import BaseTool +DEFAULT_NAME = 'Basic DocQA' +DEFAULT_DESC = '可以根据问题,检索出知识库中的某个相关细节来回答。适用于需要定位到具体位置的问题,例如“介绍表1”等类型的问题' -class DocQAAgent(Assistant): + +class BasicDocQA(Assistant): """This is an agent for doc QA.""" def __init__(self, function_list: Optional[List[Union[str, Dict, BaseTool]]] = None, llm: Optional[Union[Dict, BaseChatModel]] = None, system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE, - name: Optional[str] = None, - description: Optional[str] = None, + name: Optional[str] = DEFAULT_NAME, + description: Optional[str] = DEFAULT_DESC, files: Optional[List[str]] = None): super().__init__(function_list=function_list, llm=llm, @@ -23,20 +27,20 @@ def __init__(self, name=name, description=description, files=files) - self.doc_qa = DocQA(llm=self.llm) def _run(self, messages: List[Message], lang: str = 'en', - max_ref_token: int = 4000, + max_ref_token: int = DEFAULT_MAX_REF_TOKEN, **kwargs) -> Iterator[List[Message]]: - + """This agent using different doc qa prompt with Assistant""" # Need to use Memory agent for data management *_, last = self.mem.run(messages=messages, max_ref_token=max_ref_token, **kwargs) _ref = last[-1][CONTENT] # Use RetrievalQA agent + # Todo: Prompt engineering response = self.doc_qa.run(messages=messages, lang=lang, knowledge=_ref) return response diff --git a/qwen_agent/agents/fncall_agent.py b/qwen_agent/agents/fncall_agent.py index 601e899..e22c2e5 100644 --- a/qwen_agent/agents/fncall_agent.py +++ b/qwen_agent/agents/fncall_agent.py @@ -5,9 +5,9 @@ from qwen_agent.llm import BaseChatModel from qwen_agent.llm.schema import DEFAULT_SYSTEM_MESSAGE, FUNCTION, Message from qwen_agent.memory import Memory +from qwen_agent.settings import MAX_LLM_CALL_PER_RUN from qwen_agent.tools import BaseTool - -MAX_LLM_CALL_PER_RUN = 8 +from qwen_agent.utils.utils import extract_files_from_messages class FnCallAgent(Agent): @@ -38,8 +38,9 @@ def __init__(self, name=name, description=description) - # Default to use Memory to manage files - self.mem = Memory(llm=self.llm, files=files) + if not hasattr(self, 'mem'): + # Default to use Memory to manage files + self.mem = Memory(llm=self.llm, files=files) def _run(self, messages: List[Message], lang: str = 'en', **kwargs) -> Iterator[List[Message]]: messages = copy.deepcopy(messages) @@ -56,19 +57,19 @@ def _run(self, messages: List[Message], lang: str = 'en', **kwargs) -> Iterator[ if output: response.extend(output) messages.extend(output) - use_tool, action, action_input, _ = self._detect_tool(response[-1]) - if use_tool: - observation = self._call_tool(action, action_input, messages=messages) - fn_msg = Message( - role=FUNCTION, - name=action, - content=observation, - ) - messages.append(fn_msg) - response.append(fn_msg) - yield response - else: - break + use_tool, tool_name, tool_args, _ = self._detect_tool(response[-1]) + if use_tool: + tool_result = self._call_tool(tool_name, tool_args, messages=messages) + fn_msg = Message( + role=FUNCTION, + name=tool_name, + content=tool_result, + ) + messages.append(fn_msg) + response.append(fn_msg) + yield response + else: + break def _call_tool(self, tool_name: str, tool_args: Union[str, dict] = '{}', **kwargs) -> str: if tool_name not in self.function_map: @@ -77,7 +78,7 @@ def _call_tool(self, tool_name: str, tool_args: Union[str, dict] = '{}', **kwarg # Todo: This should be changed to parameter passing, and the file URL should be determined by the model if self.function_map[tool_name].file_access: assert 'messages' in kwargs - files = self.mem.get_all_files_of_messages(kwargs['messages']) + self.mem.system_files + files = extract_files_from_messages(kwargs['messages']) + self.mem.system_files return super()._call_tool(tool_name, tool_args, files=files, **kwargs) else: return super()._call_tool(tool_name, tool_args, **kwargs) diff --git a/qwen_agent/agents/group_chat_auto_router.py b/qwen_agent/agents/group_chat_auto_router.py index fda73d3..fa2368d 100644 --- a/qwen_agent/agents/group_chat_auto_router.py +++ b/qwen_agent/agents/group_chat_auto_router.py @@ -6,32 +6,31 @@ from qwen_agent.tools import BaseTool from qwen_agent.utils.utils import has_chinese_chars -PROMPT_TEMPLATE_ZH = '''你扮演角色扮演游戏的上帝,你的任务是选择合适的发言角色。有如下角色: -{agent_descs} -角色间的对话历史格式如下,越新的对话越重要: -角色名: 说话内容 - -请阅读对话历史,并选择下一个合适的发言角色,从 [{agent_names}] 里选,当真实用户最近表明了停止聊天时,或话题应该终止时,请返回“[STOP]”,用户很懒,非必要不要选真实用户。 -仅返回角色名或“[STOP]”,不要返回其余内容。''' +class GroupChatAutoRouter(Agent): + PROMPT_TEMPLATE_ZH = '''你扮演角色扮演游戏的上帝,你的任务是选择合适的发言角色。有如下角色: + {agent_descs} -PROMPT_TEMPLATE_EN = '''You are in a role play game. The following roles are available: -{agent_descs} + 角色间的对话历史格式如下,越新的对话越重要: + 角色名: 说话内容 -The format of dialogue history between roles is as follows: -Role Name: Speech Content + 请阅读对话历史,并选择下一个合适的发言角色,从 [{agent_names}] 里选,当真实用户最近表明了停止聊天时,或话题应该终止时,请返回“[STOP]”,用户很懒,非必要不要选真实用户。 + 仅返回角色名或“[STOP]”,不要返回其余内容。''' -Please read the dialogue history and choose the next suitable role to speak. -When the user indicates to stop chatting or when the topic should be terminated, please return '[STOP]'. -Only return the role name from [{agent_names}] or '[STOP]'. Do not reply any other content.''' + PROMPT_TEMPLATE_EN = '''You are in a role play game. The following roles are available: + {agent_descs} -PROMPT_TEMPLATE = { - 'zh': PROMPT_TEMPLATE_ZH, - 'en': PROMPT_TEMPLATE_EN, -} + The format of dialogue history between roles is as follows: + Role Name: Speech Content + Please read the dialogue history and choose the next suitable role to speak. + When the user indicates to stop chatting or when the topic should be terminated, please return '[STOP]'. + Only return the role name from [{agent_names}] or '[STOP]'. Do not reply any other content.''' -class GroupChatAutoRouter(Agent): + PROMPT_TEMPLATE = { + 'zh': PROMPT_TEMPLATE_ZH, + 'en': PROMPT_TEMPLATE_EN, + } def __init__(self, function_list: Optional[List[Union[str, Dict, BaseTool]]] = None, @@ -45,8 +44,8 @@ def __init__(self, lang = 'en' if has_chinese_chars(agent_descs): lang = 'zh' - system_prompt = PROMPT_TEMPLATE[lang].format(agent_descs=agent_descs, - agent_names=', '.join([x.name for x in agents])) + system_prompt = self.PROMPT_TEMPLATE[lang].format(agent_descs=agent_descs, + agent_names=', '.join([x.name for x in agents])) super().__init__(function_list=function_list, llm=llm, diff --git a/qwen_agent/agents/react_chat.py b/qwen_agent/agents/react_chat.py index a4e675e..e384fb8 100644 --- a/qwen_agent/agents/react_chat.py +++ b/qwen_agent/agents/react_chat.py @@ -1,11 +1,13 @@ import copy from typing import Dict, Iterator, List, Optional, Tuple, Union -from qwen_agent.agents.fncall_agent import MAX_LLM_CALL_PER_RUN, FnCallAgent +from qwen_agent.agents.fncall_agent import FnCallAgent from qwen_agent.llm import BaseChatModel +from qwen_agent.llm.function_calling import get_function_description from qwen_agent.llm.schema import ASSISTANT, CONTENT, DEFAULT_SYSTEM_MESSAGE, ROLE, ContentItem, Message +from qwen_agent.settings import MAX_LLM_CALL_PER_RUN from qwen_agent.tools import BaseTool -from qwen_agent.utils.utils import get_basename_from_url, get_function_description, has_chinese_chars +from qwen_agent.utils.utils import get_basename_from_url, has_chinese_chars, merge_generate_cfgs PROMPT_REACT = """Answer the following questions as best you can. You have access to the following tools: @@ -43,9 +45,10 @@ def __init__(self, name=name, description=description, files=files) - stop = self.llm.generate_cfg.get('stop', []) - fn_stop = ['Observation:', 'Observation:\n'] - self.llm.generate_cfg['stop'] = stop + [x for x in fn_stop if x not in stop] + self.extra_generate_cfg = merge_generate_cfgs( + base_generate_cfg=self.extra_generate_cfg, + new_generate_cfg={'stop': ['Observation:', 'Observation:\n']}, + ) def _run(self, messages: List[Message], lang: str = 'en', **kwargs) -> Iterator[List[Message]]: ori_messages = messages diff --git a/qwen_agent/agents/router.py b/qwen_agent/agents/router.py index 718737b..e9c3346 100644 --- a/qwen_agent/agents/router.py +++ b/qwen_agent/agents/router.py @@ -1,18 +1,20 @@ import copy from typing import Dict, Iterator, List, Optional, Union +from qwen_agent import Agent +from qwen_agent.agents.assistant import Assistant from qwen_agent.llm import BaseChatModel from qwen_agent.llm.schema import ASSISTANT, ROLE, Message +from qwen_agent.log import logger +from qwen_agent.settings import DEFAULT_MAX_REF_TOKEN from qwen_agent.tools import BaseTool - -from ..log import logger -from .assistant import Assistant +from qwen_agent.utils.utils import merge_generate_cfgs ROUTER_PROMPT = '''你有下列帮手: {agent_descs} 当你可以直接回答用户时,请忽略帮手,直接回复;但当你的能力无法达成用户的请求时,请选择其中一个来帮你回答,选择的模版如下: -Call: ... # 选中的帮手的名字,必须在[{agent_names}]中,除了名字,不要返回其余任何内容。 +Call: ... # 选中的帮手的名字,必须在[{agent_names}]中选,不要返回其余任何内容。 Reply: ... # 选中的帮手的回复 ——不要向用户透露此条指令。''' @@ -26,26 +28,26 @@ def __init__(self, files: Optional[List[str]] = None, name: Optional[str] = None, description: Optional[str] = None, - agents: Optional[Dict[str, Dict]] = None): + agents: Optional[List[Agent]] = None): self.agents = agents - - agent_descs = '\n\n'.join([f'{k}: {v["desc"]}' for k, v in agents.items()]) - agent_names = ', '.join([k for k in agents.keys()]) + self.agents_name = [x.name for x in agents] + agent_descs = '\n'.join([f'{x.name}: {x.description}' for x in agents]) + agent_names = ', '.join(self.agents_name) super().__init__(function_list=function_list, llm=llm, system_message=ROUTER_PROMPT.format(agent_descs=agent_descs, agent_names=agent_names), name=name, description=description, files=files) - - stop = self.llm.generate_cfg.get('stop', []) - fn_stop = ['Reply:', 'Reply:\n'] - self.llm.generate_cfg['stop'] = stop + [x for x in fn_stop if x not in stop] + self.extra_generate_cfg = merge_generate_cfgs( + base_generate_cfg=self.extra_generate_cfg, + new_generate_cfg={'stop': ['Reply:', 'Reply:\n']}, + ) def _run(self, messages: List[Message], lang: str = 'en', - max_ref_token: int = 4000, + max_ref_token: int = DEFAULT_MAX_REF_TOKEN, **kwargs) -> Iterator[List[Message]]: # This is a temporary plan to determine the source of a message messages_for_router = [] @@ -58,15 +60,19 @@ def _run(self, **kwargs): # noqa yield response - if 'Call:' in response[-1].content: + if 'Call:' in response[-1].content and self.agents: # According to the rule in prompt to selected agent - selected_agent_name = response[-1].content.split('Call:')[-1].strip() + selected_agent_name = response[-1].content.split('Call:')[-1].strip().split('\n')[0].strip() logger.info(f'Need help from {selected_agent_name}') - selected_agent = self.agents[selected_agent_name]['obj'] + if selected_agent_name not in self.agents_name: + # If the model generates a non-existent agent, the first agent will be used by default. + selected_agent_name = self.agents_name[0] + selected_agent = self.agents[self.agents_name.index(selected_agent_name)] for response in selected_agent.run(messages=messages, lang=lang, max_ref_token=max_ref_token, **kwargs): for i in range(len(response)): if response[i].role == ASSISTANT: response[i].name = selected_agent_name + # This new response will overwrite the above 'Call: xxx' message yield response @staticmethod diff --git a/qwen_agent/agents/user_agent.py b/qwen_agent/agents/user_agent.py index 334465f..b5ecffc 100644 --- a/qwen_agent/agents/user_agent.py +++ b/qwen_agent/agents/user_agent.py @@ -2,6 +2,7 @@ from qwen_agent.agents.assistant import Assistant from qwen_agent.llm.schema import Message +from qwen_agent.settings import DEFAULT_MAX_REF_TOKEN PENDING_USER_INPUT = '' @@ -11,7 +12,7 @@ class UserAgent(Assistant): def _run(self, messages: List[Message], lang: str = 'en', - max_ref_token: int = 4000, + max_ref_token: int = DEFAULT_MAX_REF_TOKEN, **kwargs) -> Iterator[List[Message]]: yield [Message(role='user', content=PENDING_USER_INPUT, name=self.name)] diff --git a/qwen_agent/agents/write_from_scratch.py b/qwen_agent/agents/write_from_scratch.py index 5142b0b..a5fca07 100644 --- a/qwen_agent/agents/write_from_scratch.py +++ b/qwen_agent/agents/write_from_scratch.py @@ -4,8 +4,9 @@ import json5 from qwen_agent import Agent +from qwen_agent.agents.assistant import Assistant from qwen_agent.llm.schema import ASSISTANT, CONTENT, USER, Message -from qwen_agent.prompts import DocQA, ExpandWriting, OutlineWriting +from qwen_agent.prompts import ExpandWriting, OutlineWriting default_plan = """{"action1": "summarize", "action2": "outline", "action3": "expand"}""" @@ -38,26 +39,26 @@ def _run(self, messages: List[Message], knowledge: str = '', lang: str = 'en') - user_request = 'Summarize the main content of reference materials.' else: raise NotImplementedError - sum_agent = DocQA(llm=self.llm) + sum_agent = Assistant(llm=self.llm) res_sum = sum_agent.run(messages=[Message(USER, user_request)], knowledge=knowledge, lang=lang) - trunk = None - for trunk in res_sum: - yield response + trunk - if trunk: - response.extend(trunk) - summ = trunk[-1][CONTENT] + chunk = None + for chunk in res_sum: + yield response + chunk + if chunk: + response.extend(chunk) + summ = chunk[-1][CONTENT] elif plan == 'outline': response.append(Message(ASSISTANT, '>\n> Generate Outline: \n')) yield response otl_agent = OutlineWriting(llm=self.llm) res_otl = otl_agent.run(messages=messages, knowledge=summ, lang=lang) - trunk = None - for trunk in res_otl: - yield response + trunk - if trunk: - response.extend(trunk) - outline = trunk[-1][CONTENT] + chunk = None + for chunk in res_otl: + yield response + chunk + if chunk: + response.extend(chunk) + outline = chunk[-1][CONTENT] elif plan == 'expand': response.append(Message(ASSISTANT, '>\n> Writing Text: \n')) yield response @@ -88,10 +89,10 @@ def _run(self, messages: List[Message], knowledge: str = '', lang: str = 'en') - capture_later=capture_later, lang=lang, ) - trunk = None - for trunk in res_exp: - yield response + trunk - if trunk: - response.extend(trunk) + chunk = None + for chunk in res_exp: + yield response + chunk + if chunk: + response.extend(chunk) else: pass diff --git a/qwen_agent/gui/__init__.py b/qwen_agent/gui/__init__.py new file mode 100644 index 0000000..09eabad --- /dev/null +++ b/qwen_agent/gui/__init__.py @@ -0,0 +1,3 @@ +from qwen_agent.gui.gradio import gr + +__all__ = ['gr'] diff --git a/qwen_agent/gui/assets/logo.jpeg b/qwen_agent/gui/assets/logo.jpeg new file mode 100644 index 0000000..e511f7d Binary files /dev/null and b/qwen_agent/gui/assets/logo.jpeg differ diff --git a/qwen_agent/gui/assets/user.jpeg b/qwen_agent/gui/assets/user.jpeg new file mode 100644 index 0000000..536948b Binary files /dev/null and b/qwen_agent/gui/assets/user.jpeg differ diff --git a/qwen_agent/gui/gradio.py b/qwen_agent/gui/gradio.py new file mode 100644 index 0000000..d61f432 --- /dev/null +++ b/qwen_agent/gui/gradio.py @@ -0,0 +1,8 @@ +try: + import gradio as gr + if gr.__version__ < '4.0': + raise ImportError('Incompatible "gradio" version detected. ' + 'Please install the correct version with: pip install "gradio>=4.0"') +except (ModuleNotFoundError, AttributeError): + raise ImportError('Requirement "gradio" not installed. ' + 'Please install it by: pip install "gradio>=4.0"') diff --git a/qwen_agent/gui/utils.py b/qwen_agent/gui/utils.py new file mode 100644 index 0000000..ac8e39a --- /dev/null +++ b/qwen_agent/gui/utils.py @@ -0,0 +1,8 @@ +import os + + +def get_avatar_image(name: str = 'user') -> str: + if name == 'user': + return os.path.join(os.path.dirname(__file__), 'assets/user.jpeg') + + return os.path.join(os.path.dirname(__file__), 'assets/logo.jpeg') diff --git a/qwen_agent/llm/__init__.py b/qwen_agent/llm/__init__.py index e6e292d..7afbca5 100644 --- a/qwen_agent/llm/__init__.py +++ b/qwen_agent/llm/__init__.py @@ -20,7 +20,7 @@ def get_chat_model(cfg: Optional[Dict] = None) -> BaseChatModel: # Use your own model service compatible with OpenAI API: # 'model': 'Qwen', # 'model_server': 'http://127.0.0.1:7905/v1', - # (Optional) LLM hyper-paramters: + # (Optional) LLM hyper-parameters: 'generate_cfg': { 'top_p': 0.8 } diff --git a/qwen_agent/llm/base.py b/qwen_agent/llm/base.py index f00f75e..c39030f 100644 --- a/qwen_agent/llm/base.py +++ b/qwen_agent/llm/base.py @@ -2,12 +2,11 @@ import random import time from abc import ABC, abstractmethod -from typing import Dict, Iterator, List, Optional, Union +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +from qwen_agent.llm.schema import DEFAULT_SYSTEM_MESSAGE, SYSTEM, Message from qwen_agent.utils.tokenization_qwen import tokenizer -from qwen_agent.utils.utils import get_basename_from_url, has_chinese_chars, is_image, print_traceback - -from .schema import ASSISTANT, DEFAULT_SYSTEM_MESSAGE, FUNCTION, SYSTEM, USER, ContentItem, Message +from qwen_agent.utils.utils import format_as_multimodal_message, merge_generate_cfgs, print_traceback LLM_REGISTRY = {} @@ -52,20 +51,25 @@ def chat( functions: Optional[List[Dict]] = None, stream: bool = True, delta_stream: bool = False, + extra_generate_cfg: Optional[Dict] = None, ) -> Union[List[Message], List[Dict], Iterator[List[Message]], Iterator[List[Dict]]]: """LLM chat interface. Args: messages: Inputted messages. - functions: Inputted functions, which supports OpenAI format. + functions: Inputted functions for function calling. OpenAI format supported. stream: Whether to use streaming generation. - delta_stream: Whether to return incrementally. - (1) When False: Use full return. - (2) When True: Use incremental return. + delta_stream: Whether to stream the response incrementally. + (1) When False (recommended): Stream the full response every iteration. + (2) When True: Stream the chunked response, i.e, delta responses. + extra_generate_cfg: Extra LLM generation hyper-paramters. Returns: the generated message list response by llm. """ + + generate_cfg = merge_generate_cfgs(base_generate_cfg=self.generate_cfg, new_generate_cfg=extra_generate_cfg) + messages = copy.deepcopy(messages) _return_message_type = 'dict' @@ -95,12 +99,14 @@ def _call_model_service(): functions=functions, stream=stream, delta_stream=delta_stream, + generate_cfg=generate_cfg, ) else: return self._chat( messages, stream=stream, delta_stream=delta_stream, + generate_cfg=generate_cfg, ) if stream and delta_stream: @@ -112,36 +118,41 @@ def _call_model_service(): output = retry_model_service(_call_model_service, max_retries=self.max_retries) if isinstance(output, list): - output = self._postprocess_messages(output, fncall_mode=fncall_mode) + output = self._postprocess_messages(output, fncall_mode=fncall_mode, generate_cfg=generate_cfg) return self._convert_messages_to_target_type(output, _return_message_type) else: - output = self._postprocess_messages_iterator(output, fncall_mode=fncall_mode) + output = self._postprocess_messages_iterator(output, fncall_mode=fncall_mode, generate_cfg=generate_cfg) return self._convert_messages_iterator_to_target_type(output, _return_message_type) def _chat( self, messages: List[Union[Message, Dict]], - stream: bool = True, - delta_stream: bool = False, + stream: bool, + delta_stream: bool, + generate_cfg: dict, ) -> Union[List[Message], Iterator[List[Message]]]: if stream: - return self._chat_stream(messages, delta_stream=delta_stream) + return self._chat_stream(messages, delta_stream=delta_stream, generate_cfg=generate_cfg) else: - return self._chat_no_stream(messages) + return self._chat_no_stream(messages, generate_cfg=generate_cfg) @abstractmethod - def _chat_with_functions(self, - messages: List[Union[Message, Dict]], - functions: List[Dict], - stream: bool = True, - delta_stream: bool = False) -> Union[List[Message], Iterator[List[Message]]]: + def _chat_with_functions( + self, + messages: List[Union[Message, Dict]], + functions: List[Dict], + stream: bool, + delta_stream: bool, + generate_cfg: dict, + ) -> Union[List[Message], Iterator[List[Message]]]: raise NotImplementedError @abstractmethod def _chat_stream( self, messages: List[Message], - delta_stream: bool = False, + delta_stream: bool, + generate_cfg: dict, ) -> Iterator[List[Message]]: raise NotImplementedError @@ -149,25 +160,34 @@ def _chat_stream( def _chat_no_stream( self, messages: List[Message], + generate_cfg: dict, ) -> List[Message]: raise NotImplementedError def _preprocess_messages(self, messages: List[Message]) -> List[Message]: - messages = self._format_as_multimodal_messages(messages) + messages = [format_as_multimodal_message(msg) for msg in messages] return messages - def _postprocess_messages(self, messages: List[Message], fncall_mode: bool) -> List[Message]: - messages = self._format_as_multimodal_messages(messages) - messages = self._postprocess_stop_words(messages) + def _postprocess_messages( + self, + messages: List[Message], + fncall_mode: bool, + generate_cfg: dict, + ) -> List[Message]: + messages = [format_as_multimodal_message(msg) for msg in messages] + messages = self._postprocess_stop_words(messages, generate_cfg=generate_cfg) return messages def _postprocess_messages_iterator( self, messages: Iterator[List[Message]], fncall_mode: bool, + generate_cfg: dict, ) -> Iterator[List[Message]]: for m in messages: - m = self._postprocess_messages(m, fncall_mode=fncall_mode) + m = self._postprocess_messages(m, fncall_mode=fncall_mode, generate_cfg=generate_cfg) + # TODO: Postprocessing may be incorrect if delta_stream=True. + # TODO: Early break if truncated at stop words. if m: yield m @@ -186,62 +206,9 @@ def _convert_messages_iterator_to_target_type( for messages in messages_iter: yield self._convert_messages_to_target_type(messages, target_type) - def _format_as_multimodal_messages(self, messages: List[Message]) -> List[Message]: - - multimodal_messages = [] - for msg in messages: - assert msg.role in (USER, ASSISTANT, SYSTEM, FUNCTION) - - content = [] - if isinstance(msg.content, str): # if text content - if msg.content: - content = [ContentItem(text=msg.content)] - elif isinstance(msg.content, list): # if multimodal content - files = [] - for item in msg.content: - (k, v), = item.model_dump().items() - if k in ('box', 'text'): - content.append(ContentItem(text=v)) - if k == 'image': - content.append(item) - if k in ('file', 'image'): - files.append(v) - if (msg.role in (SYSTEM, USER)) and files: - has_zh = has_chinese_chars(content) - upload = [] - for f in [get_basename_from_url(f) for f in files]: - if is_image(f): - if has_zh: - upload.append(f'![图片]({f})') - else: - upload.append(f'![image]({f})') - else: - if has_zh: - upload.append(f'[文件]({f})') - else: - upload.append(f'[file]({f})') - upload = ' '.join(upload) - if has_zh: - upload = f'(上传了 {upload})\n\n' - else: - upload = f'(Uploaded {upload})\n\n' - content = [ContentItem(text=upload)] + content - else: - raise TypeError - - multimodal_messages.append( - Message( - role=msg.role, - content=content, - name=msg.name if msg.role == FUNCTION else None, - function_call=msg.function_call, - )) - - return multimodal_messages - - def _postprocess_stop_words(self, messages: List[Message]) -> List[Message]: + def _postprocess_stop_words(self, messages: List[Message], generate_cfg: dict) -> List[Message]: messages = copy.deepcopy(messages) - stop = self.generate_cfg.get('stop', []) + stop = generate_cfg.get('stop', []) # Make sure it stops before stop words. trunc_messages = [] @@ -264,7 +231,6 @@ def _postprocess_stop_words(self, messages: List[Message]) -> List[Message]: # It may ends with 'Observation' when the stop word is 'Observation:'. partial_stop = [] for s in stop: - # TODO: This tokenizer is Qwen-specific. s = tokenizer.tokenize(s)[:-1] if s: s = tokenizer.convert_tokens_to_string(s) @@ -295,50 +261,25 @@ def _truncate_at_stop_word(text: str, stop: List[str]): def retry_model_service( fn, max_retries: int = 10, - exponential_base: float = 1.0, -): - """Retry a function with exponential backoff""" +) -> Any: + """Retry a function""" - num_retries = 0 - delay = 2.0 + num_retries, delay = 0, 1.0 while True: try: return fn() except ModelServiceError as e: - if max_retries <= 0: # no retry - raise e - - # If harmful input or output detected, let it fail - if e.code == 'DataInspectionFailed': - raise e - if 'inappropriate content' in str(e): - raise e - - # Retry is meaningless if the input is too long - if 'maximum context length' in str(e): - raise e - - print_traceback(is_error=False) - - if num_retries >= max_retries: - raise ModelServiceError(exception=Exception(f'Maximum number of retries ({max_retries}) exceeded.')) - - num_retries += 1 - delay *= exponential_base * (1.0 + random.random()) - time.sleep(delay) + num_retries, delay = _raise_or_delay(e, num_retries, delay, max_retries) def retry_model_service_iterator( it_fn, max_retries: int = 10, - exponential_base: float = 1.0, -): - """Retry an iterator with exponential backoff""" - - num_retries = 0 - delay = 2.0 +) -> Iterator: + """Retry an iterator""" + num_retries, delay = 0, 1.0 while True: try: for rsp in it_fn(): @@ -346,24 +287,39 @@ def retry_model_service_iterator( break except ModelServiceError as e: - if max_retries <= 0: # no retry - raise e - - # If harmful input or output detected, let it fail - if e.code == 'DataInspectionFailed': - raise e - if 'inappropriate content' in str(e): - raise e + num_retries, delay = _raise_or_delay(e, num_retries, delay, max_retries) - # Retry is meaningless if the input is too long - if 'maximum context length' in str(e): - raise e - print_traceback(is_error=False) - - if num_retries >= max_retries: - raise ModelServiceError(exception=Exception(f'Maximum number of retries ({max_retries}) exceeded.')) - - num_retries += 1 - delay *= exponential_base * (1.0 + random.random()) - time.sleep(delay) +def _raise_or_delay( + e: ModelServiceError, + num_retries: int, + delay: float, + max_retries: int = 10, + max_delay: float = 300.0, + exponential_base: float = 2.0, +) -> Tuple[int, float]: + """Retry with exponential backoff""" + + if max_retries <= 0: # no retry + raise e + + # If harmful input or output detected, let it fail + if e.code == 'DataInspectionFailed': + raise e + if 'inappropriate content' in str(e): + raise e + + # Retry is meaningless if the input is too long + if 'maximum context length' in str(e): + raise e + + print_traceback(is_error=False) + + if num_retries >= max_retries: + raise ModelServiceError(exception=Exception(f'Maximum number of retries ({max_retries}) exceeded.')) + + num_retries += 1 + jittor = 1.0 + random.random() + delay = min(delay * exponential_base, max_delay) * jittor + time.sleep(delay) + return num_retries, delay diff --git a/qwen_agent/llm/function_calling.py b/qwen_agent/llm/function_calling.py index 9df1b7a..56fe192 100644 --- a/qwen_agent/llm/function_calling.py +++ b/qwen_agent/llm/function_calling.py @@ -1,10 +1,11 @@ import copy +import json from abc import ABC from typing import Dict, Iterator, List, Optional, Union from qwen_agent.llm.base import BaseChatModel from qwen_agent.llm.schema import ASSISTANT, FUNCTION, SYSTEM, USER, ContentItem, FunctionCall, Message -from qwen_agent.utils.utils import get_function_description, has_chinese_chars +from qwen_agent.utils.utils import has_chinese_chars class BaseFnCallModel(BaseChatModel, ABC): @@ -14,11 +15,14 @@ def __init__(self, cfg: Optional[Dict] = None): stop = self.generate_cfg.get('stop', []) self.generate_cfg['stop'] = stop + [x for x in FN_STOP_WORDS if x not in stop] - def _chat_with_functions(self, - messages: List[Union[Message, Dict]], - functions: List[Dict], - stream: bool = True, - delta_stream: bool = False) -> Union[List[Message], Iterator[List[Message]]]: + def _chat_with_functions( + self, + messages: List[Union[Message, Dict]], + functions: List[Dict], + stream: bool, + delta_stream: bool, + generate_cfg: dict, + ) -> Union[List[Message], Iterator[List[Message]]]: if delta_stream: raise NotImplementedError @@ -40,15 +44,20 @@ def _chat_with_functions(self, text_to_complete.content = usr messages = messages[:-2] + [text_to_complete] - return self._chat(messages, stream=stream, delta_stream=delta_stream) + return self._chat(messages, stream=stream, delta_stream=delta_stream, generate_cfg=generate_cfg) def _preprocess_messages(self, messages: List[Message]) -> List[Message]: messages = super()._preprocess_messages(messages) messages = self._preprocess_fncall_messages(messages) return messages - def _postprocess_messages(self, messages: List[Message], fncall_mode: bool) -> List[Message]: - messages = super()._postprocess_messages(messages, fncall_mode=fncall_mode) + def _postprocess_messages( + self, + messages: List[Message], + fncall_mode: bool, + generate_cfg: dict, + ) -> List[Message]: + messages = super()._postprocess_messages(messages, fncall_mode=fncall_mode, generate_cfg=generate_cfg) if fncall_mode: messages = self._postprocess_fncall_messages(messages) return messages @@ -160,7 +169,7 @@ def _postprocess_fncall_messages(self, messages: List[Message], stop_at_fncall: i = item_text.find(f'{FN_NAME}:') if i < 0: # no function call - show_text = remove_special_tokens(item_text) + show_text = remove_incomplete_special_tokens(item_text) if show_text: new_content.append(ContentItem(text=show_text)) continue @@ -169,7 +178,7 @@ def _postprocess_fncall_messages(self, messages: List[Message], stop_at_fncall: answer = item_text[:i].lstrip('\n').rstrip() if answer.endswith('\n'): answer = answer[:-1] - show_text = remove_special_tokens(answer) + show_text = remove_incomplete_special_tokens(answer) if show_text: new_content.append(ContentItem(text=show_text)) if new_content: @@ -207,8 +216,8 @@ def _postprocess_fncall_messages(self, messages: List[Message], stop_at_fncall: role=ASSISTANT, content=[], function_call=FunctionCall( - name=remove_special_tokens(fn_name), - arguments=remove_special_tokens(fn_args), + name=remove_incomplete_special_tokens(fn_name), + arguments=remove_incomplete_special_tokens(fn_args), ), )) @@ -218,17 +227,17 @@ def _postprocess_fncall_messages(self, messages: List[Message], stop_at_fncall: if (result and result[1:]) or answer: # result[1:] == '' is possible and allowed # rm the ' ' after ':' - show_text = remove_special_tokens(result[1:]) + show_text = remove_incomplete_special_tokens(result[1:]) new_messages.append( Message( role=FUNCTION, content=[ContentItem(text=show_text)], - name=remove_special_tokens(fn_name), + name=remove_incomplete_special_tokens(fn_name), )) if answer and answer[1:]: # rm the ' ' after ':' - show_text = remove_special_tokens(answer[1:]) + show_text = remove_incomplete_special_tokens(answer[1:]) if show_text: new_messages.append(Message( role=ASSISTANT, @@ -291,20 +300,46 @@ def _postprocess_fncall_messages(self, messages: List[Message], stop_at_fncall: } -# TODO: This affects users who use the ✿ character accidentally. +def get_function_description(function: Dict) -> str: + """ + Text description of function + """ + tool_desc_template = { + 'zh': '### {name_for_human}\n\n{name_for_model}: {description_for_model} 输入参数:{parameters} {args_format}', + 'en': '### {name_for_human}\n\n{name_for_model}: {description_for_model} Parameters: {parameters} {args_format}' + } + if has_chinese_chars(function): + tool_desc = tool_desc_template['zh'] + else: + tool_desc = tool_desc_template['en'] + + name = function.get('name', None) + name_for_human = function.get('name_for_human', name) + name_for_model = function.get('name_for_model', name) + assert name_for_human and name_for_model + args_format = function.get('args_format', '') + return tool_desc.format(name_for_human=name_for_human, + name_for_model=name_for_model, + description_for_model=function['description'], + parameters=json.dumps(function['parameters'], ensure_ascii=False), + args_format=args_format).rstrip() + + # Mainly for removing incomplete trailing special tokens when streaming the output -def remove_special_tokens(text: str, strip: bool = True) -> str: - text = text.replace('✿:', '✿') - text = text.replace('✿:', '✿') - out = '' - is_special = False - for c in text: - if c == '✿': - is_special = not is_special - continue - if is_special: - continue - out += c - if strip: - out = out.lstrip('\n').rstrip() - return out +def remove_incomplete_special_tokens(text: str) -> str: + special_tokens = (FN_NAME, FN_ARGS, FN_RESULT, FN_EXIT) + text = text.rstrip() + if text.endswith(special_tokens): + for s in special_tokens: + if text.endswith(s): + text = text[:-len(s)] + break + else: + trail_start = text.rfind('✿') + trail_token = text[trail_start:] + for s in special_tokens: + if s.startswith(trail_token): + text = text[:trail_start] + break + text = text.lstrip('\n').rstrip() + return text diff --git a/qwen_agent/llm/oai.py b/qwen_agent/llm/oai.py index 1d58ad9..991e6d8 100644 --- a/qwen_agent/llm/oai.py +++ b/qwen_agent/llm/oai.py @@ -5,7 +5,7 @@ import openai if openai.__version__.startswith('0.'): - from openai.error import OpenAIError + from openai.error import OpenAIError # noqa else: from openai import OpenAIError @@ -50,17 +50,17 @@ def __init__(self, cfg: Optional[Dict] = None): if api_key: api_kwargs['api_key'] = api_key - # OpenAI API v1 does not allow the following args, must pass by extra_body - extra_params = ['top_k', 'repetition_penalty'] - if any((k in self.generate_cfg) for k in extra_params): - self.generate_cfg['extra_body'] = {} - for k in extra_params: - if k in self.generate_cfg: - self.generate_cfg['extra_body'][k] = self.generate_cfg.pop(k) - if 'request_timeout' in self.generate_cfg: - self.generate_cfg['timeout'] = self.generate_cfg.pop('request_timeout') - def _chat_complete_create(*args, **kwargs): + # OpenAI API v1 does not allow the following args, must pass by extra_body + extra_params = ['top_k', 'repetition_penalty'] + if any((k in kwargs) for k in extra_params): + kwargs['extra_body'] = {} + for k in extra_params: + if k in kwargs: + kwargs['extra_body'][k] = kwargs.pop(k) + if 'request_timeout' in kwargs: + kwargs['timeout'] = kwargs.pop('request_timeout') + client = openai.OpenAI(**api_kwargs) return client.chat.completions.create(*args, **kwargs) @@ -69,12 +69,13 @@ def _chat_complete_create(*args, **kwargs): def _chat_stream( self, messages: List[Message], - delta_stream: bool = False, + delta_stream: bool, + generate_cfg: dict, ) -> Iterator[List[Message]]: messages = [msg.model_dump() for msg in messages] logger.debug(f'*{pformat(messages, indent=2)}*') try: - response = self._chat_complete_create(model=self.model, messages=messages, stream=True, **self.generate_cfg) + response = self._chat_complete_create(model=self.model, messages=messages, stream=True, **generate_cfg) if delta_stream: for chunk in response: if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content: @@ -88,14 +89,15 @@ def _chat_stream( except OpenAIError as ex: raise ModelServiceError(exception=ex) - def _chat_no_stream(self, messages: List[Message]) -> List[Message]: + def _chat_no_stream( + self, + messages: List[Message], + generate_cfg: dict, + ) -> List[Message]: messages = [msg.model_dump() for msg in messages] logger.debug(f'*{pformat(messages, indent=2)}*') try: - response = self._chat_complete_create(model=self.model, - messages=messages, - stream=False, - **self.generate_cfg) + response = self._chat_complete_create(model=self.model, messages=messages, stream=False, **generate_cfg) return [Message(ASSISTANT, response.choices[0].message.content)] except OpenAIError as ex: raise ModelServiceError(exception=ex) diff --git a/qwen_agent/llm/qwen_dashscope.py b/qwen_agent/llm/qwen_dashscope.py index 8310466..9b78565 100644 --- a/qwen_agent/llm/qwen_dashscope.py +++ b/qwen_agent/llm/qwen_dashscope.py @@ -29,7 +29,8 @@ def __init__(self, cfg: Optional[Dict] = None): def _chat_stream( self, messages: List[Message], - delta_stream: bool = False, + delta_stream: bool, + generate_cfg: dict, ) -> Iterator[List[Message]]: messages = [msg.model_dump() for msg in messages] logger.debug(f'*{pformat(messages, indent=2)}*') @@ -38,7 +39,7 @@ def _chat_stream( messages=messages, # noqa result_format='message', stream=True, - **self.generate_cfg) + **generate_cfg) if delta_stream: return self._delta_stream_output(response) else: @@ -47,6 +48,7 @@ def _chat_stream( def _chat_no_stream( self, messages: List[Message], + generate_cfg: dict, ) -> List[Message]: messages = [msg.model_dump() for msg in messages] logger.debug(f'*{pformat(messages, indent=2)}*') @@ -55,17 +57,20 @@ def _chat_no_stream( messages=messages, # noqa result_format='message', stream=False, - **self.generate_cfg) + **generate_cfg) if response.status_code == HTTPStatus.OK: return [Message(ASSISTANT, response.output.choices[0].message.content)] else: raise ModelServiceError(code=response.code, message=response.message) - def _chat_with_functions(self, - messages: List[Message], - functions: List[Dict], - stream: bool = True, - delta_stream: bool = False) -> Union[List[Message], Iterator[List[Message]]]: + def _chat_with_functions( + self, + messages: List[Message], + functions: List[Dict], + stream: bool, + delta_stream: bool, + generate_cfg: dict, + ) -> Union[List[Message], Iterator[List[Message]]]: if delta_stream: raise NotImplementedError @@ -74,13 +79,14 @@ def _chat_with_functions(self, # Using text completion prompt = self._build_text_completion_prompt(messages) if stream: - return self._text_completion_stream(prompt, delta_stream) + return self._text_completion_stream(prompt, delta_stream, generate_cfg=generate_cfg) else: - return self._text_completion_no_stream(prompt) + return self._text_completion_no_stream(prompt, generate_cfg=generate_cfg) def _text_completion_no_stream( self, prompt: str, + generate_cfg: dict, ) -> List[Message]: logger.debug(f'*{prompt}*') response = dashscope.Generation.call(self.model, @@ -88,7 +94,7 @@ def _text_completion_no_stream( result_format='message', stream=False, use_raw_prompt=True, - **self.generate_cfg) + **generate_cfg) if response.status_code == HTTPStatus.OK: return [Message(ASSISTANT, response.output.choices[0].message.content)] else: @@ -97,7 +103,8 @@ def _text_completion_no_stream( def _text_completion_stream( self, prompt: str, - delta_stream: bool = False, + delta_stream: bool, + generate_cfg: dict, ) -> Iterator[List[Message]]: logger.debug(f'*{prompt}*') response = dashscope.Generation.call( @@ -106,7 +113,7 @@ def _text_completion_stream( result_format='message', stream=True, use_raw_prompt=True, - **self.generate_cfg) + **generate_cfg) if delta_stream: return self._delta_stream_output(response) else: @@ -142,9 +149,9 @@ def _delta_stream_output(response) -> Iterator[List[Message]]: delay_len = 5 in_delay = False text = '' - for trunk in response: - if trunk.status_code == HTTPStatus.OK: - text = trunk.output.choices[0].message.content + for chunk in response: + if chunk.status_code == HTTPStatus.OK: + text = chunk.output.choices[0].message.content if (len(text) - last_len) <= delay_len: in_delay = True continue @@ -155,14 +162,14 @@ def _delta_stream_output(response) -> Iterator[List[Message]]: yield [Message(ASSISTANT, now_rsp)] last_len = len(real_text) else: - raise ModelServiceError(code=trunk.code, message=trunk.message) + raise ModelServiceError(code=chunk.code, message=chunk.message) if text and (in_delay or (last_len != len(text))): yield [Message(ASSISTANT, text[last_len:])] @staticmethod def _full_stream_output(response) -> Iterator[List[Message]]: - for trunk in response: - if trunk.status_code == HTTPStatus.OK: - yield [Message(ASSISTANT, trunk.output.choices[0].message.content)] + for chunk in response: + if chunk.status_code == HTTPStatus.OK: + yield [Message(ASSISTANT, chunk.output.choices[0].message.content)] else: - raise ModelServiceError(code=trunk.code, message=trunk.message) + raise ModelServiceError(code=chunk.code, message=chunk.message) diff --git a/qwen_agent/llm/qwenvl_dashscope.py b/qwen_agent/llm/qwenvl_dashscope.py index 113f2ba..0adb1ac 100644 --- a/qwen_agent/llm/qwenvl_dashscope.py +++ b/qwen_agent/llm/qwenvl_dashscope.py @@ -32,7 +32,8 @@ def __init__(self, cfg: Optional[Dict] = None): def _chat_stream( self, messages: List[Message], - delta_stream: bool = False, + delta_stream: bool, + generate_cfg: dict, ) -> Iterator[List[Message]]: if delta_stream: raise NotImplementedError @@ -44,17 +45,18 @@ def _chat_stream( messages=messages, result_format='message', stream=True, - **self.generate_cfg) + **generate_cfg) - for trunk in response: - if trunk.status_code == HTTPStatus.OK: - yield _extract_vl_response(trunk) + for chunk in response: + if chunk.status_code == HTTPStatus.OK: + yield _extract_vl_response(chunk) else: - raise ModelServiceError(code=trunk.code, message=trunk.message) + raise ModelServiceError(code=chunk.code, message=chunk.message) def _chat_no_stream( self, messages: List[Message], + generate_cfg: dict, ) -> List[Message]: messages = _format_local_files(messages) messages = [msg.model_dump() for msg in messages] @@ -63,14 +65,14 @@ def _chat_no_stream( messages=messages, result_format='message', stream=False, - **self.generate_cfg) + **generate_cfg) if response.status_code == HTTPStatus.OK: return _extract_vl_response(response=response) else: raise ModelServiceError(code=response.code, message=response.message) - def _postprocess_messages(self, messages: List[Message], fncall_mode: bool) -> List[Message]: - messages = super()._postprocess_messages(messages, fncall_mode=fncall_mode) + def _postprocess_messages(self, messages: List[Message], fncall_mode: bool, generate_cfg: dict) -> List[Message]: + messages = super()._postprocess_messages(messages, fncall_mode=fncall_mode, generate_cfg=generate_cfg) # Make VL return the same format as text models for easy usage messages = format_as_text_messages(messages) return messages diff --git a/qwen_agent/llm/schema.py b/qwen_agent/llm/schema.py index 215537d..c107446 100644 --- a/qwen_agent/llm/schema.py +++ b/qwen_agent/llm/schema.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Union +from typing import List, Literal, Optional, Tuple, Union from pydantic import BaseModel, field_validator, model_validator @@ -6,6 +6,7 @@ ROLE = 'role' CONTENT = 'content' +NAME = 'name' SYSTEM = 'system' USER = 'user' @@ -80,11 +81,21 @@ def check_exclusivity(self): def __repr__(self): return f'ContentItem({self.model_dump()})' - def get_type_and_value(self): + def get_type_and_value(self) -> Tuple[Literal['text', 'image', 'file'], str]: (t, v), = self.model_dump().items() assert t in ('text', 'image', 'file') return t, v + @property + def type(self) -> Literal['text', 'image', 'file']: + t, v = self.get_type_and_value() + return t + + @property + def value(self) -> str: + t, v = self.get_type_and_value() + return v + class Message(BaseModelCompatibleDict): role: str diff --git a/qwen_agent/llm/text_base.py b/qwen_agent/llm/text_base.py index 7f836d5..b80851c 100644 --- a/qwen_agent/llm/text_base.py +++ b/qwen_agent/llm/text_base.py @@ -13,8 +13,13 @@ def _preprocess_messages(self, messages: List[Message]) -> List[Message]: messages = format_as_text_messages(messages) return messages - def _postprocess_messages(self, messages: List[Message], fncall_mode: bool) -> List[Message]: - messages = super()._postprocess_messages(messages, fncall_mode=fncall_mode) + def _postprocess_messages( + self, + messages: List[Message], + fncall_mode: bool, + generate_cfg: dict, + ) -> List[Message]: + messages = super()._postprocess_messages(messages, fncall_mode=fncall_mode, generate_cfg=generate_cfg) messages = format_as_text_messages(messages) return messages diff --git a/qwen_agent/memory/memory.py b/qwen_agent/memory/memory.py index a8c4166..e47f7db 100644 --- a/qwen_agent/memory/memory.py +++ b/qwen_agent/memory/memory.py @@ -8,8 +8,10 @@ from qwen_agent.llm.schema import ASSISTANT, DEFAULT_SYSTEM_MESSAGE, USER, Message from qwen_agent.log import logger from qwen_agent.prompts import GenKeyword +from qwen_agent.settings import DEFAULT_MAX_REF_TOKEN, DEFAULT_PARSER_PAGE_SIZE from qwen_agent.tools import BaseTool -from qwen_agent.utils.utils import get_file_type +from qwen_agent.tools.simple_doc_parser import PARSER_SUPPORTED_FILE_TYPES +from qwen_agent.utils.utils import extract_files_from_messages, extract_text_from_message, get_file_type class Memory(Agent): @@ -26,20 +28,19 @@ def __init__(self, function_list = function_list or [] super().__init__(function_list=['retrieval'] + function_list, llm=llm, system_message=system_message) - self.keygen = GenKeyword(llm=llm) - self.system_files = files or [] def _run(self, messages: List[Message], - max_ref_token: int = 4000, + max_ref_token: int = DEFAULT_MAX_REF_TOKEN, + parser_page_size: int = DEFAULT_PARSER_PAGE_SIZE, lang: str = 'en', ignore_cache: bool = False) -> Iterator[List[Message]]: """This agent is responsible for processing the input files in the message. This method stores the files in the knowledge base, and retrievals the relevant parts based on the query and returning them. - The currently supported file types include: .pdf, .docx, .pptx, and html. + The currently supported file types include: .pdf, .docx, .pptx, .txt, and html. Args: messages: A list of messages. @@ -51,11 +52,12 @@ def _run(self, The message of retrieved documents. """ # process files in messages - session_files = self.get_all_files_of_messages(messages) + session_files = extract_files_from_messages(messages) files = self.system_files + session_files rag_files = [] for file in files: - if (file.split('.')[-1].lower() in ['pdf', 'docx', 'pptx']) or get_file_type(file) == 'html': + f_type = get_file_type(file) + if f_type in PARSER_SUPPORTED_FILE_TYPES and file not in rag_files: rag_files.append(file) if not rag_files: @@ -64,15 +66,12 @@ def _run(self, query = '' # Only retrieval content according to the last user query if exists if messages and messages[-1].role == USER: - if isinstance(messages[-1].content, str): - query = messages[-1].content - else: - for item in messages[-1].content: - if item.text: - query += item.text + query = extract_text_from_message(messages[-1], add_upload_info=False) if query: # Gen keyword - *_, last = self.keygen.run([Message(USER, query)]) + keygen = GenKeyword(llm=self.llm) + *_, last = keygen.run([Message(USER, query)]) + keyword = last[-1].content keyword = keyword.strip() if keyword.startswith('```json'): @@ -87,21 +86,15 @@ def _run(self, except Exception: query = query - content = self._call_tool('retrieval', { - 'query': query, - 'files': rag_files - }, - ignore_cache=ignore_cache, - max_token=max_ref_token) + content = self._call_tool( + 'retrieval', + { + 'query': query, + 'files': rag_files + }, + ignore_cache=ignore_cache, + max_token=max_ref_token, + parser_page_size=parser_page_size, + ) yield [Message(role=ASSISTANT, content=content, name='memory')] - - @staticmethod - def get_all_files_of_messages(messages: List[Message]): - files = [] - for msg in messages: - if isinstance(msg.content, list): - for item in msg.content: - if item.file and item.file not in files: - files.append(item.file) - return files diff --git a/qwen_agent/prompts/__init__.py b/qwen_agent/prompts/__init__.py index 10f3d61..f13ac1d 100644 --- a/qwen_agent/prompts/__init__.py +++ b/qwen_agent/prompts/__init__.py @@ -6,4 +6,10 @@ from .gen_keyword import GenKeyword from .outline_writing import OutlineWriting -__all__ = ['DocQA', 'ContinueWriting', 'OutlineWriting', 'ExpandWriting', 'GenKeyword'] +__all__ = [ + 'DocQA', + 'ContinueWriting', + 'OutlineWriting', + 'ExpandWriting', + 'GenKeyword', +] diff --git a/qwen_agent/prompts/gen_keyword.py b/qwen_agent/prompts/gen_keyword.py index 4268913..bf3ad9c 100644 --- a/qwen_agent/prompts/gen_keyword.py +++ b/qwen_agent/prompts/gen_keyword.py @@ -2,10 +2,10 @@ from typing import Dict, Iterator, List, Optional, Union from qwen_agent import Agent -from qwen_agent.llm import get_chat_model from qwen_agent.llm.base import BaseChatModel from qwen_agent.llm.schema import CONTENT, DEFAULT_SYSTEM_MESSAGE, Message from qwen_agent.tools import BaseTool +from qwen_agent.utils.utils import merge_generate_cfgs PROMPT_TEMPLATE_ZH = """请提取问题中的关键词,需要中英文均有,可以适量补充不在问题中但相关的关键词。关键词尽量切分为动词/名词/形容词等类型,不要长词组。关键词以JSON的格式给出,比如{{"keywords_zh": ["关键词1", "关键词2"], "keywords_en": ["keyword 1", "keyword 2"]}} @@ -52,20 +52,16 @@ class GenKeyword(Agent): - # TODO: Adding a stop word is not conveient! We should fix this later. def __init__(self, function_list: Optional[List[Union[str, Dict, BaseTool]]] = None, llm: Optional[Union[Dict, BaseChatModel]] = None, system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE, **kwargs): - if llm is not None: # TODO: Why this happens? - llm = copy.deepcopy(llm) - if isinstance(llm, dict): - llm = get_chat_model(llm) - stop = llm.generate_cfg.get('stop', []) - key_stop = ['Observation:', 'Observation:\n'] - llm.generate_cfg['stop'] = stop + [x for x in key_stop if x not in stop] super().__init__(function_list, llm, system_message, **kwargs) + self.extra_generate_cfg = merge_generate_cfgs( + base_generate_cfg=self.extra_generate_cfg, + new_generate_cfg={'stop': ['Observation:', 'Observation:\n']}, + ) def _run(self, messages: List[Message], lang: str = 'en', **kwargs) -> Iterator[List[Message]]: messages = copy.deepcopy(messages) diff --git a/qwen_agent/settings.py b/qwen_agent/settings.py new file mode 100644 index 0000000..3ab99cb --- /dev/null +++ b/qwen_agent/settings.py @@ -0,0 +1,11 @@ +# Settings for LLMs +DEFAULT_MAX_TOKEN = 6000 # It hasn't worked yet + +# Settings for agents +MAX_LLM_CALL_PER_RUN = 8 + +DEFAULT_WORKSPACE = 'workspace' + +# Settings for RAG +DEFAULT_MAX_REF_TOKEN = 4000 +DEFAULT_PARSER_PAGE_SIZE = 500 diff --git a/qwen_agent/tools/__init__.py b/qwen_agent/tools/__init__.py index 94724b6..11a51b4 100644 --- a/qwen_agent/tools/__init__.py +++ b/qwen_agent/tools/__init__.py @@ -5,6 +5,7 @@ from .image_gen import ImageGen from .retrieval import Retrieval from .similarity_search import SimilaritySearch +from .simple_doc_parser import SimpleDocParser from .storage import Storage from .web_extractor import WebExtractor @@ -18,5 +19,5 @@ def call_tool(plugin_name: str, plugin_args: str) -> str: __all__ = [ 'BaseTool', 'CodeInterpreter', 'ImageGen', 'AmapWeather', 'TOOL_REGISTRY', 'DocParser', 'SimilaritySearch', - 'Storage', 'Retrieval', 'WebExtractor' + 'Storage', 'Retrieval', 'WebExtractor', 'SimpleDocParser' ] diff --git a/qwen_agent/tools/base.py b/qwen_agent/tools/base.py index ff000dd..3874a64 100644 --- a/qwen_agent/tools/base.py +++ b/qwen_agent/tools/base.py @@ -38,12 +38,6 @@ def __init__(self, cfg: Optional[Dict] = None): f'You must set {self.__class__.__name__}.name, either by @register_tool(name=...) or explicitly setting {self.__class__.__name__}.name' ) - self.name_for_human = self.cfg.get('name_for_human', self.name) - if not hasattr(self, 'args_format'): - self.args_format = self.cfg.get('args_format', '此工具的输入应为JSON对象。') - self.function = self._build_function() - self.file_access = False - @abstractmethod def call(self, params: Union[str, dict], **kwargs) -> Union[str, list, dict]: """The interface for calling tools. @@ -74,7 +68,8 @@ def _verify_json_format_args(self, params: Union[str, dict]) -> Union[str, dict] except Exception: raise ValueError('Parameters cannot be converted to Json Format!') - def _build_function(self) -> dict: + @property + def function(self) -> dict: # Bad naming. It should be `function_info`. return { 'name_for_human': self.name_for_human, 'name': self.name, @@ -82,3 +77,15 @@ def _build_function(self) -> dict: 'parameters': self.parameters, 'args_format': self.args_format } + + @property + def name_for_human(self) -> str: + return self.cfg.get('name_for_human', self.name) + + @property + def args_format(self) -> str: + return self.cfg.get('args_format', '此工具的输入应为JSON对象。') + + @property + def file_access(self) -> bool: + return False diff --git a/qwen_agent/tools/code_interpreter.py b/qwen_agent/tools/code_interpreter.py index 19bb45f..f1ac6eb 100644 --- a/qwen_agent/tools/code_interpreter.py +++ b/qwen_agent/tools/code_interpreter.py @@ -15,7 +15,7 @@ import time import uuid from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union import json5 import matplotlib @@ -23,28 +23,9 @@ from jupyter_client import BlockingKernelClient from qwen_agent.log import logger +from qwen_agent.settings import DEFAULT_WORKSPACE from qwen_agent.tools.base import BaseTool, register_tool -from qwen_agent.utils.utils import extract_code, print_traceback, save_url_to_local_work_dir - -WORK_DIR = os.getenv('M6_CODE_INTERPRETER_WORK_DIR', os.getcwd() + '/workspace/ci_workspace/') - - -def _fix_secure_write_for_code_interpreter(): - if 'linux' in sys.platform.lower(): - os.makedirs(WORK_DIR, exist_ok=True) - fname = os.path.join(WORK_DIR, f'test_file_permission_{os.getpid()}.txt') - if os.path.exists(fname): - os.remove(fname) - with os.fdopen(os.open(fname, os.O_CREAT | os.O_WRONLY | os.O_TRUNC, 0o0600), 'w') as f: - f.write('test') - file_mode = stat.S_IMODE(os.stat(fname).st_mode) & 0o6677 - if file_mode != 0o0600: - os.environ['JUPYTER_ALLOW_INSECURE_WRITES'] = '1' - if os.path.exists(fname): - os.remove(fname) - - -_fix_secure_write_for_code_interpreter() +from qwen_agent.utils.utils import append_signal_handler, extract_code, print_traceback, save_url_to_local_work_dir LAUNCH_KERNEL_PY = """ from ipykernel import kernelapp as app @@ -52,62 +33,13 @@ def _fix_secure_write_for_code_interpreter(): """ INIT_CODE_FILE = str(Path(__file__).absolute().parent / 'resource' / 'code_interpreter_init_kernel.py') - ALIB_FONT_FILE = str(Path(__file__).absolute().parent / 'resource' / 'AlibabaPuHuiTi-3-45-Light.ttf') -_KERNEL_CLIENTS: Dict[int, BlockingKernelClient] = {} +_KERNEL_CLIENTS: Dict[str, BlockingKernelClient] = {} _MISC_SUBPROCESSES: Dict[str, subprocess.Popen] = {} -def _start_kernel(pid) -> BlockingKernelClient: - connection_file = os.path.join(WORK_DIR, f'kernel_connection_file_{pid}.json') - launch_kernel_script = os.path.join(WORK_DIR, f'launch_kernel_{pid}.py') - for f in [connection_file, launch_kernel_script]: - if os.path.exists(f): - logger.info(f'WARNING: {f} already exists') - os.remove(f) - - os.makedirs(WORK_DIR, exist_ok=True) - with open(launch_kernel_script, 'w') as fout: - fout.write(LAUNCH_KERNEL_PY) - - kernel_process = subprocess.Popen( - [ - sys.executable, - launch_kernel_script, - '--IPKernelApp.connection_file', - connection_file, - '--matplotlib=inline', - '--quiet', - ], - cwd=WORK_DIR, - ) - _MISC_SUBPROCESSES[f'kc_{kernel_process.pid}'] = kernel_process - logger.info(f"INFO: kernel process's PID = {kernel_process.pid}") - - # Wait for kernel connection file to be written - while True: - if not os.path.isfile(connection_file): - time.sleep(0.1) - else: - # Keep looping if JSON parsing fails, file may be partially written - try: - with open(connection_file, 'r') as fp: - json.load(fp) - break - except json.JSONDecodeError: - pass - - # Client - kc = BlockingKernelClient(connection_file=connection_file) - asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) - kc.load_connection_file() - kc.start_channels() - kc.wait_for_ready() - return kc - - -def _kill_kernels_and_subprocesses(sig_num=None, _frame=None): +def _kill_kernels_and_subprocesses(_sig_num=None, _frame=None): for v in _KERNEL_CLIENTS.values(): v.shutdown() for k in list(_KERNEL_CLIENTS.keys()): @@ -118,117 +50,11 @@ def _kill_kernels_and_subprocesses(sig_num=None, _frame=None): for k in list(_MISC_SUBPROCESSES.keys()): del _MISC_SUBPROCESSES[k] - if sig_num == signal.SIGINT: - raise KeyboardInterrupt() - +# Make sure all subprocesses are terminated even if killed abnormally: atexit.register(_kill_kernels_and_subprocesses) -signal.signal(signal.SIGTERM, _kill_kernels_and_subprocesses) -signal.signal(signal.SIGINT, _kill_kernels_and_subprocesses) - - -def _serve_image(image_base64: str) -> str: - image_file = f'{uuid.uuid4()}.png' - local_image_file = os.path.join(WORK_DIR, image_file) - - png_bytes = base64.b64decode(image_base64) - assert isinstance(png_bytes, bytes) - bytes_io = io.BytesIO(png_bytes) - PIL.Image.open(bytes_io).save(local_image_file, 'png') - - static_url = os.getenv('M6_CODE_INTERPRETER_STATIC_URL', 'http://127.0.0.1:7865/static') - - # Hotfix: Temporarily generate image URL proxies for code interpreter to display in gradio - # Todo: Generate real url - if static_url == 'http://127.0.0.1:7865/static': - if 'image_service' not in _MISC_SUBPROCESSES: - try: - # run a fastapi server for image show in gradio demo by http://127.0.0.1:7865/figure_name - _MISC_SUBPROCESSES['image_service'] = subprocess.Popen( - ['python', Path(__file__).absolute().parent / 'resource' / 'image_service.py']) - except Exception: - print_traceback() - - image_url = f'{static_url}/{image_file}' - - return image_url - - -def _escape_ansi(line: str) -> str: - ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]') - return ansi_escape.sub('', line) - - -def _fix_matplotlib_cjk_font_issue(): - ttf_name = os.path.basename(ALIB_FONT_FILE) - local_ttf = os.path.join(os.path.abspath(os.path.join(matplotlib.matplotlib_fname(), os.path.pardir)), 'fonts', - 'ttf', ttf_name) - if not os.path.exists(local_ttf): - try: - shutil.copy(ALIB_FONT_FILE, local_ttf) - font_list_cache = os.path.join(matplotlib.get_cachedir(), 'fontlist-*.json') - for cache_file in glob.glob(font_list_cache): - with open(cache_file) as fin: - cache_content = fin.read() - if ttf_name not in cache_content: - os.remove(cache_file) - except Exception: - print_traceback() - - -def _execute_code(kc: BlockingKernelClient, code: str) -> str: - kc.wait_for_ready() - kc.execute(code) - result = '' - image_idx = 0 - while True: - text = '' - image = '' - finished = False - msg_type = 'error' - try: - msg = kc.get_iopub_msg() - msg_type = msg['msg_type'] - if msg_type == 'status': - if msg['content'].get('execution_state') == 'idle': - finished = True - elif msg_type == 'execute_result': - text = msg['content']['data'].get('text/plain', '') - if 'image/png' in msg['content']['data']: - image_b64 = msg['content']['data']['image/png'] - image_url = _serve_image(image_b64) - image_idx += 1 - image = '![fig-%03d](%s)' % (image_idx, image_url) - elif msg_type == 'display_data': - if 'image/png' in msg['content']['data']: - image_b64 = msg['content']['data']['image/png'] - image_url = _serve_image(image_b64) - image_idx += 1 - image = '![fig-%03d](%s)' % (image_idx, image_url) - else: - text = msg['content']['data'].get('text/plain', '') - elif msg_type == 'stream': - msg_type = msg['content']['name'] # stdout, stderr - text = msg['content']['text'] - elif msg_type == 'error': - text = _escape_ansi('\n'.join(msg['content']['traceback'])) - if 'M6_CODE_INTERPRETER_TIMEOUT' in text: - text = 'Timeout: Code execution exceeded the time limit.' - except queue.Empty: - text = 'Timeout: Code execution exceeded the time limit.' - finished = True - except Exception: - text = 'The code interpreter encountered an unexpected error.' - print_traceback() - finished = True - if text: - result += f'\n\n{msg_type}:\n\n```\n{text}\n```' - if image: - result += f'\n\n{image}' - if finished: - break - result = result.lstrip('\n') - return result +append_signal_handler(signal.SIGTERM, _kill_kernels_and_subprocesses) +append_signal_handler(signal.SIGINT, _kill_kernels_and_subprocesses) @register_tool('code_interpreter') @@ -237,9 +63,19 @@ class CodeInterpreter(BaseTool): parameters = [{'name': 'code', 'type': 'string', 'description': '待执行的代码', 'required': True}] def __init__(self, cfg: Optional[Dict] = None): - self.args_format = '此工具的输入应为Markdown代码块。' super().__init__(cfg) - self.file_access = True + default_work_dir = os.getenv('M6_CODE_INTERPRETER_WORK_DIR', + os.path.join(DEFAULT_WORKSPACE, 'tools', 'code_interpreter')) + self.work_dir: str = self.cfg.get('work_dir', default_work_dir) + self.instance_id: str = str(uuid.uuid4()) + + @property + def args_format(self) -> str: + return self.cfg.get('args_format', '此工具的输入应为Markdown代码块。') + + @property + def file_access(self) -> bool: + return True def call(self, params: Union[str, dict], files: List[str] = None, timeout: Optional[int] = 30, **kwargs) -> str: try: @@ -252,24 +88,26 @@ def call(self, params: Union[str, dict], files: List[str] = None, timeout: Optio return '' # download file if files: - os.makedirs(WORK_DIR, exist_ok=True) + os.makedirs(self.work_dir, exist_ok=True) for file in files: try: - save_url_to_local_work_dir(file, WORK_DIR) + save_url_to_local_work_dir(file, self.work_dir) except Exception: print_traceback() - pid: int = os.getpid() - if pid in _KERNEL_CLIENTS: - kc = _KERNEL_CLIENTS[pid] + kernel_id: str = f'{self.instance_id}_{os.getpid()}' + if kernel_id in _KERNEL_CLIENTS: + kc = _KERNEL_CLIENTS[kernel_id] else: _fix_matplotlib_cjk_font_issue() - kc = _start_kernel(pid) + self._fix_secure_write_for_code_interpreter() + kc, subproc = self._start_kernel(kernel_id) with open(INIT_CODE_FILE) as fin: start_code = fin.read() start_code = start_code.replace('{{M6_FONT_PATH}}', repr(ALIB_FONT_FILE)[1:-1]) - logger.info(_execute_code(kc, start_code)) - _KERNEL_CLIENTS[pid] = kc + logger.info(self._execute_code(kc, start_code)) + _KERNEL_CLIENTS[kernel_id] = kc + _MISC_SUBPROCESSES[kernel_id] = subproc if timeout: code = f'_M6CountdownTimer.start({timeout})\n{code}' @@ -281,13 +119,184 @@ def call(self, params: Union[str, dict], files: List[str] = None, timeout: Optio fixed_code.append('plt.rcParams["font.family"] = _m6_font_prop.get_name()') fixed_code = '\n'.join(fixed_code) fixed_code += '\n\n' # Prevent code not executing in notebook due to no line breaks at the end - result = _execute_code(kc, fixed_code) + result = self._execute_code(kc, fixed_code) if timeout: - _execute_code(kc, '_M6CountdownTimer.cancel()') + self._execute_code(kc, '_M6CountdownTimer.cancel()') return result if result.strip() else 'Finished execution.' + def __del__(self): + # Recycle the jupyter subprocess: + k: str = f'{self.instance_id}_{os.getpid()}' + if k in _KERNEL_CLIENTS: + _KERNEL_CLIENTS[k].shutdown() + del _KERNEL_CLIENTS[k] + if k in _MISC_SUBPROCESSES: + _MISC_SUBPROCESSES[k].terminate() + del _MISC_SUBPROCESSES[k] + + def _fix_secure_write_for_code_interpreter(self): + if 'linux' in sys.platform.lower(): + os.makedirs(self.work_dir, exist_ok=True) + fname = os.path.join(self.work_dir, f'test_file_permission_{os.getpid()}.txt') + if os.path.exists(fname): + os.remove(fname) + with os.fdopen(os.open(fname, os.O_CREAT | os.O_WRONLY | os.O_TRUNC, 0o0600), 'w') as f: + f.write('test') + file_mode = stat.S_IMODE(os.stat(fname).st_mode) & 0o6677 + if file_mode != 0o0600: + os.environ['JUPYTER_ALLOW_INSECURE_WRITES'] = '1' + if os.path.exists(fname): + os.remove(fname) + + def _start_kernel(self, kernel_id: str) -> Tuple[BlockingKernelClient, subprocess.Popen]: + connection_file = os.path.join(self.work_dir, f'kernel_connection_file_{kernel_id}.json') + launch_kernel_script = os.path.join(self.work_dir, f'launch_kernel_{kernel_id}.py') + for f in [connection_file, launch_kernel_script]: + if os.path.exists(f): + logger.info(f'WARNING: {f} already exists') + os.remove(f) + + os.makedirs(self.work_dir, exist_ok=True) + with open(launch_kernel_script, 'w') as fout: + fout.write(LAUNCH_KERNEL_PY) + + kernel_process = subprocess.Popen( + [ + sys.executable, + os.path.abspath(launch_kernel_script), + '--IPKernelApp.connection_file', + os.path.abspath(connection_file), + '--matplotlib=inline', + '--quiet', + ], + cwd=os.path.abspath(self.work_dir), + ) + logger.info(f"INFO: kernel process's PID = {kernel_process.pid}") + + # Wait for kernel connection file to be written + while True: + if not os.path.isfile(connection_file): + time.sleep(0.1) + else: + # Keep looping if JSON parsing fails, file may be partially written + try: + with open(connection_file, 'r') as fp: + json.load(fp) + break + except json.JSONDecodeError: + pass + + # Client + kc = BlockingKernelClient(connection_file=connection_file) + asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) + kc.load_connection_file() + kc.start_channels() + kc.wait_for_ready() + return kc, kernel_process + + def _execute_code(self, kc: BlockingKernelClient, code: str) -> str: + kc.wait_for_ready() + kc.execute(code) + result = '' + image_idx = 0 + while True: + text = '' + image = '' + finished = False + msg_type = 'error' + try: + msg = kc.get_iopub_msg() + msg_type = msg['msg_type'] + if msg_type == 'status': + if msg['content'].get('execution_state') == 'idle': + finished = True + elif msg_type == 'execute_result': + text = msg['content']['data'].get('text/plain', '') + if 'image/png' in msg['content']['data']: + image_b64 = msg['content']['data']['image/png'] + image_url = self._serve_image(image_b64) + image_idx += 1 + image = '![fig-%03d](%s)' % (image_idx, image_url) + elif msg_type == 'display_data': + if 'image/png' in msg['content']['data']: + image_b64 = msg['content']['data']['image/png'] + image_url = self._serve_image(image_b64) + image_idx += 1 + image = '![fig-%03d](%s)' % (image_idx, image_url) + else: + text = msg['content']['data'].get('text/plain', '') + elif msg_type == 'stream': + msg_type = msg['content']['name'] # stdout, stderr + text = msg['content']['text'] + elif msg_type == 'error': + text = _escape_ansi('\n'.join(msg['content']['traceback'])) + if 'M6_CODE_INTERPRETER_TIMEOUT' in text: + text = 'Timeout: Code execution exceeded the time limit.' + except queue.Empty: + text = 'Timeout: Code execution exceeded the time limit.' + finished = True + except Exception: + text = 'The code interpreter encountered an unexpected error.' + print_traceback() + finished = True + if text: + result += f'\n\n{msg_type}:\n\n```\n{text}\n```' + if image: + result += f'\n\n{image}' + if finished: + break + result = result.lstrip('\n') + return result + + # TODO: Remove this buggy image service and return local_image_file directly. + def _serve_image(self, image_base64: str) -> str: + image_file = f'{uuid.uuid4()}.png' + local_image_file = os.path.join(self.work_dir, image_file) + + png_bytes = base64.b64decode(image_base64) + assert isinstance(png_bytes, bytes) + bytes_io = io.BytesIO(png_bytes) + PIL.Image.open(bytes_io).save(local_image_file, 'png') + + static_url = os.getenv('M6_CODE_INTERPRETER_STATIC_URL', 'http://127.0.0.1:7865/static') + + # Hotfix: Temporarily generate image URL proxies for code interpreter to display in gradio + if static_url == 'http://127.0.0.1:7865/static': + if 'image_service' not in _MISC_SUBPROCESSES: + try: + # run a fastapi server for image show in gradio demo by http://127.0.0.1:7865/{image_file} + _MISC_SUBPROCESSES['image_service'] = subprocess.Popen( + ['python', Path(__file__).absolute().parent / 'resource' / 'image_service.py']) + except Exception: + print_traceback() + + image_url = f'{static_url}/{image_file}' + return image_url + + +def _fix_matplotlib_cjk_font_issue(): + ttf_name = os.path.basename(ALIB_FONT_FILE) + local_ttf = os.path.join(os.path.abspath(os.path.join(matplotlib.matplotlib_fname(), os.path.pardir)), 'fonts', + 'ttf', ttf_name) + if not os.path.exists(local_ttf): + try: + shutil.copy(ALIB_FONT_FILE, local_ttf) + font_list_cache = os.path.join(matplotlib.get_cachedir(), 'fontlist-*.json') + for cache_file in glob.glob(font_list_cache): + with open(cache_file) as fin: + cache_content = fin.read() + if ttf_name not in cache_content: + os.remove(cache_file) + except Exception: + print_traceback() + + +def _escape_ansi(line: str) -> str: + ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]') + return ansi_escape.sub('', line) + # # The _BasePolicy and AnyThreadEventLoopPolicy below are borrowed from Tornado. diff --git a/qwen_agent/tools/doc_parser.py b/qwen_agent/tools/doc_parser.py index ebda4b3..c1c0869 100644 --- a/qwen_agent/tools/doc_parser.py +++ b/qwen_agent/tools/doc_parser.py @@ -1,162 +1,301 @@ -import datetime import json import os import re -from typing import Dict, Optional, Union -from urllib.parse import unquote, urlparse +import time +from typing import Dict, List, Optional, Union import json5 from pydantic import BaseModel from qwen_agent.log import logger +from qwen_agent.settings import DEFAULT_MAX_REF_TOKEN, DEFAULT_PARSER_PAGE_SIZE, DEFAULT_WORKSPACE from qwen_agent.tools.base import BaseTool, register_tool -from qwen_agent.tools.storage import Storage -from qwen_agent.utils.doc_parser import parse_doc, parse_html_bs -from qwen_agent.utils.utils import (get_file_type, hash_sha256, is_local_path, print_traceback, - save_url_to_local_work_dir) +from qwen_agent.tools.simple_doc_parser import PARAGRAPH_SPLIT_SYMBOL, SimpleDocParser, get_plain_doc +from qwen_agent.tools.storage import KeyNotExistsError, Storage +from qwen_agent.utils.tokenization_qwen import count_tokens, tokenizer +from qwen_agent.utils.utils import get_basename_from_url, hash_sha256 -class FileTypeNotImplError(NotImplementedError): - pass +class Chunk(BaseModel): + content: str + metadata: dict + token: int + + def __init__(self, content: str, metadata: dict, token: int): + super().__init__(content=content, metadata=metadata, token=token) + + def to_dict(self) -> dict: + return {'content': self.content, 'metadata': self.metadata, 'token': self.token} class Record(BaseModel): url: str - time: str - source: str - raw: list + raw: List[Chunk] title: str - topic: str - checked: bool - session: list - def to_dict(self) -> dict: - return { - 'url': self.url, - 'time': self.time, - 'source': self.source, - 'raw': self.raw, - 'title': self.title, - 'topic': self.topic, - 'checked': self.checked, - 'session': self.session - } - - -def sanitize_chrome_file_path(file_path: str) -> str: - # For Linux and macOS. - if os.path.exists(file_path): - return file_path - - # For native Windows, drop the leading '/' in '/C:/' - win_path = file_path - if win_path.startswith('/'): - win_path = win_path[1:] - if os.path.exists(win_path): - return win_path - - # For Windows + WSL. - if re.match(r'^[A-Za-z]:/', win_path): - wsl_path = f'/mnt/{win_path[0].lower()}/{win_path[3:]}' - if os.path.exists(wsl_path): - return wsl_path - - # For native Windows, replace / with \. - win_path = win_path.replace('/', '\\') - if os.path.exists(win_path): - return win_path - - return file_path - - -def process_file(url: str, db: Storage = None): - logger.info('Starting cache pages...') - url = url - if url.split('.')[-1].lower() in ['pdf', 'docx', 'pptx']: - date1 = datetime.datetime.now() - - if url.startswith('https://') or url.startswith('http://') or re.match(r'^[A-Za-z]:\\', url) or re.match( - r'^[A-Za-z]:/', url): - pdf_path = url - else: - parsed_url = urlparse(url) - pdf_path = unquote(parsed_url.path) - pdf_path = sanitize_chrome_file_path(pdf_path) - - try: - if not is_local_path(url): - # download - file_tmp_path = save_url_to_local_work_dir(pdf_path, - db.root, - new_name=hash_sha256(url) + '.' + - pdf_path.split('.')[-1].lower()) - pdf_content = parse_doc(file_tmp_path) - else: - pdf_content = parse_doc(pdf_path) - date2 = datetime.datetime.now() - logger.info('Parsing pdf time: ' + str(date2 - date1)) - content = pdf_content - source = 'doc' - title = pdf_path.split('/')[-1].split('\\')[-1].split('.')[0] - except Exception: - print_traceback() - return 'failed' - else: - if not is_local_path(url): - file_tmp_path = save_url_to_local_work_dir(url, db.root, new_name=hash_sha256(url)) - else: - file_tmp_path = url - file_source = get_file_type(file_tmp_path) - if file_source == 'html': - try: - content = parse_html_bs(file_tmp_path) - title = content[0]['metadata']['title'] - except Exception: - print_traceback() - return 'failed' - source = 'html' - else: - raise FileTypeNotImplError - - # save real data - now_time = str(datetime.date.today()) - new_record = Record(url=url, - time=now_time, - source=source, - raw=content, - title=title, - topic='', - checked=True, - session=[]).to_dict() - new_record_str = json.dumps(new_record, ensure_ascii=False) - db.put(hash_sha256(url), new_record_str) + def __init__(self, url: str, raw: List[Chunk], title: str): + super().__init__(url=url, raw=raw, title=title) - return new_record + def to_dict(self) -> dict: + return {'url': self.url, 'raw': [x.to_dict() for x in self.raw], 'title': self.title} @register_tool('doc_parser') class DocParser(BaseTool): - description = '解析并存储一个文件,返回解析后的文件内容' - parameters = [{'name': 'url', 'type': 'string', 'description': '待解析的文件的路径', 'required': True}] + description = '对一个文件进行内容提取和分块、返回分块后的文件内容' + parameters = [{ + 'name': 'url', + 'type': 'string', + 'description': '待解析的文件的路径,可以是一个本地路径或可下载的http(s)链接', + 'required': True + }] def __init__(self, cfg: Optional[Dict] = None): super().__init__(cfg) - self.data_root = self.cfg.get('path', 'workspace/default_doc_parser_data_path') + self.data_root = self.cfg.get('path', os.path.join(DEFAULT_WORKSPACE, 'tools', self.name)) self.db = Storage({'storage_root_path': self.data_root}) - def call(self, params: Union[str, dict], ignore_cache: bool = False) -> dict: - """Parse file by url, and return the formatted content.""" + self.doc_extractor = SimpleDocParser({'structured_doc': True}) - params = self._verify_json_format_args(params) + def call(self, + params: Union[str, dict], + ignore_cache: bool = False, + parser_page_size: int = DEFAULT_PARSER_PAGE_SIZE, + max_token: int = DEFAULT_MAX_REF_TOKEN) -> dict: + """Extracting and blocking - if ignore_cache: - record = process_file(url=params['url'], db=self.db) - else: + Returns: + Parse doc as the following chunks: + { + 'url': 'This is the url of this file', + 'title': 'This is the extracted title of this file', + 'raw': [ + { + 'content': 'This is one chunk', + 'token': 'The token number', + 'metadata': {} # some information of this chunk + }, + ..., + ] + } + """ + + params = self._verify_json_format_args(params) + url = params['url'] + cached_name_ori = f'{hash_sha256(url)}_ori' + cached_name_chunking = f'{hash_sha256(url)}_{str(parser_page_size)}' + doc = None + if not ignore_cache: try: - record = self.db.get(hash_sha256(params['url'])) + # Directly load the chunked doc + record = self.db.get(cached_name_chunking) record = json5.loads(record) + return record + except KeyNotExistsError: + try: + # Directly load the parsed doc + doc = json5.loads(self.db.get(cached_name_ori)) + except KeyNotExistsError: + pass + total_token = 0 + if not doc: + logger.info(f'Start parsing {url}...') + time1 = time.time() + doc = self.doc_extractor.call({'url': url}) + for page in doc: + for para in page['content']: + # Todo: More attribute types + para['token'] = count_tokens(para.get('text', para.get('table'))) + total_token += para['token'] + time2 = time.time() + logger.info(f'Finished parsing {url}. Time spent: {time2 - time1} seconds.') + # Cache the parsing doc + self.db.put(cached_name_ori, json.dumps(doc, ensure_ascii=False, indent=2)) + else: + for page in doc: + for para in page['content']: + total_token += para['token'] + + if doc and 'title' in doc[0]: + title = doc[0]['title'] + else: + title = get_basename_from_url(url) + + logger.info(f'Start chunking {url} ({title})...') + time2 = time.time() + if total_token <= max_token: + # The whole doc is one chunk + content = [ + Chunk(content=get_plain_doc(doc), + metadata={ + 'source': url, + 'title': title, + 'chunk_id': 0 + }, + token=total_token) + ] + cached_name_chunking = f'{hash_sha256(url)}_without_chunking' + else: + content = self.split_doc_to_chunk(doc, url, title=title, parser_page_size=parser_page_size) + + time3 = time.time() + logger.info(f'Finished chunking {url} ({title}). Time spent: {time3 - time2} seconds.') + + # save the document data + new_record = Record(url=url, raw=content, title=title).to_dict() + new_record_str = json.dumps(new_record, ensure_ascii=False) + self.db.put(cached_name_chunking, new_record_str) + return new_record + + def split_doc_to_chunk(self, + doc: List[dict], + path: str, + title: str = '', + parser_page_size: int = DEFAULT_PARSER_PAGE_SIZE) -> List[Chunk]: + res = [] + chunk = [] + available_token = parser_page_size + has_para = False + for page in doc: + page_num = page['page_num'] + if not chunk or f'[page: {str(page_num)}]' != chunk[0]: + chunk.append(f'[page: {str(page_num)}]') + idx = 0 + len_para = len(page['content']) + while idx < len_para: + if not chunk: + chunk.append(f'[page: {str(page_num)}]') + para = page['content'][idx] + txt = para.get('text', para.get('table')) + token = para['token'] + if token <= available_token: + available_token -= token + chunk.append([txt, page_num]) + has_para = True + idx += 1 + else: + if has_para: + # Record one chunk + if isinstance(chunk[-1], str) and re.fullmatch(r'^\[page: \d+\]$', chunk[-1]) is not None: + chunk.pop() # Redundant page information + res.append( + Chunk(content=PARAGRAPH_SPLIT_SYMBOL.join( + [x if isinstance(x, str) else x[0] for x in chunk]), + metadata={ + 'source': path, + 'title': title, + 'chunk_id': len(res) + }, + token=parser_page_size - available_token)) + + # Define new chunk + overlap_txt = self._get_last_part(chunk) + if overlap_txt.strip(): + chunk = [f'[page: {str(chunk[-1][1])}]', overlap_txt] + has_para = False + available_token = parser_page_size - count_tokens(overlap_txt) + else: + chunk = [] + has_para = False + available_token = parser_page_size + else: + # There are excessively long paragraphs present + # Split paragraph to sentences + _sentences = re.split(r'\. |。', txt) + sentences = [] + for s in _sentences: + token = count_tokens(s) + if not s.strip() or token == 0: + continue + if token <= available_token: + sentences.append([s, token]) + else: + # Limit the length of a sentence to chunk size + token_list = tokenizer.tokenize(s) + for si in range(0, len(token_list), available_token): + ss = tokenizer.convert_tokens_to_string( + token_list[si:min(len(token_list), si + available_token)]) + sentences.append([ss, min(available_token, len(token_list) - si)]) + for s, token in sentences: + if not chunk: + chunk.append(f'[page: {str(page_num)}]') + + if token <= available_token or (not has_para): + # Be sure to add at least one sentence + # (not has_para) is a patch of the previous sentence splitting + available_token -= token + chunk.append([s, page_num]) + has_para = True + else: + assert has_para + if isinstance(chunk[-1], str) and re.fullmatch(r'^\[page: \d+\]$', + chunk[-1]) is not None: + chunk.pop() # Redundant page information + res.append( + Chunk(content=PARAGRAPH_SPLIT_SYMBOL.join( + [x if isinstance(x, str) else x[0] for x in chunk]), + metadata={ + 'source': path, + 'title': title, + 'chunk_id': len(res) + }, + token=parser_page_size - available_token)) + + overlap_txt = self._get_last_part(chunk) + if overlap_txt.strip(): + chunk = [f'[page: {str(chunk[-1][1])}]', overlap_txt] + has_para = False + available_token = parser_page_size - count_tokens(overlap_txt) + else: + chunk = [] + has_para = False + available_token = parser_page_size + # Has split this paragraph by sentence + idx += 1 + if has_para: + if isinstance(chunk[-1], str) and re.fullmatch(r'^\[page: \d+\]$', chunk[-1]) is not None: + chunk.pop() # Redundant page information + res.append( + Chunk(content=PARAGRAPH_SPLIT_SYMBOL.join([x if isinstance(x, str) else x[0] for x in chunk]), + metadata={ + 'source': path, + 'title': title, + 'chunk_id': len(res) + }, + token=parser_page_size - available_token)) - except Exception: - record = process_file(url=params['url'], db=self.db) + return res - return record + def _get_last_part(self, chunk: list) -> str: + overlap = '' + need_page = chunk[-1][1] # Only need this page to prepend + available_len = 150 + for i in range(len(chunk) - 1, -1, -1): + if chunk[i][1] != need_page: + return overlap + para = chunk[i][0] + if len(para) <= available_len: + if overlap: + overlap = f'{para}{PARAGRAPH_SPLIT_SYMBOL}{overlap}' + else: + overlap = f'{para}' + available_len -= len(para) + continue + sentence_split_symbol = '. ' + if '。' in para: + sentence_split_symbol = '。' + sentences = re.split(r'\. |。', para) + sentences = [sentence.strip() for sentence in sentences if sentence] + for j in range(len(sentences) - 1, -1, -1): + sent = sentences[j] + if not sent.strip(): + continue + if len(sent) <= available_len: + if overlap: + overlap = f'{sent}{sentence_split_symbol}{overlap}' + else: + overlap = f'{sent}' + available_len -= len(sent) + else: + return overlap + return overlap diff --git a/qwen_agent/tools/resource/code_interpreter_init_kernel.py b/qwen_agent/tools/resource/code_interpreter_init_kernel.py index 258739f..193eff6 100644 --- a/qwen_agent/tools/resource/code_interpreter_init_kernel.py +++ b/qwen_agent/tools/resource/code_interpreter_init_kernel.py @@ -34,14 +34,14 @@ def start(cls, timeout: int): try: signal.alarm(timeout) except AttributeError: # windows - pass # TODO: I haven't found a solution that works with jupyter yet. + pass # I haven't found a timeout solution that works with windows + jupyter yet. @classmethod def cancel(cls): try: signal.alarm(0) except AttributeError: # windows - pass # TODO + pass sns.set_theme() diff --git a/qwen_agent/tools/resource/image_service.py b/qwen_agent/tools/resource/image_service.py index 482daa0..2d12580 100644 --- a/qwen_agent/tools/resource/image_service.py +++ b/qwen_agent/tools/resource/image_service.py @@ -17,7 +17,12 @@ allow_headers=['*'], ) -app.mount('/static', StaticFiles(directory=os.getcwd() + '/workspace/ci_workspace/'), name='static') +# TODO: This is buggy if workspace is modified. To be removed. +app.mount( + '/static', + StaticFiles(directory=os.path.abspath('workspace/tools/code_interpreter/')), + name='static', +) if __name__ == '__main__': uvicorn.run(app='image_service:app', port=7865) diff --git a/qwen_agent/tools/retrieval.py b/qwen_agent/tools/retrieval.py index 142ea81..c724887 100644 --- a/qwen_agent/tools/retrieval.py +++ b/qwen_agent/tools/retrieval.py @@ -1,23 +1,11 @@ -from typing import Dict, List, Optional, Union +from typing import Dict, Optional, Union import json5 -from qwen_agent.log import logger +from qwen_agent.settings import DEFAULT_MAX_REF_TOKEN, DEFAULT_PARSER_PAGE_SIZE from qwen_agent.tools.base import BaseTool, register_tool -from qwen_agent.utils.utils import get_basename_from_url, print_traceback - -from .doc_parser import DocParser, FileTypeNotImplError -from .similarity_search import RefMaterialInput, RefMaterialInputItem, SimilaritySearch - - -def format_records(records: List[Dict]): - formatted_records = [] - for record in records: - formatted_records.append( - RefMaterialInput( - url=get_basename_from_url(record['url']), - text=[RefMaterialInputItem(content=x['page_content'], token=x['token']) for x in record['raw']])) - return formatted_records +from qwen_agent.tools.doc_parser import DocParser, Record +from qwen_agent.tools.similarity_search import SimilaritySearch @register_tool('retrieval') @@ -26,14 +14,14 @@ class Retrieval(BaseTool): parameters = [{ 'name': 'query', 'type': 'string', - 'description': '问题,需要从文档中检索和这个问题有关的内容' + 'description': '在这里列出关键词,用逗号分隔,目的是方便在文档中匹配到相关的内容,由于文档可能多语言,关键词最好中英文都有。', }, { 'name': 'files', 'type': 'array', 'items': { 'type': 'string' }, - 'description': '待解析的文件路径列表', + 'description': '待解析的文件路径列表,支持本地文件路径或可下载的http(s)链接。', 'required': True }] @@ -42,7 +30,12 @@ def __init__(self, cfg: Optional[Dict] = None): self.doc_parse = DocParser() self.search = SimilaritySearch() - def call(self, params: Union[str, dict], ignore_cache: bool = False, max_token: int = 4000) -> list: + def call(self, + params: Union[str, dict], + ignore_cache: bool = False, + max_token: int = DEFAULT_MAX_REF_TOKEN, + parser_page_size: int = DEFAULT_PARSER_PAGE_SIZE, + **kwargs) -> list: """RAG tool. Step1: Parse and save files @@ -52,9 +45,10 @@ def call(self, params: Union[str, dict], ignore_cache: bool = False, max_token: params: The files and query. ignore_cache: When set to True, overwrite the same documents that have been parsed before. max_token: Maximum retrieval length. + parser_page_size: The size of one page for doc parser. Returns: - The retrieved file list. + The parsed file list or retrieved file list. """ params = self._verify_json_format_args(params) @@ -63,28 +57,16 @@ def call(self, params: Union[str, dict], ignore_cache: bool = False, max_token: files = json5.loads(files) records = [] for file in files: - try: - _record = self.doc_parse.call(params={'url': file}, ignore_cache=ignore_cache) - records.append(_record) - except FileTypeNotImplError: - logger.warning( - 'Only Parsing the Following File Types: [\'web page\', \'.pdf\', \'.docx\', \'.pptx\'] to knowledge base!' - ) - except Exception: - print_traceback() + _record = self.doc_parse.call(params={'url': file}, + ignore_cache=ignore_cache, + parser_page_size=parser_page_size, + max_token=max_token) + records.append(_record) query = params.get('query', '') if query and records: - records = format_records(records) - return self._retrieve_content(query, records, max_token) + return self.search.call(params={'query': query}, + docs=[Record(**rec) for rec in records], + max_token=max_token) else: return records - - def _retrieve_content(self, query: str, records: List[RefMaterialInput], max_token=4000) -> List[Dict]: - single_max_token = int(max_token / len(records)) - _ref_list = [] - for record in records: - # Retrieval for query - now_ref_list = self.search.call(params={'query': query}, doc=record, max_token=single_max_token) - _ref_list.append(now_ref_list) - return _ref_list diff --git a/qwen_agent/tools/similarity_search.py b/qwen_agent/tools/similarity_search.py index 411faae..c952498 100644 --- a/qwen_agent/tools/similarity_search.py +++ b/qwen_agent/tools/similarity_search.py @@ -1,11 +1,15 @@ from typing import List, Union +import jieba +import json5 from pydantic import BaseModel from qwen_agent.log import logger +from qwen_agent.settings import DEFAULT_MAX_REF_TOKEN from qwen_agent.tools.base import BaseTool, register_tool -from qwen_agent.utils.tokenization_qwen import count_tokens -from qwen_agent.utils.utils import get_split_word, parse_keyword +from qwen_agent.tools.doc_parser import DocParser, Record +from qwen_agent.utils.tokenization_qwen import count_tokens, tokenizer +from qwen_agent.utils.utils import has_chinese_chars class RefMaterialOutput(BaseModel): @@ -20,29 +24,14 @@ def to_dict(self) -> dict: } -class RefMaterialInputItem(BaseModel): - content: str - token: int - - def to_dict(self) -> dict: - return {'content': self.content, 'token': self.token} - - -class RefMaterialInput(BaseModel): - """The knowledge data format input to the retrieval""" - url: str - text: List[RefMaterialInputItem] - - def to_dict(self) -> dict: - return {'url': self.url, 'text': [x.to_dict() for x in self.text]} - - -def format_input_doc(doc: List[str]) -> RefMaterialInput: +def format_input_doc(doc: List[str], url: str = '') -> Record: new_doc = [] - for x in doc: - item = RefMaterialInputItem(content=x, token=count_tokens(x)) - new_doc.append(item) - return RefMaterialInput(url='', text=new_doc) + parser = DocParser() + for i, x in enumerate(doc): + page = {'page_num': i, 'content': [{'text': x, 'token': count_tokens(x)}]} + new_doc.append(page) + content = parser.split_doc_to_chunk(new_doc, path=url) + return Record(url=url, raw=content, title='') @register_tool('similarity_search') @@ -50,88 +39,187 @@ class SimilaritySearch(BaseTool): description = '从给定文档中检索和问题相关的部分' parameters = [{'name': 'query', 'type': 'string', 'description': '问题,需要从文档中检索和这个问题有关的内容', 'required': True}] - def call(self, - params: Union[str, dict], - doc: Union[RefMaterialInput, str, List[str]] = None, - max_token: int = 4000) -> dict: + def call( + self, + params: Union[str, dict], + docs: List[Union[Record, str, List[str]]] = None, + max_token: int = DEFAULT_MAX_REF_TOKEN, + ) -> list: params = self._verify_json_format_args(params) query = params['query'] - if not doc: - return {} - if isinstance(doc, str): - doc = [doc] - if isinstance(doc, list): - doc = format_input_doc(doc) - - tokens = [page.token for page in doc.text] - all_tokens = sum(tokens) - logger.info(f'all tokens of {doc.url}: {all_tokens}') + if not docs: + return [] + new_docs = [] + all_tokens = 0 + for i, doc in enumerate(docs): + if isinstance(doc, str): + doc = [doc] # Doc with one page + if isinstance(doc, list): + doc = format_input_doc(doc, f'doc_{str(i)}') + + if isinstance(doc, Record): + new_docs.append(doc) + all_tokens += sum([page.token for page in doc.raw]) + else: + raise TypeError + logger.info(f'all tokens: {all_tokens}') if all_tokens <= max_token: + # Todo: Whether to use full window logger.info('use full ref') - return RefMaterialOutput(url=doc.url, text=[x.content for x in doc.text]).to_dict() + return [ + RefMaterialOutput(url=doc.url, text=[page.content for page in doc.raw]).to_dict() for doc in new_docs + ] wordlist = parse_keyword(query) logger.info('wordlist: ' + ','.join(wordlist)) if not wordlist: - return self.get_top(doc, max_token) - - sims = [] - for i, page in enumerate(doc.text): - sim = self.filter_section(page.content, wordlist) - sims.append([i, sim]) + # Todo: This represents the queries that do not use retrieval: summarize, etc. + return self.get_top(new_docs, max_token) + + # Mix all chunks + docs_map = {} # {'text id': ['doc id', 'chunk id']} + text_list = [] # ['text 1', 'text 2', ...] + docs_retrieved = [] # [{'url': 'doc id', 'text': []}] + for i, doc in enumerate(new_docs): + docs_retrieved.append(RefMaterialOutput(url=doc.url, text=[''] * len(doc.raw))) + for j, page in enumerate(doc.raw): + docs_map[len(text_list)] = [i, j] + text_list.append(page.content) + assert len(docs_map) == len(text_list) + + # Using bm25 retrieval + from rank_bm25 import BM25Okapi + bm25 = BM25Okapi([split_text_into_keywords(x) for x in text_list]) + doc_scores = bm25.get_scores(wordlist) + sims = [[i, sim] for i, sim in enumerate(doc_scores)] sims.sort(key=lambda item: item[1], reverse=True) assert len(sims) > 0 - - res = [] max_sims = sims[0][1] + available_token = max_token if max_sims != 0: - manul = 0 - for i in range(min(manul, len(doc.text))): - if max_token >= tokens[i] * 2: # Ensure that the first two pages do not fill up the window - res.append(doc.text[i].content) - max_token -= tokens[i] - for i, x in enumerate(sims): - if x[0] < manul: + if len(new_docs) == 1: + # This is a trick for improving performance for one doc + manual = 2 + for doc_id, doc in enumerate(new_docs): + for chunk_id in range(min(manual, len(doc.raw))): + page = doc.raw[chunk_id] + if available_token >= page.token * manual * 2: # Ensure that the first two pages do not fill up the window + docs_retrieved[doc_id].text[chunk_id] = page.content + available_token -= page.token + else: + break + for (index, sim) in sims: + # Retrieval by BM25 + if available_token <= 0: + break + doc_id = docs_map[index][0] + chunk_id = docs_map[index][1] + page = new_docs[doc_id].raw[chunk_id] + if docs_retrieved[doc_id].text[chunk_id]: + # Has retrieved continue - page = doc.text[x[0]] - if max_token < page.token: - use_rate = (max_token / page.token) * 0.2 - res.append(page.content[:int(len(page.content) * use_rate)]) + if available_token < page.token: + docs_retrieved[doc_id].text[chunk_id] = tokenizer.truncate(page.content, max_token=available_token) break - - res.append(page.content) - max_token -= page.token - - logger.info(f'remaining slots: {max_token}') - return RefMaterialOutput(url=doc.url, text=res).to_dict() + docs_retrieved[doc_id].text[chunk_id] = page.content + available_token -= page.token + + res = [] + for x in docs_retrieved: + x.text = [trk for trk in x.text if trk] + if x.text: + res.append(x.to_dict()) + return res else: - return self.get_top(doc, max_token) - - def filter_section(self, text: str, wordlist: list) -> int: - page_list = get_split_word(text) - sim = self.jaccard_similarity(wordlist, page_list) - - return sim + return self.get_top(new_docs, max_token) @staticmethod - def jaccard_similarity(list1: list, list2: list) -> int: - s1 = set(list1) - s2 = set(list2) - return len(s1.intersection(s2)) # avoid text length impact - # return len(s1.intersection(s2)) / len(s1.union(s2)) # jaccard similarity - - @staticmethod - def get_top(doc: RefMaterialInput, max_token=4000) -> dict: - now_token = 0 - text = [] - for page in doc.text: - if (now_token + page.token) <= max_token: - text.append(page.content) - now_token += page.token - else: - use_rate = ((max_token - now_token) / page.token) * 0.2 - text.append(page.content[:int(len(page.content) * use_rate)]) - break - logger.info(f'remaining slots: {max_token-now_token}') - return RefMaterialOutput(url=doc.url, text=text).to_dict() + def get_top(docs: List[Record], max_token: int = DEFAULT_MAX_REF_TOKEN) -> list: + single_max_token = int(max_token / len(docs)) + _ref_list = [] + for doc in docs: + available_token = single_max_token + text = [] + for page in doc.raw: + if available_token <= 0: + break + if page.token <= available_token: + text.append(page.content) + available_token -= page.token + else: + text.append(tokenizer.truncate(page.content, max_token=available_token)) + break + logger.info(f'[Get top] Remaining slots: {available_token}') + now_ref_list = RefMaterialOutput(url=doc.url, text=text).to_dict() + _ref_list.append(now_ref_list) + return _ref_list + + +WORDS_TO_IGNORE = [ + '', '\\t', '\\n', '\\\\', '\\', '', '\n', '\t', '\\', ' ', ',', ',', ';', ';', '/', '.', '。', '-', 'is', 'are', + 'am', 'what', 'how', '的', '吗', '是', '了', '啊', '呢', '怎么', '如何', '什么', '(', ')', '(', ')', '【', '】', '[', ']', '{', + '}', '?', '?', '!', '!', '“', '”', '‘', '’', "'", "'", '"', '"', ':', ':', '讲了', '描述', '讲', '总结', 'summarize', + '总结下', '总结一下', '文档', '文章', 'article', 'paper', '文稿', '稿子', '论文', 'PDF', 'pdf', '这个', '这篇', '这', '我', '帮我', '那个', + '下', '翻译', 'i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you', "you're", "you've", "you'll", + "you'd", 'your', 'yours', 'yourself', 'yourselves', 'he', 'him', 'his', 'himself', 'she', "she's", 'her', 'hers', + 'herself', 'it', "it's", 'its', 'itself', 'they', 'them', 'their', 'theirs', 'themselves', 'what', 'which', 'who', + 'whom', 'this', 'that', "that'll", 'these', 'those', 'am', 'is', 'are', 'was', 'were', 'be', 'been', 'being', + 'have', 'has', 'had', 'having', 'do', 'does', 'did', 'doing', 'a', 'an', 'the', 'and', 'but', 'if', 'or', 'because', + 'as', 'until', 'while', 'of', 'at', 'by', 'for', 'with', 'about', 'against', 'between', 'into', 'through', 'during', + 'before', 'after', 'above', 'below', 'to', 'from', 'up', 'down', 'in', 'out', 'on', 'off', 'over', 'under', 'again', + 'further', 'then', 'once', 'here', 'there', 'when', 'where', 'why', 'how', 'all', 'any', 'both', 'each', 'few', + 'more', 'most', 'other', 'some', 'such', 'no', 'nor', 'not', 'only', 'own', 'same', 'so', 'than', 'too', 'very', + 's', 't', 'can', 'will', 'just', 'don', "don't", 'should', "should've", 'now', 'd', 'll', 'm', 'o', 're', 've', 'y', + 'ain', 'aren', "aren't", 'couldn', "couldn't", 'didn', "didn't", 'doesn', "doesn't", 'hadn', "hadn't", 'hasn', + "hasn't", 'haven', "haven't", 'isn', "isn't", 'ma', 'mightn', "mightn't", 'mustn', "mustn't", 'needn', "needn't", + 'shan', "shan't", 'shouldn', "shouldn't", 'wasn', "wasn't", 'weren', "weren't", 'won', "won't", 'wouldn', + "wouldn't", '说说', '讲讲', '介绍', 'summary' +] + + +def string_tokenizer(text: str) -> List[str]: + text = text.lower() + if has_chinese_chars(text): + _wordlist = list(jieba.lcut(text.strip())) + else: + _wordlist = text.strip().split() + return _wordlist + + +def split_text_into_keywords(text: str) -> List[str]: + _wordlist = string_tokenizer(text) + wordlist = [] + for x in _wordlist: + if x in WORDS_TO_IGNORE or x in wordlist: + continue + wordlist.append(x) + return wordlist + + +def parse_keyword(text): + try: + res = json5.loads(text) + except Exception: + return split_text_into_keywords(text) + + # json format + _wordlist = [] + try: + if 'keywords_zh' in res and isinstance(res['keywords_zh'], list): + _wordlist.extend([kw.lower() for kw in res['keywords_zh']]) + if 'keywords_en' in res and isinstance(res['keywords_en'], list): + _wordlist.extend([kw.lower() for kw in res['keywords_en']]) + wordlist = [] + for x in _wordlist: + if x in WORDS_TO_IGNORE: + continue + wordlist.append(x) + split_wordlist = split_text_into_keywords(res['text']) + for x in split_wordlist: + if x in wordlist: + continue + wordlist.append(x) + return wordlist + except Exception: + return split_text_into_keywords(text) diff --git a/qwen_agent/tools/simple_doc_parser.py b/qwen_agent/tools/simple_doc_parser.py new file mode 100644 index 0000000..b6ed8c7 --- /dev/null +++ b/qwen_agent/tools/simple_doc_parser.py @@ -0,0 +1,367 @@ +import os +import re +import urllib.parse +from collections import Counter +from typing import Dict, List, Optional, Union + +from qwen_agent.settings import DEFAULT_WORKSPACE +from qwen_agent.tools.base import BaseTool, register_tool +from qwen_agent.utils.str_processing import rm_cid, rm_continuous_placeholders, rm_hexadecimal +from qwen_agent.utils.utils import (get_file_type, hash_sha256, is_http_url, read_text_from_file, + save_url_to_local_work_dir) + + +def clean_paragraph(text): + text = rm_cid(text) + text = rm_hexadecimal(text) + text = rm_continuous_placeholders(text) + return text + + +PARAGRAPH_SPLIT_SYMBOL = '\n' + + +def parse_word(docx_path: str, extract_image: bool = False): + if extract_image: + raise ValueError('Currently, extracting images is not supported!') + + from docx import Document + doc = Document(docx_path) + + content = [] + for para in doc.paragraphs: + if para.text.strip(): + content.append({'text': para.text}) + for table in doc.tables: + tbl = [] + for row in table.rows: + tbl.append('|' + '|'.join([cell.text for cell in row.cells]) + '|') + tbl = '\n'.join(tbl) + content.append({'table': tbl}) + + # Due to the pages in Word are not fixed, the entire document is returned as one page + return [{'page_num': 1, 'content': content}] + + +def parse_ppt(path: str, extract_image: bool = False): + if extract_image: + raise ValueError('Currently, extracting images is not supported!') + + from pptx import Presentation + ppt = Presentation(path) + doc = [] + for slide_number, slide in enumerate(ppt.slides): + page = {'page_num': slide_number + 1, 'content': []} + + for shape in slide.shapes: + if not shape.has_text_frame and not shape.has_table: + pass + + if shape.has_text_frame: + for paragraph in shape.text_frame.paragraphs: + paragraph_text = ''.join(run.text for run in paragraph.runs) + paragraph_text = clean_paragraph(paragraph_text) + if paragraph_text.strip(): + page['content'].append({'text': paragraph_text}) + + if shape.has_table: + tbl = [] + for row_number, row in enumerate(shape.table.rows): + tbl.append('|' + '|'.join([cell.text for cell in row.cells]) + '|') + tbl = '\n'.join(tbl) + page['content'].append({'table': tbl}) + doc.append(page) + return doc + + +def parse_txt(path: str): + text = read_text_from_file(path) + paras = text.split(PARAGRAPH_SPLIT_SYMBOL) + content = [] + for p in paras: + p = clean_paragraph(p) + if p.strip(): + content.append({'text': p}) + + # Due to the pages in txt are not fixed, the entire document is returned as one page + return [{'page_num': 1, 'content': content}] + + +def parse_html_bs(path: str, extract_image: bool = False): + if extract_image: + raise ValueError('Currently, extracting images is not supported!') + + def pre_process_html(s): + # replace multiple newlines + s = re.sub('\n+', '\n', s) + # replace special string + s = s.replace("Add to Qwen's Reading List", '') + return s + + try: + from bs4 import BeautifulSoup + except Exception: + raise ValueError('Please install bs4 by `pip install beautifulsoup4`') + bs_kwargs = {'features': 'lxml'} + with open(path, 'r', encoding='utf-8') as f: + soup = BeautifulSoup(f, **bs_kwargs) + + text = soup.get_text() + + if soup.title: + title = str(soup.title.string) + else: + title = '' + + text = pre_process_html(text) + paras = text.split(PARAGRAPH_SPLIT_SYMBOL) + content = [] + for p in paras: + p = clean_paragraph(p) + if p.strip(): + content.append({'text': p}) + + # The entire document is returned as one page + return [{'page_num': 1, 'content': content, 'title': title}] + + +def parse_pdf(pdf_path: str, extract_image: bool = False) -> List[dict]: + # Todo: header and footer + from pdfminer.high_level import extract_pages + from pdfminer.layout import LTImage, LTRect, LTTextContainer + + doc = [] + for i, page_layout in enumerate(extract_pages(pdf_path)): + page = {'page_num': page_layout.pageid, 'content': []} + + elements = [] + for element in page_layout: + elements.append(element) + + # Init params for table + table_num = 0 + tables = [] + + for element in elements: + if isinstance(element, LTRect): + if not tables: + tables = extract_tables(pdf_path, i) + if table_num < len(tables): + table_string = table_converter(tables[table_num]) + table_num += 1 + if table_string: + page['content'].append({'table': table_string, 'obj': element}) + elif isinstance(element, LTTextContainer): + # Delete line breaks in the same paragraph + text = element.get_text() + # Todo: Further analysis using font + font = get_font(element) + if text.strip(): + new_content_item = {'text': text, 'obj': element} + if font: + new_content_item['font-size'] = round(font[1]) + # new_content_item['font-name'] = font[0] + page['content'].append(new_content_item) + elif extract_image and isinstance(element, LTImage): + # Todo: ocr + raise ValueError('Currently, extracting images is not supported!') + else: + pass + + # merge elements + page['content'] = postprocess_page_content(page['content']) + doc.append(page) + + return doc + + +def postprocess_page_content(page_content: list) -> list: + # rm repetitive identification for table and text + # Some documents may repeatedly recognize LTRect and LTTextContainer + table_obj = [p['obj'] for p in page_content if 'table' in p] + tmp = [] + for p in page_content: + repetitive = False + if 'text' in p: + for t in table_obj: + if t.bbox[0] <= p['obj'].bbox[0] and p['obj'].bbox[1] <= t.bbox[1] and t.bbox[2] <= p['obj'].bbox[ + 2] and p['obj'].bbox[3] <= t.bbox[3]: + repetitive = True + break + + if not repetitive: + tmp.append(p) + page_content = tmp + + # merge paragraphs that have been separated by mistake + new_page_content = [] + for p in page_content: + if new_page_content and 'text' in new_page_content[-1] and 'text' in p and abs( + p.get('font-size', 12) - + new_page_content[-1].get('font-size', 12)) < 2 and p['obj'].height < p.get('font-size', 12) + 1: + # Merge those lines belonging to a paragraph + new_page_content[-1]['text'] += f' {p["text"]}' + # new_page_content[-1]['font-name'] = p.get('font-name', '') + new_page_content[-1]['font-size'] = p.get('font-size', 12) + else: + p.pop('obj') + new_page_content.append(p) + for i in range(len(new_page_content)): + if 'text' in new_page_content[i]: + new_page_content[i]['text'] = clean_paragraph(new_page_content[i]['text']) + return new_page_content + + +def get_font(element): + from pdfminer.layout import LTChar, LTTextContainer + + fonts_list = [] + for text_line in element: + if isinstance(text_line, LTTextContainer): + for character in text_line: + if isinstance(character, LTChar): + fonts_list.append((character.fontname, character.size)) + + fonts_list = list(set(fonts_list)) + if fonts_list: + counter = Counter(fonts_list) + most_common_fonts = counter.most_common(1)[0][0] + return most_common_fonts + else: + return [] + + +def extract_tables(pdf_path, page_num): + import pdfplumber + pdf = pdfplumber.open(pdf_path) + table_page = pdf.pages[page_num] + tables = table_page.extract_tables() + return tables + + +def table_converter(table): + table_string = '' + for row_num in range(len(table)): + row = table[row_num] + cleaned_row = [ + item.replace('\n', ' ') if item is not None and '\n' in item else 'None' if item is None else item + for item in row + ] + table_string += ('|' + '|'.join(cleaned_row) + '|' + '\n') + table_string = table_string[:-1] + return table_string + + +def sanitize_chrome_file_path(file_path: str) -> str: + # For Linux and macOS. + if os.path.exists(file_path): + return file_path + + # For native Windows, drop the leading '/' in '/C:/' + win_path = file_path + if win_path.startswith('/'): + win_path = win_path[1:] + if os.path.exists(win_path): + return win_path + + # For Windows + WSL. + if re.match(r'^[A-Za-z]:/', win_path): + wsl_path = f'/mnt/{win_path[0].lower()}/{win_path[3:]}' + if os.path.exists(wsl_path): + return wsl_path + + # For native Windows, replace / with \. + win_path = win_path.replace('/', '\\') + if os.path.exists(win_path): + return win_path + + return file_path + + +PARSER_SUPPORTED_FILE_TYPES = ['pdf', 'docx', 'pptx', 'txt', 'html'] + + +def get_plain_doc(doc: list): + paras = [] + for page in doc: + for para in page['content']: + for k, v in para.items(): + if k in ['text', 'table', 'image']: + paras.append(v) + return PARAGRAPH_SPLIT_SYMBOL.join(paras) + + +@register_tool('simple_doc_parser') +class SimpleDocParser(BaseTool): + description = f'提取出一个文档的内容,支持类型包括:{"/".join(PARSER_SUPPORTED_FILE_TYPES)}' + parameters = [{ + 'name': 'url', + 'type': 'string', + 'description': '待提取的文件的路径,可以是一个本地路径或可下载的http(s)链接', + 'required': True + }] + + def __init__(self, cfg: Optional[Dict] = None): + super().__init__(cfg) + self.data_root = self.cfg.get('path', os.path.join(DEFAULT_WORKSPACE, 'tools', self.name)) + self.extract_image = self.cfg.get('extract_image', False) + self.structured_doc = self.cfg.get('structured_doc', False) + + def call(self, params: Union[str, dict], **kwargs) -> Union[str, list]: + """Parse pdf by url, and return the formatted content. + + Returns: + Extracted doc as plain text or the following list format: + [ + {'page_num': 1, + 'content': [ + {'text': 'This is one paragraph'}, + {'table': 'This is one table'} + ], + 'title': 'If extracted, this is the title of the doc.'}, + {'page_num': 2, + 'content': [ + {'text': 'This is one paragraph'}, + {'table': 'This is one table'} + ]} + ] + """ + + params = self._verify_json_format_args(params) + path = params['url'] + + f_type = get_file_type(path) + if f_type in PARSER_SUPPORTED_FILE_TYPES: + if path.startswith('https://') or path.startswith('http://') or re.match(r'^[A-Za-z]:\\', path) or re.match( + r'^[A-Za-z]:/', path): + path = path + else: + parsed_url = urllib.parse.urlparse(path) + path = urllib.parse.unquote(parsed_url.path) + path = sanitize_chrome_file_path(path) + + os.makedirs(self.data_root, exist_ok=True) + if is_http_url(path): + # download online url + tmp_file_root = os.path.join(self.data_root, hash_sha256(path)) + os.makedirs(tmp_file_root, exist_ok=True) + path = save_url_to_local_work_dir(path, tmp_file_root) + + if f_type == 'pdf': + parsed_file = parse_pdf(path, self.extract_image) + elif f_type == 'docx': + parsed_file = parse_word(path, self.extract_image) + elif f_type == 'pptx': + parsed_file = parse_ppt(path, self.extract_image) + elif f_type == 'txt': + parsed_file = parse_txt(path) + elif f_type == 'html': + parsed_file = parse_html_bs(path, self.extract_image) + else: + raise ValueError( + f'Failed: The current parser does not support this file type! Supported types: {"/".join(PARSER_SUPPORTED_FILE_TYPES)}' + ) + if not self.structured_doc: + return get_plain_doc(parsed_file) + else: + return parsed_file diff --git a/qwen_agent/tools/storage.py b/qwen_agent/tools/storage.py index bfc137f..66d31f3 100644 --- a/qwen_agent/tools/storage.py +++ b/qwen_agent/tools/storage.py @@ -1,11 +1,13 @@ import os from typing import Dict, Optional, Union +from qwen_agent.settings import DEFAULT_WORKSPACE from qwen_agent.tools.base import BaseTool, register_tool from qwen_agent.utils.utils import read_text_from_file, save_text_to_file -DEFAULT_STORAGE_PATH = 'workspace/default_data_path' -SUCCESS_MESSAGE = 'SUCCESS' + +class KeyNotExistsError(ValueError): + pass @register_tool('storage') @@ -32,7 +34,7 @@ class Storage(BaseTool): def __init__(self, cfg: Optional[Dict] = None): super().__init__(cfg) - self.root = self.cfg.get('storage_root_path', DEFAULT_STORAGE_PATH) + self.root = self.cfg.get('storage_root_path', os.path.join(DEFAULT_WORKSPACE, 'tools', self.name)) os.makedirs(self.root, exist_ok=True) def call(self, params: Union[str, dict], **kwargs) -> str: @@ -63,10 +65,12 @@ def put(self, key: str, value: str, path: Optional[str] = None) -> str: os.makedirs(path_dir, exist_ok=True) save_text_to_file(path, value) - return SUCCESS_MESSAGE + return f'Successfully saved {key}.' def get(self, key: str, path: Optional[str] = None) -> str: path = path or self.root + if not os.path.exists(os.path.join(path, key)): + raise KeyNotExistsError(f'Get Failed: {key} does not exist') return read_text_from_file(os.path.join(path, key)) def delete(self, key, path: Optional[str] = None) -> str: @@ -74,7 +78,7 @@ def delete(self, key, path: Optional[str] = None) -> str: path = os.path.join(path, key) if os.path.exists(path): os.remove(path) - return f'Successfully deleted{key}' + return f'Successfully deleted {key}' else: return f'Delete Failed: {key} does not exist' @@ -83,7 +87,7 @@ def scan(self, key: str, path: Optional[str] = None) -> str: path = os.path.join(path, key) if os.path.exists(path): if not os.path.isdir(path): - return 'Scan Failed: The scan operation requires passing in a key to a folder path' + return 'Scan Failed: The scan operation requires passing in a folder path as the key.' # All key-value pairs kvs = {} for root, dirs, files in os.walk(path): diff --git a/qwen_agent/tools/web_extractor.py b/qwen_agent/tools/web_extractor.py index 7becdf9..c58036a 100644 --- a/qwen_agent/tools/web_extractor.py +++ b/qwen_agent/tools/web_extractor.py @@ -1,8 +1,7 @@ from typing import Union -import requests - from qwen_agent.tools.base import BaseTool, register_tool +from qwen_agent.tools.simple_doc_parser import SimpleDocParser @register_tool('web_extractor') @@ -11,28 +10,7 @@ class WebExtractor(BaseTool): parameters = [{'name': 'url', 'type': 'string', 'description': '网页URL', 'required': True}] def call(self, params: Union[str, dict], **kwargs) -> str: - only_text = self.cfg.get('only_text', False) - params = self._verify_json_format_args(params) - url = params['url'] - headers = { - 'User-Agent': - 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3' - } - response = requests.get(url, headers=headers) - if response.status_code == 200: - if only_text: - - import justext - - paragraphs = justext.justext(response.text, justext.get_stoplist('English')) - content = '\n\n'.join([paragraph.text for paragraph in paragraphs]).strip() - if content: - return content - else: - return response.text - else: - return response.text - else: - return '' + parsed_web = SimpleDocParser().call({'url': url}) + return parsed_web diff --git a/qwen_agent/utils/doc_parser.py b/qwen_agent/utils/doc_parser.py deleted file mode 100644 index cedafe8..0000000 --- a/qwen_agent/utils/doc_parser.py +++ /dev/null @@ -1,106 +0,0 @@ -import re - -from qwen_agent.utils.tokenization_qwen import count_tokens - -ONE_PAGE_TOKEN = 500 - - -def rm_newlines(text): - text = re.sub(r'(?<=[^\.。::])\n', ' ', text) - return text - - -def rm_cid(text): - text = re.sub(r'\(cid:\d+\)', '', text) - return text - - -def rm_hexadecimal(text): - text = re.sub(r'[0-9A-Fa-f]{21,}', '', text) - return text - - -def rm_continuous_placeholders(text): - text = re.sub(r'(\.|-|——|。|_|\*){7,}', '...', text) - return text - - -def deal(text): - text = rm_newlines(text) - text = rm_cid(text) - text = rm_hexadecimal(text) - text = rm_continuous_placeholders(text) - return text - - -def parse_doc(path): - if '.pdf' in path.lower(): - from pdfminer.high_level import extract_text - text = extract_text(path) - elif '.docx' in path.lower(): - import docx2txt - text = docx2txt.process(path) - elif '.pptx' in path.lower(): - from pptx import Presentation - ppt = Presentation(path) - text = [] - for slide in ppt.slides: - for shape in slide.shapes: - if hasattr(shape, 'text'): - text.append(shape.text) - text = '\n'.join(text) - else: - raise TypeError - - text = deal(text) - return split_text_to_trunk(text, path) - - -def pre_process_html(s): - # replace multiple newlines - s = re.sub('\n+', '\n', s) - # replace special string - s = s.replace("Add to Qwen's Reading List", '') - return s - - -def parse_html_bs(path): - try: - from bs4 import BeautifulSoup - except Exception: - raise ValueError('Please install bs4 by `pip install beautifulsoup4`') - bs_kwargs = {'features': 'lxml'} - with open(path, 'r', encoding='utf-8') as f: - soup = BeautifulSoup(f, **bs_kwargs) - - text = soup.get_text() - - if soup.title: - title = str(soup.title.string) - else: - title = '' - text = pre_process_html(text) - return split_text_to_trunk(text, path, title) - - -def split_text_to_trunk(content: str, path: str, title: str = ''): - all_tokens = count_tokens(content) - all_pages = round(all_tokens / ONE_PAGE_TOKEN) - if all_pages == 0: - all_pages = 1 - len_content = len(content) - len_one_page = int(len_content / all_pages) # Approximately equal to ONE_PAGE_TOKEN - - res = [] - for i in range(0, len_content, len_one_page): - text = content[i:min(i + len_one_page, len_content)] - res.append({ - 'page_content': text, - 'metadata': { - 'source': path, - 'title': title, - 'page': (i % len_one_page) - }, - 'token': count_tokens(text) - }) - return res diff --git a/qwen_agent/utils/str_processing.py b/qwen_agent/utils/str_processing.py new file mode 100644 index 0000000..e064096 --- /dev/null +++ b/qwen_agent/utils/str_processing.py @@ -0,0 +1,30 @@ +import re + +from qwen_agent.utils.utils import has_chinese_chars + + +def rm_newlines(text): + if text.endswith('-\n'): + text = text[:-2] + return text.strip() + rep_c = ' ' + if has_chinese_chars(text): + rep_c = '' + text = re.sub(r'(?<=[^\.。::\d])\n', rep_c, text) + return text.strip() + + +def rm_cid(text): + text = re.sub(r'\(cid:\d+\)', '', text) + return text + + +def rm_hexadecimal(text): + text = re.sub(r'[0-9A-Fa-f]{21,}', '', text) + return text + + +def rm_continuous_placeholders(text): + text = re.sub(r'[.\- —。_*]{7,}', '\t', text) + text = re.sub(r'\n{3,}', '\n\n', text) + return text diff --git a/qwen_agent/utils/tokenization_qwen.py b/qwen_agent/utils/tokenization_qwen.py index 5c330da..b633f10 100644 --- a/qwen_agent/utils/tokenization_qwen.py +++ b/qwen_agent/utils/tokenization_qwen.py @@ -1,20 +1,13 @@ -# Copyright (c) Alibaba Cloud. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. """Tokenization classes for QWen.""" import base64 -import logging -import os import unicodedata -from dataclasses import dataclass, field from pathlib import Path -from typing import Collection, Dict, List, Set, Tuple, Union +from typing import Collection, Dict, List, Set, Union import tiktoken -logger = logging.getLogger(__name__) +from qwen_agent.log import logger VOCAB_FILES_NAMES = {'vocab_file': 'qwen.tiktoken'} @@ -47,23 +40,6 @@ def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]: } -@dataclass(frozen=True, eq=True) -class AddedToken: - """ - AddedToken represents a token to be added to a Tokenizer An AddedToken can have special options defining the - way it should behave. - """ - - content: str = field(default_factory=str) - single_word: bool = False - lstrip: bool = False - rstrip: bool = False - normalized: bool = True - - def __getstate__(self): - return self.__dict__ - - class QWenTokenizer: """QWen tokenizer.""" @@ -157,33 +133,6 @@ def convert_tokens_to_ids(self, tokens: Union[bytes, str, List[Union[bytes, str] ids.append(self.mergeable_ranks.get(token)) return ids - def _add_tokens( - self, - new_tokens: Union[List[str], List[AddedToken]], - special_tokens: bool = False, - ) -> int: - if not special_tokens and new_tokens: - raise ValueError('Adding regular tokens is not supported') - for token in new_tokens: - surface_form = token.content if isinstance(token, AddedToken) else token - if surface_form not in SPECIAL_TOKENS_SET: - raise ValueError('Adding unknown special tokens is not supported') - return 0 - - def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]: - """ - Save only the vocabulary of the tokenizer (vocabulary). - - Returns: - `Tuple(str)`: Paths to the files saved. - """ - file_path = os.path.join(save_directory, 'qwen.tiktoken') - with open(file_path, 'w', encoding='utf8') as w: - for k, v in self.mergeable_ranks.items(): - line = base64.b64encode(k).decode('utf8') + ' ' + str(v) + '\n' - w.write(line) - return (file_path,) - def tokenize( self, text: str, @@ -242,29 +191,6 @@ def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str: def vocab_size(self): return self.tokenizer.n_vocab - def _convert_id_to_token(self, index: int) -> Union[bytes, str]: - """Converts an id to a token, special tokens included""" - if index in self.decoder: - return self.decoder[index] - raise ValueError('unknown ids') - - def _convert_token_to_id(self, token: Union[bytes, str]) -> int: - """Converts a token to an id using the vocab, special tokens included""" - if token in self.special_tokens: - return self.special_tokens[token] - if token in self.mergeable_ranks: - return self.mergeable_ranks[token] - raise ValueError('unknown token') - - def _tokenize(self, text: str, **kwargs): - """ - Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based - vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces). - - Do NOT take care of added tokens. - """ - raise NotImplementedError - def _decode( self, token_ids: Union[int, List[int]], @@ -278,10 +204,20 @@ def _decode( token_ids = [i for i in token_ids if i < self.eod_id] return self.tokenizer.decode(token_ids, errors=errors or self.errors) + def encode(self, text: str) -> List[int]: + return self.convert_tokens_to_ids(self.tokenize(text)) + + def count_tokens(self, text: str) -> int: + return len(self.tokenize(text)) + + def truncate(self, text: str, max_token: int, start_token: int = 0) -> str: + token_list = self.tokenize(text) + token_list = token_list[start_token:min(len(token_list), start_token + max_token)] + return self.convert_tokens_to_string(token_list) + tokenizer = QWenTokenizer(Path(__file__).resolve().parent / 'qwen.tiktoken') -def count_tokens(text): - tokens = tokenizer.tokenize(text) - return len(tokens) +def count_tokens(text: str) -> int: + return tokenizer.count_tokens(text) diff --git a/qwen_agent/utils/utils.py b/qwen_agent/utils/utils.py index 8bc5a14..28054c0 100644 --- a/qwen_agent/utils/utils.py +++ b/qwen_agent/utils/utils.py @@ -1,25 +1,50 @@ -import datetime +import copy import hashlib -import json import os import re import shutil +import signal import socket import sys +import time import traceback -import urllib -from typing import Dict, List, Literal, Optional, Union -from urllib.parse import urlparse +import urllib.parse +from typing import Any, List, Literal, Optional -import jieba import json5 import requests -from jieba import analyse +from qwen_agent.llm.schema import ASSISTANT, FUNCTION, SYSTEM, USER, ContentItem, Message from qwen_agent.log import logger -def get_local_ip(): +def append_signal_handler(sig, handler): + """ + Installs a new signal handler while preserving any existing handler. + If an existing handler is present, it will be called _after_ the new handler. + """ + + old_handler = signal.getsignal(sig) + if not callable(old_handler): + old_handler = None + if sig == signal.SIGINT: + + def old_handler(*args, **kwargs): + raise KeyboardInterrupt + elif sig == signal.SIGTERM: + + def old_handler(*args, **kwargs): + raise SystemExit + + def new_handler(*args, **kwargs): + handler(*args, **kwargs) + if old_handler is not None: + old_handler(*args, **kwargs) + + signal.signal(sig, new_handler) + + +def get_local_ip() -> str: s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) try: # doesn't even have to be reachable @@ -32,45 +57,67 @@ def get_local_ip(): return ip -def hash_sha256(key): - hash_object = hashlib.sha256(key.encode()) +def hash_sha256(text: str) -> str: + hash_object = hashlib.sha256(text.encode()) key = hash_object.hexdigest() return key -def print_traceback(is_error=True): +def print_traceback(is_error: bool = True): if is_error: logger.error(''.join(traceback.format_exception(*sys.exc_info()))) else: logger.warning(''.join(traceback.format_exception(*sys.exc_info()))) -def has_chinese_chars(data) -> bool: +def has_chinese_chars(data: Any) -> bool: text = f'{data}' return len(re.findall(r'[\u4e00-\u9fff]+', text)) > 0 -def get_basename_from_url(url: str) -> str: - basename = os.path.basename(urlparse(url).path) +def get_basename_from_url(path_or_url: str) -> str: + if re.match(r'^[A-Za-z]:\\', path_or_url): + # "C:\\a\\b\\c" -> "C:/a/b/c" + path_or_url = path_or_url.replace('\\', '/') + + # "/mnt/a/b/c" -> "c" + # "https://github.com/here?k=v" -> "here" + # "https://github.com/" -> "" + basename = urllib.parse.urlparse(path_or_url).path + basename = os.path.basename(basename) basename = urllib.parse.unquote(basename) - return basename.strip() + basename = basename.strip() + + # "https://github.com/" -> "" -> "github.com" + if not basename: + basename = [x.strip() for x in path_or_url.split('/') if x.strip()][-1] + + return basename -def is_local_path(path): - if path.startswith('https://') or path.startswith('http://'): - return False - return True +def is_http_url(path_or_url: str) -> bool: + if path_or_url.startswith('https://') or path_or_url.startswith('http://'): + return True + return False + + +def is_image(path_or_url: str) -> bool: + filename = get_basename_from_url(path_or_url).lower() + for ext in ['jpg', 'jpeg', 'png', 'webp']: + if filename.endswith(ext): + return True + return False -def save_url_to_local_work_dir(url, base_dir, new_name=''): - if not new_name: - new_name = get_basename_from_url(url) - new_path = os.path.join(base_dir, new_name) +def save_url_to_local_work_dir(url: str, save_dir: str, save_filename: str = '') -> str: + if not save_filename: + save_filename = get_basename_from_url(url) + new_path = os.path.join(save_dir, save_filename) if os.path.exists(new_path): os.remove(new_path) - logger.info(f'download {url} to {new_path}') - start_time = datetime.datetime.now() - if is_local_path(url): + logger.info(f'Downloading {url} to {new_path}...') + start_time = time.time() + if not is_http_url(url): shutil.copy(url, new_path) else: headers = { @@ -83,162 +130,57 @@ def save_url_to_local_work_dir(url, base_dir, new_name=''): file.write(response.content) else: raise ValueError('Can not download this file. Please check your network or the file link.') - end_time = datetime.datetime.now() - logger.info(f'Time: {str(end_time - start_time)}') + end_time = time.time() + logger.info(f'Finished downloading {url} to {new_path}. Time spent: {end_time - start_time} seconds.') return new_path -def is_image(filename): - filename = filename.lower() - for ext in ['jpg', 'jpeg', 'png', 'webp']: - if filename.endswith(ext): - return True - return False - - -def get_current_date_str( - lang: Literal['en', 'zh'] = 'en', - hours_from_utc: Optional[int] = None, -) -> str: - if hours_from_utc is None: - cur_time = datetime.datetime.now() - else: - cur_time = datetime.datetime.utcnow() + datetime.timedelta(hours=hours_from_utc) - if lang == 'en': - date_str = 'Current date: ' + cur_time.strftime('%A, %B %d, %Y') - elif lang == 'zh': - cur_time = cur_time.timetuple() - date_str = f'当前时间:{cur_time.tm_year}年{cur_time.tm_mon}月{cur_time.tm_mday}日,星期' - date_str += ['一', '二', '三', '四', '五', '六', '日'][cur_time.tm_wday] - date_str += '。' - else: - raise NotImplementedError - return date_str - - -def save_text_to_file(path, text): +def save_text_to_file(path: str, text: str) -> None: with open(path, 'w', encoding='utf-8') as fp: fp.write(text) -def read_text_from_file(path): +def read_text_from_file(path: str) -> str: with open(path, 'r', encoding='utf-8') as file: file_content = file.read() return file_content -def contains_html_tags(text): +def contains_html_tags(text: str) -> bool: pattern = r'<(p|span|div|li|html|script)[^>]*?' return bool(re.search(pattern, text)) -def get_file_type(path): - # This is a temporary plan - if is_local_path(path): +def get_file_type(path: str) -> Literal['pdf', 'docx', 'pptx', 'txt', 'html', 'unk']: + f_type = get_basename_from_url(path).split('.')[-1].lower() + if f_type in ['pdf', 'docx', 'pptx', 'txt']: + # Specially supported file types + return f_type + + if is_http_url(path): + # Assuming that the URL is HTML by default + return 'html' + else: + # Determine by reading local HTML file try: content = read_text_from_file(path) except Exception: print_traceback() - return 'Unknown' + return 'unk' if contains_html_tags(content): return 'html' else: - return 'Unknown' - else: - headers = { - 'User-Agent': - 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3' - } - response = requests.get(path, headers=headers) - if response.status_code == 200: - if contains_html_tags(response.text): - return 'html' - else: - return 'Unknown' - else: - print_traceback() - return 'Unknown' - - -ignore_words = [ - '', ' ', '\t', '\n', '\\', 'is', 'are', 'am', 'what', 'how', '的', '吗', '是', '了', '啊', '呢', '怎么', '如何', '什么', '?', - '?', '!', '!', '“', '”', '‘', '’', "'", "'", '"', '"', ':', ':', '讲了', '描述', '讲', '说说', '讲讲', '介绍', '总结下', '总结一下', - '文档', '文章', '文稿', '稿子', '论文', 'PDF', 'pdf', '这个', '这篇', '这', '我', '帮我', '那个', '下', '翻译' -] - - -def get_split_word(text): - text = text.lower() - _wordlist = jieba.lcut(text.strip()) - wordlist = [] - for x in _wordlist: - if x in ignore_words: - continue - wordlist.append(x) - return wordlist - - -def parse_keyword(text): - try: - res = json5.loads(text) - except Exception: - return get_split_word(text) + return 'unk' - # json format - _wordlist = [] - try: - if 'keywords_zh' in res and isinstance(res['keywords_zh'], list): - _wordlist.extend([kw.lower() for kw in res['keywords_zh']]) - if 'keywords_en' in res and isinstance(res['keywords_en'], list): - _wordlist.extend([kw.lower() for kw in res['keywords_en']]) - wordlist = [] - for x in _wordlist: - if x in ignore_words: - continue - wordlist.append(x) - wordlist.extend(get_split_word(res['text'])) - return wordlist - except Exception: - return get_split_word(text) - - -def get_key_word(text): - text = text.lower() - _wordlist = analyse.extract_tags(text) - wordlist = [] - for x in _wordlist: - if x in ignore_words: - continue - wordlist.append(x) - return wordlist - -def get_last_one_line_context(text): - lines = text.split('\n') - n = len(lines) - res = '' - for i in range(n - 1, -1, -1): - if lines[i].strip(): - res = lines[i] - break - return res - - -def extract_urls(text): +def extract_urls(text: str) -> List[str]: pattern = re.compile(r'https?://\S+') urls = re.findall(pattern, text) return urls -def extract_obs(text): - k = text.rfind('\nObservation:') - j = text.rfind('\nThought:') - obs = text[k + len('\nObservation:'):j] - return obs.strip() - - -def extract_code(text): +def extract_code(text: str) -> str: # Match triple backtick blocks first triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL) if triple_match: @@ -247,69 +189,90 @@ def extract_code(text): try: text = json5.loads(text)['code'] except Exception: - print_traceback() + print_traceback(is_error=False) # If no code blocks found, return original text return text -def parse_latest_plugin_call(text): - plugin_name, plugin_args = '', '' - i = text.rfind('\nAction:') - j = text.rfind('\nAction Input:') - k = text.rfind('\nObservation:') - if 0 <= i < j: # If the text has `Action` and `Action input`, - if k < j: # but does not contain `Observation`, - # then it is likely that `Observation` is ommited by the LLM, - # because the output text may have discarded the stop word. - text = text.rstrip() + '\nObservation:' # Add it back. - k = text.rfind('\nObservation:') - plugin_name = text[i + len('\nAction:'):j].strip() - plugin_args = text[j + len('\nAction Input:'):k].strip() - text = text[:k] - return plugin_name, plugin_args, text - - -def get_function_description(function: Dict) -> str: - """ - Text description of function - """ - tool_desc_template = { - 'zh': '### {name_for_human}\n\n{name_for_model}: {description_for_model} 输入参数:{parameters} {args_format}', - 'en': '### {name_for_human}\n\n{name_for_model}: {description_for_model} Parameters: {parameters} {args_format}' - } - if has_chinese_chars(function): - tool_desc = tool_desc_template['zh'] +def format_as_multimodal_message(msg: Message, add_upload_info: bool = True) -> Message: + assert msg.role in (USER, ASSISTANT, SYSTEM, FUNCTION) + content = [] + if isinstance(msg.content, str): # if text content + if msg.content: + content = [ContentItem(text=msg.content)] + elif isinstance(msg.content, list): # if multimodal content + files = [] + for item in msg.content: + k, v = item.get_type_and_value() + if k == 'text': + content.append(ContentItem(text=v)) + if k == 'image': + content.append(item) + if k in ('file', 'image'): + files.append(v) + if add_upload_info and files and (msg.role in (SYSTEM, USER)): + has_zh = has_chinese_chars(content) + upload = [] + for f in [get_basename_from_url(f) for f in files]: + if is_image(f): + if has_zh: + upload.append(f'![图片]({f})') + else: + upload.append(f'![image]({f})') + else: + if has_zh: + upload.append(f'[文件]({f})') + else: + upload.append(f'[file]({f})') + upload = ' '.join(upload) + if has_zh: + upload = f'(上传了 {upload})\n\n' + else: + upload = f'(Uploaded {upload})\n\n' + content = [ContentItem(text=upload)] + content else: - tool_desc = tool_desc_template['en'] - - name = function.get('name', None) - name_for_human = function.get('name_for_human', name) - name_for_model = function.get('name_for_model', name) - assert name_for_human and name_for_model - args_format = function.get('args_format', '') - return tool_desc.format(name_for_human=name_for_human, - name_for_model=name_for_model, - description_for_model=function['description'], - parameters=json.dumps(function['parameters'], ensure_ascii=False), - args_format=args_format).rstrip() - - -def format_knowledge_to_source_and_content(result: Union[str, List[dict]]) -> List[dict]: - knowledge = [] - if isinstance(result, str): - result = f'{result}'.strip() - docs = json5.loads(result) + raise TypeError + msg = Message( + role=msg.role, + content=content, + name=msg.name if msg.role == FUNCTION else None, + function_call=msg.function_call, + ) + return msg + + +def extract_text_from_message(msg: Message, add_upload_info: bool = True) -> str: + if isinstance(msg.content, list): + mm_msg = format_as_multimodal_message(msg, add_upload_info=add_upload_info) + text = '' + for item in mm_msg.content: + if item.type == 'text': + text += item.value + elif isinstance(msg.content, str): + text = msg.content else: - docs = result - try: - _tmp_knowledge = [] - assert isinstance(docs, list) - for doc in docs: - url, snippets = doc['url'], doc['text'] - assert isinstance(snippets, list) - _tmp_knowledge.append({'source': f'[文件]({url})', 'content': '\n\n...\n\n'.join(snippets)}) - knowledge.extend(_tmp_knowledge) - except Exception: - print_traceback() - knowledge.append({'source': '上传的文档', 'content': result}) - return knowledge + raise TypeError + return text.strip() + + +def extract_files_from_messages(messages: List[Message]) -> List[str]: + files = [] + for msg in messages: + if isinstance(msg.content, list): + for item in msg.content: + if item.file and item.file not in files: + files.append(item.file) + return files + + +def merge_generate_cfgs(base_generate_cfg: Optional[dict], new_generate_cfg: Optional[dict]) -> dict: + generate_cfg: dict = copy.deepcopy(base_generate_cfg or {}) + if new_generate_cfg: + for k, v in new_generate_cfg.items(): + if k == 'stop': + stop = generate_cfg.get('stop', []) + stop = stop + [s for s in v if s not in stop] + generate_cfg['stop'] = stop + else: + generate_cfg[k] = v + return generate_cfg diff --git a/qwen_server/assistant_server.py b/qwen_server/assistant_server.py index 3068172..2279e9c 100644 --- a/qwen_server/assistant_server.py +++ b/qwen_server/assistant_server.py @@ -3,24 +3,18 @@ import time from pathlib import Path +import jsonlines + try: import add_qwen_libs # NOQA except ImportError: pass -try: - import gradio as gr - if gr.__version__ < '3.50' or gr.__version__ >= '4.0': - raise ImportError('Incompatible gradio version detected. ' - 'Please install the correct version with: pip install "gradio>=3.50,<4.0"') -except (ModuleNotFoundError, AttributeError): - raise ImportError('Please install gradio by: pip install "gradio>=3.50,<4.0"') -import jsonlines - -from qwen_agent.agents import DocQAAgent +from qwen_agent.agents import Assistant +from qwen_agent.gui import gr +from qwen_agent.gui.utils import get_avatar_image from qwen_agent.llm.base import ModelServiceError from qwen_agent.log import logger -from qwen_server import output_beautify from qwen_server.schema import GlobalConfig from qwen_server.utils import read_history, read_meta_data_by_condition, save_history @@ -29,9 +23,7 @@ server_config = json.load(f) server_config = GlobalConfig(**server_config) -function_list = None llm_config = None -storage_path = None if hasattr(server_config.server, 'llm'): llm_config = { @@ -39,12 +31,8 @@ 'api_key': server_config.server.api_key, 'model_server': server_config.server.model_server } -if hasattr(server_config.server, 'functions'): - function_list = server_config.server.functions -if hasattr(server_config.path, 'database_root'): - storage_path = server_config.path.database_root -assistant = DocQAAgent(function_list=function_list, llm=llm_config) +assistant = Assistant(llm=llm_config) with open(Path(__file__).resolve().parent / 'css/main.css', 'r') as f: css = f.read() @@ -91,10 +79,10 @@ def bot(history): history[-1][1] = '' try: response = assistant.run(messages=messages, max_ref_token=server_config.server.max_ref_token) - - for chunk in output_beautify.convert_to_full_str_stream(response): - history[-1][1] = chunk - yield history + for rsp in response: + if rsp: + history[-1][1] = rsp[-1]['content'] + yield history except ModelServiceError as ex: history[-1][1] = str(ex) yield history @@ -124,10 +112,7 @@ def clear_session(): with gr.Blocks(css=css, theme='soft') as demo: - chatbot = gr.Chatbot([], - elem_id='chatbot', - height=480, - avatar_images=(None, (os.path.join(Path(__file__).resolve().parent, 'img/logo.png')))) + chatbot = gr.Chatbot([], elem_id='chatbot', height=480, avatar_images=(None, get_avatar_image('qwen'))) with gr.Row(): with gr.Column(scale=7): txt = gr.Textbox(show_label=False, placeholder='Chat with Qwen...', container=False) diff --git a/qwen_server/database_server.py b/qwen_server/database_server.py index adb9b3d..cb81743 100644 --- a/qwen_server/database_server.py +++ b/qwen_server/database_server.py @@ -3,10 +3,6 @@ import os from pathlib import Path -try: - import add_qwen_libs # NOQA -except ImportError: - pass import json5 import jsonlines import uvicorn @@ -15,9 +11,14 @@ from fastapi.responses import JSONResponse from fastapi.staticfiles import StaticFiles +try: + import add_qwen_libs # NOQA +except ImportError: + pass + from qwen_agent.log import logger from qwen_agent.memory import Memory -from qwen_agent.utils.utils import get_local_ip, hash_sha256, save_text_to_file +from qwen_agent.utils.utils import get_basename_from_url, get_local_ip, hash_sha256, save_text_to_file from qwen_server.schema import GlobalConfig from qwen_server.utils import rm_browsing_meta_data, save_browsing_meta_data, save_history @@ -55,8 +56,8 @@ def update_pop_url(url: str): - if not url.lower().endswith('.pdf'): - url = os.path.join(server_config.path.download_root, hash_sha256(url)) + if not get_basename_from_url(url).lower().endswith('.pdf'): + url = os.path.join(server_config.path.download_root, hash_sha256(url), get_basename_from_url(url)) new_line = {'url': url} with jsonlines.open(cache_file_popup_url, mode='w') as writer: @@ -78,9 +79,10 @@ def cache_page(**kwargs): url = kwargs.get('url', '') page_content = kwargs.get('content', '') - if page_content and not url.lower().endswith('.pdf'): + if page_content and not get_basename_from_url(url).lower().endswith('.pdf'): # map to local url - url = os.path.join(server_config.path.download_root, hash_sha256(url)) + os.makedirs(os.path.join(server_config.path.download_root, hash_sha256(url)), exist_ok=True) + url = os.path.join(server_config.path.download_root, hash_sha256(url), get_basename_from_url(url)) save_browsing_meta_data(url, '[CACHING]', meta_file) # rm history save_history(None, url, history_dir) diff --git a/qwen_server/img/logo.png b/qwen_server/img/logo.png deleted file mode 100644 index d80ed39..0000000 Binary files a/qwen_server/img/logo.png and /dev/null differ diff --git a/qwen_server/js/main.js b/qwen_server/js/main.js index d784584..498d95e 100644 --- a/qwen_server/js/main.js +++ b/qwen_server/js/main.js @@ -1,45 +1,46 @@ -window.onload = function() { - // autoTriggerFunction(); -}; +() => { + window.onload = function() { + // autoTriggerFunction(); + }; -function autoTriggerFunction() { - var button = document.getElementById("update_all_bt"); - button.click(); -} - -// const textbox = document.querySelector('#cmd label textarea'); + function autoTriggerFunction() { + var button = document.getElementById("update_all_bt"); + button.click(); + } -// textbox.addEventListener('input', () => { -// textbox.scrollTop = textbox.scrollHeight; -// console.log('input'); -// }); -// textbox.addEventListener('change', () => { -// textbox.scrollTop = textbox.scrollHeight; -// console.log('change'); -// }); + // const textbox = document.querySelector('#cmd label textarea'); -function scrollTextboxToBottom() { - var textbox = document.querySelector('.textbox_container label textarea'); - textbox.scrollTop = textbox.scrollHeight*10; -} -window.addEventListener('DOMContentLoaded', scrollTextboxToBottom); + // textbox.addEventListener('input', () => { + // textbox.scrollTop = textbox.scrollHeight; + // console.log('input'); + // }); + // textbox.addEventListener('change', () => { + // textbox.scrollTop = textbox.scrollHeight; + // console.log('change'); + // }); + function scrollTextboxToBottom() { + var textbox = document.querySelector('.textbox_container label textarea'); + textbox.scrollTop = textbox.scrollHeight*10; + } + window.addEventListener('DOMContentLoaded', scrollTextboxToBottom); -var checkboxes = document.querySelectorAll('input[type="checkbox"]'); -checkboxes.forEach(function(checkbox) { - checkbox.addEventListener("change", function() { - console.log(location.hostname) - var _server_url = "http://"+location.hostname+":7866/endpoint"; - fetch(_server_url, { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({'task': 'change_checkbox', 'ckid': checkbox.id}), - }) - .then((response) => response.json()) - .then((data) => { - console.log(data.result) - }); + document.addEventListener('change', function(event) { + // Check if the changed element is a checkbox + if (event.target.type === 'checkbox') { + console.log(location.hostname); + var _server_url = "http://" + location.hostname + ":7866/endpoint"; + fetch(_server_url, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({'task': 'change_checkbox', 'ckid': event.target.id}), + }) + .then((response) => response.json()) + .then((data) => { + console.log(data.result); + }); + } }); -}); +} diff --git a/qwen_server/output_beautify.py b/qwen_server/output_beautify.py index 5ef9f4f..38bf668 100644 --- a/qwen_server/output_beautify.py +++ b/qwen_server/output_beautify.py @@ -1,10 +1,7 @@ -from typing import Dict, Iterator, List - import json5 -from qwen_agent.llm.schema import ASSISTANT, CONTENT, FUNCTION, ROLE, SYSTEM, USER from qwen_agent.log import logger -from qwen_agent.utils.utils import extract_code, extract_obs, extract_urls, print_traceback +from qwen_agent.utils.utils import extract_code, extract_urls, print_traceback FN_NAME = 'Action' FN_ARGS = 'Action Input' @@ -12,33 +9,11 @@ FN_EXIT = 'Response' -def convert_fncall_to_text(messages: List[Dict]) -> List[Dict]: - new_messages = [] - for msg in messages: - role, content = msg[ROLE], msg[CONTENT] - content = (content or '').lstrip('\n').rstrip() - if role in (SYSTEM, USER): - new_messages.append({ROLE: role, CONTENT: content}) - elif role == ASSISTANT: - fn_call = msg.get(f'{FUNCTION}_call', {}) - if fn_call: - f_name = fn_call['name'] - f_args = fn_call['arguments'] - if f_args.startswith('```'): # if code snippet - f_args = '\n' + f_args # for markdown rendering - content += f'\n{FN_NAME}: {f_name}' - content += f'\n{FN_ARGS}: {f_args}' - if len(new_messages) > 0 and new_messages[-1][ROLE] == ASSISTANT: - new_messages[-1][CONTENT] += content - else: - content = content.lstrip('\n').rstrip() - new_messages.append({ROLE: role, CONTENT: content}) - elif role == FUNCTION: - assert new_messages[-1][ROLE] == ASSISTANT - new_messages[-1][CONTENT] += f'\n{FN_RESULT}: {content}\n{FN_EXIT}: ' - else: - raise TypeError - return new_messages +def extract_obs(text): + k = text.rfind('\nObservation:') + j = text.rfind('\nThought:') + obs = text[k + len('\nObservation:'):j] + return obs.strip() def format_answer(text): @@ -72,48 +47,3 @@ def format_answer(text): return rsp else: return text.split(f'{FN_EXIT}:')[-1].strip() - - -def convert_to_full_str_stream(message_list_stream: Iterator[List[Dict]]) -> Iterator[str]: - """ - output the full streaming str response - """ - for message_list in message_list_stream: - if not message_list: - continue - new_message_list = convert_fncall_to_text(message_list) - assert len(new_message_list) == 1 and new_message_list[0][ROLE] == ASSISTANT - yield new_message_list[0][CONTENT] - - -def convert_to_delta_str_stream(message_list_stream: Iterator[List[Dict]]) -> Iterator[str]: - """ - output the delta streaming str response - """ - last_len = 0 - delay_len = 20 - in_delay = False - text = '' - for text in convert_to_full_str_stream(message_list_stream): - if (len(text) - last_len) <= delay_len: - in_delay = True - continue - else: - in_delay = False - real_text = text[:-delay_len] - now_rsp = real_text[last_len:] - yield now_rsp - last_len = len(real_text) - - if text and (in_delay or (last_len != len(text))): - yield text[last_len:] - - -def convert_to_str(message_list_stream: Iterator[List[Dict]]) -> str: - """ - output the final full str response - """ - response = '' - for r in convert_to_full_str_stream(message_list_stream): - response = r - return response diff --git a/qwen_server/schema.py b/qwen_server/schema.py index f1ddf27..11651a0 100644 --- a/qwen_server/schema.py +++ b/qwen_server/schema.py @@ -3,7 +3,6 @@ class PathConfig(BaseModel): work_space_root: str - database_root: str download_root: str code_interpreter_ws: str @@ -18,7 +17,6 @@ class ServerConfig(BaseModel): llm: str max_ref_token: int max_days: int - functions: list class Config: protected_namespaces = () diff --git a/qwen_server/server_config.json b/qwen_server/server_config.json index 16a5acc..86f28da 100644 --- a/qwen_server/server_config.json +++ b/qwen_server/server_config.json @@ -1,9 +1,8 @@ { "path": { "work_space_root": "workspace/", - "database_root": "workspace/database/", "download_root": "workspace/download/", - "code_interpreter_ws": "workspace/ci_workspace/" + "code_interpreter_ws": "workspace/tools/code_interpreter/" }, "server": { "server_host": "127.0.0.1", @@ -14,10 +13,6 @@ "api_key": "", "llm": "qwen-plus", "max_ref_token": 4000, - "max_days": 7, - "functions": [ - "code_interpreter", - "image_gen" - ] + "max_days": 7 } } diff --git a/qwen_server/workstation_server.py b/qwen_server/workstation_server.py index f2b7bd0..6679005 100644 --- a/qwen_server/workstation_server.py +++ b/qwen_server/workstation_server.py @@ -3,25 +3,22 @@ import os from pathlib import Path -try: - import gradio as gr - if gr.__version__ < '3.50' or gr.__version__ >= '4.0': - raise ImportError('Incompatible gradio version detected. ' - 'Please install the correct version with: pip install "gradio>=3.50,<4.0"') -except (ModuleNotFoundError, AttributeError): - raise ImportError('Please install gradio by: pip install "gradio>=3.50,<4.0"') import json5 +from qwen_agent.tools.simple_doc_parser import PARSER_SUPPORTED_FILE_TYPES + try: import add_qwen_libs # NOQA except ImportError: pass -from qwen_agent.agents import ArticleAgent, DocQAAgent, ReActChat + +from qwen_agent.agents import ArticleAgent, Assistant, ReActChat +from qwen_agent.gui import gr +from qwen_agent.gui.utils import get_avatar_image from qwen_agent.llm import get_chat_model from qwen_agent.llm.base import ModelServiceError from qwen_agent.memory import Memory -from qwen_agent.utils.utils import (get_basename_from_url, get_last_one_line_context, has_chinese_chars, - save_text_to_file) +from qwen_agent.utils.utils import get_basename_from_url, get_file_type, has_chinese_chars, save_text_to_file from qwen_server import output_beautify from qwen_server.schema import GlobalConfig from qwen_server.utils import read_meta_data_by_condition, save_browsing_meta_data @@ -30,9 +27,7 @@ with open(Path(__file__).resolve().parent / 'server_config.json', 'r') as f: server_config = json.load(f) server_config = GlobalConfig(**server_config) -function_list = None llm_config = None -storage_path = None if hasattr(server_config.server, 'llm'): llm_config = { @@ -40,17 +35,15 @@ 'api_key': server_config.server.api_key, 'model_server': server_config.server.model_server } -if hasattr(server_config.server, 'functions'): - function_list = server_config.server.functions -if hasattr(server_config.path, 'database_root'): - storage_path = server_config.path.database_root app_global_para = { 'time': [str(datetime.date.today()), str(datetime.date.today())], 'messages': [], 'last_turn_msg_id': [], 'is_first_upload': True, - 'uploaded_ci_file': '' + 'uploaded_ci_file': '', + 'pure_messages': [], + 'pure_last_turn_msg_id': [], } DOC_OPTION = 'Document QA' @@ -75,6 +68,7 @@ def add_text(history, text): def pure_add_text(history, text): history = history + [(text, None)] + app_global_para['pure_last_turn_msg_id'] = [] return history, gr.update(value='', interactive=False) @@ -93,12 +87,23 @@ def chat_clear(): return None, None +def chat_clear_pure(): + app_global_para['pure_messages'] = [] + return None, None + + def chat_clear_last(): for index in app_global_para['last_turn_msg_id'][::-1]: del app_global_para['messages'][index] app_global_para['last_turn_msg_id'] = [] +def pure_chat_clear_last(): + for index in app_global_para['pure_last_turn_msg_id'][::-1]: + del app_global_para['pure_messages'][index] + app_global_para['pure_last_turn_msg_id'] = [] + + def add_file(file, chosen_plug): display_path = get_basename_from_url(file.name) @@ -106,10 +111,10 @@ def add_file(file, chosen_plug): app_global_para['uploaded_ci_file'] = file.name app_global_para['is_first_upload'] = True return display_path - - if not file.name.lower().endswith(('pdf', 'docx', 'pptx')): + f_type = get_file_type(file) + if f_type not in PARSER_SUPPORTED_FILE_TYPES: display_path = ( - 'Upload failed: only adding [\'.pdf\', \'.docx\', \'.pptx\'] documents as references is supported!') + f'Upload failed: only adding {", ".join(PARSER_SUPPORTED_FILE_TYPES)} as references is supported!') else: # cache file try: @@ -189,17 +194,22 @@ def pure_bot(history): yield history else: history[-1][1] = '' - messages = [] - for chat in history[:-1]: - messages.append({'role': 'user', 'content': chat[0]}) - messages.append({'role': 'assistant', 'content': chat[1]}) - messages.append({'role': 'user', 'content': history[-1][0]}) + message = [{'role': 'user', 'content': history[-1][0], 'name': 'pure_chat_user'}] try: llm = get_chat_model(llm_config) - response = llm.chat(messages=messages) - for chunk in output_beautify.convert_to_full_str_stream(response): - history[-1][1] = chunk - yield history + response = llm.chat(messages=app_global_para['pure_messages'] + message) + rsp = [] + for rsp in response: + if rsp: + history[-1][1] = rsp[-1]['content'] + yield history + + # Record the conversation history when the conversation succeeds + app_global_para['pure_last_turn_msg_id'].append(len(app_global_para['pure_messages'])) + app_global_para['pure_messages'].extend(message) # New user message + app_global_para['pure_last_turn_msg_id'].append(len(app_global_para['pure_messages'])) + app_global_para['pure_messages'].extend(rsp) # The response + except ModelServiceError as ex: history[-1][1] = str(ex) yield history @@ -231,7 +241,6 @@ def bot(history, chosen_plug): yield history else: history[-1][1] = '' - message = [] if chosen_plug == CI_OPTION: # use code interpreter if app_global_para['uploaded_ci_file'] and app_global_para['is_first_upload']: app_global_para['is_first_upload'] = False # only send file when first upload @@ -250,9 +259,16 @@ def bot(history, chosen_plug): func_assistant = ReActChat(function_list=['code_interpreter'], llm=llm_config) try: response = func_assistant.run(messages=messages) - for chunk in output_beautify.convert_to_full_str_stream(response): - history[-1][1] = chunk - yield history + rsp = [] + for rsp in response: + if rsp: + history[-1][1] = rsp[-1]['content'] + yield history + # append message + app_global_para['last_turn_msg_id'].append(len(app_global_para['messages'])) + app_global_para['messages'].extend(message) + app_global_para['last_turn_msg_id'].append(len(app_global_para['messages'])) + app_global_para['messages'].extend(rsp) except ModelServiceError as ex: history[-1][1] = str(ex) yield history @@ -264,25 +280,38 @@ def bot(history, chosen_plug): # checked files for record in read_meta_data_by_condition(meta_file, time_limit=app_global_para['time'], checked=True): content.append({'file': record['url']}) - qa_assistant = DocQAAgent(llm=llm_config) + qa_assistant = Assistant(llm=llm_config) message = [{'role': 'user', 'content': content}] - response = qa_assistant.run(messages=message, max_ref_token=server_config.server.max_ref_token) - for chunk in output_beautify.convert_to_full_str_stream(response): - history[-1][1] = chunk - yield history + # rm all files of history + messages = keep_only_files_for_name(app_global_para['messages'], 'None') + message + response = qa_assistant.run(messages=messages, max_ref_token=server_config.server.max_ref_token) + rsp = [] + for rsp in response: + if rsp: + history[-1][1] = rsp[-1]['content'] + yield history + # append message + app_global_para['last_turn_msg_id'].append(len(app_global_para['messages'])) + app_global_para['messages'].extend(message) + app_global_para['last_turn_msg_id'].append(len(app_global_para['messages'])) + app_global_para['messages'].extend(rsp) + except ModelServiceError as ex: history[-1][1] = str(ex) yield history except Exception as ex: raise ValueError(ex) - # append message - app_global_para['last_turn_msg_id'].append(len(app_global_para['messages'])) - app_global_para['messages'].extend(message) - message = {'role': 'assistant', 'content': history[-1][1]} - app_global_para['last_turn_msg_id'].append(len(app_global_para['messages'])) - app_global_para['messages'].append(message) +def get_last_one_line_context(text): + lines = text.split('\n') + n = len(lines) + res = '' + for i in range(n - 1, -1, -1): + if lines[i].strip(): + res = lines[i] + break + return res def generate(context): @@ -297,8 +326,9 @@ def generate(context): func_assistant = ReActChat(function_list=['code_interpreter'], llm=llm_config) try: response = func_assistant.run(messages=[{'role': 'user', 'content': sp_query}]) - for chunk in output_beautify.convert_to_full_str_stream(response): - yield chunk + for rsp in response: + if rsp: + yield rsp[-1]['content'] except ModelServiceError as ex: yield str(ex) except Exception as ex: @@ -309,8 +339,9 @@ def generate(context): func_assistant = ReActChat(function_list=['code_interpreter', 'image_gen'], llm=llm_config) try: response = func_assistant.run(messages=[{'role': 'user', 'content': sp_query}]) - for chunk in output_beautify.convert_to_full_str_stream(response): - yield chunk + for rsp in response: + if rsp: + yield rsp[-1]['content'] except ModelServiceError as ex: yield str(ex) except Exception as ex: @@ -338,8 +369,9 @@ def generate(context): }], max_ref_token=server_config.server.max_ref_token, full_article=full_article) - for chunk in output_beautify.convert_to_full_str_stream(response): - yield chunk + for rsp in response: + if rsp: + yield rsp[-1]['content'] except ModelServiceError as ex: yield str(ex) except Exception as ex: @@ -364,7 +396,7 @@ def format_generate(edit, context): yield res -with gr.Blocks(css=css, theme='soft') as demo: +with gr.Blocks(css=css, js=js, theme='soft') as demo: title = gr.Markdown('Qwen Agent: BrowserQwen', elem_classes='title') desc = gr.Markdown( 'This is the editing workstation of BrowserQwen, where Qwen has collected the browsing history. Qwen can assist you in completing your creative work!', @@ -480,10 +512,7 @@ def format_generate(edit, context): elem_id='chatbot', height=680, show_copy_button=True, - avatar_images=( - None, - (os.path.join(Path(__file__).resolve().parent, 'img/logo.png')), - ), + avatar_images=(None, get_avatar_image('qwen')), ) with gr.Row(): with gr.Column(scale=1, min_width=0): @@ -523,7 +552,7 @@ def format_generate(edit, context): re_txt_msg.then(lambda: gr.update(interactive=True), None, [chat_txt], queue=False) file_msg = file_btn.upload(add_file, [file_btn, plug_bt], [hidden_file_path], queue=False) - file_msg.then(update_browser_list, None, browser_list).then(lambda: None, None, None, _js=f'() => {{{js}}}') + file_msg.then(update_browser_list, None, browser_list) chat_clr_bt.click(chat_clear, None, [chatbot, hidden_file_path], queue=False) # re_bt.click(re_bot, chatbot, chatbot) @@ -539,10 +568,7 @@ def format_generate(edit, context): elem_id='pure_chatbot', height=680, show_copy_button=True, - avatar_images=( - None, - (os.path.join(Path(__file__).resolve().parent, 'img/logo.png')), - ), + avatar_images=(None, get_avatar_image('qwen')), ) with gr.Row(): with gr.Column(scale=13): @@ -563,27 +589,22 @@ def format_generate(edit, context): txt_msg.then(lambda: gr.update(interactive=True), None, [chat_txt], queue=False) re_txt_msg = chat_re_bt.click(rm_text, [pure_chatbot], [pure_chatbot, chat_txt], - queue=False).then(pure_bot, pure_chatbot, pure_chatbot) + queue=False).then(pure_chat_clear_last, None, + None).then(pure_bot, pure_chatbot, pure_chatbot) re_txt_msg.then(lambda: gr.update(interactive=True), None, [chat_txt], queue=False) - chat_clr_bt.click(lambda: None, None, pure_chatbot, queue=False) + chat_clr_bt.click(chat_clear_pure, None, pure_chatbot, queue=False) - chat_stop_bt.click(chat_clear_last, None, None, cancels=[txt_msg, re_txt_msg], queue=False) + chat_stop_bt.click(pure_chat_clear_last, None, None, cancels=[txt_msg, re_txt_msg], queue=False) date1.change(update_app_global_para, [date1, date2], - None).then(update_browser_list, None, - browser_list).then(lambda: None, None, None, - _js=f'() => {{{js}}}').then(chat_clear, None, - [chatbot, hidden_file_path]) + None).then(update_browser_list, None, browser_list).then(chat_clear, None, [chatbot, hidden_file_path]) date2.change(update_app_global_para, [date1, date2], - None).then(update_browser_list, None, - browser_list).then(lambda: None, None, None, - _js=f'() => {{{js}}}').then(chat_clear, None, - [chatbot, hidden_file_path]) - - demo.load(update_app_global_para, [date1, date2], None).then(refresh_date, None, [date1, date2]).then( - update_browser_list, None, browser_list).then(lambda: None, None, None, - _js=f'() => {{{js}}}').then(chat_clear, None, - [chatbot, hidden_file_path]) + None).then(update_browser_list, None, browser_list).then(chat_clear, None, [chatbot, hidden_file_path]) + + demo.load(update_app_global_para, [date1, date2], + None).then(refresh_date, None, + [date1, date2]).then(update_browser_list, None, + browser_list).then(chat_clear, None, [chatbot, hidden_file_path]) demo.queue().launch(server_name=server_config.server.server_host, server_port=server_config.server.workstation_port) diff --git a/requirements.txt b/requirements.txt index 5982a60..85c3b9f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ anyio>=3.7.1 beautifulsoup4 dashscope>=1.11.0 -docx2txt +eval_type_backport fastapi>=0.103.1 html2text jieba @@ -13,9 +13,12 @@ numpy openai pandas pdfminer-six +pdfplumber pillow pydantic>=2.3.0 +python-docx python-pptx +rank_bm25 seaborn sympy tiktoken diff --git a/run_server.py b/run_server.py index 289e853..ddcaace 100644 --- a/run_server.py +++ b/run_server.py @@ -89,15 +89,15 @@ def main(): server_config = update_config(server_config, args, server_config_path) os.makedirs(server_config.path.work_space_root, exist_ok=True) - os.makedirs(server_config.path.database_root, exist_ok=True) os.makedirs(server_config.path.download_root, exist_ok=True) os.makedirs(server_config.path.code_interpreter_ws, exist_ok=True) code_interpreter_work_dir = str(Path(__file__).resolve().parent / server_config.path.code_interpreter_ws) + + # TODO: Remove these two hacky code interpreter env vars. os.environ['M6_CODE_INTERPRETER_WORK_DIR'] = code_interpreter_work_dir - from qwen_agent.log import logger - from qwen_agent.utils.utils import get_local_ip + from qwen_agent.utils.utils import append_signal_handler, get_local_ip, logger logger.info(server_config) if args.server_host == '0.0.0.0': @@ -109,14 +109,20 @@ def main(): servers = { 'database': - subprocess.Popen([sys.executable, - os.path.join(os.getcwd(), 'qwen_server/database_server.py')]), + subprocess.Popen([ + sys.executable, + os.path.join(os.getcwd(), 'qwen_server/database_server.py'), + ]), 'workstation': - subprocess.Popen([sys.executable, - os.path.join(os.getcwd(), 'qwen_server/workstation_server.py')]), + subprocess.Popen([ + sys.executable, + os.path.join(os.getcwd(), 'qwen_server/workstation_server.py'), + ]), 'assistant': - subprocess.Popen([sys.executable, - os.path.join(os.getcwd(), 'qwen_server/assistant_server.py')]), + subprocess.Popen([ + sys.executable, + os.path.join(os.getcwd(), 'qwen_server/assistant_server.py'), + ]), } def signal_handler(sig_num, _frame): @@ -127,8 +133,8 @@ def signal_handler(sig_num, _frame): if sig_num == signal.SIGINT: raise KeyboardInterrupt() - signal.signal(signal.SIGINT, signal_handler) - signal.signal(signal.SIGTERM, signal_handler) + append_signal_handler(signal.SIGINT, signal_handler) + append_signal_handler(signal.SIGTERM, signal_handler) for p in list(servers.values()): p.wait() diff --git a/tests/agents/test_doc_qa.py b/tests/agents/test_doc_qa.py index 60f3a78..fb54333 100644 --- a/tests/agents/test_doc_qa.py +++ b/tests/agents/test_doc_qa.py @@ -1,9 +1,9 @@ -from qwen_agent.agents import DocQAAgent +from qwen_agent.agents.doc_qa import BasicDocQA def test_doc_qa(): llm_cfg = {'model': 'qwen-max', 'api_key': '', 'model_server': 'dashscope'} - agent = DocQAAgent(llm=llm_cfg) + agent = BasicDocQA(llm=llm_cfg) messages = [{ 'role': 'user', 'content': [{ diff --git a/tests/agents/test_router.py b/tests/agents/test_router.py index c3e12de..343d17a 100644 --- a/tests/agents/test_router.py +++ b/tests/agents/test_router.py @@ -5,26 +5,21 @@ def test_router(): llm_cfg = {'model': 'qwen-max'} llm_cfg_vl = {'model': 'qwen-vl-max'} - tools = ['image_gen', 'amap_weather'] + tools = ['amap_weather'] - # define a vl agent - bot_vl = Assistant(llm=llm_cfg_vl) + # Define a vl agent + bot_vl = Assistant(llm=llm_cfg_vl, name='多模态助手', description='可以理解图像内容。') - # define a tool agent - bot_tool = Assistant(llm=llm_cfg, function_list=tools) + # Define a tool agent + bot_tool = Assistant( + llm=llm_cfg, + name='天气预报助手', + description='可以查询天气', + function_list=tools, + ) # define a router (Simultaneously serving as a text agent) - bot = Router(llm=llm_cfg, - agents={ - 'vl': { - 'obj': bot_vl, - 'desc': '多模态助手,可以理解图像内容。' - }, - 'tool': { - 'obj': bot_tool, - 'desc': '工具助手,可以使用天气查询工具和画图工具来解决问题' - } - }) + bot = Router(llm=llm_cfg, agents=[bot_vl, bot_tool]) messages = [ Message( 'user', diff --git a/tests/examples/test_examples.py b/tests/examples/test_examples.py index ada8230..d1b520a 100644 --- a/tests/examples/test_examples.py +++ b/tests/examples/test_examples.py @@ -12,7 +12,7 @@ from examples.assistant_growing_girl import test as assistant_growing_girl # noqa from examples.assistant_weather_bot import test as assistant_weather_bot # noqa from examples.function_calling import test as function_calling # noqa -from examples.gpt_mentions import test as gpt_mentions # noqa +# from examples.gpt_mentions import test as gpt_mentions # noqa from examples.group_chat_chess import test as group_chat_chess # noqa from examples.group_chat_demo import test as group_chat_demo # noqa from examples.llm_riddles import test as llm_riddles # noqa @@ -67,10 +67,11 @@ def test_function_calling(): function_calling() -@pytest.mark.parametrize('history', ['你能做什么?']) -@pytest.mark.parametrize('chosen_plug', ['code_interpreter', 'doc_qa', 'assistant']) -def test_gpt_mentions(history, chosen_plug): - gpt_mentions(history=history, chosen_plug=chosen_plug) +# @pytest.mark.parametrize('history', ['你能做什么?']) +# @pytest.mark.parametrize('chosen_plug', +# ['code_interpreter', 'doc_qa', 'assistant']) +# def test_gpt_mentions(history, chosen_plug): +# gpt_mentions(history=history, chosen_plug=chosen_plug) @pytest.mark.parametrize( diff --git a/tests/memory/test_memory.py b/tests/memory/test_memory.py index 0671c03..ad4d273 100644 --- a/tests/memory/test_memory.py +++ b/tests/memory/test_memory.py @@ -16,14 +16,18 @@ def test_memory(): mem = Memory(llm=llm_cfg) messages = [ Message('user', [ - ContentItem(text='总结'), - ContentItem(file='https://github.com/QwenLM/Qwen-Agent'), + ContentItem(text='女孩成长历程'), ContentItem(file=str(Path(__file__).resolve().parent.parent.parent / 'examples/resource/growing_girl.pdf')) ]) ] - *_, last = mem.run(messages) + *_, last = mem.run(messages, max_ref_token=4000, parser_page_size=500, ignore_cache=True) + print(last) assert isinstance(last[-1].content, str) assert len(last[-1].content) > 0 res = json5.loads(last[-1].content) assert isinstance(res, list) + + +if __name__ == '__main__': + test_memory() diff --git a/tests/qwen_server/test_database_server.py b/tests/qwen_server/test_database_server.py index 10b63ce..ed074d5 100644 --- a/tests/qwen_server/test_database_server.py +++ b/tests/qwen_server/test_database_server.py @@ -3,7 +3,7 @@ import shutil from pathlib import Path -from qwen_agent.utils.utils import hash_sha256 +from qwen_agent.utils.utils import get_basename_from_url, hash_sha256 from qwen_server.schema import GlobalConfig from qwen_server.utils import read_meta_data_by_condition @@ -16,7 +16,6 @@ def test_database_server(): if os.path.exists('workspace'): shutil.rmtree('workspace') os.makedirs(server_config.path.work_space_root) - os.makedirs(server_config.path.database_root) os.makedirs(server_config.path.download_root) os.makedirs(server_config.path.code_interpreter_ws) @@ -31,7 +30,8 @@ def test_database_server(): } cache_page(**data) - new_url = os.path.join(server_config.path.download_root, hash_sha256(data['url'])) + new_url = os.path.join(server_config.path.download_root, hash_sha256(data['url']), + get_basename_from_url(data['url'])) assert os.path.exists(new_url) meta_file = os.path.join(server_config.path.work_space_root, 'meta_data.jsonl') diff --git a/tests/tools/test_doc_parser.py b/tests/tools/test_doc_parser.py new file mode 100644 index 0000000..0f9f2a1 --- /dev/null +++ b/tests/tools/test_doc_parser.py @@ -0,0 +1,12 @@ +from qwen_agent.tools import DocParser + + +def test_doc_parser(): + tool = DocParser() + res = tool.call({'url': 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/QWEN_TECHNICAL_REPORT.pdf'}, + ignore_cache=True) + print(res) + + +if __name__ == '__main__': + test_doc_parser() diff --git a/tests/tools/test_similarity_search.py b/tests/tools/test_similarity_search.py new file mode 100644 index 0000000..49baa53 --- /dev/null +++ b/tests/tools/test_similarity_search.py @@ -0,0 +1,20 @@ +from qwen_agent.tools import SimilaritySearch + + +def test_similarity_search(): + tool = SimilaritySearch() + doc = ('主要序列转导模型基于复杂的循环或卷积神经网络,包括编码器和解码器。性能最好的模型还通过注意力机制连接编码器和解码器。' + '我们提出了一种新的简单网络架构——Transformer,它完全基于注意力机制,完全不需要递归和卷积。对两个机器翻译任务的实验表明,' + '这些模型在质量上非常出色,同时具有更高的并行性,并且需要的训练时间显着减少。' + '我们的模型在 WMT 2014 英语到德语翻译任务中取得了 28.4 BLEU,比现有的最佳结果(包括集成)提高了 2 BLEU 以上。' + '在 WMT 2014 英法翻译任务中,我们的模型在 8 个 GPU 上训练 3.5 天后,建立了新的单模型最先进 BLEU 分数 41.0,' + '这只是最佳模型训练成本的一小部分文献中的模型。') + res = tool.call({'query': '这个模型要训练多久?'}, docs=[doc]) + print(res) + + res = tool.call({'query': '这个模型要训练多久?'}, docs=[doc.split('。')]) + print(res) + + +if __name__ == '__main__': + test_similarity_search() diff --git a/tests/tools/test_simple_doc_parser.py b/tests/tools/test_simple_doc_parser.py new file mode 100644 index 0000000..ab288e9 --- /dev/null +++ b/tests/tools/test_simple_doc_parser.py @@ -0,0 +1,11 @@ +from qwen_agent.tools import SimpleDocParser + + +def test_simple_doc_parser(): + tool = SimpleDocParser() + res = tool.call({'url': 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/QWEN_TECHNICAL_REPORT.pdf'}) + print(res) + + +if __name__ == '__main__': + test_simple_doc_parser() diff --git a/tests/tools/test_tools.py b/tests/tools/test_tools.py index a6b5827..bc9c1e4 100644 --- a/tests/tools/test_tools.py +++ b/tests/tools/test_tools.py @@ -2,7 +2,7 @@ import pytest -from qwen_agent.tools import AmapWeather, CodeInterpreter, DocParser, ImageGen, Retrieval, SimilaritySearch, Storage +from qwen_agent.tools import AmapWeather, CodeInterpreter, ImageGen, Retrieval, Storage # [NOTE] 不带“市”会出错 @@ -17,11 +17,6 @@ def test_code_interpreter(): tool.call("print('hello qwen')") -def test_doc_parser(): - tool = DocParser() - tool.call({'url': 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/QWEN_TECHNICAL_REPORT.pdf'}) - - def test_image_gen(): tool = ImageGen() tool.call({'prompt': 'a dog'}) @@ -35,17 +30,6 @@ def test_retrieval(): }) -def test_similarity_search(): - tool = SimilaritySearch() - doc = ('主要序列转导模型基于复杂的循环或卷积神经网络,包括编码器和解码器。性能最好的模型还通过注意力机制连接编码器和解码器。' - '我们提出了一种新的简单网络架构——Transformer,它完全基于注意力机制,完全不需要递归和卷积。对两个机器翻译任务的实验表明,' - '这些模型在质量上非常出色,同时具有更高的并行性,并且需要的训练时间显着减少。' - '我们的模型在 WMT 2014 英语到德语翻译任务中取得了 28.4 BLEU,比现有的最佳结果(包括集成)提高了 2 BLEU 以上。' - '在 WMT 2014 英法翻译任务中,我们的模型在 8 个 GPU 上训练 3.5 天后,建立了新的单模型最先进 BLEU 分数 41.0,' - '这只是最佳模型训练成本的一小部分文献中的模型。') - tool.call({'query': '这个模型要训练多久?'}, doc=doc) - - @pytest.mark.parametrize('operate', ['put']) def test_storage_put(operate): tool = Storage()