Skip to content

Commit

Permalink
refactor function calling to allow for flexible prompt modification
Browse files Browse the repository at this point in the history
  • Loading branch information
JianxinMa committed Sep 2, 2024
1 parent c1ed1e5 commit cb86572
Show file tree
Hide file tree
Showing 21 changed files with 849 additions and 598 deletions.
25 changes: 6 additions & 19 deletions examples/qwen2vl_assistant_tooluse.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,26 +46,22 @@ def call(self, params: Union[str, dict], files: List[str] = None, **kwargs) -> s

host = 'https://wuliu.market.alicloudapi.com'
path = '/kdi'
method = 'GET'
appcode = os.environ['AppCode_ExpressTracking'] # 开通服务后 买家中心-查看AppCode
querys = f'no={id}&type={company}'

bodys = {}
url = host + path + '?' + querys
header = {'Authorization': 'APPCODE ' + appcode}
try:
res = requests.get(url, headers=header)
except:
return 'URL错误'
res = requests.get(url, headers=header)
httpStatusCode = res.status_code

if (httpStatusCode == 200):
# print("正常请求计费(其他均不计费)")
import json
try:
out = json.loads(res.text)
except:
out = eval(res.text)
except json.decoder.JSONDecodeError:
import json5
out = json5.loads(res.text)
return '```json' + json.dumps(out, ensure_ascii=False, indent=4) + '\n```'
else:
httpReason = res.headers['X-Ca-Error-Message']
Expand Down Expand Up @@ -143,10 +139,8 @@ def call(self, params: Union[str, dict], files: List[str] = None, **kwargs) -> s

host = 'https://ali-weather.showapi.com'
path = '/spot-to-weather'
method = 'GET'
appcode = os.environ['AppCode_Area2Weather'] # 开通服务后 买家中心-查看AppCode
querys = f'area={area}&needMoreDay={needMoreDay}&needIndex={needIndex}&needHourData={needHourData}&need3HourForcast={need3HourForcast}&needAlarm={needAlarm}'
bodys = {}
url = host + path + '?' + querys

request = urllib.request.Request(url)
Expand Down Expand Up @@ -181,10 +175,8 @@ def call(self, params: Union[str, dict], files: List[str] = None, **kwargs) -> s

host = 'https://ali-weather.showapi.com'
path = '/hour24'
method = 'GET'
appcode = os.environ('AppCode_weather_hour24') # 开通服务后 买家中心-查看AppCode
querys = f'area={area}&areaCode='
bodys = {}
url = host + path + '?' + querys

request = urllib.request.Request(url)
Expand Down Expand Up @@ -283,18 +275,13 @@ def init_agent_service():
llm_cfg_vl = {
# Using Qwen2-VL deployed at any openai-compatible service such as vLLM:
# 'model_type': 'qwenvl_oai',
# 'model': 'Qwen/Qwen2-VL-72B-Instruct',
# 'model': 'Qwen2-VL-7B-Instruct',
# 'model_server': 'http://localhost:8000/v1', # api_base
# 'api_key': 'EMPTY',

# Using Qwen2-VL provided by Alibaba Cloud DashScope:
# 'model_type': 'qwenvl_dashscope',
# 'model': 'qwen2-vl-72b-instruct',
# 'api_key': os.getenv('DASHSCOPE_API_KEY'),

# TODO: Use qwen2-vl instead once qwen2-vl is released.
'model_type': 'qwenvl_dashscope',
'model': 'qwen-vl-max',
'model': 'qwen-vl-max-0809',
'api_key': os.getenv('DASHSCOPE_API_KEY'),
'generate_cfg': dict(max_retries=10,)
}
Expand Down
19 changes: 14 additions & 5 deletions examples/qwen2vl_function_calling.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,33 @@
import json
import os
import urllib.parse

from qwen_agent.llm import get_chat_model
from qwen_agent.llm.schema import ContentItem
from qwen_agent.utils.utils import save_url_to_local_work_dir


def image_gen(prompt: str) -> str:
prompt = urllib.parse.quote(prompt)
image_url = f'https://image.pollinations.ai/prompt/{prompt}'
image_url = save_url_to_local_work_dir(image_url, save_dir='./', save_filename='pic.jpg')
return image_url


