From c95f92a5119f276d908bcba93b2cd5421201e852 Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Thu, 7 Mar 2024 17:19:04 -0800 Subject: [PATCH 1/2] Removes AMD Layer Norm --- olmo/config.py | 5 ----- olmo/model.py | 35 ----------------------------------- tests/model_test.py | 9 --------- 3 files changed, 49 deletions(-) diff --git a/olmo/config.py b/olmo/config.py index 66e176a7e..d9e257f88 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -171,11 +171,6 @@ class LayerNormType(StrEnum): probably the fastest implementation. """ - amd_compatible = "amd_compatible" - """ - LayerNorm implemented manually to work around an issue with ROCm. - """ - class ActivationType(StrEnum): gelu = "gelu" diff --git a/olmo/model.py b/olmo/model.py index bd7be3097..f975c7c98 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -57,7 +57,6 @@ "LayerNormBase", "LayerNorm", "RMSLayerNorm", - "AMDLayerNorm", "RotaryEmbedding", "Activation", "GELU", @@ -152,8 +151,6 @@ def build(cls, config: ModelConfig, size: Optional[int] = None, **kwargs) -> Lay return LayerNorm(config, size=size, low_precision=True, **kwargs) elif config.layer_norm_type == LayerNormType.rms: return RMSLayerNorm(config, size=size, **kwargs) - elif config.layer_norm_type == LayerNormType.amd_compatible: - return AMDLayerNorm(config, size=size, **kwargs) else: raise NotImplementedError(f"Unknown LayerNorm type: '{config.layer_norm_type}'") @@ -207,38 +204,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps) -class AMDLayerNorm(LayerNormBase): - """ - LayerNorm implemented using PyTorch primitives. - - We do this to work around a bug in the PyTorch/ROCm implementation of layer norm that fails with a - segfault when the bias is not present. - """ - - def __init__( - self, - config: ModelConfig, - size: Optional[int] = None, - elementwise_affine: Optional[bool] = None, - eps: float = 1e-05, - ): - super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - og_dtype = x.dtype - x = self._cast_if_autocast_enabled(x, dtype=torch.float32) - with torch.autocast(enabled=False, device_type=x.device.type): - var, mean = torch.var_mean(x, dim=-1, correction=0, keepdim=True) - var.add_(self.eps) - var.rsqrt_() # rsqrt should be more stable than 1/sqrt - x = var * (x - mean) - if self.weight is not None: - x.mul_(self.weight) - if self.bias is not None: - x.add_(self.bias) - return x.to(og_dtype) - - class RMSLayerNorm(LayerNormBase): """ RMS layer norm, a simplified :class:`LayerNorm` implementation diff --git a/tests/model_test.py b/tests/model_test.py index 18dd5401f..79f2b1a26 100644 --- a/tests/model_test.py +++ b/tests/model_test.py @@ -6,7 +6,6 @@ from olmo import BlockType, LayerNorm, Olmo, Tokenizer, TrainConfig from olmo.config import ModelConfig, PaddingDirection from olmo.data import DataCollator -from olmo.model import AMDLayerNorm @pytest.mark.parametrize( @@ -399,7 +398,6 @@ def test_layer_norm(train_config: TrainConfig, elementwise_affine: bool, include train_config.model.layer_norm_with_affine = elementwise_affine train_config.model.include_bias = include_bias ln = LayerNorm.build(train_config.model) - amd_ln = AMDLayerNorm(train_config.model) needs_weight = elementwise_affine needs_bias = elementwise_affine and include_bias @@ -407,21 +405,17 @@ def test_layer_norm(train_config: TrainConfig, elementwise_affine: bool, include if needs_weight: weight = torch.randn(train_config.model.d_model) ln.weight.copy_(weight) - amd_ln.weight.copy_(weight) else: weight = None if needs_bias: bias = torch.randn(train_config.model.d_model) ln.bias.copy_(bias) - amd_ln.bias.copy_(bias) else: bias = None assert ln.bias is None or ln.bias.requires_grad == needs_bias assert ln.weight is None or ln.weight.requires_grad == needs_weight - assert amd_ln.bias is None or amd_ln.bias.requires_grad == needs_bias - assert amd_ln.weight is None or amd_ln.weight.requires_grad == needs_weight x = torch.randn(16, 1024, train_config.model.d_model) x.requires_grad = False @@ -430,9 +424,6 @@ def test_layer_norm(train_config: TrainConfig, elementwise_affine: bool, include y_actual = ln(x) torch.testing.assert_close(y_actual, y_expected) - y_actual = amd_ln(x) - torch.testing.assert_close(y_actual, y_expected) - def test_block_groups(): model_with_block_groups = Olmo(ModelConfig(d_model=128, n_heads=2, n_layers=9, block_group_size=3)).eval() From 1810817d7ce905a553e98f4d39bb7e09bc25ec72 Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Thu, 7 Mar 2024 17:20:00 -0800 Subject: [PATCH 2/2] Changelog --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 183826d72..9e4e8536b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Changed legacy checkpoint unsharding to use processes and shared memory instead of threads +### Removed + +- Removed `AMDLayerNorm`, since the original layer norm bug has been fixed and we don't need this workaround anymore. + + ## [v0.2.4](https://github.com/allenai/OLMo/releases/tag/v0.2.4) - 2024-02-02 ### Fixed