Skip to content

Commit

Permalink
better evaluate
Browse files Browse the repository at this point in the history
  • Loading branch information
sassa7777 committed Nov 28, 2024
1 parent 4f0c71f commit 6ad2a7a
Show file tree
Hide file tree
Showing 9 changed files with 7,410 additions and 14,719 deletions.
13 changes: 4 additions & 9 deletions learning/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import tensorflow as tf
from tensorflow.keras.layers import Add, Dense, Input, LeakyReLU, Concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping, TensorBoard
from tensorflow.keras.callbacks import EarlyStopping, TensorBoard, ModelCheckpoint
import numpy as np
from tqdm import trange, tqdm
from random import shuffle
Expand Down Expand Up @@ -56,7 +56,7 @@ def digit(n, r):

# 学習
test_ratio = 0.001
n_epochs = 6
n_epochs = 5

diagonal8_idx = [[0, 9, 18, 27, 36, 45, 54, 63], [7, 14, 21, 28, 35, 42, 49, 56]]
for pattern in deepcopy(diagonal8_idx):
Expand Down Expand Up @@ -220,13 +220,8 @@ def lr_schedule(epoch, lr):


print(model.evaluate(test_data, test_labels))
checkpoint = ModelCheckpoint(
'models/best_model.h5',
monitor='val_loss',
save_best_only=True
)
early_stop = EarlyStopping(monitor='val_loss', patience=1, restore_best_weights=True)
history = model.fit(train_data, train_labels, epochs=n_epochs, validation_data=(test_data, test_labels), callbacks=[early_stop, checkpoint, tensorboard_callback], batch_size=32)
early_stop = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
history = model.fit(train_data, train_labels, epochs=n_epochs, validation_data=(test_data, test_labels), callbacks=[early_stop, tensorboard_callback], batch_size=32)

now = datetime.datetime.today()
model.save('models/model.h5')
Expand Down
Binary file modified learning/models/model.h5
Binary file not shown.
Loading

0 comments on commit 6ad2a7a

Please sign in to comment.