-
Notifications
You must be signed in to change notification settings - Fork 504
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llama Config #317
Llama Config #317
Conversation
I want to merge this as it is, and add @2015aroras's Llama block later. @AkshitaB needs some of the changes in here for her config. |
My top concern is whether and of the code changed in here affect the stability of existing runs, like mitchish. |
@@ -423,8 +424,8 @@ class OptimizerConfig(BaseConfig): | |||
learning_rate: float = 1.0e-4 | |||
weight_decay: float = 0.01 | |||
betas: Tuple[float, float] = (0.9, 0.95) | |||
no_decay_norm_and_bias: bool = True | |||
"""Do not apply weight decay to norms and biases.""" | |||
decay_norm_and_bias: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a breaking change. I'd suggest keeping the old option and marking it as deprecated, or we need a preprocessor that renames the old option to the new one when present.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renames, and also, changes the value to not value
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did that in 7244c0b. Please read that carefully! There are a lot of not
and "no"s around, and this is exactly the kind of thing where you can run for 500B tokens before you notice you screwed up.
olmo/initialization.py
Outdated
@@ -15,6 +15,7 @@ def init_weights( | |||
d: Optional[int] = None, | |||
layer_id: Optional[int] = None, | |||
std_factor: float = 1.0, | |||
type_of_module: str = "", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: make a StrEnum
for this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
olmo/config.py
Outdated
@@ -594,6 +595,15 @@ class ShardedCheckpointerType(StrEnum): | |||
local = "local" | |||
|
|||
|
|||
class ActivationCheckpointingStrategy(StrEnum): | |||
none = "none" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: instead of having a "none" variant we could make ActivationCheckpointStrategy
optional. I think that's more Pythonic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM other than Pete's comment on the breaking change
}, | ||
} | ||
) | ||
if len(no_decay_sorted) > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could result in less than 2 param groups overall. Just checking that there are no foreseeable problems with that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually removed some checks that relied on there always being two. I think that change should be in here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't recall the checks, but if you have accounted for this them I'm not worried
if pn.endswith("bias"): | ||
# all biases will not be decayed | ||
if pn.endswith("bias"): | ||
if cfg.optimizer.decay_norm_and_bias: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could potentially break up decay_norm_and_bias
further into decay_norm
and decay_bias
or similar.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we wanted to actually experiment with that, I would.
With my run, the goal, I think, is to be as close to the "original run" plus weight UN-tying. Will committing all the changes here cause unexpected discrepancies? For the specific decay-related arguments, is it ok to simply use the current main status and use the older |
The new run should decay everything, including embeddings. The new run should use the new flags.
If it does, then that's a problem in this PR. I hope it does not, but there is a lot of room for error here, which is why I'm glad we're all looking at it instead of just me. |
Argh, I think those function pointer shenanigans make |
Nevermind, |
Alright, I'll wait for this code to be merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This is a config that tracks Llama as closely as possible.
Differences that we know of:
Uses MQA instead of GQATODO: