diff --git a/.gitignore b/.gitignore
index 5b91d12..3e8c0ed 100644
--- a/.gitignore
+++ b/.gitignore
@@ -42,4 +42,7 @@ coverage.xml
*dataset_infos.json
# mypy
-.mypy_cache
\ No newline at end of file
+.mypy_cache
+
+# Misc
+debug_scripts/
\ No newline at end of file
diff --git a/README.md b/README.md
index 82b8f25..d9e7e8f 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,16 @@
-# Trapper (Transformers wRAPPER)
+
Trapper (Transformers wRAPPER)
+
+
+
+
+
+
+
+
+
+
+
+
Trapper is an NLP library that aims to make it easier to train transformer based
models on downstream tasks. It wraps the HuggingFace's `transformers` library to
@@ -300,15 +312,15 @@ thanks to configuration file based experiments.
### Training a POS Tagging Model on CONLL2003
Since the transformers library lacks a direct support for POS tagging, we added an
-example project that trains a transformer model on `CONLL2003` POS tagging dataset
-and perform inference using it. You can find it in `examples/pos_tagging`. It is a
+[example project](./examples/pos_tagging) that trains a transformer model on `CONLL2003` POS tagging dataset
+and perform inference using it. It is a
self-contained project including its own requirements file, therefore you can copy
the folder into another directory to use as a template for your own project. Please
follow its `README.md` to get started.
### Training a Question Answering Model on SQuAD Dataset
-You can use the notebook in `examples/question_answering/question_answering. ipynb`
+You can use the notebook in the [Example QA Project](./examples/question_answering) `examples/question_answering/question_answering.ipynb`
to follow the steps while training a transformer model on SQuAD v1.
## Installation
diff --git a/examples/pos_tagging/README.md b/examples/pos_tagging/README.md
index dbb816a..c960c58 100644
--- a/examples/pos_tagging/README.md
+++ b/examples/pos_tagging/README.md
@@ -4,8 +4,7 @@ This project show how to train a transformer model from on CONLL2003 dataset
from `HuggingFace datasets`. You can explore the dataset
from [its page](https://huggingface.co/datasets/conll2003). This project is intended
to serve as a demo for using trapper as a library to train and evaluate a
-transformer model on a custom task/dataset as well as perform inference using it. We
-start by creating a fresh python environment and install the dependencies.
+transformer model on a custom task/dataset as well as perform inference using it. To see an example of supported task, see [Question answering example](../question_answering). We start by creating a fresh python environment and install the dependencies.
## Environment Creation and Dependency Installation
diff --git a/examples/pos_tagging/requirements.txt b/examples/pos_tagging/requirements.txt
index fbbaf6f..00f4c12 100644
--- a/examples/pos_tagging/requirements.txt
+++ b/examples/pos_tagging/requirements.txt
@@ -1 +1 @@
-trapper==0.0.3
+trapper==0.0.4
diff --git a/examples/question_answering/README.md b/examples/question_answering/README.md
new file mode 100644
index 0000000..62f083b
--- /dev/null
+++ b/examples/question_answering/README.md
@@ -0,0 +1,6 @@
+## Question Answering Demo
+
+
+
+
+This notebook serves as an example for demonstrating training and inference using `trapper`. Question-answering task is supported by `trapper` already, and thus in this notebook we only give a basic [configuration file](./experiment.jsonnet) and let the trapper take care of the rest. For implementation of a desired task using trapper, see [Pos tagging example](../pos_tagging).
diff --git a/examples/question_answering/experiments/question-answering/experiment.jsonnet b/examples/question_answering/experiment.jsonnet
similarity index 95%
rename from examples/question_answering/experiments/question-answering/experiment.jsonnet
rename to examples/question_answering/experiment.jsonnet
index 3037151..d42bbfb 100644
--- a/examples/question_answering/experiments/question-answering/experiment.jsonnet
+++ b/examples/question_answering/experiment.jsonnet
@@ -30,10 +30,10 @@ local result_dir = std.extVar("OUTPUT_PATH");
"type": "default",
"output_dir": checkpoint_dir,
"result_dir": result_dir,
- "num_train_epochs": 2,
+ "num_train_epochs": 10,
"per_device_train_batch_size": 2,
"gradient_accumulation_steps": 12,
- "per_device_eval_batch_size": 4,
+ "per_device_eval_batch_size": 2,
"logging_dir": checkpoint_dir + "/logs",
"no_cuda": false,
"logging_steps": 500,
diff --git a/examples/question_answering/experiment.py b/examples/question_answering/experiment.py
deleted file mode 100644
index 6151255..0000000
--- a/examples/question_answering/experiment.py
+++ /dev/null
@@ -1,127 +0,0 @@
-import argparse
-import os
-import warnings
-from typing import Dict, Optional, Tuple
-
-import requests
-
-from examples.question_answering.util import (
- DATASET_DIR,
- DEFAULT_EXTRA_VARIABLES,
- get_dir_from_task,
-)
-from trapper.training.train import run_experiment
-
-__arguments__ = ["config", "task", "experiment_name"]
-
-
-def download_squad(
- task: str, version: str = "1.1", overwrite: bool = False
-) -> Tuple[str, str]:
- """
- Downloads SQuAD dataset with given version.
-
- Args:
- task:
- version: SQuAD dataset version.
- overwrite: If true, overwrites the destination file.
-
- Returns: (train set path, dev set path) local paths of downloaded dataset files.
-
- """
- destination_dir = DATASET_DIR.format(task=task)
- os.makedirs(destination_dir, exist_ok=True)
- dataset_base_url = "https://rajpurkar.github.io/SQuAD-explorer/dataset/"
-
- train_set = f"train-v{version}.json"
- dev_set = f"dev-v{version}.json"
-
- datasets = [train_set, dev_set]
- paths = []
-
- for dataset in datasets:
- dest_name = "train.json" if "train" in dataset else "dev.json"
- url = os.path.join(dataset_base_url, dataset)
- dest = os.path.join(destination_dir, dest_name)
- paths.append(dest)
-
- if not overwrite and os.path.exists(dest):
- warnings.warn(f"{dest} already exists, not overwriting.")
- continue
-
- r = requests.get(url, allow_redirects=True)
-
- with open(dest, "wb") as out_file:
- out_file.write(r.content)
-
- return paths[0], paths[1]
-
-
-def create_arguments():
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--config", type=str, required=True, help="Path to experiment.jsonnet"
- )
- parser.add_argument("--task", type=str, required=True, help="Name of the task")
- parser.add_argument("--experiment-name", type=str, default=None)
-
- # Handle unset arguments
- parsed, unknown = parser.parse_known_args()
- for arg in unknown:
- if arg.startswith(("-", "--")):
- # you can pass any arguments to add_argument
- parser.add_argument(arg.split("=")[0])
-
- return parser.parse_args()
-
-
-def get_extra_variables(args):
- ext_vars = {}
- for arg, val in args.__dict__.items():
- if arg not in __arguments__:
- ext_vars[arg.upper()] = val
- return ext_vars
-
-
-def validate_extra_variables(
- extra_vars: Dict[str, str], task: Optional[str] = None
-):
- for key, val in DEFAULT_EXTRA_VARIABLES.items():
- if key not in extra_vars:
- extra_vars[key] = get_dir_from_task(val, task=task)
-
- return extra_vars
-
-
-def start_experiment(config: str, task: str, ext_vars: Dict[str, str]):
- ext_vars = validate_extra_variables(extra_vars=ext_vars, task=task)
- result = run_experiment(
- config_path=config,
- ext_vars=ext_vars,
- )
- print("Training complete.")
- return result
-
-
-def main():
- experiment_name = "roberta-base-training-example"
- task = "question-answering"
- working_dir = os.getcwd()
- experiments_dir = os.path.join(working_dir, "experiments")
- task_dir = get_dir_from_task(os.path.join(experiments_dir, "{task}"), task=task)
- experiment_dir = os.path.join(task_dir, experiment_name)
- checkpoint_dir = os.path.join(experiment_dir, "checkpoints")
- output_dir = os.path.join(experiment_dir, "outputs")
- ext_vars = {
- # Used to feed the jsonnet config file with file paths
- "OUTPUT_PATH": output_dir,
- "CHECKPOINT_PATH": checkpoint_dir,
- }
- config_path = os.path.join(
- task_dir, "experiment.jsonnet"
- ) # default experiment params
- start_experiment(config=config_path, task=task, ext_vars=ext_vars)
-
-
-if __name__ == "__main__":
- main()
diff --git a/examples/question_answering/question_answering.ipynb b/examples/question_answering/question_answering.ipynb
index ff525d7..049cb0f 100644
--- a/examples/question_answering/question_answering.ipynb
+++ b/examples/question_answering/question_answering.ipynb
@@ -1,23 +1,21 @@
{
"cells": [
{
- "cell_type": "markdown",
+ "cell_type": "code",
+ "execution_count": null,
"metadata": {},
+ "outputs": [],
"source": [
- "# Training\n",
- "\n",
- "This notebook serves as a walkthrough for training with trapper package."
+ "!pip install trapper jury"
]
},
{
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "scrolled": true
- },
- "outputs": [],
+ "cell_type": "markdown",
+ "metadata": {},
"source": [
- "cd .."
+ "# Training\n",
+ "\n",
+ "This notebook serves as a walkthrough for training with trapper package."
]
},
{
@@ -30,15 +28,21 @@
"\n",
"from copy import deepcopy\n",
"import os\n",
- "import json\n",
- "from typing import Any, Dict, List, Tuple, Union\n",
- "import warnings\n",
+ "from typing import Dict, List, Union\n",
"\n",
"from jury import Jury\n",
- "import requests\n",
- "from tqdm import tqdm\n",
"\n",
- "from trapper.training.train import run_experiment"
+ "from trapper.training.train import run_experiment\n",
+ "from trapper.common.notebook_utils import prepare_data, load_json"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "prepare_data()"
]
},
{
@@ -72,50 +76,10 @@
"EXPERIMENT_NAME = \"roberta-base-training-example\"\n",
"\n",
"WORKING_DIR = os.getcwd()\n",
- "EXPERIMENTS_DIR = os.path.join(WORKING_DIR, \"examples/experiments\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Helper functions\n",
- "\n",
- "Some useful helper functions to ease training."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "def get_dir_from_task(path: str, task: str):\n",
- " task = \"unnamed-task\" if task is None else task\n",
- " return path.format(task=task)\n",
- "\n",
- "def start_experiment(config: str, task: str, ext_vars: Dict[str, str]):\n",
- " result = run_experiment(\n",
- " config_path=config,\n",
- " ext_vars=ext_vars,\n",
- " )\n",
+ "PROJECT_ROOT = os.path.dirname(os.path.dirname(WORKING_DIR))\n",
+ "EXPERIMENT_DIR = os.path.join(WORKING_DIR, EXPERIMENT_NAME)\n",
+ "CONFIG_PATH = os.path.join(WORKING_DIR, \"experiment.jsonnet\") # default experiment params\n",
"\n",
- " print(\"Training complete.\")\n",
- " return result"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "scrolled": true
- },
- "outputs": [],
- "source": [
- "TASK = \"question-answering\"\n",
- "TASK_DIR = get_dir_from_task(os.path.join(EXPERIMENTS_DIR, \"{task}\"), task=TASK)\n",
- "DATASET_DIR = os.path.join(TASK_DIR, \"datasets\")\n",
- "EXPERIMENT_DIR = os.path.join(TASK_DIR, EXPERIMENT_NAME)\n",
"MODEL_DIR = os.path.join(EXPERIMENT_DIR, \"model\")\n",
"CHECKPOINT_DIR = os.path.join(EXPERIMENT_DIR, \"checkpoints\")\n",
"OUTPUT_DIR = os.path.join(EXPERIMENT_DIR, \"outputs\")"
@@ -125,7 +89,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
- "scrolled": false
+ "scrolled": true
},
"outputs": [],
"source": [
@@ -135,7 +99,10 @@
" \"CHECKPOINT_PATH\": CHECKPOINT_DIR\n",
"}\n",
"\n",
- "CONFIG_PATH = os.path.join(TASK_DIR, \"experiment.jsonnet\") # default experiment params"
+ "result = run_experiment(\n",
+ " config_path=CONFIG_PATH,\n",
+ " ext_vars=ext_vars,\n",
+ ")"
]
},
{
@@ -145,19 +112,6 @@
"scrolled": true
},
"outputs": [],
- "source": [
- "result = start_experiment(\n",
- " config=CONFIG_PATH,\n",
- " task=TASK,\n",
- " ext_vars=ext_vars,\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
"source": [
"result"
]
@@ -177,6 +131,7 @@
"metadata": {},
"outputs": [],
"source": [
+ "# required to register the pipeline\n",
"from trapper.pipelines.question_answering_pipeline import SquadQuestionAnsweringPipeline\n",
"from trapper.pipelines.pipeline import create_pipeline_from_checkpoint"
]
@@ -196,16 +151,6 @@
"metadata": {},
"outputs": [],
"source": [
- "def save_json(samples: List[Dict], path: str):\n",
- " with open(path, \"w\") as jf:\n",
- " json.dump(samples, jf)\n",
- "\n",
- "\n",
- "def load_json(path: str):\n",
- " with open(path, \"r\") as jf:\n",
- " return json.load(jf)\n",
- "\n",
- "\n",
"def prepare_samples(data: Union[str, Dict]):\n",
" if isinstance(data, str):\n",
" data = load_json(data)\n",
@@ -226,8 +171,10 @@
"\n",
"def prepare_samples_for_pipeline(samples: List[Dict]):\n",
" pipeline_samples = deepcopy(samples)\n",
- " for sample in pipeline_samples:\n",
+ " for i, sample in enumerate(pipeline_samples):\n",
" sample.pop(\"gold_answers\")\n",
+ " if \"id\" not in sample:\n",
+ " sample[\"id\"] = str(i)\n",
" return pipeline_samples\n",
"\n",
"\n",
@@ -245,10 +192,10 @@
"metadata": {},
"outputs": [],
"source": [
- "SQUAD_DEV = \"/home/devrimcavusoglu/lab/bb/nqg/datasets/squad/squad-qa/dev-v1.1.json\"\n",
- "EXPORT_PATH = \"/home/devrimcavusoglu/lab/bb/nqg/notebooks/qa-predictions_pipeline.json\"\n",
+ "DEV_SET = \"squad_qa_test_fixture/dev.json\"\n",
+ "EXPORT_PATH = os.path.join(WORKING_DIR, \"qa-outputs.json\")\n",
"\n",
- "PRETRAINED_MODEL_PATH = \"/home/devrimcavusoglu/lab/bb/nqg/models/question_answering/roberta-large_ep_5_ebs_384_lr_5e-05-v1\"\n",
+ "PRETRAINED_MODEL_PATH = OUTPUT_DIR\n",
"EXPERIMENT_CONFIG = os.path.join(PRETRAINED_MODEL_PATH, \"experiment_config.json\")"
]
},
@@ -264,7 +211,6 @@
" checkpoint_path=PRETRAINED_MODEL_PATH,\n",
" experiment_config_path=EXPERIMENT_CONFIG,\n",
" task=\"squad-question-answering\",\n",
- " device=0\n",
")"
]
},
@@ -272,31 +218,24 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
- "scrolled": true
+ "scrolled": false
},
"outputs": [],
"source": [
- "samples = prepare_samples(SQUAD_DEV)"
+ "samples = prepare_samples(DEV_SET)"
]
},
{
"cell_type": "code",
"execution_count": null,
- "metadata": {},
+ "metadata": {
+ "scrolled": true
+ },
"outputs": [],
"source": [
"predictions = predict(qa_pipeline, samples)"
]
},
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "save_json(predictions, EXPORT_PATH)"
- ]
- },
{
"cell_type": "code",
"execution_count": null,
@@ -313,13 +252,15 @@
"metadata": {},
"outputs": [],
"source": [
- "jury = Jury(metrics=[\"squad_f1\", \"squad_em\"])"
+ "jury = Jury(metrics=\"squad\")"
]
},
{
"cell_type": "code",
"execution_count": null,
- "metadata": {},
+ "metadata": {
+ "scrolled": true
+ },
"outputs": [],
"source": [
"jury.evaluate(references=references, predictions=hypotheses)"
@@ -342,9 +283,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.7.10"
+ "version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 1
-}
+}
\ No newline at end of file
diff --git a/examples/question_answering/util.py b/examples/question_answering/util.py
deleted file mode 100644
index 58754ed..0000000
--- a/examples/question_answering/util.py
+++ /dev/null
@@ -1,68 +0,0 @@
-import os
-import warnings
-from typing import Tuple
-
-import requests
-
-WORKING_DIR = os.path.abspath(os.path.dirname(__file__))
-EXPERIMENTS_DIR = os.path.join(WORKING_DIR, "experiments")
-TASK_DIR = os.path.join(EXPERIMENTS_DIR, "{task}")
-DATASET_DIR = os.path.join(TASK_DIR, "datasets")
-CHECKPOINT_DIR = os.path.join(TASK_DIR, "checkpoints")
-OUTPUT_DIR = os.path.join(TASK_DIR, "outputs")
-
-DEFAULT_EXTRA_VARIABLES = {
- "OUTPUT_PATH": OUTPUT_DIR,
- "CHECKPOINT_PATH": CHECKPOINT_DIR,
-}
-
-
-def download_squad(
- task: str, version: str = "1.1", overwrite: bool = False
-) -> Tuple[str, str]:
- """
- Downloads SQuAD dataset with given version.
-
- Args:
- version: SQuAD dataset version.
- overwrite: If true, overwrites the destination file.
-
- Returns: (train set path, dev set path) local paths of downloaded dataset files.
-
- """
- destination_dir = DATASET_DIR.format(task=task)
- os.makedirs(destination_dir, exist_ok=True)
- dataset_base_url = "https://rajpurkar.github.io/SQuAD-explorer/dataset/"
-
- train_set = f"train-v{version}.json"
- dev_set = f"dev-v{version}.json"
-
- datasets = [train_set, dev_set]
- paths = []
-
- for dataset in datasets:
- url = os.path.join(dataset_base_url, dataset)
- dest = os.path.join(destination_dir, dataset)
- paths.append(dest)
-
- if not overwrite and os.path.exists(dest):
- warnings.warn(f"{dest} already exists, not overwriting.")
- continue
-
- r = requests.get(url, allow_redirects=True)
-
- with open(dest, "wb") as out_file:
- out_file.write(r.content)
-
- return paths[0], paths[1]
-
-
-def create_dir(prefix: str, name: str):
- path = os.path.join(prefix, name)
- os.makedirs(path, exist_ok=True)
- return path
-
-
-def get_dir_from_task(path: str, task: str):
- task = "unnamed-task" if task is None else task
- return path.format(task=task)
diff --git a/trapper/common/notebook_utils/__init__.py b/trapper/common/notebook_utils/__init__.py
new file mode 100644
index 0000000..ffb6057
--- /dev/null
+++ b/trapper/common/notebook_utils/__init__.py
@@ -0,0 +1,3 @@
+from trapper.common.notebook_utils.file_transfer import download_from_url
+from trapper.common.notebook_utils.io import load_json, save_json
+from trapper.common.notebook_utils.prepare_data import prepare_data
diff --git a/trapper/common/notebook_utils/file_transfer.py b/trapper/common/notebook_utils/file_transfer.py
new file mode 100644
index 0000000..e5d4a11
--- /dev/null
+++ b/trapper/common/notebook_utils/file_transfer.py
@@ -0,0 +1,26 @@
+import os
+import urllib.request
+from pathlib import Path
+from typing import Optional
+
+
+def download_from_url(url: str, destination: Optional[str] = None) -> None:
+ """
+ Utility function to download data from a specified url.
+
+ Args:
+ url: Source url of data to be downloaded.
+ destination: Destination where the downloaded data is placed. If None,
+ base name of the url is used, i.e if url="a/b.txt", it will be
+ downloaded to "./b.txt".
+ """
+ if destination is None:
+ destination = os.path.basename(url)
+
+ Path(destination).parent.mkdir(parents=True, exist_ok=True)
+
+ if not os.path.exists(destination):
+ urllib.request.urlretrieve(
+ url,
+ destination,
+ )
diff --git a/trapper/common/notebook_utils/io.py b/trapper/common/notebook_utils/io.py
new file mode 100644
index 0000000..985fb91
--- /dev/null
+++ b/trapper/common/notebook_utils/io.py
@@ -0,0 +1,12 @@
+import json
+from typing import Dict, List
+
+
+def load_json(path: str):
+ with open(path, "r") as jf:
+ return json.load(jf)
+
+
+def save_json(samples: List[Dict], path: str):
+ with open(path, "w") as jf:
+ json.dump(samples, jf)
diff --git a/trapper/common/notebook_utils/prepare_data.py b/trapper/common/notebook_utils/prepare_data.py
new file mode 100644
index 0000000..327bfc3
--- /dev/null
+++ b/trapper/common/notebook_utils/prepare_data.py
@@ -0,0 +1,26 @@
+import os
+
+from trapper.common.notebook_utils.file_transfer import download_from_url
+
+FIXTURES_PATH = "squad_qa_test_fixture"
+SQUAD_QA_FIXTURES = {
+ "dev.json": "https://raw.githubusercontent.com/obss/trapper/main/test_fixtures/hf_datasets/squad_qa_test_fixture/dev.json",
+ "train.json": "https://raw.githubusercontent.com/obss/trapper/main/test_fixtures/hf_datasets/squad_qa_test_fixture/train.json",
+ "squad_qa_test_fixture.py": "https://raw.githubusercontent.com/obss/trapper/main/test_fixtures/hf_datasets/squad_qa_test_fixture/squad_qa_test_fixture.py",
+}
+EXPERIMENT_CONFIG = "https://raw.githubusercontent.com/obss/trapper/main/examples/question_answering/experiment.jsonnet"
+
+
+def download_fixture_data():
+ for file, url in SQUAD_QA_FIXTURES.items():
+ destination = os.path.join(FIXTURES_PATH, file)
+ download_from_url(url, destination)
+
+
+def fetch_experiment_config():
+ download_from_url(EXPERIMENT_CONFIG)
+
+
+def prepare_data():
+ download_fixture_data()
+ fetch_experiment_config()
diff --git a/trapper/data/data_adapters/question_answering_adapter.py b/trapper/data/data_adapters/question_answering_adapter.py
index e47f3f5..020537a 100644
--- a/trapper/data/data_adapters/question_answering_adapter.py
+++ b/trapper/data/data_adapters/question_answering_adapter.py
@@ -18,12 +18,6 @@ class DataAdapterForQuestionAnswering(DataAdapter):
CONTEXT_TOKEN_TYPE_ID = 0
QUESTION_TOKEN_TYPE_ID = 1
- def __init__(
- self,
- tokenizer_wrapper: TokenizerWrapper,
- ):
- super().__init__(tokenizer_wrapper)
-
def __call__(self, raw_instance: IndexedInstance) -> IndexedInstance:
"""
Create a sequence with the following fields:
diff --git a/trapper/data/data_processors/squad/question_answering_processor.py b/trapper/data/data_processors/squad/question_answering_processor.py
index 1576711..d847560 100644
--- a/trapper/data/data_processors/squad/question_answering_processor.py
+++ b/trapper/data/data_processors/squad/question_answering_processor.py
@@ -19,11 +19,9 @@ class SquadQuestionAnsweringDataProcessor(SquadDataProcessor):
MAX_SEQUENCE_LEN = 512
def process(self, instance_dict: Dict[str, Any]) -> IndexedInstance:
- id_ = instance_dict["id"]
+ qa_id = instance_dict["id"]
context = instance_dict["context"]
- question = convert_spandict_to_spantuple(
- {"text": instance_dict["question"], "start": -1}
- )
+ question = instance_dict["question"]
if self._is_input_too_long(context, question):
return self.filtered_instance()
# Rename SQuAD answer_start as start for trapper tuple conversion.
@@ -35,7 +33,7 @@ def process(self, instance_dict: Dict[str, Any]) -> IndexedInstance:
return self.text_to_instance(
context=context,
question=question,
- id_=id_,
+ id_=qa_id,
answer=first_answer,
)
except ImproperDataInstanceError:
@@ -47,17 +45,16 @@ def filtered_instance() -> IndexedInstance:
"answer": [-1],
"answer_position_tokenized": {"start": -1, "end": -1},
"context": [-1],
- "qa_id": -1,
+ "qa_id": "",
"question": [-1],
"__discard_sample": True,
}
def text_to_instance(
- self, context: str, question: SpanTuple, id_: str, answer: SpanTuple = None
+ self, context: str, question: str, id_: str, answer: SpanTuple = None
) -> IndexedInstance:
- question = self._join_whitespace_prefix(context, question)
tokenized_context = self._tokenizer.tokenize(context)
- tokenized_question = self._tokenizer.tokenize(question.text)
+ tokenized_question = self._tokenizer.tokenize(question)
self._chop_excess_context_tokens(tokenized_context, tokenized_question)
instance = {
@@ -75,9 +72,9 @@ def text_to_instance(
instance["qa_id"] = id_
return instance
- def _is_input_too_long(self, context: str, question: SpanTuple) -> bool:
+ def _is_input_too_long(self, context: str, question: str) -> bool:
context_tokens = self.tokenizer.tokenize(context)
- question_tokens = self.tokenizer.tokenize(question.text)
+ question_tokens = self.tokenizer.tokenize(question)
return (
len(context_tokens)
+ len(question_tokens)
diff --git a/trapper/pipelines/question_answering_pipeline.py b/trapper/pipelines/question_answering_pipeline.py
index 509d39e..be4995c 100644
--- a/trapper/pipelines/question_answering_pipeline.py
+++ b/trapper/pipelines/question_answering_pipeline.py
@@ -68,17 +68,21 @@ def normalize(self, item):
elif isinstance(item[k], str) and len(item[k]) == 0:
raise ValueError("`{}` cannot be empty".format(k))
- question = {"text": item["question"], "start": None}
- item["question"] = self._convert_to_span_tuple(question)
+ self._add_id(item)
return item
raise ValueError("{} argument needs to be of type dict".format(item))
@staticmethod
def _convert_to_span_tuple(span: Union[SpanDict, SpanTuple]) -> SpanTuple:
- if isinstance(span, SpanDict):
+ if isinstance(span, dict):
span = convert_spandict_to_spantuple(span)
return span
+ @staticmethod
+ def _add_id(item: Dict) -> None:
+ item["id_"] = item["id"]
+ item.pop("id")
+
def __call__(self, *args, **kwargs):
if args is not None and len(args) > 0:
inputs = self._handle_single_input(args)
@@ -166,7 +170,7 @@ def __init__(
framework: Optional[str] = None,
device: int = -1,
task: str = "",
- **kwargs
+ **kwargs, # For the ignored arguments
):
super().__init__(
model=model,
@@ -175,7 +179,6 @@ def __init__(
framework=framework,
device=device,
task=task,
- **kwargs,
)
self._args_parser = QuestionAnsweringArgumentHandler()
@@ -348,30 +351,37 @@ def _construct_answer(
start_token_ind: int,
end_token_ind: int,
) -> SpanTuple:
- answer_start_ind = self._get_answer_start_ind(
- context, input_ids, start_token_ind
- )
- answer_inds = list(range(start_token_ind, end_token_ind))
- answer_token_ids = [input_ids[ind] for ind in answer_inds]
- decoded_answer = self.tokenizer.decode(
- answer_token_ids, skip_special_tokens=True
- ).strip()
- case_corrected_answer = context[
- answer_start_ind : answer_start_ind + len(decoded_answer)
- ]
- answer: SpanDict = {
- "start": answer_start_ind,
- "text": case_corrected_answer,
- }
+ answer_start_ind = self._get_answer_start_ind(context, start_token_ind)
+ if answer_start_ind is None:
+ answer: SpanDict = {
+ "start": -1,
+ "text": "",
+ }
+ else:
+ answer_token_ids = input_ids[start_token_ind:end_token_ind]
+ decoded_answer = self.tokenizer.decode(
+ answer_token_ids, skip_special_tokens=True
+ ).strip()
+ case_corrected_answer = context[
+ answer_start_ind : answer_start_ind + len(decoded_answer)
+ ]
+ answer: SpanDict = {
+ "start": answer_start_ind,
+ "text": case_corrected_answer,
+ }
return convert_spandict_to_spantuple(answer)
- def _get_answer_start_ind(self, context, input_ids, start_token_ind):
- answer_prefix_inds = list(range(0, start_token_ind))
- answer_prefix_token_ids = [input_ids[ind] for ind in answer_prefix_inds]
+ def _get_answer_start_ind(self, context, start_token_ind):
+ context_tokenized = self.tokenizer(context)["input_ids"]
+ if start_token_ind > len(context_tokenized):
+ return None
+
+ answer_prefix_token_ids = context_tokenized[0:start_token_ind]
answer_prefix = self.tokenizer.decode(
answer_prefix_token_ids, skip_special_tokens=True
)
answer_start_ind = len(answer_prefix)
+
if context[answer_start_ind] == " ":
answer_start_ind += 1
return answer_start_ind
diff --git a/trapper/version.py b/trapper/version.py
index 98c891c..e427974 100644
--- a/trapper/version.py
+++ b/trapper/version.py
@@ -1,5 +1,5 @@
_MAJOR = "0"
_MINOR = "0"
-_PATCH = "3"
+_PATCH = "4"
VERSION = "{0}.{1}.{2}".format(_MAJOR, _MINOR, _PATCH)