Skip to content

Commit

Permalink
Merge pull request #10 from EleutherAI/cdf
Browse files Browse the repository at this point in the history
Add QuantileNormalizer class
  • Loading branch information
norabelrose authored Jan 18, 2024
2 parents 9b18b3d + 8f3af13 commit 3d33e1f
Show file tree
Hide file tree
Showing 8 changed files with 506 additions and 2 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,7 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# Weights & Biases logs
*.ckpt
wandb/
4 changes: 4 additions & 0 deletions concept_erasure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
from .leace import ErasureMethod, LeaceEraser, LeaceFitter
from .oracle import OracleEraser, OracleFitter
from .quadratic import QuadraticEditor, QuadraticEraser, QuadraticFitter
from .quantile import QuantileNormalizer, cdf, icdf
from .shrinkage import optimal_linear_shrinkage
from .utils import assert_type

__all__ = [
"assert_type",
"cdf",
"groupby",
"icdf",
"optimal_linear_shrinkage",
"ConceptScrubber",
"ErasureMethod",
Expand All @@ -20,4 +23,5 @@
"QuadraticEditor",
"QuadraticEraser",
"QuadraticFitter",
"QuantileNormalizer",
]
7 changes: 5 additions & 2 deletions concept_erasure/quadratic.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,11 @@ class QuadraticEditor:

def transport(self, x: Tensor, source_z: int, target_z: int) -> Tensor:
"""Transport `x` from class `source_z` to class `target_z`"""
x_ = x.flatten(1)

T = self.ot_maps[source_z, target_z]
return (x - self.class_means[source_z]) @ T.mH + self.class_means[target_z]
x_ = (x_ - self.class_means[source_z]) @ T.mH + self.class_means[target_z]
return x_.view_as(x)

def __call__(self, x: Tensor, source_z: Tensor, target_z: int) -> Tensor:
"""Transport `x` from classes `source_z` to class `target_z`."""
Expand Down Expand Up @@ -146,7 +149,7 @@ def update_single(self, x: Tensor, z: int) -> "QuadraticFitter":

return self

def editor(self, device: str | None = None) -> QuadraticEditor:
def editor(self, device: str | torch.device | None = None) -> QuadraticEditor:
"""Quadratic editor for the concept."""
sigma = self.sigma_xx
device = device or sigma.device
Expand Down
105 changes: 105 additions & 0 deletions concept_erasure/quantile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import torch
from torch import Tensor

from .groupby import groupby


def cdf(x: float | Tensor, q: Tensor) -> Tensor:
"""Evaluate empirical CDF defined by quantiles `q` on `x`.
Args:
x: `[...]` Scalar or tensor of data points of arbitrary shape.
q: `[..., num_quantiles]` batch of quantiles. Must be sorted, and
should broadcast with `x` except for the last dimension.
Returns:
`[...]` Empirical CDF evaluated for each element of `x`.
"""
n = q.shape[-1]
assert n > 2, "Must have at least two quantiles to interpolate."

# Approach used by SciPy interp1d with kind='previous'
# Shift x toward +inf by epsilon to appropriately handle ties
x = torch.nextafter(torch.as_tensor(x), q.new_tensor(torch.inf))
return torch.searchsorted(q, x, out_int32=True) / n


def icdf(p: Tensor, q: Tensor) -> Tensor:
"""(Pseudo-)inverse of the ECDF defined by quantiles `q`.
Returns the *smallest* `x` such that the ECDF of `x` is greater than or
equal to `p`.
NOTE: Strictly speaking, this function should return `-inf` when `p` is exactly
zero, because there is no smallest `x` such that `p(x) = 0`. But in practice we
want this function to always return a finite value, so we clip to the minimum
value in `q`.
Args:
x: `[...]` Tensor of data points of arbitrary shape.
q: `[..., num_quantiles]` batch of quantiles. Must be sorted, and
should broadcast with `x` except for the last dimension.
Returns:
`[...]` Empirical CDF evaluated for each element of `x`.
"""
n = q.shape[-1]
assert n > 2, "Must have at least two quantiles to interpolate."

soft_ranks = torch.nextafter(p * n, p.new_tensor(0.0))
return q.gather(-1, soft_ranks.long())


class QuantileNormalizer:
"""Componentwise quantile normalization."""

lut: Tensor
"""`[k, ..., num_bins]` batch of lookup tables."""

dim: int
"""Dimension along which to group the data."""

def __init__(
self,
x: Tensor,
z: Tensor,
num_bins: int = 256,
dim: int = 0,
):
# Efficiently get a view onto each class
grouped = groupby(x, z, dim=dim)
self.dim = dim

k = len(grouped.labels)
self.lut = x.new_empty([k, *x.shape[1:], num_bins])

grid = torch.linspace(0, 1, num_bins, device=x.device)
for i, grp in grouped:
self.lut[i] = grp.quantile(grid, dim=dim).movedim(0, -1)

@property
def num_bins(self) -> int:
return self.lut.shape[-1]

def cdf(self, z: int, x: Tensor) -> Tensor:
return cdf(x.movedim(0, -1), self.lut[z]).movedim(-1, 0)

