Skip to content

Commit

Permalink
Update model to detect class imbalance and retrain it
Browse files Browse the repository at this point in the history
  • Loading branch information
mostafa committed Oct 28, 2024
1 parent 226f4b2 commit 95f2f41
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 25 deletions.
2 changes: 1 addition & 1 deletion sqli_model/3/fingerprint.pb
Original file line number Diff line number Diff line change
@@ -1 +1 @@
������ܾ?�ɘ��������月�� �����Ϗ�(����қ��P2
��������=�������ʒ���月�� �����Ϗ�(���������2
Binary file modified sqli_model/3/saved_model.pb
Binary file not shown.
Binary file modified sqli_model/3/variables/variables.data-00000-of-00001
Binary file not shown.
Binary file modified sqli_model/3/variables/variables.index
Binary file not shown.
70 changes: 46 additions & 24 deletions training/train_v3.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
import os
import pandas as pd
import tensorflow as tf
from tensorflow.keras.preprocessing.text import Tokenizer
Expand All @@ -17,6 +18,7 @@
from tensorflow.keras.callbacks import EarlyStopping
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score, precision_score, recall_score
from sklearn.utils.class_weight import compute_class_weight
import numpy as np
import matplotlib.pyplot as plt

Expand Down Expand Up @@ -54,11 +56,7 @@ def build_model(input_dim, output_dim=128):
model.compile(
loss="binary_crossentropy",
optimizer="adam",
metrics=[
"accuracy",
tf.keras.metrics.Precision(name="precision"),
tf.keras.metrics.Recall(name="recall"),
],
metrics=["accuracy", tf.keras.metrics.Precision(), tf.keras.metrics.Recall()],
)
return model

Expand All @@ -75,31 +73,43 @@ def calculate_f1_f2(precision, recall, beta=1):

def plot_history(history):
"""Plot the training and validation loss, accuracy, precision, and recall."""
available_metrics = history.history.keys() # Check which metrics are available
plt.figure(figsize=(12, 8))
for i, metric in enumerate(["loss", "accuracy", "precision", "recall"], start=1):
plt.subplot(2, 2, i)
plt.plot(history.history[metric], label=f"Training {metric.capitalize()}")
plt.plot(
history.history[f"val_{metric}"], label=f"Validation {metric.capitalize()}"
)
plt.title(metric.capitalize())
plt.xlabel("Epochs")
plt.ylabel(metric.capitalize())
plt.legend()

# Define metrics to plot
metrics_to_plot = ["loss", "accuracy", "precision", "recall"]
for i, metric in enumerate(metrics_to_plot, start=1):
if metric in available_metrics:
plt.subplot(2, 2, i)
plt.plot(history.history[metric], label=f"Training {metric.capitalize()}")
plt.plot(
history.history[f"val_{metric}"],
label=f"Validation {metric.capitalize()}",
)
plt.title(metric.capitalize())
plt.xlabel("Epochs")
plt.ylabel(metric.capitalize())
plt.legend()

plt.tight_layout()
plt.savefig("training_history.png")


# Main function
if __name__ == "__main__":
if len(sys.argv) != 3:
print("Usage: python train.py <input_file> <output_dir>")
sys.exit(1)

# Constants
MAX_WORDS = 10000
MAX_LEN = 100
EPOCHS = 50
BATCH_SIZE = 32

# Load and preprocess data
data = load_data(sys.argv[1])
X, tokenizer = preprocess_text(data)
y = data["Label"]
y = data["Label"].values # Convert to NumPy array to avoid KeyError in KFold

# Initialize cross-validation
k_folds = 5
Expand All @@ -111,7 +121,13 @@ def plot_history(history):

# Split the data
X_train, X_val = X[train_idx], X[val_idx]
y_train, y_val = y.iloc[train_idx], y.iloc[val_idx]
y_train, y_val = y[train_idx], y[val_idx]

# Compute class weights to handle imbalance
class_weights = compute_class_weight(
"balanced", classes=np.unique(y_train), y=y_train
)
class_weight_dict = {i: class_weights[i] for i in range(len(class_weights))}

# Build and train the model
model = build_model(input_dim=len(tokenizer.word_index) + 1)
Expand All @@ -121,15 +137,16 @@ def plot_history(history):
history = model.fit(
X_train,
y_train,
epochs=50,
batch_size=32,
epochs=EPOCHS,
batch_size=BATCH_SIZE,
validation_data=(X_val, y_val),
class_weight=class_weight_dict,
callbacks=[early_stopping],
verbose=1,
)

# Make predictions to manually calculate metrics
y_val_pred = (model.predict(X_val) > 0.5).astype(int)
# Make predictions to calculate metrics
y_val_pred = (model.predict(X_val) > 0.8).astype(int)
accuracy = accuracy_score(y_val, y_val_pred)
precision = precision_score(y_val, y_val_pred)
recall = recall_score(y_val, y_val_pred)
Expand All @@ -143,12 +160,17 @@ def plot_history(history):
fold_metrics["f1"].append(f1_score)
fold_metrics["f2"].append(f2_score)

# Calculate average metrics across folds
# Calculate and display average metrics across folds
avg_metrics = {metric: np.mean(scores) for metric, scores in fold_metrics.items()}
print("\nCross-validation results:")
for metric, value in avg_metrics.items():
print(f"{metric.capitalize()}: {value:.2f}")

# Save the final model trained on the last fold
model.export(sys.argv[2])
output_dir = sys.argv[2]
if not os.path.exists(output_dir):
os.makedirs(output_dir)
model.export(output_dir)

# Plot training history of the last fold
plot_history(history)

0 comments on commit 95f2f41

Please sign in to comment.