def test():
# Config for the model
llm_cfg_oai = {
# Using Qwen2-VL deployed at any openai-compatible service such as vLLM:
'model_type': 'qwenvl_oai',
'model': 'Qwen/Qwen2-VL-72B-Instruct',
'model_server': 'http://localhost:8000/v1', # api_base
'api_key': 'EMPTY',
# 'model_type': 'qwenvl_oai',
# 'model': 'Qwen2-VL-7B-Instruct',
# 'model_server': 'http://localhost:8000/v1', # api_base
# 'api_key': 'EMPTY',

# Using Qwen2-VL provided by Alibaba Cloud DashScope:
'model_type': 'qwenvl_dashscope',
'model': 'qwen-vl-max-0809',
'api_key': os.getenv('DASHSCOPE_API_KEY'),
'generate_cfg': dict(max_retries=10,)
}
llm = get_chat_model(llm_cfg_oai)

Expand All @@ -29,7 +38,7 @@ def test():
'content': [{
'image': 'https://dashscope.oss-cn-beijing.aliyuncs.com/images/dog_and_girl.jpeg'
}, {
'text': '图片中的内容是什么?请画一张内容相同,风格类似的图片。'
'text': '图片中的内容是什么?请画一张内容相同,风格类似的图片。把女人换成男人'
}]
}]

Expand Down
2 changes: 1 addition & 1 deletion qwen_agent/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '0.0.7'
__version__ = '0.0.8'
from .agent import Agent
from .multi_agent_hub import MultiAgentHub

Expand Down
11 changes: 9 additions & 2 deletions qwen_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class Agent(ABC):

