Skip to content

Commit

Permalink
fix the no event loop issue of code interpreter
Browse files Browse the repository at this point in the history
  • Loading branch information
JianxinMa committed Apr 7, 2024
1 parent 8b049e8 commit eeab80d
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 55 deletions.
2 changes: 1 addition & 1 deletion benchmark/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -245,4 +245,4 @@ The inference_and_exec.py file contains the following configurable options:
- `--eval-code-exec-only`: Only evaluate code executable rate
- `--gen-exec-only`: Only generate and execuate code without calculating evaluation metrics.
- `--gen-only`: Only generate without execuating code and calculating evaluation metrics.
- `--vis-judger`: The model to judge the result correctness for `Visualization` task which can be one of `gpt-4-vision-preview`, `qwen-vl-chat`, `qwen-vl-plus`. It is set to `gpt-4-vision-preview` by default in the version 20231206, and `Qwen-vl-chat` has been deprecated.
- `--vis-judger`: The model to judge the result correctness for `Visualization` task which can be one of `gpt-4-vision-preview`, `qwen-vl-chat`, `qwen-vl-plus`. It is set to `gpt-4-vision-preview` by default in the version 20231206, and `Qwen-vl-chat` has been deprecated.
11 changes: 8 additions & 3 deletions benchmark/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from parser import InternLMReActParser, ReActParser

from models import LLM, QwenVL, Qwen, QwenDashscopeVLModel
from models import LLM, Qwen, QwenDashscopeVLModel, QwenVL
from prompt import InternLMReAct, LlamaReAct, QwenReAct

react_prompt_map = {
Expand All @@ -15,7 +15,12 @@
'internlm': InternLMReActParser,
}

model_map = {'qwen': Qwen, 'llama': LLM, 'internlm': LLM, 'qwen-vl-chat': QwenVL}
model_map = {
'qwen': Qwen,
'llama': LLM,
'internlm': LLM,
'qwen-vl-chat': QwenVL
}

