Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemma 2 #9672

Merged
merged 12 commits into from
Jul 12, 2024
Merged

Gemma 2 #9672

merged 12 commits into from
Jul 12, 2024

Conversation

cuichenx
Copy link
Collaborator

@cuichenx cuichenx commented Jul 10, 2024

What does this PR do ?

Duplicate of #9587 for release branch
Also includes transformers version update from #9606

Collection: [Note which collection this PR will affect]

Changelog

  • Add specific line by line info of high level changes in this PR.

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this 

GitHub Actions CI

The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.

The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

If you haven't finished some of the above items you can still open "Draft" PR.

Who can review?

Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.

Additional Information

  • Related to # (issue)

@cuichenx cuichenx changed the title Gemma2 (release branch) Gemma 2 Jul 10, 2024
@github-actions github-actions bot added the NLP label Jul 10, 2024
@cuichenx cuichenx mentioned this pull request Jul 10, 2024
8 tasks
@cuichenx cuichenx requested a review from yaoyu-33 July 10, 2024 18:38
@Emperorizzis
Copy link

Emperorizzis commented Jul 11, 2024

Hi! Thank you for your work!

When continue pretrain the model using your updated gemma-related code, I found that the initial loss is around 9.x, while hf was about 2.x. Are there still some differences that are not aligned?

@cuichenx
Copy link
Collaborator Author

Hi! Thank you for your work!

When continue pretrain the model using your updated gemma-related code, I found that the initial loss is around 9.x, while hf was about 2.x. Are there still some differences that are not aligned?

Hi @Emperorizzis, thanks for your interest! It should work the same as HF but there could be a bug somewhere. Do you mind sharing your config and/or run command?

@Emperorizzis
Copy link

Hi! Thank you for your work!
When continue pretrain the model using your updated gemma-related code, I found that the initial loss is around 9.x, while hf was about 2.x. Are there still some differences that are not aligned?

Hi @Emperorizzis, thanks for your interest! It should work the same as HF but there could be a bug somewhere. Do you mind sharing your config and/or run command?

Hi @cuichenx, I used your modeling file, as well as the native megatron (f3a3020031f384ddafd9b7e9f3a587798c0aea21) for training (with a few additional arguments). Below are my megatron arguments configuration during training:

