Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Taylor Robie committed Apr 29, 2024
1 parent 841df23 commit 107832a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
8 changes: 5 additions & 3 deletions litgpt/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,11 +338,13 @@ def zero_pad(self, x: torch.Tensor) -> torch.Tensor:
x = x.transpose(0, 1)
result = x.new_zeros((*x.shape[:-1], self.linear.out_features)) # (64, 64, 384)
result = result.view(-1, self.linear.out_features) # (4096, 384)

# `lora_ind` is constant, so we want to avoid copying it (and incurring an expensive cudaStreamSynchronize)
# every time this method is called. So instead we simply cache a copy on each device that needs it.
if (lora_ind := self._lora_ind_cache.get(result.device)) is None:
self._lora_ind_cache[result.device] = lora_ind = self._lora_ind.to(result.device)
result = result.index_copy(
1, torch.tensor(lora_ind, device=result.device), x.reshape(-1, sum(self.qkv_shapes))
) # (4096, 256)

result = result.index_copy(1, lora_ind, x.reshape(-1, sum(self.qkv_shapes))) # (4096, 256)
return result.view((*x.shape[:-1], self.linear.out_features)).transpose(0, 1) # (64, 64, 384)

def conv1d(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
Expand Down
5 changes: 5 additions & 0 deletions litgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def chunked_cross_entropy(
for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
]
non_masked_elems = (targets != ignore_index).sum()
# See [non_masked_elems div note]
return torch.cat(loss_chunks).sum() / non_masked_elems.maximum(torch.ones_like(non_masked_elems))

# no chunking at all
Expand All @@ -288,6 +289,10 @@ def chunked_cross_entropy(
for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
]
non_masked_elems = (targets != ignore_index).sum()
# [non_masked_elems div note]:
# max(1, non_masked_elems) would be more ergonomic to avoid a division by zero. However that
# results in a python int which is then passed back to torch division. By using the
# `x.maximum(torch.ones_like(x))` pattern we avoid a cudaStreamSynchronize.
return torch.cat(loss_chunks).sum() / non_masked_elems.maximum(torch.ones_like(non_masked_elems))


Expand Down

0 comments on commit 107832a

Please sign in to comment.