diff --git a/concept_erasure/leace.py b/concept_erasure/leace.py index c599a80..ce58f36 100644 --- a/concept_erasure/leace.py +++ b/concept_erasure/leace.py @@ -46,6 +46,14 @@ def __call__(self, x: Tensor) -> Tensor: x_ = x - (delta @ self.proj_right.mH) @ self.proj_left.mH return x_.type_as(x) + def to(self, device: torch.device | str) -> "LeaceEraser": + """Move eraser to a new device.""" + return LeaceEraser( + self.proj_left.to(device), + self.proj_right.to(device), + self.bias.to(device) if self.bias is not None else None, + ) + class LeaceFitter: """Fits an affine transform that surgically erases a concept from a representation. diff --git a/concept_erasure/oracle.py b/concept_erasure/oracle.py index ae0a2f2..e1ac16b 100644 --- a/concept_erasure/oracle.py +++ b/concept_erasure/oracle.py @@ -27,6 +27,10 @@ def __call__(self, x: Tensor, z: Tensor) -> Tensor: return x.sub(expected_x).type_as(x) + def to(self, device: torch.device | str) -> "OracleEraser": + """Move eraser to a new device.""" + return OracleEraser(self.coef.to(device), self.mean_z.to(device)) + class OracleFitter: """Compute stats needed for surgically erasing a concept Z from a random vector X. diff --git a/concept_erasure/quadratic.py b/concept_erasure/quadratic.py index 75ad21d..480049f 100644 --- a/concept_erasure/quadratic.py +++ b/concept_erasure/quadratic.py @@ -38,6 +38,14 @@ def __call__(self, x: Tensor, z: Tensor) -> Tensor: # Efficiently group `x` by `z`, optimally transport each group, then coalesce return groupby(x, z).map(self.optimal_transport).coalesce() + def to(self, device: torch.device | str) -> "QuadraticEraser": + """Move eraser to a new device.""" + return QuadraticEraser( + self.class_means.to(device), + self.global_mean.to(device), + self.ot_maps.to(device), + ) + @dataclass(frozen=True) class QuadraticEditor: diff --git a/concept_erasure/scrubbing/neox.py b/concept_erasure/scrubbing/neox.py index b435ee7..0fb44e9 100644 --- a/concept_erasure/scrubbing/neox.py +++ b/concept_erasure/scrubbing/neox.py @@ -6,10 +6,12 @@ from tqdm.auto import tqdm from transformers import ( GPTNeoXForCausalLM, - GPTNeoXLayer, GPTNeoXModel, ) -from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXAttention +from transformers.models.gpt_neox.modeling_gpt_neox import ( + GPTNeoXAttention, + GPTNeoXLayer, +) from concept_erasure import ConceptScrubber, ErasureMethod, LeaceFitter from concept_erasure.utils import assert_type