Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NNsPOD #221

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 37 additions & 4 deletions ezyrb/ann.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _build_model(self, points, values):
layers_torch.append(nn.Linear(layers[-2], layers[-1]))
self.model = nn.Sequential(*layers_torch)

def fit(self, points, values):
def fit(self, points, values, optimizer = torch.optim.Adam, learning_rate = 0.001, frequency_print = 0):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optimizer and learning_rate should be constructor arguments

"""
Build the ANN given 'points' and 'values' and perform training.

Expand All @@ -119,14 +119,16 @@ def fit(self, points, values):
:param numpy.ndarray points: the coordinates of the given (training)
points.
:param numpy.ndarray values: the (training) values in the points.
:param torch.optimizer optimizer: the optimizer used for the neural network
:param float learning_rate: learning rate used in the optimizer
:param int frequency_print: the number of epochs between the print of each loss value
"""

self._build_model(points, values)
self.optimizer = torch.optim.Adam(self.model.parameters())
self.optimizer = optimizer(self.model.parameters(), lr = learning_rate)

points = self._convert_numpy_to_torch(points)
values = self._convert_numpy_to_torch(values)

n_epoch = 1
flag = True
while flag:
Expand All @@ -143,7 +145,9 @@ def fit(self, points, values):
elif isinstance(criteria, float): # stop criteria is float
if loss.item() < criteria:
flag = False

if frequency_print != 0:
if n_epoch % frequency_print == 1:
print(loss.item())
n_epoch += 1

def predict(self, new_point):
Expand All @@ -157,3 +161,32 @@ def predict(self, new_point):
new_point = self._convert_numpy_to_torch(np.array(new_point))
y_new = self.model(new_point)
return self._convert_torch_to_numpy(y_new)


def save_state(self, filename):

checkpoint = {
'model_state': self.model.state_dict(),
'optimizer_state' : self.optimizer.state_dict(),
'optimizer_class' : self.optimizer.__class__,
'model_class' : self.model.__class__
}

torch.save(checkpoint, filename)

def load_state(self, filename, points, values):

checkpoint = torch.load(filename)

self._build_model(points, values)
self.optimizer = checkpoint['optimizer_class']

self.model.load_state_dict(checkpoint['model_state'])

self.optimizer = checkpoint['optimizer_class'](self.model.parameters())
self.optimizer.load_state_dict(checkpoint['optimizer_state'])

# self.trained_epoch = checkpoint['epoch']
# self.history = checkpoint['history']

return self
43 changes: 24 additions & 19 deletions ezyrb/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,6 @@ def __init__(self,
raise RuntimeError(
'Parameters and Snapshots are not both provided')

if space is not None and snapshots is None:
raise RuntimeError(
'Snapshot data is not provided with Spatial data')

if space is not None and snapshots is None:
raise RuntimeError

if space is not None and snapshots is None:
raise RuntimeError(
Expand Down Expand Up @@ -82,6 +76,7 @@ def space(self):
"""
return self._space


def __getitem__(self, val):
"""
This method returns a new Database with the selected parameters and
Expand All @@ -90,24 +85,24 @@ def __getitem__(self, val):
.. warning:: The new parameters and snapshots are a view of the
original Database.
"""

if isinstance(val, int):
if self._space is None:
return Database(np.reshape(self._parameters[val],
(1,len(self._parameters[val]))),
np.reshape(self._snapshots[val],
(1,len(self._snapshots[val]))),
return Database(self._parameters[val].reshape(1,len(self._parameters[val])),
self._snapshots[val].reshape(1, len(self._snapshots[val])),
self.scaler_parameters,
self.scaler_snapshots)

return Database(np.reshape(self._parameters[val],
(1,len(self._parameters[val]))),
np.reshape(self._snapshots[val],
(1,len(self._snapshots[val]))),
try:
self._space[val][0]
space = self._space[val].reshape(1,len(self._space[val]), len(self._space[val][0]))
except:
space = self._space[val].reshape(1,-1)
return Database(self._parameters[val].reshape(1,len(self._parameters[val])),
self._snapshots[val].reshape(1, len(self._snapshots[val])),
self.scaler_parameters,
self.scaler_snapshots,
np.reshape(self._space[val],
(1,len(self._space[val]))))

space)

if self._space is None:
return Database(self._parameters[val],
self._snapshots[val],
Expand Down Expand Up @@ -145,10 +140,20 @@ def add(self, parameters, snapshots, space=None):
raise RuntimeError('No Spatial Value given')

if (self._space is not None) or (space is not None):
if space.shape != snapshots.shape:
if len(space) != len(snapshots) or len(space[0]) != len(snapshots[0]):
raise RuntimeError(
'shape of space and snapshots are different.')

if self._space is not None:
if space is None:
raise RuntimeError('No Spatial Value given')

if (self._space is not None) or (space is not None):
if len(space) != len(snapshots) or len(space[0]) != len(snapshots[0]):
raise RuntimeError(
'shape of space and snapshots are different.')


if self._parameters is None and self._snapshots is None:
self._parameters = parameters
self._snapshots = snapshots
Expand Down
245 changes: 245 additions & 0 deletions ezyrb/nnspod.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import torch
import torch.nn as nn
from .ann import ANN
from .pod import POD
from .database import Database


class NNsPOD(POD):
def __init__(self,
interp_loss, interp_layers, interp_function, interp_stop_training, ref_point,
shift_layers, shift_function, shift_stop_training, shift_loss = nn.MSELoss(),
method = "svd"):
'''
:param list interp_layers: list with number of neurons in each layer
:param torch.nn.modules.activation interp_function: activation function for the interpnet
:param float interp_stop_training: desired tolerance for the interp training
:param torch.nn.Module interp_loss: loss function (MSE default)
'''
## add loss, layers, and functions variables
super().__init__(method)
self.interp_loss = interp_loss
self.interp_layers = interp_layers
self.interp_function = interp_function
self.interp_stop_training = interp_stop_training
self.shift_loss = shift_loss
self.shift_layers = shift_layers
self.shift_function = shift_function
self.shift_stop_training = shift_stop_training
self.ref_point = ref_point

def reshape2dto1d(self, x, y):
"""
reshapes two n by n arrays into one n^2 by 2 array
:param numpy.array x: x value of data
:param numpy.array y: y value of data
"""
x = x.reshape(-1,1)
y = y.reshape(-1,1)
coords = np.concatenate((x, y), axis = 1)
coords = np.array(coords).reshape(-1,2)

return coords

def reshape1dto2d(self, snapshots):
"""
turns 1d list of data into 2d
:param array-like snapshots: data to be reshaped
"""
return snapshots.reshape(int(np.sqrt(len(snapshots))), int(np.sqrt(len(snapshots))))

def train_interpnet(self,ref_data, retrain = False, frequency_print = 0, save = True, interp_file = None):
"""
trains the Interpnet given 1d data:

:param database ref_data: the reference data that the rest of the data will be shifted to

:param boolean retrain: True if the interpNetShould be retrained, False if it should be loaded
"""
if interp_file:
print(interp_file)
self.interp_path = interp_file

self.interp_net = ANN(self.interp_layers, self.interp_function, self.interp_stop_training, self.interp_loss)
if len(ref_data.space.shape) > 2:
space = ref_data.space.reshape(-1, 2)
else:
space = ref_data.space.reshape(-1,1)
snapshots = ref_data.snapshots.reshape(-1,1)
if not retrain:
try:
self.interp_net = self.interp_net.load_state(self.interp_path, space, snapshots)
print("loaded interpnet")
except:
self.interp_net.fit(space, snapshots, frequency_print = frequency_print)
if save:
self.interp_net.save_state(self.interp_path)
else:
self.interp_net.fit(space, snapshots, frequency_print = frequency_print)
if save:
self.interp_net.save_state(self.interp_path)

def shift(self, x, y, shift_quantity):
"""
shifts data by shift_quanity
"""
return(x+shift_quantity, y)

def pre_shift(self,x,y, ref_y):
"""
moves data so that the max of y and max of ref_y are at the same x coordinate
"""
maxy = 0
for i, n, in enumerate(y):
if n > y[maxy]:
maxy = i
maxref = 0
for i, n in enumerate(ref_y):
if n > ref_y[maxref]:
maxref = i

return self.shift(x, y, x[maxref]-x[maxy])[0]

def make_points(self, x, params):
"""
creates points that can be used to train and predict shiftnet
"""
if len(x.shape)> 1:
points = np.zeros((len(x),3))
for j, s in enumerate(x):
points[j][0] = s[0]
points[j][1] = s[1]
points[j][2] = params[0]
else:
points = np.zeros((len(x),2))
for j, s in enumerate(x):
points[j][0] = s
points[j][1] = params[0]
return points

def build_model(self, dim = 1):
"""
builds model based on dimension of input data
"""
layers = self.layers.copy()
layers.insert(0, dim + 1)
layers.append(dim)
layers_torch = []
for i in range(len(layers) - 2):
layers_torch.append(nn.Linear(layers[i], layers[i + 1]))
layers_torch.append(self.function)
layers_torch.append(nn.Linear(layers[-2], layers[-1]))
self.model = nn.Sequential(*layers_torch)

def train_shiftnet(self, db, ref_data, preshift = False,
optimizer = torch.optim.Adam, learning_rate = 0.0001,
frequency_print = 0):
"""
Trains and evaluates shiftnet given 1d data 'db'

:param Database db: data at a certain parameter value
:param list shift_layers: ordered list with number of neurons in each layer
:param torch.nn.module.activation shift_function: the activation function used by the shiftnet
:param int, float, or list stop_training:
int: number of epochs before stopping
float: desired tolarance before stopping training
list: a int and a float, stops when either desired epochs or tolerance is reached
:param Database db: data at the reference datapoint
:param boolean preshift: True if preshift is desired otherwise false.
"""
self.layers = self.shift_layers
self.function = self.shift_function
if preshift:
x = self.pre_shift(db.space[0], db.snapshots[0], ref_data.snapshots[0])
else:
x = db.space[0]
if len(db.space.shape) > 2:
x_reshaped = x.reshape(-1,2)
self.build_model(dim = 2)
else:
self.build_model(dim = 1)
x_reshaped = x.reshape(-1,1)

values = db.snapshots.reshape(-1,1)

self.stop_training = self.shift_stop_training
points = self.make_points(x, db.parameters)

self.optimizer = optimizer(self.model.parameters(), lr = learning_rate)

self.loss = self.shift_loss
points = torch.from_numpy(points).float()
self.loss_trend = []
n_epoch = 1
flag = True
while flag:
shift = self.model(points)
x_shift, y = self.shift(
torch.from_numpy(x_reshaped).float(),
torch.from_numpy(values).float(),
shift)
ref_interp = self.interp_net.model(x_shift)
loss = self.loss(ref_interp, y)
loss.backward()
self.optimizer.step()
self.loss_trend.append(loss.item())
for criteria in self.stop_training:
if isinstance(criteria, int): # stop criteria is an integer
if n_epoch == criteria:
flag = False
elif isinstance(criteria, float): # stop criteria is float
if loss.item() < criteria:
flag = False
if frequency_print != 0:
if n_epoch % frequency_print == 1:
print(loss.item())
n_epoch += 1

new_point = self.make_points(x, db.parameters)
shift = self.model(torch.from_numpy(new_point).float())
x_new = self.shift(
torch.from_numpy(x_reshaped).float(),
torch.from_numpy(values).float(),
shift)[0]
x_ret = x_new.detach().numpy()
return x_ret

def fit(self, db, interp_file):
self.interp_path = interp_file
## input variables: load files.
self.train_interpnet(db[self.ref_point], retrain = False, frequency_print = 25)
new_x = np.zeros(shape = db.space.shape)
i = 0
while i < db.parameters.shape[0]:
if len(db.space.shape) > 2:
new_x[i] = self.train_shiftnet(db[i], db[self.ref_point], preshift = True, frequency_print = 50).reshape(-1, 2)
else:
new_x[i] = self.train_shiftnet(db[i], db[self.ref_point], preshift = True, frequency_print = 50).reshape(-1)
i+=1
if i == self.ref_point:
new_x[self.ref_point] = db.space[self.ref_point]
i +=1
db = Database(space = new_x, snapshots = db.snapshots, parameters = db.parameters)

i = 0
new_snapshots = np.zeros(shape = db.snapshots.shape)
new_space = np.zeros(shape = db.space.shape)
while i < db.parameters.shape[0]:
if len(db.space.shape) > 2:
new_snapshots[i] = self.interp_net.model(torch.from_numpy(db.space[i].reshape(-1,2)).float()).detach().numpy().reshape(-1)
else:
new_snapshots[i] = self.interp_net.model(torch.from_numpy(db.space[i].reshape(-1,1)).float()).detach().numpy().reshape(-1)
new_space[i] = db.space[self.ref_point]
i+=1
if i == self.ref_point:
new_snapshots[self.ref_point] = db.snapshots[self.ref_point]
new_space[self.ref_point] = db.space[self.ref_point]
i +=1

db = Database(space = new_space, snapshots = new_snapshots, parameters = db.parameters)
POD_ = POD(method = 'svd')
return POD_.fit(db.snapshots)

Binary file added tutorials/interpnet1d.pth
Binary file not shown.
Binary file added tutorials/interpnet2d.pth
Binary file not shown.
Loading