-
Notifications
You must be signed in to change notification settings - Fork 514
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llama Config #317
Llama Config #317
Changes from 105 commits
c6b5ee9
27e7c84
1850ebb
b852ec4
a2f1517
387f659
160d143
7fc33c5
ee95fd3
95c806c
299b5cc
5a628a3
ba80eba
5098b94
85e38fc
d0f61ca
8a3c9e5
0d3ad37
a3c722b
96102e2
3976c8b
f95c153
186fd2a
c245630
0e6dfcd
b331f8b
75c5813
b9805ff
36370d0
7e8b88f
a4577b6
6b68368
20c16da
f51b04e
de4ba36
20aca2a
c47ab78
e8be916
c3b510f
b33c23e
71f2e91
0672a22
293ef24
67c9e31
4cfcc5f
3ecab3b
e2dceaa
27c6866
98ccaf4
71b5050
fe03d05
18a8c6d
f560f2c
86b6a85
6578709
4f5d3b3
e342d88
d8031f5
3327591
07a9cab
e66f8d7
938c3c8
b476200
ec8db60
8d4ce6b
68f3765
a80af4d
2811f79
236bcb1
3b88be1
d00b4dc
e4ce8d4
f7aae4d
50f8650
ab3a73d
e4db9c8
7c7211b
d71c346
561f79c
bf30c13
ea01559
db10402
1a265b5
1e7d85b
de8fd96
1e644f9
7ff22c4
6ed4425
1b8ed30
28c6ff9
37d0615
1baac62
6ec2e91
3bd8f31
f535069
33d2af6
8e44049
8a43e65
1c3eb88
145ab24
c1b9a59
de0ef38
41c1251
771b04e
581538c
d88e02d
f44fae2
a4609d4
ba7bc3a
445d7c6
7244c0b
4f90216
f3491db
f52220f
7b6add5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
run_name: llama7-001 | ||
seed: 6198 | ||
dry_run: false | ||
|
||
wandb: | ||
name: ${run_name} | ||
project: olmo-medium | ||
group: llama7 | ||
|
||
model: | ||
d_model: 4096 | ||
n_heads: 32 | ||
n_layers: 32 | ||
mlp_hidden_size: 22016 | ||
rope: true | ||
flash_attention: true | ||
attention_dropout: 0.0 | ||
attention_layer_norm: false | ||
multi_query_attention: false | ||
include_bias: false | ||
block_type: sequential | ||
layer_norm_type: rms | ||
layer_norm_with_affine: true | ||
bias_for_layer_norm: false | ||
activation_type: swiglu | ||
residual_dropout: 0.0 | ||
embedding_dropout: 0.0 | ||
max_sequence_length: 3072 | ||
vocab_size: 50280 | ||
embedding_size: 50304 | ||
eos_token_id: 0 | ||
pad_token_id: 1 | ||
init_device: meta | ||
init_fn: full_megatron | ||
init_std: 0.006 | ||
init_cutoff_factor: 3 | ||
weight_tying: false | ||
|
||
fsdp: | ||
precision: mixed | ||
wrapping_strategy: size_based | ||
sharding_strategy: SHARD_GRAD_OP | ||
|
||
activation_checkpointing: by_layer | ||
|
||
compile: null | ||
|
||
optimizer: | ||
name: adamw | ||
learning_rate: 3.0e-4 | ||
weight_decay: 0.1 | ||
decay_norm_and_bias: true | ||
decay_embeddings: true | ||
betas: | ||
- 0.9 | ||
- 0.95 | ||
metrics_log_interval: 10 | ||
|
||
scheduler: | ||
name: cosine_with_warmup | ||
t_warmup: 2000 | ||
alpha_f: 0.1 | ||
|
||
data: | ||
paths: ${path.glob:${oc.env:DATA_PATH}/v1_5-sample/gpt-neox-20b-pii-special/*.npy} | ||
pad_direction: right | ||
num_workers: 0 | ||
drop_last: true | ||
pin_memory: true | ||
prefetch_factor: 16 | ||
persistent_workers: true | ||
timeout: 0 | ||
|
||
tokenizer: | ||
identifier: tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json | ||
truncate_direction: right | ||
|
||
save_folder: ${oc.env:CHECKPOINTS_PATH}/${oc.env:SLURM_JOB_ID,${run_name}} | ||
save_overwrite: false | ||
# Sharded checkpoints (best for restarts) | ||
save_interval: 1000000 | ||
save_num_checkpoints_to_keep: -1 | ||
# Unsharded checkpoints (for final storage) | ||
save_interval_unsharded: 500 | ||
save_num_unsharded_checkpoints_to_keep: -1 | ||
|
||
load_path: null | ||
|
||
max_duration: 423855 # 2T tokens | ||
global_train_batch_size: 1536 | ||
device_train_microbatch_size: 1 | ||
|
||
precision: amp_bf16 | ||
|
||
max_grad_norm: 1.0 | ||
|
||
speed_monitor: | ||
window_size: 20 | ||
|
||
eval_interval: ${save_interval} | ||
eval_subset_num_batches: -1 | ||
device_eval_batch_size: ${device_train_microbatch_size} | ||
evaluators: | ||
# lump all the small datasets together (we still get separate metrics). | ||
- label: all-small-ppl-validation | ||
data: | ||
datasets: | ||
4chan-validation: | ||
- ${oc.env:EVAL_DATA_PATH}/perplexity/v2_small_gptneox20b/4chan/val.npy | ||
c4_100_domains-validation: | ||
- ${oc.env:EVAL_DATA_PATH}/perplexity/v2_small_gptneox20b/c4_100_domains/val.npy | ||
c4_en-validation: | ||
- ${oc.env:EVAL_DATA_PATH}/perplexity/v2_small_gptneox20b/c4_en/val.npy | ||
gab-validation: | ||
- ${oc.env:EVAL_DATA_PATH}/perplexity/v2_small_gptneox20b/gab/val.npy | ||
ice-validation: | ||
- ${oc.env:EVAL_DATA_PATH}/perplexity/v2_small_gptneox20b/ice/val.npy | ||
m2d2_s2orc-validation: | ||
- ${oc.env:EVAL_DATA_PATH}/perplexity/v2_small_gptneox20b/m2d2_s2orc/val.npy | ||
m2d2_wiki-validation: | ||
- ${oc.env:EVAL_DATA_PATH}/perplexity/v2_small_gptneox20b/m2d2_wiki/val.npy | ||
manosphere-validation: | ||
- ${oc.env:EVAL_DATA_PATH}/perplexity/v2_small_gptneox20b/manosphere/val.npy | ||
mc4_en-validation: | ||
- ${oc.env:EVAL_DATA_PATH}/perplexity/v2_small_gptneox20b/mc4_en/val.npy | ||
pile-validation: | ||
- ${oc.env:EVAL_DATA_PATH}/perplexity/v2_small_gptneox20b/pile/val.npy | ||
ptb-validation: | ||
- ${oc.env:EVAL_DATA_PATH}/perplexity/v2_small_gptneox20b/ptb/val.npy | ||
twitterAEE-validation: | ||
- ${oc.env:EVAL_DATA_PATH}/perplexity/v2_small_gptneox20b/twitterAEE/val.npy | ||
wikitext_103-validation: | ||
- ${oc.env:EVAL_DATA_PATH}/perplexity/v2_small_gptneox20b/wikitext_103/val.npy | ||
drop_last: true | ||
|
||
########################## | ||
# Downstream evaluations # | ||
########################## | ||
- label: piqa | ||
type: downstream | ||
|
||
- label: hellaswag | ||
type: downstream | ||
|
||
- label: winogrande | ||
type: downstream | ||
|
||
- label: openbook_qa | ||
type: downstream | ||
|
||
# - label: boolq # requires implementation of the pmi_dc matrix | ||
# type: downstream | ||
|
||
- label: sciq | ||
type: downstream | ||
|
||
- label: arc_easy | ||
type: downstream | ||
|
||
# - label: arc_challenge # requires implementation of the pmi_dc matrix | ||
# type: downstream | ||
|
||
- label: copa | ||
type: downstream | ||
|
||
- label: rte | ||
type: downstream | ||
|
||
- label: commitment_bank | ||
type: downstream | ||
|
||
- label: mrpc | ||
type: downstream | ||
|
||
- label: sst2 | ||
type: downstream |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,7 @@ | |
__all__ = [ | ||
"LogFilterType", | ||
"ActivationType", | ||
"ActivationCheckpointingStrategy", | ||
"BlockType", | ||
"CompilerConfig", | ||
"LayerNormType", | ||
|
@@ -166,11 +167,6 @@ class LayerNormType(StrEnum): | |
probably the fastest implementation. | ||
""" | ||
|
||
low_precision_rms = "low_precision_rms" | ||
""" | ||
A low-precision version of RMSNorm. | ||
""" | ||
|
||
amd_compatible = "amd_compatible" | ||
""" | ||
LayerNorm implemented manually to work around an issue with ROCm. | ||
|
@@ -213,6 +209,11 @@ class InitFnType(StrEnum): | |
is the input dimensionality of the kernel. | ||
""" | ||
|
||
full_megatron = "full_megatron" | ||
""" | ||
This is what metaseq calls "full megatron init". It is the init used for Llama 2. | ||
""" | ||
|
||
|
||
@dataclass | ||
class ModelConfig(BaseConfig): | ||
|
@@ -423,8 +424,8 @@ class OptimizerConfig(BaseConfig): | |
learning_rate: float = 1.0e-4 | ||
weight_decay: float = 0.01 | ||
betas: Tuple[float, float] = (0.9, 0.95) | ||
no_decay_norm_and_bias: bool = True | ||
"""Do not apply weight decay to norms and biases.""" | ||
decay_norm_and_bias: bool = False | ||
decay_embeddings: bool = False | ||
metrics_log_interval: Optional[int] = None | ||
""" | ||
The interval with which to collect and log detailed parameter-specific metrics. | ||
|
@@ -594,6 +595,15 @@ class ShardedCheckpointerType(StrEnum): | |
local = "local" | ||
|
||
|
||
class ActivationCheckpointingStrategy(StrEnum): | ||
none = "none" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: instead of having a "none" variant we could make There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
whole_layer = "whole_layer" | ||
one_in_two = "one_in_two" | ||
one_in_three = "one_in_three" | ||
one_in_four = "one_in_four" | ||
fine_grained = "fine_grained" | ||
|
||
|
||
@dataclass | ||
class TrainConfig(BaseConfig): | ||
""" | ||
|
@@ -861,9 +871,9 @@ class TrainConfig(BaseConfig): | |
Stop at a specific step. | ||
""" | ||
|
||
activation_checkpointing: bool = False | ||
activation_checkpointing: ActivationCheckpointingStrategy = ActivationCheckpointingStrategy.none | ||
""" | ||
Use activation checkpointing on transformer blocks. | ||
The activation checkpointing strategy to use. | ||
""" | ||
|
||
@property | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ def init_weights( | |
d: Optional[int] = None, | ||
layer_id: Optional[int] = None, | ||
std_factor: float = 1.0, | ||
type_of_module: str = "", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: make a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
) -> None: | ||
""" | ||
Initialize weights of a linear or embedding module. | ||
|
@@ -44,6 +45,33 @@ def init_weights( | |
elif config.init_fn == InitFnType.fan_in: | ||
std = std_factor / math.sqrt(d) | ||
nn.init.normal_(module.weight, mean=0.0, std=std) | ||
elif config.init_fn == InitFnType.full_megatron: | ||
cutoff_factor = config.init_cutoff_factor | ||
if cutoff_factor is None: | ||
cutoff_factor = 3 | ||
|
||
if type_of_module == "in": | ||
# for att_proj (same as QKV), ff_proj | ||
std = config.init_std | ||
elif type_of_module == "out": | ||
# for attn_out, ff_out | ||
std = config.init_std / math.sqrt(2.0 * config.n_layers) | ||
elif type_of_module == "emb": | ||
# positional embeddings (wpe) | ||
# token embeddings (wte) | ||
std = config.init_std | ||
elif type_of_module == "final_out": | ||
# final output (ff_out) | ||
std = config.d_model**-0.5 | ||
else: | ||
raise RuntimeError(f"Unknown module type '{type_of_module}'") | ||
nn.init.trunc_normal_( | ||
module.weight, | ||
mean=0.0, | ||
std=std, | ||
a=-cutoff_factor * std, | ||
b=cutoff_factor * std, | ||
) | ||
else: | ||
raise NotImplementedError(config.init_fn) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a breaking change. I'd suggest keeping the old option and marking it as deprecated, or we need a preprocessor that renames the old option to the new one when present.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renames, and also, changes the value to
not value
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did that in 7244c0b. Please read that carefully! There are a lot of
not
and "no"s around, and this is exactly the kind of thing where you can run for 500B tokens before you notice you screwed up.