Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama Config #317

Merged
merged 115 commits into from
Nov 2, 2023
Merged

Llama Config #317

merged 115 commits into from
Nov 2, 2023

Conversation

dirkgr
Copy link
Member

@dirkgr dirkgr commented Oct 5, 2023

This is a config that tracks Llama as closely as possible.

Differences that we know of:

  • Uses MQA instead of GQA
  • Different data
  • Different tokenizer

TODO:

@dirkgr dirkgr requested a review from 2015aroras October 31, 2023 19:13
@dirkgr
Copy link
Member Author

dirkgr commented Oct 31, 2023

I want to merge this as it is, and add @2015aroras's Llama block later. @AkshitaB needs some of the changes in here for her config.

@dirkgr
Copy link
Member Author

dirkgr commented Oct 31, 2023

My top concern is whether and of the code changed in here affect the stability of existing runs, like mitchish.

@dirkgr dirkgr requested a review from AkshitaB October 31, 2023 19:16
@@ -423,8 +424,8 @@ class OptimizerConfig(BaseConfig):
learning_rate: float = 1.0e-4
weight_decay: float = 0.01
betas: Tuple[float, float] = (0.9, 0.95)
no_decay_norm_and_bias: bool = True
"""Do not apply weight decay to norms and biases."""
decay_norm_and_bias: bool = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a breaking change. I'd suggest keeping the old option and marking it as deprecated, or we need a preprocessor that renames the old option to the new one when present.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renames, and also, changes the value to not value.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did that in 7244c0b. Please read that carefully! There are a lot of not and "no"s around, and this is exactly the kind of thing where you can run for 500B tokens before you notice you screwed up.

@@ -15,6 +15,7 @@ def init_weights(
d: Optional[int] = None,
layer_id: Optional[int] = None,
std_factor: float = 1.0,
type_of_module: str = "",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: make a StrEnum for this

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

olmo/config.py Outdated
@@ -594,6 +595,15 @@ class ShardedCheckpointerType(StrEnum):
local = "local"


class ActivationCheckpointingStrategy(StrEnum):
none = "none"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: instead of having a "none" variant we could make ActivationCheckpointStrategy optional. I think that's more Pythonic.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

@2015aroras 2015aroras left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM other than Pete's comment on the breaking change

},
}
)
if len(no_decay_sorted) > 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could result in less than 2 param groups overall. Just checking that there are no foreseeable problems with that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually removed some checks that relied on there always being two. I think that change should be in here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't recall the checks, but if you have accounted for this them I'm not worried

if pn.endswith("bias"):
# all biases will not be decayed
if pn.endswith("bias"):
if cfg.optimizer.decay_norm_and_bias:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could potentially break up decay_norm_and_bias further into decay_norm and decay_bias or similar.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we wanted to actually experiment with that, I would.

@AkshitaB
Copy link
Contributor

My top concern is whether and of the code changed in here affect the stability of existing runs, like mitchish.

With my run, the goal, I think, is to be as close to the "original run" plus weight UN-tying. Will committing all the changes here cause unexpected discrepancies? For the specific decay-related arguments, is it ok to simply use the current main status and use the older no_decay_norm_and_bias flag? The caveat is that it won't decay embeddings, if I'm reading it right.

@dirkgr
Copy link
Member Author

dirkgr commented Nov 1, 2023

The caveat is that it won't decay embeddings, if I'm reading it right.

The new run should decay everything, including embeddings. The new run should use the new flags.

Will committing all the changes here cause unexpected discrepancies?

If it does, then that's a problem in this PR. I hope it does not, but there is a lot of room for error here, which is why I'm glad we're all looking at it instead of just me.

@dirkgr
Copy link
Member Author

dirkgr commented Nov 1, 2023

Argh, I think those function pointer shenanigans make torch.compile() not work :-(

@dirkgr
Copy link
Member Author

dirkgr commented Nov 1, 2023

Nevermind, compile() and activation checkpointing were never going to work at the same time.

@AkshitaB
Copy link
Contributor

AkshitaB commented Nov 1, 2023

The new run should decay everything, including embeddings. The new run should use the new flags.

Alright, I'll wait for this code to be merged.

Copy link
Collaborator

@2015aroras 2015aroras left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@dirkgr dirkgr merged commit da91f34 into main Nov 2, 2023
10 checks passed
@dirkgr dirkgr deleted the Llama branch November 2, 2023 00:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants