Skip to content

Commit

Permalink
[Feat] Add mix_evals audio2text (#420)
Browse files Browse the repository at this point in the history
* Add mix_evals audio2text

* Fix task tags in datasets
  • Loading branch information
kcz358 authored Nov 23, 2024
1 parent ce5e7c9 commit 1cc17b9
Show file tree
Hide file tree
Showing 7 changed files with 220 additions and 3 deletions.
2 changes: 1 addition & 1 deletion lmms_eval/tasks/air_bench/air_bench_chat.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
group: air_bench_chat
tasks:
task:
- air_bench_chat_sound
- air_bench_chat_music
- air_bench_chat_speech
Expand Down
2 changes: 1 addition & 1 deletion lmms_eval/tasks/air_bench/air_bench_foundation.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
group: air_bench_foundation
tasks:
task:
- air_bench_foundation_sound
- air_bench_foundation_music
- air_bench_foundation_speech
2 changes: 1 addition & 1 deletion lmms_eval/tasks/clotho_aqa/clotho_aqa.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
group: clotho_aqa
tasks:
task:
- clotho_aqa_val
- clotho_aqa_test
10 changes: 10 additions & 0 deletions lmms_eval/tasks/mix_evals/audio2text/_default_template_yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
dataset_kwargs:
token: true
dataset_path: lmms-lab/MixEval-X-audio2text
lmms_eval_specific_kwargs:
default:
post_prompt: ""
pre_prompt: ""
metadata:
gpt_eval_model_name: gpt-4o-mini
version: 0
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
task: "mix_evals_audio2text_freeform"
test_split: free_form
output_type: generate_until
doc_to_visual: !function utils.mix_evals_audio2text_doc_to_audio
doc_to_text: !function utils.mix_evals_audio2text_doc_to_text
doc_to_target: !function utils.mix_evals_audio2text_doc_to_target
process_results: !function utils.mix_evals_audio2text_process_results_freeform
metric_list:
- metric: gpt_eval
aggregation: !function utils.mix_evals_audio2text_gpt_eval
higher_is_better: true

generation_kwargs:
max_new_tokens: 64

include: _default_template_yaml

lmms_eval_specific_kwargs:
default:
pre_prompt: ""
post_prompt: "Answer the question using a single word or phrase."
gpt4v:
pre_prompt: "These are frames from a video. Please answer the following questions about the video with a short phrase."
post_prompt: ""
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
task: "mix_evals_audio2text_freeform_hard"
test_split: free_form_hard
output_type: generate_until
doc_to_visual: !function utils.mix_evals_audio2text_doc_to_audio
doc_to_text: !function utils.mix_evals_audio2text_doc_to_text
doc_to_target: !function utils.mix_evals_audio2text_doc_to_target
process_results: !function utils.mix_evals_audio2text_process_results_freeform
metric_list:
- metric: gpt_eval
aggregation: !function utils.mix_evals_audio2text_gpt_eval
higher_is_better: true

generation_kwargs:
max_new_tokens: 64

include: _default_template_yaml

lmms_eval_specific_kwargs:
default:
pre_prompt: ""
post_prompt: "Answer the question using a single word or phrase."
gpt4v:
pre_prompt: "These are frames from a video. Please answer the following questions about the video with a short phrase."
post_prompt: ""
159 changes: 159 additions & 0 deletions lmms_eval/tasks/mix_evals/audio2text/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import datetime
import json
import os
import re
import sys
import time
from pathlib import Path

import requests
import yaml
from loguru import logger as eval_logger

import lmms_eval.tasks._task_utils.file_utils as file_utils
from lmms_eval.filters.extraction import ExtendedRegexFilter

with open(Path(__file__).parent / "_default_template_yaml", "r") as f:
raw_data = f.readlines()
safe_data = []
for i, line in enumerate(raw_data):
# remove function definition since yaml load cannot handle it
if "!function" not in line:
safe_data.append(line)

config = yaml.safe_load("".join(safe_data))

NUM_SECONDS_TO_SLEEP = 5
GPT_EVAL_MODEL_NAME = config["metadata"]["gpt_eval_model_name"]
API_TYPE = os.getenv("API_TYPE", "openai")

if API_TYPE == "openai":
API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions")
API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY")
headers = {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json",
}
elif API_TYPE == "azure":
API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken")
API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY")
headers = {
"api-key": API_KEY,
"Content-Type": "application/json",
}

eval_prompt = """You are an AI assistant who will help me to evaluate the quality of a model response to a few candidate ground truth answers.
Some criterion
- Response that perfectly reflect the meaning of the ground truth: 1 point
- Response that reflect none of the key points in the ground truth: 0 point
- Some part in the response are correct but some parts in the ground truth are not mentioned in the response: 0.5 point
- Some part in the response are correct but other parts in the response are not mentioned in the ground truth: 0.5 point
Here're some examples about the scoring criterion and format:
model response: Steam Cleaning Services
ground truth: ["steam clean", "steam clean", "cleaning", "car", "steam clean"],
Point: 1
model response: A cowboy action shooter.
ground truth: ["man"]
Point: 1
model response: I'm sorry, but I can't assist with that request.
ground truth: ["quality"]
Point: 0
Let's begin this task:
model response: {model_response}
ground truth: {ground_truth}
Point:"""


def get_eval(model_response: str, ground_truth: str, max_tokens: int, retries: int = 5):
global headers
content = eval_prompt.format(model_response=model_response, ground_truth=ground_truth)

messages = [
{"role": "user", "content": content},
]

payload = {
"model": GPT_EVAL_MODEL_NAME,
"messages": messages,
"temperature": 0.2,
"max_tokens": max_tokens,
}

for attempt in range(retries):
try:
response = requests.post(API_URL, headers=headers, json=payload, timeout=60)
response.raise_for_status()
response_data = response.json()

content = response_data["choices"][0]["message"]["content"].strip()
if content != "":
return content, response_data["model"]
break # If successful, break out of the loop

except Exception as e:
eval_logger.info(f"Attempt {attempt + 1} failed with error: {e}")
if attempt < retries: # If we have retries left, sleep and then continue to next attempt
time.sleep(NUM_SECONDS_TO_SLEEP)
else: # If this was the last attempt, log and return empty
eval_logger.error(f"All {retries} attempts failed. Last error message: {e}")
return "0", ""
return "", ""


def mix_evals_audio2text_doc_to_audio(doc):
return [doc["audio"]]


def mix_evals_audio2text_doc_to_target(doc):
return doc["reference_answer"][0]


# This is the place where you format your question
def mix_evals_audio2text_doc_to_text(doc, lmms_eval_specific_kwargs=None):
if lmms_eval_specific_kwargs is None:
lmms_eval_specific_kwargs = {}
pre_prompt = ""
post_prompt = ""
if "pre_prompt" in lmms_eval_specific_kwargs:
pre_prompt = lmms_eval_specific_kwargs["pre_prompt"]
if "post_prompt" in lmms_eval_specific_kwargs:
post_prompt = lmms_eval_specific_kwargs["post_prompt"]

user_prompt = doc["query"]

if pre_prompt:
user_prompt = f"{pre_prompt}\n{user_prompt}"

if post_prompt:
user_prompt = f"{user_prompt}\n{post_prompt}"
return user_prompt


def mix_evals_audio2text_process_results_freeform(doc, result):
pred = result[0]
ground_truth_str = doc["reference_answer"][0]
content = eval_prompt.format(model_response=pred, ground_truth=ground_truth_str)
eval_answer, model_name = get_eval(model_response=pred, ground_truth=ground_truth_str, max_tokens=1024)
return {
"gpt_eval": {"pred": pred, "id": doc["id"], "target": ground_truth_str, "eval_answer": eval_answer, "gpt_prompt": content},
}


def mix_evals_audio2text_gpt_eval(results, args):
score = 0
for result in results:
eval_answer = result["eval_answer"]
eval_score = re.search(r"([0-9.]+)", eval_answer).group(1)
try:
eval_score = float(eval_score)
except Exception as e:
eval_logger.error(f"Error parsing eval_score: {e}")
eval_score = 0.0
score += eval_score

return score / len(results)

0 comments on commit 1cc17b9

Please sign in to comment.