Skip to content

Commit

Permalink
Codestyle, github workflow.
Browse files Browse the repository at this point in the history
  • Loading branch information
devrimcavusoglu committed Oct 14, 2024
1 parent 3454c86 commit 481efc2
Show file tree
Hide file tree
Showing 14 changed files with 490 additions and 354 deletions.
58 changes: 58 additions & 0 deletions .github/workflows/check_format.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
name: Tests

on:
push:
branches: [main]
pull_request:
branches: [main]

jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
operating-system: [ubuntu-latest, windows-latest, macos-latest]
# for Python 3.10, ref https://github.com/actions/setup-python/issues/160#issuecomment-724485470
python-version: [3.8, 3.9, '3.10', '3.11']
fail-fast: false

steps:
- name: Checkout
uses: actions/checkout@v2

- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}

- name: Restore Ubuntu cache
uses: actions/cache@v1
if: matrix.operating-system == 'ubuntu-latest'
with:
path: ~/.cache/pip
key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}}
restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}-

- name: Restore MacOS cache
uses: actions/cache@v1
if: matrix.operating-system == 'macos-latest'
with:
path: ~/Library/Caches/pip
key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}}
restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}-

- name: Restore Windows cache
uses: actions/cache@v1
if: matrix.operating-system == 'windows-latest'
with:
path: ~\AppData\Local\pip\Cache
key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}}
restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}-

- name: Conda env create
run: conda env create -f environment.yml

- name: Lint with flake8 and black
run: |
conda activate disgem
python -m scripts.run_code_style check
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Clone the repository.

```bash
git clone https://github.com/obss/disgem.git
cd disgem
```

In the project root, create a virtual environment (preferably using conda) as follows:
Expand All @@ -43,3 +44,20 @@ The following provides an example to generate distractors for CLOTH test-high da
```shell
python -m generate data/CLOTH/test/high --data-format cloth --top-k 3 --dispersion 0 --output-path cloth_test_outputs.json
```


## Contributing

Format and check the code style of the codebase as follows.

To check the codestyle,

```bash
python -m scripts.run_code_style check
```

To format the codebase,

```bash
python -m scripts.run_code_style format
```
2 changes: 1 addition & 1 deletion disgem/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from disgem.distractor_generator import MaskedLMBasedDistractorGenerator
from disgem.distractor_generator import MaskedLMBasedDistractorGenerator
33 changes: 17 additions & 16 deletions disgem/data_loader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import List, Dict, Union, OrderedDict
from typing import Dict, List, Union

from disgem.util import read_json

Expand Down Expand Up @@ -38,12 +38,7 @@ def __iter__(self) -> Instance:
for i, ans in enumerate(instance.answers):
d = instance.distractors_collection[i] if instance.distractors_collection is not None else None
q = instance.questions[i] if instance.questions is not None else None
yield Instance(
context=instance.context,
answer=ans,
distractors=d,
question=q
)
yield Instance(context=instance.context, answer=ans, distractors=d, question=q)


class SquadLoader(DataLoader):
Expand Down Expand Up @@ -81,7 +76,12 @@ def read(self, filepath):
answer["end"] += len(question) + 1
elif self.prepend_question == "mid":
# the following prepends just before the mask.
ctx = paragraph["context"][:answer["start"]] + question + " " + paragraph["context"][answer["start"]:]
ctx = (
paragraph["context"][: answer["start"]]
+ question
+ " "
+ paragraph["context"][answer["start"] :]
)
answer["start"] += len(question) + 1
answer["end"] += len(question) + 1
elif self.prepend_question == "end":
Expand All @@ -102,6 +102,7 @@ class ClothLoader(DataLoader):
a form compatible for distractor generation.
See the home page for CLOTH: https://www.cs.cmu.edu/~glai1/data/cloth/
"""

_dataset_mask_str = " _ "

def __iter__(self) -> Instance:
Expand All @@ -112,11 +113,7 @@ def __iter__(self) -> Instance:
"""
for instance in self.dataset:
for i, ans in enumerate(instance.answers):
yield Instance(
context=instance.context,
answer=ans,
distractors=instance.distractors_collection[i]
)
yield Instance(context=instance.context, answer=ans, distractors=instance.distractors_collection[i])

@staticmethod
def replace_nth(text: str, substr: str, replace: str, nth: int):
Expand Down Expand Up @@ -145,7 +142,7 @@ def read(self, filepath):
assert filepath.is_dir(), "`filepath` for CLOTH dataset needs to be a directory."

