-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #10 from EleutherAI/cdf
Add QuantileNormalizer class
- Loading branch information
Showing
8 changed files
with
506 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import torch | ||
from torch import Tensor | ||
|
||
from .groupby import groupby | ||
|
||
|
||
def cdf(x: float | Tensor, q: Tensor) -> Tensor: | ||
"""Evaluate empirical CDF defined by quantiles `q` on `x`. | ||
Args: | ||
x: `[...]` Scalar or tensor of data points of arbitrary shape. | ||
q: `[..., num_quantiles]` batch of quantiles. Must be sorted, and | ||
should broadcast with `x` except for the last dimension. | ||
Returns: | ||
`[...]` Empirical CDF evaluated for each element of `x`. | ||
""" | ||
n = q.shape[-1] | ||
assert n > 2, "Must have at least two quantiles to interpolate." | ||
|
||
# Approach used by SciPy interp1d with kind='previous' | ||
# Shift x toward +inf by epsilon to appropriately handle ties | ||
x = torch.nextafter(torch.as_tensor(x), q.new_tensor(torch.inf)) | ||
return torch.searchsorted(q, x, out_int32=True) / n | ||
|
||
|
||
def icdf(p: Tensor, q: Tensor) -> Tensor: | ||
"""(Pseudo-)inverse of the ECDF defined by quantiles `q`. | ||
Returns the *smallest* `x` such that the ECDF of `x` is greater than or | ||
equal to `p`. | ||
NOTE: Strictly speaking, this function should return `-inf` when `p` is exactly | ||
zero, because there is no smallest `x` such that `p(x) = 0`. But in practice we | ||
want this function to always return a finite value, so we clip to the minimum | ||
value in `q`. | ||
Args: | ||
x: `[...]` Tensor of data points of arbitrary shape. | ||
q: `[..., num_quantiles]` batch of quantiles. Must be sorted, and | ||
should broadcast with `x` except for the last dimension. | ||
Returns: | ||
`[...]` Empirical CDF evaluated for each element of `x`. | ||
""" | ||
n = q.shape[-1] | ||
assert n > 2, "Must have at least two quantiles to interpolate." | ||
|
||
soft_ranks = torch.nextafter(p * n, p.new_tensor(0.0)) | ||
return q.gather(-1, soft_ranks.long()) | ||
|
||
|
||
class QuantileNormalizer: | ||
"""Componentwise quantile normalization.""" | ||
|
||
lut: Tensor | ||
"""`[k, ..., num_bins]` batch of lookup tables.""" | ||
|
||
dim: int | ||
"""Dimension along which to group the data.""" | ||
|
||
def __init__( | ||
self, | ||
x: Tensor, | ||
z: Tensor, | ||
num_bins: int = 256, | ||
dim: int = 0, | ||
): | ||
# Efficiently get a view onto each class | ||
grouped = groupby(x, z, dim=dim) | ||
self.dim = dim | ||
|
||
k = len(grouped.labels) | ||
self.lut = x.new_empty([k, *x.shape[1:], num_bins]) | ||
|
||
grid = torch.linspace(0, 1, num_bins, device=x.device) | ||
for i, grp in grouped: | ||
self.lut[i] = grp.quantile(grid, dim=dim).movedim(0, -1) | ||
|
||
@property | ||
def num_bins(self) -> int: | ||
return self.lut.shape[-1] | ||
|
||
def cdf(self, z: int, x: Tensor) -> Tensor: | ||
return cdf(x.movedim(0, -1), self.lut[z]).movedim(-1, 0) | ||
|
||
def sample(self, z: int, n: int) -> Tensor: | ||
lut = self.lut[z] | ||
|
||
# Sample p from uniform distribution, then apply inverse CDF | ||
p = torch.rand(*lut[..., 0].shape, n, device=lut.device) | ||
return icdf(p, lut).movedim(-1, 0) | ||
|
||
def transport(self, x: Tensor, source_z: Tensor, target_z: int) -> Tensor: | ||
"""Transport `x` from class `source_z` to class `target_z`""" | ||
return ( | ||
groupby(x, source_z, dim=self.dim) | ||
.map( | ||
# Probability integral transform, followed by inverse for target class | ||
lambda z, x: icdf( | ||
cdf(x.movedim(0, -1), self.lut[z]), self.lut[target_z] | ||
).movedim(-1, 0) | ||
) | ||
.coalesce() | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
from itertools import pairwise | ||
from typing import Literal | ||
|
||
import pytorch_lightning as pl | ||
import torch | ||
import torchmetrics as tm | ||
import torchvision as tv | ||
from torch import nn | ||
from torch.optim import RAdam | ||
from torch.optim.lr_scheduler import CosineAnnealingLR | ||
|
||
|
||
class Mlp(pl.LightningModule): | ||
def __init__(self, k, h=512, **kwargs): | ||
super().__init__() | ||
self.save_hyperparameters() | ||
|
||
self.build_net() | ||
self.train_acc = tm.Accuracy("multiclass", num_classes=k) | ||
self.val_acc = tm.Accuracy("multiclass", num_classes=k) | ||
self.test_acc = tm.Accuracy("multiclass", num_classes=k) | ||
|
||
def build_net(self): | ||
sizes = [3 * 32 * 32] + [self.hparams["h"]] * 4 | ||
|
||
self.net = nn.Sequential( | ||
*[ | ||
MlpBlock( | ||
in_dim, | ||
out_dim, | ||
device=self.device, | ||
dtype=self.dtype, | ||
residual=True, | ||
act="gelu", | ||
) | ||
for in_dim, out_dim in pairwise(sizes) | ||
] | ||
) | ||
self.net.append(nn.Linear(self.hparams["h"], self.hparams["k"])) | ||
|
||
def forward(self, x): | ||
return self.net(x) | ||
|
||
def training_step(self, batch, batch_idx): | ||
x, y = batch | ||
|
||
y_hat = self(x) | ||
loss = torch.nn.functional.cross_entropy(y_hat, y) | ||
self.log("train_loss", loss) | ||
|
||
self.train_acc(y_hat, y) | ||
self.log("train_acc", self.train_acc, on_epoch=True, on_step=False) | ||
# Log the norm of the weights | ||
fc = self.net[-1] if isinstance(self.net, nn.Sequential) else None | ||
if isinstance(fc, nn.Linear): | ||
self.log("weight_norm", fc.weight.data.norm()) | ||
|
||
return loss | ||
|
||
def validation_step(self, batch, batch_idx): | ||
x, y = batch | ||
|
||
y_hat = self(x) | ||
loss = torch.nn.functional.cross_entropy(y_hat, y) | ||
|
||
self.val_acc(y_hat, y) | ||
self.log("val_loss", loss) | ||
self.log("val_acc", self.val_acc, prog_bar=True) | ||
return loss | ||
|
||
def test_step(self, batch, batch_idx): | ||
x, y = batch | ||
|
||
y_hat = self(x) | ||
loss = torch.nn.functional.cross_entropy(y_hat, y) | ||
|
||
self.test_acc(y_hat, y) | ||
self.log("test_loss", loss) | ||
self.log("test_acc", self.test_acc, prog_bar=True) | ||
return loss | ||
|
||
def configure_optimizers(self): | ||
opt = RAdam(self.parameters(), lr=1e-4) | ||
return [opt], [CosineAnnealingLR(opt, T_max=200)] | ||
|
||
|
||
class MlpMixer(Mlp): | ||
def build_net(self): | ||
from mlp_mixer_pytorch import MLPMixer | ||
|
||
self.net = MLPMixer( | ||
image_size=32, | ||
channels=3, | ||
patch_size=self.hparams.get("patch_size", 4), | ||
num_classes=self.hparams["k"], | ||
dim=512, | ||
depth=6, | ||
dropout=0.1, | ||
) | ||
|
||
|
||
class ResNet(Mlp): | ||
def build_net(self): | ||
self.net = tv.models.resnet18(pretrained=False, num_classes=self.hparams["k"]) | ||
|
||
|
||
class ViT(MlpMixer): | ||
def build_net(self): | ||
from vit_pytorch import ViT | ||
|
||
self.net = ViT( | ||
image_size=32, | ||
patch_size=self.hparams.get("patch_size", 4), | ||
num_classes=self.hparams["k"], | ||
dim=512, | ||
depth=6, | ||
heads=8, | ||
mlp_dim=512, | ||
dropout=0.1, | ||
emb_dropout=0.1, | ||
) | ||
|
||
|
||
class MlpBlock(nn.Module): | ||
def __init__( | ||
self, | ||
in_features: int, | ||
out_features: int, | ||
device=None, | ||
dtype=None, | ||
residual: bool = True, | ||
*, | ||
act: Literal["relu", "gelu"] = "relu", | ||
norm: Literal["batch", "layer"] = "batch", | ||
): | ||
super().__init__() | ||
|
||
self.linear1 = nn.Linear( | ||
in_features, out_features, bias=False, device=device, dtype=dtype | ||
) | ||
self.linear2 = nn.Linear( | ||
out_features, out_features, bias=False, device=device, dtype=dtype | ||
) | ||
self.act_fn = nn.ReLU() if act == "relu" else nn.GELU() | ||
|
||
norm_cls = nn.BatchNorm1d if norm == "batch" else nn.LayerNorm | ||
self.bn1 = norm_cls(out_features, device=device, dtype=dtype) | ||
self.bn2 = norm_cls(out_features, device=device, dtype=dtype) | ||
self.downsample = ( | ||
nn.Linear(in_features, out_features, bias=False, device=device, dtype=dtype) | ||
if in_features != out_features | ||
else None | ||
) | ||
self.residual = residual | ||
|
||
def forward(self, x): | ||
identity = x | ||
|
||
out = self.linear1(x) | ||
out = self.bn1(out) | ||
out = self.act_fn(out) | ||
|
||
out = self.linear2(out) | ||
out = self.bn2(out) | ||
|
||
if self.downsample is not None: | ||
identity = self.downsample(identity) | ||
|
||
if self.residual: | ||
out += identity | ||
|
||
out = self.act_fn(out) | ||
return out |
Oops, something went wrong.