Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: akoumpa <[email protected]>
  • Loading branch information
akoumpa committed Oct 23, 2024
1 parent 9d39bae commit c324d48
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 14 deletions.
22 changes: 13 additions & 9 deletions examples/llm/peft/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@


def mk_hf_dataset(tokenizer):
EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN
EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN

def formatting_prompts_func(examples):
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
Expand All @@ -33,18 +34,22 @@ def formatting_prompts_func(examples):
### Response:
{}"""
instruction = examples["context"]
input = examples["question"]
output = examples["answers"]['text']
input = examples["question"]
output = examples["answers"]['text']
if isinstance(output, list):
output = output[0]
text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
ans = tokenizer(text)
tokens = ans['input_ids']
return {'tokens': tokens, 'labels': tokens[1:] + [tokens[-1]], }
return {
'tokens': tokens,
'labels': tokens[1:] + [tokens[-1]],
}

from datasets import load_dataset
dataset = load_dataset("rajpurkar/squad", split = "train")
dataset = dataset.map(formatting_prompts_func, batched = False, batch_size = 2)

dataset = load_dataset("rajpurkar/squad", split="train")
dataset = dataset.map(formatting_prompts_func, batched=False, batch_size=2)
return dataset


Expand Down Expand Up @@ -81,8 +86,7 @@ def formatting_prompts_func(examples):
llm.api.finetune(
model=llm.HfAutoModelForCausalLM(args.model),
data=llm.HfDatasetDataModule(
mk_hf_dataset(tokenizer.tokenizer),
pad_token_id=tokenizer.tokenizer.eos_token_id
mk_hf_dataset(tokenizer.tokenizer), pad_token_id=tokenizer.tokenizer.eos_token_id
),
trainer=nl.Trainer(
devices=args.devices,
Expand All @@ -100,4 +104,4 @@ def formatting_prompts_func(examples):
optim=fdl.build(llm.adam.pytorch_adam_with_flat_lr(max_lr=1e-5, clip_grad=0.5)),
log=None,
peft=lora,
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM
from nemo.collections.llm import fn

from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.llm import fn
from nemo.lightning import io


Expand Down
13 changes: 9 additions & 4 deletions nemo/collections/llm/peft/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import re
import torch
import math
import re
from dataclasses import dataclass, field
from typing import List, Literal

import torch
from megatron.core import parallel_state
from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear
from torch import nn
Expand Down Expand Up @@ -72,7 +72,9 @@ def forward(self, x):


class LinearAdapter(nn.Module):
def __init__(self, orig_linear, dim=8, alpha=32, dropout=0.1, dropout_position='post', lora_A_init_method='xavier'):
def __init__(
self, orig_linear, dim=8, alpha=32, dropout=0.1, dropout_position='post', lora_A_init_method='xavier'
):
super(LinearAdapter, self).__init__()
assert isinstance(orig_linear, nn.Linear)

Expand Down Expand Up @@ -111,6 +113,7 @@ def forward(self, x):
lora_res = self.dropout(lora_res)
return res + lora_res


@dataclass
class LoRA(PEFT):
"""
Expand Down Expand Up @@ -212,7 +215,9 @@ def wildcard_match(pattern, key):
in_features = m.input_size
out_features = m.output_size
elif isinstance(m, nn.Linear):
return LinearAdapter(m, dim=self.dim, alpha=self.alpha, dropout=self.dropout, lora_A_init_method=self.lora_A_init_method)
return LinearAdapter(
m, dim=self.dim, alpha=self.alpha, dropout=self.dropout, lora_A_init_method=self.lora_A_init_method
)
else:
raise NotImplementedError(f"Layer type is unrecognized for LoRA: {type(m)}")

Expand Down

0 comments on commit c324d48

Please sign in to comment.