From 481efc245a1ee93e4cb12ed64105c487146cd616 Mon Sep 17 00:00:00 2001 From: devrimcavusoglu Date: Mon, 14 Oct 2024 15:23:27 +0300 Subject: [PATCH] Codestyle, github workflow. --- .github/workflows/check_format.yml | 58 ++++ README.md | 18 ++ disgem/__init__.py | 2 +- disgem/data_loader.py | 33 +- disgem/distractor_evaluator.py | 21 +- disgem/distractor_generator.py | 12 +- disgem/pipeline.py | 129 ++------ disgem/util.py | 1 - environment.yml | 2 - generate.py | 488 ++++++++++++++++------------- pyproject.toml | 14 + scripts/run_code_style.py | 16 + scripts/utils.py | 41 +++ setup.cfg | 9 + 14 files changed, 490 insertions(+), 354 deletions(-) create mode 100644 .github/workflows/check_format.yml create mode 100644 pyproject.toml create mode 100644 scripts/run_code_style.py create mode 100644 scripts/utils.py create mode 100644 setup.cfg diff --git a/.github/workflows/check_format.yml b/.github/workflows/check_format.yml new file mode 100644 index 0000000..3164000 --- /dev/null +++ b/.github/workflows/check_format.yml @@ -0,0 +1,58 @@ +name: Tests + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + operating-system: [ubuntu-latest, windows-latest, macos-latest] + # for Python 3.10, ref https://github.com/actions/setup-python/issues/160#issuecomment-724485470 + python-version: [3.8, 3.9, '3.10', '3.11'] + fail-fast: false + + steps: + - name: Checkout + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Restore Ubuntu cache + uses: actions/cache@v1 + if: matrix.operating-system == 'ubuntu-latest' + with: + path: ~/.cache/pip + key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}} + restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}- + + - name: Restore MacOS cache + uses: actions/cache@v1 + if: matrix.operating-system == 'macos-latest' + with: + path: ~/Library/Caches/pip + key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}} + restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}- + + - name: Restore Windows cache + uses: actions/cache@v1 + if: matrix.operating-system == 'windows-latest' + with: + path: ~\AppData\Local\pip\Cache + key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}} + restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}- + + - name: Conda env create + run: conda env create -f environment.yml + + - name: Lint with flake8 and black + run: | + conda activate disgem + python -m scripts.run_code_style check diff --git a/README.md b/README.md index 41f2c0e..30845bb 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,7 @@ Clone the repository. ```bash git clone https://github.com/obss/disgem.git +cd disgem ``` In the project root, create a virtual environment (preferably using conda) as follows: @@ -43,3 +44,20 @@ The following provides an example to generate distractors for CLOTH test-high da ```shell python -m generate data/CLOTH/test/high --data-format cloth --top-k 3 --dispersion 0 --output-path cloth_test_outputs.json ``` + + +## Contributing + +Format and check the code style of the codebase as follows. + +To check the codestyle, + +```bash +python -m scripts.run_code_style check +``` + +To format the codebase, + +```bash +python -m scripts.run_code_style format +``` \ No newline at end of file diff --git a/disgem/__init__.py b/disgem/__init__.py index c9856c5..0f3a53e 100644 --- a/disgem/__init__.py +++ b/disgem/__init__.py @@ -1 +1 @@ -from disgem.distractor_generator import MaskedLMBasedDistractorGenerator \ No newline at end of file +from disgem.distractor_generator import MaskedLMBasedDistractorGenerator diff --git a/disgem/data_loader.py b/disgem/data_loader.py index 1d70802..87bd99e 100644 --- a/disgem/data_loader.py +++ b/disgem/data_loader.py @@ -1,7 +1,7 @@ from abc import abstractmethod from dataclasses import dataclass from pathlib import Path -from typing import List, Dict, Union, OrderedDict +from typing import Dict, List, Union from disgem.util import read_json @@ -38,12 +38,7 @@ def __iter__(self) -> Instance: for i, ans in enumerate(instance.answers): d = instance.distractors_collection[i] if instance.distractors_collection is not None else None q = instance.questions[i] if instance.questions is not None else None - yield Instance( - context=instance.context, - answer=ans, - distractors=d, - question=q - ) + yield Instance(context=instance.context, answer=ans, distractors=d, question=q) class SquadLoader(DataLoader): @@ -81,7 +76,12 @@ def read(self, filepath): answer["end"] += len(question) + 1 elif self.prepend_question == "mid": # the following prepends just before the mask. - ctx = paragraph["context"][:answer["start"]] + question + " " + paragraph["context"][answer["start"]:] + ctx = ( + paragraph["context"][: answer["start"]] + + question + + " " + + paragraph["context"][answer["start"] :] + ) answer["start"] += len(question) + 1 answer["end"] += len(question) + 1 elif self.prepend_question == "end": @@ -102,6 +102,7 @@ class ClothLoader(DataLoader): a form compatible for distractor generation. See the home page for CLOTH: https://www.cs.cmu.edu/~glai1/data/cloth/ """ + _dataset_mask_str = " _ " def __iter__(self) -> Instance: @@ -112,11 +113,7 @@ def __iter__(self) -> Instance: """ for instance in self.dataset: for i, ans in enumerate(instance.answers): - yield Instance( - context=instance.context, - answer=ans, - distractors=instance.distractors_collection[i] - ) + yield Instance(context=instance.context, answer=ans, distractors=instance.distractors_collection[i]) @staticmethod def replace_nth(text: str, substr: str, replace: str, nth: int): @@ -145,7 +142,7 @@ def read(self, filepath): assert filepath.is_dir(), "`filepath` for CLOTH dataset needs to be a directory." instances = [] - files = sorted(filepath.glob('*.json')) + files = sorted(filepath.glob("*.json")) for p in files: data = read_json(p) ctx = data["article"] @@ -158,7 +155,7 @@ def read(self, filepath): answer = {"text": ans, "start": start} answer["end"] = answer["start"] + len(answer["text"]) answers.append(answer) - distractors.append(choices[:opt] + choices[opt+1:]) # remove the answer from choices + distractors.append(choices[:opt] + choices[opt + 1 :]) # remove the answer from choices assert len(answers) == len(distractors), "The length of the `answers` and `distractors` must be equal." instances.append(InstanceCollection(context=ctx, answers=answers, distractors_collection=distractors)) return instances @@ -171,6 +168,7 @@ class CdgpClothLoader(DataLoader): as it is used and published in a related work. See the home page for CDGP style CLOTH: https://huggingface.co/datasets/AndyChiang/cloth """ + _dataset_mask_str = " [MASK] " def read(self, filepath): @@ -186,7 +184,9 @@ def read(self, filepath): else: ctx = ctx.replace(self._dataset_mask_str, ans, 1) answers = [{"text": ans, "start": start, "end": start + len(ans)}] - instances.append(InstanceCollection(context=ctx, answers=answers, distractors_collection=[instance["distractors"]])) + instances.append( + InstanceCollection(context=ctx, answers=answers, distractors_collection=[instance["distractors"]]) + ) return instances @@ -196,4 +196,5 @@ class DGenLoader(CdgpClothLoader): in a form compatible for distractor generation. See the home page for DGEN Dataset: AndyChiang/dgen """ + _dataset_mask_str = "**blank**" diff --git a/disgem/distractor_evaluator.py b/disgem/distractor_evaluator.py index ddee466..4cb2e34 100644 --- a/disgem/distractor_evaluator.py +++ b/disgem/distractor_evaluator.py @@ -50,6 +50,7 @@ def _load_model_and_tokenizer(self, model_name_or_path: str): def __call__(self, inputs: Dict[str, Any], *args, **kwargs): pass + class NLIBasedDistractorEvaluator(BaseDistractorEvaluator): """ NLI based distractor evaluation, meant to be designed to provide classification @@ -81,9 +82,7 @@ def preprocess( return sentence if isinstance(distractor, dict): distractor = distractor["token_str"] - context_with_distractor = replace_str( - sentence, distractor, answer["start"], answer["end"] - ) + context_with_distractor = replace_str(sentence, distractor, answer["start"], answer["end"]) if reverse: return context_with_distractor + " " + sentence return sentence + " " + context_with_distractor @@ -101,12 +100,8 @@ def preprocess_distractors( if isinstance(distractor2, dict): distractor2 = distractor2["token_str"] - context_with_d1 = replace_str( - sentence, distractor1, answer["start"], answer["end"] - ) - context_with_d2 = replace_str( - sentence, distractor2, answer["start"], answer["end"] - ) + context_with_d1 = replace_str(sentence, distractor1, answer["start"], answer["end"]) + context_with_d2 = replace_str(sentence, distractor2, answer["start"], answer["end"]) return context_with_d1 + " " + context_with_d2 @@ -135,13 +130,9 @@ def __call__( else: distractor1 = inputs["distractors"][distractor_ids[0]] distractor2 = inputs["distractors"][distractor_ids[1]] - input_text = self.preprocess_distractors( - **inputs, distractor1=distractor1, distractor2=distractor2 - ) + input_text = self.preprocess_distractors(**inputs, distractor1=distractor1, distractor2=distractor2) nli_output = self.get_model_output(input_text) - input_text_rev = self.preprocess_distractors( - **inputs, distractor1=distractor2, distractor2=distractor1 - ) + input_text_rev = self.preprocess_distractors(**inputs, distractor1=distractor2, distractor2=distractor1) nli_output_rev = self.get_model_output(input_text_rev) return f"{nli_output}-{nli_output_rev}" diff --git a/disgem/distractor_generator.py b/disgem/distractor_generator.py index d97c966..6032141 100644 --- a/disgem/distractor_generator.py +++ b/disgem/distractor_generator.py @@ -2,12 +2,8 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer -from disgem.pipeline import ( - DistractorGenerationPipeline, -) -from disgem.util import ( - DistractorGenerationOutput, -) +from disgem.pipeline import DistractorGenerationPipeline +from disgem.util import DistractorGenerationOutput class MaskedLMBasedDistractorGenerator: @@ -61,7 +57,5 @@ def __call__( for distractor in outputs.distractors ] else: - return [ - distractor["token_str"] for distractor in outputs.distractors - ] + return [distractor["token_str"] for distractor in outputs.distractors] return outputs diff --git a/disgem/pipeline.py b/disgem/pipeline.py index 6c99d0d..49c05ab 100644 --- a/disgem/pipeline.py +++ b/disgem/pipeline.py @@ -3,31 +3,13 @@ import numpy as np import spacy -from transformers import ( - FillMaskPipeline, - ModelCard, - PreTrainedTokenizer, - add_end_docstrings, - set_seed, -) +from transformers import FillMaskPipeline, ModelCard, PreTrainedTokenizer, add_end_docstrings from transformers.feature_extraction_utils import PreTrainedFeatureExtractor -from transformers.pipelines.base import ( - PIPELINE_INIT_ARGS, - ArgumentHandler, - GenericTensor, - PipelineException, -) +from transformers.pipelines.base import PIPELINE_INIT_ARGS, ArgumentHandler, GenericTensor, PipelineException from transformers.utils import logging -from disgem.distractor_evaluator import ( - NLIBasedDistractorEvaluator, -) -from disgem.util import ( - DistractorGenerationOutput, - geometric_mean, - harmonic_mean, - replace_str, -) +from disgem.distractor_evaluator import NLIBasedDistractorEvaluator +from disgem.util import DistractorGenerationOutput, geometric_mean, harmonic_mean, replace_str logger = logging.get_logger(__name__) @@ -122,9 +104,7 @@ def __init__( def _mask_answer(self, context: str, answer: Dict, n_mask: int): mask_str = " ".join([self.tokenizer.mask_token] * n_mask) - return replace_str( - context, mask_str, start_index=answer["start"], end_index=answer["end"] - ) + return replace_str(context, mask_str, start_index=answer["start"], end_index=answer["end"]) def _sanitize_parameters( self, @@ -186,9 +166,7 @@ def _sanitize_parameters( return preprocess_params, forward_params, postprocess_params - def get_masked_index( - self, input_ids: GenericTensor, as_tuple=False - ) -> Union[List[Tuple[int, int]], np.ndarray]: + def get_masked_index(self, input_ids: GenericTensor, as_tuple=False) -> Union[List[Tuple[int, int]], np.ndarray]: masked_index = super().get_masked_index(input_ids) if as_tuple: # noinspection PyTypeChecker @@ -216,9 +194,7 @@ def preprocess( if n_mask is not None: n_tokens = n_mask else: - n_tokens = len( - self.tokenizer(inputs["answer"]["text"])["input_ids"][1:-1] - ) + n_tokens = len(self.tokenizer(inputs["answer"]["text"])["input_ids"][1:-1]) l_disp, r_disp = self._get_lr_dispersion(n_tokens, dispersion) masked_inputs = [ self._mask_answer(**inputs, n_mask=n_mask) @@ -228,17 +204,12 @@ def preprocess( replace=False, ) ] - model_inputs = [ - self.tokenizer(masked_input, return_tensors=return_tensors) - for masked_input in masked_inputs - ] + model_inputs = [self.tokenizer(masked_input, return_tensors=return_tensors) for masked_input in masked_inputs] for model_input in model_inputs: self.ensure_exactly_one_mask_token(model_input) return model_inputs - def postprocess_all_outputs( - self, all_outputs: List[Dict[str, Any]], top_k: int - ) -> List[Dict[str, Any]]: + def postprocess_all_outputs(self, all_outputs: List[Dict[str, Any]], top_k: int) -> List[Dict[str, Any]]: return sorted(all_outputs, key=lambda d: d["ranking_score"], reverse=True) @staticmethod @@ -267,10 +238,7 @@ def _preprocess_input_for_distractor_evaluation( sentences = self.split_sentences(inputs["context"]) updated_answer = dict(text=inputs["answer"]["text"]) for candid in sentences: - if ( - inputs["answer"]["start"] >= candid.start_char - and inputs["answer"]["end"] <= candid.end_char - ): + if inputs["answer"]["start"] >= candid.start_char and inputs["answer"]["end"] <= candid.end_char: start, end = self.find_span_start_end( span_start=inputs["answer"]["start"], span_end=inputs["answer"]["end"], @@ -289,9 +257,7 @@ def _preprocess_input_for_distractor_evaluation( def _evaluate_answer_distractors( self, inputs: Dict[str, Any], outputs: List[Dict[str, Any]], top_k: int = 3 ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: - processed_input = self._preprocess_input_for_distractor_evaluation( - inputs, outputs - ) + processed_input = self._preprocess_input_for_distractor_evaluation(inputs, outputs) # Compare distractors with the answer answer_distractor_evaluation_results = self.evaluator(processed_input) check = lambda x: "contradiction" in x[0] or "neutral" in x[0] @@ -307,9 +273,7 @@ def _evaluate_answer_distractors( while kept_index < top_k < len(processed_input["distractors"]): processed_input["distractors"][kept_index]["nli_output"] = [] for j in range(kept_index): - nli_out = self.evaluator( - processed_input, distractor_ids=(j, kept_index) - ) + nli_out = self.evaluator(processed_input, distractor_ids=(j, kept_index)) processed_input["distractors"][kept_index]["nli_output"].append( { "result": nli_out, @@ -318,9 +282,7 @@ def _evaluate_answer_distractors( ) if nli_out == "entailment-entailment": increment = False - discarded_distractors.append( - processed_input["distractors"].pop(kept_index) - ) + discarded_distractors.append(processed_input["distractors"].pop(kept_index)) break increment = True @@ -338,9 +300,7 @@ def _finalize_generation_outputs( self._fix_cocktail_shaker_list(output["token_list"]) self._fix_cocktail_shaker_list(output["token_str"]) outputs[i]["token_str_list"] = output["token_str"] - outputs[i]["token_str"] = self.tokenizer.decode( - output["token_list"] - ).strip() + outputs[i]["token_str"] = self.tokenizer.decode(output["token_list"]).strip() outputs[i]["score"] = np.prod(output["score_list"]) if self._use_harmonic_mean: outputs[i]["ranking_score"] = harmonic_mean(output["score_list"]) @@ -360,9 +320,7 @@ def _generate_distractors( ): idx = -1 if reverse else 0 try: - masked_indices = self.get_masked_index( - model_inputs[0]["input_ids"], as_tuple=True - ) + masked_indices = self.get_masked_index(model_inputs[0]["input_ids"], as_tuple=True) current_mask_index = masked_indices[idx] except (KeyError, IndexError): # End of the generation return self._finalize_generation_outputs( @@ -377,28 +335,20 @@ def _generate_distractors( if len(masked_indices) == 1: is_last = True - model_outputs = [ - self.forward(model_input, **forward_params) - for model_input in model_inputs - ] + model_outputs = [self.forward(model_input, **forward_params) for model_input in model_inputs] postprocess_params_ = deepcopy(postprocess_params) if is_start: postprocess_params_["top_k"] *= self._search_multiplier else: postprocess_params_["top_k"] = 1 - outputs = [ - self.postprocess(model_output, **postprocess_params_) - for model_output in model_outputs - ] + outputs = [self.postprocess(model_output, **postprocess_params_) for model_output in model_outputs] if is_start and is_last: outputs = [outputs] new_model_inputs = [] prev_outputs_ = [] for i, output in enumerate(outputs): if is_start: - output_at_idx = ( - [output[idx]] if isinstance(output[idx], dict) else output[idx] - ) + output_at_idx = [output[idx]] if isinstance(output[idx], dict) else output[idx] for out in output_at_idx: model_inputs_ = deepcopy(model_inputs[0]) model_inputs_["input_ids"][current_mask_index] = out["token"] @@ -410,9 +360,7 @@ def _generate_distractors( model_inputs_ = output[0] prev_outputs_.append(output[0]) else: - model_inputs_["input_ids"][current_mask_index] = output[idx][0][ - "token" - ] + model_inputs_["input_ids"][current_mask_index] = output[idx][0]["token"] prev_outputs_.append(output[idx][0]) new_model_inputs.append(model_inputs_) @@ -425,15 +373,9 @@ def _generate_distractors( else: for i, prev_output in enumerate(prev_outputs_): insert_idx = 0 if reverse else len(prev_outputs[i]["token_list"]) - prev_outputs[i]["score_list"].insert( - insert_idx, prev_output["score"] - ) - prev_outputs[i]["token_list"].insert( - insert_idx, prev_output["token"] - ) - prev_outputs[i]["token_str"].insert( - insert_idx, prev_output["token_str"] - ) + prev_outputs[i]["score_list"].insert(insert_idx, prev_output["score"]) + prev_outputs[i]["token_list"].insert(insert_idx, prev_output["token"]) + prev_outputs[i]["token_str"].insert(insert_idx, prev_output["token_str"]) prev_outputs[i]["sequence"] = prev_output["sequence"] if cocktail_shaker: @@ -470,20 +412,14 @@ def generate_distractors( cocktail_shaker=cocktail_shaker, ) all_outputs.extend(outputs) - all_outputs = self.postprocess_all_outputs( - all_outputs, **postprocess_params - ) + all_outputs = self.postprocess_all_outputs(all_outputs, **postprocess_params) return all_outputs def generate(self, model_inputs, forward_params, postprocess_params): if self._decoding == "l2r": - return self.generate_distractors( - model_inputs, forward_params, postprocess_params - ) + return self.generate_distractors(model_inputs, forward_params, postprocess_params) elif self._decoding == "r2l": - return self.generate_distractors( - model_inputs, forward_params, postprocess_params, reverse=True - ) + return self.generate_distractors(model_inputs, forward_params, postprocess_params, reverse=True) elif self._decoding == "ctl": return self.generate_distractors( model_inputs, @@ -493,20 +429,13 @@ def generate(self, model_inputs, forward_params, postprocess_params): ) else: raise ValueError( - f"Unknown unmasking strategy '{self._decoding}'. Supported types are " - f"('l2r', 'r2l', 'ctl')" + f"Unknown unmasking strategy '{self._decoding}'. Supported types are " f"('l2r', 'r2l', 'ctl')" ) - def run_single( - self, inputs, preprocess_params, forward_params, postprocess_params - ) -> DistractorGenerationOutput: + def run_single(self, inputs, preprocess_params, forward_params, postprocess_params) -> DistractorGenerationOutput: model_inputs = self.preprocess(inputs, **preprocess_params) - all_outputs = self.generate( - model_inputs, forward_params, postprocess_params - ) + all_outputs = self.generate(model_inputs, forward_params, postprocess_params) kept, discarded = self._evaluate_answer_distractors( inputs=inputs, outputs=all_outputs, top_k=postprocess_params["top_k"] ) - return DistractorGenerationOutput( - distractors=kept, discarded_distractors=discarded - ) \ No newline at end of file + return DistractorGenerationOutput(distractors=kept, discarded_distractors=discarded) diff --git a/disgem/util.py b/disgem/util.py index d7e23f4..a7c89f2 100644 --- a/disgem/util.py +++ b/disgem/util.py @@ -39,4 +39,3 @@ class DistractorGenerationOutput(ModelOutput): distractors: List[Dict] = None discarded_distractors: List[Dict] = None - diff --git a/environment.yml b/environment.yml index 31e518e..a6ddcef 100644 --- a/environment.yml +++ b/environment.yml @@ -14,9 +14,7 @@ dependencies: - click==8.0.4 - deepdiff==6.3.0 - en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.3.0/en_core_web_sm-3.3.0-py3-none-any.whl - - fire - flake8==3.9.2 - - gensim==4.3.1 - isort==5.9.2 - notebook - transformers==4.19.4 diff --git a/generate.py b/generate.py index 0449407..04c27ce 100644 --- a/generate.py +++ b/generate.py @@ -5,227 +5,295 @@ import numpy as np from tqdm import tqdm -from transformers import FillMaskPipeline, AutoModelForMaskedLM, AutoTokenizer, set_seed +from transformers import AutoModelForMaskedLM, AutoTokenizer, FillMaskPipeline, set_seed from disgem import MaskedLMBasedDistractorGenerator -from disgem.data_loader import ClothLoader, CdgpClothLoader, SquadLoader, DGenLoader -from disgem.util import read_json, harmonic_mean +from disgem.data_loader import CdgpClothLoader, ClothLoader, DGenLoader, SquadLoader +from disgem.util import harmonic_mean, read_json def create_args(): - parser = argparse.ArgumentParser( - prog="DisGeM", - description="Distractor Generator for MCQ" - ) - parser.add_argument("filepath", type=str, help="Path to SQuAD style data.") - parser.add_argument("--data-format", type=str, default="squad", choices= ["cloth", "cdgp-cloth", "squad", "dgen"], - help="Data format whether SQuAD style or CLOTH style dataset. Default 'squad'.") - parser.add_argument("--model", type=str, default="roberta-large", help="Masked LM for distractor generation phase. Models are loaded from huggingface hub. Default 'roberta-large'.") - parser.add_argument("--top-k", type=int, default=3, help="Number of distractors. By default 3.") - parser.add_argument("--batch-size", type=int, default=1, help="Batch size, batched inference might be even slower, " - "see https://huggingface.co/docs/transformers/main_classes/pipelines#pipeline-batching. By default 1.") - parser.add_argument("--output-path", type=str, default=None, - help="File path to dump outputs. By default no output file is created.") - parser.add_argument("--output-format", type=str, default="cdgp", choices=["cdgp", "all"]) - parser.add_argument("--question-limit", type=int, default=100, help="Question limit to stop generation at. Default 100.") - parser.add_argument("--dispersion", type=int, default=1, help="Dispersion parameter to determine interval for sampling num mask tokens. By default 1.") - parser.add_argument("--device", type=int, default=-1, help="Device for generation phase. Set -1 for cpu, numbers 0,1,2,... refer to that gpu device. By default -1.") - parser.add_argument("--no-minify-output", action="store_true", help="If given, no minification is placed on outputs.") - parser.add_argument("--decoding", type=str, default="l2r", choices=["l2r", "r2l", "ctl"], - help="Generation strategy for the generation phase.By default 'snowball'.") - parser.add_argument("--n-mask", type=int, default=None, help="Number of mask tokens to be replaced with answer text. Default `none`.") - parser.add_argument("--use-geometric-mean", action="store_true", help="If given, uses geometric mean to determine final ranking, otherwise uses harmonic mean.") - parser.add_argument("--single-mask", action="store_true", help="If given, only applies a single mask to replace the answer. It is the same as setting `dispersion=0` and `n_mask=1`.") - parser.add_argument("--seed", type=int, default=42, help="Seed for RNG. Default 42.") - parser.add_argument("--prepend-question", type=str, default="none", choices=["none", "begin", "mid"], - help="If not `none`, prepends `question` to the context to guide the distractor generation with the question. " - "Default option is `none`.") - parser.add_argument("--evaluate", action="store_true", help="If given, starts evaluation process rather than generation. You must supply result json file for evaluation.") - return parser.parse_args() + parser = argparse.ArgumentParser(prog="DisGeM", description="Distractor Generator for MCQ") + parser.add_argument("filepath", type=str, help="Path to SQuAD style data.") + parser.add_argument( + "--data-format", + type=str, + default="squad", + choices=["cloth", "cdgp-cloth", "squad", "dgen"], + help="Data format whether SQuAD style or CLOTH style dataset. Default 'squad'.", + ) + parser.add_argument( + "--model", + type=str, + default="roberta-large", + help="Masked LM for distractor generation phase. Models are loaded from huggingface hub. Default 'roberta-large'.", + ) + parser.add_argument("--top-k", type=int, default=3, help="Number of distractors. By default 3.") + parser.add_argument( + "--batch-size", + type=int, + default=1, + help="Batch size, batched inference might be even slower, " + "see https://huggingface.co/docs/transformers/main_classes/pipelines#pipeline-batching. By default 1.", + ) + parser.add_argument( + "--output-path", type=str, default=None, help="File path to dump outputs. By default no output file is created." + ) + parser.add_argument("--output-format", type=str, default="cdgp", choices=["cdgp", "all"]) + parser.add_argument( + "--question-limit", type=int, default=100, help="Question limit to stop generation at. Default 100." + ) + parser.add_argument( + "--dispersion", + type=int, + default=1, + help="Dispersion parameter to determine interval for sampling num mask tokens. By default 1.", + ) + parser.add_argument( + "--device", + type=int, + default=-1, + help="Device for generation phase. Set -1 for cpu, numbers 0,1,2,... refer to that gpu device. By default -1.", + ) + parser.add_argument( + "--no-minify-output", action="store_true", help="If given, no minification is placed on outputs." + ) + parser.add_argument( + "--decoding", + type=str, + default="l2r", + choices=["l2r", "r2l", "ctl"], + help="Generation strategy for the generation phase.By default 'snowball'.", + ) + parser.add_argument( + "--n-mask", + type=int, + default=None, + help="Number of mask tokens to be replaced with answer text. Default `none`.", + ) + parser.add_argument( + "--use-geometric-mean", + action="store_true", + help="If given, uses geometric mean to determine final ranking, otherwise uses harmonic mean.", + ) + parser.add_argument( + "--single-mask", + action="store_true", + help="If given, only applies a single mask to replace the answer. It is the same as setting `dispersion=0` and `n_mask=1`.", + ) + parser.add_argument("--seed", type=int, default=42, help="Seed for RNG. Default 42.") + parser.add_argument( + "--prepend-question", + type=str, + default="none", + choices=["none", "begin", "mid"], + help="If not `none`, prepends `question` to the context to guide the distractor generation with the question. " + "Default option is `none`.", + ) + parser.add_argument( + "--evaluate", + action="store_true", + help="If given, starts evaluation process rather than generation. You must supply result json file for evaluation.", + ) + return parser.parse_args() def main(args): - if args.prepend_question != "none": - warnings.warn("`--prepend-question` is only available for squad format.") - if args.batch_size > 1: - warnings.warn("Currently, batched inference is not supported.") - args.batch_size = 1 - if args.data_format == "cloth": - data_loader = ClothLoader(args.filepath) - elif args.data_format == "cdgp-cloth": - data_loader = CdgpClothLoader(args.filepath) - elif args.data_format == "dgen": - data_loader = DGenLoader(args.filepath) - elif args.data_format == "squad": - data_loader = SquadLoader(args.filepath, prepend_question=args.prepend_question) - else: - raise ValueError(f"Unknown data format {args.data_format}.") - - distractor_generator = MaskedLMBasedDistractorGenerator( - pretrained_model_name_or_path=args.model, - dispersion=args.dispersion, - n_mask=args.n_mask, - device=args.device, - decoding=args.decoding, - single_mask=args.single_mask - ) - - squad_answers = [] - outputs = [] - count = 0 - pbar = tqdm(data_loader) - for instance in pbar: - pbar.set_postfix({"count": count}) - if count == args.question_limit: - break - - dgen_tokenizer = distractor_generator._pipeline.tokenizer - if len(dgen_tokenizer.encode(instance.context)) > dgen_tokenizer.model_max_length: - # Skip if tokenized context does not fit into model max input length - continue - - if args.data_format == "squad": - if args.prepend_question: - pass - elif instance.answer in squad_answers: - # squad contains different questions for some answer spans. Since our - # framework does not depend on question, we skip these questions as - # it would yield the same distractors. - continue - else: - squad_answers.append(instance.answer) - - generations = distractor_generator( - context=instance.context, - answer=instance.answer, - minify_output=not args.no_minify_output, - top_k=args.top_k, - use_harmonic_mean=not args.use_geometric_mean, - batch_size=args.batch_size - ) - if args.data_format == "squad": # no gt distractors/evaluation, put context as well - outputs.append( - { - "context": instance.context, - "question": instance.question, - "answer": instance.answer, - "generations": generations - } - ) - else: - if args.output_format == "cdgp": - outputs.append( - { - "generations": generations, - "distractors": instance.distractors - } - ) - else: - # For better readability, put blank in the output - ctx = instance.context[:instance.answer["start"]] + " ____ " + instance.context[instance.answer["end"]:] - outputs.append( - { - "context": ctx, - "answer": instance.answer["text"], - "generations": generations, - "distractors": instance.distractors - } - ) - count += 1 - - if args.output_path is not None: - output_path = Path(args.output_path) - output_path.parent.mkdir(parents=True, exist_ok=True) - with open(output_path.as_posix(), "w") as fd_out: - json.dump(outputs, fd_out) + if args.prepend_question != "none": + warnings.warn("`--prepend-question` is only available for squad format.") + if args.batch_size > 1: + warnings.warn("Currently, batched inference is not supported.") + args.batch_size = 1 + if args.data_format == "cloth": + data_loader = ClothLoader(args.filepath) + elif args.data_format == "cdgp-cloth": + data_loader = CdgpClothLoader(args.filepath) + elif args.data_format == "dgen": + data_loader = DGenLoader(args.filepath) + elif args.data_format == "squad": + data_loader = SquadLoader(args.filepath, prepend_question=args.prepend_question) + else: + raise ValueError(f"Unknown data format {args.data_format}.") + + distractor_generator = MaskedLMBasedDistractorGenerator( + pretrained_model_name_or_path=args.model, + dispersion=args.dispersion, + n_mask=args.n_mask, + device=args.device, + decoding=args.decoding, + single_mask=args.single_mask, + ) + + squad_answers = [] + outputs = [] + count = 0 + pbar = tqdm(data_loader) + for instance in pbar: + pbar.set_postfix({"count": count}) + if count == args.question_limit: + break + + dgen_tokenizer = distractor_generator._pipeline.tokenizer + if len(dgen_tokenizer.encode(instance.context)) > dgen_tokenizer.model_max_length: + # Skip if tokenized context does not fit into model max input length + continue + + if args.data_format == "squad": + if args.prepend_question: + pass + elif instance.answer in squad_answers: + # squad contains different questions for some answer spans. Since our + # framework does not depend on question, we skip these questions as + # it would yield the same distractors. + continue + else: + squad_answers.append(instance.answer) + + generations = distractor_generator( + context=instance.context, + answer=instance.answer, + minify_output=not args.no_minify_output, + top_k=args.top_k, + use_harmonic_mean=not args.use_geometric_mean, + batch_size=args.batch_size, + ) + if args.data_format == "squad": # no gt distractors/evaluation, put context as well + outputs.append( + { + "context": instance.context, + "question": instance.question, + "answer": instance.answer, + "generations": generations, + } + ) + else: + if args.output_format == "cdgp": + outputs.append({"generations": generations, "distractors": instance.distractors}) + else: + # For better readability, put blank in the output + ctx = ( + instance.context[: instance.answer["start"]] + " ____ " + instance.context[instance.answer["end"] :] + ) + outputs.append( + { + "context": ctx, + "answer": instance.answer["text"], + "generations": generations, + "distractors": instance.distractors, + } + ) + count += 1 + + if args.output_path is not None: + output_path = Path(args.output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path.as_posix(), "w") as fd_out: + json.dump(outputs, fd_out) def evaluate(args): - """ - - Args: - args: - - Returns: - - """ - - # metrics - def precision(preds, targets, k: int = 1): - matches = [int(generation in targets) for generation in preds] - return sum(matches[:k]) / k - - def recall(preds, targets, k: int = 1): - matches = [int(generation in targets) for generation in preds] - return sum(matches[:k]) / len(targets) - - def f1(preds, targets, k: int = 1): - p = precision(preds, targets, k) - r = recall(preds, targets, k) - return harmonic_mean([p, r]) - - def ndcg_at_k(preds, targets, k: int = 1): - def dcg_at_k(r, k): - r = np.asfarray(r)[:k] - if r.size: - return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1))) - return 0. - r = [int(generation in targets) for generation in preds] - idcg = dcg_at_k(sorted(r, reverse=True), k) - if not idcg: - return 0. - return dcg_at_k(r, k) / idcg - - def mmr_at_k(preds, targets, k: int = 1): - matches = [int(generation in targets) for generation in preds] - k = len(matches) if k > len(matches) else k - for i in range(k): - if matches[i] == 1: - return 1 / (i+1) - return .0 - - outputs = read_json(args.filepath) - avg_eval = { - "P@1" : 0.0, "P@3": 0.0, "P@5": 0.0, "P@10" : 0.0, - "R@1": 0.0, "R@3": 0.0, "R@5": 0.0, "R@10": 0.0, - "F1@1": 0.0, "F1@3": 0.0, "F1@5" : 0.0, "F1@10": 0.0, - "MRR@1": 0.0, "MRR@3": 0.0, "MRR@5": 0.0, "MRR@10": 0.0, - "NDCG@1": 0.0, "NDCG@3": 0.0, "NDCG@5": 0.0, "NDCG@10": 0.0} - for output in outputs: - distractors = [d.lower() for d in output["distractors"]] - generations = [d.lower() for d in output["generations"]] - - for key in avg_eval.keys(): - metric, k = key.split("@") - if metric == "P": - metric_fn = precision - elif metric == "R": - metric_fn = recall - elif metric == "F1": - metric_fn = f1 - elif metric == "NDCG": - metric_fn = ndcg_at_k - elif metric == "MRR": - metric_fn = mmr_at_k - else: - continue - avg_eval[key] += metric_fn(preds=generations, targets=distractors, k=int(k)) - - # calculate average - for key in avg_eval.keys(): - avg_eval[key] /= len(outputs) - avg_eval[key] = str(round(100 * avg_eval[key], 4)) + "%" - - print(json.dumps(avg_eval, indent=2)) - if args.output_path is not None: - with open(args.output_path, "w") as fd_out: - json.dump(avg_eval, fd_out, indent=2) + """ + + Args: + args: + + Returns: + + """ + + # metrics + def precision(preds, targets, k: int = 1): + matches = [int(generation in targets) for generation in preds] + return sum(matches[:k]) / k + + def recall(preds, targets, k: int = 1): + matches = [int(generation in targets) for generation in preds] + return sum(matches[:k]) / len(targets) + + def f1(preds, targets, k: int = 1): + p = precision(preds, targets, k) + r = recall(preds, targets, k) + return harmonic_mean([p, r]) + + def ndcg_at_k(preds, targets, k: int = 1): + def dcg_at_k(r, k): + r = np.asfarray(r)[:k] + if r.size: + return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1))) + return 0.0 + + r = [int(generation in targets) for generation in preds] + idcg = dcg_at_k(sorted(r, reverse=True), k) + if not idcg: + return 0.0 + return dcg_at_k(r, k) / idcg + + def mmr_at_k(preds, targets, k: int = 1): + matches = [int(generation in targets) for generation in preds] + k = len(matches) if k > len(matches) else k + for i in range(k): + if matches[i] == 1: + return 1 / (i + 1) + return 0.0 + + outputs = read_json(args.filepath) + avg_eval = { + "P@1": 0.0, + "P@3": 0.0, + "P@5": 0.0, + "P@10": 0.0, + "R@1": 0.0, + "R@3": 0.0, + "R@5": 0.0, + "R@10": 0.0, + "F1@1": 0.0, + "F1@3": 0.0, + "F1@5": 0.0, + "F1@10": 0.0, + "MRR@1": 0.0, + "MRR@3": 0.0, + "MRR@5": 0.0, + "MRR@10": 0.0, + "NDCG@1": 0.0, + "NDCG@3": 0.0, + "NDCG@5": 0.0, + "NDCG@10": 0.0, + } + for output in outputs: + distractors = [d.lower() for d in output["distractors"]] + generations = [d.lower() for d in output["generations"]] + + for key in avg_eval.keys(): + metric, k = key.split("@") + if metric == "P": + metric_fn = precision + elif metric == "R": + metric_fn = recall + elif metric == "F1": + metric_fn = f1 + elif metric == "NDCG": + metric_fn = ndcg_at_k + elif metric == "MRR": + metric_fn = mmr_at_k + else: + continue + avg_eval[key] += metric_fn(preds=generations, targets=distractors, k=int(k)) + + # calculate average + for key in avg_eval.keys(): + avg_eval[key] /= len(outputs) + avg_eval[key] = str(round(100 * avg_eval[key], 4)) + "%" + + print(json.dumps(avg_eval, indent=2)) + if args.output_path is not None: + with open(args.output_path, "w") as fd_out: + json.dump(avg_eval, fd_out, indent=2) if __name__ == "__main__": - args = create_args() - set_seed(args.seed) - if args.evaluate: - evaluate(args) - else: - main(args) + args = create_args() + set_seed(args.seed) + if args.evaluate: + evaluate(args) + else: + main(args) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..3dbb71a --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,14 @@ +[tool.black] +line-length = 120 +exclude = ''' +( + /( + | .git + | venv + | .venv + )/ +) +''' + +[tool.pytest.ini_options] +tmp_path_retention_policy = "failed" \ No newline at end of file diff --git a/scripts/run_code_style.py b/scripts/run_code_style.py new file mode 100644 index 0000000..55aabed --- /dev/null +++ b/scripts/run_code_style.py @@ -0,0 +1,16 @@ +import sys + +from scripts.utils import shell, validate_and_exit + +if __name__ == "__main__": + arg = sys.argv[1] + + if arg == "check": + sts_flake = shell("flake8 src tests services --config setup.cfg") + sts_isort = shell("isort . --check --settings setup.cfg") + sts_black = shell("black . --check --config pyproject.toml") + validate_and_exit(flake8=sts_flake, isort=sts_isort, black=sts_black) + elif arg == "format": + sts_isort = shell("isort . --settings setup.cfg") + sts_black = shell("black . --config pyproject.toml") + validate_and_exit(isort=sts_isort, black=sts_black) diff --git a/scripts/utils.py b/scripts/utils.py new file mode 100644 index 0000000..5b25767 --- /dev/null +++ b/scripts/utils.py @@ -0,0 +1,41 @@ +import os +import shutil +import sys + + +def shell(command, exit_status=0): + """ + Run command through shell and return exit status if exit status of command run match with given exit status. + + Args: + command: (str) Command string which runs through system shell. + exit_status: (int) Expected exit status of given command run. + + Returns: actual_exit_status + + """ + actual_exit_status = os.system(command) + if actual_exit_status == exit_status: + return 0 + return actual_exit_status + + +def validate_and_exit(expected_out_status=0, **kwargs): + if all([arg == expected_out_status for arg in kwargs.values()]): + # Expected status, OK + sys.exit(0) + else: + # Failure + print_console_centered("Summary Results") + fail_count = 0 + for component, exit_status in kwargs.items(): + if exit_status != expected_out_status: + print(f"{component} failed.") + fail_count += 1 + print_console_centered(f"{len(kwargs)-fail_count} success, {fail_count} failure") + sys.exit(1) + + +def print_console_centered(text: str, fill_char="="): + w, _ = shutil.get_terminal_size((80, 20)) + print(f" {text} ".center(w, fill_char)) diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..85f524b --- /dev/null +++ b/setup.cfg @@ -0,0 +1,9 @@ +[flake8] +max-line-length = 120 +select = E9,F63,F7,F82 +per-file-ignores = __init__.py: F401 +max-complexity = 10 + +[isort] +line_length=120 +profile=black