Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama Config #317

Merged
merged 115 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from 105 commits
Commits
Show all changes
115 commits
Select commit Hold shift + click to select a range
c6b5ee9
cast logits to fp32 before output
epwalsh Sep 28, 2023
27e7c84
revert logit manual cast
epwalsh Sep 28, 2023
1850ebb
add option for no weight tying
epwalsh Sep 28, 2023
b852ec4
init ff_out weights
epwalsh Sep 28, 2023
a2f1517
fix init device for ff_out
epwalsh Sep 28, 2023
387f659
refactor how we cache buffers
epwalsh Sep 29, 2023
160d143
cache rope sin and cos
epwalsh Sep 29, 2023
7fc33c5
Refactor how RoPE is applied
epwalsh Sep 29, 2023
ee95fd3
remove unused import
epwalsh Sep 29, 2023
95c806c
Add back Olmo.device property
epwalsh Sep 29, 2023
299b5cc
Merge branch 'main' into petew/tweaks
epwalsh Sep 29, 2023
5a628a3
give cache a type, make it required in constructors
epwalsh Oct 2, 2023
ba80eba
Merge branch 'main' into petew/tweaks
epwalsh Oct 3, 2023
5098b94
Allows us to use an intermediate size instead of an mlp ratio
dirkgr Oct 5, 2023
85e38fc
Makes RMSNorm like Llama's
dirkgr Oct 5, 2023
d0f61ca
Llama-like config
dirkgr Oct 5, 2023
8a3c9e5
Enable flash
dirkgr Oct 5, 2023
0d3ad37
Actually use intermediate size
dirkgr Oct 5, 2023
a3c722b
Merge remote-tracking branch 'origin/petew/tweaks' into Llama
dirkgr Oct 5, 2023
96102e2
Turn off weight tying
dirkgr Oct 5, 2023
3976c8b
Formatting
dirkgr Oct 5, 2023
f95c153
Make mypy happy
dirkgr Oct 5, 2023
186fd2a
Switch to the 1.5 data mix
dirkgr Oct 5, 2023
c245630
Longer sequence length
dirkgr Oct 6, 2023
0e6dfcd
Merge branch 'main' into petew/tweaks
epwalsh Oct 6, 2023
b331f8b
add mitch config
epwalsh Oct 6, 2023
75c5813
add option to override hidden size
epwalsh Oct 6, 2023
b9805ff
MCLI configs
epwalsh Oct 6, 2023
36370d0
rename config option to `mlp_hidden_size`
epwalsh Oct 6, 2023
7e8b88f
don't use adaptive clipping
epwalsh Oct 6, 2023
a4577b6
clean up mcli config
epwalsh Oct 6, 2023
6b68368
Add option to skip pre-train ckpt (for debuggin)
epwalsh Oct 6, 2023
20c16da
No QK norm, no affines
epwalsh Oct 6, 2023
f51b04e
enable flash
epwalsh Oct 6, 2023
de4ba36
update configs
epwalsh Oct 6, 2023
20aca2a
apply rotary in FP32
epwalsh Oct 6, 2023
c47ab78
clean up
epwalsh Oct 6, 2023
e8be916
Add v1.5 mix mitch-ish
epwalsh Oct 6, 2023
c3b510f
Merge remote-tracking branch 'origin/petew/tweaks' into Llama
dirkgr Oct 6, 2023
b33c23e
Use Pete's implementation of `mlp_hidden_size`
dirkgr Oct 6, 2023
71f2e91
No more MQA
dirkgr Oct 6, 2023
0672a22
We use hidden size differently
dirkgr Oct 6, 2023
293ef24
Different wrapping strategy
dirkgr Oct 7, 2023
67c9e31
Removes low precision RMSNorm. Fixes high precision RMSNorm.
dirkgr Oct 8, 2023
4cfcc5f
Fix comment
dirkgr Oct 10, 2023
3ecab3b
Disable autocast
dirkgr Oct 10, 2023
e2dceaa
Merge remote-tracking branch 'origin/main' into Llama
dirkgr Oct 10, 2023
27c6866
Fix merge gore
dirkgr Oct 11, 2023
98ccaf4
Full megatron init
dirkgr Oct 11, 2023
71b5050
Adds a script for running the Llama config on LUMI
dirkgr Oct 11, 2023
fe03d05
Adjust Llama config so it runs on LUMI
dirkgr Oct 11, 2023
18a8c6d
Typos
dirkgr Oct 11, 2023
f560f2c
Merge branch 'Llama' of https://github.com/allenai/LLM into Llama
dirkgr Oct 11, 2023
86b6a85
Turn off adaptive grad clipping
dirkgr Oct 11, 2023
6578709
Bring back global gradient clipping
epwalsh Oct 11, 2023
4f5d3b3
fix device type
epwalsh Oct 11, 2023
e342d88
collect metrics on GPU
epwalsh Oct 11, 2023
d8031f5
fix device type of numel tensor
epwalsh Oct 11, 2023
3327591
separate clipping strategies to methods
epwalsh Oct 11, 2023
07a9cab
clean up
epwalsh Oct 11, 2023
e66f8d7
wait longer for host-device sync of loss
epwalsh Oct 11, 2023
938c3c8
overlap z-batch-loss reduction too
epwalsh Oct 11, 2023
b476200
fix typo in comment
epwalsh Oct 11, 2023
ec8db60
add more no_grads for future proofing
epwalsh Oct 11, 2023
8d4ce6b
Merge branch 'Llama' of https://github.com/allenai/LLM into Llama
dirkgr Oct 11, 2023
68f3765
Merge remote-tracking branch 'origin/petew/global-grad-clipping' into…
dirkgr Oct 11, 2023
a80af4d
Decay everything
dirkgr Oct 12, 2023
2811f79
Merge remote-tracking branch 'origin/main' into Llama
dirkgr Oct 12, 2023
236bcb1
Separate settings for decaying norms and embeddings
dirkgr Oct 12, 2023
3b88be1
Fix Llama run name
dirkgr Oct 12, 2023
d00b4dc
Fix how we build parameter groups
dirkgr Oct 12, 2023
e4ce8d4
Add a config for debugging on Beaker
dirkgr Oct 12, 2023
f7aae4d
Put in the config we actually want
dirkgr Oct 12, 2023
50f8650
Merge pull request #329 from allenai/Llama-Decay
dirkgr Oct 12, 2023
ab3a73d
Decay embeddings
dirkgr Oct 13, 2023
e4db9c8
Only use unsharded checkpoints
dirkgr Oct 13, 2023
7c7211b
Merge branch 'main' into Llama
dirkgr Oct 23, 2023
d71c346
Merge branch 'main' into Llama
dirkgr Oct 24, 2023
561f79c
Merge remote-tracking branch 'origin/main' into Llama
dirkgr Oct 27, 2023
bf30c13
Rename config
dirkgr Oct 27, 2023
ea01559
LUMI scripts now live in a LUMI world
dirkgr Oct 27, 2023
db10402
Use new paths
dirkgr Oct 27, 2023
1a265b5
Rename some things
dirkgr Oct 28, 2023
1e7d85b
Adds a script for Llama7
dirkgr Oct 28, 2023
de8fd96
More renaming
dirkgr Oct 28, 2023
1e644f9
Implements fine-grained activation checkpointing
dirkgr Oct 28, 2023
7ff22c4
Forgot to rename this one
dirkgr Oct 28, 2023
6ed4425
Double-underscore names in Python are "private"-ish, not "protected"-…
dirkgr Oct 28, 2023
1b8ed30
Someone checked in debug code
dirkgr Oct 28, 2023
28c6ff9
Can't cache this
dirkgr Oct 28, 2023
37d0615
Does it work this way around?
dirkgr Oct 28, 2023
1baac62
Giving up on backwards compatibility
dirkgr Oct 28, 2023
6ec2e91
Activation checkpointing works exactly the other way around
dirkgr Oct 28, 2023
3bd8f31
Update job name
dirkgr Oct 28, 2023
f535069
Putting 16 nodes back
dirkgr Oct 28, 2023
33d2af6
Use more activation checkpointing
dirkgr Oct 29, 2023
8e44049
Fix sequence length
dirkgr Oct 29, 2023
8a43e65
Activation checkpointing in parts
dirkgr Oct 31, 2023
1c3eb88
Changing some defaults
dirkgr Oct 31, 2023
145ab24
Merge branch 'Llama-ActCheck' into Llama
dirkgr Oct 31, 2023
c1b9a59
Update configs/llama7.yaml
dirkgr Oct 31, 2023
de0ef38
Clean up configs
dirkgr Oct 31, 2023
41c1251
Merge branch 'Llama' of https://github.com/allenai/LLM into Llama
dirkgr Oct 31, 2023
771b04e
Fix and rename LUMI script
dirkgr Oct 31, 2023
581538c
Undo debugging change
dirkgr Oct 31, 2023
d88e02d
Merge remote-tracking branch 'origin/main' into Llama
dirkgr Nov 1, 2023
f44fae2
Use `StrEnum` for module type
dirkgr Nov 1, 2023
a4609d4
`None` means "no activation checkpointing"
dirkgr Nov 1, 2023
ba7bc3a
Productivity through formatting
dirkgr Nov 1, 2023
445d7c6
Resolve circular imports
dirkgr Nov 1, 2023
7244c0b
Preserve the functionality of existing configs
dirkgr Nov 1, 2023
4f90216
Fix activation checkpointing defaults
dirkgr Nov 1, 2023
f3491db
Try to bring back compile
dirkgr Nov 1, 2023
f52220f
Merge branch 'Llama' of https://github.com/allenai/LLM into Llama
dirkgr Nov 1, 2023
7b6add5
Revert "Try to bring back compile"
dirkgr Nov 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
623 changes: 623 additions & 0 deletions configs/llama7-s3.yaml

