Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prereleased(LLMLinguia): fix the chuck issue and prepare for v0.2.2 #130

Merged
merged 1 commit into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 92 additions & 77 deletions experiments/llmlingua2/evaluation/eval_meetingbank_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,89 +32,104 @@

args = parser.parse_args()
os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
data = json.load(open(args.load_prompt_from))
data = data.values() if isinstance(data, dict) else data

print(f"num data: {len(data)}")

model, tokenizer = load_model_and_tokenizer(args.model_name_or_path)

results = defaultdict(dict)
results_list = defaultdict(list)
if os.path.exists(args.save_path):
prev_results = json.load(open(args.save_path))
results.update(prev_results)
if os.path.exists(
os.path.join(
os.path.dirname(args.save_path),
os.path.basename(args.save_path).replace("answer", "answer_list"),
)
):
results_list = json.load(
open(
os.path.join(
os.path.dirname(args.save_path),
os.path.basename(args.save_path).replace("answer", "answer_list"),


def predict():
data = json.load(open(args.load_prompt_from))
data = data.values() if isinstance(data, dict) else data

print(f"num data: {len(data)}")

model, tokenizer = load_model_and_tokenizer(args.model_name_or_path)

results = defaultdict(dict)
results_list = defaultdict(list)
if os.path.exists(args.save_path):
prev_results = json.load(open(args.save_path))
results.update(prev_results)
if os.path.exists(
os.path.join(
os.path.dirname(args.save_path),
os.path.basename(args.save_path).replace("answer", "answer_list"),
)
):
results_list = json.load(
open(
os.path.join(
os.path.dirname(args.save_path),
os.path.basename(args.save_path).replace("answer", "answer_list"),
)
)
)
)

prompt = "Write a high-quality answer for the given question using the provided meeting transcript (which may be compressed).\n{transcript}\nQuestion:{question}\nAnswer:"
for sample in tqdm(data):
sample_idx = int(sample["idx"])
if sample_idx in results or str(sample_idx) in results:
print(f"{sample_idx}-th already processed.")
continue
if args.num_sample > 0 and int(sample_idx) > args.num_sample:
break
transcript = sample[args.load_key]
token_ids = tokenizer.encode(transcript)
if len(token_ids) > args.n_max_token - args.n_max_token_ans:
transcript = tokenizer.decode(
token_ids[: args.n_max_token - args.n_max_token_ans]
prompt = "Write a high-quality answer for the given question using the provided meeting transcript (which may be compressed).\n{transcript}\nQuestion:{question}\nAnswer:"
for sample in tqdm(data):
sample_idx = int(sample["idx"])
if sample_idx in results or str(sample_idx) in results:
print(f"{sample_idx}-th already processed.")
continue
if args.num_sample > 0 and int(sample_idx) > args.num_sample:
break
transcript = sample[args.load_key]
token_ids = tokenizer.encode(transcript)
if len(token_ids) > args.n_max_token - args.n_max_token_ans:
transcript = tokenizer.decode(
token_ids[: args.n_max_token - args.n_max_token_ans]
)
qa_list = sample["QA_pairs"]
q_list = []
a_list = []
a_list_model = []
for qa in qa_list:
q = qa["question"]
a = qa["answer"]
query = prompt.format(transcript=transcript, question=q)
answer = query_llm(
query,
model,
args.model_name_or_path,
args.n_max_token_ans,
tokenizer=tokenizer,
)
q_list.append(q)
a_list.append(a)
a_list_model.append(answer)

results[sample_idx]["transcript"] = transcript
results[sample_idx]["questions"] = q_list[:]
results[sample_idx]["answers"] = a_list[:]
results[sample_idx]["model_answers"] = a_list_model[:]

results_list["questions"].extend(q_list[:])
results_list["answers"].extend(a_list[:])
results_list["model_answers"].extend(a_list_model[:])

json.dump(results, open(args.save_path, "w"), indent=4)
json.dump(
results_list,
open(
os.path.join(
os.path.dirname(args.save_path),
os.path.basename(args.save_path).replace("answer", "answer_list"),
),
"w",
),
indent=4,
)
qa_list = sample["QA_pairs"]
q_list = []
a_list = []
a_list_model = []
for qa in qa_list:
q = qa["question"]
a = qa["answer"]
query = prompt.format(transcript=transcript, question=q)
answer = query_llm(
query,
model,
args.model_name_or_path,
args.n_max_token_ans,
tokenizer=tokenizer,


predict()
results_list = json.load(
open(
os.path.join(
os.path.dirname(args.save_path),
os.path.basename(args.save_path).replace("answer", "answer_list"),
)
q_list.append(q)
a_list.append(a)
a_list_model.append(answer)

results[sample_idx]["transcript"] = transcript
results[sample_idx]["questions"] = q_list[:]
results[sample_idx]["answers"] = a_list[:]
results[sample_idx]["model_answers"] = a_list_model[:]

results_list["questions"].extend(q_list[:])
results_list["answers"].extend(a_list[:])
results_list["model_answers"].extend(a_list_model[:])

json.dump(results, open(args.save_path, "w"), indent=4)
json.dump(
results_list,
open(
os.path.join(
os.path.dirname(args.save_path),
os.path.basename(args.save_path).replace("answer", "answer_list"),
),
"w",
),
indent=4,
)

score_dict = evaluate_with_gt(results_list["answers"], results_list["model_answers"])
)
for i, ans in enumerate(results_list["answers"]):
results_list["answers"][i] = [results_list["answers"][i]]
score_dict = evaluate_with_gt(results_list["model_answers"], results_list["answers"])
json.dump(
score_dict,
open(
Expand Down
16 changes: 12 additions & 4 deletions experiments/llmlingua2/evaluation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,16 @@ def qa_f1_zh_score(prediction, ground_truth, **kwargs):
return f1_score(prediction_tokens, ground_truth_tokens)


def qa_score(prediction, ground_truths):
normalized_prediction = normalize_answer2(prediction)

for ground_truth in ground_truths:
normalized_ground_truth = normalize_answer2(ground_truth)
if normalized_ground_truth.lower() in normalized_prediction.lower():
return 1.0
return 0.0


import regex


Expand Down Expand Up @@ -207,12 +217,10 @@ def eval_qa_f1_score(pred, ground_truths):
pred_list = pred_list_truncated

metrics = {
"qa_f1_score": 0.0,
"best_subspan_em": 0.0,
"qa_score": 0.0,
}
for pred, gts in zip(pred_list, gt_list):
metrics["qa_f1_score"] += eval_qa_f1_score(pred, gts)
metrics["best_subspan_em"] += best_subspan_em(pred, gts)
metrics["qa_score"] += qa_score(pred, gts)
# average
for metric_name, score in metrics.items():
metrics[metric_name] = score * 100 / len(pred_list)
Expand Down
6 changes: 4 additions & 2 deletions llmlingua/prompt_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2407,8 +2407,10 @@ def split_string_to_words(input_string):
keep_words = []
word_labels = []
assert len(words) == len(word_probs)
for word, word_porb in zip(words, word_probs):
if word_porb > threshold:
for word, word_prob in zip(words, word_probs):
if word_prob > threshold or (
threshold == 1.0 and word_prob == threshold
):
if (
drop_consecutive
and word in force_tokens
Expand Down
2 changes: 1 addition & 1 deletion llmlingua/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
_MINOR = "2"
# On master and in a nightly release the patch should be one ahead of the last
# released build.
_PATCH = "1"
_PATCH = "2"
# This is mainly for nightly builds which have the suffix ".dev$DATE". See
# https://semver.org/#is-v123-a-semantic-version for the semantics.
_SUFFIX = ""
Expand Down