------------------------ arguments ------------------------
  accumulate_allreduce_grads_in_fp32 .............. True
  adam_beta1 ...................................... 0.9
  adam_beta2 ...................................... 0.95
  adam_eps ........................................ 1e-08
  adaptive_seq_len ................................ False
  add_bias_attn_fc ................................ True
  add_bias_linear ................................. False
  add_bias_linear_fc .............................. True
  add_position_embedding .......................... True
  add_qkv_bias .................................... False
  adlr_autoresume ................................. False
  adlr_autoresume_interval ........................ 1000
  apply_layernorm_1p .............................. True
  apply_query_key_layer_scaling ................... False
  apply_residual_connection_post_layernorm ........ False
  apply_rope_fusion ............................... False
  async_tensor_model_parallel_allreduce ........... False
  attention_dropout ............................... 0.0
  attention_head_type ............................. None
  attention_softmax_in_fp32 ....................... False
  auto_detect_ckpt_format ......................... False
  barrier_with_L1_time ............................ True
  bert_binary_head ................................ True
  bert_embedder_type .............................. megatron
  bert_load ....................................... None
  bf16 ............................................ True
  bias_dropout_fusion ............................. True
  bias_gelu_fusion ................................ False
  bias_swiglu_fusion .............................. True
  biencoder_projection_dim ........................ 0
  biencoder_shared_query_context_model ............ False
  block_data_path ................................. None
  check_for_nan_in_loss_and_grad .................. True
  ckpt_fully_parallel_save ........................ False
  ckpt_step ....................................... None
  classes_fraction ................................ 1.0
  clip_grad ....................................... 1.0
  clone_scatter_output_in_embedding ............... True
  consumed_train_samples .......................... 0
  consumed_valid_samples .......................... 0
  context_parallel_size ........................... 1
  create_attention_mask_in_dataloader ............. True
  cvcuda_image_processing ......................... False
  data_cache_path ................................. None
  data_dir ........................................ None
  data_impl ....................................... mmap
  data_parallel_random_init ....................... False
  data_parallel_size .............................. 2
  data_path ....................................... ['/basedir/mcore_data/gemma2/testing/testing']
  data_per_class_fraction ......................... 1.0
  data_sharding ................................... True
  dataloader_type ................................. single
  dataset ......................................... None
  decoder_num_layers .............................. None
  decoder_seq_length .............................. None
  decoupled_lr .................................... None
  decoupled_min_lr ................................ None
  delay_grad_reduce ............................... True
  delay_param_gather .............................. False
  dino_bottleneck_size ............................ 256
  dino_freeze_last_layer .......................... 1
  dino_head_hidden_size ........................... 2048
  dino_local_crops_number ......................... 10
  dino_local_img_size ............................. 96
  dino_norm_last_layer ............................ False
  dino_teacher_temp ............................... 0.07
  dino_warmup_teacher_temp ........................ 0.04
  dino_warmup_teacher_temp_epochs ................. 30
  dist_ckpt_format ................................ torch_dist
  distribute_saved_activations .................... False
  distributed_backend ............................. nccl
  distributed_timeout_minutes ..................... 60
  dummy_load ...................................... 
  embed_layernorm ................................. False
  embedding_path .................................. None
  empty_unused_memory_level ....................... 0
  enable_one_logger ............................... False
  enable_parallel_output .......................... True
  enable_shared_expert ............................ False
  encoder_num_layers .............................. 42
  encoder_seq_length .............................. 64
  end_weight_decay ................................ 0.1
  eod_mask_loss ................................... False
  epochs .......................................... None
  eval_dev ........................................ False
  eval_fp32 ....................................... False
  eval_interval ................................... 1000000000
  eval_iters ...................................... 1
  evidence_data_path .............................. None
  exit_duration_in_mins ........................... None
  exit_interval ................................... None
  exit_on_missing_checkpoint ...................... False
  exit_signal_handler ............................. False
  expert_interval ................................. 2
  expert_model_parallel_size ...................... 1
  expert_tensor_parallelism ....................... False
  extra_vocab_size ................................ 0
  ffn_hidden_size ................................. 14336
  finetune ........................................ False
  fp16 ............................................ False
  fp16_lm_cross_entropy ........................... False
  fp32_residual_connection ........................ False
  fp8 ............................................. None
  fp8_amax_compute_algo ........................... most_recent
  fp8_amax_history_len ............................ 1
  fp8_interval .................................... 1
  fp8_margin ...................................... 0
  fp8_wgrad ....................................... True
  freeze_clip_vision_tower ........................ False
  freeze_llm ...................................... False
  generation_length ............................... None
  global_batch_size ............................... 8
  glu_activation .................................. None
  gradient_accumulation_fusion .................... True
  group_query_attention ........................... True
  head_lr_mult .................................... 1.0
  hidden_dropout .................................. 0.0
  hidden_size ..................................... 3584
  hysteresis ...................................... 2
  ict_head_size ................................... None
  ict_load ........................................ None
  image_aspect_ratio .............................. square
  image_folder .................................... 
  image_size ...................................... None
  img_h ........................................... 224
  img_w ........................................... 224
  indexer_batch_size .............................. 128
  indexer_log_interval ............................ 1000
  inference_batch_times_seqlen_threshold .......... 512
  init_method_std ................................. 0.02
  init_method_xavier_uniform ...................... False
  initial_loss_scale .............................. 4294967296
  input_len ....................................... 1
  intermediate_size ............................... None
  iter_per_epoch .................................. 1250
  keep_last ....................................... False
  kv_channels ..................................... 256
  lazy_mpu_init ................................... None
  load ............................................ /basedir/mcore_models/gemma-2-9b-to-mcore-tp4-pp1-te-nemo
  local_rank ...................................... None
  log_batch_size_to_tensorboard ................... False
  log_interval .................................... 1
  log_learning_rate_to_tensorboard ................ True
  log_loss_scale_to_tensorboard ................... True
  log_memory_to_tensorboard ....................... False
  log_num_zeros_in_grad ........................... False
  log_params_norm ................................. False
  log_progress .................................... False
  log_throughput .................................. True
  log_timers_to_tensorboard ....................... False
  log_validation_ppl_to_tensorboard ............... False
  log_world_size_to_tensorboard ................... False
  loss_scale ...................................... None
  loss_scale_window ............................... 1000
  lr .............................................. 3e-05
  lr_decay_iters .................................. 1953125
  lr_decay_samples ................................ None
  lr_decay_style .................................. cosine
  lr_warmup_fraction .............................. None
  lr_warmup_init .................................. 0.0
  lr_warmup_iters ................................. 0
  lr_warmup_samples ............................... 0
  make_vocab_size_divisible_by .................... 128
  manual_gc ....................................... False
  manual_gc_eval .................................. True
  manual_gc_interval .............................. 0
  mask_factor ..................................... 1.0
  mask_prob ....................................... 0.15
  mask_type ....................................... random
  masked_softmax_fusion ........................... False
  max_padding_length .............................. None
  max_position_embeddings ......................... 64
  max_tokens_to_oom ............................... 12000
  merge_file ...................................... None
  micro_batch_size ................................ 1
  min_loss_scale .................................. 1.0
  min_lr .......................................... 3e-06
  mm_projector_type ............................... None
  mm_use_im_patch_token ........................... False
  mm_use_im_start_end ............................. False
  mm_vision_select_layer .......................... None
  mmap_bin_files .................................. True
  mock_data ....................................... False
  moe ............................................. False
  moe_aux_loss_coeff .............................. 0.0
  moe_eval_capacity_factor ........................ 1.0
  moe_expert_parallel_size ........................ None
  moe_ffn_hidden_size ............................. None
  moe_grouped_gemm ................................ False
  moe_input_feature_slicing ....................... False
  moe_input_jitter_eps ............................ None
  moe_loss_coeff .................................. 0.01
  moe_min_capacity ................................ 4
  moe_per_layer_logging ........................... False
  moe_router_load_balancing_type .................. aux_loss
  moe_router_topk ................................. 2
  moe_token_dispatcher_type ....................... allgather
  moe_token_dropping .............................. False
  moe_topk ........................................ 1
  moe_train_capacity_factor ....................... 1.0
  moe_z_loss_coeff ................................ None
  n_head_kv ....................................... None
  nccl_communicator_config_path ................... None
  no_load_optim ................................... None
  no_load_rng ..................................... None
  no_persist_layer_norm ........................... False
  no_save_optim ................................... None
  no_save_rng ..................................... None
  norm_epsilon .................................... 1e-06
  normalization ................................... RMSNorm
  num_attention_heads ............................. 16
  num_channels .................................... 3
  num_classes ..................................... 1000
  num_experts ..................................... None
  num_fewshot ..................................... None
  num_layers ...................................... 42
  num_layers_per_virtual_pipeline_stage ........... None
  num_query_groups ................................ 8
  num_workers ..................................... 8
  one_logger_entity ............................... hwinf_dcm
  one_logger_project .............................. e2e-tracking
  one_logger_run_name ............................. None
  onnx_safe ....................................... None
  openai_gelu ..................................... False
  optimizer ....................................... adam
  out_seq_length .................................. 1024
  output_bert_embeddings .......................... False
  overlap_grad_reduce ............................. False
  overlap_p2p_comm ................................ False
  overlap_param_gather ............................ False
  override_opt_param_scheduler .................... False
  params_dtype .................................... torch.bfloat16
  patch_dim ....................................... 16
  patch_size ...................................... None
  patch_tokenizer_type ............................ Gemma2Tokenizer
  perform_initialization .......................... True
  pipeline_model_parallel_size .................... 1
  pipeline_model_parallel_split_rank .............. None
  position_embedding_type ......................... rope
  position_encoding_2d ............................ False
  pretrained_checkpoint ........................... None
  profile ......................................... False
  profile_ranks ................................... [0]
  profile_step_end ................................ 12
  profile_step_start .............................. 10
  qk_layernorm .................................... False
  query_in_block_prob ............................. 0.1
  rampup_batch_size ............................... None
  rank ............................................ 0
  recompute_granularity ........................... None
  recompute_method ................................ uniform
  recompute_num_layers ............................ None
  repetition_penalty .............................. 1.1
  reset_attention_mask ............................ False
  reset_position_ids .............................. False
  retriever_report_topk_accuracies ................ []
  retriever_score_scaling ......................... False
  retriever_seq_length ............................ 256
  retro_add_retriever ............................. False
  retro_attention_gate ............................ 1
  retro_cyclic_train_iters ........................ None
  retro_encoder_attention_dropout ................. 0.1
  retro_encoder_hidden_dropout .................... 0.1
  retro_encoder_layers ............................ 2
  retro_num_neighbors ............................. 2
  retro_num_retrieved_chunks ...................... 2
  retro_project_dir ............................... None
  retro_verify_neighbor_count ..................... True
  rotary_base ..................................... 10000
  rotary_interleaved .............................. False
  rotary_percent .................................. 1.0
  rotary_scale_factor ............................. 1
  rotary_seq_len_interpolation_factor ............. None
  router_type ..................................... topk
  sample_rate ..................................... 1.0
  save ............................................ /basedir/output_mcore_models/gemma-2-9b-tp4-pp1-for-test
  save_interval ................................... 1000000000
  scatter_gather_tensors_in_pipeline .............. True
  seed ............................................ 1234
  seq_length ...................................... 64
  sequence_parallel ............................... True
  sgd_momentum .................................... 0.9
  shared_moe_ffn_hidden_size ...................... None
  short_seq_prob .................................. 0.1
  skip_train ...................................... False
  sliding_window .................................. None
  source_seq_len .................................. None
  spec ............................................ None
  split ........................................... 100,0,0
  squared_relu .................................... False
  standalone_embedding_stage ...................... False
  start_weight_decay .............................. 0.1
  swiglu .......................................... False
  swin_backbone_type .............................. tiny
  target_seq_len .................................. None
  task_list ....................................... all
  temperature ..................................... 1.0
  tensor_model_parallel_size ...................... 4
  tensorboard_dir ................................. /basedir/testing/tensorboard
  tensorboard_log_interval ........................ 1
  tensorboard_queue_size .......................... 1000
  test_data_path .................................. None
  test_mode ....................................... False
  text_generate_gt_file ........................... 
  text_generate_input_file ........................ 
  text_generate_output_file ....................... 
  time ............................................ False
  timing_log_level ................................ 0
  timing_log_option ............................... minmax
  titles_data_path ................................ None
  tokenizer_model ................................. /basedir/hf_models/gemma-2-9b
  tokenizer_type .................................. NullTokenizer
  top_k ........................................... 0
  top_p ........................................... 0.0
  tp_comm_bulk_dgrad .............................. True
  tp_comm_bulk_wgrad .............................. True
  tp_comm_overlap ................................. False
  tp_comm_overlap_ag .............................. True
  tp_comm_overlap_cfg ............................. None
  tp_comm_overlap_rs .............................. True
  tp_comm_split_ag ................................ True
  tp_comm_split_rs ................................ True
  train_data ...................................... None
  train_data_path ................................. None
  train_iters ..................................... 1953125
  train_samples ................................... None
  transformer_impl ................................ transformer_engine
  transformer_pipeline_model_parallel_size ........ 1
  transformer_timers .............................. False
  transformer_type ................................ megatron
  tune_mm_mlp_adapter ............................. False
  untie_embeddings_and_output_weights ............. False
  use_alibi_mask .................................. False
  use_checkpoint_args ............................. False
  use_checkpoint_opt_param_scheduler .............. False
  use_cpu_initialization .......................... None
  use_dist_ckpt ................................... False
  use_distributed_optimizer ....................... True
  use_flash_attn .................................. False
  use_llama2_rotary_position_embeddings ........... False
  use_mcore_models ................................ True
  use_mistral_rotary_position_embeddings .......... False
  use_normhead .................................... False
  use_one_sent_docs ............................... False
  use_ring_exchange_p2p ........................... False
  use_rotary_position_embeddings .................. False
  use_tutel ....................................... False
  valid_data ...................................... None
  valid_data_path ................................. None
  variable_seq_lengths ............................ False
  verbosity ....................................... INFO
  version ......................................... plain
  virtual_pipeline_model_parallel_size ............ None
  vision_backbone_type ............................ vit
  vision_pretraining .............................. False
  vision_pretraining_type ......................... classify
  vision_tower .................................... 
  vocab_extra_ids ................................. 0
  vocab_file ...................................... None
  vocab_size ...................................... -1
  wandb_exp_name .................................. 
  wandb_project ................................... 
  wandb_save_dir .................................. 
  weight_decay .................................... 0.1
  weight_decay_incr_style ......................... constant
  world_size ...................................... 8
  yaml_cfg ........................................ None
  z_loss_weight ................................... 0.0
