Skip to content

Commit

Permalink
feat: added notebooks for ASH + improved SCALE ones
Browse files Browse the repository at this point in the history
  • Loading branch information
paulnovello committed Apr 18, 2024
1 parent 5486a5c commit 1f906e1
Show file tree
Hide file tree
Showing 4 changed files with 868 additions and 83 deletions.
413 changes: 413 additions & 0 deletions docs/notebooks/tensorflow/demo_ash_tf.ipynb

Large diffs are not rendered by default.

77 changes: 27 additions & 50 deletions docs/notebooks/tensorflow/demo_scale_tf.ipynb

Large diffs are not rendered by default.

411 changes: 411 additions & 0 deletions docs/notebooks/torch/demo_ash_torch.ipynb

Large diffs are not rendered by default.

50 changes: 17 additions & 33 deletions docs/notebooks/torch/demo_scale_torch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"s = \\exp(\\frac{\\sum_{i} a_i}{\\sum_{a_i > P_p(a)} a_i})\n",
"$$\n",
"\n",
"Here, we focus on a toy convolutional network trained on MNIST[0-4] challenged on MNIST[5-9].\n",
"Here, we focus on a Resnet trained on CIFAR10, challenged on SVHN.\n",
"\n",
"**Reference** \n",
"_Scaling for Training Time and Post-hoc Out-of-distribution Detection Enhancement_, ICLR 2024\n",
Expand All @@ -31,7 +31,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -70,7 +70,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -80,25 +80,15 @@
"os.makedirs(data_path, exist_ok=True)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## MNIST[0-4] vs MNIST[5-9]\n",
"\n",
"We train a toy convolutional network on the MNIST dataset restricted to digits 0 to 4. After fitting the train subset of this dataset to different OOD methods with SCALE option enabled, we will compare the scores returned for MNIST[0-4] (in-distrubtion) and MNIST[5-9] (out-of-distribution) test subsets."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Data loading\n",
"\n",
"* In-distribution data: MNIST[0-4]\n",
"* Out-of-distribution data: MNIST[5-9]\n",
"* In-distribution data: CIFAR-10 \n",
"* Out-of-distribution data: SVHN\n",
"\n",
"> **Note:** We denote In-Distribution (ID) data with `_in` and Out-Of-Distribution (OOD) data\n",
"with `_out` to avoid confusion with OOD detection which is the name of the task, and is\n",
Expand All @@ -107,7 +97,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -166,7 +156,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 4,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -194,7 +184,7 @@
" x = x.to(device)\n",
" preds.append(torch.argmax(model(x), dim=-1).detach().cpu())\n",
" labels.append(y)\n",
"print(f\"Test accuracy:\\t{accuracy_score(torch.cat(labels), torch.cat(preds)):.6f}\")\n"
"print(f\"Test accuracy:\\t{accuracy_score(torch.cat(labels), torch.cat(preds)):.6f}\")"
]
},
{
Expand All @@ -209,15 +199,21 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"=== ODIN ===\n",
"~ Without SCALE ~\n",
"~ Without SCALE ~\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"auroc 0.872446\n",
"fpr95tpr 0.489400\n"
]
Expand Down Expand Up @@ -373,22 +369,10 @@
" for k, v in metrics.items():\n",
" print(f\"{k:<10} {v:.6f}\")\n",
"\n",
" log_scale = d in [\"gen\"]\n",
" # hists / roc\n",
" plt.figure(figsize=(9, 3))\n",
" plt.subplot(121)\n",
" if d == \"msp\":\n",
" # Normalize scores for a better hist visualization\n",
" minim = np.min([np.min(scores_in), np.min(scores_out)])\n",
" scores_in_ = (\n",
" scores_in - 2 * minim + np.min(scores_in[np.where(scores_in != minim)])\n",
" )\n",
" scores_out_ = (\n",
" scores_out - 2 * minim + np.min(scores_in[np.where(scores_in != minim)])\n",
" )\n",
" plot_ood_scores(scores_in_, scores_out_, log_scale=log_scale)\n",
" else:\n",
" plot_ood_scores(scores_in, scores_out, log_scale=log_scale)\n",
" plot_ood_scores(scores_in, scores_out)\n",
" plt.subplot(122)\n",
" plot_roc_curve(scores_in, scores_out)\n",
" plt.tight_layout()\n",
Expand Down

0 comments on commit 1f906e1

Please sign in to comment.