model_type_map = {
'qwen-72b-chat': 'qwen',
Expand Down Expand Up @@ -59,7 +64,7 @@ def get_react_parser(model_name):


def get_model(model_name):
if model_name in ["qwen-vl-plus"]:
if model_name in ['qwen-vl-plus']:
return QwenDashscopeVLModel(model=model_name)
model_path = model_path_map.get(model_name, None)
model_cls = model_map.get(model_type_map[model_name], LLM)
Expand Down
55 changes: 32 additions & 23 deletions benchmark/metrics/visualization.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import base64
import logging
import os
import re
import base64

import torch
from config import get_model, get_react_parser
from utils.data_utils import load_jsonl, save_jsonl
Expand All @@ -23,43 +24,46 @@


def encode_image(image_path):
with open(image_path, "rb") as image_file:
with open(image_path, 'rb') as image_file:
a = base64.b64encode(image_file.read()).decode('utf-8')
return a


def judger_model_inference(judger_model_name, judger_model, imgs=[], prompt=''):
output = ""
def judger_model_inference(judger_model_name,
judger_model,
imgs=[],
prompt=''):
output = ''
if judger_model_name == 'gpt-4-vision-preview':
logging.warning("This is an example of `gpt-4-vision-preview`. "
"Please set the API key and use according to your actual situation.")
logging.warning(
'This is an example of `gpt-4-vision-preview`. '
'Please set the API key and use according to your actual situation.'
)
from openai import OpenAI
client = OpenAI()
content_list = []
content_list.append({"type": "text", "text": prompt})
content_list.append({'type': 'text', 'text': prompt})
input_images = []
for img in imgs:
if 'http' not in img:
base64_image = encode_image(img)
img = f"data:image/jpeg;base64,{base64_image}"
input_images.append({"type": "image_url", 'image_url': img})
img = f'data:image/jpeg;base64,{base64_image}'
input_images.append({'type': 'image_url', 'image_url': img})
content_list.extend(input_images)
response = client.chat.completions.create(
model="gpt-4-vision-preview",
messages=[
{
"role": "user",
"content": content_list,
}
],
model='gpt-4-vision-preview',
messages=[{
'role': 'user',
'content': content_list,
}],
max_tokens=300,
)
output = response.choices[0]
elif judger_model_name in ['qwen-vl-plus', 'qwen-vl-chat']:
inputs = []
for img in imgs:
if 'http' not in img and judger_model_name == 'qwen-vl-plus':
img = "file://" + img
img = 'file://' + img
inputs.append({'image': img})
inputs.append({'text': prompt})

Expand Down Expand Up @@ -105,17 +109,21 @@ def check_images_observation(text, images, model_name):
eval_visual_prompt = {'zh': EVAL_VISUAL_PROMPT_ZH, 'en': EVAL_VISUAL_PROMPT_EN}


def eval_visualization_acc(output_fname, model_name, judger_model_name='gpt-4-vision-preview'):
def eval_visualization_acc(output_fname,
model_name,
judger_model_name='gpt-4-vision-preview'):
if judger_model_name == 'gpt-4-vision-preview':
judger_model = None
elif judger_model_name in ['qwen-vl-chat', 'qwen-vl-plus']:
if judger_model_name == 'qwen-vl-chat':
logging.warning('In this benchmark of version 20231206, `Qwen-vl-chat` is no longer used as the '
'evaluation model for `Visualization` task.. If you insist on using it, '
'the evaluation results might differ from the official results.')
logging.warning(
'In this benchmark of version 20231206, `Qwen-vl-chat` is no longer used as the '
'evaluation model for `Visualization` task.. If you insist on using it, '
'the evaluation results might differ from the official results.'
)
judger_model = get_model(judger_model_name)
else:
raise Exception("Not supported judger model.")
raise Exception('Not supported judger model.')

one_action, one_action_right = 0, 0
zero_action, zero_action_right = 0, 0
Expand All @@ -139,7 +147,8 @@ def eval_visualization_acc(output_fname, model_name, judger_model_name='gpt-4-vi
model_name):
input_prompt = eval_visual_prompt[item.get('lang', 'en')]
format_prompt = input_prompt.format(query=prompt)
output = judger_model_inference(judger_model_name, judger_model, images, format_prompt)
output = judger_model_inference(judger_model_name, judger_model,
images, format_prompt)
if 'right' in output.lower():
item['vis_acc'] = True
if '<|im_end|>' in item['query']:
Expand Down
2 changes: 1 addition & 1 deletion benchmark/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from models.base import HFModel # noqa
from models.dashscope import QwenDashscopeVLModel
from models.llm import LLM # noqa
from models.qwen import Qwen, QwenVL # noqa
from models.dashscope import QwenDashscopeVLModel
14 changes: 10 additions & 4 deletions benchmark/models/dashscope.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import logging
from http import HTTPStatus
import time
from http import HTTPStatus

import dashscope


class QwenDashscopeVLModel(object):

def __init__(self, model, api_key):
self.model = model
dashscope.api_key = api_key.strip() or os.getenv('DASHSCOPE_API_KEY', default='')
dashscope.api_key = api_key.strip() or os.getenv('DASHSCOPE_API_KEY',
default='')
assert dashscope.api_key, 'DASHSCOPE_API_KEY is required.'

def generate(self, prompt, stop_words=[]):
Expand All @@ -19,7 +22,10 @@ def generate(self, prompt, stop_words=[]):
while count < MAX_TRY:
response = dashscope.MultiModalConversation.call(
self.model,
messages=[{'role': 'user', 'content': prompt}],
messages=[{
'role': 'user',
'content': prompt
}],
top_p=0.01,
top_k=1,
)
Expand All @@ -28,7 +34,7 @@ def generate(self, prompt, stop_words=[]):
for stop_str in stop_words:
idx = output.find(stop_str)
if idx != -1:
output = output[: idx + len(stop_str)]
output = output[:idx + len(stop_str)]
return output
else:
err = 'Error code: %s, error message: %s' % (
Expand Down
7 changes: 5 additions & 2 deletions benchmark/models/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,14 @@ def generate(self, input_text, stop_words=[]):


class QwenVL(HFModel):

def __init__(self, model_path):
super().__init__(model_path)

def generate(self, inputs: list):
query = self.tokenizer.from_list_format(inputs)
response, _ = self.model.chat(self.tokenizer, query=query, history=None)
response, _ = self.model.chat(self.tokenizer,
query=query,
history=None)

return response
return response
2 changes: 1 addition & 1 deletion benchmark/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ func_timeout
json5
matplotlib
numpy
openai
pandas
PrettyTable
scipy
seaborn
sympy
transformers==4.33.1
transformers_stream_generator
openai
52 changes: 32 additions & 20 deletions qwen_agent/tools/code_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _start_kernel(pid) -> BlockingKernelClient:

# Client
kc = BlockingKernelClient(connection_file=connection_file)
asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
kc.load_connection_file()
kc.start_channels()
kc.wait_for_ready()
Expand Down Expand Up @@ -317,24 +317,36 @@ def call(self,
return result if result.strip() else 'Finished execution.'


def _get_multiline_input() -> str:
logger.info(
'// Press ENTER to make a new line. Press CTRL-D to end input.')
lines = []
while True:
try:
line = input()
except EOFError: # CTRL-D
break
lines.append(line)
logger.info('// Input received.')
if lines:
return '\n'.join(lines)
else:
return ''
#
# The _BasePolicy and AnyThreadEventLoopPolicy below are borrowed from Tornado.
# Ref: https://www.tornadoweb.org/en/stable/_modules/tornado/platform/asyncio.html#AnyThreadEventLoopPolicy
#

if sys.platform == 'win32' and hasattr(asyncio,
'WindowsSelectorEventLoopPolicy'):
_BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore
else:
_BasePolicy = asyncio.DefaultEventLoopPolicy

if __name__ == '__main__':
tool = CodeInterpreter()
while True:
logger.info(tool.call(_get_multiline_input()))

class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore
"""Event loop policy that allows loop creation on any thread.
The default `asyncio` event loop policy only automatically creates
event loops in the main threads. Other threads must create event
loops explicitly or `asyncio.get_event_loop` (and therefore
`.IOLoop.current`) will fail. Installing this policy allows event
loops to be created automatically on any thread.
Usage::
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
"""

def get_event_loop(self) -> asyncio.AbstractEventLoop:
try:
return super().get_event_loop()
except RuntimeError:
# "There is no current event loop in thread %r"
loop = self.new_event_loop()
self.set_event_loop(loop)
return loop

0 comments on commit eeab80d

Please sign in to comment.