Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Neftune #8268

Closed
wants to merge 15 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,15 @@ def __init__(
use_flash_attention=False,
seq_len_interpolation_factor=None,
rotary_base=10000,
embedding_noise=False,
embedding_noise_mean=0.0,
embedding_noise_std=0.001,
embedding_noise_type='uniform',
neft=False,
neft_alpha=5.0,
noise_positonal_embedding=False,
adversarial_training=False,
adversarial_training_epsilon=0.01,
):
super(GPTModel, self).__init__(config=config, share_token_embeddings=share_embeddings_and_output_weights)

Expand Down Expand Up @@ -246,6 +255,15 @@ def __init__(
use_flash_attention=use_flash_attention,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
embedding_noise=embedding_noise,
embedding_noise_mean=embedding_noise_mean,
embedding_noise_std=embedding_noise_std,
embedding_noise_type=embedding_noise_type,
neft=neft,
neft_alpha=neft_alpha,
noise_positonal_embedding=noise_positonal_embedding,
adversarial_training=adversarial_training,
adversarial_training_epsilon=adversarial_training_epsilon,
)

if self.share_embeddings_and_output_weights:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,23 @@ def model_provider_func(self, pre_process, post_process):
rotary_percent=self.cfg.get('rotary_percentage', 1.0),
seq_len_interpolation_factor=self.cfg.get('seq_len_interpolation_factor', None),
rotary_base=self.cfg.get('rotary_base', 10000),
embedding_noise=self.cfg.get('embedding_noise', False),
embedding_noise_mean=self.cfg.get('embedding_noise_mean', 0.0),
embedding_noise_std=self.cfg.get('embedding_noise_std', 0.001),
embedding_noise_type=self.cfg.get('embedding_noise_type', 'uniform'),
neft=self.cfg.get('neft', False),
neft_alpha=self.cfg.get('neft_alpha', 5.0),
noise_positonal_embedding=self.cfg.get('noise_positonal_embedding', False),
adversarial_training=self.cfg.get('adversarial_training', False),
adversarial_training_epsilon=self.cfg.get('adversarial_training_epsilon', 0.01),
noise_scheduler_config=self.cfg.get('noise_scheduler_config', None),
cre_adversarial_training=self.cfg.get('cre_adversarial_training', False),
creat_init_var=self.cfg.get('creat_init_var', 1e-2),
creat_num_adv_steps=self.cfg.get('creat_num_adv_steps', 2),
creat_adv_temp=self.cfg.get('creat_adv_temp', 1.0),
creat_lambda=self.cfg.get('creat_lambda', 0.5),
creat_lr=self.cfg.get('creat_lr', 0.1),
creat_max_norm=self.cfg.get('creat_max_norm', 0.1),
)
else:
assert self.cfg.get('num_query_groups', None) is None or self.cfg.get(
Expand Down Expand Up @@ -426,6 +443,23 @@ def model_provider_func(self, pre_process, post_process):
megatron_legacy=self.cfg.get('megatron_legacy', False),
seq_len_interpolation_factor=self.cfg.get('seq_len_interpolation_factor', None),
rotary_base=self.cfg.get('rotary_base', 10000),
embedding_noise=self.cfg.get('embedding_noise', False),
embedding_noise_mean=self.cfg.get('embedding_noise_mean', 0.0),
embedding_noise_std=self.cfg.get('embedding_noise_std', 0.001),
embedding_noise_type=self.cfg.get('embedding_noise_type', 'uniform'),
neft=self.cfg.get('neft', False),
neft_alpha=self.cfg.get('neft_alpha', 5.0),
noise_positonal_embedding=self.cfg.get('noise_positonal_embedding', False),
adversarial_training=self.cfg.get('adversarial_training', False),
adversarial_training_epsilon=self.cfg.get('adversarial_training_epsilon', 0.01),
noise_scheduler_config=self.cfg.get('noise_scheduler_config', None),
cre_adversarial_training=self.cfg.get('cre_adversarial_training', False),
creat_init_var=self.cfg.get('creat_init_var', 1e-2),
creat_num_adv_steps=self.cfg.get('cre_num_adv_steps', 2),
creat_adv_temp=self.cfg.get('cre_adv_temp', 1.0),
creat_lambda=self.cfg.get('cre_lambda', 0.5),
creat_adv_lr=self.cfg.get('cre_adv_lr', 0.1),
creat_adv_max_norm=self.cfg.get('cre_adv_max_norm', 0.1),
)
return model

