diff --git a/ezyrb/ann.py b/ezyrb/ann.py index d0dac282..47635548 100755 --- a/ezyrb/ann.py +++ b/ezyrb/ann.py @@ -99,7 +99,7 @@ def _build_model(self, points, values): layers_torch.append(nn.Linear(layers[-2], layers[-1])) self.model = nn.Sequential(*layers_torch) - def fit(self, points, values): + def fit(self, points, values, optimizer = torch.optim.Adam, learning_rate = 0.001, frequency_print = 0): """ Build the ANN given 'points' and 'values' and perform training. @@ -119,14 +119,16 @@ def fit(self, points, values): :param numpy.ndarray points: the coordinates of the given (training) points. :param numpy.ndarray values: the (training) values in the points. + :param torch.optimizer optimizer: the optimizer used for the neural network + :param float learning_rate: learning rate used in the optimizer + :param int frequency_print: the number of epochs between the print of each loss value """ self._build_model(points, values) - self.optimizer = torch.optim.Adam(self.model.parameters()) + self.optimizer = optimizer(self.model.parameters(), lr = learning_rate) points = self._convert_numpy_to_torch(points) values = self._convert_numpy_to_torch(values) - n_epoch = 1 flag = True while flag: @@ -143,7 +145,9 @@ def fit(self, points, values): elif isinstance(criteria, float): # stop criteria is float if loss.item() < criteria: flag = False - + if frequency_print != 0: + if n_epoch % frequency_print == 1: + print(loss.item()) n_epoch += 1 def predict(self, new_point): @@ -157,3 +161,32 @@ def predict(self, new_point): new_point = self._convert_numpy_to_torch(np.array(new_point)) y_new = self.model(new_point) return self._convert_torch_to_numpy(y_new) + + + def save_state(self, filename): + + checkpoint = { + 'model_state': self.model.state_dict(), + 'optimizer_state' : self.optimizer.state_dict(), + 'optimizer_class' : self.optimizer.__class__, + 'model_class' : self.model.__class__ + } + + torch.save(checkpoint, filename) + + def load_state(self, filename, points, values): + + checkpoint = torch.load(filename) + + self._build_model(points, values) + self.optimizer = checkpoint['optimizer_class'] + + self.model.load_state_dict(checkpoint['model_state']) + + self.optimizer = checkpoint['optimizer_class'](self.model.parameters()) + self.optimizer.load_state_dict(checkpoint['optimizer_state']) + + # self.trained_epoch = checkpoint['epoch'] + # self.history = checkpoint['history'] + + return self diff --git a/ezyrb/database.py b/ezyrb/database.py index da4fafa3..0c07d97c 100644 --- a/ezyrb/database.py +++ b/ezyrb/database.py @@ -31,12 +31,6 @@ def __init__(self, raise RuntimeError( 'Parameters and Snapshots are not both provided') - if space is not None and snapshots is None: - raise RuntimeError( - 'Snapshot data is not provided with Spatial data') - - if space is not None and snapshots is None: - raise RuntimeError if space is not None and snapshots is None: raise RuntimeError( @@ -82,6 +76,7 @@ def space(self): """ return self._space + def __getitem__(self, val): """ This method returns a new Database with the selected parameters and @@ -90,24 +85,24 @@ def __getitem__(self, val): .. warning:: The new parameters and snapshots are a view of the original Database. """ + if isinstance(val, int): if self._space is None: - return Database(np.reshape(self._parameters[val], - (1,len(self._parameters[val]))), - np.reshape(self._snapshots[val], - (1,len(self._snapshots[val]))), + return Database(self._parameters[val].reshape(1,len(self._parameters[val])), + self._snapshots[val].reshape(1, len(self._snapshots[val])), self.scaler_parameters, self.scaler_snapshots) - - return Database(np.reshape(self._parameters[val], - (1,len(self._parameters[val]))), - np.reshape(self._snapshots[val], - (1,len(self._snapshots[val]))), + try: + self._space[val][0] + space = self._space[val].reshape(1,len(self._space[val]), len(self._space[val][0])) + except: + space = self._space[val].reshape(1,-1) + return Database(self._parameters[val].reshape(1,len(self._parameters[val])), + self._snapshots[val].reshape(1, len(self._snapshots[val])), self.scaler_parameters, self.scaler_snapshots, - np.reshape(self._space[val], - (1,len(self._space[val])))) - + space) + if self._space is None: return Database(self._parameters[val], self._snapshots[val], @@ -145,10 +140,20 @@ def add(self, parameters, snapshots, space=None): raise RuntimeError('No Spatial Value given') if (self._space is not None) or (space is not None): - if space.shape != snapshots.shape: + if len(space) != len(snapshots) or len(space[0]) != len(snapshots[0]): raise RuntimeError( 'shape of space and snapshots are different.') + if self._space is not None: + if space is None: + raise RuntimeError('No Spatial Value given') + + if (self._space is not None) or (space is not None): + if len(space) != len(snapshots) or len(space[0]) != len(snapshots[0]): + raise RuntimeError( + 'shape of space and snapshots are different.') + + if self._parameters is None and self._snapshots is None: self._parameters = parameters self._snapshots = snapshots diff --git a/ezyrb/nnspod.py b/ezyrb/nnspod.py new file mode 100644 index 00000000..02f43b47 --- /dev/null +++ b/ezyrb/nnspod.py @@ -0,0 +1,245 @@ +import matplotlib.pyplot as plt +from mpl_toolkits.mplot3d import Axes3D +import numpy as np +import torch +import torch.nn as nn +from .ann import ANN +from .pod import POD +from .database import Database + + +class NNsPOD(POD): + def __init__(self, + interp_loss, interp_layers, interp_function, interp_stop_training, ref_point, + shift_layers, shift_function, shift_stop_training, shift_loss = nn.MSELoss(), + method = "svd"): + ''' + :param list interp_layers: list with number of neurons in each layer + :param torch.nn.modules.activation interp_function: activation function for the interpnet + :param float interp_stop_training: desired tolerance for the interp training + :param torch.nn.Module interp_loss: loss function (MSE default) + ''' + ## add loss, layers, and functions variables + super().__init__(method) + self.interp_loss = interp_loss + self.interp_layers = interp_layers + self.interp_function = interp_function + self.interp_stop_training = interp_stop_training + self.shift_loss = shift_loss + self.shift_layers = shift_layers + self.shift_function = shift_function + self.shift_stop_training = shift_stop_training + self.ref_point = ref_point + + def reshape2dto1d(self, x, y): + """ + reshapes two n by n arrays into one n^2 by 2 array + :param numpy.array x: x value of data + :param numpy.array y: y value of data + """ + x = x.reshape(-1,1) + y = y.reshape(-1,1) + coords = np.concatenate((x, y), axis = 1) + coords = np.array(coords).reshape(-1,2) + + return coords + + def reshape1dto2d(self, snapshots): + """ + turns 1d list of data into 2d + :param array-like snapshots: data to be reshaped + """ + return snapshots.reshape(int(np.sqrt(len(snapshots))), int(np.sqrt(len(snapshots)))) + + def train_interpnet(self,ref_data, retrain = False, frequency_print = 0, save = True, interp_file = None): + """ + trains the Interpnet given 1d data: + + :param database ref_data: the reference data that the rest of the data will be shifted to + + :param boolean retrain: True if the interpNetShould be retrained, False if it should be loaded + """ + if interp_file: + print(interp_file) + self.interp_path = interp_file + + self.interp_net = ANN(self.interp_layers, self.interp_function, self.interp_stop_training, self.interp_loss) + if len(ref_data.space.shape) > 2: + space = ref_data.space.reshape(-1, 2) + else: + space = ref_data.space.reshape(-1,1) + snapshots = ref_data.snapshots.reshape(-1,1) + if not retrain: + try: + self.interp_net = self.interp_net.load_state(self.interp_path, space, snapshots) + print("loaded interpnet") + except: + self.interp_net.fit(space, snapshots, frequency_print = frequency_print) + if save: + self.interp_net.save_state(self.interp_path) + else: + self.interp_net.fit(space, snapshots, frequency_print = frequency_print) + if save: + self.interp_net.save_state(self.interp_path) + + def shift(self, x, y, shift_quantity): + """ + shifts data by shift_quanity + """ + return(x+shift_quantity, y) + + def pre_shift(self,x,y, ref_y): + """ + moves data so that the max of y and max of ref_y are at the same x coordinate + """ + maxy = 0 + for i, n, in enumerate(y): + if n > y[maxy]: + maxy = i + maxref = 0 + for i, n in enumerate(ref_y): + if n > ref_y[maxref]: + maxref = i + + return self.shift(x, y, x[maxref]-x[maxy])[0] + + def make_points(self, x, params): + """ + creates points that can be used to train and predict shiftnet + """ + if len(x.shape)> 1: + points = np.zeros((len(x),3)) + for j, s in enumerate(x): + points[j][0] = s[0] + points[j][1] = s[1] + points[j][2] = params[0] + else: + points = np.zeros((len(x),2)) + for j, s in enumerate(x): + points[j][0] = s + points[j][1] = params[0] + return points + + def build_model(self, dim = 1): + """ + builds model based on dimension of input data + """ + layers = self.layers.copy() + layers.insert(0, dim + 1) + layers.append(dim) + layers_torch = [] + for i in range(len(layers) - 2): + layers_torch.append(nn.Linear(layers[i], layers[i + 1])) + layers_torch.append(self.function) + layers_torch.append(nn.Linear(layers[-2], layers[-1])) + self.model = nn.Sequential(*layers_torch) + + def train_shiftnet(self, db, ref_data, preshift = False, + optimizer = torch.optim.Adam, learning_rate = 0.0001, + frequency_print = 0): + """ + Trains and evaluates shiftnet given 1d data 'db' + + :param Database db: data at a certain parameter value + :param list shift_layers: ordered list with number of neurons in each layer + :param torch.nn.module.activation shift_function: the activation function used by the shiftnet + :param int, float, or list stop_training: + int: number of epochs before stopping + float: desired tolarance before stopping training + list: a int and a float, stops when either desired epochs or tolerance is reached + :param Database db: data at the reference datapoint + :param boolean preshift: True if preshift is desired otherwise false. + """ + self.layers = self.shift_layers + self.function = self.shift_function + if preshift: + x = self.pre_shift(db.space[0], db.snapshots[0], ref_data.snapshots[0]) + else: + x = db.space[0] + if len(db.space.shape) > 2: + x_reshaped = x.reshape(-1,2) + self.build_model(dim = 2) + else: + self.build_model(dim = 1) + x_reshaped = x.reshape(-1,1) + + values = db.snapshots.reshape(-1,1) + + self.stop_training = self.shift_stop_training + points = self.make_points(x, db.parameters) + + self.optimizer = optimizer(self.model.parameters(), lr = learning_rate) + + self.loss = self.shift_loss + points = torch.from_numpy(points).float() + self.loss_trend = [] + n_epoch = 1 + flag = True + while flag: + shift = self.model(points) + x_shift, y = self.shift( + torch.from_numpy(x_reshaped).float(), + torch.from_numpy(values).float(), + shift) + ref_interp = self.interp_net.model(x_shift) + loss = self.loss(ref_interp, y) + loss.backward() + self.optimizer.step() + self.loss_trend.append(loss.item()) + for criteria in self.stop_training: + if isinstance(criteria, int): # stop criteria is an integer + if n_epoch == criteria: + flag = False + elif isinstance(criteria, float): # stop criteria is float + if loss.item() < criteria: + flag = False + if frequency_print != 0: + if n_epoch % frequency_print == 1: + print(loss.item()) + n_epoch += 1 + + new_point = self.make_points(x, db.parameters) + shift = self.model(torch.from_numpy(new_point).float()) + x_new = self.shift( + torch.from_numpy(x_reshaped).float(), + torch.from_numpy(values).float(), + shift)[0] + x_ret = x_new.detach().numpy() + return x_ret + + def fit(self, db, interp_file): + self.interp_path = interp_file + ## input variables: load files. + self.train_interpnet(db[self.ref_point], retrain = False, frequency_print = 25) + new_x = np.zeros(shape = db.space.shape) + i = 0 + while i < db.parameters.shape[0]: + if len(db.space.shape) > 2: + new_x[i] = self.train_shiftnet(db[i], db[self.ref_point], preshift = True, frequency_print = 50).reshape(-1, 2) + else: + new_x[i] = self.train_shiftnet(db[i], db[self.ref_point], preshift = True, frequency_print = 50).reshape(-1) + i+=1 + if i == self.ref_point: + new_x[self.ref_point] = db.space[self.ref_point] + i +=1 + db = Database(space = new_x, snapshots = db.snapshots, parameters = db.parameters) + + i = 0 + new_snapshots = np.zeros(shape = db.snapshots.shape) + new_space = np.zeros(shape = db.space.shape) + while i < db.parameters.shape[0]: + if len(db.space.shape) > 2: + new_snapshots[i] = self.interp_net.model(torch.from_numpy(db.space[i].reshape(-1,2)).float()).detach().numpy().reshape(-1) + else: + new_snapshots[i] = self.interp_net.model(torch.from_numpy(db.space[i].reshape(-1,1)).float()).detach().numpy().reshape(-1) + new_space[i] = db.space[self.ref_point] + i+=1 + if i == self.ref_point: + new_snapshots[self.ref_point] = db.snapshots[self.ref_point] + new_space[self.ref_point] = db.space[self.ref_point] + i +=1 + + db = Database(space = new_space, snapshots = new_snapshots, parameters = db.parameters) + POD_ = POD(method = 'svd') + return POD_.fit(db.snapshots) + diff --git a/tutorials/interpnet1d.pth b/tutorials/interpnet1d.pth new file mode 100644 index 00000000..847f4f09 Binary files /dev/null and b/tutorials/interpnet1d.pth differ diff --git a/tutorials/interpnet2d.pth b/tutorials/interpnet2d.pth new file mode 100644 index 00000000..19bcd99f Binary files /dev/null and b/tutorials/interpnet2d.pth differ diff --git a/tutorials/tutorial-3.ipynb b/tutorials/tutorial-3.ipynb new file mode 100644 index 00000000..e5929d6e --- /dev/null +++ b/tutorials/tutorial-3.ipynb @@ -0,0 +1,1544 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# EZyRB Tutorial 3\n", + "## Use NNsPOD to help with POD\n", + "\n", + "In this tutorial we show how to set up and use the NNsPOD class in order to make all data align, allowing the use of POD." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To do this we will show a simple example where the data is a moving gaussian wave.\n", + "\n", + "the first step is to import necessary packages" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import matplotlib\n", + "import torch\n", + "import torch.nn as nn\n", + "from ezyrb.nnspod import NNsPOD\n", + "from ezyrb import Database\n", + "from ezyrb import POD\n", + "matplotlib.use('Qt5Agg')\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1d Data\n", + "\n", + "Now we make the data we will use. We make a simple gaussian function and populate the space, snapshots, and parameters of the database" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "n_params = 15\n", + "params = np.linspace(0.5, 4.5, n_params).reshape(-1, 1) # actually the time steps\n", + "def gaussian(x, mu, sig):\n", + " return np.exp(-np.power(x - mu, 2.) / (2 * np.power(sig, 2.)))\n", + "def wave(t, res=256):\n", + " x = np.linspace(0, 5, res)\n", + " return x, gaussian(x, t, 0.1)\n", + "\n", + "db = np.array([wave(t)[1] for t in params])\n", + "db_array = np.array([wave(t)[1] for t in params])\n", + "space = wave(0)[0]\n", + "space_array = np.array([wave(t)[0] for t in params])\n", + "\n", + "database = Database(space = space_array, snapshots = db_array, parameters = params)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we make a NNsPOD class, the only value to pass in is where you want to save the interpnet, or where you want to load it from. This is especially usefull with 2d data where training can take hours depending on the size of the dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "ref_data = 5\n", + "NNsPOD_tutorial = NNsPOD(None, [20,20], nn.Sigmoid(), [0.000001], ref_data,\n", + " shift_layers = [20,20,20], shift_function = nn.Tanh(), shift_stop_training =[10000, 0.00001]\n", + " \n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we train the interpnet. the data to pass in is the reference database, the shape of the layers of the NN, the trainnig function the stop training value, which can be a float(loss value to stop at), int(epoch to stop at), or both(will stop at whichever is reached first).The loss function(MSE by default). and whether you would like to retrain NN or load a saved NN. If you choose to retrain the loss value at each epoch will be printed out." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "interpnet1d.pth\n", + "loaded interpnet\n" + ] + } + ], + "source": [ + "\n", + "NNsPOD_tutorial.train_interpnet(database[ref_data],retrain = False, frequency_print = 5, interp_file = \"interpnet1d.pth\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can graph the original reference data as well as the data we get from the interpNet after feeding it 1000 positional datapoints. The large points are the original data points, and the small points are ones created by the interpnet. It should be clear that the interpnet is able to accuratly replicate the gaussian" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "plt.plot(database[ref_data].space, database[ref_data].snapshots, \"o\")\n", + "xi = np.linspace(0,5,1000).reshape(-1,1)\n", + "yi = NNsPOD_tutorial.interp_net.predict(xi)\n", + "plt.plot(xi,yi, \".\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we train the shiftnet on all data besides the reference. to do this you must pass in the database at the value, the shape of the NN, the training function, the stop training value, the reference database, and if you would like the data to be preshifted. For the shiftnet it can be useful to put a loss value and epoch value to stop at, as there is a minimum level the loss value can reach, and if you put a lower value the neural net will not stop.\n", + "\n", + "Training the shiftnet also prints out the loss value at every epoch" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.021576717495918274\n", + "0.018197430297732353\n", + "0.014862452633678913\n", + "0.011553604155778885\n", + "0.00839116983115673\n", + "0.005527989007532597\n", + "0.0031261462718248367\n", + "0.0013361244928091764\n", + "0.0002760965726338327\n", + "0.009425859898328781\n", + "0.005659420974552631\n", + "0.0026965439319610596\n", + "0.0007226068992167711\n", + "1.6661051631672308e-05\n", + "0.0007273565279319882\n", + "0.0027560549788177013\n", + "0.005670929327607155\n", + "0.008676070719957352\n", + "0.010716329328715801\n", + "0.010876988060772419\n", + "0.008965139277279377\n", + "0.005723173264414072\n", + "0.002428542822599411\n", + "0.00032989634200930595\n", + "0.0002644015767145902\n", + "0.0024630895350128412\n", + "0.006529766134917736\n", + "0.011595336720347404\n", + "0.01659214496612549\n", + "0.02053970657289028\n", + "0.02274174429476261\n", + "0.022853480651974678\n", + "0.020864220336079597\n", + "0.017085248604416847\n", + "0.012163663282990456\n", + "0.007046045735478401\n", + "0.002808806486427784\n", + "0.0003732281329575926\n", + "0.04303792864084244\n", + "0.03818920627236366\n", + "0.03288600593805313\n", + "0.027048692107200623\n", + "0.020863473415374756\n", + "0.014669893309473991\n", + "0.008941804990172386\n", + "0.004238499328494072\n", + "0.0011151344515383244\n", + "1.9768831407418475e-06\n", + "0.007210953626781702\n", + "0.0040014563128352165\n", + "0.0016342208255082369\n", + "0.0002741178323049098\n", + "0.00014190685760695487\n", + "0.0013017948949709535\n", + "0.0035335379652678967\n", + "0.006214252207428217\n", + "0.008304495364427567\n", + "0.008700933307409286\n", + "0.007055965717881918\n", + "0.0041963569819927216\n", + "0.0014894729247316718\n", + "8.464818529319018e-05\n", + "0.0005883289268240333\n", + "0.003006851766258478\n", + "0.006825079210102558\n", + "0.011184435337781906\n", + "0.01512390747666359\n", + "0.01782083511352539\n", + "0.018740613013505936\n", + "0.01768900454044342\n", + "0.014845665544271469\n", + "0.01077677495777607\n", + "0.006363732274621725\n", + "0.002611860865727067\n", + "0.0003844686143565923\n", + "0.06169769540429115\n", + "0.058381494134664536\n", + "0.05433915928006172\n", + "0.04933427274227142\n", + "0.0432819128036499\n", + "0.036218151450157166\n", + "0.028347400948405266\n", + "0.020118918269872665\n", + "0.0122715774923563\n", + "0.00573799479752779\n", + "0.001442683394998312\n", + "4.386320506455377e-05\n", + "0.0017070393078029156\n", + "0.006024369038641453\n", + "0.012133130803704262\n", + "0.01897878758609295\n", + "0.02559839002788067\n", + "0.031321700662374496\n", + "0.0358378142118454\n", + "0.039108872413635254\n", + "0.04122433811426163\n", + "0.04229919612407684\n", + "0.042426154017448425\n", + "0.041660331189632416\n", + "0.04002184420824051\n", + "0.037509094923734665\n", + "0.03412190452218056\n", + "0.02989368699491024\n", + "0.024931874126195908\n", + "0.019458262249827385\n", + "0.013834601268172264\n", + "0.008552074432373047\n", + "0.004169863648712635\n", + "0.0012067914940416813\n", + "1.7138121620519087e-05\n", + "0.06644221395254135\n", + "0.0647602453827858\n", + "0.06264251470565796\n", + "0.059907376766204834\n", + "0.05640283599495888\n", + "0.051982518285512924\n", + "0.04655303433537483\n", + "0.04011589288711548\n", + "0.032818883657455444\n", + "0.025001294910907745\n", + "0.0172053724527359\n", + "0.010131403803825378\n", + "0.004527377896010876\n", + "0.0010316736297681928\n", + "1.385439645673614e-05\n", + "0.0014708599774166942\n", + "0.005022689234465361\n", + "0.010014653205871582\n", + "0.015687840059399605\n", + "0.02135387249290943\n", + "0.026505593210458755\n", + "0.030837208032608032\n", + "0.03420782834291458\n", + "0.036586254835128784\n", + "0.03799882531166077\n", + "0.03849009796977043\n", + "0.03809867054224014\n", + "0.03684785217046738\n", + "0.03474872186779976\n", + "0.03181450814008713\n", + "0.028085049241781235\n", + "0.02365906722843647\n", + "0.018728353083133698\n", + "0.013601750135421753\n", + "0.008702464401721954\n", + "0.004524555988609791\n", + "0.0015478282002732158\n", + "0.00013203418347984552\n", + "0.00042849991586990654\n", + "0.002345738001167774\n", + "0.005582794547080994\n", + "0.009713300503790379\n", + "0.014284427277743816\n", + "0.018896684050559998\n", + "0.023247715085744858\n", + "0.027141187340021133\n", + "0.030472425743937492\n", + "0.03320423513650894\n", + "0.03534217178821564\n", + "0.036914411932229996\n", + "0.03795753791928291\n", + "0.03850764036178589\n", + "0.038595862686634064\n", + "0.03824695944786072\n", + "0.03747975453734398\n", + "0.03630919009447098\n", + "0.03474922850728035\n", + "0.03281635046005249\n", + "0.030533386394381523\n", + "0.027933379635214806\n", + "0.02506263740360737\n", + "0.02198323979973793\n", + "0.01877349615097046\n", + "0.015526765026152134\n", + "0.012347851879894733\n", + "0.009347422048449516\n", + "0.006634680088609457\n", + "0.0043092286214232445\n", + "0.0024531283415853977\n", + "0.0011242154287174344\n", + "0.00035156618105247617\n", + "0.00013377639697864652\n", + "0.0004401349578984082\n", + "0.001214438583701849\n", + "0.002380803693085909\n", + "0.0038505096454173326\n", + "0.0055290767922997475\n", + "0.007322595454752445\n", + "0.009142911992967129\n", + "0.010911216959357262\n", + "0.012560092844069004\n", + "0.014034277759492397\n", + "0.015290258452296257\n", + "0.016295338049530983\n", + "0.017026308923959732\n", + "0.017468100413680077\n", + "0.017612628638744354\n", + "0.01745811477303505\n", + "0.01700875163078308\n", + "0.01627495512366295\n", + "0.015274074859917164\n", + "0.014031353406608105\n", + "0.0125809945166111\n", + "0.010967165231704712\n", + "0.009244360029697418\n", + "0.007477018516510725\n", + "0.005737995263189077\n", + "0.00410578865557909\n", + "0.0026604270096868277\n", + "0.0014784163795411587\n", + "0.0006270904559642076\n", + "0.00015916969277895987\n", + "0.00010819463932421058\n", + "0.0004856055020354688\n", + "0.0012798289535567164\n", + "0.0024576301220804453\n", + "0.003967359662055969\n", + "0.005743705667555332\n", + "0.007713082246482372\n", + "0.009799117222428322\n", + "0.011927350424230099\n", + "0.01402902789413929\n", + "0.016043594107031822\n", + "0.01791994646191597\n", + "0.019616704434156418\n", + "0.02110179141163826\n", + "0.022351322695612907\n", + "0.023348499089479446\n", + "0.024082379415631294\n", + "0.024546628817915916\n", + "0.02473861537873745\n", + "0.024658722802996635\n", + "0.024310005828738213\n", + "0.023698018863797188\n", + "0.02283119596540928\n", + "0.021721335127949715\n", + "0.020384356379508972\n", + "0.01884119212627411\n", + "0.017118705436587334\n", + "0.015250563621520996\n", + "0.013277867808938026\n", + "0.011249088682234287\n", + "0.009219637140631676\n", + "0.007250454276800156\n", + "0.005405657924711704\n", + "0.0037496155127882957\n", + "0.0023431533481925726\n", + "0.001239611767232418\n", + "0.0004810524405911565\n", + "9.508390940027311e-05\n", + "9.278312063543126e-05\n", + "0.000467951234895736\n", + "0.0011978597613051534\n", + "0.0022453144192695618\n", + "0.003561766818165779\n", + "0.005091031547635794\n", + "0.0067731840535998344\n", + "0.008548134006559849\n", + "0.010358665138483047\n", + "0.012152724899351597\n", + "0.013884753920137882\n", + "0.0155164934694767\n", + "0.017016900703310966\n", + "0.01836181990802288\n", + "0.019533153623342514\n", + "0.020518075674772263\n", + "0.021308094263076782\n", + "0.021898208186030388\n", + "0.022286223247647285\n", + "0.02247222326695919\n", + "0.022458158433437347\n", + "0.02224757894873619\n", + "0.021845676004886627\n", + "0.02125927060842514\n", + "0.020496997982263565\n", + "0.01956959255039692\n", + "0.018490182235836983\n", + "0.017274508252739906\n", + "0.015941254794597626\n", + "0.014512143097817898\n", + "0.013011864386498928\n", + "0.011467988602817059\n", + "0.009910468012094498\n", + "0.00837093498557806\n", + "0.006881913170218468\n", + "0.0054756272584199905\n", + "0.004182853270322084\n", + "0.0030317013151943684\n", + "0.002046398352831602\n", + "0.0012463353341445327\n", + "0.0006452340167015791\n", + "0.00025074713630601764\n", + "6.432300870073959e-05\n", + "8.141293074004352e-05\n", + "0.0002919957332778722\n", + "0.0006813097279518843\n", + "0.0012307793367654085\n", + "0.0019189915619790554\n", + "0.002722718520089984\n", + "0.003617832437157631\n", + "0.004580203909426928\n", + "0.00558634614571929\n", + "0.006614023819565773\n", + "0.007642619777470827\n", + "0.008653474040329456\n", + "0.009629963897168636\n", + "0.010557576082646847\n", + "0.011423877440392971\n", + "0.012218430638313293\n", + "0.01293264701962471\n", + "0.013559592887759209\n", + "0.014093930833041668\n", + "0.014531675726175308\n", + "0.014870106242597103\n", + "0.015107572078704834\n", + "0.015243458561599255\n", + "0.015278050675988197\n", + "0.015212523750960827\n", + "0.01504885870963335\n", + "0.014789791777729988\n", + "0.014438888989388943\n", + "0.014000498689711094\n", + "0.013479791581630707\n", + "0.01288274209946394\n", + "0.012216164730489254\n", + "0.011487750336527824\n", + "0.010706046596169472\n", + "0.0098804272711277\n", + "0.009021093137562275\n", + "0.008138997480273247\n", + "0.0072457254864275455\n", + "0.006353435106575489\n", + "0.005474673584103584\n", + "0.004622177220880985\n", + "0.0038086881395429373\n", + "0.003046720754355192\n", + "0.0023482847027480602\n", + "0.0017246251227334142\n", + "0.0011859589722007513\n", + "0.0007412132690660655\n", + "0.00039775195182301104\n", + "0.0001611923798918724\n", + "3.520934114931151e-05\n", + "2.141118784493301e-05\n", + "0.00011926355364266783\n", + "0.0003260810044594109\n", + "0.0006370742339640856\n", + "0.001045465818606317\n", + "0.0015426523750647902\n", + "0.002118440577760339\n", + "0.0027612841222435236\n", + "0.0034586070105433464\n", + "0.004197070840746164\n", + "0.0049629113636910915\n", + "0.005742225330322981\n", + "0.006521232891827822\n", + "0.00728656305000186\n", + "0.008025462739169598\n", + "0.008725926280021667\n", + "0.00937690306454897\n", + "0.009968371130526066\n", + "0.010491407476365566\n", + "0.010938257910311222\n", + "0.011302358470857143\n", + "0.011578329838812351\n", + "0.011761984787881374\n", + "0.011850402690470219\n", + "0.01184182520955801\n", + "0.01173577830195427\n", + "0.011533013544976711\n", + "0.011235598474740982\n", + "0.010846925899386406\n", + "0.010371768847107887\n", + "0.009816325269639492\n", + "0.009188266471028328\n", + "0.008496719412505627\n", + "0.007752353325486183\n", + "0.006967286113649607\n", + "0.006155043374747038\n", + "0.005330459214746952\n", + "0.0045094904489815235\n", + "0.003708993084728718\n", + "0.002946417545899749\n", + "0.0022394577972590923\n", + "0.0016055881278589368\n", + "0.0010616187937557697\n", + "0.0006231768056750298\n", + "0.00030415289802476764\n", + "0.00011620117584243417\n", + "6.826104799984023e-05\n", + "0.00016614978085272014\n", + "0.00041225310997106135\n", + "0.0008053550263866782\n", + "0.001340570510365069\n", + "0.002009493065997958\n", + "0.002800383372232318\n", + "0.0036986172199249268\n", + "0.004687127657234669\n", + "0.005746973678469658\n", + "0.006857965141534805\n", + "0.007999259978532791\n", + "0.009149918332695961\n", + "0.010289512574672699\n", + "0.011398466303944588\n", + "0.012458469718694687\n", + "0.013452782295644283\n", + "0.01436629518866539\n", + "0.01518571749329567\n", + "0.015899617224931717\n", + "0.01649831421673298\n", + "0.016973918303847313\n", + "0.01732020452618599\n", + "0.01753259263932705\n", + "0.017608102411031723\n", + "0.017545299604535103\n", + "0.01734430529177189\n", + "0.017006929963827133\n", + "0.016536619514226913\n", + "0.01593867689371109\n", + "0.015220334753394127\n", + "0.014390864409506321\n", + "0.013461733236908913\n", + "0.012446590699255466\n", + "0.011361307464540005\n", + "0.010223855264484882\n", + "0.009054076857864857\n", + "0.007873421534895897\n", + "0.006704377941787243\n", + "0.005570011213421822\n", + "0.004493238870054483\n", + "0.0034960980992764235\n", + "0.0025990379508584738\n", + "0.0018201640341430902\n", + "0.0011746272211894393\n", + "0.0006740806275047362\n", + "0.00032632608781568706\n", + "0.000135106107336469\n", + "0.00010009175457525998\n", + "0.00021704615210182965\n", + "0.0004781255847774446\n", + "0.000872315198648721\n", + "0.0013859531609341502\n", + "0.002003282541409135\n", + "0.0027070725336670876\n", + "0.003479147097095847\n", + "0.0043009547516703606\n", + "0.005154035519808531\n", + "0.0060204025357961655\n", + "0.006882912013679743\n", + "0.007725486997514963\n", + "0.00853332132101059\n", + "0.009293001145124435\n", + "0.009992564097046852\n", + "0.010621577501296997\n", + "0.011171037331223488\n", + "0.011633478105068207\n", + "0.01200283132493496\n", + "0.012274467386305332\n", + "0.012445105239748955\n", + "0.012512858025729656\n", + "0.012477110140025616\n", + "0.012338664382696152\n", + "0.012099642306566238\n", + "0.011763506568968296\n", + "0.011335132643580437\n", + "0.010820768773555756\n", + "0.010228062979876995\n", + "0.009566043503582478\n", + "0.008845048025250435\n", + "0.00807669386267662\n", + "0.007273687981069088\n", + "0.006449711974710226\n", + "0.005619166884571314\n", + "0.004796959459781647\n", + "0.003998138010501862\n", + "0.003237646073102951\n", + "0.0025298886466771364\n", + "0.0018884205492213368\n", + "0.0013256153324618936\n", + "0.0008522987482137978\n", + "0.000477493362268433\n", + "0.00020819163182750344\n", + "4.917894329992123e-05\n", + "0.014750871807336807\n", + "0.01039370708167553\n", + "0.00654645124450326\n", + "0.003356589935719967\n", + "0.00113444565795362\n", + "0.00015178375178948045\n", + "0.0005520590348169208\n", + "0.0022779800929129124\n", + "0.005032084416598082\n", + "0.008299550041556358\n", + "0.01142976526170969\n", + "0.013751139864325523\n", + "0.014702635817229748\n", + "0.013981902971863747\n", + "0.011671886779367924\n", + "0.008280996233224869\n", + "0.004640574567019939\n", + "0.0016813843976706266\n", + "0.00017625238979235291\n", + "0.0005404774565249681\n", + "0.0027509378269314766\n", + "0.006399288307875395\n", + "0.010843554511666298\n", + "0.015389041043817997\n", + "0.019432013854384422\n", + "0.022533230483531952\n", + "0.024426830932497978\n", + "0.0249926894903183\n", + "0.024222832173109055\n", + "0.022202910855412483\n", + "0.019115574657917023\n", + "0.015255775302648544\n", + "0.011034531518816948\n", + "0.006947942078113556\n", + "0.0035054772160947323\n", + "0.001135458704084158\n", + "0.00010128309077117592\n", + "0.0004583543923217803\n", + "0.002063436433672905\n", + "0.004626723006367683\n", + "0.007784146815538406\n", + "0.011166420765221119\n", + "0.01444930862635374\n", + "0.017379965633153915\n", + "0.019782420247793198\n", + "0.02154950052499771\n", + "0.022628234699368477\n", + "0.023004328832030296\n", + "0.022689536213874817\n", + "0.02171437069773674\n", + "0.02012675255537033\n", + "0.01799573190510273\n", + "0.015418622642755508\n", + "0.012528576888144016\n", + "0.009499252773821354\n", + "0.00654290895909071\n", + "0.0038984655402600765\n", + "0.0018083793111145496\n", + "0.0004864291986450553\n", + "8.289789548143744e-05\n", + "0.0006562626222148538\n", + "0.00216035358607769\n", + "0.004451609216630459\n", + "0.007314395625144243\n", + "0.010496897622942924\n", + "0.013747379183769226\n", + "0.016842836514115334\n", + "0.019605811685323715\n", + "0.02190973050892353\n", + "0.02367556281387806\n", + "0.02486414462327957\n", + "0.025466591119766235\n", + "0.02549501322209835\n", + "0.024975180625915527\n", + "0.0239424891769886\n", + "0.022441376000642776\n", + "0.020527418702840805\n", + "0.018270516768097878\n", + "0.01575767993927002\n", + "0.013093679212033749\n", + "0.010398561134934425\n", + "0.007801433559507132\n", + "0.005430798977613449\n", + "0.003403223818168044\n", + "0.00181202357634902\n", + "0.0007186412112787366\n", + "0.00014820904470980167\n", + "8.98896687431261e-05\n", + "0.0005015245405957103\n", + "0.001317191869020462\n", + "0.0024560410529375076\n", + "0.003830817760899663\n", + "0.005355048459023237\n", + "0.006948340684175491\n", + "0.008539719507098198\n", + "0.010069239884614944\n", + "0.011488203890621662\n", + "0.01275866199284792\n", + "0.01385211106389761\n", + "0.014748349785804749\n", + "0.015433945693075657\n", + "0.015901174396276474\n", + "0.016146965324878693\n", + "0.016172174364328384\n", + "0.015981219708919525\n", + "0.015581697225570679\n", + "0.01498444564640522\n", + "0.014203705824911594\n", + "0.013257281854748726\n", + "0.012166814878582954\n", + "0.010957998223602772\n", + "0.009660599753260612\n", + "0.008308309130370617\n", + "0.006938313599675894\n", + "0.005590461194515228\n", + "0.004306095186620951\n", + "0.0031265774741768837\n", + "0.0020914808847010136\n", + "0.0012367463205009699\n", + "0.0005927263991907239\n", + "0.00018249571439810097\n", + "2.0446075723157264e-05\n", + "0.00011135855311295018\n", + "0.00045007007429376245\n", + "0.0010217364178970456\n", + "0.0018026973120868206\n", + "0.0027618363965302706\n", + "0.003862340236082673\n", + "0.005063636694103479\n", + "0.0063233757391572\n", + "0.0075993589125573635\n", + "0.00885116308927536\n", + "0.010041488334536552\n", + "0.01113723311573267\n", + "0.01211010292172432\n", + "0.012937000021338463\n", + "0.013600091449916363\n", + "0.014086619950830936\n", + "0.01438863854855299\n", + "0.014502688311040401\n", + "0.014429517090320587\n", + "0.014173837378621101\n", + "0.013744128867983818\n", + "0.01315251924097538\n", + "0.012414654716849327\n", + "0.011549500748515129\n", + "0.010579122230410576\n", + "0.009528273716568947\n", + "0.00842381827533245\n", + "0.007294075563549995\n", + "0.006167891900986433\n", + "0.0050736889243125916\n", + "0.004038454964756966\n", + "0.0030867336317896843\n", + "0.0022398107685148716\n", + "0.0015149833634495735\n", + "0.0009251396404579282\n", + "0.000478517496958375\n", + "0.00017878094513434917\n", + "2.5246172299375758e-05\n", + "0.05785323679447174\n", + "0.05510306358337402\n", + "0.05190043896436691\n", + "0.04808390513062477\n", + "0.04359421879053116\n", + "0.03842729702591896\n", + "0.0326523631811142\n", + "0.026432041078805923\n", + "0.0200360044836998\n", + "0.01384079921990633\n", + "0.008302401751279831\n", + "0.003897265996783972\n", + "0.0010420067701488733\n", + "7.209966497612186e-06\n", + "0.03894026204943657\n", + "0.03362800553441048\n", + "0.02801947295665741\n", + "0.022103294730186462\n", + "0.01613111048936844\n", + "0.01049034483730793\n", + "0.0056516374461352825\n", + "0.002098629716783762\n", + "0.00023438036441802979\n", + "0.000274478254141286\n", + "0.0021767320577055216\n", + "0.005631352309137583\n", + "0.010112211108207703\n", + "0.014979765750467777\n", + "0.019637107849121094\n", + "0.0236451905220747\n", + "0.026734858751296997\n", + "0.02877196855843067\n", + "0.02971198409795761\n", + "0.029559288173913956\n", + "0.028344010934233665\n", + "0.02612127736210823\n", + "0.022988755255937576\n", + "0.01911473646759987\n", + "0.014763519167900085\n", + "0.010302257724106312\n", + "0.00617592316120863\n", + "0.0028468144591897726\n", + "0.0007109406869858503\n", + "1.6628953744657338e-05\n", + "0.0008129460620693862\n", + "0.002945574000477791\n", + "0.006099552847445011\n", + "0.009871837683022022\n", + "0.013848711736500263\n", + "0.01766674779355526\n", + "0.021046102046966553\n", + "0.02379748970270157\n", + "0.02581043541431427\n", + "0.027033472433686256\n", + "0.027454476803541183\n", + "0.027086015790700912\n", + "0.025958241894841194\n", + "0.024118687957525253\n", + "0.021638615056872368\n", + "0.01862272433936596\n", + "0.015219843946397305\n", + "0.01162953395396471\n", + "0.008099888451397419\n", + "0.0049120537005364895\n", + "0.0023503752890974283\n", + "0.0006620745989494026\n", + "1.5730347513454035e-05\n", + "0.00047031990834511817\n", + "0.001964575843885541\n", + "0.004330150783061981\n", + "0.007324078120291233\n", + "0.010670541785657406\n", + "0.014100567437708378\n", + "0.017381839454174042\n", + "0.02033509872853756\n", + "0.022838160395622253\n", + "0.024820156395435333\n", + "0.026250243186950684\n", + "0.027124624699354172\n", + "0.027455274015665054\n", + "0.02726190909743309\n", + "0.02656821720302105\n", + "0.025401366874575615\n", + "0.023793965578079224\n", + "0.0217877347022295\n", + "0.019437870010733604\n", + "0.016817202791571617\n", + "0.014019113034009933\n", + "0.01115772407501936\n", + "0.008364609442651272\n", + "0.005781251937150955\n", + "0.0035477711353451014\n", + "0.0017893024487420917\n", + "0.0006024610484018922\n", + "4.444699516170658e-05\n" + ] + } + ], + "source": [ + "i = 0\n", + "while i < 10:\n", + " x_new = NNsPOD_tutorial.train_shiftnet(database[i], database[ref_data], preshift = True, frequency_print = 5, learning_rate = 0.0001) \n", + " db = database[i] \n", + " plt.plot(db.space, db.snapshots, \"go\")\n", + " plt.plot(x_new, db.snapshots.reshape(-1,1), \".\")\n", + " i+=1\n", + " if i == ref_data:\n", + " i +=1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we plot and show all the data. The original positions is represented by green circles, the reference data is the blue plusmarks, and the different shifted data is represented by the smaller dots. It should be clear that all of the data has been moved to allign with the reference data" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "plt.plot(database[0].space, database[ref_data].snapshots, \"b+\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we use the NNsPOD fit function, and we get get the modes and singular values" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loaded interpnet\n", + "0.0011340980418026447\n", + "0.037661511451005936\n", + "0.0011428085854277015\n", + "0.018658315762877464\n", + "0.010687826201319695\n", + "0.005290625151246786\n", + "0.022084934636950493\n", + "0.006550148129463196\n", + "0.004208308644592762\n", + "0.01799725741147995\n", + "0.00940329022705555\n", + "0.00030273222364485264\n", + "0.012823530472815037\n", + "0.018274104222655296\n", + "0.008066706359386444\n", + "1.0090350770042278e-05\n", + "0.0055536674335598946\n", + "0.005161964800208807\n", + "0.015556512400507927\n", + "0.010334138758480549\n", + "0.008668169379234314\n", + "0.015472044236958027\n", + "0.01992838829755783\n", + "0.0002611768722999841\n", + "0.005396030843257904\n", + "0.004713937174528837\n", + "0.015291141346096992\n", + "0.023489514365792274\n", + "0.009426423348486423\n", + "0.002368635730817914\n", + "0.06547225266695023\n", + "0.020965032279491425\n", + "0.02389248088002205\n", + "0.027566159144043922\n", + "0.003045934485271573\n", + "0.031973134726285934\n", + "0.01897011697292328\n", + "0.0004547239514067769\n", + "0.01814550906419754\n", + "0.017512764781713486\n", + "0.0004336966958362609\n", + "0.013621974736452103\n", + "0.02712256833910942\n", + "0.016880594193935394\n", + "0.0003777860547415912\n", + "0.02320970594882965\n", + "0.0007515900651924312\n", + "0.015391875058412552\n", + "0.0584186390042305\n", + "0.005546580534428358\n", + "0.03250119462609291\n", + "0.013546831905841827\n", + "0.0033655676525086164\n", + "0.002723336685448885\n", + "0.005520799662917852\n", + "0.008607450872659683\n", + "0.032345738261938095\n", + "0.010342801921069622\n", + "0.009515229612588882\n", + "0.01155978161841631\n", + "0.05270132049918175\n", + "0.007957477122545242\n", + "0.052852530032396317\n", + "0.0010104936081916094\n", + "0.0340232215821743\n", + "0.011018244549632072\n", + "0.00018309219740331173\n", + "0.004560935311019421\n", + "0.0707082450389862\n", + "0.04232151806354523\n", + "0.03979899361729622\n", + "0.00926896184682846\n", + "0.01903892681002617\n", + "0.023522906005382538\n", + "0.02911665476858616\n", + "0.001565837999805808\n", + "0.028716061264276505\n", + "0.010102448053658009\n", + "0.011108173988759518\n", + "0.029574397951364517\n", + "0.003892436856403947\n", + "0.014879656955599785\n", + "0.0292720478028059\n", + "0.013431227765977383\n", + "0.00037743939901702106\n", + "0.014957211911678314\n", + "0.02001883275806904\n", + "0.007836512289941311\n", + "[[-2.58198463e-01 -6.90081522e-02 5.99533825e-02 9.61757257e-01\n", + " -1.76272665e-17 2.37019593e-15 -7.42092999e-16 -6.98440470e-17\n", + " 1.19577835e-69 3.05567705e-17 -1.59163809e-75 6.54787940e-18\n", + " 5.23830352e-17 1.74610117e-17 3.49220235e-17]\n", + " [-2.58198463e-01 -6.90081522e-02 -1.22645165e-01 -6.66234422e-02\n", + " -1.78411177e-03 6.93120119e-01 6.54730631e-01 -1.67758763e-04\n", + " -5.87464405e-17 3.13725161e-16 4.50284033e-18 -5.98746728e-18\n", + " 2.22170332e-17 -9.79297043e-18 -2.25677372e-17]\n", + " [-2.58198463e-01 -6.90081522e-02 -1.22645165e-01 -6.66234422e-02\n", + " 5.22611394e-04 -7.20761415e-01 6.24174333e-01 -1.45751506e-04\n", + " -5.58744461e-17 2.67685403e-16 1.69295668e-17 9.92463816e-19\n", + " 1.30866159e-17 4.05981073e-18 -1.43529889e-17]\n", + " [-2.58198463e-01 -6.90081522e-02 -1.22645165e-01 -6.66234422e-02\n", + " 1.40175332e-04 3.08093071e-03 -1.41869223e-01 9.42843848e-01\n", + " 2.07932401e-12 -2.31166448e-12 8.64571015e-17 1.77970839e-17\n", + " 4.52233050e-17 1.98638112e-17 -1.36853316e-17]\n", + " [-2.58198463e-01 -6.90081522e-02 -1.22645165e-01 -6.66234422e-02\n", + " 1.40165630e-04 3.07004570e-03 -1.42129468e-01 -1.17816292e-01\n", + " 2.47004143e-01 9.02213364e-01 7.83226058e-06 3.89902753e-17\n", + " 1.12824117e-16 5.98602335e-18 2.95328832e-17]\n", + " [-2.58204862e-01 9.66090187e-01 6.21898366e-15 -1.02695630e-15\n", + " -1.01226374e-17 -1.79728023e-16 -4.12682016e-17 -9.70432135e-19\n", + " 3.62113052e-19 -1.60704204e-18 -6.68386587e-19 4.64828821e-18\n", + " 5.59261950e-18 5.27104376e-18 1.14534054e-17]\n", + " [-2.58198463e-01 -6.90081522e-02 -1.22645165e-01 -6.66234422e-02\n", + " 1.40165630e-04 3.07004570e-03 -1.42129468e-01 -1.17816292e-01\n", + " 8.57673342e-01 -3.73358324e-01 2.97837927e-07 2.81985641e-17\n", + " 8.20844100e-19 1.11086406e-18 1.03359653e-17]\n", + " [-2.58198463e-01 -6.90081522e-02 -1.22645165e-01 -6.66234422e-02\n", + " 1.40165630e-04 3.07004570e-03 -1.42129468e-01 -1.17816292e-01\n", + " -1.84116280e-01 -8.81495100e-02 9.12869574e-01 1.45507690e-11\n", + " 1.23555782e-11 1.32205716e-16 7.66640572e-17]\n", + " [-2.58198463e-01 -6.90081522e-02 -1.22645165e-01 -6.66234422e-02\n", + " 1.40165630e-04 3.07004570e-03 -1.42129468e-01 -1.17816292e-01\n", + " -1.84112241e-01 -8.81411060e-02 -1.82575541e-01 7.38946389e-01\n", + " 5.03942690e-01 1.04296844e-06 1.87485991e-16]\n", + " [-2.58198463e-01 -6.90081522e-02 -1.22645165e-01 -6.66234422e-02\n", + " 1.40165630e-04 3.07004570e-03 -1.42129468e-01 -1.17816292e-01\n", + " -1.84112241e-01 -8.81411060e-02 -1.82575541e-01 -6.72677008e-01\n", + " 5.89496092e-01 6.80515822e-07 1.95132432e-16]\n", + " [-2.58198463e-01 -6.90081522e-02 -1.22645165e-01 -6.66234422e-02\n", + " 1.40165630e-04 3.07004570e-03 -1.42129468e-01 -1.17816292e-01\n", + " -1.84112241e-01 -8.81411060e-02 -1.82575541e-01 -2.20900801e-02\n", + " -3.64480863e-01 8.16496006e-01 -1.70663929e-11]\n", + " [-2.58198463e-01 -6.90081522e-02 -1.22645165e-01 -6.66234422e-02\n", + " 1.40165630e-04 3.07004570e-03 -1.42129468e-01 -1.17816292e-01\n", + " -1.84112241e-01 -8.81411060e-02 -1.82575541e-01 -2.20896503e-02\n", + " -3.64478959e-01 -4.08248865e-01 7.07106781e-01]\n", + " [-2.58198463e-01 -6.90081522e-02 -1.22645165e-01 -6.66234422e-02\n", + " 1.40165630e-04 3.07004570e-03 -1.42129468e-01 -1.17816292e-01\n", + " -1.84112241e-01 -8.81411060e-02 -1.82575541e-01 -2.20896503e-02\n", + " -3.64478959e-01 -4.08248865e-01 -7.07106781e-01]\n", + " [-2.58198463e-01 -6.90081522e-02 6.44571715e-01 -1.14449696e-01\n", + " -7.07105497e-01 -1.13802502e-03 -7.22078888e-04 1.95316799e-07\n", + " 5.62273052e-18 3.29950898e-17 -7.41420268e-18 1.14415618e-18\n", + " 6.96225735e-18 -8.64775513e-18 9.70002130e-18]\n", + " [-2.58198463e-01 -6.90081522e-02 6.44571715e-01 -1.14449696e-01\n", + " 7.07105497e-01 1.13802502e-03 7.22078888e-04 -1.95316799e-07\n", + " -8.25505729e-18 -2.25160614e-17 -7.48958861e-18 9.94342099e-19\n", + " 6.82917753e-18 -8.59994584e-18 9.86051632e-18]] [1.16440157e+01 1.54573545e-02 2.93212533e-15 1.71488636e-16\n", + " 1.44714566e-31 1.09336034e-31 5.76118658e-33 1.16764654e-36\n", + " 8.30550925e-48 3.59970132e-48 2.73053874e-53 6.28589384e-64\n", + " 2.73537325e-64 5.33603813e-70 3.78431391e-80]\n" + ] + } + ], + "source": [ + "pod = NNsPOD_tutorial.fit(interp_file = \"interpnet1d.pth\",\n", + " db = database)\n", + "print(pod.modes, pod.singular_values)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[-2.58198890e-01 5.59788262e-01 7.87382013e-01 -6.19866111e-16\n", + " 2.68173182e-16 -5.67483059e-17 1.74610172e-17 -1.30957629e-17\n", + " 0.00000000e+00 8.73050861e-18 8.73050861e-18 -3.38307208e-17\n", + " 3.49220344e-17 3.27394073e-17 -3.49220344e-17]\n", + " [-2.58198890e-01 -1.28910610e-01 6.98006745e-03 9.57427108e-01\n", + " 1.48987716e-17 -1.32856031e-16 1.33295622e-16 2.25883973e-16\n", + " -2.86503678e-17 -3.35625782e-17 -1.53950992e-17 3.31617815e-17\n", + " -2.01522788e-17 4.57926075e-18 -1.24573224e-17]\n", + " [-2.58198890e-01 -1.28910610e-01 6.98006745e-03 -8.70388280e-02\n", + " -7.44298646e-17 9.49214452e-01 -2.34327101e-02 8.44009971e-02\n", + " -2.02537247e-02 -2.91328554e-17 -8.91115352e-17 -6.85128563e-18\n", + " 3.65744161e-18 -1.55735773e-17 -6.07114327e-19]\n", + " [-2.58198890e-01 -1.28910610e-01 6.98006745e-03 -8.70388280e-02\n", + " 3.41662710e-17 -1.39215195e-01 6.76185044e-01 6.39480270e-01\n", + " -1.53456212e-01 -1.46289570e-16 -6.60239786e-16 3.48154415e-17\n", + " 1.92611922e-17 -3.06586380e-17 9.90641208e-18]\n", + " [-2.58198890e-01 -1.28910610e-01 6.98006745e-03 -8.70388280e-02\n", + " 3.41662710e-17 -1.27632667e-02 5.90717752e-01 -7.27660306e-01\n", + " 1.74616794e-01 1.62330859e-16 6.53685847e-16 2.55680424e-17\n", + " -1.53578642e-17 -1.39306772e-18 -1.86945872e-17]\n", + " [-2.58198890e-01 -1.28910610e-01 6.98006745e-03 -8.70388280e-02\n", + " -2.62617044e-18 -9.96544987e-02 -1.55433761e-01 2.18747164e-01\n", + " 9.09477825e-01 -1.41285899e-16 -3.36323374e-16 6.64875574e-18\n", + " 2.07510909e-17 7.28456428e-18 2.34577989e-17]\n", + " [-2.58198890e-01 -1.28910610e-01 6.98006745e-03 -8.70388280e-02\n", + " -2.62617044e-18 -9.96544987e-02 -1.55433761e-01 -3.07097321e-02\n", + " -1.30054955e-01 4.10409597e-01 8.29883618e-01 2.31657483e-15\n", + " 1.17115661e-15 -1.15618011e-17 5.01668996e-17]\n", + " [-2.58198890e-01 -1.28910610e-01 6.98006745e-03 -8.70388280e-02\n", + " -2.62617044e-18 -9.96544987e-02 -1.55433761e-01 -3.07097321e-02\n", + " -1.30054955e-01 -8.86677882e-01 2.66355383e-01 -4.47763119e-16\n", + " -3.65961123e-17 -1.93388014e-17 1.76436076e-17]\n", + " [-2.58198890e-01 -1.28910610e-01 6.98006745e-03 -8.70388280e-02\n", + " -2.62617044e-18 -9.96544987e-02 -1.55433761e-01 -3.07097321e-02\n", + " -1.30054955e-01 9.52536571e-02 -2.19247800e-01 7.50762736e-01\n", + " 4.86163877e-01 9.21342531e-16 5.89705626e-16]\n", + " [-2.58198890e-01 -1.28910610e-01 6.98006745e-03 -8.70388280e-02\n", + " -2.62617044e-18 -9.96544987e-02 -1.55433761e-01 -3.07097321e-02\n", + " -1.30054955e-01 9.52536571e-02 -2.19247800e-01 -6.58416833e-01\n", + " 6.05381924e-01 -8.93690642e-17 -1.71658171e-16]\n", + " [-2.58198890e-01 -1.28910610e-01 6.98006745e-03 -8.70388280e-02\n", + " -2.62617044e-18 -9.96544987e-02 -1.55433761e-01 -3.07097321e-02\n", + " -1.30054955e-01 9.52536571e-02 -2.19247800e-01 -3.07819674e-02\n", + " -3.63848600e-01 -6.66035383e-01 -4.72296025e-01]\n", + " [-2.58198890e-01 -1.28910610e-01 6.98006745e-03 -8.70388280e-02\n", + " -2.62617044e-18 -9.96544987e-02 -1.55433761e-01 -3.07097321e-02\n", + " -1.30054955e-01 9.52536571e-02 -2.19247800e-01 -3.07819674e-02\n", + " -3.63848600e-01 7.42038047e-01 -3.40655549e-01]\n", + " [-2.58198890e-01 -1.28910610e-01 6.98006745e-03 -8.70388280e-02\n", + " -2.62617044e-18 -9.96544987e-02 -1.55433761e-01 -3.07097321e-02\n", + " -1.30054955e-01 9.52536571e-02 -2.19247800e-01 -3.07819674e-02\n", + " -3.63848600e-01 -7.60026642e-02 8.12951574e-01]\n", + " [-2.58198890e-01 4.93569528e-01 -4.35571411e-01 -1.61823414e-17\n", + " -7.07106781e-01 2.36176330e-17 6.12564926e-17 -1.73267876e-16\n", + " 2.39476085e-17 1.10816779e-17 6.96944634e-18 -1.10715973e-17\n", + " 1.33585933e-18 4.39421446e-18 -1.29149755e-17]\n", + " [-2.58198890e-01 4.93569528e-01 -4.35571411e-01 -7.16934926e-17\n", + " 7.07106781e-01 2.36176330e-17 -5.84394272e-17 1.87554607e-16\n", + " -3.85024366e-17 1.10816779e-17 6.96944634e-18 -1.10715973e-17\n", + " 1.33585933e-18 4.39421446e-18 -1.29149755e-17]] [1.16444287e+001 9.09725165e-016 1.23874383e-016 1.45246438e-031\n", + " 2.81726618e-032 3.32593487e-047 6.98909361e-049 2.48708947e-064\n", + " 3.32046152e-065 1.76147819e-081 1.21941048e-081 5.94958006e-098\n", + " 3.77252322e-098 2.98723364e-114 4.78768273e-115]\n" + ] + } + ], + "source": [ + "n_params = 15\n", + "params = np.linspace(0.5, 4.5, n_params).reshape(-1, 1) # actually the time steps\n", + "def gaussian(x, mu, sig):\n", + " return np.exp(-np.power(x - mu, 2.) / (2 * np.power(sig, 2.)))\n", + "def wave(t, res=256):\n", + " x = np.linspace(0, 5, res)\n", + " return x, gaussian(x, 2, 0.1)\n", + "\n", + "db = np.array([wave(t)[1] for t in params])\n", + "db_array = np.array([wave(t)[1] for t in params])\n", + "space = wave(0)[0]\n", + "space_array = np.array([wave(t)[0] for t in params])\n", + "\n", + "database = Database(space = space_array, snapshots = db_array, parameters = params)\n", + "\n", + "pod = POD()\n", + "\n", + "pod.fit(database.snapshots)\n", + "print(pod.modes, pod.singular_values)\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2d Data\n", + "\n", + "Now we do the same but with 2d data and implement some basic functions to help with shaping the data" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "def reshape2dto1d(x, y):\n", + " x = x.reshape(-1,1)\n", + " y = y.reshape(-1,1)\n", + " coords = np.concatenate((x, y), axis = 1)\n", + " coords = np.array(coords).reshape(-1,2)\n", + " \n", + " return coords\n", + "\n", + "def reshape1dto2d(snapshots):\n", + " return snapshots.reshape(int(np.sqrt(len(snapshots))), int(np.sqrt(len(snapshots))))\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we create the 2d gaussian and populate the database" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[0.5 0.5]\n", + "[0.94444444 0.94444444]\n", + "[1.38888889 1.38888889]\n", + "[1.83333333 1.83333333]\n", + "[2.27777778 2.27777778]\n", + "[2.72222222 2.72222222]\n", + "[3.16666667 3.16666667]\n", + "[3.61111111 3.61111111]\n", + "[4.05555556 4.05555556]\n", + "[4.5 4.5]\n", + "[0 0]\n" + ] + } + ], + "source": [ + "\n", + "n_params = 10\n", + "params = np.linspace(0.5, 4.5, n_params).reshape(-1, 1) # actually the time steps\n", + "\n", + "def gaussian(x, mu, sig):\n", + " print(mu)\n", + " gaussx, gaussy = np.exp(-np.power(x - mu, 2.) / (2 * np.power(sig, 2.))).T\n", + " return gaussx * gaussy\n", + "def wave(t, res=256):\n", + " x = np.linspace(0, 5, res)\n", + " return x, gaussian(x, t, 0.1)\n", + "def wave2D(t, res=256):\n", + " x = np.linspace(0, 5, res)\n", + " y = np.linspace(0, 5, res)\n", + " gridx, gridy = np.meshgrid(x, y)\n", + " gridx, gridy = gridx.reshape(-1,1), gridy.reshape(-1,1)\n", + " wave = gaussian(np.hstack([gridx, gridy]), t*np.array([1, 1]), 0.1)\n", + " return gridx, gridy, wave\n", + "db = np.array([wave2D(t)[2] for t in params])\n", + "db_array = db.reshape(n_params, -1, 1)\n", + "gridx, gridy = wave2D(0)[0 :2]\n", + "space = reshape2dto1d(gridx,gridy)\n", + "space_array = np.array([space.copy() for t in params])\n", + "\n", + "database = Database(space = space_array, snapshots = db_array, parameters = params)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "ref_point = 5\n", + "NNsPOD_tutorial = NNsPOD(interp_loss = None, interp_layers = [40,40], interp_function = nn.Sigmoid(), interp_stop_training = [10000000,0.000001], ref_point = ref_point,\n", + " shift_layers = [20,20,20], shift_function = nn.PReLU(), shift_stop_training = [0.001])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we put the value for the reference data and train the interpnet. If you want to change this it will need to retrain the interpnet, which can take a long time for 2d data, especially is the loss value is very low." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "interpnet2d.pth\n", + "loaded interpnet\n" + ] + } + ], + "source": [ + "\n", + "ref_data = 5\n", + "NNsPOD_tutorial.train_interpnet(database[ref_data], retrain = False, frequency_print = 5, interp_file = \"interpnet2d.pth\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we graph the reference data" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "x = np.linspace(0, 5, 256)\n", + "y = np.linspace(0, 5, 256)\n", + "gridx, gridy = np.meshgrid(x, y)\n", + " \n", + "plt.pcolor(gridx,gridy,database[ref_data].snapshots.reshape(256, 256))\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we graph the interpolated data. it should be visisble that we get the same function, but with better resolution" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "res = 1000\n", + "x = np.linspace(0, 5, res)\n", + "y = np.linspace(0, 5, res)\n", + "gridx, gridy = np.meshgrid(x, y)\n", + "input = NNsPOD_tutorial.reshape2dto1d(gridx, gridy)\n", + "output = NNsPOD_tutorial.interp_net.predict(input)\n", + "\n", + "toshow = NNsPOD_tutorial.reshape1dto2d(output)\n", + "plt.pcolor(gridx,gridy,toshow)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we use the shiftnet and graph the shifted data. For each parameter we first graph the refrence data, then the input data, then it will take some time to shift it, and finally it will graph the shifted data. The loss value at every epoch will be printed out as well." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.0018307577120140195\n", + "0.0016863499768078327\n", + "0.0015271008014678955\n", + "0.0013487492688000202\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\mrowk\\AppData\\Local\\Temp\\ipykernel_20536\\800095035.py:16: UserWarning: The input coordinates to pcolor are interpreted as cell centers, but are not monotonically increasing or decreasing. This may lead to incorrectly calculated cell edges, in which case, please supply explicit cell edges to pcolor.\n", + " plt.pcolor(x,y,snapshots)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.002387014217674732\n", + "0.0023464937694370747\n", + "0.002300079446285963\n", + "0.0022443446796387434\n", + "0.0021787662990391254\n", + "0.0021018616389483213\n", + "0.0020114330109208822\n", + "0.001906601944938302\n", + "0.0017864166293293238\n", + "0.0016403829213231802\n", + "0.001478851423598826\n", + "0.0013043213402852416\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[1;32mc:\\Users\\mrowk\\OneDrive\\Documents\\GitHub\\EZyRB\\tutorial-3.ipynb Cell 31\u001b[0m in \u001b[0;36m\u001b[1;34m()\u001b[0m\n\u001b[0;32m 6\u001b[0m db \u001b[39m=\u001b[39m database[i]\n\u001b[0;32m 7\u001b[0m plt\u001b[39m.\u001b[39mpcolor(gridx,gridy,database[ref_data]\u001b[39m.\u001b[39msnapshots\u001b[39m.\u001b[39mreshape(\u001b[39m256\u001b[39m, \u001b[39m256\u001b[39m))\n\u001b[1;32m----> 8\u001b[0m plt\u001b[39m.\u001b[39;49mshow()\n\u001b[0;32m 9\u001b[0m plt\u001b[39m.\u001b[39mpcolor(gridx,gridy,database[i]\u001b[39m.\u001b[39msnapshots\u001b[39m.\u001b[39mreshape(\u001b[39m256\u001b[39m, \u001b[39m256\u001b[39m))\n\u001b[0;32m 10\u001b[0m plt\u001b[39m.\u001b[39mshow()\n", + "File \u001b[1;32mc:\\Users\\mrowk\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\matplotlib\\pyplot.py:389\u001b[0m, in \u001b[0;36mshow\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m 345\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[0;32m 346\u001b[0m \u001b[39mDisplay all open figures.\u001b[39;00m\n\u001b[0;32m 347\u001b[0m \n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 386\u001b[0m \u001b[39mexplicitly there.\u001b[39;00m\n\u001b[0;32m 387\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[0;32m 388\u001b[0m _warn_if_gui_out_of_main_thread()\n\u001b[1;32m--> 389\u001b[0m \u001b[39mreturn\u001b[39;00m _get_backend_mod()\u001b[39m.\u001b[39mshow(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n", + "File \u001b[1;32mc:\\Users\\mrowk\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\matplotlib\\backend_bases.py:3544\u001b[0m, in \u001b[0;36m_Backend.show\u001b[1;34m(cls, block)\u001b[0m\n\u001b[0;32m 3542\u001b[0m block \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m\n\u001b[0;32m 3543\u001b[0m \u001b[39mif\u001b[39;00m block:\n\u001b[1;32m-> 3544\u001b[0m \u001b[39mcls\u001b[39;49m\u001b[39m.\u001b[39;49mmainloop()\n", + "File \u001b[1;32mc:\\Users\\mrowk\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\matplotlib\\backends\\backend_qt.py:1054\u001b[0m, in \u001b[0;36m_BackendQT.mainloop\u001b[1;34m()\u001b[0m\n\u001b[0;32m 1052\u001b[0m \u001b[39m@staticmethod\u001b[39m\n\u001b[0;32m 1053\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mmainloop\u001b[39m():\n\u001b[1;32m-> 1054\u001b[0m \u001b[39mwith\u001b[39;00m _maybe_allow_interrupt(qApp):\n\u001b[0;32m 1055\u001b[0m qt_compat\u001b[39m.\u001b[39m_exec(qApp)\n", + "File \u001b[1;32mc:\\Users\\mrowk\\AppData\\Local\\Programs\\Python\\Python310\\lib\\contextlib.py:142\u001b[0m, in \u001b[0;36m_GeneratorContextManager.__exit__\u001b[1;34m(self, typ, value, traceback)\u001b[0m\n\u001b[0;32m 140\u001b[0m \u001b[39mif\u001b[39;00m typ \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m 141\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m--> 142\u001b[0m \u001b[39mnext\u001b[39;49m(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mgen)\n\u001b[0;32m 143\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mStopIteration\u001b[39;00m:\n\u001b[0;32m 144\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mFalse\u001b[39;00m\n", + "File \u001b[1;32mc:\\Users\\mrowk\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\matplotlib\\backends\\qt_compat.py:276\u001b[0m, in \u001b[0;36m_maybe_allow_interrupt\u001b[1;34m(qapp)\u001b[0m\n\u001b[0;32m 274\u001b[0m signal\u001b[39m.\u001b[39msignal(signal\u001b[39m.\u001b[39mSIGINT, old_sigint_handler)\n\u001b[0;32m 275\u001b[0m \u001b[39mif\u001b[39;00m handler_args \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m--> 276\u001b[0m old_sigint_handler(\u001b[39m*\u001b[39;49mhandler_args)\n", + "\u001b[1;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "i = 0\n", + "x = np.linspace(0, 5, 256)\n", + "y = np.linspace(0, 5, 256)\n", + "gridx, gridy = np.meshgrid(x, y)\n", + "while i < database.parameters.shape[0]:\n", + " db = database[i]\n", + " plt.pcolor(gridx,gridy,database[ref_data].snapshots.reshape(256, 256))\n", + " plt.show()\n", + " plt.pcolor(gridx,gridy,database[i].snapshots.reshape(256, 256))\n", + " plt.show()\n", + " x_new = NNsPOD_tutorial.train_shiftnet(database[i], database[ref_data], preshift = True, frequency_print = 5)\n", + " x, y = np.hsplit(x_new, 2)\n", + " x = NNsPOD_tutorial.reshape1dto2d(x)\n", + " y = NNsPOD_tutorial.reshape1dto2d(y)\n", + " snapshots = NNsPOD_tutorial.reshape1dto2d(db.snapshots.reshape(-1,1))\n", + " plt.pcolor(x,y,snapshots)\n", + " plt.show()\n", + " res = 256\n", + " i+=1\n", + " if i == ref_data:\n", + " i +=1\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.001274904003366828\n", + "0.001253377995453775\n", + "0.0012363020796328783\n", + "0.0012356039369478822\n", + "0.0012348229065537453\n", + "0.0012340829707682133\n", + "0.0012333827326074243\n", + "0.0012327131116762757\n", + "0.0012320721289142966\n", + "0.0012314565246924758\n", + "0.001230863039381802\n", + "0.0012302907416597009\n", + "0.0012297385837882757\n", + "0.0012292059836909175\n", + "0.001228693057782948\n", + "0.0012281996896490455\n", + "0.0012277251807972789\n", + "0.0012272682506591082\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "File \u001b[1;32mc:\\Users\\mrowk\\OneDrive\\Documents\\GitHub\\EZyRB\\ezyrb\\nnspod.py:74\u001b[0m, in \u001b[0;36mNNsPOD.train_interpnet\u001b[1;34m(self, ref_data, retrain, frequency_print, save, interp_file)\u001b[0m\n\u001b[0;32m 73\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m---> 74\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39minterp_net \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49minterp_net\u001b[39m.\u001b[39;49mload_state(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49minterp_path, space, snapshots)\n\u001b[0;32m 75\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39m\"\u001b[39m\u001b[39mloaded interpnet\u001b[39m\u001b[39m\"\u001b[39m)\n", + "File \u001b[1;32mc:\\Users\\mrowk\\OneDrive\\Documents\\GitHub\\EZyRB\\ezyrb\\ann.py:184\u001b[0m, in \u001b[0;36mANN.load_state\u001b[1;34m(self, filename, points, values)\u001b[0m\n\u001b[0;32m 182\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39moptimizer \u001b[39m=\u001b[39m checkpoint[\u001b[39m'\u001b[39m\u001b[39moptimizer_class\u001b[39m\u001b[39m'\u001b[39m]\n\u001b[1;32m--> 184\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmodel\u001b[39m.\u001b[39;49mload_state_dict(checkpoint[\u001b[39m'\u001b[39;49m\u001b[39mmodel_state\u001b[39;49m\u001b[39m'\u001b[39;49m])\n\u001b[0;32m 186\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39moptimizer \u001b[39m=\u001b[39m checkpoint[\u001b[39m'\u001b[39m\u001b[39moptimizer_class\u001b[39m\u001b[39m'\u001b[39m](\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmodel\u001b[39m.\u001b[39mparameters())\n", + "File \u001b[1;32mc:\\Users\\mrowk\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torch\\nn\\modules\\module.py:1497\u001b[0m, in \u001b[0;36mModule.load_state_dict\u001b[1;34m(self, state_dict, strict)\u001b[0m\n\u001b[0;32m 1496\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mlen\u001b[39m(error_msgs) \u001b[39m>\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[1;32m-> 1497\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\u001b[39m'\u001b[39m\u001b[39mError(s) in loading state_dict for \u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m:\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\\t\u001b[39;00m\u001b[39m{}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mformat(\n\u001b[0;32m 1498\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__class__\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\\t\u001b[39;00m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mjoin(error_msgs)))\n\u001b[0;32m 1499\u001b[0m \u001b[39mreturn\u001b[39;00m _IncompatibleKeys(missing_keys, unexpected_keys)\n", + "\u001b[1;31mRuntimeError\u001b[0m: Error(s) in loading state_dict for Sequential:\n\tsize mismatch for 0.weight: copying a param with shape torch.Size([20, 1]) from checkpoint, the shape in current model is torch.Size([40, 2]).\n\tsize mismatch for 0.bias: copying a param with shape torch.Size([20]) from checkpoint, the shape in current model is torch.Size([40]).\n\tsize mismatch for 2.weight: copying a param with shape torch.Size([20, 20]) from checkpoint, the shape in current model is torch.Size([40, 40]).\n\tsize mismatch for 2.bias: copying a param with shape torch.Size([20]) from checkpoint, the shape in current model is torch.Size([40]).\n\tsize mismatch for 4.weight: copying a param with shape torch.Size([1, 20]) from checkpoint, the shape in current model is torch.Size([1, 40]).", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[1;32mc:\\Users\\mrowk\\OneDrive\\Documents\\GitHub\\EZyRB\\tutorial-3.ipynb Cell 32\u001b[0m in \u001b[0;36m\u001b[1;34m()\u001b[0m\n\u001b[1;32m----> 1\u001b[0m pod \u001b[39m=\u001b[39m NNsPOD_tutorial\u001b[39m.\u001b[39;49mfit(interp_file \u001b[39m=\u001b[39;49m \u001b[39m\"\u001b[39;49m\u001b[39minterpnet1d.pth\u001b[39;49m\u001b[39m\"\u001b[39;49m,\n\u001b[0;32m 2\u001b[0m db \u001b[39m=\u001b[39;49m database)\n\u001b[0;32m 3\u001b[0m \u001b[39mprint\u001b[39m(pod\u001b[39m.\u001b[39mmodes, pod\u001b[39m.\u001b[39msingular_values)\n", + "File \u001b[1;32mc:\\Users\\mrowk\\OneDrive\\Documents\\GitHub\\EZyRB\\ezyrb\\nnspod.py:213\u001b[0m, in \u001b[0;36mNNsPOD.fit\u001b[1;34m(self, db, interp_file)\u001b[0m\n\u001b[0;32m 211\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39minterp_path \u001b[39m=\u001b[39m interp_file\n\u001b[0;32m 212\u001b[0m \u001b[39m## input variables: load files.\u001b[39;00m\n\u001b[1;32m--> 213\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mtrain_interpnet(db[\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mref_point], retrain \u001b[39m=\u001b[39;49m \u001b[39mFalse\u001b[39;49;00m, frequency_print \u001b[39m=\u001b[39;49m \u001b[39m25\u001b[39;49m)\n\u001b[0;32m 214\u001b[0m new_x \u001b[39m=\u001b[39m np\u001b[39m.\u001b[39mzeros(shape \u001b[39m=\u001b[39m db\u001b[39m.\u001b[39mspace\u001b[39m.\u001b[39mshape)\n\u001b[0;32m 215\u001b[0m i \u001b[39m=\u001b[39m \u001b[39m0\u001b[39m\n", + "File \u001b[1;32mc:\\Users\\mrowk\\OneDrive\\Documents\\GitHub\\EZyRB\\ezyrb\\nnspod.py:77\u001b[0m, in \u001b[0;36mNNsPOD.train_interpnet\u001b[1;34m(self, ref_data, retrain, frequency_print, save, interp_file)\u001b[0m\n\u001b[0;32m 75\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39m\"\u001b[39m\u001b[39mloaded interpnet\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m 76\u001b[0m \u001b[39mexcept\u001b[39;00m:\n\u001b[1;32m---> 77\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49minterp_net\u001b[39m.\u001b[39;49mfit(space, snapshots, frequency_print \u001b[39m=\u001b[39;49m frequency_print)\n\u001b[0;32m 78\u001b[0m \u001b[39mif\u001b[39;00m save:\n\u001b[0;32m 79\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39minterp_net\u001b[39m.\u001b[39msave_state(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39minterp_path)\n", + "File \u001b[1;32mc:\\Users\\mrowk\\OneDrive\\Documents\\GitHub\\EZyRB\\ezyrb\\ann.py:138\u001b[0m, in \u001b[0;36mANN.fit\u001b[1;34m(self, points, values, optimizer, learning_rate, frequency_print)\u001b[0m\n\u001b[0;32m 136\u001b[0m loss \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mloss(y_pred, values)\n\u001b[0;32m 137\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39moptimizer\u001b[39m.\u001b[39mzero_grad()\n\u001b[1;32m--> 138\u001b[0m loss\u001b[39m.\u001b[39;49mbackward()\n\u001b[0;32m 139\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39moptimizer\u001b[39m.\u001b[39mstep()\n\u001b[0;32m 140\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mloss_trend\u001b[39m.\u001b[39mappend(loss\u001b[39m.\u001b[39mitem())\n", + "File \u001b[1;32mc:\\Users\\mrowk\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torch\\_tensor.py:363\u001b[0m, in \u001b[0;36mTensor.backward\u001b[1;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[0;32m 354\u001b[0m \u001b[39mif\u001b[39;00m has_torch_function_unary(\u001b[39mself\u001b[39m):\n\u001b[0;32m 355\u001b[0m \u001b[39mreturn\u001b[39;00m handle_torch_function(\n\u001b[0;32m 356\u001b[0m Tensor\u001b[39m.\u001b[39mbackward,\n\u001b[0;32m 357\u001b[0m (\u001b[39mself\u001b[39m,),\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 361\u001b[0m create_graph\u001b[39m=\u001b[39mcreate_graph,\n\u001b[0;32m 362\u001b[0m inputs\u001b[39m=\u001b[39minputs)\n\u001b[1;32m--> 363\u001b[0m torch\u001b[39m.\u001b[39;49mautograd\u001b[39m.\u001b[39;49mbackward(\u001b[39mself\u001b[39;49m, gradient, retain_graph, create_graph, inputs\u001b[39m=\u001b[39;49minputs)\n", + "File \u001b[1;32mc:\\Users\\mrowk\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torch\\autograd\\__init__.py:173\u001b[0m, in \u001b[0;36mbackward\u001b[1;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[0;32m 168\u001b[0m retain_graph \u001b[39m=\u001b[39m create_graph\n\u001b[0;32m 170\u001b[0m \u001b[39m# The reason we repeat same the comment below is that\u001b[39;00m\n\u001b[0;32m 171\u001b[0m \u001b[39m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[0;32m 172\u001b[0m \u001b[39m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[1;32m--> 173\u001b[0m Variable\u001b[39m.\u001b[39;49m_execution_engine\u001b[39m.\u001b[39;49mrun_backward( \u001b[39m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[0;32m 174\u001b[0m tensors, grad_tensors_, retain_graph, create_graph, inputs,\n\u001b[0;32m 175\u001b[0m allow_unreachable\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m, accumulate_grad\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m)\n", + "\u001b[1;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "pod = NNsPOD_tutorial.fit(interp_file = \"interpnet2d.pth\",\n", + " db = database)\n", + "print(pod.modes, pod.singular_values)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.10.4 64-bit", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "c3571620c9e7a2ef712c686809cf2d92a9d8fa44cb30698f5938f17078c44765" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}