Skip to content

Commit

Permalink
Merge pull request #15 from EleutherAI/to-device
Browse files Browse the repository at this point in the history
Add convenience fn to move erasers to new device
  • Loading branch information
norabelrose authored Oct 15, 2024
2 parents ff119e8 + ee37d7b commit 9f51753
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 2 deletions.
8 changes: 8 additions & 0 deletions concept_erasure/leace.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ def __call__(self, x: Tensor) -> Tensor:
x_ = x - (delta @ self.proj_right.mH) @ self.proj_left.mH
return x_.type_as(x)

def to(self, device: torch.device | str) -> "LeaceEraser":
"""Move eraser to a new device."""
return LeaceEraser(
self.proj_left.to(device),
self.proj_right.to(device),
self.bias.to(device) if self.bias is not None else None,
)


class LeaceFitter:
"""Fits an affine transform that surgically erases a concept from a representation.
Expand Down
4 changes: 4 additions & 0 deletions concept_erasure/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ def __call__(self, x: Tensor, z: Tensor) -> Tensor:

return x.sub(expected_x).type_as(x)

def to(self, device: torch.device | str) -> "OracleEraser":
"""Move eraser to a new device."""
return OracleEraser(self.coef.to(device), self.mean_z.to(device))


class OracleFitter:
"""Compute stats needed for surgically erasing a concept Z from a random vector X.
Expand Down
8 changes: 8 additions & 0 deletions concept_erasure/quadratic.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ def __call__(self, x: Tensor, z: Tensor) -> Tensor:
# Efficiently group `x` by `z`, optimally transport each group, then coalesce
return groupby(x, z).map(self.optimal_transport).coalesce()

def to(self, device: torch.device | str) -> "QuadraticEraser":
"""Move eraser to a new device."""
return QuadraticEraser(
self.class_means.to(device),
self.global_mean.to(device),
self.ot_maps.to(device),
)


@dataclass(frozen=True)
class QuadraticEditor:
Expand Down
6 changes: 4 additions & 2 deletions concept_erasure/scrubbing/neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from tqdm.auto import tqdm
from transformers import (
GPTNeoXForCausalLM,
GPTNeoXLayer,
GPTNeoXModel,
)
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXAttention
from transformers.models.gpt_neox.modeling_gpt_neox import (
GPTNeoXAttention,
GPTNeoXLayer,
)

from concept_erasure import ConceptScrubber, ErasureMethod, LeaceFitter
from concept_erasure.utils import assert_type
Expand Down

0 comments on commit 9f51753

Please sign in to comment.