-
Notifications
You must be signed in to change notification settings - Fork 64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactoring for pre- and post-processing, new design #232
Changes from all commits
fc61276
042be1a
8c32785
4870982
b5bd50c
7d1617b
d456557
31ea0bf
93e16d2
0153128
1838059
cb35c19
80e88dd
7e4377f
03e0344
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,16 @@ | ||
"""EZyRB package""" | ||
|
||
__all__ = [ | ||
'Database', 'Reduction', 'POD', 'Approximation', 'RBF', 'Linear', 'GPR', | ||
'Database', 'Snapshot', 'Reduction', 'POD', 'Approximation', 'RBF', 'Linear', 'GPR', | ||
'ANN', 'KNeighborsRegressor', 'RadiusNeighborsRegressor', 'AE', | ||
'ReducedOrderModel', 'PODAE', 'RegularGrid' | ||
] | ||
|
||
from .meta import * | ||
from .database import Database | ||
from .reduction import Reduction | ||
from .pod import POD | ||
from .ae import AE | ||
from .pod_ae import PODAE | ||
from .approximation import Approximation | ||
from .rbf import RBF | ||
from .linear import Linear | ||
from .regular_grid import RegularGrid | ||
from .gpr import GPR | ||
from .snapshot import Snapshot | ||
from .parameter import Parameter | ||
from .reducedordermodel import ReducedOrderModel | ||
from .ann import ANN | ||
from .kneighbors_regressor import KNeighborsRegressor | ||
from .radius_neighbors_regressor import RadiusNeighborsRegressor | ||
from .reduction import * | ||
from .approximation import * | ||
from .regular_grid import RegularGrid |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
"""EZyRB package""" | ||
|
||
__all__ = [ | ||
'Approximation', 'RBF', 'Linear', 'GPR', | ||
'ANN', 'KNeighborsRegressor', 'RadiusNeighborsRegressor' | ||
] | ||
|
||
from .approximation import Approximation | ||
from .rbf import RBF | ||
from .linear import Linear | ||
from .gpr import GPR | ||
from .ann import ANN | ||
from .kneighbors_regressor import KNeighborsRegressor | ||
from .radius_neighbors_regressor import RadiusNeighborsRegressor |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,9 @@ | |
|
||
import numpy as np | ||
|
||
from .parameter import Parameter | ||
from .snapshot import Snapshot | ||
|
||
class Database(): | ||
""" | ||
Database class | ||
|
@@ -14,66 +17,35 @@ class Database(): | |
None meaning no scaling. | ||
:param array_like space: the input spatial data | ||
""" | ||
def __init__(self, | ||
parameters=None, | ||
snapshots=None, | ||
scaler_parameters=None, | ||
scaler_snapshots=None, | ||
space=None): | ||
self._parameters = None | ||
self._snapshots = None | ||
self._space = None | ||
self.scaler_parameters = scaler_parameters | ||
self.scaler_snapshots = scaler_snapshots | ||
|
||
# if only parameters or snapshots are provided | ||
if (parameters is None) ^ (snapshots is None): | ||
raise RuntimeError( | ||
'Parameters and Snapshots are not both provided') | ||
|
||
if space is not None and snapshots is None: | ||
raise RuntimeError( | ||
'Snapshot data is not provided with Spatial data') | ||
|
||
if parameters is not None and snapshots is not None: | ||
if space is not None: | ||
self.add(parameters, snapshots, space) | ||
else: | ||
self.add(parameters, snapshots) | ||
def __init__(self, parameters=None, snapshots=None): | ||
self._pairs = [] | ||
|
||
if parameters is None and snapshots is None: | ||
return | ||
|
||
@property | ||
def parameters(self): | ||
""" | ||
The matrix containing the input parameters (by row). | ||
|
||
:rtype: numpy.ndarray | ||
""" | ||
if self.scaler_parameters: | ||
return self.scaler_parameters.fit_transform(self._parameters) | ||
if len(parameters) != len(snapshots): | ||
raise ValueError | ||
|
||
return self._parameters | ||
for param, snap in zip(parameters, snapshots): | ||
self.add(Parameter(param), Snapshot(snap)) | ||
|
||
@property | ||
def snapshots(self): | ||
def parameters_matrix(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. small typo, could be changed to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would keep the plural form at the moment, just to avoid changing all the tests! I'll refactor this class a bit. |
||
""" | ||
The matrix containing the snapshots (by row). | ||
The matrix containing the input parameters (by row). | ||
|
||
:rtype: numpy.ndarray | ||
""" | ||
if self.scaler_snapshots: | ||
return self.scaler_snapshots.fit_transform(self._snapshots) | ||
|
||
return self._snapshots | ||
return np.asarray([pair[0].values for pair in self._pairs]) | ||
|
||
@property | ||
def space(self): | ||
def snapshots_matrix(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typo, could be changed to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would keep the plural form at the moment, just to avoid changing all the tests! I'll refactor this class a bit. |
||
""" | ||
The matrix containing spatial information (by row). | ||
The matrix containing the snapshots (by row). | ||
|
||
:rtype: numpy.ndarray | ||
""" | ||
return self._space | ||
return np.asarray([pair[1].flattened for pair in self._pairs]) | ||
|
||
def __getitem__(self, val): | ||
""" | ||
|
@@ -83,75 +55,85 @@ def __getitem__(self, val): | |
.. warning:: The new parameters and snapshots are a view of the | ||
original Database. | ||
""" | ||
if isinstance(val, int): | ||
if self._space is None: | ||
return Database(np.reshape(self._parameters[val], | ||
(1,len(self._parameters[val]))), | ||
np.reshape(self._snapshots[val], | ||
(1,len(self._snapshots[val]))), | ||
self.scaler_parameters, | ||
self.scaler_snapshots) | ||
|
||
return Database(np.reshape(self._parameters[val], | ||
(1,len(self._parameters[val]))), | ||
np.reshape(self._snapshots[val], | ||
(1,len(self._snapshots[val]))), | ||
self.scaler_parameters, | ||
self.scaler_snapshots, | ||
np.reshape(self._space[val], | ||
(1,len(self._space[val])))) | ||
|
||
if self._space is None: | ||
return Database(self._parameters[val], | ||
self._snapshots[val], | ||
self.scaler_parameters, | ||
self.scaler_snapshots) | ||
|
||
return Database(self._parameters[val], | ||
self._snapshots[val], | ||
self.scaler_parameters, | ||
self.scaler_snapshots, | ||
self._space[val]) | ||
if isinstance(val, np.ndarray): | ||
view = Database() | ||
for p, s in np.asarray(self._pairs)[val]: | ||
view.add(p, s) | ||
elif isinstance(val, (int, slice)): | ||
view = Database() | ||
view._pairs = self._pairs[val] | ||
return view | ||
|
||
def __len__(self): | ||
""" | ||
This method returns the number of snapshots. | ||
|
||
:rtype: int | ||
""" | ||
return len(self._snapshots) | ||
return len(self._pairs) | ||
|
||
def __str__(self): | ||
""" Print minimal info about the Database """ | ||
return str(self.parameters_matrix) | ||
|
||
def add(self, parameters, snapshots, space=None): | ||
def add(self, parameter, snapshot): | ||
""" | ||
Add (by row) new sets of snapshots and parameters to the original | ||
database. | ||
|
||
:param array_like parameters: the parameters to add. | ||
:param array_like snapshots: the snapshots to add. | ||
:param Parameter parameter: the parameter to add. | ||
:param Snapshot snapshot: the snapshot to add. | ||
""" | ||
if len(parameters) != len(snapshots): | ||
raise RuntimeError( | ||
'Different number of parameters and snapshots.') | ||
|
||
if self._space is not None and space is None: | ||
raise RuntimeError('No Spatial Value given') | ||
|
||
if (self._space is not None) or (space is not None): | ||
if space.shape != snapshots.shape: | ||
raise RuntimeError( | ||
'shape of space and snapshots are different.') | ||
|
||
if self._parameters is None and self._snapshots is None: | ||
self._parameters = parameters | ||
self._snapshots = snapshots | ||
if self._space is None: | ||
self._space = space | ||
elif self._space is None: | ||
self._parameters = np.vstack([self._parameters, parameters]) | ||
self._snapshots = np.vstack([self._snapshots, snapshots]) | ||
else: | ||
self._parameters = np.vstack([self._parameters, parameters]) | ||
self._snapshots = np.vstack([self._snapshots, snapshots]) | ||
self._space = np.vstack([self._space, space]) | ||
if not isinstance(parameter, Parameter): | ||
raise ValueError | ||
|
||
if not isinstance(snapshot, Snapshot): | ||
raise ValueError | ||
|
||
self._pairs.append((parameter, snapshot)) | ||
|
||
return self | ||
|
||
|
||
def split(self, chunks, seed=None): | ||
""" | ||
|
||
>>> db = Database(...) | ||
>>> train, test = db.split([0.8, 0.2]) # ratio | ||
>>> train, test = db.split([80, 20]) # n snapshots | ||
|
||
""" | ||
if all(isinstance(n, int) for n in chunks): | ||
if sum(chunks) != len(self): | ||
raise ValueError('chunk elements are inconsistent') | ||
|
||
ids = [ | ||
j for j, chunk in enumerate(chunks) | ||
for i in range(chunk) | ||
] | ||
np.random.shuffle(ids) | ||
|
||
|
||
elif all(isinstance(n, float) for n in chunks): | ||
if not np.isclose(sum(chunks), 1.): | ||
raise ValueError('chunk elements are inconsistent') | ||
|
||
cum_chunks = np.cumsum(chunks) | ||
cum_chunks = np.insert(cum_chunks, 0, 0.0) | ||
ids = np.ones(len(self)) * -1. | ||
tmp = np.random.uniform(0, 1, size=len(self)) | ||
for i in range(len(cum_chunks)-1): | ||
is_between = np.logical_and( | ||
tmp >= cum_chunks[i], tmp < cum_chunks[i+1]) | ||
ids[is_between] = i | ||
|
||
else: | ||
ValueError | ||
|
||
new_database = [Database() for _ in range(len(chunks))] | ||
for i, chunk in enumerate(chunks): | ||
chunk_ids = np.array(ids) == i | ||
for p, s in np.asarray(self._pairs)[chunk_ids]: | ||
new_database[i].add(p, s) | ||
|
||
return new_database |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
""" Module for parameter object """ | ||
import numpy as np | ||
|
||
class Parameter: | ||
|
||
def __init__(self, values): | ||
self.values = values | ||
|
||
@property | ||
def values(self): | ||
""" Get the snapshot values. """ | ||
return self._values | ||
|
||
@values.setter | ||
def values(self, new_values): | ||
if np.asarray(new_values).ndim != 1: | ||
raise ValueError('only 1D array are usable as parameter.') | ||
self._values = new_values |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
""" Plugins submodule """ | ||
|
||
__all__ = [ | ||
'Plugin', | ||
'DatabaseScaler', | ||
'ShiftSnapshots', | ||
'AutomaticShiftSnapshots', | ||
] | ||
|
||
from .scaler import DatabaseScaler | ||
from .plugin import Plugin | ||
from .shift import ShiftSnapshots | ||
from .automatic_shift import AutomaticShiftSnapshots |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You might want to consider adding
axis=0, bounds_error=False, **kwargs
to the arguments offit(...)
. That way the user can control and specify these things.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this class, it is quite difficult to keep that general layout, mainly because we are internally using two different methods depending on the input dimension!