Skip to content

Commit

Permalink
refactor wip
Browse files Browse the repository at this point in the history
Signed-off-by: HuiyingLi <[email protected]>
  • Loading branch information
HuiyingLi committed Oct 9, 2024
1 parent d2f9cf7 commit c527914
Show file tree
Hide file tree
Showing 7 changed files with 238 additions and 5 deletions.
73 changes: 73 additions & 0 deletions my/nemo2_lora_merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# import nemo_run as run
from nemo import lightning as nl
from nemo.collections import llm
from megatron.core.optimizer import OptimizerConfig
import torch
from pytorch_lightning.loggers import WandbLogger
import pytorch_lightning as pl

def logger() -> nl.NeMoLogger:
# ckpt = None
ckpt = nl.ModelCheckpoint(
#save_best_model=True,
save_last=True,
every_n_train_steps=10,
monitor="reduced_train_loss",
save_top_k=1,
save_on_train_epoch_end=True,
save_optim_on_train_end=True,
)

wandb = None
# wandb = WandbLogger(
# project="nemo2-squad",
# name='llama3-8b_lora-attn_test_api_wandb',
# )

return nl.NeMoLogger(
name="nemo2_peft",
log_dir="/workspace/peftmerge/exp/peft_iomixin0",
use_datetime_version=False, # must be false if using auto resume
ckpt=ckpt,
wandb=wandb
)

def trainer(devices=1) -> nl.Trainer:
strategy = nl.MegatronStrategy(
tensor_model_parallel_size=1,
)

return nl.Trainer(
devices=1,
max_steps=40,
accelerator="gpu",
strategy=strategy,
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
log_every_n_steps=1,
limit_val_batches=2,
val_check_interval=2,
num_sanity_val_steps=0,
)

def resume() -> nl.AutoResume:
return nl.AutoResume(
restore_config=nl.RestoreConfig(
path="hf://meta-llama/Meta-Llama-3-8B",
),
resume_if_exists=True,
# resume_ignore_no_checkpoint=True,
)

def llama3_8b() -> pl.LightningModule:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
return llm.LlamaModel(llm.Llama3Config8B(), tokenizer=tokenizer)


if __name__ == '__main__':
llm.peft.merge_lora(
model=llama3_8b(),
trainer=trainer(),
log=logger(),
resume=resume(),
)
4 changes: 4 additions & 0 deletions nemo/collections/llm/gpt/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,10 @@ def validation_step(self, batch, batch_idx=None) -> torch.Tensor:

return self.forward_step(batch)

def predict_step(self, batch, batch_idx=None) -> torch.Tensor:

return self.forward_step(batch)

@property
def training_loss_reduction(self) -> MaskedTokenLossReduction:
if not self._training_loss_reduction:
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/llm/peft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo.collections.llm.peft.api import gpt_lora
from nemo.collections.llm.peft.api import gpt_lora, merge_lora
from nemo.collections.llm.peft.lora import LoRA

__all__ = ["LoRA", "gpt_lora"]
__all__ = ["LoRA", "gpt_lora", "merge_lora"]
51 changes: 50 additions & 1 deletion nemo/collections/llm/peft/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,58 @@
from nemo.lightning.pytorch.callbacks.peft import PEFT


import json
import os
from copy import deepcopy
from pathlib import Path
from typing import Any, Callable, Optional, Union

import nemo_run as run
import pytorch_lightning as pl
from typing_extensions import Annotated

from nemo.lightning import AutoResume, NeMoLogger, OptimizerModule, Trainer, io
from nemo.lightning.pytorch.callbacks import PEFT, ModelTransform
from nemo.utils import logging
from nemo.lightning.io import load_context, ModelConnector
from nemo.collections.llm.api import _set_with_io
from nemo.collections import llm




@factory
def gpt_lora() -> PEFT:
return LoRA()

def merge_lora(
model: pl.LightningModule,
trainer: Trainer,
log: Annotated[Optional[NeMoLogger], run.Config[NeMoLogger]] = None,
resume: Annotated[Optional[AutoResume], run.Config[AutoResume]] = None,
):
_log = log or NeMoLogger()
#logger will setup paths in trainer
app_state = _log.setup(
trainer,
resume_if_exists=getattr(resume, "resume_if_exists", False),
task_config=None,
)
#resume will import hf LLM to default path if it doesn't already exists.
#if exists -> ok;
# if doesnt exist -> will download to a new(maybe default) dir;
# if new dir == old dir -> ok, otherwise throw error
resume.setup(trainer, model)
lora = load_context(resume.get_context_path(), "model.model_transform")
if lora:
_set_with_io(model, "model_transform", lora)
trainer.callbacks.append(lora)
import pdb; pdb.set_trace()
#need to init the lora transform from checkpoint dir