def sample(self, z: int, n: int) -> Tensor:
lut = self.lut[z]

# Sample p from uniform distribution, then apply inverse CDF
p = torch.rand(*lut[..., 0].shape, n, device=lut.device)
return icdf(p, lut).movedim(-1, 0)

def transport(self, x: Tensor, source_z: Tensor, target_z: int) -> Tensor:
"""Transport `x` from class `source_z` to class `target_z`"""
return (
groupby(x, source_z, dim=self.dim)
.map(
# Probability integral transform, followed by inverse for target class
lambda z, x: icdf(
cdf(x.movedim(0, -1), self.lut[z]), self.lut[target_z]
).movedim(-1, 0)
)
.coalesce()
)
173 changes: 173 additions & 0 deletions experiments/prediction_steering/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
from itertools import pairwise
from typing import Literal

import pytorch_lightning as pl
import torch
import torchmetrics as tm
import torchvision as tv
from torch import nn
from torch.optim import RAdam
from torch.optim.lr_scheduler import CosineAnnealingLR


class Mlp(pl.LightningModule):
def __init__(self, k, h=512, **kwargs):
super().__init__()
self.save_hyperparameters()

self.build_net()
self.train_acc = tm.Accuracy("multiclass", num_classes=k)
self.val_acc = tm.Accuracy("multiclass", num_classes=k)
self.test_acc = tm.Accuracy("multiclass", num_classes=k)

def build_net(self):
sizes = [3 * 32 * 32] + [self.hparams["h"]] * 4

self.net = nn.Sequential(
*[
MlpBlock(
in_dim,
out_dim,
device=self.device,
dtype=self.dtype,
residual=True,
act="gelu",
)
for in_dim, out_dim in pairwise(sizes)
]
)
self.net.append(nn.Linear(self.hparams["h"], self.hparams["k"]))

def forward(self, x):
return self.net(x)

def training_step(self, batch, batch_idx):
x, y = batch

y_hat = self(x)
loss = torch.nn.functional.cross_entropy(y_hat, y)
self.log("train_loss", loss)

self.train_acc(y_hat, y)
self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
# Log the norm of the weights
fc = self.net[-1] if isinstance(self.net, nn.Sequential) else None
if isinstance(fc, nn.Linear):
self.log("weight_norm", fc.weight.data.norm())

return loss

def validation_step(self, batch, batch_idx):
x, y = batch

y_hat = self(x)
loss = torch.nn.functional.cross_entropy(y_hat, y)

self.val_acc(y_hat, y)
self.log("val_loss", loss)
self.log("val_acc", self.val_acc, prog_bar=True)
return loss

def test_step(self, batch, batch_idx):
x, y = batch

y_hat = self(x)
loss = torch.nn.functional.cross_entropy(y_hat, y)

self.test_acc(y_hat, y)
self.log("test_loss", loss)
self.log("test_acc", self.test_acc, prog_bar=True)
return loss

def configure_optimizers(self):
opt = RAdam(self.parameters(), lr=1e-4)
return [opt], [CosineAnnealingLR(opt, T_max=200)]


class MlpMixer(Mlp):
def build_net(self):
from mlp_mixer_pytorch import MLPMixer

self.net = MLPMixer(
image_size=32,
channels=3,
patch_size=self.hparams.get("patch_size", 4),
num_classes=self.hparams["k"],
dim=512,
depth=6,
dropout=0.1,
)


class ResNet(Mlp):
def build_net(self):
self.net = tv.models.resnet18(pretrained=False, num_classes=self.hparams["k"])


class ViT(MlpMixer):
def build_net(self):
from vit_pytorch import ViT

self.net = ViT(
image_size=32,
patch_size=self.hparams.get("patch_size", 4),
num_classes=self.hparams["k"],
dim=512,
depth=6,
heads=8,
mlp_dim=512,
dropout=0.1,
emb_dropout=0.1,
)


class MlpBlock(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
device=None,
dtype=None,
residual: bool = True,
*,
act: Literal["relu", "gelu"] = "relu",
norm: Literal["batch", "layer"] = "batch",
):
super().__init__()

self.linear1 = nn.Linear(
in_features, out_features, bias=False, device=device, dtype=dtype
)
self.linear2 = nn.Linear(
out_features, out_features, bias=False, device=device, dtype=dtype
)
self.act_fn = nn.ReLU() if act == "relu" else nn.GELU()

norm_cls = nn.BatchNorm1d if norm == "batch" else nn.LayerNorm
self.bn1 = norm_cls(out_features, device=device, dtype=dtype)
self.bn2 = norm_cls(out_features, device=device, dtype=dtype)
self.downsample = (
nn.Linear(in_features, out_features, bias=False, device=device, dtype=dtype)
if in_features != out_features
else None
)
self.residual = residual

def forward(self, x):
identity = x

out = self.linear1(x)
out = self.bn1(out)
out = self.act_fn(out)

out = self.linear2(out)
out = self.bn2(out)

if self.downsample is not None:
identity = self.downsample(identity)

if self.residual:
out += identity

out = self.act_fn(out)
return out
Loading

0 comments on commit 3d33e1f

Please sign in to comment.