Skip to content

Commit

Permalink
Merge pull request #88 from xingyaoww/mistral-pr-branch
Browse files Browse the repository at this point in the history
Add Mistral Model
  • Loading branch information
AleHD authored Nov 29, 2023
2 parents dead8d2 + bc354d2 commit a8822f8
Show file tree
Hide file tree
Showing 13 changed files with 402 additions and 27 deletions.
11 changes: 8 additions & 3 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from megatron.training import pretrain
from megatron.core import tensor_parallel
from megatron.core.parallel_state import get_data_parallel_group
from megatron.model import GPTModel, ModelType, LlamaModel, FalconModel
from megatron.model import GPTModel, ModelType, LlamaModel, FalconModel, MistralModel
from megatron.utils import get_ltor_masks_and_position_ids, average_losses_across_data_parallel_group
from megatron.data.gpt_dataset import build_train_valid_test_datasets as gpt_build_datasets
from megatron.data.instruction_dataset import instruction_collator
Expand All @@ -35,8 +35,13 @@ def model_provider(pre_process: bool = True, post_process: bool = True):
cls = FalconModel
elif args.model_name in {"llama", "llama2", "codellama"}:
cls = partial(LlamaModel, version=1 if args.model_name == "llama" else 2)
elif args.model_name == "mistral":
cls = MistralModel
if args.sliding_window_size != 4096:
print_rank_0("Mistral uses sliding window attention (set sliding_window=4096)")
args.sliding_window_size = 4096
else:
raise KeyError(f"Unkown model {other}")
raise KeyError(f"Unkown model {args.model_name}")

if isinstance(args.model_type, ModelType):
model_type = args.model_type
Expand Down Expand Up @@ -238,7 +243,7 @@ def extra_args(parser):
"""Text generation arguments."""
group = parser.add_argument_group(title='validation set')
group.add_argument("--model_name",
choices={"gpt", "llama", "falcon", "llama2", "codellama"},
choices={"gpt", "llama", "falcon", "llama2", "codellama", "mistral"},
default="gpt")
group.add_argument("--model_type", choices={"encoder_or_decoder", "encoder_and_decoder"},
default="encoder_or_decoder")
Expand Down
3 changes: 3 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,9 @@ def _add_network_size_args(parser):
group.add_argument("--no_tie_embed_logits", action="store_false", dest="tie_embed_logits",
help=("If set, the weights of the word embedding and lm_head "
"are not tied"))
# Added for Mistral
group.add_argument("--sliding_window_size", type=int, default=None,
help="Whether to use sliding window attention for Mistral. Default is None, which means no sliding window attention.")
return parser


