Skip to content

Commit

Permalink
TempNumpySeed simplification
Browse files Browse the repository at this point in the history
  • Loading branch information
nmichlo committed Jun 9, 2022
1 parent 939b64e commit 8131101
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions disent/util/seeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@

import contextlib
import logging
import random
import numpy as np


log = logging.getLogger(__name__)
Expand All @@ -44,8 +42,10 @@ def seed(long=777):
log.warning(f'[SEEDING]: no seed was specified. Seeding skipped!')
return
# seed python
import random
random.seed(long)
# seed numpy
import numpy as np
np.random.seed(long)
# seed torch - it can be slow to import
try:
Expand All @@ -60,34 +60,34 @@ def seed(long=777):


class TempNumpySeed(contextlib.ContextDecorator):
def __init__(self, seed=None, offset=0):
def __init__(self, seed: int = None):
# check and normalize seed
if seed is not None:
try:
seed = int(seed)
except:
raise ValueError(f'{seed=} is not int-like!')
# offset seed
if seed is not None:
seed += offset
raise ValueError(f'seed={seed} is not int-like!')
# save values
self._seed = seed
self._state = None

def __enter__(self):
if self._seed is not None:
import numpy as np
self._state = np.random.get_state()
np.random.seed(self._seed)

def __exit__(self, *args, **kwargs):
if self._seed is not None:
import numpy as np
np.random.set_state(self._state)
self._state = None

def _recreate_cm(self):
# TODO: do we need to override this?
return self


# ========================================================================= #
# END #
# ========================================================================= #

0 comments on commit 8131101

Please sign in to comment.