-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_qa.patch
45 lines (39 loc) · 2.09 KB
/
run_qa.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
--- hugging/run_qa.py 2022-07-30 13:30:22.000000000 -0400
+++ hugging-akbc/run_qa.py 2022-07-30 13:27:34.000000000 -0400
@@ -420,7 +420,7 @@
raise ValueError("--do_train requires a train dataset")
train_dataset = raw_datasets["train"]
if data_args.max_train_samples is not None:
- # We will select sample from whole data if argument is specified
+ # We will select sample from whole data if agument is specified
train_dataset = train_dataset.select(range(data_args.max_train_samples))
# Create train feature from dataset
with training_args.main_process_first(desc="train dataset map pre-processing"):
@@ -537,7 +537,7 @@
# Post-processing:
def post_processing_function(examples, features, predictions, stage="eval"):
# Post-processing: we match the start logits and end logits to answers in the original context.
- predictions = postprocess_qa_predictions(
+ predictions, nbest = postprocess_qa_predictions(
examples=examples,
features=features,
predictions=predictions,
@@ -558,7 +558,7 @@
formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()]
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
- return EvalPrediction(predictions=formatted_predictions, label_ids=references)
+ return EvalPrediction(predictions=formatted_predictions, label_ids=references),nbest
metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad")
@@ -613,6 +613,7 @@
if training_args.do_predict:
logger.info("*** Predict ***")
results = trainer.predict(predict_dataset, predict_examples)
+ '''
metrics = results.metrics
max_predict_samples = (
@@ -622,6 +623,7 @@
trainer.log_metrics("predict", metrics)
trainer.save_metrics("predict", metrics)
+ '''
kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "question-answering"}
if data_args.dataset_name is not None: