Skip to content

Commit

Permalink
Merge pull request swiss-ai#7 from TJ-Solergibert/llama3_converter
Browse files Browse the repository at this point in the history
Llama3 conversion scripts 🦙
  • Loading branch information
ischlag authored Jul 2, 2024
2 parents e3ec5e3 + eb68e41 commit c104c34
Show file tree
Hide file tree
Showing 7 changed files with 743 additions and 40 deletions.
3 changes: 3 additions & 0 deletions src/nanotron/config/models_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class LlamaConfig:
rms_norm_eps: float = 1e-6
rope_scaling: Optional[dict] = None
rope_theta: float = 10000.0
rope_interleaved: bool = (
True # The default value has been True, but for loading Llama3 checkpoints you have to set it to False
)
tie_word_embeddings: bool = False
use_cache: bool = True
vocab_size: int = 32000
Expand Down
53 changes: 13 additions & 40 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
"""PyTorch LLaMa model."""

from typing import Dict, Optional, Union, List
from typing import Dict, Optional, Union

import torch
from torch import nn
Expand Down Expand Up @@ -188,35 +188,21 @@ def __init__(self, config: LlamaConfig, parallel_config: Optional[ParallelismArg
@checkpoint_method(attr_name="checkpoint_attention")
def forward(
self,
query_states: torch.Tensor, # [batch_size * q_length, n_local_q_heads, inner_dim]
key_states: torch.Tensor, # [batch_size * kv_length, n_local_kv_heads, inner_dim]
value_states: torch.Tensor, # [batch_size * kv_length, n_local_kv_heads, inner_dim]
q_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, q_length] (can be broadcasted to that size)
kv_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, kv_length] (can be broadcasted to that size)
query_states: torch.Tensor, # [batch_size, q_length, n_local_q_heads, inner_dim]
key_states: torch.Tensor, # [batch_size, kv_length, n_local_kv_heads, inner_dim]
value_states: torch.Tensor, # [batch_size, kv_length, n_local_kv_heads, inner_dim]
):
from flash_attn.flash_attn_interface import flash_attn_varlen_func

# TODO @thomasw21: Compute once, instead of computing for each layers.
cu_seqlens_q = torch.zeros((q_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device)
cu_seqlens_k = torch.zeros((kv_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device)
torch.cumsum(q_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_q[1:])
torch.cumsum(kv_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_k[1:])

# TODO(kunhao): flash attn's causal means that the query can only attend to the keys before it. This is not
# what we want if we are using kv cache. This is a hack as we always have q_length == 1 when using kv cache.
causal = False if q_sequence_mask.shape[1] == 1 else True
from flash_attn.flash_attn_interface import flash_attn_func

# NOTE: this scale is for µTransfer,
# in SP, we use sqrt(1/d_h)
softmax_scale = 1 / query_states.shape[-1] if self.is_using_mup else None
attn_output = flash_attn_varlen_func(
# For now we are assuming that we use causual mask. No magic here
causal = True
attn_output = flash_attn_func(
q=query_states,
k=key_states,
v=value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=q_sequence_mask.shape[1],
max_seqlen_k=kv_sequence_mask.shape[1],
dropout_p=0.0,
softmax_scale=softmax_scale,
causal=causal,
Expand Down Expand Up @@ -324,7 +310,9 @@ def __init__(
)

# NOTE: Only supported for training (TODO(fmom): position_ids not supported yet)
self.flash_rotary_embedding = FlashRotaryEmbedding(dim=self.d_qk, base=config.rope_theta, interleaved=True)
self.flash_rotary_embedding = FlashRotaryEmbedding(
dim=self.d_qk, interleaved=config.rope_interleaved, base=config.rope_theta
)

self.o_proj = TensorParallelRowLinear(
config.num_attention_heads * self.d_qk,
Expand Down Expand Up @@ -566,29 +554,14 @@ def forward(
# [batch_size, seq_length, num_heads, d_qk]
key_states, value_states = torch.split(key_value_states, 1, dim=2)

q_sequence_mask = sequence_mask
kv_sequence_mask = sequence_mask

kv_length = key_states.shape[1]
# [batch_size, seq_length, num_heads, d_qk]
# Shaping for use in `flash-attn` version of flash-attn: `flash_attn_unpadded_func`
query_states = query_states.view(
batch_size * q_length, self.n_local_q_heads, self.d_qk
) # [batch_size * q_length, self.n_heads, d_qk]

key_states = key_states.view(
batch_size * kv_length, self.n_local_kv_heads, self.d_qk
) # [batch_size * kv_length, self.n_heads, d_qk]
value_states = value_states.view(
batch_size * kv_length, self.n_local_kv_heads, self.d_v
) # [batch_size * kv_length, self.n_heads, d_v]
key_states = key_states.view(batch_size, kv_length, self.n_local_kv_heads, self.d_qk)
value_states = value_states.view(batch_size, kv_length, self.n_local_kv_heads, self.d_v)

attention_output = self.attention(
query_states=query_states,
key_states=key_states,
value_states=value_states,
q_sequence_mask=q_sequence_mask,
kv_sequence_mask=kv_sequence_mask,
)

attention_output = (
Expand Down
19 changes: 19 additions & 0 deletions tools/llama3/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Llama3 Weight conversion tool
This directory contains the scripts to convert the Llama3 checkpoints from HuggingFace to Nanotron and vice versa.

- Convert from HuggingFace to Nanotron

`torchrun --nproc-per-node 1 tools/llama3/convert_hf_to_nanotron.py --nanotron-checkpoint-path nanotron_checkpoints/Nanotron-Llama-3-8B --pretrained-model-name-or-path meta-llama/Meta-Llama-3-8B-Instruct`
- Convert from Nanotron to HuggingFace

`torchrun --nproc-per-node 1 tools/llama3/convert_nanotron_to_hf.py --nanotron-checkpoint-path nanotron_checkpoints/Nanotron-Llama3-8B --hugging-face-checkpoint-path hf_checkpoints/Converted-Nanotron-Llama-3-8B`

In summary, we will do the following:
- Initialize the HuggingFace model with the pretrained weights. The model definition is [here](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py).
- Initialize a Nanotron model with empty weights. The model definition is [here](https://github.com/huggingface/nanotron/blob/main/src/nanotron/models/llama.py).
- Copy the parameters layer by layer from one model to the other.
- Store the Nanotron model along with the tokenizer.

When comparing the HuggingFace implementation with the Nanotron implementation, the main difference lies in the Q, K & V matrices and in the MLP projections. In the HuggingFace implementation, these matrices are separated [[1]](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L415), [[2]](https://github.com/huggingface/transformers/blob/1518508467d96b3866fc4ebcb7a5b3a2e0df2aa4/src/transformers/models/llama/modeling_llama.py#L194), while in the Nanotron implementation, they are concatenated [[1b]](https://github.com/huggingface/nanotron/blob/b69690703a1c41b60cd706f92a80a3d23ebaf2d0/src/nanotron/models/llama.py#L310), [[2b]](https://github.com/huggingface/nanotron/blob/b69690703a1c41b60cd706f92a80a3d23ebaf2d0/src/nanotron/models/llama.py#L149). It is crucial to pay attention to these details to convert the models correctly.

To perform the conversion, we will need at least **1 GPU**, although the operations will be carried out on the **CPU**. We will convert the models with a parallel configuration of DP = PP = TP = 1, but it should be noted that the checkpoints generated by Nanotron are topology agnostic.
266 changes: 266 additions & 0 deletions tools/llama3/convert_hf_to_nanotron.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
"""
torchrun --nproc-per-node 1 tools/llama3/convert_hf_to_nanotron.py --nanotron-checkpoint-path nanotron_checkpoints/Nanotron-Llama-3-8B --pretrained-model-name-or-path meta-llama/Meta-Llama-3-8B-Instruct
"""
import argparse
import json
from dataclasses import asdict
from pathlib import Path

import torch
import yaml
from nanotron import logging
from nanotron.config import Config, GeneralArgs, LoggingArgs, ModelArgs, ParallelismArgs, TokenizerArgs
from nanotron.config.models_config import ExistingCheckpointInit
from nanotron.config.models_config import LlamaConfig as LlamaConfigNanotron
from nanotron.logging import log_rank, set_ranks_logging_level
from nanotron.models import build_model
from nanotron.models.llama import LlamaForTraining
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import sanity_check
from nanotron.serialize import TrainingMetadata, save_meta, save_weights
from nanotron.serialize.metadata import DataStageMetadata
from nanotron.trainer import mark_tied_parameters
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

logger = logging.get_logger(__name__)

DEVICE = torch.device("cpu")
TORCH_DTYPE = torch.bfloat16


def get_args():
parser = argparse.ArgumentParser()
group = parser.add_argument_group(title="Nanotron Model")
group.add_argument(
"--nanotron-checkpoint-path",
type=str,
required=True,
help="A path to a directory to store the converted Nanotron Checkpoint",
)

group = parser.add_argument_group(title="HuggingFace Model")
group.add_argument(
"--pretrained-model-name-or-path",
type=str,
required=True,
help="A path to a directory containing model weights saved using save_pretrained() or the model id of a pretrained model hosted inside a model repo on the Hugging Face Hub",
)

args = parser.parse_args()

return args


def main(args):
# Init Nanotron Parallel Utilities
parallel_config = ParallelismArgs(dp=1, pp=1, tp=1)

parallel_context = ParallelContext(
data_parallel_size=parallel_config.dp,
pipeline_parallel_size=parallel_config.pp,
tensor_parallel_size=parallel_config.tp,
)

set_ranks_logging_level(parallel_context=parallel_context, logging_config=LoggingArgs())

# Load Llama3-8B HF model
log_rank(
f"Loading pretrained Llama3 Model: {args.pretrained_model_name_or_path}",
logger=logger,
level=logging.INFO,
rank=0,
)
hf_model = AutoModelForCausalLM.from_pretrained(
args.pretrained_model_name_or_path, torch_dtype=TORCH_DTYPE, attn_implementation="flash_attention_2"
).to(DEVICE)
hf_config = hf_model.config

# Set Nanotron LlamaConfig
nanotron_llama_config = LlamaConfigNanotron(
bos_token_id=hf_config.bos_token_id,
eos_token_id=hf_config.eos_token_id,
hidden_act=hf_config.hidden_act,
hidden_size=hf_config.hidden_size,
initializer_range=hf_config.initializer_range,
intermediate_size=hf_config.intermediate_size,
is_llama_config=True,
max_position_embeddings=hf_config.max_position_embeddings,
num_attention_heads=hf_config.num_attention_heads,
num_hidden_layers=hf_config.num_hidden_layers,
num_key_value_heads=hf_config.num_key_value_heads,
pad_token_id=None,
pretraining_tp=hf_config.pretraining_tp,
rms_norm_eps=hf_config.rms_norm_eps,
rope_scaling=hf_config.rope_scaling,
rope_theta=hf_config.rope_theta,
rope_interleaved=False,
tie_word_embeddings=hf_config.tie_word_embeddings,
use_cache=hf_config.use_cache,
vocab_size=hf_config.vocab_size,
)

# Init Llama3-8B Nanotron model
log_rank("Init empty Nanotron Llama3 Model", logger=logger, level=logging.INFO, rank=0)
nanotron_model = build_model(
model_builder=lambda: LlamaForTraining(
config=nanotron_llama_config,
parallel_context=parallel_context,
parallel_config=parallel_config,
random_states=None,
),
parallel_context=parallel_context,
dtype=TORCH_DTYPE,
device=DEVICE,
)

mark_tied_parameters(model=nanotron_model, parallel_context=parallel_context)
sanity_check(root_module=nanotron_model)

# Copy params from HF to Nanotron
log_rank("Copying weights from HF model to Nanotron model...", logger=logger, level=logging.INFO, rank=0)
# Token embeddings
log_rank("Copying Token Embeddings...", logger=logger, level=logging.INFO, rank=0)
assert (
nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight.shape
== hf_model.model.embed_tokens.weight.shape
)
with torch.no_grad():
nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight.copy_(
hf_model.model.embed_tokens.weight
)

# Decoder layers
for i in tqdm(
range(nanotron_llama_config.num_hidden_layers),
desc="Copying Hidden Layers",
total=nanotron_llama_config.num_hidden_layers,
):
# Input layer norm
assert (
hf_model.model.layers[i].input_layernorm.weight.shape
== nanotron_model.model.decoder[i].pp_block.input_layernorm.weight.shape
)
with torch.no_grad():
nanotron_model.model.decoder[i].pp_block.input_layernorm.weight.copy_(
hf_model.model.layers[i].input_layernorm.weight
)

# Self attn
## QKV
tmp_qkv_proj = torch.cat(
[
hf_model.model.layers[i].self_attn.q_proj.weight,
hf_model.model.layers[i].self_attn.k_proj.weight,
hf_model.model.layers[i].self_attn.v_proj.weight,
],
dim=0,
)
assert tmp_qkv_proj.shape == nanotron_model.model.decoder[i].pp_block.attn.qkv_proj.weight.shape
with torch.no_grad():
nanotron_model.model.decoder[i].pp_block.attn.qkv_proj.weight.copy_(tmp_qkv_proj)

## O
assert (
hf_model.model.layers[i].self_attn.o_proj.weight.shape
== nanotron_model.model.decoder[i].pp_block.attn.o_proj.weight.shape
)
with torch.no_grad():
nanotron_model.model.decoder[i].pp_block.attn.o_proj.weight.copy_(
hf_model.model.layers[i].self_attn.o_proj.weight
)

# MLP
## Gate Up Proj
tmp_gate_up_proj = torch.cat(
[
hf_model.model.layers[i].mlp.gate_proj.weight,
hf_model.model.layers[i].mlp.up_proj.weight,
],
dim=0,
)

assert tmp_gate_up_proj.shape == nanotron_model.model.decoder[i].pp_block.mlp.gate_up_proj.weight.shape
with torch.no_grad():
nanotron_model.model.decoder[i].pp_block.mlp.gate_up_proj.weight.copy_(tmp_gate_up_proj)

## Down Proj
assert (
hf_model.model.layers[i].mlp.down_proj.weight.shape
== nanotron_model.model.decoder[i].pp_block.mlp.down_proj.weight.shape
)
with torch.no_grad():
nanotron_model.model.decoder[i].pp_block.mlp.down_proj.weight.copy_(
hf_model.model.layers[i].mlp.down_proj.weight
)

# Post attn layer norm
assert (
hf_model.model.layers[i].post_attention_layernorm.weight.shape
== nanotron_model.model.decoder[i].pp_block.post_attention_layernorm.weight.shape
)
with torch.no_grad():
nanotron_model.model.decoder[i].pp_block.post_attention_layernorm.weight.copy_(
hf_model.model.layers[i].post_attention_layernorm.weight
)

# Last layer norm
log_rank("Copying Final Layer Norm...", logger=logger, level=logging.INFO, rank=0)
assert nanotron_model.model.final_layer_norm.pp_block.weight.shape == hf_model.model.norm.weight.shape
with torch.no_grad():
nanotron_model.model.final_layer_norm.pp_block.weight.copy_(hf_model.model.norm.weight)

# LM_Head
log_rank("Copying LM Head...", logger=logger, level=logging.INFO, rank=0)
assert nanotron_model.model.lm_head.pp_block.weight.shape == hf_model.lm_head.weight.shape
with torch.no_grad():
nanotron_model.model.lm_head.pp_block.weight.copy_(hf_model.lm_head.weight)

log_rank("Copied weights from HF model to Nanotron model!", logger=logger, level=logging.INFO, rank=0)
# Store weights
nanotron_checkpoint_path = Path(args.nanotron_checkpoint_path)
save_weights(model=nanotron_model, parallel_context=parallel_context, root_folder=nanotron_checkpoint_path)

# Store metadata
log_rank("Storing Nanotron model Configs and Metadata!", logger=logger, level=logging.INFO, rank=0)
training_metadata = TrainingMetadata(
last_train_step=0,
consumed_train_samples=0,
data_stages=[DataStageMetadata(name="Empty", consumed_train_samples=0, start_training_step=0)],
)
save_meta(
root_folder=nanotron_checkpoint_path, parallel_context=parallel_context, training_metadata=training_metadata
)
# Store Tokenizer into Nanotron Checkpoint folder
tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path)
tokenizer.save_pretrained(nanotron_checkpoint_path)

# Store Config and Model Config files
with open(nanotron_checkpoint_path / "config.yaml", "w") as f:
config = Config(
general=GeneralArgs(project="Nanotron", run="Llama3"),
parallelism=parallel_config,
model=ModelArgs(
init_method=ExistingCheckpointInit(nanotron_checkpoint_path),
model_config=nanotron_llama_config,
),
tokenizer=TokenizerArgs(nanotron_checkpoint_path),
)
log_rank("Saving config ...", logger=logger, level=logging.INFO, rank=0)
yaml.dump(config.as_dict(), f)

with open(nanotron_checkpoint_path / "model_config.json", "w") as f:
log_rank("Saving model config ...", logger=logger, level=logging.INFO, rank=0)
json.dump(asdict(nanotron_llama_config), f)

log_rank(
f"Checkpoint conversion finished, check {args.nanotron_checkpoint_path}",
logger=logger,
level=logging.INFO,
rank=0,
)


if __name__ == "__main__":
_args = get_args()
main(_args)
Loading

0 comments on commit c104c34

Please sign in to comment.