Skip to content

Commit

Permalink
Missed a few ".to(device)" steps
Browse files Browse the repository at this point in the history
  • Loading branch information
WardLT committed Sep 11, 2024
1 parent 09305e0 commit 01b979a
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions cascade/learning/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def train(self,
if reset_weights:
config = extract_config_mace_model(model)
model = ScaleShiftMACE(**config)
model.to(device)
model.to(device)

# Unpin weights
for p in model.parameters():
Expand All @@ -142,8 +142,8 @@ def train(self,
# Update the shift of the energy scale
errors = []
for _, batch in zip(range(4), train_loader): # Use only the first 4 batches, for computational efficiency
num_atoms = batch.ptr[1:] - batch.ptr[:-1] # Taken from loss function, still don't understand it
batch = batch.to(device)
num_atoms = batch.ptr[1:] - batch.ptr[:-1] # Use the offsets to compute the number of atoms per inference
ml = model(
batch,
training=False,
Expand Down

0 comments on commit 01b979a

Please sign in to comment.