forked from SphericalKat/electrical-lstm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
36 lines (25 loc) · 828 Bytes
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# -*- coding: utf-8 -*-
"""
Created on Fri Feb 21 23:02:51 2020
@author: Tanmay Thakur
"""
import pickle
import numpy as np
import keras
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
X_train, y_train = pickle.load(open( "dict.pickle", "rb" ))
model = keras.models.load_model("recurrent_model_initial.h5")
validation_target = y_train[3*len(X_train)//4:]
validation_predictions = []
error = []
# index of first validation input
i = 3*len(X_train)//4
while len(validation_predictions) < len(validation_target) - 1:
p = model.predict(X_train[i].reshape(1, X_train.shape[1], X_train.shape[2]))[0]
i += 1
error.append(mean_squared_error(p,y_train[i]))
# update the predictions list
validation_predictions.append(p)
plt.plot(error)
pickle.dump(error, open("error.pickle", "wb"))