Skip to content

Commit

Permalink
update examples: build a multimodal tool-use agent with qwen2vl
Browse files Browse the repository at this point in the history
  • Loading branch information
gewenbin0992 authored and JianxinMa committed Aug 14, 2024
1 parent f3f04ff commit 5119eb0
Show file tree
Hide file tree
Showing 2 changed files with 237 additions and 5 deletions.
240 changes: 235 additions & 5 deletions examples/qwen2vl_assistant_tooluse.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,203 @@
import os
import re
import ssl
import urllib
import urllib.parse
import uuid
from io import BytesIO
from pprint import pprint
from typing import List, Union

import requests
from PIL import Image

from qwen_agent.agents import FnCallAgent
from qwen_agent.gui import WebUI
from qwen_agent.llm.schema import ContentItem
from qwen_agent.tools.base import BaseToolWithFileAccess, register_tool

ROOT_RESOURCE = os.path.join(os.path.dirname(__file__), 'resource')


@register_tool('express_tracking')
class ExpressTracking(BaseToolWithFileAccess):
API_URL = 'https://market.aliyun.com/apimarket/detail/cmapi021863#sku=yuncode15863000017'
description = '全国快递物流查询-快递查询接口'
parameters = [
{
'name': 'no',
'type': 'string',
'description': '快递单号 【顺丰和丰网请输入单号:收件人或寄件人手机号后四位。例如:123456789:1234】',
'required': True
},
{
'name': 'type',
'type': 'string',
'description': '快递公司字母简写:不知道可不填95%能自动识别,填写查询速度会更快',
'required': False
},
]

def call(self, params: Union[str, dict], files: List[str] = None, **kwargs) -> str:
super().call(params=params, files=files)
params = self._verify_json_format_args(params)

id = params['no'].strip()
company = params.get('type', '').strip()

host = 'https://wuliu.market.alicloudapi.com'
path = '/kdi'
method = 'GET'
appcode = os.environ['AppCode_ExpressTracking'] # 开通服务后 买家中心-查看AppCode
querys = f'no={id}&type={company}'

bodys = {}
url = host + path + '?' + querys
header = {'Authorization': 'APPCODE ' + appcode}
try:
res = requests.get(url, headers=header)
except:
return 'URL错误'
httpStatusCode = res.status_code

if (httpStatusCode == 200):
# print("正常请求计费(其他均不计费)")
import json
try:
out = json.loads(res.text)
except:
out = eval(res.text)
return '```json' + json.dumps(out, ensure_ascii=False, indent=4) + '\n```'
else:
httpReason = res.headers['X-Ca-Error-Message']
if (httpStatusCode == 400 and httpReason == 'Invalid Param Location'):
return '参数错误'
elif (httpStatusCode == 400 and httpReason == 'Invalid AppCode'):
return 'AppCode错误'
elif (httpStatusCode == 400 and httpReason == 'Invalid Url'):
return '请求的 Method、Path 或者环境错误'
elif (httpStatusCode == 403 and httpReason == 'Unauthorized'):
return '服务未被授权(或URL和Path不正确)'
elif (httpStatusCode == 403 and httpReason == 'Quota Exhausted'):
return '套餐包次数用完'
elif (httpStatusCode == 403 and httpReason == 'Api Market Subscription quota exhausted'):
return '套餐包次数用完,请续购套餐'
elif (httpStatusCode == 500):
return 'API网关错误'
else:
return f'参数名错误 或 其他错误 \nhttpStatusCode:{httpStatusCode} \nhttpReason: {httpReason}'


@register_tool('area_to_weather')
class Area2Weather(BaseToolWithFileAccess):
API_URL = 'https://market.aliyun.com/apimarket/detail/cmapi010812#sku=yuncode4812000017'
description = '地名查询天气预报,调用此API查询未来7天的天气,不支持具体时刻天气查询'
parameters = [
{
'name': 'area',
'type': 'string',
'description': '地区名称',
'required': True
},
{
'name': 'needMoreDay',
'type': 'string',
'description': '是否需要返回7天数据中的后4天。1为返回,0为不返回。',
'required': False
},
{
'name': 'needIndex',
'type': 'string',
'description': '是否需要返回指数数据,比如穿衣指数、紫外线指数等。1为返回,0为不返回。',
'required': False
},
{
'name': 'need3HourForcast',
'type': 'string',
'description': '是否需要每小时数据的累积数组。由于本系统是半小时刷一次实时状态,因此实时数组最大长度为48。每天0点长度初始化为0. 1为需要 0为不',
'required': False
},
{
'name': 'needAlarm',
'type': 'string',
'description': '是否需要天气预警。1为需要,0为不需要。',
'required': False
},
{
'name': 'needHourData',
'type': 'string',
'description': '是否需要每小时数据的累积数组。由于本系统是半小时刷一次实时状态,因此实时数组最大长度为48。每天0点长度初始化为0.',
'required': False
},
]

def call(self, params: Union[str, dict], files: List[str] = None, **kwargs) -> str:
super().call(params=params, files=files)
params = self._verify_json_format_args(params)

