Skip to content

Commit

Permalink
Ensuring matrices are on same devices
Browse files Browse the repository at this point in the history
  • Loading branch information
WardLT committed Dec 25, 2023
1 parent a01614c commit 725db57
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions jitterbug/model/dscribe/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(self, models: dict[int, torch.nn.Module]):
)

def forward(self, element: torch.IntTensor, desc: torch.Tensor) -> torch.Tensor:
output = torch.empty((desc.shape[0],), dtype=desc.dtype)
output = torch.empty_like(desc[:, 0])
for elem, model in self.models.items():
elem_id = int(elem)
mask = element == elem_id
Expand Down Expand Up @@ -136,7 +136,7 @@ def train_model(model: torch.nn.Module,
n_conf, n_atoms = train_x.shape[:2]
train_x = torch.from_numpy(train_x)
train_y = torch.from_numpy(train_y)
train_e = torch.from_numpy(train_e)
train_e = torch.from_numpy(train_e).to(device)

# Duplicate the elements per batch size
train_e = train_e.repeat(batch_size)
Expand Down

0 comments on commit 725db57

Please sign in to comment.