diff --git a/nemo/collections/llm/peft/lora.py b/nemo/collections/llm/peft/lora.py index 18db9b164d6aa..b9fd4be83c61f 100644 --- a/nemo/collections/llm/peft/lora.py +++ b/nemo/collections/llm/peft/lora.py @@ -72,7 +72,7 @@ def forward(self, x): class LinearAdapter(nn.Module): - def __init__(self, orig_linear, dim=8, alpha=32, dropout=0.1, dropout_position='post'): + def __init__(self, orig_linear, dim=8, alpha=32, dropout=0.1, dropout_position='post', lora_A_init_method='xavier'): super(LoraLinear, self).__init__() assert isinstance(orig_linear, nn.Linear) @@ -91,8 +91,11 @@ def __init__(self, orig_linear, dim=8, alpha=32, dropout=0.1, dropout_position=' dtype = self.orig_linear.weight.dtype self.lora_a = nn.Parameter(torch.zeros((in_features, dim), dtype=dtype, device=device)) self.lora_b = nn.Parameter(torch.zeros((dim, out_features), dtype=dtype, device=device)) - nn.init.kaiming_uniform_(self.lora_a, a=math.sqrt(5)) - nn.init.zeros_(self.lora_b) + if lora_A_init_method == 'xavier': + torch.nn.init.uniform_(self.lora_a) + else: + nn.init.kaiming_uniform_(self.lora_a, a=math.sqrt(5)) + self.dropout = nn.Dropout(p=dropout) self.dropout_position = dropout_position @@ -208,7 +211,7 @@ def wildcard_match(pattern, key): in_features = m.input_size out_features = m.output_size elif isinstance(m, nn.Linear): - return LinearAdapter(m, dim=self.dim, lora_alpha=self.alpha, lora_dropout=self.dropout) + return LinearAdapter(m, dim=self.dim, lora_alpha=self.alpha, lora_dropout=self.dropout, lora_A_init_method=self.lora_A_init_method) else: raise NotImplementedError(f"Layer type is unrecognized for LoRA: {type(m)}")