Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandros Koumparoulis <[email protected]>
  • Loading branch information
akoumpa committed Oct 23, 2024
1 parent 8019edc commit bd8fee0
Showing 1 changed file with 71 additions and 13 deletions.
84 changes: 71 additions & 13 deletions examples/llm/sft/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,85 @@
from nemo.collections import llm


class SquadDataModuleWithPthDataloader(llm.SquadDataModule):
class SquadDataModuleWithMbs(llm.SquadDataModule):
def _create_dataloader(self, dataset, **kwargs) -> DataLoader:
return DataLoader(
dataset,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
persistent_workers=self.persistent_workers,
collate_fn=dataset.collate_fn,
batch_size=self.micro_batch_size,
**kwargs,
from nemo.lightning.data import add_megatron_sampler

kwargs1 = {
'consumed_samples': 0,
'dataloader_type': 'single',
'drop_last': True,
'pad_samples_to_global_batch_size': False,
}
return add_megatron_sampler(
DataLoader(
dataset,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
persistent_workers=self.persistent_workers,
collate_fn=dataset.collate_fn,
**kwargs,
),
self.micro_batch_size,
self.global_batch_size,
**kwargs1,
)


def mk_hf_dataset(tokenizer):
EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN

def formatting_prompts_func(examples):
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{}
### Input:
{}
### Response:
{}"""
instruction = examples["context"]
input = examples["question"]
output = examples["answers"]['text']
if isinstance(output, list):
output = output[0]
text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
ans = tokenizer(text)
tokens = ans['input_ids']
return {
'tokens': tokens,
'labels': tokens[1:] + [tokens[-1]],
}

from datasets import load_dataset

dataset = load_dataset("rajpurkar/squad", split="train")
dataset = dataset.map(formatting_prompts_func, batched=False, batch_size=2)
return dataset


def squad(tokenizer) -> pl.LightningDataModule:
return SquadDataModuleWithPthDataloader(
return SquadDataModuleWithMbs(
tokenizer=tokenizer,
seq_length=2048,
micro_batch_size=2,
global_batch_size=128, # assert gbs == mbs * accumulate_grad_batches
global_batch_size=128,
num_workers=0,
sanity_check_dist_workers=False,
)


class HfAutoModelPeft(llm.HfAutoModel):
def configure_model(self):
super().configure_model()
self.model.eval()
from lora import apply_lora_to_model

apply_lora_to_model(self.model)


if __name__ == '__main__':
import argparse

Expand All @@ -69,10 +124,13 @@ def squad(tokenizer) -> pl.LightningDataModule:
# See: https://github.com/Lightning-AI/pytorch-lightning/blob/8ad3e29816a63d8ce5c00ac104b14729a4176f4f/src/lightning/pytorch/plugins/precision/fsdp.py#L81
grad_clip = None
use_dist_samp = False
tokenizer = llm.HfAutoModel.configure_tokenizer(args.model)

llm.api.finetune(
model=llm.HfAutoModelForCausalLM(args.model),
data=squad(llm.HfAutoModelForCausalLM.configure_tokenizer(args.model)),
model=HfAutoModelPeft(args.model),
data=llm.HfDatasetDataModule(
mk_hf_dataset(tokenizer.tokenizer), pad_token_id=tokenizer.tokenizer.eos_token_id
),
trainer=nl.Trainer(
devices=args.devices,
max_steps=args.max_steps,
Expand Down

0 comments on commit bd8fee0

Please sign in to comment.