From 813110109ce39d4e905d43246fc4c671fba3aa45 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Thu, 9 Jun 2022 04:27:29 +0200 Subject: [PATCH] TempNumpySeed simplification --- disent/util/seeds.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/disent/util/seeds.py b/disent/util/seeds.py index 9bb29ae2..912af68d 100644 --- a/disent/util/seeds.py +++ b/disent/util/seeds.py @@ -24,8 +24,6 @@ import contextlib import logging -import random -import numpy as np log = logging.getLogger(__name__) @@ -44,8 +42,10 @@ def seed(long=777): log.warning(f'[SEEDING]: no seed was specified. Seeding skipped!') return # seed python + import random random.seed(long) # seed numpy + import numpy as np np.random.seed(long) # seed torch - it can be slow to import try: @@ -60,27 +60,26 @@ def seed(long=777): class TempNumpySeed(contextlib.ContextDecorator): - def __init__(self, seed=None, offset=0): + def __init__(self, seed: int = None): # check and normalize seed if seed is not None: try: seed = int(seed) except: - raise ValueError(f'{seed=} is not int-like!') - # offset seed - if seed is not None: - seed += offset + raise ValueError(f'seed={seed} is not int-like!') # save values self._seed = seed self._state = None def __enter__(self): if self._seed is not None: + import numpy as np self._state = np.random.get_state() np.random.seed(self._seed) def __exit__(self, *args, **kwargs): if self._seed is not None: + import numpy as np np.random.set_state(self._state) self._state = None @@ -88,6 +87,7 @@ def _recreate_cm(self): # TODO: do we need to override this? return self + # ========================================================================= # # END # # ========================================================================= #