Skip to content

Commit

Permalink
Merge pull request #250 from GeoStat-Framework/custom-generators
Browse files Browse the repository at this point in the history
Add better support for custom generators
  • Loading branch information
LSchueler authored Aug 16, 2022
2 parents 661bbe6 + c837550 commit 4362175
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 38 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ target-version = [
extension-pkg-whitelist = [
"numpy",
"scipy",
"gstools_core",
]
ignore = "_version.py"
load-plugins = [
Expand Down
12 changes: 7 additions & 5 deletions src/gstools/field/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@

__all__ = ["Field"]

VALUE_TYPES = ["scalar", "vector"]
""":class:`list` of :class:`str`: valid field value types."""


def _pos_equal(pos1, pos2):
if pos1 is None or pos2 is None:
Expand Down Expand Up @@ -79,6 +76,9 @@ class Field:
Dimension of the field if no model is given.
"""

valid_value_types = ["scalar", "vector"]
""":class:`list` of :class:`str`: valid field value types."""

default_field_names = ["field"]
""":class:`list`: Default field names."""

Expand Down Expand Up @@ -663,8 +663,10 @@ def value_type(self):

@value_type.setter
def value_type(self, value_type):
if value_type not in VALUE_TYPES:
raise ValueError(f"Field: value type not in {VALUE_TYPES}")
if value_type not in self.valid_value_types:
raise ValueError(
f"Field: value type not in {self.valid_value_types}"
)
self._value_type = value_type

@property
Expand Down
27 changes: 16 additions & 11 deletions src/gstools/field/cond_srf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
CondSRF
"""
# pylint: disable=C0103, W0231, W0221, W0222, E1102

import numpy as np

from gstools.field.base import Field
from gstools.field.generator import RandMeth
from gstools.field.generator import Generator, RandMeth
from gstools.krige import Krige

__all__ = ["CondSRF"]
Expand All @@ -31,8 +32,8 @@ class CondSRF(Field):
----------
krige : :any:`Krige`
Kriging setup to condition the spatial random field.
generator : :class:`str`, optional
Name of the field generator to be used.
generator : :class:`str` or :any:`Generator`, optional
Name or class of the field generator to be used.
At the moment, only the following generator is provided:
* "RandMeth" : The Randomization Method.
Expand All @@ -44,6 +45,9 @@ class CondSRF(Field):
Have a look at the provided generators for further information.
"""

valid_value_types = ["scalar"]
""":class:`list` of :class:`str`: valid field value types."""

default_field_names = ["field", "raw_field", "raw_krige"]
""":class:`list`: Default field names."""

Expand Down Expand Up @@ -180,18 +184,19 @@ def set_generator(self, generator, **generator_kwargs):
Parameters
----------
generator : :class:`str`, optional
Name of the generator to use for field generation.
generator : :class:`str` or :any:`Generator`, optional
Name or class of the generator to use for field generation.
Default: "RandMeth"
**generator_kwargs
keyword arguments that are forwarded to the generator in use.
"""
if generator in GENERATOR:
gen = GENERATOR[generator]
self._generator = gen(self.model, **generator_kwargs)
self.value_type = self.generator.value_type
else:
raise ValueError(f"gstools.CondSRF: Unknown generator {generator}")
gen = GENERATOR[generator] if generator in GENERATOR else generator
if not (isinstance(gen, type) and issubclass(gen, Generator)):
raise ValueError(
f"gstools.CondSRF: Unknown or wrong generator: {generator}"
)
self._generator = gen(self.model, **generator_kwargs)
self.value_type = self.generator.value_type

def set_pos(self, pos, mesh_type="unstructured", info=False):
"""
Expand Down
83 changes: 72 additions & 11 deletions src/gstools/field/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
The following classes are provided
.. autosummary::
Generator
RandMeth
IncomprRandMeth
"""
# pylint: disable=C0103, W0222, C0412
import warnings
from abc import ABC, abstractmethod
from copy import deepcopy as dcp

import numpy as np
Expand All @@ -26,13 +28,77 @@
else:
from gstools.field.summator import summate, summate_incompr

__all__ = ["RandMeth", "IncomprRandMeth"]
__all__ = ["Generator", "RandMeth", "IncomprRandMeth"]


SAMPLING = ["auto", "inversion", "mcmc"]


class RandMeth:
class Generator(ABC):
"""Abstract generator class."""

@abstractmethod
def update(self, model=None, seed=np.nan):
"""Update the model and the seed.
If model and seed are not different, nothing will be done.
Parameters
----------
model : :any:`CovModel` or :any:`None`, optional
covariance model. Default: :any:`None`
seed : :class:`int` or :any:`None` or :any:`numpy.nan`, optional
the seed of the random number generator.
If :any:`None`, a random seed is used. If :any:`numpy.nan`,
the actual seed will be kept. Default: :any:`numpy.nan`
"""

@abstractmethod
def get_nugget(self, shape):
"""
Generate normal distributed values for the nugget simulation.
Parameters
----------
shape : :class:`tuple`
the shape of the summed modes
Returns
-------
nugget : :class:`numpy.ndarray`
the nugget in the same shape as the summed modes
"""

@abstractmethod
def __call__(self, pos, add_nugget=True):
"""
Generate the field.
Parameters
----------
pos : (d, n), :class:`numpy.ndarray`
the position tuple with d dimensions and n points.
add_nugget : :class:`bool`
Whether to add nugget noise to the field.
Returns
-------
:class:`numpy.ndarray`
the random modes
"""

@property
@abstractmethod
def value_type(self):
""":class:`str`: Type of the field values (scalar, vector)."""

@property
def name(self):
""":class:`str`: Name of the generator."""
return self.__class__.__name__


class RandMeth(Generator):
r"""Randomization method for calculating isotropic random fields.
Parameters
Expand Down Expand Up @@ -310,11 +376,6 @@ def verbose(self):
def verbose(self, verbose):
self._verbose = bool(verbose)

@property
def name(self):
""":class:`str`: Name of the generator."""
return self.__class__.__name__

@property
def value_type(self):
""":class:`str`: Type of the field values (scalar, vector)."""
Expand Down Expand Up @@ -405,7 +466,7 @@ def __init__(
self.mean_u = mean_velocity
self._value_type = "vector"

def __call__(self, pos):
def __call__(self, pos, add_nugget=True):
"""Calculate the random modes for the randomization method.
This method calls the `summate_incompr_*` Cython methods,
Expand All @@ -417,6 +478,8 @@ def __call__(self, pos):
----------
pos : (d, n), :class:`numpy.ndarray`
the position tuple with d dimensions and n points.
add_nugget : :class:`bool`
Whether to add nugget noise to the field.
Returns
-------
Expand All @@ -427,10 +490,8 @@ def __call__(self, pos):
summed_modes = summate_incompr(
self._cov_sample, self._z_1, self._z_2, pos
)
nugget = self.get_nugget(summed_modes.shape)

nugget = self.get_nugget(summed_modes.shape) if add_nugget else 0.0
e1 = self._create_unit_vector(summed_modes.shape)

return (
self.mean_u * e1
+ self.mean_u
Expand Down
25 changes: 14 additions & 11 deletions src/gstools/field/srf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
SRF
"""
# pylint: disable=C0103, W0221, E1102

import numpy as np

from gstools.field.base import Field
from gstools.field.generator import IncomprRandMeth, RandMeth
from gstools.field.generator import Generator, IncomprRandMeth, RandMeth
from gstools.field.upscaling import var_coarse_graining, var_no_scaling

__all__ = ["SRF"]
Expand Down Expand Up @@ -63,8 +64,8 @@ class SRF(Field):
See: :any:`var_coarse_graining`
Default: "no_scaling"
generator : :class:`str`, optional
Name of the field generator to be used.
generator : :class:`str` or :any:`Generator`, optional
Name or class of the field generator to be used.
At the moment, the following generators are provided:
* "RandMeth" : The Randomization Method.
Expand Down Expand Up @@ -165,18 +166,20 @@ def set_generator(self, generator, **generator_kwargs):
Parameters
----------
generator : :class:`str`, optional
Name of the generator to use for field generation.
generator : :class:`str` or :any:`Generator`, optional
Name or class of the field generator to be used.
Default: "RandMeth"
**generator_kwargs
keyword arguments that are forwarded to the generator in use.
"""
if generator in GENERATOR:
gen = GENERATOR[generator]
self._generator = gen(self.model, **generator_kwargs)
self.value_type = self._generator.value_type
else:
raise ValueError(f"gstools.SRF: Unknown generator: {generator}")
gen = GENERATOR[generator] if generator in GENERATOR else generator
if not (isinstance(gen, type) and issubclass(gen, Generator)):
raise ValueError(
f"gstools.SRF: Unknown or wrong generator: {generator}"
)
self._generator = gen(self.model, **generator_kwargs)
self.value_type = self.generator.value_type

for val in [self.mean, self.trend]:
if not callable(val) and val is not None:
if np.size(val) > 1 and self.value_type == "scalar":
Expand Down
3 changes: 3 additions & 0 deletions src/gstools/krige/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ class Krige(Field):
Springer, Berlin, Heidelberg (2003)
"""

valid_value_types = ["scalar"]
""":class:`list` of :class:`str`: valid field value types."""

default_field_names = ["field", "krige_var", "mean_field"]
""":class:`list`: Default field names."""

Expand Down

0 comments on commit 4362175

Please sign in to comment.