diff --git a/nemo/collections/llm/fn/activation.py b/nemo/collections/llm/fn/activation.py index 50e076a79d36..e17f486a5f63 100644 --- a/nemo/collections/llm/fn/activation.py +++ b/nemo/collections/llm/fn/activation.py @@ -25,7 +25,7 @@ def openai_gelu(x): return gelu_impl(x) -@torch.jit.script +#@torch.jit.script # remove until we have serialization def squared_relu(x): """Squared ReLU activation function.""" return torch.pow(torch.nn.functional.relu(x), 2)