-------------------- end of arguments ---------------------

And below are the losses of first 5 steps:

... lm loss: 8.237988E+00 | loss scale: 1.0 | grad norm: 354.350 ...
... lm loss: 4.174296E+00 | loss scale: 1.0 | grad norm: 68.632 ...
... lm loss: 4.116720E+00 | loss scale: 1.0 | grad norm: 43.852 ...
... lm loss: 3.366329E+00 | loss scale: 1.0 | grad norm: 28.838 ...
... lm loss: 3.328519E+00 | loss scale: 1.0 | grad norm: 26.727 ...

The data should be fine; I sampled 1 million articles from enwiki and tokenized them using the tokenizer from gemma2-9b.

@cuichenx
Copy link
Collaborator Author

@Emperorizzis I verified inference and finetuning performance with the code in nemo framework, and the accuracy looked okay.
Maybe you can double check the new components are included in your implementation? These are 1) logit soft capping, 2) alternating SWA, 3) post attn/mlp layernorm, 4) custom scale in attention

@cuichenx cuichenx merged commit 80e6927 into r2.0.0rc1 Jul 12, 2024
115 checks passed
@cuichenx cuichenx deleted the chcui/gemma2_release branch July 12, 2024 05:54
github-actions bot pushed a commit that referenced this pull request Jul 12, 2024
* gemma2 initial commit

