Skip to content

Commit

Permalink
Use replaces= for swiglu (#1200)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Mar 26, 2024
1 parent d296c98 commit e4d634d
Showing 1 changed file with 11 additions and 15 deletions.
26 changes: 11 additions & 15 deletions extensions/thunder/unsloth/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ def unsloth_cross_entropy_meta(logits: TensorProxy, labels: TensorProxy) -> Tupl


def unsloth_cross_entropy_backward_impl(dlosses: Tensor, logits: Tensor, labels: Tensor, logsumexp: Tensor) -> Tensor:
# clone() because the kernel writes the grads in the logits.
# If it works, we can remove this it, but it's not a thing we generally anticipate and support right now.
# clone() because the kernel writes the grads in the logits
return kernels.cross_entropy_loss._cross_entropy_backward_impl(dlosses, logits.clone(), logsumexp, labels)


Expand Down Expand Up @@ -152,34 +151,31 @@ def unsloth_cross_entropy_grad(
"""


def swiglu_forward_meta(e: TensorProxy, g: TensorProxy) -> TensorProxy:
return TensorProxy(like=e)


def swiglu_forward(e: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
def swiglu(e: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.silu(e) * g


swiglu = unsloth_ex.register_operator("swiglu", meta=swiglu_forward_meta, fn=swiglu_forward)


from litgpt.model import LLaMAMLP as OriginalLLaMAMLP


class ThunderLLaMAMLP(OriginalLLaMAMLP):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_fc_1 = self.fc_1(x)
x_fc_2 = self.fc_2(x)
# There's no `register_operator` for Modules and `swiglu_forward` is not a torch symbol that we can register to
# For now, some duplication and monkey patching is required
fn = swiglu if thunder.core.interpreter.is_jitting() else swiglu_forward
x = fn(x_fc_1, x_fc_2)
x = swiglu(x_fc_1, x_fc_2)
return self.proj(x)


litgpt.model.LLaMAMLP = ThunderLLaMAMLP


def swiglu_forward_meta(e: TensorProxy, g: TensorProxy) -> TensorProxy:
return TensorProxy(like=e)


litgpt_swiglu = unsloth_ex.register_operator("litgpt_swiglu", meta=swiglu_forward_meta, fn=swiglu, replaces=swiglu)


unsloth_swiglu_forward = unsloth_ex.register_operator(
"unsloth_swiglu_forward", meta=swiglu_forward_meta, fn=lambda *args: kernels.swiglu_fg_kernel(*args)
)
Expand Down Expand Up @@ -217,7 +213,7 @@ def unsloth_swiglu_grad(e: TensorProxy, g: TensorProxy) -> TensorProxy:


unsloth_ex.register_implementation(
swiglu,
litgpt_swiglu,
checker=swiglu_to_unsloth_checker,
execution_transform=unsloth_swiglu_forward,
grad_transform=unsloth_swiglu_grad,
Expand Down

0 comments on commit e4d634d

Please sign in to comment.