Skip to content

Commit

Permalink
add init method
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandros Koumparoulis <[email protected]>
  • Loading branch information
akoumpa committed Oct 23, 2024
1 parent 47a895f commit 16f046e
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions nemo/collections/llm/peft/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def forward(self, x):


class LinearAdapter(nn.Module):
def __init__(self, orig_linear, dim=8, alpha=32, dropout=0.1, dropout_position='post'):
def __init__(self, orig_linear, dim=8, alpha=32, dropout=0.1, dropout_position='post', lora_A_init_method='xavier'):
super(LoraLinear, self).__init__()
assert isinstance(orig_linear, nn.Linear)

Expand All @@ -91,8 +91,11 @@ def __init__(self, orig_linear, dim=8, alpha=32, dropout=0.1, dropout_position='
dtype = self.orig_linear.weight.dtype
self.lora_a = nn.Parameter(torch.zeros((in_features, dim), dtype=dtype, device=device))
self.lora_b = nn.Parameter(torch.zeros((dim, out_features), dtype=dtype, device=device))
nn.init.kaiming_uniform_(self.lora_a, a=math.sqrt(5))
nn.init.zeros_(self.lora_b)
if lora_A_init_method == 'xavier':
torch.nn.init.uniform_(self.lora_a)
else:
nn.init.kaiming_uniform_(self.lora_a, a=math.sqrt(5))

self.dropout = nn.Dropout(p=dropout)
self.dropout_position = dropout_position

Expand Down Expand Up @@ -208,7 +211,7 @@ def wildcard_match(pattern, key):
in_features = m.input_size
out_features = m.output_size
elif isinstance(m, nn.Linear):
return LinearAdapter(m, dim=self.dim, lora_alpha=self.alpha, lora_dropout=self.dropout)
return LinearAdapter(m, dim=self.dim, lora_alpha=self.alpha, lora_dropout=self.dropout, lora_A_init_method=self.lora_A_init_method)
else:
raise NotImplementedError(f"Layer type is unrecognized for LoRA: {type(m)}")

Expand Down

0 comments on commit 16f046e

Please sign in to comment.