-
Notifications
You must be signed in to change notification settings - Fork 0
/
autoencoder_estimator.py
69 lines (53 loc) · 2.05 KB
/
autoencoder_estimator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import mlflow
import numpy as np
from sklearn.base import BaseEstimator
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras import layers, losses
from sklearn.model_selection import train_test_split
class AutoEncoder(Model):
def __init__(self):
super(AutoEncoder, self).__init__()
self.encoder = tf.keras.Sequential([
layers.Dense(32, activation="relu"),
layers.Dense(16, activation="relu"),
layers.Dense(8, activation="relu")])
self.decoder = tf.keras.Sequential([
layers.Dense(16, activation="relu"),
layers.Dense(32, activation="relu"),
layers.Dense(27, activation="sigmoid")])
def call(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded
class AutoEncoderEstimator(BaseEstimator):
def __init__(self, epochs=10, batch_size=512) -> None:
self.epochs = epochs
self.batch_size = batch_size
def fit(self, X, y=None):
X_ = X.values
train_data, test_data = train_test_split(
X_, test_size=0.2, shuffle=False
)
train_data = tf.cast(train_data, tf.float32)
test_data = tf.cast(test_data, tf.float32)
self.autoencoder = AutoEncoder()
self.autoencoder.compile(optimizer='adam', loss='mae')
self.autoencoder.fit(train_data, train_data,
epochs=self.epochs,
batch_size=self.batch_size,
validation_data=(test_data, test_data),
shuffle=True)
reconstructions = self.autoencoder.predict(X_)
train_loss = tf.keras.losses.mae(reconstructions, X_)
self.threshold = np.mean(train_loss) + np.std(train_loss)
return self
def transform(self, X, y=None):
return self
def predict(self, X):
X_ = X.values
print(X_.shape)
print(type(X_))
reconstructions = self.autoencoder(X_)
losses = tf.keras.eses.mae(reconstructions, X_)
return tf.math.less(self.threshold, losses)