diff --git a/notebooks/04_gan/03_cgan/cgan.ipynb b/notebooks/04_gan/03_cgan/cgan.ipynb index 8b9a06e..c5917ee 100644 --- a/notebooks/04_gan/03_cgan/cgan.ipynb +++ b/notebooks/04_gan/03_cgan/cgan.ipynb @@ -70,12 +70,13 @@ "BATCH_SIZE = 128\n", "Z_DIM = 32\n", "LEARNING_RATE = 0.00005\n", - "ADAM_BETA_1 = 0.5\n", - "ADAM_BETA_2 = 0.999\n", "EPOCHS = 20\n", "CRITIC_STEPS = 3\n", "GP_WEIGHT = 10.0\n", "LOAD_MODEL = False\n", + "# Uncomment below if you want to play with these parameters and see how they are affect learning process\n", + "#ADAM_BETA_1 = 0.5\n", + "#ADAM_BETA_2 = 0.999\n", "ADAM_BETA_1 = 0.5\n", "ADAM_BETA_2 = 0.9\n", "LABEL = \"Blond_Hair\"" @@ -162,7 +163,8 @@ "outputs": [], "source": [ "# Show some faces from the training set\n", - "train_sample = sample_batch(train)" + "#train_sample = sample_batch(train)\n", + "#display(train_sample, cmap=None)" ] }, { @@ -172,7 +174,25 @@ "metadata": {}, "outputs": [], "source": [ - "display(train_sample, cmap=None)" + "# Show some faces from the training set\n", + "# The updated version of the image display below, mainly concentrated on 2 things:\n", + "# 1. Make sure that only blond hair celebrties being selected in example\n", + "# 2. Show explicitly filtered labels and image shape for education purposes\n", + "\n", + "for img_batch, label_batch in train_data.take(1):\n", + " # Create filter mask for labels=1 (i.e. Blond_Hair)\n", + " mask=tf.equal(label_batch,1)\n", + "\n", + " # Apply mask for filter labels and images\n", + " filtered_images=tf.boolean_mask(img_batch, mask)\n", + " filtered_labels=tf.boolean_mask(label_batch, mask)\n", + " # Convert to numpy array\n", + " filtered_images_np=filtered_images.numpy()\n", + "\n", + " print('Filtered images shape: ', filtered_images.shape)\n", + " print('filtered labels: ', filtered_labels)\n", + "\n", + "display(filtered_images_np, cmap=None)" ] }, { @@ -435,6 +455,7 @@ "source": [ "# Create a model save checkpoint\n", "model_checkpoint_callback = callbacks.ModelCheckpoint(\n", + " # NOTE: For tensorflow 11 and upper you may be required to change extension from .ckpt to .weights.h5\n", " filepath=\"./checkpoint/checkpoint.ckpt\",\n", " save_weights_only=True,\n", " save_freq=\"epoch\",\n", @@ -492,7 +513,7 @@ "history = cgan.fit(\n", " train,\n", " epochs=EPOCHS * 100,\n", - " steps_per_epoch=1,\n", + " steps_per_epoch=1, # NOTE: You probably want to tune up this parameter. Steps per epoch=1 means that is literally take one batch per epoch, this will cause rapid education but very poor results.\n", " callbacks=[\n", " model_checkpoint_callback,\n", " tensorboard_callback,\n",