Skip to content

Commit

Permalink
fix and refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
ndem0 committed Apr 16, 2024
1 parent 0f9a4e8 commit 6bb8295
Show file tree
Hide file tree
Showing 12 changed files with 91 additions and 30 deletions.
12 changes: 6 additions & 6 deletions ezyrb/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ def __init__(self, parameters=None, snapshots=None):
if parameters is None and snapshots is None:
return

# if len(parameters) != len(snapshots):
# raise ValueError

if parameters is None:
parameters = [None] * len(snapshots)
elif snapshots is None:
snapshots = [None] * len(parameters)

if len(parameters) != len(snapshots):
raise ValueError

for param, snap in zip(parameters, snapshots):
self.add(Parameter(param), Snapshot(snap))
Expand All @@ -41,8 +41,6 @@ def parameters_matrix(self):
:rtype: numpy.ndarray
"""
print(self._pairs)
print(self._pairs[0])
return np.asarray([pair[0].values for pair in self._pairs])

@property
Expand Down Expand Up @@ -81,7 +79,9 @@ def __len__(self):

def __str__(self):
""" Print minimal info about the Database """
return str(self.parameters_matrix)
s = 'Database with {} snapshots and {} parameters'.format(
self.snapshots_matrix.shape[1], self.parameters_matrix.shape[1])
return s

def add(self, parameter, snapshot):
"""
Expand Down
8 changes: 4 additions & 4 deletions ezyrb/plugin/automatic_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ def _train_shift_network(self, db):

n_epoch += 1

def fom_preprocessing(self, rom):
db = rom._full_database
def fit_preprocessing(self, rom):
db = rom.database

reference_snapshot = db._pairs[self.reference_index][1]
self.reference_snapshot = reference_snapshot
Expand All @@ -154,11 +154,11 @@ def fom_preprocessing(self, rom):
snap.values = self.interpolator.predict(
reference_snapshot.space.reshape(-1, 1)).flatten()

def fom_postprocessing(self, rom):
def predict_postprocessing(self, rom):

ref_space = self.reference_snapshot.space

