Skip to content

Commit

Permalink
Merge pull request #2305 from priankakariatyml:imbalanced_classificat…
Browse files Browse the repository at this point in the history
…ion_tf_2.16_fixes

PiperOrigin-RevId: 664610546
  • Loading branch information
copybara-github committed Aug 19, 2024
2 parents 7d4187e + 4b60ead commit 4a3e670
Showing 1 changed file with 45 additions and 48 deletions.
93 changes: 45 additions & 48 deletions site/en/tutorials/structured_data/imbalanced_data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,10 @@
"train_df, val_df = train_test_split(train_df, test_size=0.2)\n",
"\n",
"# Form np arrays of labels and features.\n",
"train_labels = np.array(train_df.pop('Class'))\n",
"bool_train_labels = train_labels != 0\n",
"val_labels = np.array(val_df.pop('Class'))\n",
"test_labels = np.array(test_df.pop('Class'))\n",
"train_labels = np.array(train_df.pop('Class')).reshape(-1, 1)\n",
"bool_train_labels = train_labels[:, 0] != 0\n",
"val_labels = np.array(val_df.pop('Class')).reshape(-1, 1)\n",
"test_labels = np.array(test_df.pop('Class')).reshape(-1, 1)\n",
"\n",
"train_features = np.array(train_df)\n",
"val_features = np.array(val_df)\n",
Expand Down Expand Up @@ -291,18 +291,17 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "8a_Z_kBmr7Oh"
"id": "ueKV4cmcoRnf"
},
"source": [
"Given the small number of positive labels, this seems about right.\n",
"\n",
"Normalize the input features using the sklearn StandardScaler.\n",
"This will set the mean to 0 and standard deviation to 1.\n",
"\n",
"Note: The `StandardScaler` is only fit using the `train_features` to be sure the model is not peeking at the validation or test sets. "
"Note: The `StandardScaler` is only fit using the `train_features` to be sure the model is not peeking at the validation or test sets."
]
},
{
Expand Down Expand Up @@ -352,7 +351,7 @@
"\n",
"Next compare the distributions of the positive and negative examples over a few features. Good questions to ask yourself at this point are:\n",
"\n",
"* Do these distributions make sense? \n",
"* Do these distributions make sense?\n",
" * Yes. You've normalized the input and these are mostly concentrated in the `+/- 2` range.\n",
"* Can you see the difference between the distributions?\n",
" * Yes the positive examples contain a much higher rate of extreme values."
Expand Down Expand Up @@ -386,7 +385,7 @@
"source": [
"## Define the model and metrics\n",
"\n",
"Define a function that creates a simple neural network with a densly connected hidden layer, a [dropout](https://developers.google.com/machine-learning/glossary/#dropout_regularization) layer to reduce overfitting, and an output sigmoid layer that returns the probability of a transaction being fraudulent: "
"Define a function that creates a simple neural network with a densly connected hidden layer, a [dropout](https://developers.google.com/machine-learning/glossary/#dropout_regularization) layer to reduce overfitting, and an output sigmoid layer that returns the probability of a transaction being fraudulent:"
]
},
{
Expand All @@ -403,7 +402,7 @@
" keras.metrics.TruePositives(name='tp'),\n",
" keras.metrics.FalsePositives(name='fp'),\n",
" keras.metrics.TrueNegatives(name='tn'),\n",
" keras.metrics.FalseNegatives(name='fn'), \n",
" keras.metrics.FalseNegatives(name='fn'),\n",
" keras.metrics.BinaryAccuracy(name='accuracy'),\n",
" keras.metrics.Precision(name='precision'),\n",
" keras.metrics.Recall(name='recall'),\n",
Expand Down Expand Up @@ -432,7 +431,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "SU0GX6E6mieP"
Expand All @@ -456,7 +454,7 @@
"In the end, one often wants to predict a class label, 0 or 1, *no fraud* or *fraud*.\n",
"This is called a deterministic classifier.\n",
"To get a label prediction from our probabilistic classifier, one needs to choose a probability threshold $t$.\n",
"The default is to predict label 1 (fraud) if the predicted probability is larger than $t=50\\%$ and all the following metrics implicitly use this default. \n",
"The default is to predict label 1 (fraud) if the predicted probability is larger than $t=50\\%$ and all the following metrics implicitly use this default.\n",
"\n",
"* **False** negatives and **false** positives are samples that were **incorrectly** classified\n",
"* **True** negatives and **true** positives are samples that were **correctly** classified\n",
Expand All @@ -474,7 +472,7 @@
"The following metrics take into account all possible choices of thresholds $t$.\n",
"\n",
"* **AUC** refers to the Area Under the Curve of a Receiver Operating Characteristic curve (ROC-AUC). This metric is equal to the probability that a classifier will rank a random positive sample higher than a random negative sample.\n",
"* **AUPRC** refers to Area Under the Curve of the Precision-Recall Curve. This metric computes precision-recall pairs for different probability thresholds. \n",
"* **AUPRC** refers to Area Under the Curve of the Precision-Recall Curve. This metric computes precision-recall pairs for different probability thresholds.\n",
"\n",
"\n",
"#### Read more:\n",
Expand Down Expand Up @@ -520,8 +518,9 @@
"EPOCHS = 100\n",
"BATCH_SIZE = 2048\n",
"\n",
"early_stopping = tf.keras.callbacks.EarlyStopping(\n",
" monitor='val_prc', \n",
"def early_stopping():\n",
" return tf.keras.callbacks.EarlyStopping(\n",
" monitor='val_prc',\n",
" verbose=1,\n",
" patience=10,\n",
" mode='max',\n",
Expand Down Expand Up @@ -584,7 +583,7 @@
"id": "PdbfWDuVpo6k"
},
"source": [
"With the default bias initialization the loss should be about `math.log(2) = 0.69314` "
"With the default bias initialization the loss should be about `math.log(2) = 0.69314`"
]
},
{
Expand Down Expand Up @@ -630,7 +629,7 @@
"id": "d1juXI9yY1KD"
},
"source": [
"Set that as the initial bias, and the model will give much more reasonable initial guesses. \n",
"Set that as the initial bias, and the model will give much more reasonable initial guesses.\n",
"\n",
"It should be near: `pos/total = 0.0018`"
]
Expand Down Expand Up @@ -700,7 +699,7 @@
},
"outputs": [],
"source": [
"initial_weights = os.path.join(tempfile.mkdtemp(), 'initial_weights')\n",
"initial_weights = os.path.join(tempfile.mkdtemp(), 'initial.weights.h5')\n",
"model.save_weights(initial_weights)"
]
},
Expand All @@ -714,7 +713,7 @@
"\n",
"Before moving on, confirm quick that the careful bias initialization actually helped.\n",
"\n",
"Train the model for 20 epochs, with and without this careful initialization, and compare the losses: "
"Train the model for 20 epochs, with and without this careful initialization, and compare the losses:"
]
},
{
Expand All @@ -733,7 +732,7 @@
" train_labels,\n",
" batch_size=BATCH_SIZE,\n",
" epochs=20,\n",
" validation_data=(val_features, val_labels), \n",
" validation_data=(val_features, val_labels),\n",
" verbose=0)"
]
},
Expand All @@ -752,7 +751,7 @@
" train_labels,\n",
" batch_size=BATCH_SIZE,\n",
" epochs=20,\n",
" validation_data=(val_features, val_labels), \n",
" validation_data=(val_features, val_labels),\n",
" verbose=0)"
]
},
Expand Down Expand Up @@ -794,7 +793,7 @@
"id": "fKMioV0ddG3R"
},
"source": [
"The above figure makes it clear: In terms of validation loss, on this problem, this careful initialization gives a clear advantage. "
"The above figure makes it clear: In terms of validation loss, on this problem, this careful initialization gives a clear advantage."
]
},
{
Expand All @@ -821,7 +820,7 @@
" train_labels,\n",
" batch_size=BATCH_SIZE,\n",
" epochs=EPOCHS,\n",
" callbacks=[early_stopping],\n",
" callbacks=[early_stopping()],\n",
" validation_data=(val_features, val_labels))"
]
},
Expand Down Expand Up @@ -996,10 +995,9 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "P-QpQsip_F2Q"
"id": "kF8k-g9goRni"
},
"source": [
"### Plot the ROC\n",
Expand Down Expand Up @@ -1161,10 +1159,10 @@
" train_labels,\n",
" batch_size=BATCH_SIZE,\n",
" epochs=EPOCHS,\n",
" callbacks=[early_stopping],\n",
" callbacks=[early_stopping()],\n",
" validation_data=(val_features, val_labels),\n",
" # The class weights go here\n",
" class_weight=class_weight) "
" class_weight=class_weight)"
]
},
{
Expand Down Expand Up @@ -1333,7 +1331,7 @@
"source": [
"#### Using NumPy\n",
"\n",
"You can balance the dataset manually by choosing the right number of random \n",
"You can balance the dataset manually by choosing the right number of random\n",
"indices from the positive examples:"
]
},
Expand Down Expand Up @@ -1485,7 +1483,7 @@
},
"outputs": [],
"source": [
"resampled_steps_per_epoch = np.ceil(2.0*neg/BATCH_SIZE)\n",
"resampled_steps_per_epoch = int(np.ceil(2.0*neg/BATCH_SIZE))\n",
"resampled_steps_per_epoch"
]
},
Expand All @@ -1499,7 +1497,7 @@
"\n",
"Now try training the model with the resampled data set instead of using class weights to see how these methods compare.\n",
"\n",
"Note: Because the data was balanced by replicating the positive examples, the total dataset size is larger, and each epoch runs for more training steps. "
"Note: Because the data was balanced by replicating the positive examples, the total dataset size is larger, and each epoch runs for more training steps."
]
},
{
Expand All @@ -1514,17 +1512,17 @@
"resampled_model.load_weights(initial_weights)\n",
"\n",
"# Reset the bias to zero, since this dataset is balanced.\n",
"output_layer = resampled_model.layers[-1] \n",
"output_layer = resampled_model.layers[-1]\n",
"output_layer.bias.assign([0])\n",
"\n",
"val_ds = tf.data.Dataset.from_tensor_slices((val_features, val_labels)).cache()\n",
"val_ds = val_ds.batch(BATCH_SIZE).prefetch(2) \n",
"val_ds = val_ds.batch(BATCH_SIZE).prefetch(2)\n",
"\n",
"resampled_history = resampled_model.fit(\n",
" resampled_ds,\n",
" epochs=EPOCHS,\n",
" steps_per_epoch=resampled_steps_per_epoch,\n",
" callbacks=[early_stopping],\n",
" callbacks=[early_stopping()],\n",
" validation_data=val_ds)"
]
},
Expand All @@ -1536,7 +1534,7 @@
"source": [
"If the training process were considering the whole dataset on each gradient update, this oversampling would be basically identical to the class weighting.\n",
"\n",
"But when training the model batch-wise, as you did here, the oversampled data provides a smoother gradient signal: Instead of each positive example being shown in one batch with a large weight, they're shown in many different batches each time with a small weight. \n",
"But when training the model batch-wise, as you did here, the oversampled data provides a smoother gradient signal: Instead of each positive example being shown in one batch with a large weight, they're shown in many different batches each time with a small weight.\n",
"\n",
"This smoother gradient signal makes it easier to train the model."
]
Expand All @@ -1549,7 +1547,7 @@
"source": [
"### Check training history\n",
"\n",
"Note that the distributions of metrics will be different here, because the training data has a totally different distribution from the validation and test data. "
"Note that the distributions of metrics will be different here, because the training data has a totally different distribution from the validation and test data."
]
},
{
Expand Down Expand Up @@ -1578,7 +1576,7 @@
"id": "KFLxRL8eoDE5"
},
"source": [
"Because training is easier on the balanced data, the above training procedure may overfit quickly. \n",
"Because training is easier on the balanced data, the above training procedure may overfit quickly.\n",
"\n",
"So break up the epochs to give the `tf.keras.callbacks.EarlyStopping` finer control over when to stop training."
]
Expand All @@ -1595,15 +1593,15 @@
"resampled_model.load_weights(initial_weights)\n",
"\n",
"# Reset the bias to zero, since this dataset is balanced.\n",
"output_layer = resampled_model.layers[-1] \n",
"output_layer = resampled_model.layers[-1]\n",
"output_layer.bias.assign([0])\n",
"\n",
"resampled_history = resampled_model.fit(\n",
" resampled_ds,\n",
" # These are not real epochs\n",
" steps_per_epoch=20,\n",
" epochs=10*EPOCHS,\n",
" callbacks=[early_stopping],\n",
" callbacks=[early_stopping()],\n",
" validation_data=(val_ds))"
]
},
Expand Down Expand Up @@ -1696,7 +1694,7 @@
"id": "vayGnv0VOe_v"
},
"source": [
"### Plot the AUPRC\r\n"
"### Plot the AUPRC\n"
]
},
{
Expand All @@ -1707,14 +1705,14 @@
},
"outputs": [],
"source": [
"plot_prc(\"Train Baseline\", train_labels, train_predictions_baseline, color=colors[0])\r\n",
"plot_prc(\"Test Baseline\", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')\r\n",
"\r\n",
"plot_prc(\"Train Weighted\", train_labels, train_predictions_weighted, color=colors[1])\r\n",
"plot_prc(\"Test Weighted\", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')\r\n",
"\r\n",
"plot_prc(\"Train Resampled\", train_labels, train_predictions_resampled, color=colors[2])\r\n",
"plot_prc(\"Test Resampled\", test_labels, test_predictions_resampled, color=colors[2], linestyle='--')\r\n",
"plot_prc(\"Train Baseline\", train_labels, train_predictions_baseline, color=colors[0])\n",
"plot_prc(\"Test Baseline\", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')\n",
"\n",
"plot_prc(\"Train Weighted\", train_labels, train_predictions_weighted, color=colors[1])\n",
"plot_prc(\"Test Weighted\", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')\n",
"\n",
"plot_prc(\"Train Resampled\", train_labels, train_predictions_resampled, color=colors[2])\n",
"plot_prc(\"Test Resampled\", test_labels, test_predictions_resampled, color=colors[2], linestyle='--')\n",
"plt.legend(loc='lower right');"
]
},
Expand All @@ -1732,7 +1730,6 @@
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "imbalanced_data.ipynb",
"toc_visible": true
},
Expand Down

0 comments on commit 4a3e670

Please sign in to comment.