From 829a84b4a2a2336eb508a8c338992543c8b7d9f7 Mon Sep 17 00:00:00 2001 From: devrimcavusoglu Date: Mon, 30 Sep 2024 11:48:06 +0300 Subject: [PATCH 1/4] license and readme. --- LICENSE | 213 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ README.md | 32 ++++++-- 2 files changed, 239 insertions(+), 6 deletions(-) create mode 100644 LICENSE diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..6682ad7 --- /dev/null +++ b/LICENSE @@ -0,0 +1,213 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + Patent Disclaimer: Pending Patent Applications. Please be advised + that a patent application is pending for certain inventions or + technologies described or used within the Work. This License does + not grant any rights under any pending patent applications until such + patents are issued. If such patents are issued, any licenses for the + Work will be subject to the terms and scope of those patents. You are + advised to consult with legal counsel to fully understand the implications + of the patent status before using or distributing the Work or any + derivative thereof. Until the patent process is finalized, no explicit + or implicit patent licenses are granted under this License, except + as may be required by applicable law. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2024 Open Business Software Solutions (OBSS) + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/README.md b/README.md index 18e6ceb..343c391 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,32 @@ -# disgem -Distractor Generation for Multiple Choice Question +# DisGeM: Distractor Generation for Multiple Choice Questions with Span Masking +Arxiv +DisGeM + + +A Distractor Generation framework utilizing Pre-trained Language Models (PLMs) that are pre-trained with Masked Language Modeling (MLM) objective. + +[Paper](https://arxiv.org/abs/2409.18263v1) + +### Abstract + +> Recent advancements in Natural Language Processing (NLP) have impacted numerous sub-fields such as natural language generation, natural language inference, question answering, and more. However, in the field of question generation, the creation of distractors for multiple-choice questions (MCQ) remains a challenging task. In this work, we present a simple, generic framework for distractor generation using readily available Large Language Models (LLMs). Unlike previous methods, our framework relies solely on pre-trained language models and does not require additional training on specific datasets. Building upon previous research, we introduce a two-stage framework consisting of candidate generation and candidate selection. Our proposed distractor generation framework outperforms previous methods without the need for training or fine-tuning. Human evaluations confirm that our approach produces more effective and engaging distractors. The related codebase is publicly available at https://github.com/obss/disgem. ## Installation +Clone the repository. + +```bash +git clone https://github.com/obss/disgem.git +``` + +In the project root, create a virtual environment (preferably using conda) as follows: + ```shell conda env create -f environment.yml ``` -## Generate Distractors +## Datasets Download datasets by the following command. This script will download CLOTH and DGen datasets. @@ -16,10 +34,12 @@ Download datasets by the following command. This script will download CLOTH and bash scripts/download_data.sh ``` -To generate distractors for CLOTH test-high dataset, run the following command. You can alter `top-k` and `dispersion` parameters. +## Generate Distractors + +To see the arguments for generation see `python -m generate --help`. + +The following provides an example to generate distractors for CLOTH test-high dataset. You can alter `top-k` and `dispersion` parameters as needed. ```shell python -m generate data/CLOTH/test/high --data-format cloth --top-k 3 --dispersion 0 --output-path cloth_test_outputs.json ``` - -To see the arguments for generation see `python -m generate --help`. From 254381201dc2fc4d8ced81ce11991bb5a8083dd7 Mon Sep 17 00:00:00 2001 From: devrimcavusoglu Date: Mon, 30 Sep 2024 19:20:53 +0300 Subject: [PATCH 2/4] link change. --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 343c391..41f2c0e 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,12 @@ # DisGeM: Distractor Generation for Multiple Choice Questions with Span Masking -Arxiv +Arxiv DisGeM A Distractor Generation framework utilizing Pre-trained Language Models (PLMs) that are pre-trained with Masked Language Modeling (MLM) objective. -[Paper](https://arxiv.org/abs/2409.18263v1) +[Paper](https://arxiv.org/abs/2409.18263) ### Abstract From 3454c86c0a129b9c7cc06955ea692c767e5aa6ab Mon Sep 17 00:00:00 2001 From: devrimcavusoglu Date: Mon, 14 Oct 2024 15:09:37 +0300 Subject: [PATCH 3/4] license updated. --- LICENSE | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/LICENSE b/LICENSE index 6682ad7..cfccd00 100644 --- a/LICENSE +++ b/LICENSE @@ -78,25 +78,22 @@ where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a + with the Work to which such Contribution(s) was submitted. + + Patent Pending Notice: Certain aspects of the Work may be subject + to a pending patent application. The Contributor(s) makes no guarantees + regarding the outcome of the patent process, and the granting of this + License does not imply any waiver or release of patent rights beyond + what is explicitly covered here. You are advised to consult the patent + office for more information about the pending status. + + If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. - Patent Disclaimer: Pending Patent Applications. Please be advised - that a patent application is pending for certain inventions or - technologies described or used within the Work. This License does - not grant any rights under any pending patent applications until such - patents are issued. If such patents are issued, any licenses for the - Work will be subject to the terms and scope of those patents. You are - advised to consult with legal counsel to fully understand the implications - of the patent status before using or distributing the Work or any - derivative thereof. Until the patent process is finalized, no explicit - or implicit patent licenses are granted under this License, except - as may be required by applicable law. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without @@ -162,6 +159,12 @@ appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. + Patent Disclaimer: As the Work is subject to a pending patent, + no warranties or guarantees are made regarding any patents that may + be granted in the future. The responsibility to check for patents + on any implementation of the Work is on You, and You assume all risks + regarding patent claims or potential infringements. + 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly From 481efc245a1ee93e4cb12ed64105c487146cd616 Mon Sep 17 00:00:00 2001 From: devrimcavusoglu Date: Mon, 14 Oct 2024 15:23:27 +0300 Subject: [PATCH 4/4] Codestyle, github workflow. --- .github/workflows/check_format.yml | 58 ++++ README.md | 18 ++ disgem/__init__.py | 2 +- disgem/data_loader.py | 33 +- disgem/distractor_evaluator.py | 21 +- disgem/distractor_generator.py | 12 +- disgem/pipeline.py | 129 ++------ disgem/util.py | 1 - environment.yml | 2 - generate.py | 488 ++++++++++++++++------------- pyproject.toml | 14 + scripts/run_code_style.py | 16 + scripts/utils.py | 41 +++ setup.cfg | 9 + 14 files changed, 490 insertions(+), 354 deletions(-) create mode 100644 .github/workflows/check_format.yml create mode 100644 pyproject.toml create mode 100644 scripts/run_code_style.py create mode 100644 scripts/utils.py create mode 100644 setup.cfg diff --git a/.github/workflows/check_format.yml b/.github/workflows/check_format.yml new file mode 100644 index 0000000..3164000 --- /dev/null +++ b/.github/workflows/check_format.yml @@ -0,0 +1,58 @@ +name: Tests + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + operating-system: [ubuntu-latest, windows-latest, macos-latest] + # for Python 3.10, ref https://github.com/actions/setup-python/issues/160#issuecomment-724485470 + python-version: [3.8, 3.9, '3.10', '3.11'] + fail-fast: false + + steps: + - name: Checkout + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Restore Ubuntu cache + uses: actions/cache@v1 + if: matrix.operating-system == 'ubuntu-latest' + with: + path: ~/.cache/pip + key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}} + restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}- + + - name: Restore MacOS cache + uses: actions/cache@v1 + if: matrix.operating-system == 'macos-latest' + with: + path: ~/Library/Caches/pip + key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}} + restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}- + + - name: Restore Windows cache + uses: actions/cache@v1 + if: matrix.operating-system == 'windows-latest' + with: + path: ~\AppData\Local\pip\Cache + key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}} + restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}- + + - name: Conda env create + run: conda env create -f environment.yml + + - name: Lint with flake8 and black + run: | + conda activate disgem + python -m scripts.run_code_style check diff --git a/README.md b/README.md index 41f2c0e..30845bb 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,7 @@ Clone the repository. ```bash git clone https://github.com/obss/disgem.git +cd disgem ``` In the project root, create a virtual environment (preferably using conda) as follows: @@ -43,3 +44,20 @@ The following provides an example to generate distractors for CLOTH test-high da ```shell python -m generate data/CLOTH/test/high --data-format cloth --top-k 3 --dispersion 0 --output-path cloth_test_outputs.json ``` + + +## Contributing + +Format and check the code style of the codebase as follows. + +To check the codestyle, + +```bash +python -m scripts.run_code_style check +``` + +To format the codebase, + +```bash +python -m scripts.run_code_style format +``` \ No newline at end of file diff --git a/disgem/__init__.py b/disgem/__init__.py index c9856c5..0f3a53e 100644 --- a/disgem/__init__.py +++ b/disgem/__init__.py @@ -1 +1 @@ -from disgem.distractor_generator import MaskedLMBasedDistractorGenerator \ No newline at end of file +from disgem.distractor_generator import MaskedLMBasedDistractorGenerator diff --git a/disgem/data_loader.py b/disgem/data_loader.py index 1d70802..87bd99e 100644 --- a/disgem/data_loader.py +++ b/disgem/data_loader.py @@ -1,7 +1,7 @@ from abc import abstractmethod from dataclasses import dataclass from pathlib import Path -from typing import List, Dict, Union, OrderedDict +from typing import Dict, List, Union from disgem.util import read_json @@ -38,12 +38,7 @@ def __iter__(self) -> Instance: for i, ans in enumerate(instance.answers): d = instance.distractors_collection[i] if instance.distractors_collection is not None else None q = instance.questions[i] if instance.questions is not None else None - yield Instance( - context=instance.context, - answer=ans, - distractors=d, - question=q - ) + yield Instance(context=instance.context, answer=ans, distractors=d, question=q) class SquadLoader(DataLoader): @@ -81,7 +76,12 @@ def read(self, filepath): answer["end"] += len(question) + 1 elif self.prepend_question == "mid": # the following prepends just before the mask. - ctx = paragraph["context"][:answer["start"]] + question + " " + paragraph["context"][answer["start"]:] + ctx = ( + paragraph["context"][: answer["start"]] + + question + + " " + + paragraph["context"][answer["start"] :] + ) answer["start"] += len(question) + 1 answer["end"] += len(question) + 1 elif self.prepend_question == "end": @@ -102,6 +102,7 @@ class ClothLoader(DataLoader): a form compatible for distractor generation. See the home page for CLOTH: https://www.cs.cmu.edu/~glai1/data/cloth/ """ + _dataset_mask_str = " _ " def __iter__(self) -> Instance: @@ -112,11 +113,7 @@ def __iter__(self) -> Instance: """ for instance in self.dataset: for i, ans in enumerate(instance.answers): - yield Instance( - context=instance.context, - answer=ans, - distractors=instance.distractors_collection[i] - ) + yield Instance(context=instance.context, answer=ans, distractors=instance.distractors_collection[i]) @staticmethod def replace_nth(text: str, substr: str, replace: str, nth: int): @@ -145,7 +142,7 @@ def read(self, filepath): assert filepath.is_dir(), "`filepath` for CLOTH dataset needs to be a directory." instances = [] - files = sorted(filepath.glob('*.json')) + files = sorted(filepath.glob("*.json")) for p in files: data = read_json(p) ctx = data["article"] @@ -158,7 +155,7 @@ def read(self, filepath): answer = {"text": ans, "start": start} answer["end"] = answer["start"] + len(answer["text"]) answers.append(answer) - distractors.append(choices[:opt] + choices[opt+1:]) # remove the answer from choices + distractors.append(choices[:opt] + choices[opt + 1 :]) # remove the answer from choices assert len(answers) == len(distractors), "The length of the `answers` and `distractors` must be equal." instances.append(InstanceCollection(context=ctx, answers=answers, distractors_collection=distractors)) return instances @@ -171,6 +168,7 @@ class CdgpClothLoader(DataLoader): as it is used and published in a related work. See the home page for CDGP style CLOTH: https://huggingface.co/datasets/AndyChiang/cloth """ + _dataset_mask_str = " [MASK] " def read(self, filepath): @@ -186,7 +184,9 @@ def read(self, filepath): else: ctx = ctx.replace(self._dataset_mask_str, ans, 1) answers = [{"text": ans, "start": start, "end": start + len(ans)}] - instances.append(InstanceCollection(context=ctx, answers=answers, distractors_collection=[instance["distractors"]])) + instances.append( + InstanceCollection(context=ctx, answers=answers, distractors_collection=[instance["distractors"]]) + ) return instances @@ -196,4 +196,5 @@ class DGenLoader(CdgpClothLoader): in a form compatible for distractor generation. See the home page for DGEN Dataset: AndyChiang/dgen """ + _dataset_mask_str = "**blank**" diff --git a/disgem/distractor_evaluator.py b/disgem/distractor_evaluator.py index ddee466..4cb2e34 100644 --- a/disgem/distractor_evaluator.py +++ b/disgem/distractor_evaluator.py @@ -50,6 +50,7 @@ def _load_model_and_tokenizer(self, model_name_or_path: str): def __call__(self, inputs: Dict[str, Any], *args, **kwargs): pass + class NLIBasedDistractorEvaluator(BaseDistractorEvaluator): """ NLI based distractor evaluation, meant to be designed to provide classification @@ -81,9 +82,7 @@ def preprocess( return sentence if isinstance(distractor, dict): distractor = distractor["token_str"] - context_with_distractor = replace_str( - sentence, distractor, answer["start"], answer["end"] - ) + context_with_distractor = replace_str(sentence, distractor, answer["start"], answer["end"]) if reverse: return context_with_distractor + " " + sentence return sentence + " " + context_with_distractor @@ -101,12 +100,8 @@ def preprocess_distractors( if isinstance(distractor2, dict): distractor2 = distractor2["token_str"] - context_with_d1 = replace_str( - sentence, distractor1, answer["start"], answer["end"] - ) - context_with_d2 = replace_str( - sentence, distractor2, answer["start"], answer["end"] - ) + context_with_d1 = replace_str(sentence, distractor1, answer["start"], answer["end"]) + context_with_d2 = replace_str(sentence, distractor2, answer["start"], answer["end"]) return context_with_d1 + " " + context_with_d2 @@ -135,13 +130,9 @@ def __call__( else: distractor1 = inputs["distractors"][distractor_ids[0]] distractor2 = inputs["distractors"][distractor_ids[1]] - input_text = self.preprocess_distractors( - **inputs, distractor1=distractor1, distractor2=distractor2 - ) + input_text = self.preprocess_distractors(**inputs, distractor1=distractor1, distractor2=distractor2) nli_output = self.get_model_output(input_text) - input_text_rev = self.preprocess_distractors( - **inputs, distractor1=distractor2, distractor2=distractor1 - ) + input_text_rev = self.preprocess_distractors(**inputs, distractor1=distractor2, distractor2=distractor1) nli_output_rev = self.get_model_output(input_text_rev) return f"{nli_output}-{nli_output_rev}" diff --git a/disgem/distractor_generator.py b/disgem/distractor_generator.py index d97c966..6032141 100644 --- a/disgem/distractor_generator.py +++ b/disgem/distractor_generator.py @@ -2,12 +2,8 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer -from disgem.pipeline import ( - DistractorGenerationPipeline, -) -from disgem.util import ( - DistractorGenerationOutput, -) +from disgem.pipeline import DistractorGenerationPipeline +from disgem.util import DistractorGenerationOutput class MaskedLMBasedDistractorGenerator: @@ -61,7 +57,5 @@ def __call__( for distractor in outputs.distractors ] else: - return [ - distractor["token_str"] for distractor in outputs.distractors - ] + return [distractor["token_str"] for distractor in outputs.distractors] return outputs diff --git a/disgem/pipeline.py b/disgem/pipeline.py index 6c99d0d..49c05ab 100644 --- a/disgem/pipeline.py +++ b/disgem/pipeline.py @@ -3,31 +3,13 @@ import numpy as np import spacy -from transformers import ( - FillMaskPipeline, - ModelCard, - PreTrainedTokenizer, - add_end_docstrings, - set_seed, -) +from transformers import FillMaskPipeline, ModelCard, PreTrainedTokenizer, add_end_docstrings from transformers.feature_extraction_utils import PreTrainedFeatureExtractor -from transformers.pipelines.base import ( - PIPELINE_INIT_ARGS, - ArgumentHandler, - GenericTensor, - PipelineException, -) +from transformers.pipelines.base import PIPELINE_INIT_ARGS, ArgumentHandler, GenericTensor, PipelineException from transformers.utils import logging -from disgem.distractor_evaluator import ( - NLIBasedDistractorEvaluator, -) -from disgem.util import ( - DistractorGenerationOutput, - geometric_mean, - harmonic_mean, - replace_str, -) +from disgem.distractor_evaluator import NLIBasedDistractorEvaluator +from disgem.util import DistractorGenerationOutput, geometric_mean, harmonic_mean, replace_str logger = logging.get_logger(__name__) @@ -122,9 +104,7 @@ def __init__( def _mask_answer(self, context: str, answer: Dict, n_mask: int): mask_str = " ".join([self.tokenizer.mask_token] * n_mask) - return replace_str( - context, mask_str, start_index=answer["start"], end_index=answer["end"] - ) + return replace_str(context, mask_str, start_index=answer["start"], end_index=answer["end"]) def _sanitize_parameters( self, @@ -186,9 +166,7 @@ def _sanitize_parameters( return preprocess_params, forward_params, postprocess_params - def get_masked_index( - self, input_ids: GenericTensor, as_tuple=False - ) -> Union[List[Tuple[int, int]], np.ndarray]: + def get_masked_index(self, input_ids: GenericTensor, as_tuple=False) -> Union[List[Tuple[int, int]], np.ndarray]: masked_index = super().get_masked_index(input_ids) if as_tuple: # noinspection PyTypeChecker @@ -216,9 +194,7 @@ def preprocess( if n_mask is not None: n_tokens = n_mask else: - n_tokens = len( - self.tokenizer(inputs["answer"]["text"])["input_ids"][1:-1] - ) + n_tokens = len(self.tokenizer(inputs["answer"]["text"])["input_ids"][1:-1]) l_disp, r_disp = self._get_lr_dispersion(n_tokens, dispersion) masked_inputs = [ self._mask_answer(**inputs, n_mask=n_mask) @@ -228,17 +204,12 @@ def preprocess( replace=False, ) ] - model_inputs = [ - self.tokenizer(masked_input, return_tensors=return_tensors) - for masked_input in masked_inputs - ] + model_inputs = [self.tokenizer(masked_input, return_tensors=return_tensors) for masked_input in masked_inputs] for model_input in model_inputs: self.ensure_exactly_one_mask_token(model_input) return model_inputs - def postprocess_all_outputs( - self, all_outputs: List[Dict[str, Any]], top_k: int - ) -> List[Dict[str, Any]]: + def postprocess_all_outputs(self, all_outputs: List[Dict[str, Any]], top_k: int) -> List[Dict[str, Any]]: return sorted(all_outputs, key=lambda d: d["ranking_score"], reverse=True) @staticmethod @@ -267,10 +238,7 @@ def _preprocess_input_for_distractor_evaluation( sentences = self.split_sentences(inputs["context"]) updated_answer = dict(text=inputs["answer"]["text"]) for candid in sentences: - if ( - inputs["answer"]["start"] >= candid.start_char - and inputs["answer"]["end"] <= candid.end_char - ): + if inputs["answer"]["start"] >= candid.start_char and inputs["answer"]["end"] <= candid.end_char: start, end = self.find_span_start_end( span_start=inputs["answer"]["start"], span_end=inputs["answer"]["end"], @@ -289,9 +257,7 @@ def _preprocess_input_for_distractor_evaluation( def _evaluate_answer_distractors( self, inputs: Dict[str, Any], outputs: List[Dict[str, Any]], top_k: int = 3 ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: - processed_input = self._preprocess_input_for_distractor_evaluation( - inputs, outputs - ) + processed_input = self._preprocess_input_for_distractor_evaluation(inputs, outputs) # Compare distractors with the answer answer_distractor_evaluation_results = self.evaluator(processed_input) check = lambda x: "contradiction" in x[0] or "neutral" in x[0] @@ -307,9 +273,7 @@ def _evaluate_answer_distractors( while kept_index < top_k < len(processed_input["distractors"]): processed_input["distractors"][kept_index]["nli_output"] = [] for j in range(kept_index): - nli_out = self.evaluator( - processed_input, distractor_ids=(j, kept_index) - ) + nli_out = self.evaluator(processed_input, distractor_ids=(j, kept_index)) processed_input["distractors"][kept_index]["nli_output"].append( { "result": nli_out, @@ -318,9 +282,7 @@ def _evaluate_answer_distractors( ) if nli_out == "entailment-entailment": increment = False - discarded_distractors.append( - processed_input["distractors"].pop(kept_index) - ) + discarded_distractors.append(processed_input["distractors"].pop(kept_index)) break increment = True @@ -338,9 +300,7 @@ def _finalize_generation_outputs( self._fix_cocktail_shaker_list(output["token_list"]) self._fix_cocktail_shaker_list(output["token_str"]) outputs[i]["token_str_list"] = output["token_str"] - outputs[i]["token_str"] = self.tokenizer.decode( - output["token_list"] - ).strip() + outputs[i]["token_str"] = self.tokenizer.decode(output["token_list"]).strip() outputs[i]["score"] = np.prod(output["score_list"]) if self._use_harmonic_mean: outputs[i]["ranking_score"] = harmonic_mean(output["score_list"]) @@ -360,9 +320,7 @@ def _generate_distractors( ): idx = -1 if reverse else 0 try: - masked_indices = self.get_masked_index( - model_inputs[0]["input_ids"], as_tuple=True - ) + masked_indices = self.get_masked_index(model_inputs[0]["input_ids"], as_tuple=True) current_mask_index = masked_indices[idx] except (KeyError, IndexError): # End of the generation return self._finalize_generation_outputs( @@ -377,28 +335,20 @@ def _generate_distractors( if len(masked_indices) == 1: is_last = True - model_outputs = [ - self.forward(model_input, **forward_params) - for model_input in model_inputs - ] + model_outputs = [self.forward(model_input, **forward_params) for model_input in model_inputs] postprocess_params_ = deepcopy(postprocess_params) if is_start: postprocess_params_["top_k"] *= self._search_multiplier else: postprocess_params_["top_k"] = 1 - outputs = [ - self.postprocess(model_output, **postprocess_params_) - for model_output in model_outputs - ] + outputs = [self.postprocess(model_output, **postprocess_params_) for model_output in model_outputs] if is_start and is_last: outputs = [outputs] new_model_inputs = [] prev_outputs_ = [] for i, output in enumerate(outputs): if is_start: - output_at_idx = ( - [output[idx]] if isinstance(output[idx], dict) else output[idx] - ) + output_at_idx = [output[idx]] if isinstance(output[idx], dict) else output[idx] for out in output_at_idx: model_inputs_ = deepcopy(model_inputs[0]) model_inputs_["input_ids"][current_mask_index] = out["token"] @@ -410,9 +360,7 @@ def _generate_distractors( model_inputs_ = output[0] prev_outputs_.append(output[0]) else: - model_inputs_["input_ids"][current_mask_index] = output[idx][0][ - "token" - ] + model_inputs_["input_ids"][current_mask_index] = output[idx][0]["token"] prev_outputs_.append(output[idx][0]) new_model_inputs.append(model_inputs_) @@ -425,15 +373,9 @@ def _generate_distractors( else: for i, prev_output in enumerate(prev_outputs_): insert_idx = 0 if reverse else len(prev_outputs[i]["token_list"]) - prev_outputs[i]["score_list"].insert( - insert_idx, prev_output["score"] - ) - prev_outputs[i]["token_list"].insert( - insert_idx, prev_output["token"] - ) - prev_outputs[i]["token_str"].insert( - insert_idx, prev_output["token_str"] - ) + prev_outputs[i]["score_list"].insert(insert_idx, prev_output["score"]) + prev_outputs[i]["token_list"].insert(insert_idx, prev_output["token"]) + prev_outputs[i]["token_str"].insert(insert_idx, prev_output["token_str"]) prev_outputs[i]["sequence"] = prev_output["sequence"] if cocktail_shaker: @@ -470,20 +412,14 @@ def generate_distractors( cocktail_shaker=cocktail_shaker, ) all_outputs.extend(outputs) - all_outputs = self.postprocess_all_outputs( - all_outputs, **postprocess_params - ) + all_outputs = self.postprocess_all_outputs(all_outputs, **postprocess_params) return all_outputs def generate(self, model_inputs, forward_params, postprocess_params): if self._decoding == "l2r": - return self.generate_distractors( - model_inputs, forward_params, postprocess_params - ) + return self.generate_distractors(model_inputs, forward_params, postprocess_params) elif self._decoding == "r2l": - return self.generate_distractors( - model_inputs, forward_params, postprocess_params, reverse=True - ) + return self.generate_distractors(model_inputs, forward_params, postprocess_params, reverse=True) elif self._decoding == "ctl": return self.generate_distractors( model_inputs, @@ -493,20 +429,13 @@ def generate(self, model_inputs, forward_params, postprocess_params): ) else: raise ValueError( - f"Unknown unmasking strategy '{self._decoding}'. Supported types are " - f"('l2r', 'r2l', 'ctl')" + f"Unknown unmasking strategy '{self._decoding}'. Supported types are " f"('l2r', 'r2l', 'ctl')" ) - def run_single( - self, inputs, preprocess_params, forward_params, postprocess_params - ) -> DistractorGenerationOutput: + def run_single(self, inputs, preprocess_params, forward_params, postprocess_params) -> DistractorGenerationOutput: model_inputs = self.preprocess(inputs, **preprocess_params) - all_outputs = self.generate( - model_inputs, forward_params, postprocess_params - ) + all_outputs = self.generate(model_inputs, forward_params, postprocess_params) kept, discarded = self._evaluate_answer_distractors( inputs=inputs, outputs=all_outputs, top_k=postprocess_params["top_k"] ) - return DistractorGenerationOutput( - distractors=kept, discarded_distractors=discarded - ) \ No newline at end of file + return DistractorGenerationOutput(distractors=kept, discarded_distractors=discarded) diff --git a/disgem/util.py b/disgem/util.py index d7e23f4..a7c89f2 100644 --- a/disgem/util.py +++ b/disgem/util.py @@ -39,4 +39,3 @@ class DistractorGenerationOutput(ModelOutput): distractors: List[Dict] = None discarded_distractors: List[Dict] = None - diff --git a/environment.yml b/environment.yml index 31e518e..a6ddcef 100644 --- a/environment.yml +++ b/environment.yml @@ -14,9 +14,7 @@ dependencies: - click==8.0.4 - deepdiff==6.3.0 - en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.3.0/en_core_web_sm-3.3.0-py3-none-any.whl - - fire - flake8==3.9.2 - - gensim==4.3.1 - isort==5.9.2 - notebook - transformers==4.19.4 diff --git a/generate.py b/generate.py index 0449407..04c27ce 100644 --- a/generate.py +++ b/generate.py @@ -5,227 +5,295 @@ import numpy as np from tqdm import tqdm -from transformers import FillMaskPipeline, AutoModelForMaskedLM, AutoTokenizer, set_seed +from transformers import AutoModelForMaskedLM, AutoTokenizer, FillMaskPipeline, set_seed from disgem import MaskedLMBasedDistractorGenerator -from disgem.data_loader import ClothLoader, CdgpClothLoader, SquadLoader, DGenLoader -from disgem.util import read_json, harmonic_mean +from disgem.data_loader import CdgpClothLoader, ClothLoader, DGenLoader, SquadLoader +from disgem.util import harmonic_mean, read_json def create_args(): - parser = argparse.ArgumentParser( - prog="DisGeM", - description="Distractor Generator for MCQ" - ) - parser.add_argument("filepath", type=str, help="Path to SQuAD style data.") - parser.add_argument("--data-format", type=str, default="squad", choices= ["cloth", "cdgp-cloth", "squad", "dgen"], - help="Data format whether SQuAD style or CLOTH style dataset. Default 'squad'.") - parser.add_argument("--model", type=str, default="roberta-large", help="Masked LM for distractor generation phase. Models are loaded from huggingface hub. Default 'roberta-large'.") - parser.add_argument("--top-k", type=int, default=3, help="Number of distractors. By default 3.") - parser.add_argument("--batch-size", type=int, default=1, help="Batch size, batched inference might be even slower, " - "see https://huggingface.co/docs/transformers/main_classes/pipelines#pipeline-batching. By default 1.") - parser.add_argument("--output-path", type=str, default=None, - help="File path to dump outputs. By default no output file is created.") - parser.add_argument("--output-format", type=str, default="cdgp", choices=["cdgp", "all"]) - parser.add_argument("--question-limit", type=int, default=100, help="Question limit to stop generation at. Default 100.") - parser.add_argument("--dispersion", type=int, default=1, help="Dispersion parameter to determine interval for sampling num mask tokens. By default 1.") - parser.add_argument("--device", type=int, default=-1, help="Device for generation phase. Set -1 for cpu, numbers 0,1,2,... refer to that gpu device. By default -1.") - parser.add_argument("--no-minify-output", action="store_true", help="If given, no minification is placed on outputs.") - parser.add_argument("--decoding", type=str, default="l2r", choices=["l2r", "r2l", "ctl"], - help="Generation strategy for the generation phase.By default 'snowball'.") - parser.add_argument("--n-mask", type=int, default=None, help="Number of mask tokens to be replaced with answer text. Default `none`.") - parser.add_argument("--use-geometric-mean", action="store_true", help="If given, uses geometric mean to determine final ranking, otherwise uses harmonic mean.") - parser.add_argument("--single-mask", action="store_true", help="If given, only applies a single mask to replace the answer. It is the same as setting `dispersion=0` and `n_mask=1`.") - parser.add_argument("--seed", type=int, default=42, help="Seed for RNG. Default 42.") - parser.add_argument("--prepend-question", type=str, default="none", choices=["none", "begin", "mid"], - help="If not `none`, prepends `question` to the context to guide the distractor generation with the question. " - "Default option is `none`.") - parser.add_argument("--evaluate", action="store_true", help="If given, starts evaluation process rather than generation. You must supply result json file for evaluation.") - return parser.parse_args() + parser = argparse.ArgumentParser(prog="DisGeM", description="Distractor Generator for MCQ") + parser.add_argument("filepath", type=str, help="Path to SQuAD style data.") + parser.add_argument( + "--data-format", + type=str, + default="squad", + choices=["cloth", "cdgp-cloth", "squad", "dgen"], + help="Data format whether SQuAD style or CLOTH style dataset. Default 'squad'.", + ) + parser.add_argument( + "--model", + type=str, + default="roberta-large", + help="Masked LM for distractor generation phase. Models are loaded from huggingface hub. Default 'roberta-large'.", + ) + parser.add_argument("--top-k", type=int, default=3, help="Number of distractors. By default 3.") + parser.add_argument( + "--batch-size", + type=int, + default=1, + help="Batch size, batched inference might be even slower, " + "see https://huggingface.co/docs/transformers/main_classes/pipelines#pipeline-batching. By default 1.", + ) + parser.add_argument( + "--output-path", type=str, default=None, help="File path to dump outputs. By default no output file is created." + ) + parser.add_argument("--output-format", type=str, default="cdgp", choices=["cdgp", "all"]) + parser.add_argument( + "--question-limit", type=int, default=100, help="Question limit to stop generation at. Default 100." + ) + parser.add_argument( + "--dispersion", + type=int, + default=1, + help="Dispersion parameter to determine interval for sampling num mask tokens. By default 1.", + ) + parser.add_argument( + "--device", + type=int, + default=-1, + help="Device for generation phase. Set -1 for cpu, numbers 0,1,2,... refer to that gpu device. By default -1.", + ) + parser.add_argument( + "--no-minify-output", action="store_true", help="If given, no minification is placed on outputs." + ) + parser.add_argument( + "--decoding", + type=str, + default="l2r", + choices=["l2r", "r2l", "ctl"], + help="Generation strategy for the generation phase.By default 'snowball'.", + ) + parser.add_argument( + "--n-mask", + type=int, + default=None, + help="Number of mask tokens to be replaced with answer text. Default `none`.", + ) + parser.add_argument( + "--use-geometric-mean", + action="store_true", + help="If given, uses geometric mean to determine final ranking, otherwise uses harmonic mean.", + ) + parser.add_argument( + "--single-mask", + action="store_true", + help="If given, only applies a single mask to replace the answer. It is the same as setting `dispersion=0` and `n_mask=1`.", + ) + parser.add_argument("--seed", type=int, default=42, help="Seed for RNG. Default 42.") + parser.add_argument( + "--prepend-question", + type=str, + default="none", + choices=["none", "begin", "mid"], + help="If not `none`, prepends `question` to the context to guide the distractor generation with the question. " + "Default option is `none`.", + ) + parser.add_argument( + "--evaluate", + action="store_true", + help="If given, starts evaluation process rather than generation. You must supply result json file for evaluation.", + ) + return parser.parse_args() def main(args): - if args.prepend_question != "none": - warnings.warn("`--prepend-question` is only available for squad format.") - if args.batch_size > 1: - warnings.warn("Currently, batched inference is not supported.") - args.batch_size = 1 - if args.data_format == "cloth": - data_loader = ClothLoader(args.filepath) - elif args.data_format == "cdgp-cloth": - data_loader = CdgpClothLoader(args.filepath) - elif args.data_format == "dgen": - data_loader = DGenLoader(args.filepath) - elif args.data_format == "squad": - data_loader = SquadLoader(args.filepath, prepend_question=args.prepend_question) - else: - raise ValueError(f"Unknown data format {args.data_format}.") - - distractor_generator = MaskedLMBasedDistractorGenerator( - pretrained_model_name_or_path=args.model, - dispersion=args.dispersion, - n_mask=args.n_mask, - device=args.device, - decoding=args.decoding, - single_mask=args.single_mask - ) - - squad_answers = [] - outputs = [] - count = 0 - pbar = tqdm(data_loader) - for instance in pbar: - pbar.set_postfix({"count": count}) - if count == args.question_limit: - break - - dgen_tokenizer = distractor_generator._pipeline.tokenizer - if len(dgen_tokenizer.encode(instance.context)) > dgen_tokenizer.model_max_length: - # Skip if tokenized context does not fit into model max input length - continue - - if args.data_format == "squad": - if args.prepend_question: - pass - elif instance.answer in squad_answers: - # squad contains different questions for some answer spans. Since our - # framework does not depend on question, we skip these questions as - # it would yield the same distractors. - continue - else: - squad_answers.append(instance.answer) - - generations = distractor_generator( - context=instance.context, - answer=instance.answer, - minify_output=not args.no_minify_output, - top_k=args.top_k, - use_harmonic_mean=not args.use_geometric_mean, - batch_size=args.batch_size - ) - if args.data_format == "squad": # no gt distractors/evaluation, put context as well - outputs.append( - { - "context": instance.context, - "question": instance.question, - "answer": instance.answer, - "generations": generations - } - ) - else: - if args.output_format == "cdgp": - outputs.append( - { - "generations": generations, - "distractors": instance.distractors - } - ) - else: - # For better readability, put blank in the output - ctx = instance.context[:instance.answer["start"]] + " ____ " + instance.context[instance.answer["end"]:] - outputs.append( - { - "context": ctx, - "answer": instance.answer["text"], - "generations": generations, - "distractors": instance.distractors - } - ) - count += 1 - - if args.output_path is not None: - output_path = Path(args.output_path) - output_path.parent.mkdir(parents=True, exist_ok=True) - with open(output_path.as_posix(), "w") as fd_out: - json.dump(outputs, fd_out) + if args.prepend_question != "none": + warnings.warn("`--prepend-question` is only available for squad format.") + if args.batch_size > 1: + warnings.warn("Currently, batched inference is not supported.") + args.batch_size = 1 + if args.data_format == "cloth": + data_loader = ClothLoader(args.filepath) + elif args.data_format == "cdgp-cloth": + data_loader = CdgpClothLoader(args.filepath) + elif args.data_format == "dgen": + data_loader = DGenLoader(args.filepath) + elif args.data_format == "squad": + data_loader = SquadLoader(args.filepath, prepend_question=args.prepend_question) + else: + raise ValueError(f"Unknown data format {args.data_format}.") + + distractor_generator = MaskedLMBasedDistractorGenerator( + pretrained_model_name_or_path=args.model, + dispersion=args.dispersion, + n_mask=args.n_mask, + device=args.device, + decoding=args.decoding, + single_mask=args.single_mask, + ) + + squad_answers = [] + outputs = [] + count = 0 + pbar = tqdm(data_loader) + for instance in pbar: + pbar.set_postfix({"count": count}) + if count == args.question_limit: + break + + dgen_tokenizer = distractor_generator._pipeline.tokenizer + if len(dgen_tokenizer.encode(instance.context)) > dgen_tokenizer.model_max_length: + # Skip if tokenized context does not fit into model max input length + continue + + if args.data_format == "squad": + if args.prepend_question: + pass + elif instance.answer in squad_answers: + # squad contains different questions for some answer spans. Since our + # framework does not depend on question, we skip these questions as + # it would yield the same distractors. + continue + else: + squad_answers.append(instance.answer) + + generations = distractor_generator( + context=instance.context, + answer=instance.answer, + minify_output=not args.no_minify_output, + top_k=args.top_k, + use_harmonic_mean=not args.use_geometric_mean, + batch_size=args.batch_size, + ) + if args.data_format == "squad": # no gt distractors/evaluation, put context as well + outputs.append( + { + "context": instance.context, + "question": instance.question, + "answer": instance.answer, + "generations": generations, + } + ) + else: + if args.output_format == "cdgp": + outputs.append({"generations": generations, "distractors": instance.distractors}) + else: + # For better readability, put blank in the output + ctx = ( + instance.context[: instance.answer["start"]] + " ____ " + instance.context[instance.answer["end"] :] + ) + outputs.append( + { + "context": ctx, + "answer": instance.answer["text"], + "generations": generations, + "distractors": instance.distractors, + } + ) + count += 1 + + if args.output_path is not None: + output_path = Path(args.output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path.as_posix(), "w") as fd_out: + json.dump(outputs, fd_out) def evaluate(args): - """ - - Args: - args: - - Returns: - - """ - - # metrics - def precision(preds, targets, k: int = 1): - matches = [int(generation in targets) for generation in preds] - return sum(matches[:k]) / k - - def recall(preds, targets, k: int = 1): - matches = [int(generation in targets) for generation in preds] - return sum(matches[:k]) / len(targets) - - def f1(preds, targets, k: int = 1): - p = precision(preds, targets, k) - r = recall(preds, targets, k) - return harmonic_mean([p, r]) - - def ndcg_at_k(preds, targets, k: int = 1): - def dcg_at_k(r, k): - r = np.asfarray(r)[:k] - if r.size: - return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1))) - return 0. - r = [int(generation in targets) for generation in preds] - idcg = dcg_at_k(sorted(r, reverse=True), k) - if not idcg: - return 0. - return dcg_at_k(r, k) / idcg - - def mmr_at_k(preds, targets, k: int = 1): - matches = [int(generation in targets) for generation in preds] - k = len(matches) if k > len(matches) else k - for i in range(k): - if matches[i] == 1: - return 1 / (i+1) - return .0 - - outputs = read_json(args.filepath) - avg_eval = { - "P@1" : 0.0, "P@3": 0.0, "P@5": 0.0, "P@10" : 0.0, - "R@1": 0.0, "R@3": 0.0, "R@5": 0.0, "R@10": 0.0, - "F1@1": 0.0, "F1@3": 0.0, "F1@5" : 0.0, "F1@10": 0.0, - "MRR@1": 0.0, "MRR@3": 0.0, "MRR@5": 0.0, "MRR@10": 0.0, - "NDCG@1": 0.0, "NDCG@3": 0.0, "NDCG@5": 0.0, "NDCG@10": 0.0} - for output in outputs: - distractors = [d.lower() for d in output["distractors"]] - generations = [d.lower() for d in output["generations"]] - - for key in avg_eval.keys(): - metric, k = key.split("@") - if metric == "P": - metric_fn = precision - elif metric == "R": - metric_fn = recall - elif metric == "F1": - metric_fn = f1 - elif metric == "NDCG": - metric_fn = ndcg_at_k - elif metric == "MRR": - metric_fn = mmr_at_k - else: - continue - avg_eval[key] += metric_fn(preds=generations, targets=distractors, k=int(k)) - - # calculate average - for key in avg_eval.keys(): - avg_eval[key] /= len(outputs) - avg_eval[key] = str(round(100 * avg_eval[key], 4)) + "%" - - print(json.dumps(avg_eval, indent=2)) - if args.output_path is not None: - with open(args.output_path, "w") as fd_out: - json.dump(avg_eval, fd_out, indent=2) + """ + + Args: + args: + + Returns: + + """ + + # metrics + def precision(preds, targets, k: int = 1): + matches = [int(generation in targets) for generation in preds] + return sum(matches[:k]) / k + + def recall(preds, targets, k: int = 1): + matches = [int(generation in targets) for generation in preds] + return sum(matches[:k]) / len(targets) + + def f1(preds, targets, k: int = 1): + p = precision(preds, targets, k) + r = recall(preds, targets, k) + return harmonic_mean([p, r]) + + def ndcg_at_k(preds, targets, k: int = 1): + def dcg_at_k(r, k): + r = np.asfarray(r)[:k] + if r.size: + return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1))) + return 0.0 + + r = [int(generation in targets) for generation in preds] + idcg = dcg_at_k(sorted(r, reverse=True), k) + if not idcg: + return 0.0 + return dcg_at_k(r, k) / idcg + + def mmr_at_k(preds, targets, k: int = 1): + matches = [int(generation in targets) for generation in preds] + k = len(matches) if k > len(matches) else k + for i in range(k): + if matches[i] == 1: + return 1 / (i + 1) + return 0.0 + + outputs = read_json(args.filepath) + avg_eval = { + "P@1": 0.0, + "P@3": 0.0, + "P@5": 0.0, + "P@10": 0.0, + "R@1": 0.0, + "R@3": 0.0, + "R@5": 0.0, + "R@10": 0.0, + "F1@1": 0.0, + "F1@3": 0.0, + "F1@5": 0.0, + "F1@10": 0.0, + "MRR@1": 0.0, + "MRR@3": 0.0, + "MRR@5": 0.0, + "MRR@10": 0.0, + "NDCG@1": 0.0, + "NDCG@3": 0.0, + "NDCG@5": 0.0, + "NDCG@10": 0.0, + } + for output in outputs: + distractors = [d.lower() for d in output["distractors"]] + generations = [d.lower() for d in output["generations"]] + + for key in avg_eval.keys(): + metric, k = key.split("@") + if metric == "P": + metric_fn = precision + elif metric == "R": + metric_fn = recall + elif metric == "F1": + metric_fn = f1 + elif metric == "NDCG": + metric_fn = ndcg_at_k + elif metric == "MRR": + metric_fn = mmr_at_k + else: + continue + avg_eval[key] += metric_fn(preds=generations, targets=distractors, k=int(k)) + + # calculate average + for key in avg_eval.keys(): + avg_eval[key] /= len(outputs) + avg_eval[key] = str(round(100 * avg_eval[key], 4)) + "%" + + print(json.dumps(avg_eval, indent=2)) + if args.output_path is not None: + with open(args.output_path, "w") as fd_out: + json.dump(avg_eval, fd_out, indent=2) if __name__ == "__main__": - args = create_args() - set_seed(args.seed) - if args.evaluate: - evaluate(args) - else: - main(args) + args = create_args() + set_seed(args.seed) + if args.evaluate: + evaluate(args) + else: + main(args) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..3dbb71a --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,14 @@ +[tool.black] +line-length = 120 +exclude = ''' +( + /( + | .git + | venv + | .venv + )/ +) +''' + +[tool.pytest.ini_options] +tmp_path_retention_policy = "failed" \ No newline at end of file diff --git a/scripts/run_code_style.py b/scripts/run_code_style.py new file mode 100644 index 0000000..55aabed --- /dev/null +++ b/scripts/run_code_style.py @@ -0,0 +1,16 @@ +import sys + +from scripts.utils import shell, validate_and_exit + +if __name__ == "__main__": + arg = sys.argv[1] + + if arg == "check": + sts_flake = shell("flake8 src tests services --config setup.cfg") + sts_isort = shell("isort . --check --settings setup.cfg") + sts_black = shell("black . --check --config pyproject.toml") + validate_and_exit(flake8=sts_flake, isort=sts_isort, black=sts_black) + elif arg == "format": + sts_isort = shell("isort . --settings setup.cfg") + sts_black = shell("black . --config pyproject.toml") + validate_and_exit(isort=sts_isort, black=sts_black) diff --git a/scripts/utils.py b/scripts/utils.py new file mode 100644 index 0000000..5b25767 --- /dev/null +++ b/scripts/utils.py @@ -0,0 +1,41 @@ +import os +import shutil +import sys + + +def shell(command, exit_status=0): + """ + Run command through shell and return exit status if exit status of command run match with given exit status. + + Args: + command: (str) Command string which runs through system shell. + exit_status: (int) Expected exit status of given command run. + + Returns: actual_exit_status + + """ + actual_exit_status = os.system(command) + if actual_exit_status == exit_status: + return 0 + return actual_exit_status + + +def validate_and_exit(expected_out_status=0, **kwargs): + if all([arg == expected_out_status for arg in kwargs.values()]): + # Expected status, OK + sys.exit(0) + else: + # Failure + print_console_centered("Summary Results") + fail_count = 0 + for component, exit_status in kwargs.items(): + if exit_status != expected_out_status: + print(f"{component} failed.") + fail_count += 1 + print_console_centered(f"{len(kwargs)-fail_count} success, {fail_count} failure") + sys.exit(1) + + +def print_console_centered(text: str, fill_char="="): + w, _ = shutil.get_terminal_size((80, 20)) + print(f" {text} ".center(w, fill_char)) diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..85f524b --- /dev/null +++ b/setup.cfg @@ -0,0 +1,9 @@ +[flake8] +max-line-length = 120 +select = E9,F63,F7,F82 +per-file-ignores = __init__.py: F401 +max-complexity = 10 + +[isort] +line_length=120 +profile=black