-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add
DistributionMixin
and some corresponding noise models
- Loading branch information
1 parent
d5da7a5
commit 6e46c65
Showing
4 changed files
with
176 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import numpy | ||
import scipy.stats | ||
|
||
from .. core import DistributionMixin | ||
from .. utils import HAS_PYMC, pm | ||
|
||
|
||
class NormalNoise(DistributionMixin): | ||
"""Normal noise, predicted in terms of mean and standard deviation.""" | ||
scipy_dist = scipy.stats.norm | ||
pymc_dist = pm.Normal if HAS_PYMC else None | ||
|
||
@staticmethod | ||
def to_scipy(*params): | ||
return dict(loc=params[0], scale=params[1]) | ||
|
||
@staticmethod | ||
def to_pymc(*params): | ||
return dict(mu=params[0], sigma=params[1]) | ||
|
||
|
||
class LaplaceNoise(DistributionMixin): | ||
"""Normal noise, predicted in terms of mean and scale.""" | ||
scipy_dist = scipy.stats.laplace | ||
pymc_dist = pm.Laplace if HAS_PYMC else None | ||
|
||
@staticmethod | ||
def to_scipy(*params): | ||
return dict(loc=params[0], scale=params[1]) | ||
|
||
@staticmethod | ||
def to_pymc(*params): | ||
return dict(mu=params[0], b=params[1]) | ||
|
||
|
||
class LogNormalNoise(DistributionMixin): | ||
"""Log-Normal noise, predicted in logarithmic mean and standard deviation. | ||
⚠ This corresponds to the NumPy/Aesara/PyMC parametrization! | ||
""" | ||
scipy_dist = scipy.stats.lognorm | ||
pymc_dist = pm.Lognormal if HAS_PYMC else None | ||
|
||
@staticmethod | ||
def to_scipy(*params): | ||
# SciPy wants linear scale mean and log scale standard deviation! | ||
return dict(scale=numpy.exp(params[0]), s=params[1]) | ||
|
||
@staticmethod | ||
def to_pymc(*params): | ||
return dict(mu=params[0], sigma=params[1]) | ||
|
||
|
||
class StudentTNoise(DistributionMixin): | ||
"""Student-t noise, predicted in terms of mean, scale and degree of freedom.""" | ||
scipy_dist = scipy.stats.t | ||
pymc_dist = pm.StudentT if HAS_PYMC else None | ||
|
||
@staticmethod | ||
def to_scipy(*params): | ||
return dict(loc=params[0], scale=params[1], df=params[2]) | ||
|
||
@staticmethod | ||
def to_pymc(*params): | ||
return dict(mu=params[0], sigma=params[1], nu=params[2]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters