Skip to content

Commit

Permalink
formatting changes and more dynamic 'to(device)'
Browse files Browse the repository at this point in the history
  • Loading branch information
jkbhagatio committed Apr 21, 2024
1 parent 414c16c commit 9288c3a
Showing 1 changed file with 37 additions and 19 deletions.
56 changes: 37 additions & 19 deletions nanogpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,12 @@ def __init__(self, head_sz, emb_dim):

def forward(self, x):
"""Compute self-attention output."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_batch_sz, ctx_len, _emb_dim = x.shape
q = self.query(x)
k = self.key(x) # -> [batch_sz, ctx_len, head_sz]
v = self.value(x)
k_q_sim = q @ k.transpose(2, 1) / np.sqrt(self.head_sz) # scaled attention to preserve k, q var
tril = torch.tril(torch.ones(ctx_len, ctx_len)).to(device) # mask to prevent access to future info
k_q_sim = q @ k.transpose(2, 1) / np.sqrt(self.head_sz) # scaled attn to preserve k, q var
tril = torch.tril(torch.ones(ctx_len, ctx_len, device=x.device)) # mask: can't see future
k_q_sim = k_q_sim.masked_fill(tril == 0, float("-inf"))
attn_weights = F.softmax(k_q_sim, dim=2)
attn_out = attn_weights @ v # weighted sum of values
Expand All @@ -51,7 +50,7 @@ def __init__(self, n_heads, head_sz, emb_dim):
super().__init__()
self.n_heads, self.head_sz, self.emb_dim = n_heads, head_sz, emb_dim
self.heads = nn.ModuleList([Head(head_sz, emb_dim) for _ in range(n_heads)])
self.proj = nn.Linear(self.n_heads * self.head_sz, self.emb_dim) # project back to `emb_dim`
self.proj = nn.Linear(self.n_heads * self.head_sz, self.emb_dim) # projct back to `emb_dim`

def forward(self, x):
"""Compute multi-head self-attention output."""
Expand Down Expand Up @@ -90,7 +89,7 @@ class Block(nn.Module):
# - Dropout

def __init__(self, n_heads, head_sz, emb_dim, ff_dim, dropout):
"""Self-attention followed by position-wise feedforward, each sandwiched by layer norm & dropout."""
"""Self-attention -> position-wise feedforward, each sandwiched by layer norm & dropout."""
super().__init__()
self.n_heads, self.head_sz, self.emb_dim, self.ff_dim = n_heads, head_sz, emb_dim, ff_dim
self.self_attn_ln = nn.LayerNorm(emb_dim) # layer norm pre self-attention
Expand All @@ -111,13 +110,13 @@ def forward(self, x):

"""Create NanoGPT: Decoder-only Transformer."""

# In addition to our Transformer blocks, we need token embedding and positional embedding layers, to compute
# the positional encodings that get passed to the attention units in the transformer blocks.
# In addition to our Transformer blocks, we need token embedding and positional embedding layers, to
# compute the positional encodings that get passed to the attention units in the transformer blocks.

# We'll also apply weight init.

# We want our output to be [batch_sz, ctx_len, n_tokens], because we want to predict the next token for each
# token in the context.
# We want our output to be [batch_sz, ctx_len, n_tokens], because we want to predict the next token
# for each token in the context.


class NanoGPT(nn.Module):
Expand All @@ -134,7 +133,7 @@ def __init__(
ff_dim=4,
dropout=0.1,
):
"""Initialize token and positional embeddings, transformer blocks, and final norm and out layers."""
"""Initialize token & positional embeddings, transformer blocks, & norm and out layers."""
super().__init__()
(
self.n_tokens,
Expand Down Expand Up @@ -216,7 +215,9 @@ def apply_gradient_centralization(optimizer):
for param in group["params"]:
if param.grad is not None:
# Compute the mean of the gradient
grad_mean = param.grad.data.mean(dim=tuple(range(1, len(param.grad.shape))), keepdim=True)
grad_mean = param.grad.data.mean(
dim=tuple(range(1, len(param.grad.shape))), keepdim=True
)
# Centralize the gradient
param.grad.data -= grad_mean

Expand All @@ -232,8 +233,8 @@ def train(
val_chk_interval: int = 200, # check val loss every `val_chk_interval` batches and print losses
val_iter: int = 5, # number of batches on val_loader to run and avg when computing val loss
patience_thresh: int = 1e9, # consecutive batches without val loss decrease for early stopping
save_chkpt_dir: str = "", # dir to save model checkpoint
save_chkpt_thresh: float = 0.5, # save model checkpoint every `save_chkpt_interval` loss decrease
save_chkpt_dir: str = "", # dir to save model checkpoints
save_chkpt_thresh: float = 0.5, # save model chkpnt every `save_chkpt_interval` loss decrease
) -> tuple[torch.Tensor, np.ndarray, np.ndarray]: # -> loss, train_losses, val_losses
"""Trains a model, returns loss."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -246,7 +247,14 @@ def print_losses(epoch, batch_i, train_losses_avg, val_losses_avg):
)

@torch.no_grad()
def estimate_losses(model, val_loader, val_losses, val_losses_avg, train_losses, train_losses_avg):
def estimate_losses(
model,
val_loader,
val_losses,
val_losses_avg,
train_losses,
train_losses_avg
):
"""Estimate losses on val_loader, and return val loss and train loss avg."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval()
Expand All @@ -273,7 +281,7 @@ def estimate_losses(model, val_loader, val_losses, val_losses_avg, train_losses,

# <s Training loop
for epoch in range(max_epochs):
pbar = tqdm(enumerate(train_loader), total=batch_lim, desc="Batch progression") # tqdm progress bar
pbar = tqdm(enumerate(train_loader), total=batch_lim, desc="Batch progression")
for batch_i, (x_train, y_train) in pbar:
# <ss Model training.
optimizer.zero_grad()
Expand Down Expand Up @@ -310,7 +318,8 @@ def estimate_losses(model, val_loader, val_losses, val_losses_avg, train_losses,
# Save checkpoint check.
if (Path(save_chkpt_dir).exists()) and (init_loss - loss.item() > save_chkpt_thresh):
torch.save(
model.state_dict(), Path(save_chkpt_dir) / f"model_chkpt_loss{loss.item():.3f}.pth"
model.state_dict(),
Path(save_chkpt_dir) / f"model_chkpt_loss{loss.item():.3f}.pth"
)
init_loss = loss.item()
# /ss> /s>
Expand All @@ -333,7 +342,16 @@ def print_model_summary(model):
print(f"\n{n_params_tot / 1e6} M total parameters")


def generate(model, tokens, in_txt=None, n_tokens=100, temp=1.0, top_k=None, seed=42, print_gen=True):
def generate(
model,
tokens,
in_txt=None,
n_tokens=100,
temp=1.0,
top_k=None,
seed=42,
print_gen=True
):
"""Generate text from a nanoGPT model."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Set a random seed for generation
Expand Down Expand Up @@ -366,7 +384,7 @@ def generate(model, tokens, in_txt=None, n_tokens=100, temp=1.0, top_k=None, see
first_gen_idx, last_gen_idx = input_len - 1, input_len + n_tokens - 1
for i in range(first_gen_idx, last_gen_idx): # start gen after `input_len`
model_first_ctx = 0 if i < model.ctx_len else i - model.ctx_len + 1
logits = model(x[model_first_ctx:(i + 1)].unsqueeze(0)) # feed in `x` with a batch_sz of 1
logits = model(x[model_first_ctx:(i + 1)].unsqueeze(0)) # feed in `x` w/ batch_sz 1
# Get logits for just `len(tokens)` (squeeze out ctx_len), and scale by temp
logits = logits[:, -1, :] / temp
if top_k is not None: # limit to top_k most likely tokens
Expand Down Expand Up @@ -398,7 +416,7 @@ def generate(model, tokens, in_txt=None, n_tokens=100, temp=1.0, top_k=None, see
# - `seed` for generate
parser = argparse.ArgumentParser(description="Generate text with NanoGPT.")
parser.add_argument(
"--model-dir", type=str, required=True, help="Path to model, model config, and tokens files."
"--model-dir", type=str, required=True, help="Path to model, model config, & tokens files."
)
parser.add_argument("--in-txt", type=str, default=None, help="Input text for generation.")
parser.add_argument("--n-tokens", type=int, required=True, help="Number of tokens to generate.")
Expand Down

0 comments on commit 9288c3a

Please sign in to comment.