From df831b309173b65acaef7cadd7a97ba9c24d6c4c Mon Sep 17 00:00:00 2001 From: Xingyao Wang Date: Tue, 7 Nov 2023 23:27:14 -0600 Subject: [PATCH 01/11] bump flash-attn version --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 95b5401..4c1bcde 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ transformers >= 4.31.0 torch >= 2.0.0 -flash-attn >= 2.0.0 +flash-attn >= 2.3.3 datasets >= 2.14.0 nltk >= 3.8.0 sentencepiece >= 0.1.0 From 3e189d375b983e9b9c231b12c9ccbc54e1963307 Mon Sep 17 00:00:00 2001 From: Xingyao Wang Date: Fri, 10 Nov 2023 23:52:29 +0000 Subject: [PATCH 02/11] cherry pick wip mistral --- megatron/arguments.py | 3 ++ megatron/model/mistral_model.py | 46 +++++++++++++++++++++++++ megatron/model/transformer.py | 39 ++++++++++++++++++--- weights_conversion/hf_to_megatron.py | 51 +++++++++++++++++++++++++++- 4 files changed, 134 insertions(+), 5 deletions(-) create mode 100644 megatron/model/mistral_model.py diff --git a/megatron/arguments.py b/megatron/arguments.py index 4464859..ad18fbf 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -475,6 +475,9 @@ def _add_network_size_args(parser): group.add_argument("--no_tie_embed_logits", action="store_false", dest="tie_embed_logits", help=("If set, the weights of the word embedding and lm_head " "are not tied")) + # Added for Mistral + group.add_argument("--sliding_window", type=int, default=None, + help="Whether to use sliding window attention for Mistral") return parser diff --git a/megatron/model/mistral_model.py b/megatron/model/mistral_model.py new file mode 100644 index 0000000..76a1927 --- /dev/null +++ b/megatron/model/mistral_model.py @@ -0,0 +1,46 @@ +"""Mistral Model.""" + +import warnings + +from megatron import get_args +from .enums import PositionEmbeddingType +from . import GPTModel + + +class MistralModel(GPTModel): + def __init__(self, + num_tokentypes: int = 0, + parallel_output: bool = True, + pre_process: bool = True, + post_process: bool = True, + model_type=None + ): + + args = get_args() + + # mandatory arguments + assert args.position_embedding_type == PositionEmbeddingType.rotary, \ + f"Mistral uses rotary embedding, not {args.position_embedding_type}" + assert not args.use_post_ln, "Mistral does not use post_ln" + assert args.glu_activation == "swiglu", "Mistral works with swiglu activation" + assert not args.use_bias, "Mistral does not use bias" + assert not args.parallel_attn, "Mistral does not use parallel_attn" + assert args.use_rms_norm, "Mistral uses rms_norm" + assert not args.tie_embed_logits , "Mistral unties embedding and lm_head weights" + assert args.sliding_window == 4096, "Mistral uses sliding window attention (sliding_window=4096)" + + # recomended arguments + if not args.use_flash_attn: + warnings.warn("Mistral should use flash attn (for sliding window local attention)") + + if args.bias_gelu_fusion: + warnings.warn("Mistral is not intended to use bias_gelu_fusion") + if args.bias_dropout_fusion: + warnings.warn("Mistral is not intended to use bias_dropout_fusion") + if args.hidden_dropout > 0.0 and not args.lima_dropout: + warnings.warn("Mistral is not intended to use dropout") + if args.attention_dropout > 0.0: + warnings.warn("Mistral is not intended to use dropout") + super().__init__(num_tokentypes=num_tokentypes, parallel_output=parallel_output, + pre_process=pre_process, post_process=post_process, + model_type=model_type) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index a9691ce..fe8c69f 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -23,7 +23,7 @@ # Extracted from: https://github.com/bigscience-workshop/Megatron-DeepSpeed from .glu_activations import GLU_ACTIVATIONS from megatron.model.positional_embeddings import precompute_freqs_cis, apply_rotary_emb - +from flash_attn.bert_padding import pad_input, unpad_input_for_concatenated_sequences """ We use the following notation throughout this file: h: hidden size @@ -301,14 +301,24 @@ def __init__(self, self.params_dtype = args.params_dtype self.sequence_parallel = args.sequence_parallel self.use_flash_attn = args.use_flash_attn + self.sliding_window_size = args.sliding_window_size self.num_attention_heads_kv = args.num_attention_heads_kv self.num_attention_heads = args.num_attention_heads self.seq_length = args.seq_length + self.packed_input = args.packed_input if self.use_flash_attn: assert attention_type == AttnType.self_attn, ('FlashAttention code path only supports ' 'self-attention for now') assert self.attn_mask_type == AttnMaskType.causal, ('FlashAttention code path only ' 'supports causal mask for now') + # If sliding window is enabled, we need to make sure that the sliding window is supported. + if self.sliding_window_size is not None: + import inspect + # https://github.com/huggingface/transformers/blob/7e1eff7600085814eac65876d4d8a0e38c2f6ccc/src/transformers/models/mistral/modeling_mistral.py#L50C5-L50C32 + assert "window_size" in list(inspect.signature( + flash_attn.flash_attn_func + ).parameters), "The current flash attention version does not support sliding window attention, please update to the latest version." + assert self.use_flash_attn, "Sliding window attention is only supported with flash attention for now." projection_size = args.kv_channels * args.num_attention_heads @@ -513,13 +523,34 @@ def forward(self, context_layer = self.core_attention( query_layer, key_layer, value_layer, attention_mask) else: + flash_attn_extra_kwargs = {} + # check if we need to use sliding window attention + # https://github.com/huggingface/transformers/blob/7ee995fd9c692761c4601ddbffa2ac2ec9f27b0b/src/transformers/models/mistral/modeling_mistral.py#L353 + if self.sliding_window_size is not None: + kv_seq_len = key_layer.shape[0] + if kv_seq_len > self.sliding_window_size: + # https://github.com/huggingface/transformers/blob/7ee995fd9c692761c4601ddbffa2ac2ec9f27b0b/src/transformers/models/mistral/modeling_mistral.py#L510C21-L510C89 + flash_attn_extra_kwargs["window_size"] = ( + self.sliding_window_size, self.sliding_window_size + ) + # It will be truncated to the actual sequence length inside flash attention + # https://github.com/Dao-AILab/flash-attention/blob/83aef842beec1037eb8c1d9c3ef3ed8aae80b091/csrc/flash_attn/src/softmax.h#L159-L161 + q, k, v = [rearrange(x, "s b n h -> b s n h").contiguous() - for x in (query_layer, key_layer, value_layer)] + for x in (query_layer, key_layer, value_layer)] if not self.sequence_parallel: with megatron.core.tensor_parallel.get_cuda_rng_tracker().fork(): - context_layer = self.core_attention_flash(q, k, v, causal=True) + context_layer = self.core_attention_flash( + q, k, v, + causal=True, + **flash_attn_extra_kwargs + ) else: - context_layer = self.core_attention_flash(q, k, v, causal=True) + context_layer = self.core_attention_flash( + q, k, v, + causal=True, + **flash_attn_extra_kwargs + ) context_layer = rearrange(context_layer, 'b s n h -> s b (n h)').contiguous() # ================= diff --git a/weights_conversion/hf_to_megatron.py b/weights_conversion/hf_to_megatron.py index 4a1f777..a0e4855 100644 --- a/weights_conversion/hf_to_megatron.py +++ b/weights_conversion/hf_to_megatron.py @@ -195,6 +195,14 @@ def main(model_name: str = "falcon", size: int = 7, out: Optional[Path] = None, trust_remote_code=True, cache_dir=cache_dir) hf_weights = model.state_dict() + elif model_name == "mistral": + print("Fetching weights from huggingface") + if model_path is None: + model_path = "mistralai/Mistral-7B-v0.1" + model = AutoModelForCausalLM.from_pretrained(model_path, + trust_remote_code=True, + cache_dir=cache_dir) + hf_weights = model.state_dict() else: print("Getting llama...") version = 2 if "2" in model_name else 1 @@ -222,6 +230,30 @@ def main(model_name: str = "falcon", size: int = 7, out: Optional[Path] = None, "hidden_dropout": 0.0, "parallel_attn": True, "max_position_embeddings": 2048, "seq_length": 2048}) + elif model_name == "mistral": + assert size == 7 + # mistral-7b mostly uses the same args as llama-7b + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + args = { + "num_layers": 32, + "hidden_size": 4096, + "num_attention_heads": 32, + "num_attention_heads_kv": 8, # except this - GroupedAttention + "ffn_hidden_size": 14336, # except this + "parallel_attn": False, + "make_vocab_size_divisible_by": 128, + "glu_activation": "swiglu", # == silu + "padded_vocab_size": 32000, + "use_rms_norm": True, + "tie_embed_logits": False, + "tokenizer_type": "SentencePieceTokenizer", + + "max_position_embeddings": 32768, + "seq_length": 32768, + "layernorm_epsilon": 1e-5, + "rope_theta": 10000.0, + # "sliding_window": 4096, + } else: # llama1, llama2, codellama args = {"num_layers": llama_s2layer[size], "hidden_size": llama_s2hidden[size], @@ -290,6 +322,21 @@ def main(model_name: str = "falcon", size: int = 7, out: Optional[Path] = None, vocab_file = tokenizer.vocab_file shutil.copy(vocab_file, token_path) print("Saved tokenizer.model in", token_path) + elif model_name == "mistral": + tokenizer = None + if model_path is not None: + try: + tokenizer = LlamaTokenizer.from_pretrained(model_path, cache_dir=cache_dir) + except OSError: + warnings.warn(f"Model path {model_path} does not have a " + "tokenizer, using default tokenizer instead") + if tokenizer is None: + tokenizer = LlamaTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", + cache_dir=cache_dir) + token_path = out/"tokenizer.model" + vocab_file = tokenizer.vocab_file + shutil.copy(vocab_file, token_path) + print("Saved tokenizer.model in", token_path) print("Done") @@ -297,7 +344,7 @@ def main(model_name: str = "falcon", size: int = 7, out: Optional[Path] = None, if __name__ == "__main__": parser = ArgumentParser(description="Convert Huggingface llama or falcon weights to " "megatron-compatible weights") - parser.add_argument("model", choices={"falcon", "llama", "llama2", "codellama"}) + parser.add_argument("model", choices={"falcon", "llama", "llama2", "codellama", "mistral"}) parser.add_argument("--size", default=7, choices={7, 13, 30, 34, 40, 65, 70}, type=int, help="The size of the model") parser.add_argument("--out", type=Path, @@ -317,6 +364,8 @@ def main(model_name: str = "falcon", size: int = 7, out: Optional[Path] = None, assert args.size in {7, 13, 30, 65} elif args.model == "codellama": assert args.size in {7, 13, 34} + elif args.model == "mistral": + assert args.size in {7} else: assert args.size in {7, 13, 70} From a805d7b9bd73009eb2b2570f4d9a6d86921a6acc Mon Sep 17 00:00:00 2001 From: Xingyao Wang Date: Fri, 10 Nov 2023 18:59:58 -0600 Subject: [PATCH 03/11] support ht to megatron for mistral --- tools/checkpoint_loader_megatron.py | 2 + weights_conversion/hf_to_megatron.py | 77 +++++++++++++++++++++++++++- 2 files changed, 78 insertions(+), 1 deletion(-) diff --git a/tools/checkpoint_loader_megatron.py b/tools/checkpoint_loader_megatron.py index 6fa6b68..77bc521 100644 --- a/tools/checkpoint_loader_megatron.py +++ b/tools/checkpoint_loader_megatron.py @@ -185,6 +185,8 @@ def _get_models(count, dtype, pre_process, post_process): md.glu_activation = margs.glu_activation md.tie_embed_logits = margs.tie_embed_logits md.params_dtype = margs.params_dtype + if hasattr(margs, "sliding_window_size"): + md.sliding_window_size = margs.sliding_window_size if margs.position_embedding_type == PositionEmbeddingType.absolute: md.position_embedding_type = "absolute" elif margs.position_embedding_type == PositionEmbeddingType.rotary: diff --git a/weights_conversion/hf_to_megatron.py b/weights_conversion/hf_to_megatron.py index a0e4855..2471aba 100644 --- a/weights_conversion/hf_to_megatron.py +++ b/weights_conversion/hf_to_megatron.py @@ -181,6 +181,79 @@ def rearrange_qkv(wq, wk, wv): "lm_head": lm_head} +def mistral_to_megatron( + weights: dict, + size: int +) -> dict: + assert size == 7 + def permute(qkv_w): + # if source == "hf": + # by default, we pull mistrals weights from huggingface + return permute_qkv(qkv_w, hidden, n_heads, n_kv_heads) + # return qkv_w + + def rearrange_qkv(wq, wk, wv): + wq = torch.split(wq, n_hidden_per_head, dim=0) + wk = torch.split(wk, n_hidden_per_head, dim=0) + wv = torch.split(wv, n_hidden_per_head, dim=0) + assert len(wq) == n_heads + assert len(wk) == n_kv_heads + assert len(wv) == n_kv_heads + n_qs_per_kv = n_heads//n_kv_heads + w_qkv = [] + for i in range(n_kv_heads): + w_qkv += [wq[i*n_qs_per_kv + j] for j in range(n_qs_per_kv)] + w_qkv += [wk[i], wv[i]] + return permute(torch.concat(w_qkv)) + + # config + if size == 7: + n_layer = 32 + hidden = 4096 + n_heads = 32 + n_kv_heads = 8 + n_hidden_per_head = hidden // n_heads + + # weights independent of layers + embedding = {"word_embeddings.weight": weights["model.embed_tokens.weight"]} + transformer = {"final_layernorm.weight": weights["model.norm.weight"]} + lm_head = weights["lm_head.weight"] + + # get all the other weights + for layer in trange(n_layer, desc="Converting weights"): + prefix = f"model.layers.{layer}" + # identical weights + transformer[f"{prefix}.attention.dense.weight"] = \ + weights[f"{prefix}.self_attn.o_proj.weight"] + transformer[f"{prefix}.post_attention_layernorm.weight"] = \ + weights[f"{prefix}.post_attention_layernorm.weight"] + transformer[f"{prefix}.input_layernorm.weight"] = \ + weights[f"{prefix}.input_layernorm.weight"] + transformer[f"{prefix}.mlp.dense_4h_to_h.weight"] = \ + weights[f"{prefix}.mlp.down_proj.weight"] + # concatenate up, gate mlp weights + transformer[f"{prefix}.mlp.dense_h_to_4h.weight"] = torch.concat([ + weights[f"{prefix}.mlp.up_proj.weight"], # w3 + weights[f"{prefix}.mlp.gate_proj.weight"] # w1 + ]) + # finally, qkv requires serious manipulation to get right (probably same as llama-2) + transformer[f"{prefix}.attention.query_key_value.weight"] = rearrange_qkv( + weights[f"{prefix}.self_attn.q_proj.weight"], + weights[f"{prefix}.self_attn.k_proj.weight"], + weights[f"{prefix}.self_attn.v_proj.weight"] + ) + + # release references to original weights (free mem) + del weights[f"{prefix}.mlp.up_proj.weight"] + del weights[f"{prefix}.mlp.gate_proj.weight"] + del weights[f"{prefix}.self_attn.q_proj.weight"] + del weights[f"{prefix}.self_attn.k_proj.weight"] + del weights[f"{prefix}.self_attn.v_proj.weight"] + + return {"embedding": embedding, "transformer": transformer, + "lm_head": lm_head} + + def main(model_name: str = "falcon", size: int = 7, out: Optional[Path] = None, cache_dir: Optional[Path] = None, model_path: Optional[str] = None): if out is None: @@ -212,6 +285,8 @@ def main(model_name: str = "falcon", size: int = 7, out: Optional[Path] = None, # convert state dict to be megatron-compatible if model_name == "falcon": megatron_weights = falcon_to_megatron(hf_weights, size) + elif model_name == "mistral": + megatron_weights = mistral_to_megatron(hf_weights, size) else: megatron_weights = llama_to_megatron(hf_weights, size, llama_source, version=1 if model_name == "llama" else 2) @@ -252,7 +327,7 @@ def main(model_name: str = "falcon", size: int = 7, out: Optional[Path] = None, "seq_length": 32768, "layernorm_epsilon": 1e-5, "rope_theta": 10000.0, - # "sliding_window": 4096, + "sliding_window_size": 4096, } else: # llama1, llama2, codellama args = {"num_layers": llama_s2layer[size], From e38f1953108032c3900adecd0d9fe79a258aed32 Mon Sep 17 00:00:00 2001 From: Xingyao Wang Date: Fri, 10 Nov 2023 19:06:22 -0600 Subject: [PATCH 04/11] adjust arg name to make compatible w/ ckpt --- finetune.py | 8 +++++--- megatron/arguments.py | 4 ++-- megatron/checkpointing.py | 1 + megatron/model/__init__.py | 1 + megatron/model/mistral_model.py | 2 +- tools/checkpoint_loader_megatron.py | 3 +-- 6 files changed, 11 insertions(+), 8 deletions(-) diff --git a/finetune.py b/finetune.py index 22e6b7e..7ef5bec 100644 --- a/finetune.py +++ b/finetune.py @@ -9,7 +9,7 @@ from megatron.training import pretrain from megatron.core import tensor_parallel from megatron.core.parallel_state import get_data_parallel_group -from megatron.model import GPTModel, ModelType, LlamaModel, FalconModel +from megatron.model import GPTModel, ModelType, LlamaModel, FalconModel, MistralModel from megatron.utils import get_ltor_masks_and_position_ids, average_losses_across_data_parallel_group from megatron.data.gpt_dataset import build_train_valid_test_datasets as gpt_build_datasets from megatron.data.instruction_dataset import instruction_collator @@ -35,8 +35,10 @@ def model_provider(pre_process: bool = True, post_process: bool = True): cls = FalconModel elif args.model_name in {"llama", "llama2", "codellama"}: cls = partial(LlamaModel, version=1 if args.model_name == "llama" else 2) + elif args.model_name == "mistral": + cls = MistralModel else: - raise KeyError(f"Unkown model {other}") + raise KeyError(f"Unkown model {args.model_name}") if isinstance(args.model_type, ModelType): model_type = args.model_type @@ -238,7 +240,7 @@ def extra_args(parser): """Text generation arguments.""" group = parser.add_argument_group(title='validation set') group.add_argument("--model_name", - choices={"gpt", "llama", "falcon", "llama2", "codellama"}, + choices={"gpt", "llama", "falcon", "llama2", "codellama", "mistral"}, default="gpt") group.add_argument("--model_type", choices={"encoder_or_decoder", "encoder_and_decoder"}, default="encoder_or_decoder") diff --git a/megatron/arguments.py b/megatron/arguments.py index ad18fbf..2a384f6 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -476,8 +476,8 @@ def _add_network_size_args(parser): help=("If set, the weights of the word embedding and lm_head " "are not tied")) # Added for Mistral - group.add_argument("--sliding_window", type=int, default=None, - help="Whether to use sliding window attention for Mistral") + group.add_argument("--sliding_window_size", type=int, default=None, + help="Whether to use sliding window attention for Mistral. Default is None, which means no sliding window attention.") return parser diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index 3136082..62ecee1 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -549,6 +549,7 @@ def _set_arg(arg_name, old_arg_name=None, force=False): _set_arg('tie_embed_logits', force=True) _set_arg('make_vocab_size_divisible_by', force=True) _set_arg('train_iters') + _set_arg('sliding_window_size') if checkpoint_version < 3.0: _set_arg('tensor_model_parallel_size', 'model_parallel_size') diff --git a/megatron/model/__init__.py b/megatron/model/__init__.py index 118ad93..83dc27b 100644 --- a/megatron/model/__init__.py +++ b/megatron/model/__init__.py @@ -8,6 +8,7 @@ from .gpt_model import GPTModel from .llama_model import LlamaModel from .falcon_model import FalconModel +from .mistral_model import MistralModel from .t5_model import T5Model from .module import Float16Module from .enums import ModelType diff --git a/megatron/model/mistral_model.py b/megatron/model/mistral_model.py index 76a1927..8c05dc5 100644 --- a/megatron/model/mistral_model.py +++ b/megatron/model/mistral_model.py @@ -27,7 +27,7 @@ def __init__(self, assert not args.parallel_attn, "Mistral does not use parallel_attn" assert args.use_rms_norm, "Mistral uses rms_norm" assert not args.tie_embed_logits , "Mistral unties embedding and lm_head weights" - assert args.sliding_window == 4096, "Mistral uses sliding window attention (sliding_window=4096)" + assert args.sliding_window_size == 4096, "Mistral uses sliding window attention (sliding_window=4096)" # recomended arguments if not args.use_flash_attn: diff --git a/tools/checkpoint_loader_megatron.py b/tools/checkpoint_loader_megatron.py index 77bc521..a2db57a 100644 --- a/tools/checkpoint_loader_megatron.py +++ b/tools/checkpoint_loader_megatron.py @@ -185,8 +185,7 @@ def _get_models(count, dtype, pre_process, post_process): md.glu_activation = margs.glu_activation md.tie_embed_logits = margs.tie_embed_logits md.params_dtype = margs.params_dtype - if hasattr(margs, "sliding_window_size"): - md.sliding_window_size = margs.sliding_window_size + md.sliding_window_size = margs.sliding_window_size if margs.position_embedding_type == PositionEmbeddingType.absolute: md.position_embedding_type = "absolute" elif margs.position_embedding_type == PositionEmbeddingType.rotary: From c972d8d478d01ff0e4652724620a151e62fa9f7d Mon Sep 17 00:00:00 2001 From: Xingyao Wang Date: Fri, 10 Nov 2023 19:06:44 -0600 Subject: [PATCH 05/11] support verify correctness for mistral --- verify_correctness.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/verify_correctness.py b/verify_correctness.py index d771c9b..6f4143b 100644 --- a/verify_correctness.py +++ b/verify_correctness.py @@ -7,7 +7,7 @@ import torch import llama from torch import nn -from transformers import AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizer +from transformers import AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizer, MistralForCausalLM from fairscale.nn.model_parallel.initialize import initialize_model_parallel from megatron import get_args, update_num_microbatches @@ -72,6 +72,16 @@ def hf_provider(name: str, cache_dir: Optional[Path], device: str, model = LlamaForCausalLM.from_pretrained( f"meta-llama/Llama-2-{size}b-hf", cache_dir=cache_dir ) + elif name == "mistral": + assert size == 7, "Mistral only supports 7B model" + try: + model = MistralForCausalLM.from_pretrained(cache_dir) + except OSError: + print(f"Cache dir {cache_dir} does not look like a huggingface " + "checkpoint, assuming cache_dir instead") + model = MistralForCausalLM.from_pretrained( + f"mistralai/Mistral-{size}B-v0.1", cache_dir=cache_dir + ) else: raise KeyError(f"Model {name} not implemented") return model.eval().requires_grad_(False).to(device) @@ -114,7 +124,7 @@ def verify_step(our_forward, our_model, base_forward, base_model, batch): our_logits, our_loss = our_forward(our_model, batch) base_logits, base_loss = base_forward(base_model, batch) assert our_logits.size() == base_logits.size(), \ - f"ours={logits1.size()}, true={logits2.size()}" + f"ours={our_logits.size()}, true={base_logits.size()}" our_logits = our_logits.cpu() base_logits = base_logits.cpu() abs_error = torch.abs(our_logits - base_logits) @@ -192,9 +202,9 @@ def extra_extra_args(parser): if __name__ == "__main__": defaults = {"micro_batch_size": 1, "use_checkpoint_args": True, "train_iters": 10, "lr": 1.0} - if not is_megatron_path(parse_args(extra_extra_args).load): - defaults.update({"encoder_num_layers": 1, "hidden_size": 1, - "num_attention_heads": 1, "seq_length": 2048, - "max_position_embeddings": 2048}) + # if not is_megatron_path(parse_args(extra_extra_args).load): + # defaults.update({"encoder_num_layers": 1, "hidden_size": 1, + # "num_attention_heads": 1, "seq_length": 2048, + # "max_position_embeddings": 2048}) initialize_megatron(extra_extra_args, args_defaults=defaults) main() From 6470ca63b7b64f925ce61fd7d4ca571168ef6459 Mon Sep 17 00:00:00 2001 From: Xingyao Wang Date: Sat, 11 Nov 2023 03:58:45 +0000 Subject: [PATCH 06/11] fix megatron conversion for mistral --- weights_conversion/hf_to_megatron.py | 31 ++++++++++++++-------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/weights_conversion/hf_to_megatron.py b/weights_conversion/hf_to_megatron.py index 2471aba..4d5fb90 100644 --- a/weights_conversion/hf_to_megatron.py +++ b/weights_conversion/hf_to_megatron.py @@ -221,34 +221,35 @@ def rearrange_qkv(wq, wk, wv): # get all the other weights for layer in trange(n_layer, desc="Converting weights"): - prefix = f"model.layers.{layer}" + prefix = f"layers.{layer}" + hf_prefix = f"model.{prefix}" # identical weights transformer[f"{prefix}.attention.dense.weight"] = \ - weights[f"{prefix}.self_attn.o_proj.weight"] + weights[f"{hf_prefix}.self_attn.o_proj.weight"] transformer[f"{prefix}.post_attention_layernorm.weight"] = \ - weights[f"{prefix}.post_attention_layernorm.weight"] + weights[f"{hf_prefix}.post_attention_layernorm.weight"] transformer[f"{prefix}.input_layernorm.weight"] = \ - weights[f"{prefix}.input_layernorm.weight"] + weights[f"{hf_prefix}.input_layernorm.weight"] transformer[f"{prefix}.mlp.dense_4h_to_h.weight"] = \ - weights[f"{prefix}.mlp.down_proj.weight"] + weights[f"{hf_prefix}.mlp.down_proj.weight"] # concatenate up, gate mlp weights transformer[f"{prefix}.mlp.dense_h_to_4h.weight"] = torch.concat([ - weights[f"{prefix}.mlp.up_proj.weight"], # w3 - weights[f"{prefix}.mlp.gate_proj.weight"] # w1 + weights[f"{hf_prefix}.mlp.up_proj.weight"], # w3 + weights[f"{hf_prefix}.mlp.gate_proj.weight"] # w1 ]) # finally, qkv requires serious manipulation to get right (probably same as llama-2) transformer[f"{prefix}.attention.query_key_value.weight"] = rearrange_qkv( - weights[f"{prefix}.self_attn.q_proj.weight"], - weights[f"{prefix}.self_attn.k_proj.weight"], - weights[f"{prefix}.self_attn.v_proj.weight"] + weights[f"{hf_prefix}.self_attn.q_proj.weight"], + weights[f"{hf_prefix}.self_attn.k_proj.weight"], + weights[f"{hf_prefix}.self_attn.v_proj.weight"] ) # release references to original weights (free mem) - del weights[f"{prefix}.mlp.up_proj.weight"] - del weights[f"{prefix}.mlp.gate_proj.weight"] - del weights[f"{prefix}.self_attn.q_proj.weight"] - del weights[f"{prefix}.self_attn.k_proj.weight"] - del weights[f"{prefix}.self_attn.v_proj.weight"] + del weights[f"{hf_prefix}.mlp.up_proj.weight"] + del weights[f"{hf_prefix}.mlp.gate_proj.weight"] + del weights[f"{hf_prefix}.self_attn.q_proj.weight"] + del weights[f"{hf_prefix}.self_attn.k_proj.weight"] + del weights[f"{hf_prefix}.self_attn.v_proj.weight"] return {"embedding": embedding, "transformer": transformer, "lm_head": lm_head} From 485c7a632e9bad2109b6675b4cf7ccb9cefed7a0 Mon Sep 17 00:00:00 2001 From: Xingyao Wang Date: Sat, 11 Nov 2023 03:59:42 +0000 Subject: [PATCH 07/11] support bf16 model verify --- verify_correctness.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/verify_correctness.py b/verify_correctness.py index 6f4143b..69ec592 100644 --- a/verify_correctness.py +++ b/verify_correctness.py @@ -48,21 +48,26 @@ def is_meta_llama2_path(path: Optional[Path]) -> bool: def hf_provider(name: str, cache_dir: Optional[Path], device: str, - size: int = 7): + size: int = 7, bf16: bool = False): print("Getting huggingface model...") + extra_kwargs = {} + if bf16: + extra_kwargs = {"torch_dtype": torch.bfloat16} if name == "falcon": model = AutoModelForCausalLM.from_pretrained( f"tiiuae/falcon-{size}b", cache_dir=cache_dir, - trust_remote_code=True + trust_remote_code=True, + **extra_kwargs ) elif name == "llama": try: - model = LlamaForCausalLM.from_pretrained(cache_dir) + model = LlamaForCausalLM.from_pretrained(cache_dir, **extra_kwargs) except OSError: print(f"Cache dir {cache_dir} does not look like a huggingface " "checkpoint, assuming cache_dir instead") model = LlamaForCausalLM.from_pretrained( - f"decapoda-research/llama-{size}b-hf", cache_dir=cache_dir + f"decapoda-research/llama-{size}b-hf", cache_dir=cache_dir, + **extra_kwargs ) elif name == "llama2" and is_meta_llama2_path(cache_dir): print(f"baseline path {cache_dir} does not look like a huggingface, " @@ -70,17 +75,19 @@ def hf_provider(name: str, cache_dir: Optional[Path], device: str, model = Llama2Wrapper(cache_dir) elif name == "llama2": model = LlamaForCausalLM.from_pretrained( - f"meta-llama/Llama-2-{size}b-hf", cache_dir=cache_dir + f"meta-llama/Llama-2-{size}b-hf", cache_dir=cache_dir, + **extra_kwargs ) elif name == "mistral": assert size == 7, "Mistral only supports 7B model" try: - model = MistralForCausalLM.from_pretrained(cache_dir) + model = MistralForCausalLM.from_pretrained(cache_dir, **extra_kwargs) except OSError: print(f"Cache dir {cache_dir} does not look like a huggingface " "checkpoint, assuming cache_dir instead") model = MistralForCausalLM.from_pretrained( - f"mistralai/Mistral-{size}B-v0.1", cache_dir=cache_dir + f"mistralai/Mistral-{size}B-v0.1", cache_dir=cache_dir, + **extra_kwargs ) else: raise KeyError(f"Model {name} not implemented") @@ -157,7 +164,7 @@ def main(): else: print("NOTE: The given path does not look like a megatron checkpoint, " f"assuming it's a huggingface checkpoint instead (path={args.load})") - our_model = hf_our_provider(args.model_name, args.load, "cuda:0") + our_model = hf_our_provider(args.model_name, args.load, "cuda:0", bf16=args.bf16) our_forward = hf_forward args.iteration = 0 From 76a4780e821c7bc99c5c92e5a761536ed81c4916 Mon Sep 17 00:00:00 2001 From: Xingyao Wang Date: Sat, 11 Nov 2023 04:20:55 +0000 Subject: [PATCH 08/11] remove packed_input arg that should not appear here --- megatron/model/transformer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index fe8c69f..a3247f0 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -305,7 +305,6 @@ def __init__(self, self.num_attention_heads_kv = args.num_attention_heads_kv self.num_attention_heads = args.num_attention_heads self.seq_length = args.seq_length - self.packed_input = args.packed_input if self.use_flash_attn: assert attention_type == AttnType.self_attn, ('FlashAttention code path only supports ' 'self-attention for now') From 1b506a086cf82cf85295ad56d33aa58852248c09 Mon Sep 17 00:00:00 2001 From: Xingyao Wang Date: Fri, 10 Nov 2023 23:16:05 -0600 Subject: [PATCH 09/11] add mistral to checkpoint util and loader --- tools/checkpoint_loader_megatron.py | 2 +- tools/checkpoint_util.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/checkpoint_loader_megatron.py b/tools/checkpoint_loader_megatron.py index a2db57a..18f776b 100644 --- a/tools/checkpoint_loader_megatron.py +++ b/tools/checkpoint_loader_megatron.py @@ -90,7 +90,7 @@ def check_for_arg(arg_name): if args.model_type == 'GPT': from pretrain_gpt import model_provider margs.model_type = ModelType.encoder_or_decoder - elif args.model_type in {"falcon", "llama", "llama2", "codellama"}: + elif args.model_type in {"falcon", "llama", "llama2", "codellama", "mistral"}: from finetune import model_provider margs.model_name = args.model_type margs.model_type = ModelType.encoder_or_decoder diff --git a/tools/checkpoint_util.py b/tools/checkpoint_util.py index 773c2bb..019526c 100644 --- a/tools/checkpoint_util.py +++ b/tools/checkpoint_util.py @@ -109,7 +109,7 @@ def main(): allow_abbrev=False, conflict_handler='resolve') parser.add_argument('--model_type', type=str, required=True, - choices=['GPT', 'BERT', 'falcon', 'llama', 'llama2', 'codellama'], + choices=['GPT', 'BERT', 'falcon', 'llama', 'llama2', 'codellama', 'mistral'], help='Type of the model') parser.add_argument('--loader', type=str, default='megatron', help='Module name to load checkpoint, should be on python path') From 136a9eb196dfdf0434d92e91638075b83a587cb4 Mon Sep 17 00:00:00 2001 From: Xingyao Wang Date: Fri, 10 Nov 2023 23:30:41 -0600 Subject: [PATCH 10/11] fix mistral sliding window size not loaded --- finetune.py | 3 +++ tools/checkpoint_saver_megatron.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/finetune.py b/finetune.py index 7ef5bec..da9adcb 100644 --- a/finetune.py +++ b/finetune.py @@ -37,6 +37,9 @@ def model_provider(pre_process: bool = True, post_process: bool = True): cls = partial(LlamaModel, version=1 if args.model_name == "llama" else 2) elif args.model_name == "mistral": cls = MistralModel + if args.sliding_window_size != 4096: + print_rank_0("Mistral uses sliding window attention (set sliding_window=4096)") + args.sliding_window_size = 4096 else: raise KeyError(f"Unkown model {args.model_name}") diff --git a/tools/checkpoint_saver_megatron.py b/tools/checkpoint_saver_megatron.py index 3f19667..85d0487 100644 --- a/tools/checkpoint_saver_megatron.py +++ b/tools/checkpoint_saver_megatron.py @@ -154,7 +154,7 @@ def check_message(msg): elif md.model_type == 'BERT': from pretrain_bert import model_provider margs.model_type = ModelType.encoder_or_decoder - elif md.model_type in {'falcon', 'llama', 'llama2', 'codellama'}: + elif md.model_type in {'falcon', 'llama', 'llama2', 'codellama', 'mistral'}: from finetune import model_provider margs.model_name = args.model_type margs.model_type = ModelType.encoder_or_decoder From bc354d2ccfa997dfcfe7f5abe7404584965abdde Mon Sep 17 00:00:00 2001 From: Xingyao Wang Date: Mon, 13 Nov 2023 15:48:23 -0600 Subject: [PATCH 11/11] support convert megatron to hf for mistral --- weights_conversion/megatron_to_hf.py | 152 ++++++++++++++++++++++++++- 1 file changed, 149 insertions(+), 3 deletions(-) diff --git a/weights_conversion/megatron_to_hf.py b/weights_conversion/megatron_to_hf.py index b0cf596..7618adc 100644 --- a/weights_conversion/megatron_to_hf.py +++ b/weights_conversion/megatron_to_hf.py @@ -32,7 +32,7 @@ import torch from tqdm.auto import trange -from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizerFast, FalconConfig, FalconForCausalLM, AutoTokenizer +from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizerFast, FalconConfig, FalconForCausalLM, AutoTokenizer, MistralConfig, MistralForCausalLM from utils.permute_qkv import permute_qkv @@ -193,6 +193,142 @@ def write_llama_model(model_path, max_num_params_per_shard = param_count*2 // max(1,(num_output_shards-1)) model.save_pretrained(model_path, max_shard_size=max_num_params_per_shard) +def write_mistral_model( + model_path, + input_base_path, + num_output_shards: int=2, + norm_eps: float=1e-5, + rope_theta: float=10000.0, + vocab_size: int=None, +): + + # Preliminaries + print(f"Fetching all parameters from the checkpoint at {input_base_path}.") + os.makedirs(model_path, exist_ok=True) + with open(os.path.join(input_base_path, 'latest_checkpointed_iteration.txt')) as f: + iteration = f.read() + if iteration != "release": + iteration = f"iter_{int(iteration):07d}" + print(f"Fetching iteration {iteration}") + + # Load weights + base_path = Path(input_base_path)/iteration + assert len(list(base_path.glob("mp_rank_*"))) == 1, "Unshard your model with checkpoint_util.py first!" + loaded = torch.load(base_path/"mp_rank_00"/"model_optim_rng.pt", map_location="cpu") + args = loaded['args'] + + loaded = loaded['model']['language_model'] + if 'transformer' not in loaded: # normalize key names + loaded["transformer"] = loaded.pop("encoder") + for key in list(loaded["transformer"].keys()): + loaded["transformer"][key.replace("self_attention", "attention")] = loaded["transformer"].pop(key) + loaded["embedding"]["word_embeddings.weight"] = loaded["embedding"].pop("word_embeddings")["weight"] + args.num_layers = args.encoder_num_layers + + # Load arguments + n_layers = args.num_layers + n_heads = args.num_attention_heads + n_heads_kv = getattr(args, "num_attention_heads_kv", n_heads) + n_dense = args.ffn_hidden_size + n_hidden = args.hidden_size + hidden_per_head = n_hidden // n_heads + intermediate_size = args.ffn_hidden_size + inv_freq = 1.0 / (rope_theta ** (torch.arange(0, hidden_per_head, 2).float() / hidden_per_head)) + + print('Mistral-Megatron Loaded!') + param_count = 0 + index_dict = {"weight_map": {}} + + # Start conversion + with TemporaryDirectory() as tmp_model_path: + print(f'Weighted Converting for {n_layers} layers...') + for layer_i in range(n_layers): + filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin" + wq_proj, wk_proj, wv_proj = convert_wqkv(llama_mega=loaded, + layer_idx=layer_i, n_heads=n_heads, + n_heads_kv=n_heads_kv) + ffn_w1, ffn_w3 = convert_ffn(llama_mega=loaded, + layer_idx=layer_i, + n_dense=n_dense) + state_dict = { + f"model.layers.{layer_i}.self_attn.q_proj.weight": wq_proj, + f"model.layers.{layer_i}.self_attn.k_proj.weight": wk_proj, + f"model.layers.{layer_i}.self_attn.v_proj.weight": wv_proj, + f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded["transformer"][f"layers.{layer_i}.attention.dense.weight"], + f"model.layers.{layer_i}.mlp.gate_proj.weight": ffn_w1, + f"model.layers.{layer_i}.mlp.down_proj.weight": loaded["transformer"][f"layers.{layer_i}.mlp.dense_4h_to_h.weight"], + f"model.layers.{layer_i}.mlp.up_proj.weight": ffn_w3, + f"model.layers.{layer_i}.input_layernorm.weight": loaded["transformer"][f"layers.{layer_i}.input_layernorm.weight"], + f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded["transformer"][f"layers.{layer_i}.post_attention_layernorm.weight"], + f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq": inv_freq + } + + for k, v in state_dict.items(): + index_dict["weight_map"][k] = filename + param_count += v.numel() + torch.save(state_dict, os.path.join(tmp_model_path, filename)) + print(f'Sharded file saved to {filename}') + + filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin" + state_dict = { + "model.norm.weight": loaded["transformer"]['final_layernorm.weight'], + "lm_head.weight": loaded['lm_head'], + "model.embed_tokens.weight": loaded['embedding']["word_embeddings.weight"] + } + + for k, v in state_dict.items(): + index_dict["weight_map"][k] = filename + param_count += v.numel() + torch_dtype = state_dict["lm_head.weight"].dtype + torch.save(state_dict, os.path.join(tmp_model_path, filename)) + print(f'Sharded file saved to {filename}') + + # Write configs and save + index_dict["metadata"] = {"total_size": param_count * 2} + write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json")) + + # load mistral config from huggingface + config = MistralConfig.from_pretrained( + "mistralai/Mistral-7B-v0.1" + ) + # assert configuration matches + assert config.hidden_size == n_hidden + assert config.intermediate_size == intermediate_size + assert config.num_attention_heads == n_heads + assert config.num_hidden_layers == n_layers + assert config.rms_norm_eps == norm_eps + assert config.num_key_value_heads == n_heads_kv + # Set vocab size + config.vocab_size = args.padded_vocab_size + config.save_pretrained(tmp_model_path) + + # Make space so we can load the model properly now. + del state_dict + del loaded + gc.collect() + + if vocab_size is None: + vocab_size = args.padded_vocab_size + else: + print(f"Using vocab size {vocab_size} from tokenizer and not {args.padded_vocab_size} from args.") + # update config + config.vocab_size = vocab_size + + print("Loading the checkpoint in a Llama model...") + model = MistralForCausalLM.from_pretrained( + tmp_model_path, + torch_dtype=torch_dtype + ) + model.config.vocab_size = vocab_size + # resizes the embedding layer to the correct size + model.resize_token_embeddings(vocab_size) + # Avoid saving this as part of the config. + del model.config._name_or_path + + print("Saving in the Transformers format.") + max_num_params_per_shard = param_count*2 // max(1,(num_output_shards-1)) + model.save_pretrained(model_path, max_shard_size=max_num_params_per_shard) + def write_falcon_model( model_path: str, @@ -338,7 +474,8 @@ def permute(qkv_w): def write_tokenizer(args: Namespace): - if args.model in {"llama", "llama2", "codellama"}: + if args.model in {"llama", "llama2", "codellama", "mistral"}: + # mistral also use LlamaTokenizerFast args.tokenizer_type = "SentencePieceTokenizer" if args.vocab_file: # prevent "single file or url is deprecated and won't be possible anymore in v5" warning, @@ -351,6 +488,8 @@ def write_tokenizer(args: Namespace): else: if args.model == "codellama": hf_repo_name = "TheBloke/CodeLlama-13B-fp16" + elif args.model == "mistral": + hf_repo_name = "mistralai/Mistral-7B-v0.1" else: hf_repo_name = "meta-llama/Llama-2-7b-hf" try: # try loading from huggingface @@ -441,7 +580,7 @@ def main(): parser.add_argument("--input_dir", help="Location of Megatron weights", required=True) parser.add_argument("--num_output_shards", type=int, default=1) - parser.add_argument("--model", choices={"falcon", "llama", "llama2", "codellama"}, + parser.add_argument("--model", choices={"falcon", "llama", "llama2", "codellama", "mistral"}, default="llama2") parser.add_argument("--output_dir", help="Location to write HF model and tokenizer", required=True) @@ -465,6 +604,13 @@ def main(): norm_eps=eps, rope_theta=rope_theta, ) + elif args.model == "mistral": + write_mistral_model( + model_path=args.output_dir, + input_base_path=args.input_dir, + num_output_shards=args.num_output_shards, + vocab_size=vocab_size, + ) elif args.model == "falcon": write_falcon_model( model_path=args.output_dir,