Signed-off-by: Chen Cui <[email protected]>

* enable conversion on cpu

Signed-off-by: Chen Cui <[email protected]>

* fix code scanning

Signed-off-by: Chen Cui <[email protected]>

* typo in config

Signed-off-by: Chen Cui <[email protected]>

* fix output layer and add comments

Signed-off-by: Chen Cui <[email protected]>

* refactor model customize to one function

Signed-off-by: Chen Cui <[email protected]>

* unpin transformers version

Signed-off-by: Chen Cui <[email protected]>

* Apply isort and black reformatting

Signed-off-by: cuichenx <[email protected]>

---------

Signed-off-by: Chen Cui <[email protected]>
Signed-off-by: cuichenx <[email protected]>
Co-authored-by: cuichenx <[email protected]>
@Emperorizzis
Copy link

Emperorizzis commented Jul 12, 2024

Hi! Thank you for your response! After our testing, we found that there seems to be two bugs :

1) Sliding Window

def get_swa(seq_q, seq_kv, w):
    """Create the equivalent attention mask fro SWA in [seq_q, seq_kv] shape"""
    m = torch.ones(seq_q, seq_kv, dtype=torch.bool, device="cuda")
    ### original
    # mu = torch.triu(m, diagonal=seq_kv - seq_q - w[0])
    ### after modified
    mu = torch.triu(m, diagonal=seq_kv - seq_q - w[0] + 1)
    ml = torch.tril(mu, diagonal=seq_kv - seq_q + w[1])
    ml = ~ml
    return ml