Expand Down
87 changes: 87 additions & 0 deletions nemo/collections/nlp/modules/common/megatron/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""Transformer based language model."""
import math
from ast import Mod

import torch
Expand Down Expand Up @@ -126,6 +127,15 @@ def get_language_model(
use_flash_attention=False,
seq_len_interpolation_factor=None,
rotary_base=10000,
embedding_noise=False,
embedding_noise_mean=0.0,
embedding_noise_std=0.001,
embedding_noise_type='uniform',
neft=False,
neft_alpha=5.0,
noise_positonal_embedding=False,
adversarial_training=False,
adversarial_training_epsilon=0.01,
):
"""Build language model and return along with the key to save."""

Expand Down Expand Up @@ -202,6 +212,15 @@ def get_language_model(
use_flash_attention=use_flash_attention,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
embedding_noise=embedding_noise,
embedding_noise_mean=embedding_noise_mean,
embedding_noise_std=embedding_noise_std,
embedding_noise_type=embedding_noise_type,
neft=neft,
neft_alpha=neft_alpha,
noise_positonal_embedding=noise_positonal_embedding,
adversarial_training=adversarial_training,
adversarial_training_epsilon=adversarial_training_epsilon,
)
# key used for checkpoints.
language_model_key = 'language_model'
Expand Down Expand Up @@ -254,6 +273,13 @@ class Embedding(MegatronModule):
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
position_embedding_type: position embedding type determines whether we instantiate a learnable position embedding table.
embedding_noise: whether to add noise to the embeddings, only works during training
embedding_noise_mean: mean of the embedding noise
embedding_noise_std: standard deviation of the embedding noise
embedding_noise_type: type of the embedding noise (distribution)
neft: whether to use NEFTune normalization technique, only works during training
neft_alpha: alpha parameter for NEFTune
noise_positonal_embedding: whether to add noise to the positional embeddings
"""

def __init__(
Expand All @@ -268,6 +294,13 @@ def __init__(
fp32_residual_connection=False,
position_embedding_type='learned_absolute',
transpose_batch_sequence=True,
embedding_noise=False,
embedding_noise_mean=0.0,
embedding_noise_std=0.001,
embedding_noise_type='uniform',
neft=False,
neft_alpha=5.0,
noise_positonal_embedding=False,
):
super(Embedding, self).__init__(config=config)

Expand All @@ -276,6 +309,14 @@ def __init__(
self.num_tokentypes = num_tokentypes
self.position_embedding_type = position_embedding_type
self.transpose_batch_sequence = transpose_batch_sequence
self.embedding_noise = embedding_noise
self.embedding_noise_mean = embedding_noise_mean
self.embedding_noise_std = embedding_noise_std
self.embedding_noise_type = embedding_noise_type
self.neft = neft
self.neft_alpha = neft_alpha
self.noise_positonal_embedding = noise_positonal_embedding

# Word embeddings (parallel).
self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
vocab_size, self.hidden_size, init_method=self.init_method, config=config,
Expand Down Expand Up @@ -318,6 +359,47 @@ def __init__(
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)

def _noise(self, embeddings):
"""
Add noise to the embeddings, only works during training.
Noise type (noise/neft) is determined by the self.neft flag.
:param embeddings: embeddings to add noise to
"""
if self.training:
if self.embedding_noise and not self.neft:
# Add noise to the embeddings
if self.embedding_noise_type == 'uniform':
noise = (
torch.empty_like(embeddings)
.uniform_(self.embedding_noise_mean, self.embedding_noise_std)
.detach()
)
elif self.embedding_noise_type == 'normal':
noise = (
torch.empty_like(embeddings)
.normal_(self.embedding_noise_mean, self.embedding_noise_std)
.detach()
)
else:
raise NotImplementedError(f"embedding noise type {self.embedding_noise_type} not implemented")

# Calculate the norm of the original embeddings
original_norm = torch.norm(embeddings, p=2, dim=1, keepdim=True)

# Apply noise
embeddings = embeddings + noise

# Calculate the norm of the noisy embeddings
noisy_norm = torch.norm(embeddings, p=2, dim=1, keepdim=True)

# Normalize the noisy embeddings
embeddings = embeddings * (original_norm / noisy_norm)
elif self.neft:
epsilon = torch.empty_like(embeddings).uniform_(-1, 1).detach()
scaled_noise = (self.neft_alpha / math.sqrt(embeddings.shape[0] * embeddings.shape[-1])) * epsilon
embeddings = embeddings + scaled_noise
return embeddings

def zero_parameters(self):
"""Zero out all parameters in embedding."""
self.word_embeddings.weight.data.fill_(0)
Expand Down Expand Up @@ -346,6 +428,8 @@ def add_tokentype_embeddings(self, num_tokentypes):
def forward(self, input_ids, position_ids=None, token_type_ids=None):
# Embeddings.
words_embeddings = self.word_embeddings(input_ids)
if not self.noise_positonal_embedding:
words_embeddings = self._noise(words_embeddings)
if self.position_embedding_type == 'learned_absolute':
assert position_ids is not None
position_embeddings = self.position_embeddings(position_ids)
Expand Down Expand Up @@ -376,6 +460,9 @@ def forward(self, input_ids, position_ids=None, token_type_ids=None):
else:
embeddings = self.embedding_dropout(embeddings)

if self.noise_positonal_embedding:
embeddings = self._noise(embeddings)

return embeddings

def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
Expand Down
Loading