Expand Down
1 change: 1 addition & 0 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,7 @@ def _set_arg(arg_name, old_arg_name=None, force=False):
_set_arg('tie_embed_logits', force=True)
_set_arg('make_vocab_size_divisible_by', force=True)
_set_arg('train_iters')
_set_arg('sliding_window_size')
if checkpoint_version < 3.0:
_set_arg('tensor_model_parallel_size',
'model_parallel_size')
Expand Down
1 change: 1 addition & 0 deletions megatron/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .gpt_model import GPTModel
from .llama_model import LlamaModel
from .falcon_model import FalconModel
from .mistral_model import MistralModel
from .t5_model import T5Model
from .module import Float16Module
from .enums import ModelType
46 changes: 46 additions & 0 deletions megatron/model/mistral_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Mistral Model."""

import warnings

from megatron import get_args
from .enums import PositionEmbeddingType
from . import GPTModel


class MistralModel(GPTModel):
def __init__(self,
num_tokentypes: int = 0,
parallel_output: bool = True,
pre_process: bool = True,
post_process: bool = True,
model_type=None
):

args = get_args()

# mandatory arguments
assert args.position_embedding_type == PositionEmbeddingType.rotary, \
f"Mistral uses rotary embedding, not {args.position_embedding_type}"
assert not args.use_post_ln, "Mistral does not use post_ln"
assert args.glu_activation == "swiglu", "Mistral works with swiglu activation"
assert not args.use_bias, "Mistral does not use bias"
assert not args.parallel_attn, "Mistral does not use parallel_attn"
assert args.use_rms_norm, "Mistral uses rms_norm"
assert not args.tie_embed_logits , "Mistral unties embedding and lm_head weights"
assert args.sliding_window_size == 4096, "Mistral uses sliding window attention (sliding_window=4096)"

# recomended arguments
if not args.use_flash_attn:
warnings.warn("Mistral should use flash attn (for sliding window local attention)")

if args.bias_gelu_fusion:
warnings.warn("Mistral is not intended to use bias_gelu_fusion")
if args.bias_dropout_fusion:
warnings.warn("Mistral is not intended to use bias_dropout_fusion")
if args.hidden_dropout > 0.0 and not args.lima_dropout:
warnings.warn("Mistral is not intended to use dropout")
if args.attention_dropout > 0.0:
warnings.warn("Mistral is not intended to use dropout")
super().__init__(num_tokentypes=num_tokentypes, parallel_output=parallel_output,
pre_process=pre_process, post_process=post_process,
model_type=model_type)
38 changes: 34 additions & 4 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
# Extracted from: https://github.com/bigscience-workshop/Megatron-DeepSpeed
from .glu_activations import GLU_ACTIVATIONS
from megatron.model.positional_embeddings import precompute_freqs_cis, apply_rotary_emb

from flash_attn.bert_padding import pad_input, unpad_input_for_concatenated_sequences

""" We use the following notation throughout this file:
h: hidden size
Expand Down Expand Up @@ -301,6 +301,7 @@ def __init__(self,
self.params_dtype = args.params_dtype
self.sequence_parallel = args.sequence_parallel
self.use_flash_attn = args.use_flash_attn
self.sliding_window_size = args.sliding_window_size
self.num_attention_heads_kv = args.num_attention_heads_kv
self.num_attention_heads = args.num_attention_heads
self.seq_length = args.seq_length
Expand All @@ -309,6 +310,14 @@ def __init__(self,
'self-attention for now')
assert self.attn_mask_type == AttnMaskType.causal, ('FlashAttention code path only '
'supports causal mask for now')
# If sliding window is enabled, we need to make sure that the sliding window is supported.
if self.sliding_window_size is not None:
import inspect
# https://github.com/huggingface/transformers/blob/7e1eff7600085814eac65876d4d8a0e38c2f6ccc/src/transformers/models/mistral/modeling_mistral.py#L50C5-L50C32
assert "window_size" in list(inspect.signature(
flash_attn.flash_attn_func
).parameters), "The current flash attention version does not support sliding window attention, please update to the latest version."
assert self.use_flash_attn, "Sliding window attention is only supported with flash attention for now."

projection_size = args.kv_channels * args.num_attention_heads

Expand Down Expand Up @@ -513,13 +522,34 @@ def forward(self,
context_layer = self.core_attention(
query_layer, key_layer, value_layer, attention_mask)
else:
flash_attn_extra_kwargs = {}
# check if we need to use sliding window attention
# https://github.com/huggingface/transformers/blob/7ee995fd9c692761c4601ddbffa2ac2ec9f27b0b/src/transformers/models/mistral/modeling_mistral.py#L353
if self.sliding_window_size is not None:
kv_seq_len = key_layer.shape[0]
if kv_seq_len > self.sliding_window_size:
# https://github.com/huggingface/transformers/blob/7ee995fd9c692761c4601ddbffa2ac2ec9f27b0b/src/transformers/models/mistral/modeling_mistral.py#L510C21-L510C89
flash_attn_extra_kwargs["window_size"] = (
self.sliding_window_size, self.sliding_window_size
)
# It will be truncated to the actual sequence length inside flash attention
# https://github.com/Dao-AILab/flash-attention/blob/83aef842beec1037eb8c1d9c3ef3ed8aae80b091/csrc/flash_attn/src/softmax.h#L159-L161

q, k, v = [rearrange(x, "s b n h -> b s n h").contiguous()
for x in (query_layer, key_layer, value_layer)]
for x in (query_layer, key_layer, value_layer)]
if not self.sequence_parallel:
with megatron.core.tensor_parallel.get_cuda_rng_tracker().fork():
context_layer = self.core_attention_flash(q, k, v, causal=True)
context_layer = self.core_attention_flash(
q, k, v,
causal=True,
**flash_attn_extra_kwargs
)
else:
context_layer = self.core_attention_flash(q, k, v, causal=True)
context_layer = self.core_attention_flash(
q, k, v,
causal=True,
**flash_attn_extra_kwargs
)
context_layer = rearrange(context_layer, 'b s n h -> s b (n h)').contiguous()

# =================
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
transformers >= 4.31.0
torch >= 2.0.0
flash-attn >= 2.0.0
flash-attn >= 2.3.3
datasets >= 2.14.0
nltk >= 3.8.0
sentencepiece >= 0.1.0
Expand Down
3 changes: 2 additions & 1 deletion tools/checkpoint_loader_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def check_for_arg(arg_name):
if args.model_type == 'GPT':
from pretrain_gpt import model_provider
margs.model_type = ModelType.encoder_or_decoder
elif args.model_type in {"falcon", "llama", "llama2", "codellama"}:
elif args.model_type in {"falcon", "llama", "llama2", "codellama", "mistral"}:
from finetune import model_provider
margs.model_name = args.model_type
margs.model_type = ModelType.encoder_or_decoder
Expand Down Expand Up @@ -187,6 +187,7 @@ def _get_models(count, dtype, pre_process, post_process):
md.glu_activation = margs.glu_activation
md.tie_embed_logits = margs.tie_embed_logits
md.params_dtype = margs.params_dtype
md.sliding_window_size = margs.sliding_window_size
if margs.position_embedding_type == PositionEmbeddingType.absolute:
md.position_embedding_type = "absolute"
elif margs.position_embedding_type == PositionEmbeddingType.rotary:
Expand Down
2 changes: 1 addition & 1 deletion tools/checkpoint_saver_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def check_message(msg):
elif md.model_type == 'BERT':
from pretrain_bert import model_provider
margs.model_type = ModelType.encoder_or_decoder
elif md.model_type in {'falcon', 'llama', 'llama2', 'codellama'}:
elif md.model_type in {'falcon', 'llama', 'llama2', 'codellama', 'mistral'}:
from finetune import model_provider
margs.model_name = args.model_type
margs.model_type = ModelType.encoder_or_decoder
Expand Down
2 changes: 1 addition & 1 deletion tools/checkpoint_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def main():
allow_abbrev=False, conflict_handler='resolve')

parser.add_argument('--model_type', type=str, required=True,
choices=['GPT', 'BERT', 'falcon', 'llama', 'llama2', 'codellama'],
choices=['GPT', 'BERT', 'falcon', 'llama', 'llama2', 'codellama', 'mistral'],
help='Type of the model')
parser.add_argument('--loader', type=str, default='megatron',
help='Module name to load checkpoint, should be on python path')
Expand Down
41 changes: 29 additions & 12 deletions verify_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
import llama
from torch import nn
from transformers import AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizer
from transformers import AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizer, MistralForCausalLM
from fairscale.nn.model_parallel.initialize import initialize_model_parallel

from megatron import get_args, update_num_microbatches
Expand Down Expand Up @@ -48,30 +48,47 @@ def is_meta_llama2_path(path: Optional[Path]) -> bool:


def hf_provider(name: str, cache_dir: Optional[Path], device: str,
size: int = 7):
size: int = 7, bf16: bool = False):
print("Getting huggingface model...")
extra_kwargs = {}
if bf16:
extra_kwargs = {"torch_dtype": torch.bfloat16}
if name == "falcon":
model = AutoModelForCausalLM.from_pretrained(
f"tiiuae/falcon-{size}b", cache_dir=cache_dir,
trust_remote_code=True
trust_remote_code=True,
**extra_kwargs
)
elif name == "llama":
try:
model = LlamaForCausalLM.from_pretrained(cache_dir)
model = LlamaForCausalLM.from_pretrained(cache_dir, **extra_kwargs)
except OSError:
print(f"Cache dir {cache_dir} does not look like a huggingface "
"checkpoint, assuming cache_dir instead")
model = LlamaForCausalLM.from_pretrained(
f"decapoda-research/llama-{size}b-hf", cache_dir=cache_dir
f"decapoda-research/llama-{size}b-hf", cache_dir=cache_dir,
**extra_kwargs
)
elif name == "llama2" and is_meta_llama2_path(cache_dir):
print(f"baseline path {cache_dir} does not look like a huggingface, "
"assuming it's raw llama2 weights instead")
model = Llama2Wrapper(cache_dir)
elif name == "llama2":
model = LlamaForCausalLM.from_pretrained(
f"meta-llama/Llama-2-{size}b-hf", cache_dir=cache_dir
f"meta-llama/Llama-2-{size}b-hf", cache_dir=cache_dir,
**extra_kwargs
)
elif name == "mistral":
assert size == 7, "Mistral only supports 7B model"
try:
model = MistralForCausalLM.from_pretrained(cache_dir, **extra_kwargs)
except OSError:
print(f"Cache dir {cache_dir} does not look like a huggingface "
"checkpoint, assuming cache_dir instead")
model = MistralForCausalLM.from_pretrained(
f"mistralai/Mistral-{size}B-v0.1", cache_dir=cache_dir,
**extra_kwargs
)
else:
raise KeyError(f"Model {name} not implemented")
return model.eval().requires_grad_(False).to(device)
Expand Down Expand Up @@ -114,7 +131,7 @@ def verify_step(our_forward, our_model, base_forward, base_model, batch):
our_logits, our_loss = our_forward(our_model, batch)
base_logits, base_loss = base_forward(base_model, batch)
assert our_logits.size() == base_logits.size(), \
f"ours={logits1.size()}, true={logits2.size()}"
f"ours={our_logits.size()}, true={base_logits.size()}"
our_logits = our_logits.cpu()
base_logits = base_logits.cpu()
abs_error = torch.abs(our_logits - base_logits)
Expand Down Expand Up @@ -147,7 +164,7 @@ def main():
else:
print("NOTE: The given path does not look like a megatron checkpoint, "
f"assuming it's a huggingface checkpoint instead (path={args.load})")
our_model = hf_our_provider(args.model_name, args.load, "cuda:0")
our_model = hf_our_provider(args.model_name, args.load, "cuda:0", bf16=args.bf16)
our_forward = hf_forward
args.iteration = 0

Expand Down Expand Up @@ -192,9 +209,9 @@ def extra_extra_args(parser):
if __name__ == "__main__":
defaults = {"micro_batch_size": 1, "use_checkpoint_args": True, "train_iters": 10,
"lr": 1.0}
if not is_megatron_path(parse_args(extra_extra_args).load):
defaults.update({"encoder_num_layers": 1, "hidden_size": 1,
"num_attention_heads": 1, "seq_length": 2048,
"max_position_embeddings": 2048})
# if not is_megatron_path(parse_args(extra_extra_args).load):
# defaults.update({"encoder_num_layers": 1, "hidden_size": 1,
# "num_attention_heads": 1, "seq_length": 2048,
# "max_position_embeddings": 2048})
initialize_megatron(extra_extra_args, args_defaults=defaults)
main()
Loading

0 comments on commit a8822f8

Please sign in to comment.