area = urllib.parse.quote(params['area'].strip())
needMoreDay = str(params.get('needMoreDay', 0)).strip()
needIndex = str(params.get('needIndex', 0)).strip()
needHourData = str(params.get('needHourData', 0)).strip()
need3HourForcast = str(params.get('need3HourForcast', 0)).strip()
needAlarm = str(params.get('needAlarm', 0)).strip()

host = 'https://ali-weather.showapi.com'
path = '/spot-to-weather'
method = 'GET'
appcode = os.environ['AppCode_Area2Weather'] # 开通服务后 买家中心-查看AppCode
querys = f'area={area}&needMoreDay={needMoreDay}&needIndex={needIndex}&needHourData={needHourData}&need3HourForcast={need3HourForcast}&needAlarm={needAlarm}'
bodys = {}
url = host + path + '?' + querys

request = urllib.request.Request(url)
request.add_header('Authorization', 'APPCODE ' + appcode)
ctx = ssl.create_default_context()
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
response = urllib.request.urlopen(request, context=ctx)
byte_string = response.read()
content = byte_string.decode('utf-8')
return content


@register_tool('weather_hour24')
class WeatherHour24(BaseToolWithFileAccess):
API_URL = 'https://market.aliyun.com/apimarket/detail/cmapi010812#sku=yuncode4812000017'
description = '查询24小时预报,调用此API查询24小时内具体时间的天气'
parameters = [
{
'name': 'area',
'type': 'string',
'description': '地区名称',
'required': True
},
]

def call(self, params: Union[str, dict], files: List[str] = None, **kwargs) -> str:
super().call(params=params, files=files)
params = self._verify_json_format_args(params)

area = urllib.parse.quote(params['area'].strip())

host = 'https://ali-weather.showapi.com'
path = '/hour24'
method = 'GET'
appcode = os.environ('AppCode_weather_hour24') # 开通服务后 买家中心-查看AppCode
querys = f'area={area}&areaCode='
bodys = {}
url = host + path + '?' + querys

request = urllib.request.Request(url)
request.add_header('Authorization', 'APPCODE ' + appcode)
ctx = ssl.create_default_context()
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
response = urllib.request.urlopen(request, context=ctx)
byte_string = response.read()
content = byte_string.decode('utf-8')
return content


@register_tool('crop_and_resize')
class CropResize(BaseToolWithFileAccess):
description = '这是一个放大镜功能,截取局部图像并放大从而查看更多细节,如果你无法直接看清细节时可以调用'
Expand Down Expand Up @@ -96,7 +279,7 @@ def call(self, params: Union[str, dict], files: List[str] = None, **kwargs) -> L
]


def test():
def init_agent_service():
llm_cfg_vl = {
# Using Qwen2-VL deployed at any openai-compatible service such as vLLM:
# 'model_type': 'qwenvl_oai',
Expand All @@ -116,7 +299,45 @@ def test():
'generate_cfg': dict(max_retries=10,)
}

agent = FnCallAgent(function_list=['crop_and_resize'], llm=llm_cfg_vl)
tools = [
'crop_and_resize',
'code_interpreter',
] # code_interpreter is a built-in tool in Qwen-Agent

# API tools
if 'AppCode_WeatherHour24' in os.environ:
tools.append('express_tracking')
else:
print(f'Please get AppCode from {WeatherHour24.API_URL} and execute:\nexport AppCode_WeatherHour24=xxx')
print('express_tracking is disabled!')

if 'AppCode_Area2Weather' in os.environ:
tools.append('weather_hour24')
else:
print(f'Please get AppCode from {Area2Weather.API_URL} and execute:\nexport AppCode_Area2Weather=xxx')
print('weather_hour24 is disabled!')

if 'AppCode_ExpressTracking' in os.environ:
tools.append('area_to_weather')
else:
print(f'Please get AppCode from {ExpressTracking.API_URL} and execute:\nexport AppCode_ExpressTracking=xxx')
print('area_to_weather is disabled!')

bot = FnCallAgent(
llm=llm_cfg_vl,
name='Qwen2-VL',
description='function calling',
function_list=tools,
)

return bot


def test():
# Define the agent
bot = init_agent_service()

# Chat
messages = [{
'role':
'user',
Expand All @@ -129,9 +350,18 @@ def test():
},
],
}]
response = agent.run_nonstream(messages=messages)
pprint(response, indent=4)

for response in bot.run(messages=messages):
print('bot response:', response)


def app_gui():
# Define the agent
bot = init_agent_service()

WebUI(bot).run()


if __name__ == '__main__':
test()
# app_gui()
2 changes: 2 additions & 0 deletions qwen_agent/llm/qwenvl_oai.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def _convert_local_images_to_base64(messages: List[Message]) -> List[Message]:
for item in msg.content:
t, v = item.get_type_and_value()
if t == 'image':
if v.startswith('file://'):
v = v[len('file://'):]
if (not v.startswith(('http://', 'https://', 'data:'))) and os.path.exists(v):
item.image = encode_image_as_base64(v, max_short_side_length=1080)
else:
Expand Down

0 comments on commit 5119eb0

Please sign in to comment.