diff --git a/.gitignore b/.gitignore index aa9306a..7c0b8d3 100644 --- a/.gitignore +++ b/.gitignore @@ -161,4 +161,5 @@ cython_debug/ # MISC /data -/outputs \ No newline at end of file +/outputs +/experiments diff --git a/disgem/pipeline.py b/disgem/pipeline.py index 1e915da..6c99d0d 100644 --- a/disgem/pipeline.py +++ b/disgem/pipeline.py @@ -296,7 +296,6 @@ def _evaluate_answer_distractors( answer_distractor_evaluation_results = self.evaluator(processed_input) check = lambda x: "contradiction" in x[0] or "neutral" in x[0] mask = np.apply_along_axis(check, arr=np.array(answer_distractor_evaluation_results).reshape(-1, 1), axis=-1) - # mask = np.array(answer_distractor_evaluation_results) != "contradiction" outputs_array = np.array(outputs, dtype=dict) filtered_distractors = outputs_array[mask] @@ -501,8 +500,6 @@ def generate(self, model_inputs, forward_params, postprocess_params): def run_single( self, inputs, preprocess_params, forward_params, postprocess_params ) -> DistractorGenerationOutput: - seed = forward_params.pop("seed") - set_seed(seed) model_inputs = self.preprocess(inputs, **preprocess_params) all_outputs = self.generate( model_inputs, forward_params, postprocess_params diff --git a/generate.py b/generate.py index d82b709..0449407 100644 --- a/generate.py +++ b/generate.py @@ -5,7 +5,7 @@ import numpy as np from tqdm import tqdm -from transformers import FillMaskPipeline, AutoModelForMaskedLM, AutoTokenizer +from transformers import FillMaskPipeline, AutoModelForMaskedLM, AutoTokenizer, set_seed from disgem import MaskedLMBasedDistractorGenerator from disgem.data_loader import ClothLoader, CdgpClothLoader, SquadLoader, DGenLoader @@ -224,6 +224,7 @@ def mmr_at_k(preds, targets, k: int = 1): if __name__ == "__main__": args = create_args() + set_seed(args.seed) if args.evaluate: evaluate(args) else: