From e4d634ddbc018fbbd014f94436f89c7ad406db1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 26 Mar 2024 21:29:21 +0100 Subject: [PATCH] Use `replaces=` for swiglu (#1200) --- extensions/thunder/unsloth/executor.py | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/extensions/thunder/unsloth/executor.py b/extensions/thunder/unsloth/executor.py index a638af079f..5b13c4dee2 100644 --- a/extensions/thunder/unsloth/executor.py +++ b/extensions/thunder/unsloth/executor.py @@ -48,8 +48,7 @@ def unsloth_cross_entropy_meta(logits: TensorProxy, labels: TensorProxy) -> Tupl def unsloth_cross_entropy_backward_impl(dlosses: Tensor, logits: Tensor, labels: Tensor, logsumexp: Tensor) -> Tensor: - # clone() because the kernel writes the grads in the logits. - # If it works, we can remove this it, but it's not a thing we generally anticipate and support right now. + # clone() because the kernel writes the grads in the logits return kernels.cross_entropy_loss._cross_entropy_backward_impl(dlosses, logits.clone(), logsumexp, labels) @@ -152,17 +151,10 @@ def unsloth_cross_entropy_grad( """ -def swiglu_forward_meta(e: TensorProxy, g: TensorProxy) -> TensorProxy: - return TensorProxy(like=e) - - -def swiglu_forward(e: torch.Tensor, g: torch.Tensor) -> torch.Tensor: +def swiglu(e: torch.Tensor, g: torch.Tensor) -> torch.Tensor: return torch.nn.functional.silu(e) * g -swiglu = unsloth_ex.register_operator("swiglu", meta=swiglu_forward_meta, fn=swiglu_forward) - - from litgpt.model import LLaMAMLP as OriginalLLaMAMLP @@ -170,16 +162,20 @@ class ThunderLLaMAMLP(OriginalLLaMAMLP): def forward(self, x: torch.Tensor) -> torch.Tensor: x_fc_1 = self.fc_1(x) x_fc_2 = self.fc_2(x) - # There's no `register_operator` for Modules and `swiglu_forward` is not a torch symbol that we can register to - # For now, some duplication and monkey patching is required - fn = swiglu if thunder.core.interpreter.is_jitting() else swiglu_forward - x = fn(x_fc_1, x_fc_2) + x = swiglu(x_fc_1, x_fc_2) return self.proj(x) litgpt.model.LLaMAMLP = ThunderLLaMAMLP +def swiglu_forward_meta(e: TensorProxy, g: TensorProxy) -> TensorProxy: + return TensorProxy(like=e) + + +litgpt_swiglu = unsloth_ex.register_operator("litgpt_swiglu", meta=swiglu_forward_meta, fn=swiglu, replaces=swiglu) + + unsloth_swiglu_forward = unsloth_ex.register_operator( "unsloth_swiglu_forward", meta=swiglu_forward_meta, fn=lambda *args: kernels.swiglu_fg_kernel(*args) ) @@ -217,7 +213,7 @@ def unsloth_swiglu_grad(e: TensorProxy, g: TensorProxy) -> TensorProxy: unsloth_ex.register_implementation( - swiglu, + litgpt_swiglu, checker=swiglu_to_unsloth_checker, execution_transform=unsloth_swiglu_forward, grad_transform=unsloth_swiglu_grad,