Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Review changes for sweep experiment #4

Open
wants to merge 49 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
2cb88f7
concept editing experiments
AlexTMallen Sep 23, 2023
6ffb97e
top1 and loss matrices
AlexTMallen Sep 23, 2023
99ca187
fix segfault
AlexTMallen Sep 23, 2023
c2d4749
both loss matrices
AlexTMallen Sep 23, 2023
e9250b9
plotting loss matrices
AlexTMallen Sep 23, 2023
bfadef8
Merge branch 'main' into editing-experiments
AlexTMallen Sep 23, 2023
5407aae
use double precision
AlexTMallen Sep 24, 2023
bbdff6c
Merge branch 'main' into editing-experiments
AlexTMallen Sep 24, 2023
c93ca68
naive quadratic erasure and editing
AlexTMallen Sep 24, 2023
870a365
create sentiment dataset
AlexTMallen Sep 26, 2023
ed96b3f
Merge branch 'main' into editing-experiments
AlexTMallen Sep 26, 2023
06a9054
Prepare for gutting Sweep.run
norabelrose Sep 26, 2023
58d3dfe
Refactor MVP
norabelrose Sep 26, 2023
ffbfb93
Don't early stop too early
norabelrose Sep 26, 2023
ba1d62e
minor changes
AlexTMallen Sep 26, 2023
7582992
Merge branch 'refactor' into editing-experiments
AlexTMallen Sep 26, 2023
5ca1ec0
Better hparams
norabelrose Sep 26, 2023
18c050e
Merge branch 'refactor' into editing-experiments
AlexTMallen Sep 26, 2023
2d6b815
Remove duplicate code in Sweep.run
norabelrose Sep 26, 2023
2f8cb88
Backtracking in Probe.fit
norabelrose Sep 26, 2023
45b1dea
batched eval; use CIFAR test; multiple random seeds
AlexTMallen Sep 26, 2023
3e3518e
Sweep now calls preprocessor on test set
norabelrose Sep 26, 2023
9bfb71a
more efficient eval
AlexTMallen Sep 26, 2023
ed6acd3
Merge remote-tracking branch 'origin/refactor' into editing-experiments
AlexTMallen Sep 26, 2023
d08f398
MlpProbe now has a residual architecture
norabelrose Sep 27, 2023
4597bd3
sentiment erasure, fix binary loss
AlexTMallen Sep 27, 2023
99488cc
label smoothing
norabelrose Sep 27, 2023
0f894b2
merge with refactor
AlexTMallen Sep 27, 2023
87fdbea
Merge branch 'editing-experiments' of github.com:EleutherAI/mdl into …
AlexTMallen Sep 27, 2023
917de33
Merge branch 'refactor' into editing-experiments
AlexTMallen Sep 27, 2023
71be6e7
remove gradscaler
norabelrose Sep 27, 2023
2fc340d
improved visionprobe for editing experiment
AlexTMallen Sep 27, 2023
a1c85ff
Merge branch 'refactor' into editing-experiments
AlexTMallen Sep 27, 2023
de35b2c
use vanilla LEACE in sweep
AlexTMallen Sep 30, 2023
351e7f8
save state
luciaquirke Oct 22, 2024
77b7c2c
Save state
luciaquirke Oct 22, 2024
7dbcf99
add gist with two more nets - vit and convnext
luciaquirke Oct 24, 2024
acb478d
Add probes; add logging; add new cli for sweep
luciaquirke Oct 24, 2024
e4c07c4
log distance from init in model params
luciaquirke Oct 24, 2024
bc9220a
update gitignore
luciaquirke Oct 24, 2024
5de0031
Add scalable resnet probe; fix trivial bugs
luciaquirke Oct 24, 2024
25a6fae
automate wandb run naming
luciaquirke Oct 24, 2024
21a2a8c
resolve stray todos
luciaquirke Oct 25, 2024
77ec2f4
remove old cli
luciaquirke Oct 28, 2024
7f452aa
plot mdl
luciaquirke Oct 30, 2024
c78161d
restrict logging to final MDL chunk; plot cleanup; check for duplicat…
luciaquirke Oct 31, 2024
d8d68b6
save progress
luciaquirke Nov 14, 2024
8af8fbd
Add muP; add checkpointing; add MlpProbe activation variations (swigl…
luciaquirke Nov 22, 2024
b72edcb
Add mdl and loss plots with activation functions plotted together
luciaquirke Nov 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
data

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down Expand Up @@ -158,3 +160,8 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/


erasers_cache
lightning_logs
wandb
4 changes: 0 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,3 @@ repos:
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/codespell-project/codespell
rev: v2.2.4
hooks:
- id: codespell
192 changes: 192 additions & 0 deletions experiments/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
from argparse import ArgumentParser
from pathlib import Path
from functools import partial

import wandb
import torch
import torch.nn.functional as F
import torchvision.transforms.v2 as transforms
from torchvision.transforms.v2.functional import to_tensor
from concept_erasure import LeaceFitter, OracleEraser, OracleFitter, QuadraticFitter, LeaceEraser
from torch import Tensor
from torchvision.datasets import CIFAR10
from tqdm.auto import tqdm
import lovely_tensors as lt


from mdl.mlp_probe import ResMlpProbe, SeqMlpProbe, LinearProbe
from mdl.sweep import Sweep
from mdl.vision_probe import ViTProbe, ConvNextProbe
from mdl.resnet_probe import ResNetProbe

lt.monkey_patch()
torch.set_default_tensor_type(torch.DoubleTensor)


if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

parser = ArgumentParser()
parser.add_argument("--name", type=str, default='')
parser.add_argument("--erasers", type=str, nargs="+", choices=["none", "leace", "oleace", "qleace"], default=["none"])
parser.add_argument("--net", type=str, choices=("mlp", "resmlp", "resnet", "convnext", "vit", "linear"))
parser.add_argument("--width", type=int, default=128)
parser.add_argument("--depth", type=int, default=3)
parser.add_argument("--num_seeds", type=int, default=4)
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()

nontest = CIFAR10(root="/mnt/ssd-1/alexm/cifar10/", download=True)
images, labels = zip(*nontest)
X: Tensor = torch.stack(list(map(to_tensor, images))).to(device)
Y = torch.tensor(labels).to(device)

# Shuffle deterministically
rng = torch.Generator(device=X.device).manual_seed(42)
perm = torch.randperm(len(X), generator=rng, device=X.device)
X, Y = X[perm], Y[perm]

k = int(Y.max()) + 1

# Split train and validation
val_size = 1024

X_train, X_val = X[:-val_size], X[-val_size:]
Y_train, Y_val = Y[:-val_size], Y[-val_size:]

# Test set is entirely separate
test = CIFAR10(root="/home/lucia/cifar10-test", train=False, download=True)
test_images, test_labels = zip(*test)
X_test: Tensor = torch.stack(list(map(to_tensor, test_images))).to(device)
Y_test = torch.tensor(test_labels).to(device)

# Populate eraser cache for training set if necessary
state_path = Path("erasers_cache") / f"cifar10_state.pth"
state_path.parent.mkdir(exist_ok=True)
state = {} if not state_path.exists() else torch.load(state_path)
for eraser_str in args.erasers:
if eraser_str == "none" or eraser_str in state:
continue

cls = {
"leace": LeaceFitter,
"oleace": OracleFitter,
"qleace": QuadraticFitter,
}[eraser_str]

fitter = cls(3 * 32 * 32, k, dtype=torch.float64, device=device, shrinkage=False)
for x, y in tqdm(zip(X_train, Y_train)):
y = torch.as_tensor(y).view(1)
if eraser_str != "qleace":
y = F.one_hot(y, k)

fitter.update(x.view(1, -1).to(device), y.to(device))

state[eraser_str] = fitter.eraser
torch.save(state, state_path)

# Reduce size after eraser computation - cache does not differentiate between train set sizes
if args.debug:
X_train = X_train[:10_000]
Y_train = Y_train[:10_000]

model_cls = {
"mlp": SeqMlpProbe,
"resmlp": ResMlpProbe,
"resnet": ResNetProbe,
"vit": ViTProbe,
"convnext": ConvNextProbe,
"linear": LinearProbe,
}[args.net]

flatten = {
"mlp": True,
"resmlp": True,
"resnet": False,
"vit": False,
"convnext": False,
"linear": True
}

image_size = X.shape[-1]
padding = round(image_size * 0.125)

if flatten[args.net]:
def reshape(x):
"reshape tensor to CxHxW"
return x.view(-1, X.shape[1], X.shape[2], X.shape[3])

augment = transforms.Compose([
transforms.Lambda(reshape),
transforms.RandomCrop(image_size, padding=padding),
transforms.RandomHorizontalFlip(),
transforms.Lambda(lambda x: x.flatten(1))
])
else:
augment = transforms.Compose([
transforms.RandomCrop(image_size, padding=padding),
transforms.RandomHorizontalFlip(),
])

def erase(x: Tensor, y: Tensor, eraser):
assert y.ndim == 1
assert x.ndim > 1 # otherwise requires unsqueeze

if isinstance(eraser, LeaceEraser):
x_erased = eraser(x.flatten(1))
elif isinstance(eraser, OracleEraser):
x_erased = eraser(x.flatten(1), y)
else:
x_erased = eraser(x.flatten(1), y)

if flatten[args.net]:
return x_erased
return x_erased.reshape_as(x)


def none_transform(x, y):
if not flatten[args.net]:
return x
return x.flatten(1)

data = {}
for eraser_str in args.erasers:
transform = (
partial(erase, eraser=state[eraser_str].to(device))
if eraser_str != "none"
else none_transform
)

results = []
for seed in range(args.num_seeds):
if not 'test' in args.name:
run = wandb.init(
project="mdl",
entity="eleutherai",
name=f'{eraser_str if eraser_str != "none" else "baseline"} {args.name} w={args.width} d={args.depth} s={seed} {args.net}',
config={'eraser': eraser_str, **vars(args)}
)
else:
run = None

sweep = Sweep(
X.shape[1] * X.shape[2] * X.shape[3], k, device=X.device, dtype=torch.float64,
num_chunks=10, logger=run,
probe_cls=model_cls,
probe_kwargs=dict(num_layers=args.depth, hidden_size=args.width),
)

results.append(sweep.run(
X.double(), Y, seed=seed, transform=transform,
augment=augment, reduce_lr_on_plateau=False,
))

if not 'test' in args.name:
wandb.finish()

data[eraser_str] = results

data_path = Path(f"/mnt/ssd-1/lucia/results" if not args.debug else f"/mnt/ssd-1/lucia/debug-results")
data_path.mkdir(exist_ok=True)

torch.save(data, data_path / f"{args.net}_h={args.width}_d={args.depth}_{'_'.join(args.erasers)}_{args.name}.pth")
Loading