From 76252fd4d7f7dde648a242b9f83e96a8bc3fc953 Mon Sep 17 00:00:00 2001 From: Huiqiang Jiang Date: Tue, 9 Apr 2024 07:39:06 +0000 Subject: [PATCH] Prereleased(LLMLinguia): fix the chunck issue and prepare for v0.2.2 --- .../evaluation/eval_meetingbank_qa.py | 169 ++++++++++-------- experiments/llmlingua2/evaluation/metrics.py | 16 +- llmlingua/prompt_compressor.py | 6 +- llmlingua/version.py | 2 +- 4 files changed, 109 insertions(+), 84 deletions(-) diff --git a/experiments/llmlingua2/evaluation/eval_meetingbank_qa.py b/experiments/llmlingua2/evaluation/eval_meetingbank_qa.py index 54835bb..b962dc8 100644 --- a/experiments/llmlingua2/evaluation/eval_meetingbank_qa.py +++ b/experiments/llmlingua2/evaluation/eval_meetingbank_qa.py @@ -32,89 +32,104 @@ args = parser.parse_args() os.makedirs(os.path.dirname(args.save_path), exist_ok=True) -data = json.load(open(args.load_prompt_from)) -data = data.values() if isinstance(data, dict) else data - -print(f"num data: {len(data)}") - -model, tokenizer = load_model_and_tokenizer(args.model_name_or_path) - -results = defaultdict(dict) -results_list = defaultdict(list) -if os.path.exists(args.save_path): - prev_results = json.load(open(args.save_path)) - results.update(prev_results) -if os.path.exists( - os.path.join( - os.path.dirname(args.save_path), - os.path.basename(args.save_path).replace("answer", "answer_list"), - ) -): - results_list = json.load( - open( - os.path.join( - os.path.dirname(args.save_path), - os.path.basename(args.save_path).replace("answer", "answer_list"), + + +def predict(): + data = json.load(open(args.load_prompt_from)) + data = data.values() if isinstance(data, dict) else data + + print(f"num data: {len(data)}") + + model, tokenizer = load_model_and_tokenizer(args.model_name_or_path) + + results = defaultdict(dict) + results_list = defaultdict(list) + if os.path.exists(args.save_path): + prev_results = json.load(open(args.save_path)) + results.update(prev_results) + if os.path.exists( + os.path.join( + os.path.dirname(args.save_path), + os.path.basename(args.save_path).replace("answer", "answer_list"), + ) + ): + results_list = json.load( + open( + os.path.join( + os.path.dirname(args.save_path), + os.path.basename(args.save_path).replace("answer", "answer_list"), + ) ) ) - ) -prompt = "Write a high-quality answer for the given question using the provided meeting transcript (which may be compressed).\n{transcript}\nQuestion:{question}\nAnswer:" -for sample in tqdm(data): - sample_idx = int(sample["idx"]) - if sample_idx in results or str(sample_idx) in results: - print(f"{sample_idx}-th already processed.") - continue - if args.num_sample > 0 and int(sample_idx) > args.num_sample: - break - transcript = sample[args.load_key] - token_ids = tokenizer.encode(transcript) - if len(token_ids) > args.n_max_token - args.n_max_token_ans: - transcript = tokenizer.decode( - token_ids[: args.n_max_token - args.n_max_token_ans] + prompt = "Write a high-quality answer for the given question using the provided meeting transcript (which may be compressed).\n{transcript}\nQuestion:{question}\nAnswer:" + for sample in tqdm(data): + sample_idx = int(sample["idx"]) + if sample_idx in results or str(sample_idx) in results: + print(f"{sample_idx}-th already processed.") + continue + if args.num_sample > 0 and int(sample_idx) > args.num_sample: + break + transcript = sample[args.load_key] + token_ids = tokenizer.encode(transcript) + if len(token_ids) > args.n_max_token - args.n_max_token_ans: + transcript = tokenizer.decode( + token_ids[: args.n_max_token - args.n_max_token_ans] + ) + qa_list = sample["QA_pairs"] + q_list = [] + a_list = [] + a_list_model = [] + for qa in qa_list: + q = qa["question"] + a = qa["answer"] + query = prompt.format(transcript=transcript, question=q) + answer = query_llm( + query, + model, + args.model_name_or_path, + args.n_max_token_ans, + tokenizer=tokenizer, + ) + q_list.append(q) + a_list.append(a) + a_list_model.append(answer) + + results[sample_idx]["transcript"] = transcript + results[sample_idx]["questions"] = q_list[:] + results[sample_idx]["answers"] = a_list[:] + results[sample_idx]["model_answers"] = a_list_model[:] + + results_list["questions"].extend(q_list[:]) + results_list["answers"].extend(a_list[:]) + results_list["model_answers"].extend(a_list_model[:]) + + json.dump(results, open(args.save_path, "w"), indent=4) + json.dump( + results_list, + open( + os.path.join( + os.path.dirname(args.save_path), + os.path.basename(args.save_path).replace("answer", "answer_list"), + ), + "w", + ), + indent=4, ) - qa_list = sample["QA_pairs"] - q_list = [] - a_list = [] - a_list_model = [] - for qa in qa_list: - q = qa["question"] - a = qa["answer"] - query = prompt.format(transcript=transcript, question=q) - answer = query_llm( - query, - model, - args.model_name_or_path, - args.n_max_token_ans, - tokenizer=tokenizer, + + +predict() +results_list = json.load( + open( + os.path.join( + os.path.dirname(args.save_path), + os.path.basename(args.save_path).replace("answer", "answer_list"), ) - q_list.append(q) - a_list.append(a) - a_list_model.append(answer) - - results[sample_idx]["transcript"] = transcript - results[sample_idx]["questions"] = q_list[:] - results[sample_idx]["answers"] = a_list[:] - results[sample_idx]["model_answers"] = a_list_model[:] - - results_list["questions"].extend(q_list[:]) - results_list["answers"].extend(a_list[:]) - results_list["model_answers"].extend(a_list_model[:]) - - json.dump(results, open(args.save_path, "w"), indent=4) - json.dump( - results_list, - open( - os.path.join( - os.path.dirname(args.save_path), - os.path.basename(args.save_path).replace("answer", "answer_list"), - ), - "w", - ), - indent=4, ) - -score_dict = evaluate_with_gt(results_list["answers"], results_list["model_answers"]) +) +for i, ans in enumerate(results_list["answers"]): + results_list["answers"][i] = [results_list["answers"][i]] +score_dict = evaluate_with_gt(results_list["model_answers"], results_list["answers"]) json.dump( score_dict, open( diff --git a/experiments/llmlingua2/evaluation/metrics.py b/experiments/llmlingua2/evaluation/metrics.py index 3da9ecd..c4aa87a 100644 --- a/experiments/llmlingua2/evaluation/metrics.py +++ b/experiments/llmlingua2/evaluation/metrics.py @@ -156,6 +156,16 @@ def qa_f1_zh_score(prediction, ground_truth, **kwargs): return f1_score(prediction_tokens, ground_truth_tokens) +def qa_score(prediction, ground_truths): + normalized_prediction = normalize_answer2(prediction) + + for ground_truth in ground_truths: + normalized_ground_truth = normalize_answer2(ground_truth) + if normalized_ground_truth.lower() in normalized_prediction.lower(): + return 1.0 + return 0.0 + + import regex @@ -207,12 +217,10 @@ def eval_qa_f1_score(pred, ground_truths): pred_list = pred_list_truncated metrics = { - "qa_f1_score": 0.0, - "best_subspan_em": 0.0, + "qa_score": 0.0, } for pred, gts in zip(pred_list, gt_list): - metrics["qa_f1_score"] += eval_qa_f1_score(pred, gts) - metrics["best_subspan_em"] += best_subspan_em(pred, gts) + metrics["qa_score"] += qa_score(pred, gts) # average for metric_name, score in metrics.items(): metrics[metric_name] = score * 100 / len(pred_list) diff --git a/llmlingua/prompt_compressor.py b/llmlingua/prompt_compressor.py index 9e852fb..91d9e4f 100644 --- a/llmlingua/prompt_compressor.py +++ b/llmlingua/prompt_compressor.py @@ -2407,8 +2407,10 @@ def split_string_to_words(input_string): keep_words = [] word_labels = [] assert len(words) == len(word_probs) - for word, word_porb in zip(words, word_probs): - if word_porb > threshold: + for word, word_prob in zip(words, word_probs): + if word_prob > threshold or ( + threshold == 1.0 and word_prob == threshold + ): if ( drop_consecutive and word in force_tokens diff --git a/llmlingua/version.py b/llmlingua/version.py index a844652..15ce103 100644 --- a/llmlingua/version.py +++ b/llmlingua/version.py @@ -5,7 +5,7 @@ _MINOR = "2" # On master and in a nightly release the patch should be one ahead of the last # released build. -_PATCH = "1" +_PATCH = "2" # This is mainly for nightly builds which have the suffix ".dev$DATE". See # https://semver.org/#is-v123-a-semantic-version for the semantics. _SUFFIX = ""