diff --git a/litgpt/lora.py b/litgpt/lora.py index 71d4abf79a..8fee63cbb6 100644 --- a/litgpt/lora.py +++ b/litgpt/lora.py @@ -338,11 +338,13 @@ def zero_pad(self, x: torch.Tensor) -> torch.Tensor: x = x.transpose(0, 1) result = x.new_zeros((*x.shape[:-1], self.linear.out_features)) # (64, 64, 384) result = result.view(-1, self.linear.out_features) # (4096, 384) + + # `lora_ind` is constant, so we want to avoid copying it (and incurring an expensive cudaStreamSynchronize) + # every time this method is called. So instead we simply cache a copy on each device that needs it. if (lora_ind := self._lora_ind_cache.get(result.device)) is None: self._lora_ind_cache[result.device] = lora_ind = self._lora_ind.to(result.device) - result = result.index_copy( - 1, torch.tensor(lora_ind, device=result.device), x.reshape(-1, sum(self.qkv_shapes)) - ) # (4096, 256) + + result = result.index_copy(1, lora_ind, x.reshape(-1, sum(self.qkv_shapes))) # (4096, 256) return result.view((*x.shape[:-1], self.linear.out_features)).transpose(0, 1) # (64, 64, 384) def conv1d(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: diff --git a/litgpt/utils.py b/litgpt/utils.py index beb52dcdc1..6eb7efbff4 100644 --- a/litgpt/utils.py +++ b/litgpt/utils.py @@ -272,6 +272,7 @@ def chunked_cross_entropy( for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) ] non_masked_elems = (targets != ignore_index).sum() + # See [non_masked_elems div note] return torch.cat(loss_chunks).sum() / non_masked_elems.maximum(torch.ones_like(non_masked_elems)) # no chunking at all @@ -288,6 +289,10 @@ def chunked_cross_entropy( for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) ] non_masked_elems = (targets != ignore_index).sum() + # [non_masked_elems div note]: + # max(1, non_masked_elems) would be more ergonomic to avoid a division by zero. However that + # results in a python int which is then passed back to torch division. By using the + # `x.maximum(torch.ones_like(x))` pattern we avoid a cudaStreamSynchronize. return torch.cat(loss_chunks).sum() / non_masked_elems.maximum(torch.ones_like(non_masked_elems))