Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added minor updates in comments for clarity and proposed some u… #44

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions notebooks/04_gan/03_cgan/cgan.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,13 @@
"BATCH_SIZE = 128\n",
"Z_DIM = 32\n",
"LEARNING_RATE = 0.00005\n",
"ADAM_BETA_1 = 0.5\n",
"ADAM_BETA_2 = 0.999\n",
"EPOCHS = 20\n",
"CRITIC_STEPS = 3\n",
"GP_WEIGHT = 10.0\n",
"LOAD_MODEL = False\n",
"# Uncomment below if you want to play with these parameters and see how they are affect learning process\n",
"#ADAM_BETA_1 = 0.5\n",
"#ADAM_BETA_2 = 0.999\n",
"ADAM_BETA_1 = 0.5\n",
"ADAM_BETA_2 = 0.9\n",
"LABEL = \"Blond_Hair\""
Expand Down Expand Up @@ -162,7 +163,8 @@
"outputs": [],
"source": [
"# Show some faces from the training set\n",
"train_sample = sample_batch(train)"
"#train_sample = sample_batch(train)\n",
"#display(train_sample, cmap=None)"
]
},
{
Expand All @@ -172,7 +174,25 @@
"metadata": {},
"outputs": [],
"source": [
"display(train_sample, cmap=None)"
"# Show some faces from the training set\n",
"# The updated version of the image display below, mainly concentrated on 2 things:\n",
"# 1. Make sure that only blond hair celebrties being selected in example\n",
"# 2. Show explicitly filtered labels and image shape for education purposes\n",
"\n",
"for img_batch, label_batch in train_data.take(1):\n",
" # Create filter mask for labels=1 (i.e. Blond_Hair)\n",
" mask=tf.equal(label_batch,1)\n",
"\n",
" # Apply mask for filter labels and images\n",
" filtered_images=tf.boolean_mask(img_batch, mask)\n",
" filtered_labels=tf.boolean_mask(label_batch, mask)\n",
" # Convert to numpy array\n",
" filtered_images_np=filtered_images.numpy()\n",
"\n",
" print('Filtered images shape: ', filtered_images.shape)\n",
" print('filtered labels: ', filtered_labels)\n",
"\n",
"display(filtered_images_np, cmap=None)"
]
},
{
Expand Down Expand Up @@ -435,6 +455,7 @@
"source": [
"# Create a model save checkpoint\n",
"model_checkpoint_callback = callbacks.ModelCheckpoint(\n",
" # NOTE: For tensorflow 11 and upper you may be required to change extension from .ckpt to .weights.h5\n",
" filepath=\"./checkpoint/checkpoint.ckpt\",\n",
" save_weights_only=True,\n",
" save_freq=\"epoch\",\n",
Expand Down Expand Up @@ -492,7 +513,7 @@
"history = cgan.fit(\n",
" train,\n",
" epochs=EPOCHS * 100,\n",
" steps_per_epoch=1,\n",
" steps_per_epoch=1, # NOTE: You probably want to tune up this parameter. Steps per epoch=1 means that is literally take one batch per epoch, this will cause rapid education but very poor results.\n",
" callbacks=[\n",
" model_checkpoint_callback,\n",
" tensorboard_callback,\n",
Expand Down