Skip to content

Commit

Permalink
Removes OLMo Parallel Block
Browse files Browse the repository at this point in the history
  • Loading branch information
dirkgr committed Feb 7, 2024
1 parent 09beb0c commit d0500ab
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 91 deletions.
1 change: 0 additions & 1 deletion olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ class ActivationType(StrEnum):

class BlockType(StrEnum):
sequential = "sequential"
parallel = "parallel"

llama = "llama"
"""
Expand Down
90 changes: 0 additions & 90 deletions olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
"SwiGLU",
"OlmoBlock",
"OlmoSequentialBlock",
"OlmoParallelBlock",
"Olmo",
"OlmoOutput",
"OlmoGenerateOutput",
Expand Down Expand Up @@ -610,8 +609,6 @@ def forward(
def build(cls, layer_id: int, config: ModelConfig, cache: BufferCache) -> OlmoBlock:
if config.block_type == BlockType.sequential:
return OlmoSequentialBlock(layer_id, config, cache)
elif config.block_type == BlockType.parallel:
return OlmoParallelBlock(layer_id, config, cache)
elif config.block_type == BlockType.llama:
return OlmoLlamaBlock(layer_id, config, cache)
else:
Expand Down Expand Up @@ -709,93 +706,6 @@ def forward(
return x, cache


class OlmoParallelBlock(OlmoBlock):
"""
This is a transformer block where the output is computed as ``MLP(LN(x)) + Attention(LN(x))``
as in the PaLM architecture, as opposed to the typical ``MLP(LN(x + Attention(LN(x))))``
as in :class:`OlmoSequentialBlock` (ignoring some skip connections).
The decoupling of the MLP and Attention functions allow us to fuse the separate input projections
into a single linear layer to increase throughput. In this configuration it's also straight-forward
to fuse the output projections, but we found that didn't help.
"""

def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
super().__init__(layer_id, config, cache)
self.norm = LayerNorm.build(config)
# Fused attention and feed-forward projection.
# NOTE: we could also fuse the attention and feed-forward output projections but we
# found that didn't help, possibly because of the overhead of joining the `att` and
# `ff` activations together. See https://github.com/allenai/LLM/pull/79 for details.
if config.multi_query_attention:
self.fused_dims = (
config.d_model,
config.d_model // config.n_heads,
config.d_model // config.n_heads,
self.hidden_size,
)
else:
self.fused_dims = (config.d_model, config.d_model, config.d_model, self.hidden_size)
self.fused_attn_ff_proj = nn.Linear(
config.d_model, sum(self.fused_dims), bias=config.include_bias, device=config.init_device
)

def reset_parameters(self):
super().reset_parameters()
self.norm.reset_parameters()
# NOTE: the standard deviation for these weights does not depend on the layer.
init_weights(
self.config,
self.fused_attn_ff_proj,
d=self.config.d_model,
layer_id=None,
type_of_module=ModuleType.in_module,
)

def forward(
self,
x: torch.Tensor,
attention_bias: Optional[torch.Tensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
# Get query, key, value, and feed-forward projections.
# shape of q, k, v:
# - for regular attn q, k, v: (batch_size, seq_len, d_model)
# - for multi-query attn q: (batch_size, seq_len, d_model)
# k, v: (batch_size, seq_len, d_model // n_heads)
# shape of ff: (batch_size, seq_len, hidden_size)
if self._activation_checkpoint_fn is not None:
q, k, v, ff = self.fused_attn_ff_proj(self._activation_checkpoint_fn(self.norm, x)).split(
self.fused_dims, dim=-1
)
else:
q, k, v, ff = self.fused_attn_ff_proj(self.norm(x)).split(self.fused_dims, dim=-1)

# Get attention scores.
# shape: (B, T, C)
if self._activation_checkpoint_fn is not None:
att, cache = self._activation_checkpoint_fn( # type: ignore
self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
)
else:
att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)

# Apply output projections (and activation function) and sum the results.
# We keep these projections separate because we found that we got better throughput this
# way compared to fusing them.
if self._activation_checkpoint_fn is not None:
return (
x + self.dropout(self.ff_out(self._activation_checkpoint_fn(self.act, ff))) + self.dropout(att),
cache,
)
else:
return (
x + self.dropout(self.ff_out(self.act(ff))) + self.dropout(att),
cache,
)


class OlmoLlamaBlock(OlmoBlock):
"""
This is a transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
Expand Down

0 comments on commit d0500ab

Please sign in to comment.