Skip to content

Commit

Permalink
Merge pull request #10 from obss/random-seed
Browse files Browse the repository at this point in the history
Add Global Seed Remove Individual Seed
  • Loading branch information
devrimcavusoglu authored Sep 24, 2024
2 parents f2e24fc + 096816a commit 6954681
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 5 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,5 @@ cython_debug/

# MISC
/data
/outputs
/outputs
/experiments
3 changes: 0 additions & 3 deletions disgem/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,6 @@ def _evaluate_answer_distractors(
answer_distractor_evaluation_results = self.evaluator(processed_input)
check = lambda x: "contradiction" in x[0] or "neutral" in x[0]
mask = np.apply_along_axis(check, arr=np.array(answer_distractor_evaluation_results).reshape(-1, 1), axis=-1)
# mask = np.array(answer_distractor_evaluation_results) != "contradiction"
outputs_array = np.array(outputs, dtype=dict)
filtered_distractors = outputs_array[mask]

Expand Down Expand Up @@ -501,8 +500,6 @@ def generate(self, model_inputs, forward_params, postprocess_params):
def run_single(
self, inputs, preprocess_params, forward_params, postprocess_params
) -> DistractorGenerationOutput:
seed = forward_params.pop("seed")
set_seed(seed)
model_inputs = self.preprocess(inputs, **preprocess_params)
all_outputs = self.generate(
model_inputs, forward_params, postprocess_params
Expand Down
3 changes: 2 additions & 1 deletion generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np
from tqdm import tqdm
from transformers import FillMaskPipeline, AutoModelForMaskedLM, AutoTokenizer
from transformers import FillMaskPipeline, AutoModelForMaskedLM, AutoTokenizer, set_seed

from disgem import MaskedLMBasedDistractorGenerator
from disgem.data_loader import ClothLoader, CdgpClothLoader, SquadLoader, DGenLoader
Expand Down Expand Up @@ -224,6 +224,7 @@ def mmr_at_k(preds, targets, k: int = 1):

if __name__ == "__main__":
args = create_args()
set_seed(args.seed)
if args.evaluate:
evaluate(args)
else:
Expand Down

0 comments on commit 6954681

Please sign in to comment.