Large diffs are not rendered by default.

176 changes: 176 additions & 0 deletions configs/llama7.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
run_name: llama7-001
seed: 6198
dry_run: false

wandb:
name: ${run_name}
project: olmo-medium
group: llama7

model:
d_model: 4096
n_heads: 32
n_layers: 32
mlp_hidden_size: 22016
rope: true
flash_attention: true
attention_dropout: 0.0
attention_layer_norm: false
multi_query_attention: false
include_bias: false
block_type: sequential
layer_norm_type: rms
layer_norm_with_affine: true
bias_for_layer_norm: false
activation_type: swiglu
residual_dropout: 0.0
embedding_dropout: 0.0
max_sequence_length: 3072
vocab_size: 50280
embedding_size: 50304
eos_token_id: 0
pad_token_id: 1
init_device: meta
init_fn: full_megatron
init_std: 0.006
init_cutoff_factor: 3
weight_tying: false

fsdp:
precision: mixed
wrapping_strategy: size_based
sharding_strategy: SHARD_GRAD_OP

activation_checkpointing: by_layer

compile: null

optimizer:
name: adamw
learning_rate: 3.0e-4
weight_decay: 0.1
decay_norm_and_bias: true
decay_embeddings: true
betas:
- 0.9
- 0.95
metrics_log_interval: 10