predict_dataloader = llm.SquadDataModule(seq_length=2048, micro_batch_size=2, global_batch_size=8, num_workers=0)
trainer.predict(model, dataloaders=predict_dataloader)


__all__ = ["gpt_lora"]
__all__ = ["gpt_lora",
"merge_lora"]
102 changes: 102 additions & 0 deletions nemo/collections/llm/peft/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,19 @@

from nemo.lightning.pytorch.callbacks.peft import PEFT, AdapterWrapper
from nemo.utils import logging
from nemo.lightning.io.mixin import IOMixin
from pytorch_lightning.trainer.states import TrainerFn
import torch

from nemo import lightning as nl
from nemo.collections import llm
from typing import Any, Dict, List
import torch
from nemo.lightning.io import load_context, ModelConnector
from nemo.lightning.megatron_parallel import MegatronParallel
from nemo.utils.get_rank import is_global_rank_zero
from pathlib import Path
from nemo.utils import logging


class AdapterParallelAdd(AdapterWrapper):
Expand Down Expand Up @@ -156,3 +169,92 @@ def transform(self, m: nn.Module, name=None, prefix=None):
)
return AdapterParallelAdd(m, adapter)
return m

def apply_transform(self, trainer):
super().apply_transform(trainer)
import pdb; pdb.set_trace()
if trainer.state.fn == TrainerFn.PREDICTING:
base_sharded_dict = {k:v for k,v in trainer.model.state_dict().items() if 'adapter' not in k and 'extra_state' not in k }
lora_sharded_dict = {k:v.data.data for k, v in trainer.model.sharded_state_dict().items() if 'adapter' in k and 'extra_state' not in k}
merged_weights = self._merge_lora_weights(base_model_state_dict = base_sharded_dict,
lora_state_dict = lora_sharded_dict,
num_layers = trainer.model._modules['0'].config.num_layers,
tp_size = trainer.strategy.tensor_model_parallel_size,
rank =torch.distributed.get_rank())
trainer.model.load_state_dict(merged_weights)

def _merge_lora_weights(self, base_model_state_dict: Dict[str, Any],
lora_state_dict: Dict[str, Any],
num_layers: int,
tp_size: int,
rank: int):
mcore_layer_to_lora = {}
"""
'self_attention.linear_qkv.adapter.linear_in.weight'
'self_attention.linear_qkv.adapter.linear_out.weight',
'self_attention.linear_proj.adapter.linear_in.weight'
'self_attention.linear_proj.adapter.linear_out.weight',
'mlp.linear_fc1.adapter.linear_in.weight',
'mlp.linear_fc1.adapter.linear_out.weight',
'mlp.linear_fc2.adapter.linear_in.weight',
'mlp.linear_fc2.adapter.linear_out.weight',
"""

mcore_layer_to_lora["attention_qkv"] = {
"base_model_layer": "self_attention.linear_qkv.weight",
"lora_in": "self_attention.linear_qkv.adapter.linear_in.weight",
"lora_out": "self_attention.linear_qkv.adapter.linear_out.weight",
}
mcore_layer_to_lora["attention_dense"] = {
"base_model_layer": "self_attention.linear_proj.weight",
"lora_in": "self_attention.linear_proj.adapter.linear_in.weight",
"lora_out": "self_attention.linear_proj.adapter.linear_out.weight",
}
mcore_layer_to_lora["mlp_fc1"] = {
"base_model_layer": "mlp.linear_fc1.weight",
"lora_in": "mlp.linear_fc1.adapter.linear_in.weight",
"lora_out": "mlp.linear_fc1.adapter.linear_out.weight",
}
mcore_layer_to_lora["mlp_fc2"] = {
"base_model_layer": "mlp.linear_fc2.weight",
"lora_in": "mlp.linear_fc2.adapter.linear_in.weight",
"lora_out": "mlp.linear_fc2.adapter.linear_out.weight",
}

