From 4b614d70aa098bcf2f0ecc677fceaf7da9a24e9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?oliver=20k=C3=B6nig?= Date: Sat, 26 Oct 2024 01:48:24 +0200 Subject: [PATCH] PEFT Inference (#11030) (#11044) * adapter inference first commit * Apply isort and black reformatting * Fix yaml serialization * add copyright header * Apply isort and black reformatting * revert accidental commit --------- Signed-off-by: Chen Cui Signed-off-by: cuichenx Signed-off-by: Hemil Desai Co-authored-by: Chen Cui Co-authored-by: cuichenx Co-authored-by: Hemil Desai Co-authored-by: Pablo Garay --- nemo/collections/llm/inference/base.py | 49 ++++++++++++++++++++---- nemo/lightning/ckpt_utils.py | 15 ++++++++ nemo/lightning/io/mixin.py | 33 +++++++++++++++- nemo/lightning/pytorch/callbacks/peft.py | 11 +++--- nemo/lightning/resume.py | 5 +-- 5 files changed, 95 insertions(+), 18 deletions(-) diff --git a/nemo/collections/llm/inference/base.py b/nemo/collections/llm/inference/base.py index 0171f1c2dd5c..9c4da7940b70 100644 --- a/nemo/collections/llm/inference/base.py +++ b/nemo/collections/llm/inference/base.py @@ -1,5 +1,20 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json from pathlib import Path -from typing import Optional +from typing import Optional, Union import pytorch_lightning as pl import torch @@ -15,8 +30,9 @@ from pytorch_lightning.trainer.states import TrainerFn import nemo.lightning as nl +from nemo.collections.llm.peft import LoRA from nemo.lightning import io -from nemo.lightning.ckpt_utils import ckpt_to_context_subdir +from nemo.lightning.ckpt_utils import ADAPTER_META_FILENAME, ckpt_to_context_subdir, ckpt_to_weights_subdir from nemo.lightning.pytorch.strategies.megatron_strategy import MegatronStrategy from nemo.lightning.pytorch.strategies.utils import RestoreConfig @@ -39,11 +55,21 @@ def tokenize(self, prompt): def _setup_trainer_and_restore_model(path: Path, trainer: nl.Trainer, model: pl.LightningModule): assert isinstance(trainer.strategy, MegatronStrategy), "Only MegatronStrategy is supported for trainer.strategy." assert trainer.strategy.context_parallel_size <= 1, "Context parallelism is not supported for inference." - restore_config = RestoreConfig( - path=path, - load_model_state=True, - load_optim_state=False, - ) + if (adapter_meta_path := ckpt_to_weights_subdir(path) / ADAPTER_META_FILENAME).exists(): + with open(adapter_meta_path, "r") as f: + metadata = json.load(f) + restore_config = RestoreConfig( + path=metadata['model_ckpt_path'], + load_model_state=True, + load_optim_state=False, + ) + else: + restore_config = RestoreConfig( + path=path, + load_model_state=True, + load_optim_state=False, + ) + trainer.strategy.restore_config = restore_config trainer.strategy._setup_optimizers = False trainer.ckpt_path = None @@ -60,6 +86,15 @@ def _setup_trainer_and_restore_model(path: Path, trainer: nl.Trainer, model: pl. trainer.strategy.trainer = trainer trainer.strategy.selective_restore() + lora: Union[io.TrainerContext, LoRA] = io.load_context(ckpt_to_context_subdir(path), "model.model_transform") + if isinstance(lora, LoRA): + model = lora(model) + adapter_sharded_state_dict = {k: v for k, v in model.sharded_state_dict().items() if ".adapter." in k} + adapter_state = trainer.strategy.checkpoint_io.load_checkpoint( + ckpt_to_weights_subdir(path), sharded_state_dict=adapter_sharded_state_dict + ) + trainer.strategy.load_model_state_dict(adapter_state, strict=False) + def setup_model_and_tokenizer( path: Path, diff --git a/nemo/lightning/ckpt_utils.py b/nemo/lightning/ckpt_utils.py index 7d0c1735f030..a7b823f2b230 100644 --- a/nemo/lightning/ckpt_utils.py +++ b/nemo/lightning/ckpt_utils.py @@ -1,3 +1,17 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from pathlib import Path from typing import Union @@ -5,6 +19,7 @@ # WEIGHTS_PATH stores the weights while CONTEXT_PATH stores the hyper-parameters. WEIGHTS_PATH: str = "weights" CONTEXT_PATH: str = "context" +ADAPTER_META_FILENAME = "adapter_metadata.json" def idempotent_path_append(base_dir: Union[str, Path], suffix) -> Path: diff --git a/nemo/lightning/io/mixin.py b/nemo/lightning/io/mixin.py index 2d1162bb2156..656b2a2f970a 100644 --- a/nemo/lightning/io/mixin.py +++ b/nemo/lightning/io/mixin.py @@ -1,3 +1,17 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import functools import inspect import json @@ -66,6 +80,21 @@ def _partial_representer_with_defaults(dumper, data): return _config_representer_with_defaults(dumper, data, type_name="Partial") +def _safe_object_representer(dumper, data): + if not inspect.isclass(data): + cls = data.__class__ + call = True + else: + cls = data + call = False + + value = { + "_target_": f"{inspect.getmodule(cls).__name__}.{cls.__qualname__}", # type: ignore + "_call_": call, + } + return dumper.represent_data(value) + + class IOMixin: """ A mixin class designed to capture the arguments passed to the `__init__` method, @@ -208,14 +237,14 @@ def _io_dump_yaml(self, io: config_lib.Config, attrs: list[str]): original_representers = yaml.SafeDumper.yaml_representers.copy() from nemo_run.config import Config, Partial - from nemo_run.core.serialization.yaml import YamlSerializer, _function_representer + from nemo_run.core.serialization.yaml import YamlSerializer yaml.SafeDumper.add_representer(config_lib.Config, _config_representer_with_defaults) yaml.SafeDumper.add_representer(partial.Partial, _partial_representer_with_defaults) yaml.SafeDumper.add_representer(Config, _config_representer_with_defaults) yaml.SafeDumper.add_representer(Partial, _partial_representer_with_defaults) - yaml.SafeDumper.add_multi_representer(object, _function_representer) + yaml.SafeDumper.add_multi_representer(object, _safe_object_representer) serializer = YamlSerializer() result = {} diff --git a/nemo/lightning/pytorch/callbacks/peft.py b/nemo/lightning/pytorch/callbacks/peft.py index 15d0dd8ac2ab..3d70062c48bd 100644 --- a/nemo/lightning/pytorch/callbacks/peft.py +++ b/nemo/lightning/pytorch/callbacks/peft.py @@ -25,6 +25,8 @@ from pytorch_lightning.trainer.states import TrainerFn from typing_extensions import override +from nemo.lightning.ckpt_utils import ADAPTER_META_FILENAME +from nemo.lightning.io.mixin import IOMixin from nemo.lightning.io.pl import ckpt_to_dir from nemo.lightning.pytorch.callbacks.model_transform import ModelTransform from nemo.utils import logging @@ -34,10 +36,7 @@ from megatron.core.dist_checkpointing.mapping import ShardedStateDict -_ADAPTER_META_FILENAME = "adapter_metadata.json" - - -class PEFT(ABC, ModelTransform): +class PEFT(IOMixin, ABC, ModelTransform): """Abstract base class for Parameter-Efficient Fine-Tuning (PEFT) methods. This class defines the interface for PEFT methods, which are used to fine-tune @@ -298,7 +297,7 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio if is_global_rank_zero(): metadata = {"model_ckpt_path": str(self.model_ckpt_path)} - adapter_meta_path = ckpt_to_dir(path) / _ADAPTER_META_FILENAME + adapter_meta_path = ckpt_to_dir(path) / ADAPTER_META_FILENAME with open(adapter_meta_path, "w") as f: json.dump(metadata, f) return request @@ -332,7 +331,7 @@ def load_checkpoint( assert self.checkpoint_io is not None - adapter_meta_path = ckpt_to_dir(path) / _ADAPTER_META_FILENAME + adapter_meta_path = ckpt_to_dir(path) / ADAPTER_META_FILENAME adapter_ckpt = None if getattr(path, "base_model_path", None): ## PEFT Resume, FIRST TIME diff --git a/nemo/lightning/resume.py b/nemo/lightning/resume.py index 40b4aa704575..412ca8665b84 100644 --- a/nemo/lightning/resume.py +++ b/nemo/lightning/resume.py @@ -23,6 +23,7 @@ from nemo.lightning import io from nemo.lightning.base import NEMO_MODELS_CACHE +from nemo.lightning.ckpt_utils import ADAPTER_META_FILENAME from nemo.lightning.pytorch.strategies.utils import RestoreConfig from nemo.utils import logging from nemo.utils.app_state import AppState @@ -279,9 +280,7 @@ def get_trainer_ckpt_path(self, model: Optional[io.ConnectorMixin] = None) -> Op if self.adapter_path: return AdapterPath(Path(self.adapter_path), base_model_path=checkpoint) else: - from nemo.lightning.pytorch.callbacks.peft import _ADAPTER_META_FILENAME - - adapter_meta_path = checkpoint / _ADAPTER_META_FILENAME + adapter_meta_path = checkpoint / ADAPTER_META_FILENAME if adapter_meta_path.exists(): base_model_path = self._resume_peft(adapter_meta_path, model) return AdapterPath(checkpoint, base_model_path=base_model_path)