diff --git a/jitterbug/model/dscribe/local.py b/jitterbug/model/dscribe/local.py index d3b0d66..1087a6a 100644 --- a/jitterbug/model/dscribe/local.py +++ b/jitterbug/model/dscribe/local.py @@ -58,7 +58,7 @@ def __init__(self, models: dict[int, torch.nn.Module]): ) def forward(self, element: torch.IntTensor, desc: torch.Tensor) -> torch.Tensor: - output = torch.empty((desc.shape[0],), dtype=desc.dtype) + output = torch.empty_like(desc[:, 0]) for elem, model in self.models.items(): elem_id = int(elem) mask = element == elem_id @@ -136,7 +136,7 @@ def train_model(model: torch.nn.Module, n_conf, n_atoms = train_x.shape[:2] train_x = torch.from_numpy(train_x) train_y = torch.from_numpy(train_y) - train_e = torch.from_numpy(train_e) + train_e = torch.from_numpy(train_e).to(device) # Duplicate the elements per batch size train_e = train_e.repeat(batch_size)