def __init__(self,
function_list: Optional[List[Union[str, Dict, BaseTool]]] = None,
llm: Optional[Union[Dict, BaseChatModel]] = None,
llm: Optional[Union[dict, BaseChatModel]] = None,
system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE,
name: Optional[str] = None,
description: Optional[str] = None,
Expand All @@ -48,7 +48,7 @@ def __init__(self,
for tool in function_list:
self._init_tool(tool)

self.system_message = system_message
self.system_message = system_message or self.SYSTEM_MESSAGE
self.name = name
self.description = description

Expand Down Expand Up @@ -226,3 +226,10 @@ def _detect_tool(self, message: Message) -> Tuple[bool, str, str, str]:
text = ''

return (func_name is not None), func_name, func_args, text


# The most basic form of an agent is just a LLM, not augmented with any tool or workflow.
class BasicAgent(Agent):

def _run(self, messages: List[Message], lang: str = 'en', **kwargs) -> Iterator[List[Message]]:
return self._call_llm(messages)
3 changes: 2 additions & 1 deletion qwen_agent/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from qwen_agent.agent import Agent
from qwen_agent.agent import Agent, BasicAgent
from qwen_agent.multi_agent_hub import MultiAgentHub

from .article_agent import ArticleAgent
Expand All @@ -20,6 +20,7 @@

__all__ = [
'Agent',
'BasicAgent',
'MultiAgentHub',
'DocQAAgent',
'ParallelDocQA',
Expand Down
1 change: 1 addition & 0 deletions qwen_agent/agents/fncall_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def _run(self, messages: List[Message], lang: Literal['en', 'zh'] = 'en', **kwar
used_any_tool = True
if not used_any_tool:
break
yield response

def _call_tool(self, tool_name: str, tool_args: Union[str, dict] = '{}', **kwargs) -> str:
if tool_name not in self.function_map:
Expand Down
20 changes: 11 additions & 9 deletions qwen_agent/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from typing import Dict, Optional
from typing import Union

from .azure import TextChatAtAzure
from .base import LLM_REGISTRY, BaseChatModel, ModelServiceError
from .oai import TextChatAtOAI
from .openvino import OpenVINO
from .qwen_dashscope import QwenChatAtDS
from .qwenvl_dashscope import QwenVLChatAtDS
from .qwenvl_oai import QwenVLChatAtOAI
from .azure import TextChatAtAZURE


def get_chat_model(cfg: Optional[Dict] = None) -> BaseChatModel:
def get_chat_model(cfg: Union[dict, str] = 'qwen-plus') -> BaseChatModel:
"""The interface of instantiating LLM objects.
Args:
Expand All @@ -30,7 +30,9 @@ def get_chat_model(cfg: Optional[Dict] = None) -> BaseChatModel:
Returns:
LLM object.
"""
cfg = cfg or {}
if isinstance(cfg, str):
cfg = {'model': cfg}

if 'model_type' in cfg:
model_type = cfg['model_type']
if model_type in LLM_REGISTRY:
Expand All @@ -40,15 +42,15 @@ def get_chat_model(cfg: Optional[Dict] = None) -> BaseChatModel:

# Deduce model_type from model and model_server if model_type is not provided:

if 'azure_endpoint' in cfg:
model_type = 'azure'
return LLM_REGISTRY[model_type](cfg)

if 'model_server' in cfg:
if cfg['model_server'].strip().startswith('http'):
model_type = 'oai'
return LLM_REGISTRY[model_type](cfg)

if 'azure_endpoint' in cfg:
model_type = 'azure'
return LLM_REGISTRY[model_type](cfg)

model = cfg.get('model', '')

if 'qwen-vl' in model:
Expand All @@ -66,7 +68,7 @@ def get_chat_model(cfg: Optional[Dict] = None) -> BaseChatModel:
'BaseChatModel',
'QwenChatAtDS',
'TextChatAtOAI',
'TextChatAtAZURE',
'TextChatAtAzure',
'QwenVLChatAtDS',
'QwenVLChatAtOAI',
'OpenVINO',
Expand Down
87 changes: 9 additions & 78 deletions qwen_agent/llm/azure.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,42 @@
import copy
import logging
import os
from pprint import pformat
from typing import Dict, Iterator, List, Optional
from typing import Dict, Optional

import openai

from openai import OpenAIError

from qwen_agent.llm.base import ModelServiceError, register_llm
from qwen_agent.llm.function_calling import BaseFnCallModel
from qwen_agent.llm.schema import ASSISTANT, Message
from qwen_agent.log import logger
from qwen_agent.llm.base import register_llm
from qwen_agent.llm.oai import TextChatAtOAI


@register_llm('azure')
class TextChatAtAZURE(BaseFnCallModel):
class TextChatAtAzure(TextChatAtOAI):

def __init__(self, cfg: Optional[Dict] = None):
super().__init__(cfg)
self.model = self.model or 'gpt-3.5-turbo'
cfg = cfg or {}

api_base = cfg.get(
'api_base',
cfg.get(
'base_url',
cfg.get('model_server', cfg.get('azure_endpoint','')),
cfg.get('model_server', cfg.get('azure_endpoint', '')),
),
).strip()

api_key = cfg.get('api_key', '')
if not api_key:
api_key = os.getenv('OPENAI_API_KEY', 'EMPTY')
api_key = api_key.strip()
api_version = cfg.get('api_version','2024-06-01')

api_version = cfg.get('api_version', '2024-06-01')

api_kwargs = {}
if api_base:
api_kwargs['azure_endpoint'] = api_base
if api_key:
api_kwargs['api_key'] = api_key
if api_version:
api_kwargs['api_version'] = api_version


def _chat_complete_create(*args, **kwargs):
# OpenAI API v1 does not allow the following args, must pass by extra_body
Expand All @@ -56,71 +48,10 @@ def _chat_complete_create(*args, **kwargs):
kwargs['extra_body'][k] = kwargs.pop(k)
if 'request_timeout' in kwargs:
kwargs['timeout'] = kwargs.pop('request_timeout')

client = openai.AzureOpenAI(**api_kwargs)

# client = openai.OpenAI(**api_kwargs)
return client.chat.completions.create(*args, **kwargs)

self._chat_complete_create = _chat_complete_create

def _chat_stream(
self,
messages: List[Message],
delta_stream: bool,
generate_cfg: dict,
) -> Iterator[List[Message]]:
messages = [msg.model_dump() for msg in messages]
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f'LLM Input:\n{pretty_format_messages(messages, indent=2)}')
try:
response = self._chat_complete_create(model=self.model, messages=messages, stream=True, **generate_cfg)
if delta_stream:
for chunk in response:
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f'Chunk received: {chunk}')
if chunk.choices and hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
yield [Message(ASSISTANT, chunk.choices[0].delta.content)]
else:
full_response = ''
for chunk in response:
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f'Chunk received: {chunk}')
if chunk.choices and hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
full_response += chunk.choices[0].delta.content
yield [Message(ASSISTANT, full_response)]
except OpenAIError as ex:
raise ModelServiceError(exception=ex)

def _chat_no_stream(
self,
messages: List[Message],
generate_cfg: dict,
) -> List[Message]:
messages = [msg.model_dump() for msg in messages]
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f'LLM Input:\n{pretty_format_messages(messages, indent=2)}')
try:
response = self._chat_complete_create(model=self.model, messages=messages, stream=False, **generate_cfg)
return [Message(ASSISTANT, response.choices[0].message.content)]
except OpenAIError as ex:
raise ModelServiceError(exception=ex)


def pretty_format_messages(messages: List[dict], indent: int = 2) -> str:
messages_show = []
for msg in messages:
assert isinstance(msg, dict)
msg_show = copy.deepcopy(msg)
if isinstance(msg['content'], list):
content = []
for item in msg['content']:
(t, v), = item.items()
if (t != 'text') and v.startswith('data:'):
v = v[:64] + ' ...'
content.append({t: v})
else:
content = msg['content']
msg_show['content'] = content
messages_show.append(msg_show)
return pformat(messages_show, indent=indent)
Loading

0 comments on commit cb86572

Please sign in to comment.