diff --git a/nemo/collections/llm/gpt/model/starcoder.py b/nemo/collections/llm/gpt/model/starcoder.py index ab3968a5c71c..5e461618d75b 100644 --- a/nemo/collections/llm/gpt/model/starcoder.py +++ b/nemo/collections/llm/gpt/model/starcoder.py @@ -18,6 +18,7 @@ import torch.nn.functional as F from torch import nn +import torch from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel, torch_dtype_from_mcore_config from nemo.collections.llm.utils import Config