scheduler:
name: cosine_with_warmup
t_warmup: 2000
alpha_f: 0.1

data:
paths: ${path.glob:${oc.env:DATA_PATH}/v1_5-sample/gpt-neox-20b-pii-special/*.npy}
pad_direction: right
num_workers: 0
drop_last: true
pin_memory: true
prefetch_factor: 16
persistent_workers: true
timeout: 0

tokenizer:
identifier: tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json
truncate_direction: right

save_folder: ${oc.env:CHECKPOINTS_PATH}/${oc.env:SLURM_JOB_ID,${run_name}}
save_overwrite: false
# Sharded checkpoints (best for restarts)
save_interval: 1000000
save_num_checkpoints_to_keep: -1
# Unsharded checkpoints (for final storage)
save_interval_unsharded: 500
save_num_unsharded_checkpoints_to_keep: -1

load_path: null

max_duration: 423855 # 2T tokens
global_train_batch_size: 1536
device_train_microbatch_size: 1

precision: amp_bf16

max_grad_norm: 1.0

speed_monitor:
window_size: 20

eval_interval: ${save_interval}
eval_subset_num_batches: -1
device_eval_batch_size: ${device_train_microbatch_size}
evaluators:
# lump all the small datasets together (we still get separate metrics).
- label: all-small-ppl-validation
data:
datasets:
4chan-validation:
- ${oc.env:EVAL_DATA_PATH}/perplexity/v2_small_gptneox20b/4chan/val.npy
c4_100_domains-validation:
- ${oc.env:EVAL_DATA_PATH}/perplexity/v2_small_gptneox20b/c4_100_domains/val.npy
c4_en-validation:
- ${oc.env:EVAL_DATA_PATH}/perplexity/v2_small_gptneox20b/c4_en/val.npy
gab-validation:
- ${oc.env:EVAL_DATA_PATH}/perplexity/v2_small_gptneox20b/gab/val.npy
ice-validation:
- ${oc.env:EVAL_DATA_PATH}/perplexity/v2_small_gptneox20b/ice/val.npy
m2d2_s2orc-validation:
- ${oc.env:EVAL_DATA_PATH}/perplexity/v2_small_gptneox20b/m2d2_s2orc/val.npy
m2d2_wiki-validation:
- ${oc.env:EVAL_DATA_PATH}/perplexity/v2_small_gptneox20b/m2d2_wiki/val.npy
manosphere-validation:
- ${oc.env:EVAL_DATA_PATH}/perplexity/v2_small_gptneox20b/manosphere/val.npy
mc4_en-validation:
- ${oc.env:EVAL_DATA_PATH}/perplexity/v2_small_gptneox20b/mc4_en/val.npy
pile-validation:
- ${oc.env:EVAL_DATA_PATH}/perplexity/v2_small_gptneox20b/pile/val.npy
ptb-validation:
- ${oc.env:EVAL_DATA_PATH}/perplexity/v2_small_gptneox20b/ptb/val.npy
twitterAEE-validation:
- ${oc.env:EVAL_DATA_PATH}/perplexity/v2_small_gptneox20b/twitterAEE/val.npy
wikitext_103-validation:
- ${oc.env:EVAL_DATA_PATH}/perplexity/v2_small_gptneox20b/wikitext_103/val.npy
drop_last: true

##########################
# Downstream evaluations #
##########################
- label: piqa
type: downstream

- label: hellaswag
type: downstream

- label: winogrande
type: downstream

- label: openbook_qa
type: downstream

# - label: boolq # requires implementation of the pmi_dc matrix
# type: downstream

- label: sciq
type: downstream

- label: arc_easy
type: downstream

# - label: arc_challenge # requires implementation of the pmi_dc matrix
# type: downstream

- label: copa
type: downstream

- label: rte
type: downstream

- label: commitment_bank
type: downstream

- label: mrpc
type: downstream

- label: sst2
type: downstream
28 changes: 19 additions & 9 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
__all__ = [
"LogFilterType",
"ActivationType",
"ActivationCheckpointingStrategy",
"BlockType",
"CompilerConfig",
"LayerNormType",
Expand Down Expand Up @@ -166,11 +167,6 @@ class LayerNormType(StrEnum):
probably the fastest implementation.
"""

low_precision_rms = "low_precision_rms"
"""
A low-precision version of RMSNorm.
"""

amd_compatible = "amd_compatible"
"""
LayerNorm implemented manually to work around an issue with ROCm.
Expand Down Expand Up @@ -213,6 +209,11 @@ class InitFnType(StrEnum):
is the input dimensionality of the kernel.
"""

full_megatron = "full_megatron"
"""
This is what metaseq calls "full megatron init". It is the init used for Llama 2.
"""


@dataclass
class ModelConfig(BaseConfig):
Expand Down Expand Up @@ -423,8 +424,8 @@ class OptimizerConfig(BaseConfig):
learning_rate: float = 1.0e-4
weight_decay: float = 0.01
betas: Tuple[float, float] = (0.9, 0.95)
no_decay_norm_and_bias: bool = True
"""Do not apply weight decay to norms and biases."""
decay_norm_and_bias: bool = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a breaking change. I'd suggest keeping the old option and marking it as deprecated, or we need a preprocessor that renames the old option to the new one when present.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renames, and also, changes the value to not value.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did that in 7244c0b. Please read that carefully! There are a lot of not and "no"s around, and this is exactly the kind of thing where you can run for 500B tokens before you notice you screwed up.

decay_embeddings: bool = False
metrics_log_interval: Optional[int] = None
"""
The interval with which to collect and log detailed parameter-specific metrics.
Expand Down Expand Up @@ -594,6 +595,15 @@ class ShardedCheckpointerType(StrEnum):
local = "local"


class ActivationCheckpointingStrategy(StrEnum):
none = "none"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: instead of having a "none" variant we could make ActivationCheckpointStrategy optional. I think that's more Pythonic.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whole_layer = "whole_layer"
one_in_two = "one_in_two"
one_in_three = "one_in_three"
one_in_four = "one_in_four"
fine_grained = "fine_grained"


@dataclass
class TrainConfig(BaseConfig):
"""
Expand Down Expand Up @@ -861,9 +871,9 @@ class TrainConfig(BaseConfig):
Stop at a specific step.
"""

activation_checkpointing: bool = False
activation_checkpointing: ActivationCheckpointingStrategy = ActivationCheckpointingStrategy.none
"""
Use activation checkpointing on transformer blocks.
The activation checkpointing strategy to use.
"""

@property
Expand Down
28 changes: 28 additions & 0 deletions olmo/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def init_weights(
d: Optional[int] = None,
layer_id: Optional[int] = None,
std_factor: float = 1.0,
type_of_module: str = "",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: make a StrEnum for this

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

) -> None:
"""
Initialize weights of a linear or embedding module.
Expand Down Expand Up @@ -44,6 +45,33 @@ def init_weights(
elif config.init_fn == InitFnType.fan_in:
std = std_factor / math.sqrt(d)
nn.init.normal_(module.weight, mean=0.0, std=std)
elif config.init_fn == InitFnType.full_megatron:
cutoff_factor = config.init_cutoff_factor
if cutoff_factor is None:
cutoff_factor = 3

if type_of_module == "in":
# for att_proj (same as QKV), ff_proj
std = config.init_std
elif type_of_module == "out":
# for attn_out, ff_out
std = config.init_std / math.sqrt(2.0 * config.n_layers)
elif type_of_module == "emb":
# positional embeddings (wpe)
# token embeddings (wte)
std = config.init_std
elif type_of_module == "final_out":
# final output (ff_out)
std = config.d_model**-0.5
else:
raise RuntimeError(f"Unknown module type '{type_of_module}'")
nn.init.trunc_normal_(
module.weight,
mean=0.0,
std=std,
a=-cutoff_factor * std,
b=cutoff_factor * std,
)
else:
raise NotImplementedError(config.init_fn)

Expand Down
Loading
Loading