diff --git a/.gitignore b/.gitignore index 5b91d12..3e8c0ed 100644 --- a/.gitignore +++ b/.gitignore @@ -42,4 +42,7 @@ coverage.xml *dataset_infos.json # mypy -.mypy_cache \ No newline at end of file +.mypy_cache + +# Misc +debug_scripts/ \ No newline at end of file diff --git a/README.md b/README.md index 82b8f25..d9e7e8f 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,16 @@ -# Trapper (Transformers wRAPPER) +

Trapper (Transformers wRAPPER)

+ +

+Python versions +PyPI version +Latest Release +Open in Colab +
+Build status +Dependencies +Code style: black +License: MIT +

Trapper is an NLP library that aims to make it easier to train transformer based models on downstream tasks. It wraps the HuggingFace's `transformers` library to @@ -300,15 +312,15 @@ thanks to configuration file based experiments. ### Training a POS Tagging Model on CONLL2003 Since the transformers library lacks a direct support for POS tagging, we added an -example project that trains a transformer model on `CONLL2003` POS tagging dataset -and perform inference using it. You can find it in `examples/pos_tagging`. It is a +[example project](./examples/pos_tagging) that trains a transformer model on `CONLL2003` POS tagging dataset +and perform inference using it. It is a self-contained project including its own requirements file, therefore you can copy the folder into another directory to use as a template for your own project. Please follow its `README.md` to get started. ### Training a Question Answering Model on SQuAD Dataset -You can use the notebook in `examples/question_answering/question_answering. ipynb` +You can use the notebook in the [Example QA Project](./examples/question_answering) `examples/question_answering/question_answering.ipynb` to follow the steps while training a transformer model on SQuAD v1. ## Installation diff --git a/examples/pos_tagging/README.md b/examples/pos_tagging/README.md index dbb816a..c960c58 100644 --- a/examples/pos_tagging/README.md +++ b/examples/pos_tagging/README.md @@ -4,8 +4,7 @@ This project show how to train a transformer model from on CONLL2003 dataset from `HuggingFace datasets`. You can explore the dataset from [its page](https://huggingface.co/datasets/conll2003). This project is intended to serve as a demo for using trapper as a library to train and evaluate a -transformer model on a custom task/dataset as well as perform inference using it. We -start by creating a fresh python environment and install the dependencies. +transformer model on a custom task/dataset as well as perform inference using it. To see an example of supported task, see [Question answering example](../question_answering). We start by creating a fresh python environment and install the dependencies. ## Environment Creation and Dependency Installation diff --git a/examples/pos_tagging/requirements.txt b/examples/pos_tagging/requirements.txt index fbbaf6f..00f4c12 100644 --- a/examples/pos_tagging/requirements.txt +++ b/examples/pos_tagging/requirements.txt @@ -1 +1 @@ -trapper==0.0.3 +trapper==0.0.4 diff --git a/examples/question_answering/README.md b/examples/question_answering/README.md new file mode 100644 index 0000000..62f083b --- /dev/null +++ b/examples/question_answering/README.md @@ -0,0 +1,6 @@ +## Question Answering Demo + +Open in Colab + + +This notebook serves as an example for demonstrating training and inference using `trapper`. Question-answering task is supported by `trapper` already, and thus in this notebook we only give a basic [configuration file](./experiment.jsonnet) and let the trapper take care of the rest. For implementation of a desired task using trapper, see [Pos tagging example](../pos_tagging). diff --git a/examples/question_answering/experiments/question-answering/experiment.jsonnet b/examples/question_answering/experiment.jsonnet similarity index 95% rename from examples/question_answering/experiments/question-answering/experiment.jsonnet rename to examples/question_answering/experiment.jsonnet index 3037151..d42bbfb 100644 --- a/examples/question_answering/experiments/question-answering/experiment.jsonnet +++ b/examples/question_answering/experiment.jsonnet @@ -30,10 +30,10 @@ local result_dir = std.extVar("OUTPUT_PATH"); "type": "default", "output_dir": checkpoint_dir, "result_dir": result_dir, - "num_train_epochs": 2, + "num_train_epochs": 10, "per_device_train_batch_size": 2, "gradient_accumulation_steps": 12, - "per_device_eval_batch_size": 4, + "per_device_eval_batch_size": 2, "logging_dir": checkpoint_dir + "/logs", "no_cuda": false, "logging_steps": 500, diff --git a/examples/question_answering/experiment.py b/examples/question_answering/experiment.py deleted file mode 100644 index 6151255..0000000 --- a/examples/question_answering/experiment.py +++ /dev/null @@ -1,127 +0,0 @@ -import argparse -import os -import warnings -from typing import Dict, Optional, Tuple - -import requests - -from examples.question_answering.util import ( - DATASET_DIR, - DEFAULT_EXTRA_VARIABLES, - get_dir_from_task, -) -from trapper.training.train import run_experiment - -__arguments__ = ["config", "task", "experiment_name"] - - -def download_squad( - task: str, version: str = "1.1", overwrite: bool = False -) -> Tuple[str, str]: - """ - Downloads SQuAD dataset with given version. - - Args: - task: - version: SQuAD dataset version. - overwrite: If true, overwrites the destination file. - - Returns: (train set path, dev set path) local paths of downloaded dataset files. - - """ - destination_dir = DATASET_DIR.format(task=task) - os.makedirs(destination_dir, exist_ok=True) - dataset_base_url = "https://rajpurkar.github.io/SQuAD-explorer/dataset/" - - train_set = f"train-v{version}.json" - dev_set = f"dev-v{version}.json" - - datasets = [train_set, dev_set] - paths = [] - - for dataset in datasets: - dest_name = "train.json" if "train" in dataset else "dev.json" - url = os.path.join(dataset_base_url, dataset) - dest = os.path.join(destination_dir, dest_name) - paths.append(dest) - - if not overwrite and os.path.exists(dest): - warnings.warn(f"{dest} already exists, not overwriting.") - continue - - r = requests.get(url, allow_redirects=True) - - with open(dest, "wb") as out_file: - out_file.write(r.content) - - return paths[0], paths[1] - - -def create_arguments(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--config", type=str, required=True, help="Path to experiment.jsonnet" - ) - parser.add_argument("--task", type=str, required=True, help="Name of the task") - parser.add_argument("--experiment-name", type=str, default=None) - - # Handle unset arguments - parsed, unknown = parser.parse_known_args() - for arg in unknown: - if arg.startswith(("-", "--")): - # you can pass any arguments to add_argument - parser.add_argument(arg.split("=")[0]) - - return parser.parse_args() - - -def get_extra_variables(args): - ext_vars = {} - for arg, val in args.__dict__.items(): - if arg not in __arguments__: - ext_vars[arg.upper()] = val - return ext_vars - - -def validate_extra_variables( - extra_vars: Dict[str, str], task: Optional[str] = None -): - for key, val in DEFAULT_EXTRA_VARIABLES.items(): - if key not in extra_vars: - extra_vars[key] = get_dir_from_task(val, task=task) - - return extra_vars - - -def start_experiment(config: str, task: str, ext_vars: Dict[str, str]): - ext_vars = validate_extra_variables(extra_vars=ext_vars, task=task) - result = run_experiment( - config_path=config, - ext_vars=ext_vars, - ) - print("Training complete.") - return result - - -def main(): - experiment_name = "roberta-base-training-example" - task = "question-answering" - working_dir = os.getcwd() - experiments_dir = os.path.join(working_dir, "experiments") - task_dir = get_dir_from_task(os.path.join(experiments_dir, "{task}"), task=task) - experiment_dir = os.path.join(task_dir, experiment_name) - checkpoint_dir = os.path.join(experiment_dir, "checkpoints") - output_dir = os.path.join(experiment_dir, "outputs") - ext_vars = { - # Used to feed the jsonnet config file with file paths - "OUTPUT_PATH": output_dir, - "CHECKPOINT_PATH": checkpoint_dir, - } - config_path = os.path.join( - task_dir, "experiment.jsonnet" - ) # default experiment params - start_experiment(config=config_path, task=task, ext_vars=ext_vars) - - -if __name__ == "__main__": - main() diff --git a/examples/question_answering/question_answering.ipynb b/examples/question_answering/question_answering.ipynb index ff525d7..049cb0f 100644 --- a/examples/question_answering/question_answering.ipynb +++ b/examples/question_answering/question_answering.ipynb @@ -1,23 +1,21 @@ { "cells": [ { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ - "# Training\n", - "\n", - "This notebook serves as a walkthrough for training with trapper package." + "!pip install trapper jury" ] }, { - "cell_type": "code", - "execution_count": null, - "metadata": { - "scrolled": true - }, - "outputs": [], + "cell_type": "markdown", + "metadata": {}, "source": [ - "cd .." + "# Training\n", + "\n", + "This notebook serves as a walkthrough for training with trapper package." ] }, { @@ -30,15 +28,21 @@ "\n", "from copy import deepcopy\n", "import os\n", - "import json\n", - "from typing import Any, Dict, List, Tuple, Union\n", - "import warnings\n", + "from typing import Dict, List, Union\n", "\n", "from jury import Jury\n", - "import requests\n", - "from tqdm import tqdm\n", "\n", - "from trapper.training.train import run_experiment" + "from trapper.training.train import run_experiment\n", + "from trapper.common.notebook_utils import prepare_data, load_json" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prepare_data()" ] }, { @@ -72,50 +76,10 @@ "EXPERIMENT_NAME = \"roberta-base-training-example\"\n", "\n", "WORKING_DIR = os.getcwd()\n", - "EXPERIMENTS_DIR = os.path.join(WORKING_DIR, \"examples/experiments\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Helper functions\n", - "\n", - "Some useful helper functions to ease training." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def get_dir_from_task(path: str, task: str):\n", - " task = \"unnamed-task\" if task is None else task\n", - " return path.format(task=task)\n", - "\n", - "def start_experiment(config: str, task: str, ext_vars: Dict[str, str]):\n", - " result = run_experiment(\n", - " config_path=config,\n", - " ext_vars=ext_vars,\n", - " )\n", + "PROJECT_ROOT = os.path.dirname(os.path.dirname(WORKING_DIR))\n", + "EXPERIMENT_DIR = os.path.join(WORKING_DIR, EXPERIMENT_NAME)\n", + "CONFIG_PATH = os.path.join(WORKING_DIR, \"experiment.jsonnet\") # default experiment params\n", "\n", - " print(\"Training complete.\")\n", - " return result" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "TASK = \"question-answering\"\n", - "TASK_DIR = get_dir_from_task(os.path.join(EXPERIMENTS_DIR, \"{task}\"), task=TASK)\n", - "DATASET_DIR = os.path.join(TASK_DIR, \"datasets\")\n", - "EXPERIMENT_DIR = os.path.join(TASK_DIR, EXPERIMENT_NAME)\n", "MODEL_DIR = os.path.join(EXPERIMENT_DIR, \"model\")\n", "CHECKPOINT_DIR = os.path.join(EXPERIMENT_DIR, \"checkpoints\")\n", "OUTPUT_DIR = os.path.join(EXPERIMENT_DIR, \"outputs\")" @@ -125,7 +89,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "scrolled": false + "scrolled": true }, "outputs": [], "source": [ @@ -135,7 +99,10 @@ " \"CHECKPOINT_PATH\": CHECKPOINT_DIR\n", "}\n", "\n", - "CONFIG_PATH = os.path.join(TASK_DIR, \"experiment.jsonnet\") # default experiment params" + "result = run_experiment(\n", + " config_path=CONFIG_PATH,\n", + " ext_vars=ext_vars,\n", + ")" ] }, { @@ -145,19 +112,6 @@ "scrolled": true }, "outputs": [], - "source": [ - "result = start_experiment(\n", - " config=CONFIG_PATH,\n", - " task=TASK,\n", - " ext_vars=ext_vars,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "result" ] @@ -177,6 +131,7 @@ "metadata": {}, "outputs": [], "source": [ + "# required to register the pipeline\n", "from trapper.pipelines.question_answering_pipeline import SquadQuestionAnsweringPipeline\n", "from trapper.pipelines.pipeline import create_pipeline_from_checkpoint" ] @@ -196,16 +151,6 @@ "metadata": {}, "outputs": [], "source": [ - "def save_json(samples: List[Dict], path: str):\n", - " with open(path, \"w\") as jf:\n", - " json.dump(samples, jf)\n", - "\n", - "\n", - "def load_json(path: str):\n", - " with open(path, \"r\") as jf:\n", - " return json.load(jf)\n", - "\n", - "\n", "def prepare_samples(data: Union[str, Dict]):\n", " if isinstance(data, str):\n", " data = load_json(data)\n", @@ -226,8 +171,10 @@ "\n", "def prepare_samples_for_pipeline(samples: List[Dict]):\n", " pipeline_samples = deepcopy(samples)\n", - " for sample in pipeline_samples:\n", + " for i, sample in enumerate(pipeline_samples):\n", " sample.pop(\"gold_answers\")\n", + " if \"id\" not in sample:\n", + " sample[\"id\"] = str(i)\n", " return pipeline_samples\n", "\n", "\n", @@ -245,10 +192,10 @@ "metadata": {}, "outputs": [], "source": [ - "SQUAD_DEV = \"/home/devrimcavusoglu/lab/bb/nqg/datasets/squad/squad-qa/dev-v1.1.json\"\n", - "EXPORT_PATH = \"/home/devrimcavusoglu/lab/bb/nqg/notebooks/qa-predictions_pipeline.json\"\n", + "DEV_SET = \"squad_qa_test_fixture/dev.json\"\n", + "EXPORT_PATH = os.path.join(WORKING_DIR, \"qa-outputs.json\")\n", "\n", - "PRETRAINED_MODEL_PATH = \"/home/devrimcavusoglu/lab/bb/nqg/models/question_answering/roberta-large_ep_5_ebs_384_lr_5e-05-v1\"\n", + "PRETRAINED_MODEL_PATH = OUTPUT_DIR\n", "EXPERIMENT_CONFIG = os.path.join(PRETRAINED_MODEL_PATH, \"experiment_config.json\")" ] }, @@ -264,7 +211,6 @@ " checkpoint_path=PRETRAINED_MODEL_PATH,\n", " experiment_config_path=EXPERIMENT_CONFIG,\n", " task=\"squad-question-answering\",\n", - " device=0\n", ")" ] }, @@ -272,31 +218,24 @@ "cell_type": "code", "execution_count": null, "metadata": { - "scrolled": true + "scrolled": false }, "outputs": [], "source": [ - "samples = prepare_samples(SQUAD_DEV)" + "samples = prepare_samples(DEV_SET)" ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "scrolled": true + }, "outputs": [], "source": [ "predictions = predict(qa_pipeline, samples)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "save_json(predictions, EXPORT_PATH)" - ] - }, { "cell_type": "code", "execution_count": null, @@ -313,13 +252,15 @@ "metadata": {}, "outputs": [], "source": [ - "jury = Jury(metrics=[\"squad_f1\", \"squad_em\"])" + "jury = Jury(metrics=\"squad\")" ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "scrolled": true + }, "outputs": [], "source": [ "jury.evaluate(references=references, predictions=hypotheses)" @@ -342,9 +283,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.10" + "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 1 -} +} \ No newline at end of file diff --git a/examples/question_answering/util.py b/examples/question_answering/util.py deleted file mode 100644 index 58754ed..0000000 --- a/examples/question_answering/util.py +++ /dev/null @@ -1,68 +0,0 @@ -import os -import warnings -from typing import Tuple - -import requests - -WORKING_DIR = os.path.abspath(os.path.dirname(__file__)) -EXPERIMENTS_DIR = os.path.join(WORKING_DIR, "experiments") -TASK_DIR = os.path.join(EXPERIMENTS_DIR, "{task}") -DATASET_DIR = os.path.join(TASK_DIR, "datasets") -CHECKPOINT_DIR = os.path.join(TASK_DIR, "checkpoints") -OUTPUT_DIR = os.path.join(TASK_DIR, "outputs") - -DEFAULT_EXTRA_VARIABLES = { - "OUTPUT_PATH": OUTPUT_DIR, - "CHECKPOINT_PATH": CHECKPOINT_DIR, -} - - -def download_squad( - task: str, version: str = "1.1", overwrite: bool = False -) -> Tuple[str, str]: - """ - Downloads SQuAD dataset with given version. - - Args: - version: SQuAD dataset version. - overwrite: If true, overwrites the destination file. - - Returns: (train set path, dev set path) local paths of downloaded dataset files. - - """ - destination_dir = DATASET_DIR.format(task=task) - os.makedirs(destination_dir, exist_ok=True) - dataset_base_url = "https://rajpurkar.github.io/SQuAD-explorer/dataset/" - - train_set = f"train-v{version}.json" - dev_set = f"dev-v{version}.json" - - datasets = [train_set, dev_set] - paths = [] - - for dataset in datasets: - url = os.path.join(dataset_base_url, dataset) - dest = os.path.join(destination_dir, dataset) - paths.append(dest) - - if not overwrite and os.path.exists(dest): - warnings.warn(f"{dest} already exists, not overwriting.") - continue - - r = requests.get(url, allow_redirects=True) - - with open(dest, "wb") as out_file: - out_file.write(r.content) - - return paths[0], paths[1] - - -def create_dir(prefix: str, name: str): - path = os.path.join(prefix, name) - os.makedirs(path, exist_ok=True) - return path - - -def get_dir_from_task(path: str, task: str): - task = "unnamed-task" if task is None else task - return path.format(task=task) diff --git a/trapper/common/notebook_utils/__init__.py b/trapper/common/notebook_utils/__init__.py new file mode 100644 index 0000000..ffb6057 --- /dev/null +++ b/trapper/common/notebook_utils/__init__.py @@ -0,0 +1,3 @@ +from trapper.common.notebook_utils.file_transfer import download_from_url +from trapper.common.notebook_utils.io import load_json, save_json +from trapper.common.notebook_utils.prepare_data import prepare_data diff --git a/trapper/common/notebook_utils/file_transfer.py b/trapper/common/notebook_utils/file_transfer.py new file mode 100644 index 0000000..e5d4a11 --- /dev/null +++ b/trapper/common/notebook_utils/file_transfer.py @@ -0,0 +1,26 @@ +import os +import urllib.request +from pathlib import Path +from typing import Optional + + +def download_from_url(url: str, destination: Optional[str] = None) -> None: + """ + Utility function to download data from a specified url. + + Args: + url: Source url of data to be downloaded. + destination: Destination where the downloaded data is placed. If None, + base name of the url is used, i.e if url="a/b.txt", it will be + downloaded to "./b.txt". + """ + if destination is None: + destination = os.path.basename(url) + + Path(destination).parent.mkdir(parents=True, exist_ok=True) + + if not os.path.exists(destination): + urllib.request.urlretrieve( + url, + destination, + ) diff --git a/trapper/common/notebook_utils/io.py b/trapper/common/notebook_utils/io.py new file mode 100644 index 0000000..985fb91 --- /dev/null +++ b/trapper/common/notebook_utils/io.py @@ -0,0 +1,12 @@ +import json +from typing import Dict, List + + +def load_json(path: str): + with open(path, "r") as jf: + return json.load(jf) + + +def save_json(samples: List[Dict], path: str): + with open(path, "w") as jf: + json.dump(samples, jf) diff --git a/trapper/common/notebook_utils/prepare_data.py b/trapper/common/notebook_utils/prepare_data.py new file mode 100644 index 0000000..327bfc3 --- /dev/null +++ b/trapper/common/notebook_utils/prepare_data.py @@ -0,0 +1,26 @@ +import os + +from trapper.common.notebook_utils.file_transfer import download_from_url + +FIXTURES_PATH = "squad_qa_test_fixture" +SQUAD_QA_FIXTURES = { + "dev.json": "https://raw.githubusercontent.com/obss/trapper/main/test_fixtures/hf_datasets/squad_qa_test_fixture/dev.json", + "train.json": "https://raw.githubusercontent.com/obss/trapper/main/test_fixtures/hf_datasets/squad_qa_test_fixture/train.json", + "squad_qa_test_fixture.py": "https://raw.githubusercontent.com/obss/trapper/main/test_fixtures/hf_datasets/squad_qa_test_fixture/squad_qa_test_fixture.py", +} +EXPERIMENT_CONFIG = "https://raw.githubusercontent.com/obss/trapper/main/examples/question_answering/experiment.jsonnet" + + +def download_fixture_data(): + for file, url in SQUAD_QA_FIXTURES.items(): + destination = os.path.join(FIXTURES_PATH, file) + download_from_url(url, destination) + + +def fetch_experiment_config(): + download_from_url(EXPERIMENT_CONFIG) + + +def prepare_data(): + download_fixture_data() + fetch_experiment_config() diff --git a/trapper/data/data_adapters/question_answering_adapter.py b/trapper/data/data_adapters/question_answering_adapter.py index e47f3f5..020537a 100644 --- a/trapper/data/data_adapters/question_answering_adapter.py +++ b/trapper/data/data_adapters/question_answering_adapter.py @@ -18,12 +18,6 @@ class DataAdapterForQuestionAnswering(DataAdapter): CONTEXT_TOKEN_TYPE_ID = 0 QUESTION_TOKEN_TYPE_ID = 1 - def __init__( - self, - tokenizer_wrapper: TokenizerWrapper, - ): - super().__init__(tokenizer_wrapper) - def __call__(self, raw_instance: IndexedInstance) -> IndexedInstance: """ Create a sequence with the following fields: diff --git a/trapper/data/data_processors/squad/question_answering_processor.py b/trapper/data/data_processors/squad/question_answering_processor.py index 1576711..d847560 100644 --- a/trapper/data/data_processors/squad/question_answering_processor.py +++ b/trapper/data/data_processors/squad/question_answering_processor.py @@ -19,11 +19,9 @@ class SquadQuestionAnsweringDataProcessor(SquadDataProcessor): MAX_SEQUENCE_LEN = 512 def process(self, instance_dict: Dict[str, Any]) -> IndexedInstance: - id_ = instance_dict["id"] + qa_id = instance_dict["id"] context = instance_dict["context"] - question = convert_spandict_to_spantuple( - {"text": instance_dict["question"], "start": -1} - ) + question = instance_dict["question"] if self._is_input_too_long(context, question): return self.filtered_instance() # Rename SQuAD answer_start as start for trapper tuple conversion. @@ -35,7 +33,7 @@ def process(self, instance_dict: Dict[str, Any]) -> IndexedInstance: return self.text_to_instance( context=context, question=question, - id_=id_, + id_=qa_id, answer=first_answer, ) except ImproperDataInstanceError: @@ -47,17 +45,16 @@ def filtered_instance() -> IndexedInstance: "answer": [-1], "answer_position_tokenized": {"start": -1, "end": -1}, "context": [-1], - "qa_id": -1, + "qa_id": "", "question": [-1], "__discard_sample": True, } def text_to_instance( - self, context: str, question: SpanTuple, id_: str, answer: SpanTuple = None + self, context: str, question: str, id_: str, answer: SpanTuple = None ) -> IndexedInstance: - question = self._join_whitespace_prefix(context, question) tokenized_context = self._tokenizer.tokenize(context) - tokenized_question = self._tokenizer.tokenize(question.text) + tokenized_question = self._tokenizer.tokenize(question) self._chop_excess_context_tokens(tokenized_context, tokenized_question) instance = { @@ -75,9 +72,9 @@ def text_to_instance( instance["qa_id"] = id_ return instance - def _is_input_too_long(self, context: str, question: SpanTuple) -> bool: + def _is_input_too_long(self, context: str, question: str) -> bool: context_tokens = self.tokenizer.tokenize(context) - question_tokens = self.tokenizer.tokenize(question.text) + question_tokens = self.tokenizer.tokenize(question) return ( len(context_tokens) + len(question_tokens) diff --git a/trapper/pipelines/question_answering_pipeline.py b/trapper/pipelines/question_answering_pipeline.py index 509d39e..be4995c 100644 --- a/trapper/pipelines/question_answering_pipeline.py +++ b/trapper/pipelines/question_answering_pipeline.py @@ -68,17 +68,21 @@ def normalize(self, item): elif isinstance(item[k], str) and len(item[k]) == 0: raise ValueError("`{}` cannot be empty".format(k)) - question = {"text": item["question"], "start": None} - item["question"] = self._convert_to_span_tuple(question) + self._add_id(item) return item raise ValueError("{} argument needs to be of type dict".format(item)) @staticmethod def _convert_to_span_tuple(span: Union[SpanDict, SpanTuple]) -> SpanTuple: - if isinstance(span, SpanDict): + if isinstance(span, dict): span = convert_spandict_to_spantuple(span) return span + @staticmethod + def _add_id(item: Dict) -> None: + item["id_"] = item["id"] + item.pop("id") + def __call__(self, *args, **kwargs): if args is not None and len(args) > 0: inputs = self._handle_single_input(args) @@ -166,7 +170,7 @@ def __init__( framework: Optional[str] = None, device: int = -1, task: str = "", - **kwargs + **kwargs, # For the ignored arguments ): super().__init__( model=model, @@ -175,7 +179,6 @@ def __init__( framework=framework, device=device, task=task, - **kwargs, ) self._args_parser = QuestionAnsweringArgumentHandler() @@ -348,30 +351,37 @@ def _construct_answer( start_token_ind: int, end_token_ind: int, ) -> SpanTuple: - answer_start_ind = self._get_answer_start_ind( - context, input_ids, start_token_ind - ) - answer_inds = list(range(start_token_ind, end_token_ind)) - answer_token_ids = [input_ids[ind] for ind in answer_inds] - decoded_answer = self.tokenizer.decode( - answer_token_ids, skip_special_tokens=True - ).strip() - case_corrected_answer = context[ - answer_start_ind : answer_start_ind + len(decoded_answer) - ] - answer: SpanDict = { - "start": answer_start_ind, - "text": case_corrected_answer, - } + answer_start_ind = self._get_answer_start_ind(context, start_token_ind) + if answer_start_ind is None: + answer: SpanDict = { + "start": -1, + "text": "", + } + else: + answer_token_ids = input_ids[start_token_ind:end_token_ind] + decoded_answer = self.tokenizer.decode( + answer_token_ids, skip_special_tokens=True + ).strip() + case_corrected_answer = context[ + answer_start_ind : answer_start_ind + len(decoded_answer) + ] + answer: SpanDict = { + "start": answer_start_ind, + "text": case_corrected_answer, + } return convert_spandict_to_spantuple(answer) - def _get_answer_start_ind(self, context, input_ids, start_token_ind): - answer_prefix_inds = list(range(0, start_token_ind)) - answer_prefix_token_ids = [input_ids[ind] for ind in answer_prefix_inds] + def _get_answer_start_ind(self, context, start_token_ind): + context_tokenized = self.tokenizer(context)["input_ids"] + if start_token_ind > len(context_tokenized): + return None + + answer_prefix_token_ids = context_tokenized[0:start_token_ind] answer_prefix = self.tokenizer.decode( answer_prefix_token_ids, skip_special_tokens=True ) answer_start_ind = len(answer_prefix) + if context[answer_start_ind] == " ": answer_start_ind += 1 return answer_start_ind diff --git a/trapper/version.py b/trapper/version.py index 98c891c..e427974 100644 --- a/trapper/version.py +++ b/trapper/version.py @@ -1,5 +1,5 @@ _MAJOR = "0" _MINOR = "0" -_PATCH = "3" +_PATCH = "4" VERSION = "{0}.{1}.{2}".format(_MAJOR, _MINOR, _PATCH)