From 212d2826e38a2d75859a2dbc60056926d6e1e95c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20M=C3=BCller?= Date: Tue, 16 Aug 2022 16:23:15 +0200 Subject: [PATCH] generator: fix/unify call signature --- src/gstools/field/generator.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/src/gstools/field/generator.py b/src/gstools/field/generator.py index 5c0494f8..07def21a 100644 --- a/src/gstools/field/generator.py +++ b/src/gstools/field/generator.py @@ -11,7 +11,7 @@ RandMeth IncomprRandMeth """ -# pylint: disable=C0103, W0222, C0412, W0221 +# pylint: disable=C0103, W0222, C0412 import warnings from abc import ABC, abstractmethod from copy import deepcopy as dcp @@ -70,8 +70,22 @@ def get_nugget(self, shape): """ @abstractmethod - def __call__(self, pos, **kwargs): - """Generate the field.""" + def __call__(self, pos, add_nugget=True): + """ + Generate the field. + + Parameters + ---------- + pos : (d, n), :class:`numpy.ndarray` + the position tuple with d dimensions and n points. + add_nugget : :class:`bool` + Whether to add nugget noise to the field. + + Returns + ------- + :class:`numpy.ndarray` + the random modes + """ @property @abstractmethod @@ -452,7 +466,7 @@ def __init__( self.mean_u = mean_velocity self._value_type = "vector" - def __call__(self, pos): + def __call__(self, pos, add_nugget=True): """Calculate the random modes for the randomization method. This method calls the `summate_incompr_*` Cython methods, @@ -464,6 +478,8 @@ def __call__(self, pos): ---------- pos : (d, n), :class:`numpy.ndarray` the position tuple with d dimensions and n points. + add_nugget : :class:`bool` + Whether to add nugget noise to the field. Returns ------- @@ -474,10 +490,8 @@ def __call__(self, pos): summed_modes = summate_incompr( self._cov_sample, self._z_1, self._z_2, pos ) - nugget = self.get_nugget(summed_modes.shape) - + nugget = self.get_nugget(summed_modes.shape) if add_nugget else 0.0 e1 = self._create_unit_vector(summed_modes.shape) - return ( self.mean_u * e1 + self.mean_u