Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: akoumpa <[email protected]>
  • Loading branch information
akoumpa committed Oct 23, 2024
1 parent 47a895f commit 73913d2
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 30 deletions.
25 changes: 17 additions & 8 deletions examples/llm/sft/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ def _create_dataloader(self, dataset, **kwargs) -> DataLoader:
**kwargs1,
)


def mk_hf_dataset(tokenizer):
EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN
EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN

def formatting_prompts_func(examples):
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
Expand All @@ -59,20 +61,25 @@ def formatting_prompts_func(examples):
### Response:
{}"""
instruction = examples["context"]
input = examples["question"]
output = examples["answers"]['text']
input = examples["question"]
output = examples["answers"]['text']
if isinstance(output, list):
output = output[0]
text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
ans = tokenizer(text)
tokens = ans['input_ids']
return {'tokens': tokens, 'labels': tokens[1:] + [tokens[-1]], }
return {
'tokens': tokens,
'labels': tokens[1:] + [tokens[-1]],
}

from datasets import load_dataset
dataset = load_dataset("rajpurkar/squad", split = "train")
dataset = dataset.map(formatting_prompts_func, batched = False, batch_size = 2)

dataset = load_dataset("rajpurkar/squad", split="train")
dataset = dataset.map(formatting_prompts_func, batched=False, batch_size=2)
return dataset


def squad(tokenizer) -> pl.LightningDataModule:
return SquadDataModuleWithMbs(
tokenizer=tokenizer,
Expand All @@ -83,13 +90,16 @@ def squad(tokenizer) -> pl.LightningDataModule:
sanity_check_dist_workers=False,
)


class HfAutoModelPeft(llm.HfAutoModel):
def configure_model(self):
super().configure_model()
self.model.eval()
from lora import apply_lora_to_model

apply_lora_to_model(self.model)


if __name__ == '__main__':
import argparse

Expand Down Expand Up @@ -119,8 +129,7 @@ def configure_model(self):
llm.api.finetune(
model=HfAutoModelPeft(args.model),
data=llm.HfDatasetDataModule(
mk_hf_dataset(tokenizer.tokenizer),
pad_token_id=tokenizer.tokenizer.eos_token_id
mk_hf_dataset(tokenizer.tokenizer), pad_token_id=tokenizer.tokenizer.eos_token_id
),
trainer=nl.Trainer(
devices=args.devices,
Expand Down
10 changes: 7 additions & 3 deletions examples/llm/sft/lora.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import torch.nn as nn
import torch
import math

import torch
import torch.nn as nn


class LoraLinear(nn.Module):
def __init__(self, orig_linear, r=8, lora_alpha=32, lora_dropout=0.1):
super(LoraLinear, self).__init__()
Expand Down Expand Up @@ -31,6 +33,7 @@ def forward(self, x):
lora_res = lora_res @ self.lora_b.t()
return res + lora_res * self.scale


# Helper funcs
def get_parent_module(model, module_name):
print('get_parent_module module_name= ' + str(module_name))
Expand All @@ -40,11 +43,12 @@ def get_parent_module(model, module_name):
parent = getattr(parent, name)
return parent


def apply_lora_to_model(model, r=8, lora_alpha=32, lora_dropout=0.1):
for name, module in model.named_modules():
if isinstance(module, nn.Linear) and '_proj' in name:
parent_module = get_parent_module(model, name)
target_attr = name.split('.')[-1]
orig_lin = getattr(parent_module, target_attr)
lora_linear = LoraLinear(orig_lin, r, lora_alpha, lora_dropout)
setattr(parent_module, target_attr, lora_linear)
setattr(parent_module, target_attr, lora_linear)
2 changes: 1 addition & 1 deletion nemo/collections/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
from nemo.collections.llm.gpt.data import (
DollyDataModule,
FineTuningDataModule,
HfDatasetDataModule,
MockDataModule,
PreTrainingDataModule,
SquadDataModule,
HfDatasetDataModule,
)
from nemo.collections.llm.gpt.data.api import dolly, mock, squad
from nemo.collections.llm.gpt.model import (
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/llm/gpt/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@

from nemo.collections.llm.gpt.data.dolly import DollyDataModule
from nemo.collections.llm.gpt.data.fine_tuning import FineTuningDataModule
from nemo.collections.llm.gpt.data.hf_dataset import HfDatasetDataModule
from nemo.collections.llm.gpt.data.mock import MockDataModule
from nemo.collections.llm.gpt.data.pre_training import PreTrainingDataModule
from nemo.collections.llm.gpt.data.squad import SquadDataModule
from nemo.collections.llm.gpt.data.hf_dataset import HfDatasetDataModule

__all__ = [
"FineTuningDataModule",
"SquadDataModule",
"DollyDataModule",
"MockDataModule",
"PreTrainingDataModule",
"HfDatasetDataModule"
"HfDatasetDataModule",
]
23 changes: 10 additions & 13 deletions nemo/collections/llm/gpt/data/hf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ class HfDatasetDataModule(pl.LightningDataModule):
def __init__(
self,
dataset,
num_workers = 2,
pin_memory = True,
persistent_workers = True,
micro_batch_size = 2,
global_batch_size = 2,
pad_token_id = 0,
use_mcore_sampler = False,
mcore_dataloader_type = 'cyclic',
num_workers=2,
pin_memory=True,
persistent_workers=True,
micro_batch_size=2,
global_batch_size=2,
pad_token_id=0,
use_mcore_sampler=False,
mcore_dataloader_type='cyclic',
) -> None:
super().__init__()
assert pad_token_id is not None
Expand Down Expand Up @@ -56,10 +56,8 @@ def extract_key_from_dicts(batch, key):

def pad_within_micro(batch, pad_token_id):
max_len = max(map(len, batch))
return [
item + [pad_token_id] * (max_len - len(item))
for item in batch
]
return [item + [pad_token_id] * (max_len - len(item)) for item in batch]

return {
key: batchify(
torch.LongTensor(
Expand Down Expand Up @@ -103,4 +101,3 @@ def train_dataloader(self, collate_fn=None):
rank=rank,
world_size=world_size,
)

Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from transformers import AutoModelForCausalLM

from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.lightning import io
from nemo.collections.llm import fn
from nemo.lightning import io


def _extract_non_bias_params(model):
Expand Down Expand Up @@ -66,6 +66,7 @@ def configure_model(self):
self.model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype='auto')
else:
from transformers import AutoConfig

config = AutoConfig.from_pretained(self.model_name)
self.model = AutoModelForCausalLM.from_config(config)
self.model.train()
Expand Down
5 changes: 3 additions & 2 deletions nemo/collections/llm/peft/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import re
import torch
import math
import re
from dataclasses import dataclass, field
from typing import List, Literal

import torch
from megatron.core import parallel_state
from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear
from torch import nn
Expand Down Expand Up @@ -107,6 +107,7 @@ def forward(self, x):
lora_res = self.dropout(lora_res)
return res + lora_res


@dataclass
class LoRA(PEFT):
"""
Expand Down

0 comments on commit 73913d2

Please sign in to comment.