for nl in range(num_layers):
for key in mcore_layer_to_lora.keys():
##TODO: prefix should be model or module or 0.module?
key_base = f'0.module.decoder.layers.{nl}.{mcore_layer_to_lora[key]["base_model_layer"]}'
key_lora_in = f'module.decoder.layers.{nl}.{mcore_layer_to_lora[key]["lora_in"]}'
key_lora_out = f'module.decoder.layers.{nl}.{mcore_layer_to_lora[key]["lora_out"]}'
if key_lora_in in lora_state_dict and key_lora_out in lora_state_dict:
if tp_size > 1:
gathered_lora_in = [torch.zeros_like(lora_state_dict[key_lora_in]) for _ in range(tp_size)]
gathered_lora_out = [torch.zeros_like(lora_state_dict[key_lora_out]) for _ in range(tp_size)]
torch.distributed.all_gather(gathered_lora_in, lora_state_dict[key_lora_in])
torch.distributed.all_gather(gathered_lora_out, lora_state_dict[key_lora_out])

if is_global_rank_zero():
print(f"RANK{torch.distributed.get_rank()} has {key_lora_in} shape {lora_state_dict[key_lora_in].shape}") #gathered lorain{gathered_lora_in}")
print(f"RANK{torch.distributed.get_rank()} has {key_lora_out} shape {lora_state_dict[key_lora_out].shape}") #gathered loraout {gathered_lora_out}")
## TODO: Who decides what dim they split?
tp_dim_lora_in = 1 if key in ["attention_dense", 'mlp_fc2'] else 0
wt_lora_in = torch.cat(gathered_lora_in, dim=tp_dim_lora_in).float()
wt_lora_out = torch.cat(gathered_lora_out, dim=0).float()
wt_lora = wt_lora_out @ wt_lora_in
tp_dim_base = 0 if key in ["attention_qkv", "mlp_fc1"] else 1
wt_lora_current_rank = torch.chunk(wt_lora, tp_size, dim=tp_dim_base)[rank]
else: #when tp==1
wt_lora_in = lora_state_dict[key_lora_in]
wt_lora_out = lora_state_dict[key_lora_out]
wt_lora = wt_lora_out @ wt_lora_in
wt_lora_current_rank = wt_lora

wt_base = base_model_state_dict[key_base]
logging.info(f"Full {key_base} wt_lora_in {wt_lora_in.shape}, wt_lora_out {wt_lora_out.shape}, wt_lora {wt_lora.shape}, wt_base {wt_base.shape}")


base_model_state_dict[key_base] = (wt_base.float() + wt_lora_current_rank.to(wt_base.device)).type_as(wt_base)
logging.info(f'merging for weight {key_base}')

return base_model_state_dict
3 changes: 3 additions & 0 deletions nemo/lightning/pytorch/callbacks/model_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._maybe_apply_transform(trainer)

def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._maybe_apply_transform(trainer)

def _maybe_apply_transform(self, trainer):
if self._needs_to_call:
self.apply_transform(trainer)
Expand Down
6 changes: 4 additions & 2 deletions nemo/lightning/pytorch/callbacks/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from nemo.lightning.io.pl import ckpt_to_dir
from nemo.lightning.pytorch.callbacks.model_transform import ModelTransform
from nemo.utils import logging
from nemo.lightning.io.mixin import IOMixin


if TYPE_CHECKING:
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
Expand All @@ -35,7 +37,7 @@
_ADAPTER_META_FILENAME = "adapter_metadata.json"


class PEFT(ABC, ModelTransform):
class PEFT(IOMixin, ABC, ModelTransform):
"""Abstract base class for Parameter-Efficient Fine-Tuning (PEFT) methods.
This class defines the interface for PEFT methods, which are used to fine-tune
Expand Down Expand Up @@ -119,7 +121,7 @@ def apply_transform(self, trainer):
k: v for k, v in trainer.model.sharded_state_dict().items() if self.adapter_key_filter(k)
}

if hasattr(trainer.strategy, "init_model_parallel"):
if trainer.state.fn == TrainerFn.FITTING and hasattr(trainer.strategy, "init_model_parallel"):
logging.info("Initializing model parallel")
trainer.strategy.init_model_parallel()

Expand Down

0 comments on commit c527914

Please sign in to comment.