Skip to content

Commit

Permalink
Remove capture from notebooks/influence_imagenet and use tag to hide …
Browse files Browse the repository at this point in the history
…output
  • Loading branch information
schroedk committed Oct 5, 2023
1 parent 9840d70 commit 9429c3e
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 10 deletions.
2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ plugins:
- hide
remove_input_tags:
- hide-input
remove_all_outputs_tags:
- hide-output
binder: true
binder_service_name: "gh"
binder_branch: "develop"
Expand Down
54 changes: 44 additions & 10 deletions notebooks/influence_imagenet.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"source": [
"# Influence functions for neural networks\n",
"\n",
Expand Down Expand Up @@ -261,10 +267,17 @@
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": [
"hide-output"
]
},
"outputs": [],
"source": [
"%%capture\n",
"model_ft = new_resnet_model(output_size=len(label_names))\n",
"mgr = TrainingManager(\n",
" \"model_ft\",\n",
Expand Down Expand Up @@ -376,10 +389,17 @@
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": [
"hide-output"
]
},
"outputs": [],
"source": [
"%%capture\n",
"influences = compute_influences(\n",
" TorchTwiceDifferentiable(mgr.model, mgr.loss),\n",
" train_data,\n",
Expand Down Expand Up @@ -662,10 +682,17 @@
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": [
"hide-output"
]
},
"outputs": [],
"source": [
"%%capture\n",
"corrupted_model = new_resnet_model(output_size=len(label_names))\n",
"corrupted_dataset, corrupted_indices = corrupt_imagenet(\n",
" dataset=train_ds,\n",
Expand Down Expand Up @@ -756,10 +783,17 @@
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": [
"hide-output"
]
},
"outputs": [],
"source": [
"%%capture\n",
"influences = compute_influences(\n",
" TorchTwiceDifferentiable(mgr.model, mgr.loss),\n",
" corrupted_data,\n",
Expand Down Expand Up @@ -1060,5 +1094,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}

0 comments on commit 9429c3e

Please sign in to comment.