Skip to content

Commit

Permalink
Merge pull request #62 from epfLLM/fix_iteration_time_linear_increase
Browse files Browse the repository at this point in the history
Fixed linear time increase observed when micro=1
  • Loading branch information
AleHD authored Sep 8, 2023
2 parents 10aaed8 + 483512a commit d7e3d04
Showing 1 changed file with 7 additions and 13 deletions.
20 changes: 7 additions & 13 deletions megatron/model/glu_activations.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,17 @@
# Extracted from: https://github.com/bigscience-workshop/Megatron-DeepSpeed

import logging
# Adapted from: https://github.com/bigscience-workshop/Megatron-DeepSpeed

import torch
from torch import nn
from torch.nn import functional as F


logger = logging.getLogger(__name__)


class _GLUBaseModule(nn.Module):
def __init__(self, activation_fn):
super().__init__()
self.activation_fn = activation_fn

def forward(self, x):
# dim=-1 breaks in jit for pt<1.10
x1, x2 = x.chunk(2, dim=(x.ndim - 1))
x1, x2 = torch.chunk(x, 2, dim=-1)
return x1 * self.activation_fn(x2)


Expand All @@ -41,15 +35,15 @@ def __init__(self):
super().__init__(F.silu)


liglu = torch.jit.script(LiGLU())
geglu = torch.jit.script(GEGLU())
reglu = torch.jit.script(ReGLU())
swiglu = torch.jit.script(SwiGLU())
liglu = LiGLU()
geglu = GEGLU()
reglu = ReGLU()
swiglu = SwiGLU()


GLU_ACTIVATIONS = {
"geglu": geglu,
"liglu": liglu,
"reglu": reglu,
"swiglu": swiglu,
}
}

0 comments on commit d7e3d04

Please sign in to comment.