for param, snap in rom._full_database._pairs:
for param, snap in rom.predict_full_database._pairs:
input_shift = np.hstack([
ref_space.reshape(-1, 1),
np.ones(shape=(ref_space.shape[0], 1))*param.values])
Expand Down
10 changes: 5 additions & 5 deletions ezyrb/plugin/shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def __init__(self, shift_function, interpolator, parameter_index=0,
self.parameter_index = parameter_index
self.reference_index = reference_index

def fom_preprocessing(self, rom):
db = rom._full_database
def fit_preprocessing(self, rom):
db = rom.database

reference_snapshot = db._pairs[self.reference_index][1]

Expand All @@ -68,10 +68,10 @@ def fom_preprocessing(self, rom):
snap.values = self.interpolator.predict(
reference_snapshot.space.reshape(-1, 1)).flatten()

rom._full_database = db
rom.database = db

def fom_postprocessing(self, rom):
for param, snap in rom._full_database._pairs:
def predict_postprocessing(self, rom):
for param, snap in rom.predict_full_database._pairs:
snap.space = (
rom.database._pairs[self.reference_index][1].space +
self.__shift_function(param.values)
Expand Down
7 changes: 6 additions & 1 deletion ezyrb/reducedordermodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,10 +266,15 @@ def predict(self, parameters):
# print(self.predict_reduced_database._pairs[0])
# print(self.predict_reduced_database._pairs[0][1].values)

print(self.predict_reduced_database.parameters_matrix)
print(self.approximation.predict(
self.predict_reduced_database.parameters_matrix))
self.predict_reduced_database = Database(
self.predict_reduced_database.parameters_matrix,
self.approximation.predict(
self.predict_reduced_database.parameters_matrix)
self.predict_reduced_database.parameters_matrix).reshape(
self.predict_reduced_database.parameters_matrix.shape[0], -1
)
)
# print(self.predict_reduced_database)
# print(self.predict_reduced_database._pairs)
Expand Down
8 changes: 6 additions & 2 deletions ezyrb/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ def __init__(self, values, space=None):

@property
def values(self):
""" Get the snapshot values. """
"""
Get the snapshot values.
"""
return self._values

@values.setter
Expand All @@ -25,7 +27,9 @@ def values(self, new_values):

@property
def space(self):
""" Get the snapshot space. """
"""
Get the snapshot space.
"""
return self._space

@space.setter
Expand Down
50 changes: 50 additions & 0 deletions tests/test_approximation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import numpy as np

from ezyrb import (GPR, Linear, RBF, ANN, KNeighborsRegressor,
RadiusNeighborsRegressor)
import sklearn
import pytest

import torch.nn as nn
np.random.seed(17)

def get_xy():
npts = 10
dinput = 4

inp = np.random.uniform(-1, 1, size=(npts, dinput))
out = np.array([
np.sin(inp[:, 0]) + np.sin(inp[:, 1]**2),
np.cos(inp[:, 2]) + np.cos(inp[:, 3]**2)
]).T

return inp, out

@pytest.mark.parametrize("model,kwargs", [
(GPR, {}),
(ANN, {'layers': [20, 20], 'function': nn.Tanh(), 'stop_training': 1e-8, 'last_identity': True}),
(KNeighborsRegressor, {'n_neighbors': 1}),
(RadiusNeighborsRegressor, {'radius': 0.1}),
(Linear, {}),
])
class TestApproximation:
def test_constructor_empty(self, model, kwargs):
model = model(**kwargs)

def test_fit(self, model, kwargs):
x, y = get_xy()
approx = model(**kwargs)
approx.fit(x[:, 0].reshape(-1, 1), y[:, 0].reshape(-1, 1))

approx = model(**kwargs)
approx.fit(x, y)

def test_predict_01(self, model, kwargs):
x, y = get_xy()
approx = model(**kwargs)
approx.fit(x, y)
test_y = approx.predict(x)
if isinstance(approx, ANN):
np.testing.assert_array_almost_equal(y, test_y, decimal=3)
else:
np.testing.assert_array_almost_equal(y, test_y, decimal=6)
6 changes: 3 additions & 3 deletions tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ def test_constructor_arg_wrong(self):
Database(np.random.uniform(size=(9, 3)),
np.random.uniform(size=(10, 8)))

def test_constructor_error(self):
with self.assertRaises(TypeError):
Database(np.eye(5))
# def test_constructor_error(self):
# with self.assertRaises(TypeError):
# Database(np.eye(5))

def test_getitem(self):
org = Database(np.random.uniform(size=(10, 3)),
Expand Down
2 changes: 1 addition & 1 deletion tests/test_k_neighbors_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_with_db_predict(self):
rom = ReducedOrderModel(db, pod, reg)

rom.fit()
pred = rom.predict([[1], [2], [3]])
pred = rom.predict(db)
np.testing.assert_equal(pred.snapshots_matrix, np.array([1, 5, 3])[:,None])

def test_wrong1(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_nnshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_fit_train():
rom = ROM(db, pod, rbf, plugins=[nnspod])
rom.fit()

pred = rom.predict(db.parameters_matrix)
pred = rom.predict(db)

error = 0.0
for (_, snap), (_, truth_snap) in zip(pred._pairs, db._pairs):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_radius_neighbors_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_with_db_predict(self):

rom.fit()
pred = rom.predict([[1], [2], [3]])
np.testing.assert_equal(pred.snapshots_matrix, np.array([1, 5, 3])[:,None])
np.testing.assert_equal(pred, np.array([1, 5, 3])[:,None])



Expand Down
7 changes: 4 additions & 3 deletions tests/test_reducedordermodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,11 @@ def test_multi_db(self):
pod2 = POD(rank=1)
gpr = GPR()
db1 = Database(param, snapshots.T)
db2 = Database(param, snapshots.T)
rom = MROM({'p': db1}, {'a': pod, 'b':pod2}, gpr).fit()
print(rom.predict([-.5, -.5]))
assert False
pred = rom.predict([-.5, -.5])
assert isinstance(pred, dict)
assert len(pred) == 2


"""
def test_optimal_mu(self):
Expand Down
7 changes: 4 additions & 3 deletions tests/test_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ def test_predict_ref():
])
rom.fit()
pred = rom.predict(db._pairs[0][0].values)
print(pred)
np.testing.assert_array_almost_equal(
pred._pairs[0][1].values, db._pairs[0][1].values, decimal=1)
pred[0], db._pairs[0][1].values, decimal=1)


def test_predict():
Expand All @@ -69,12 +70,12 @@ def test_predict():
ShiftSnapshots(shift, Linear(fill_value=0.0))
])
rom.fit()
pred = rom.predict(db._pairs[10][0].values)
pred_db = rom.predict(db)

from scipy import spatial
tree = spatial.KDTree(db._pairs[10][1].space.reshape(-1, 1))
error = 0.0
for coord, value in zip(pred._pairs[0][1].space, pred._pairs[0][1].values):
for coord, value in zip(pred_db._pairs[0][1].space, pred_db._pairs[0][1].values):
a = tree.query(coord)
error += np.abs(value - db._pairs[10][1].values[a[1]])

Expand Down

0 comments on commit 6bb8295

Please sign in to comment.