Skip to content

Commit

Permalink
PEFT Inference (#11030) (#11044)
Browse files Browse the repository at this point in the history
* adapter inference first commit



* Apply isort and black reformatting



* Fix yaml serialization



* add copyright header



* Apply isort and black reformatting



* revert accidental commit



---------

Signed-off-by: Chen Cui <[email protected]>
Signed-off-by: cuichenx <[email protected]>
Signed-off-by: Hemil Desai <[email protected]>
Co-authored-by: Chen Cui <[email protected]>
Co-authored-by: cuichenx <[email protected]>
Co-authored-by: Hemil Desai <[email protected]>
Co-authored-by: Pablo Garay <[email protected]>
  • Loading branch information
5 people authored Oct 25, 2024
1 parent fe4d09b commit 4b614d7
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 18 deletions.
49 changes: 42 additions & 7 deletions nemo/collections/llm/inference/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from pathlib import Path
from typing import Optional
from typing import Optional, Union

import pytorch_lightning as pl
import torch
Expand All @@ -15,8 +30,9 @@
from pytorch_lightning.trainer.states import TrainerFn

import nemo.lightning as nl
from nemo.collections.llm.peft import LoRA
from nemo.lightning import io
from nemo.lightning.ckpt_utils import ckpt_to_context_subdir
from nemo.lightning.ckpt_utils import ADAPTER_META_FILENAME, ckpt_to_context_subdir, ckpt_to_weights_subdir
from nemo.lightning.pytorch.strategies.megatron_strategy import MegatronStrategy
from nemo.lightning.pytorch.strategies.utils import RestoreConfig

Expand All @@ -39,11 +55,21 @@ def tokenize(self, prompt):
def _setup_trainer_and_restore_model(path: Path, trainer: nl.Trainer, model: pl.LightningModule):
assert isinstance(trainer.strategy, MegatronStrategy), "Only MegatronStrategy is supported for trainer.strategy."
assert trainer.strategy.context_parallel_size <= 1, "Context parallelism is not supported for inference."
restore_config = RestoreConfig(
path=path,
load_model_state=True,
load_optim_state=False,
)
if (adapter_meta_path := ckpt_to_weights_subdir(path) / ADAPTER_META_FILENAME).exists():
with open(adapter_meta_path, "r") as f:
metadata = json.load(f)
restore_config = RestoreConfig(
path=metadata['model_ckpt_path'],
load_model_state=True,
load_optim_state=False,
)
else:
restore_config = RestoreConfig(
path=path,
load_model_state=True,
load_optim_state=False,
)

trainer.strategy.restore_config = restore_config
trainer.strategy._setup_optimizers = False
trainer.ckpt_path = None
Expand All @@ -60,6 +86,15 @@ def _setup_trainer_and_restore_model(path: Path, trainer: nl.Trainer, model: pl.
trainer.strategy.trainer = trainer
trainer.strategy.selective_restore()

lora: Union[io.TrainerContext, LoRA] = io.load_context(ckpt_to_context_subdir(path), "model.model_transform")
if isinstance(lora, LoRA):
model = lora(model)
adapter_sharded_state_dict = {k: v for k, v in model.sharded_state_dict().items() if ".adapter." in k}
adapter_state = trainer.strategy.checkpoint_io.load_checkpoint(
ckpt_to_weights_subdir(path), sharded_state_dict=adapter_sharded_state_dict
)
trainer.strategy.load_model_state_dict(adapter_state, strict=False)


def setup_model_and_tokenizer(
path: Path,
Expand Down
15 changes: 15 additions & 0 deletions nemo/lightning/ckpt_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,25 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path
from typing import Union

# NeMo2 checkpoint structure is a checkpoint directory, with a WEIGHTS_PATH and CONTEXT_PATH subdirectory structure.
# WEIGHTS_PATH stores the weights while CONTEXT_PATH stores the hyper-parameters.
WEIGHTS_PATH: str = "weights"
CONTEXT_PATH: str = "context"
ADAPTER_META_FILENAME = "adapter_metadata.json"


def idempotent_path_append(base_dir: Union[str, Path], suffix) -> Path:
Expand Down
33 changes: 31 additions & 2 deletions nemo/lightning/io/mixin.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import inspect
import json
Expand Down Expand Up @@ -66,6 +80,21 @@ def _partial_representer_with_defaults(dumper, data):
return _config_representer_with_defaults(dumper, data, type_name="Partial")


def _safe_object_representer(dumper, data):
if not inspect.isclass(data):
cls = data.__class__
call = True
else:
cls = data
call = False

value = {
"_target_": f"{inspect.getmodule(cls).__name__}.{cls.__qualname__}", # type: ignore
"_call_": call,
}
return dumper.represent_data(value)


class IOMixin:
"""
A mixin class designed to capture the arguments passed to the `__init__` method,
Expand Down Expand Up @@ -208,14 +237,14 @@ def _io_dump_yaml(self, io: config_lib.Config, attrs: list[str]):
original_representers = yaml.SafeDumper.yaml_representers.copy()

from nemo_run.config import Config, Partial
from nemo_run.core.serialization.yaml import YamlSerializer, _function_representer
from nemo_run.core.serialization.yaml import YamlSerializer

yaml.SafeDumper.add_representer(config_lib.Config, _config_representer_with_defaults)
yaml.SafeDumper.add_representer(partial.Partial, _partial_representer_with_defaults)
yaml.SafeDumper.add_representer(Config, _config_representer_with_defaults)
yaml.SafeDumper.add_representer(Partial, _partial_representer_with_defaults)

yaml.SafeDumper.add_multi_representer(object, _function_representer)
yaml.SafeDumper.add_multi_representer(object, _safe_object_representer)

serializer = YamlSerializer()
result = {}
Expand Down
11 changes: 5 additions & 6 deletions nemo/lightning/pytorch/callbacks/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from pytorch_lightning.trainer.states import TrainerFn
from typing_extensions import override

from nemo.lightning.ckpt_utils import ADAPTER_META_FILENAME
from nemo.lightning.io.mixin import IOMixin
from nemo.lightning.io.pl import ckpt_to_dir
from nemo.lightning.pytorch.callbacks.model_transform import ModelTransform
from nemo.utils import logging
Expand All @@ -34,10 +36,7 @@
from megatron.core.dist_checkpointing.mapping import ShardedStateDict


_ADAPTER_META_FILENAME = "adapter_metadata.json"


class PEFT(ABC, ModelTransform):
class PEFT(IOMixin, ABC, ModelTransform):
"""Abstract base class for Parameter-Efficient Fine-Tuning (PEFT) methods.
This class defines the interface for PEFT methods, which are used to fine-tune
Expand Down Expand Up @@ -298,7 +297,7 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio

if is_global_rank_zero():
metadata = {"model_ckpt_path": str(self.model_ckpt_path)}
adapter_meta_path = ckpt_to_dir(path) / _ADAPTER_META_FILENAME
adapter_meta_path = ckpt_to_dir(path) / ADAPTER_META_FILENAME
with open(adapter_meta_path, "w") as f:
json.dump(metadata, f)
return request
Expand Down Expand Up @@ -332,7 +331,7 @@ def load_checkpoint(

assert self.checkpoint_io is not None

adapter_meta_path = ckpt_to_dir(path) / _ADAPTER_META_FILENAME
adapter_meta_path = ckpt_to_dir(path) / ADAPTER_META_FILENAME
adapter_ckpt = None
if getattr(path, "base_model_path", None):
## PEFT Resume, FIRST TIME
Expand Down
5 changes: 2 additions & 3 deletions nemo/lightning/resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from nemo.lightning import io
from nemo.lightning.base import NEMO_MODELS_CACHE
from nemo.lightning.ckpt_utils import ADAPTER_META_FILENAME
from nemo.lightning.pytorch.strategies.utils import RestoreConfig
from nemo.utils import logging
from nemo.utils.app_state import AppState
Expand Down Expand Up @@ -279,9 +280,7 @@ def get_trainer_ckpt_path(self, model: Optional[io.ConnectorMixin] = None) -> Op
if self.adapter_path:
return AdapterPath(Path(self.adapter_path), base_model_path=checkpoint)
else:
from nemo.lightning.pytorch.callbacks.peft import _ADAPTER_META_FILENAME

adapter_meta_path = checkpoint / _ADAPTER_META_FILENAME
adapter_meta_path = checkpoint / ADAPTER_META_FILENAME
if adapter_meta_path.exists():
base_model_path = self._resume_peft(adapter_meta_path, model)
return AdapterPath(checkpoint, base_model_path=base_model_path)
Expand Down

0 comments on commit 4b614d7

Please sign in to comment.