From 29ad320688fedd2d9be42e7a658d2802f815d48d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ula=C5=9F=20Sert?= Date: Thu, 13 Jun 2024 20:11:28 +0300 Subject: [PATCH 1/2] Add Global Seed Remove Individual Seed --- disgem/pipeline.py | 4 ++-- generate.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/disgem/pipeline.py b/disgem/pipeline.py index 1e915da..7f91a5a 100644 --- a/disgem/pipeline.py +++ b/disgem/pipeline.py @@ -501,8 +501,8 @@ def generate(self, model_inputs, forward_params, postprocess_params): def run_single( self, inputs, preprocess_params, forward_params, postprocess_params ) -> DistractorGenerationOutput: - seed = forward_params.pop("seed") - set_seed(seed) + # seed = forward_params.pop("seed") + # set_seed(seed) model_inputs = self.preprocess(inputs, **preprocess_params) all_outputs = self.generate( model_inputs, forward_params, postprocess_params diff --git a/generate.py b/generate.py index d82b709..0449407 100644 --- a/generate.py +++ b/generate.py @@ -5,7 +5,7 @@ import numpy as np from tqdm import tqdm -from transformers import FillMaskPipeline, AutoModelForMaskedLM, AutoTokenizer +from transformers import FillMaskPipeline, AutoModelForMaskedLM, AutoTokenizer, set_seed from disgem import MaskedLMBasedDistractorGenerator from disgem.data_loader import ClothLoader, CdgpClothLoader, SquadLoader, DGenLoader @@ -224,6 +224,7 @@ def mmr_at_k(preds, targets, k: int = 1): if __name__ == "__main__": args = create_args() + set_seed(args.seed) if args.evaluate: evaluate(args) else: From 096816ab7dc9f543294436522cf8231071b3fa6e Mon Sep 17 00:00:00 2001 From: devrimcavusoglu Date: Tue, 24 Sep 2024 10:42:49 +0300 Subject: [PATCH 2/2] Removed commented out sections. --- .gitignore | 3 ++- disgem/pipeline.py | 3 --- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index aa9306a..7c0b8d3 100644 --- a/.gitignore +++ b/.gitignore @@ -161,4 +161,5 @@ cython_debug/ # MISC /data -/outputs \ No newline at end of file +/outputs +/experiments diff --git a/disgem/pipeline.py b/disgem/pipeline.py index 7f91a5a..6c99d0d 100644 --- a/disgem/pipeline.py +++ b/disgem/pipeline.py @@ -296,7 +296,6 @@ def _evaluate_answer_distractors( answer_distractor_evaluation_results = self.evaluator(processed_input) check = lambda x: "contradiction" in x[0] or "neutral" in x[0] mask = np.apply_along_axis(check, arr=np.array(answer_distractor_evaluation_results).reshape(-1, 1), axis=-1) - # mask = np.array(answer_distractor_evaluation_results) != "contradiction" outputs_array = np.array(outputs, dtype=dict) filtered_distractors = outputs_array[mask] @@ -501,8 +500,6 @@ def generate(self, model_inputs, forward_params, postprocess_params): def run_single( self, inputs, preprocess_params, forward_params, postprocess_params ) -> DistractorGenerationOutput: - # seed = forward_params.pop("seed") - # set_seed(seed) model_inputs = self.preprocess(inputs, **preprocess_params) all_outputs = self.generate( model_inputs, forward_params, postprocess_params