instances = []
files = sorted(filepath.glob('*.json'))
files = sorted(filepath.glob("*.json"))
for p in files:
data = read_json(p)
ctx = data["article"]
Expand All @@ -158,7 +155,7 @@ def read(self, filepath):
answer = {"text": ans, "start": start}
answer["end"] = answer["start"] + len(answer["text"])
answers.append(answer)
distractors.append(choices[:opt] + choices[opt+1:]) # remove the answer from choices
distractors.append(choices[:opt] + choices[opt + 1 :]) # remove the answer from choices
assert len(answers) == len(distractors), "The length of the `answers` and `distractors` must be equal."
instances.append(InstanceCollection(context=ctx, answers=answers, distractors_collection=distractors))
return instances
Expand All @@ -171,6 +168,7 @@ class CdgpClothLoader(DataLoader):
as it is used and published in a related work.
See the home page for CDGP style CLOTH: https://huggingface.co/datasets/AndyChiang/cloth
"""

_dataset_mask_str = " [MASK] "

def read(self, filepath):
Expand All @@ -186,7 +184,9 @@ def read(self, filepath):
else:
ctx = ctx.replace(self._dataset_mask_str, ans, 1)
answers = [{"text": ans, "start": start, "end": start + len(ans)}]
instances.append(InstanceCollection(context=ctx, answers=answers, distractors_collection=[instance["distractors"]]))
instances.append(
InstanceCollection(context=ctx, answers=answers, distractors_collection=[instance["distractors"]])
)
return instances


Expand All @@ -196,4 +196,5 @@ class DGenLoader(CdgpClothLoader):
in a form compatible for distractor generation.
See the home page for DGEN Dataset: AndyChiang/dgen
"""

_dataset_mask_str = "**blank**"
21 changes: 6 additions & 15 deletions disgem/distractor_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def _load_model_and_tokenizer(self, model_name_or_path: str):
def __call__(self, inputs: Dict[str, Any], *args, **kwargs):
pass


class NLIBasedDistractorEvaluator(BaseDistractorEvaluator):
"""
NLI based distractor evaluation, meant to be designed to provide classification
Expand Down Expand Up @@ -81,9 +82,7 @@ def preprocess(
return sentence
if isinstance(distractor, dict):
distractor = distractor["token_str"]
context_with_distractor = replace_str(
sentence, distractor, answer["start"], answer["end"]
)
context_with_distractor = replace_str(sentence, distractor, answer["start"], answer["end"])
if reverse:
return context_with_distractor + " " + sentence
return sentence + " " + context_with_distractor
Expand All @@ -101,12 +100,8 @@ def preprocess_distractors(
if isinstance(distractor2, dict):
distractor2 = distractor2["token_str"]

context_with_d1 = replace_str(
sentence, distractor1, answer["start"], answer["end"]
)
context_with_d2 = replace_str(
sentence, distractor2, answer["start"], answer["end"]
)
context_with_d1 = replace_str(sentence, distractor1, answer["start"], answer["end"])
context_with_d2 = replace_str(sentence, distractor2, answer["start"], answer["end"])

return context_with_d1 + " " + context_with_d2

Expand Down Expand Up @@ -135,13 +130,9 @@ def __call__(
else:
distractor1 = inputs["distractors"][distractor_ids[0]]
distractor2 = inputs["distractors"][distractor_ids[1]]
input_text = self.preprocess_distractors(
**inputs, distractor1=distractor1, distractor2=distractor2
)
input_text = self.preprocess_distractors(**inputs, distractor1=distractor1, distractor2=distractor2)
nli_output = self.get_model_output(input_text)

input_text_rev = self.preprocess_distractors(
**inputs, distractor1=distractor2, distractor2=distractor1
)
input_text_rev = self.preprocess_distractors(**inputs, distractor1=distractor2, distractor2=distractor1)
nli_output_rev = self.get_model_output(input_text_rev)
return f"{nli_output}-{nli_output_rev}"
12 changes: 3 additions & 9 deletions disgem/distractor_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,8 @@

from transformers import AutoModelForMaskedLM, AutoTokenizer

from disgem.pipeline import (
DistractorGenerationPipeline,
)
from disgem.util import (
DistractorGenerationOutput,
)
from disgem.pipeline import DistractorGenerationPipeline
from disgem.util import DistractorGenerationOutput


class MaskedLMBasedDistractorGenerator:
Expand Down Expand Up @@ -61,7 +57,5 @@ def __call__(
for distractor in outputs.distractors
]
else:
return [
distractor["token_str"] for distractor in outputs.distractors
]
return [distractor["token_str"] for distractor in outputs.distractors]
return outputs
Loading

0 comments on commit 481efc2

Please sign in to comment.