For example:

【example】
seq_len=4
sliding_window=2

get_swa(4, 4, (2, 0))

original answer (×)
tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [ False, False, False,  True],
        [ True,  False, False, False]])

after modified (aligned with hf, √):
tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [ True, False, False,  True],
        [ True,  True, False, False]])

2) The odd and even layers of the sliding window are reversed

huggingface/transformers@a695c18

... ...

self.layer_number = max(1, layer_number)

# self.window_size = None
# if self.layer_number % 2 == 0:
#     self.window_size = config.window_size
self.sliding_window_size = None
### original
# if self.layer_number % 2 == 0:
### after modified
if self.layer_number % 2 != 0:
    self.sliding_window_size = (config.sliding_window_size, 0)

... ...

And we also found that adding or not adding "<bos>" token when continue pretraining Gemma Base model has a significant impact on the initial loss (possibly a difference of up to double).
Additionally, we have discovered that the Gemma Base model is highly sensitive to the data and the length of the data.

Hope this finding can help other people.

Thank you again for your work!

@ko3n1g ko3n1g mentioned this pull request Jul 18, 2024
2 tasks
ericharper added a commit that referenced this pull request Aug 19, 2024
* Gemma 2 (#9672)

* gemma2 initial commit

Signed-off-by: Chen Cui <[email protected]>

* enable conversion on cpu

Signed-off-by: Chen Cui <[email protected]>

* fix code scanning

Signed-off-by: Chen Cui <[email protected]>

* typo in config

Signed-off-by: Chen Cui <[email protected]>

* fix output layer and add comments

Signed-off-by: Chen Cui <[email protected]>

* refactor model customize to one function

Signed-off-by: Chen Cui <[email protected]>

* unpin transformers version

Signed-off-by: Chen Cui <[email protected]>

* Apply isort and black reformatting

Signed-off-by: cuichenx <[email protected]>

---------

Signed-off-by: Chen Cui <[email protected]>
Signed-off-by: cuichenx <[email protected]>
Co-authored-by: cuichenx <[email protected]>

* typo

Signed-off-by: Chen Cui <[email protected]>

* import in function to fix test

Signed-off-by: Chen Cui <[email protected]>

---------

Signed-off-by: Chen Cui <[email protected]>
Signed-off-by: cuichenx <[email protected]>
Co-authored-by: Chen Cui <[email protected]>
Co-authored-by: cuichenx <[email protected]>
Co-authored-by: Eric Harper <[email protected]>
Dido0o0 pushed a commit to Dido0o0/NeMo that referenced this pull request Aug 23, 2024
* Gemma 2 (NVIDIA#9672)

* gemma2 initial commit

Signed-off-by: Chen Cui <[email protected]>

* enable conversion on cpu

Signed-off-by: Chen Cui <[email protected]>

* fix code scanning

Signed-off-by: Chen Cui <[email protected]>

* typo in config

Signed-off-by: Chen Cui <[email protected]>

* fix output layer and add comments

Signed-off-by: Chen Cui <[email protected]>

* refactor model customize to one function

Signed-off-by: Chen Cui <[email protected]>

* unpin transformers version

Signed-off-by: Chen Cui <[email protected]>

* Apply isort and black reformatting

Signed-off-by: cuichenx <[email protected]>

---------

Signed-off-by: Chen Cui <[email protected]>
Signed-off-by: cuichenx <[email protected]>
Co-authored-by: cuichenx <[email protected]>

* typo

Signed-off-by: Chen Cui <[email protected]>

* import in function to fix test

Signed-off-by: Chen Cui <[email protected]>

---------

Signed-off-by: Chen Cui <[email protected]>
Signed-off-by: cuichenx <[email protected]>
Co-authored-by: Chen Cui <[email protected]>
Co-authored-by: cuichenx <[email protected]>
Co-authored-by: Eric Harper <[email protected]>
dimapihtar pushed a commit that referenced this pull request Aug 27, 2024
* gemma2 initial commit

Signed-off-by: Chen Cui <[email protected]>

* enable conversion on cpu

Signed-off-by: Chen Cui <[email protected]>

* fix code scanning

Signed-off-by: Chen Cui <[email protected]>

* typo in config

Signed-off-by: Chen Cui <[email protected]>

* fix output layer and add comments

Signed-off-by: Chen Cui <[email protected]>

* refactor model customize to one function

Signed-off-by: Chen Cui <[email protected]>

* unpin transformers version

Signed-off-by: Chen Cui <[email protected]>

* Apply isort and black reformatting

Signed-off-by: cuichenx <[email protected]>

---------

Signed-off-by: Chen Cui <[email protected]>
Signed-off-by: cuichenx <[email protected]>
Co-authored-by: cuichenx <[email protected]>
@longxudou
Copy link

longxudou commented Sep 2, 2024

@Emperorizzis @cuichenx I meet the same problem (high initial loss) for continual pre-training gemma2-2B model.
Could you share more details on how to correctly set up the pre-training? Many thanks for that :)

SFT scripts works well (training loss < 3), but pertaining loss (training loss > 20) is extremely high.

Experiments

The following experiments are running with the same docker nvcr.io/nvidia/nemo:24.05.llama3.1 and the same dataset databricks-dolly-15k.

image

SFT training:
Following the official document, I set up the SFT training.
[A] gemma2_2b_sft_bos_rc2:
-codebase: https://github.com/NVIDIA/NeMo/tree/r2.0.0rc1
-data: databricks-dolly-15k/train.jsonl
-add bos: Yes

[B]gemma2_2b_sft_0830_main:
-codebase: https://github.com/NVIDIA/NeMo/tree/1ce9089143b0136523cb08bb37941a35c9b08307
-data: databricks-dolly-15k/train.jsonl
-add bos: Yes

Continual pre-training
I didn't find the pre-training recipe for gemma. Thus I just preprocess the dataset with following cmd.

python ${NEMO_CODE_PATH}/scripts/nlp_language_modeling/preprocess_data_for_megatron.py \
    --input ${DATASET_PATH}/sample.jsonl \
    --json-keys text \
    --tokenizer-library huggingface \
    --dataset-impl mmap \
    --tokenizer-type ${MODEL_PATH}/gemma2_2b \
    --output-prefix ${DATASET_PATH}/sample_gemma2_preprocessed/sample \
    --append-eod \
    --workers=1

[C]gemma2_2b_sft_pretraining
-codebase: https://github.com/NVIDIA/NeMo/tree/1ce9089143b0136523cb08bb37941a35c9b08307
-data: databricks-dolly-15k/train.jsonl
-add bos: No

[D]gemma2_2b_sft_pretraining_bos
-codebase: https://github.com/NVIDIA/NeMo/tree/1ce9089143b0136523cb08bb37941a35c9b08307
-data: databricks-dolly-15k/train.jsonl
-add bos: Yes
I modify preprocess_data_for_megatron.py as following to append the BOS token.

  for sentence in Encoder.splitter.tokenize(text):
      sentence_ids = Encoder.tokenizer.text_to_ids(sentence)
      sentence_ids = [Encoder.tokenizer.bos_id] + sentence_ids
      if len(sentence_ids) > 0:
          doc_ids.append(sentence_ids)
  if len(doc_ids) > 0 and self.args.append_eod:
      doc_ids[-1].append(Encoder.tokenizer.eos_id)
  ids[key] = doc_ids

Observations

(1) SFT tuning is normal, but also differs with the codebase changes. => [A] and [B] adopts the same data and docker, but receives different loss curve.
(2) BOS matters a lot as @Emperorizzis said. It got further studied by https://unsloth.ai/blog/gemma-bugs.

Looking for help

Could @cuichenx provide a stable codebase version tag and the necessary guidance on how to run the NeMo for Gemma2 model for both pre-training? Thanks for your great work!

@cuichenx
Copy link
Collaborator Author

cuichenx commented Sep 4, 2024

Thanks for reporting these issues! I will look into them this week.

adityavavre pushed a commit to adityavavre/NeMo that referenced this pull request Sep 15, 2024
* Gemma 2 (NVIDIA#9672)

* gemma2 initial commit

Signed-off-by: Chen Cui <[email protected]>

* enable conversion on cpu

Signed-off-by: Chen Cui <[email protected]>

* fix code scanning

Signed-off-by: Chen Cui <[email protected]>

* typo in config

Signed-off-by: Chen Cui <[email protected]>

* fix output layer and add comments

Signed-off-by: Chen Cui <[email protected]>

* refactor model customize to one function

Signed-off-by: Chen Cui <[email protected]>

* unpin transformers version

Signed-off-by: Chen Cui <[email protected]>

* Apply isort and black reformatting

Signed-off-by: cuichenx <[email protected]>

---------

Signed-off-by: Chen Cui <[email protected]>
Signed-off-by: cuichenx <[email protected]>
Co-authored-by: cuichenx <[email protected]>

* typo

Signed-off-by: Chen Cui <[email protected]>

* import in function to fix test

Signed-off-by: Chen Cui <[email protected]>

---------

Signed-off-by: Chen Cui <[email protected]>
Signed-off-by: cuichenx <[email protected]>
Co-authored-by: Chen Cui <[email protected]>
Co-authored-by: cuichenx <[email protected]>
Co-authored-by: Eric Harper <[email protected]>
Signed-off-by: adityavavre <[email protected]>
monica-sekoyan pushed a commit that referenced this pull request Oct 14, 2024
* Gemma 2 (#9672)

* gemma2 initial commit

Signed-off-by: Chen Cui <[email protected]>

* enable conversion on cpu

Signed-off-by: Chen Cui <[email protected]>

* fix code scanning

Signed-off-by: Chen Cui <[email protected]>

* typo in config

Signed-off-by: Chen Cui <[email protected]>

* fix output layer and add comments

Signed-off-by: Chen Cui <[email protected]>

* refactor model customize to one function

Signed-off-by: Chen Cui <[email protected]>

* unpin transformers version

Signed-off-by: Chen Cui <[email protected]>

* Apply isort and black reformatting

Signed-off-by: cuichenx <[email protected]>

---------

Signed-off-by: Chen Cui <[email protected]>
Signed-off-by: cuichenx <[email protected]>
Co-authored-by: cuichenx <[email protected]>

* typo

Signed-off-by: Chen Cui <[email protected]>

* import in function to fix test

Signed-off-by: Chen Cui <[email protected]>

---------

Signed-off-by: Chen Cui <[email protected]>
Signed-off-by: cuichenx <[email protected]>
Co-authored-by: Chen Cui <[email protected]>
Co-authored-by: cuichenx <[email protected]>
Co-authored-by: Eric Harper <[email protected]>
hainan-xv pushed a commit to hainan-xv/NeMo that referenced this pull request Nov 5, 2024
* Gemma 2 (NVIDIA#9672)

* gemma2 initial commit

Signed-off-by: Chen Cui <[email protected]>

* enable conversion on cpu

Signed-off-by: Chen Cui <[email protected]>

* fix code scanning

Signed-off-by: Chen Cui <[email protected]>

* typo in config

Signed-off-by: Chen Cui <[email protected]>

* fix output layer and add comments

Signed-off-by: Chen Cui <[email protected]>

* refactor model customize to one function

Signed-off-by: Chen Cui <[email protected]>

* unpin transformers version

Signed-off-by: Chen Cui <[email protected]>

* Apply isort and black reformatting

Signed-off-by: cuichenx <[email protected]>

---------

Signed-off-by: Chen Cui <[email protected]>
Signed-off-by: cuichenx <[email protected]>
Co-authored-by: cuichenx <[email protected]>

* typo

Signed-off-by: Chen Cui <[email protected]>

* import in function to fix test

Signed-off-by: Chen Cui <[email protected]>

---------

Signed-off-by: Chen Cui <[email protected]>
Signed-off-by: cuichenx <[email protected]>
Co-authored-by: Chen Cui <[email protected]>
Co-authored-by: cuichenx <[email protected]>
Co-authored-by: Eric Harper <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants