Skip to content

Commit

Permalink
set tensor dtypes (#935)
Browse files Browse the repository at this point in the history
  • Loading branch information
lbluque authored Dec 18, 2024
1 parent 6efc99b commit 5959f5c
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions src/fairchem/core/preprocessing/atoms_to_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,14 +185,14 @@ def convert(self, atoms: ase.Atoms, sid=None):
cell = np.array(atoms.get_cell(complete=True), copy=True)
positions = wrap_positions(positions, cell, pbc=pbc, eps=0)

atomic_numbers = torch.Tensor(atoms.get_atomic_numbers())
atomic_numbers = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.uint8)
positions = torch.from_numpy(positions).float()
cell = torch.from_numpy(cell).view(1, 3, 3).float()
natoms = positions.shape[0]

# initialized to torch.zeros(natoms) if tags missing.
# https://wiki.fysik.dtu.dk/ase/_modules/ase/atoms.html#Atoms.get_tags
tags = torch.Tensor(atoms.get_tags())
tags = torch.tensor(atoms.get_tags(), dtype=torch.int)

# put the minimum data in torch geometric data object
data = Data(
Expand Down Expand Up @@ -228,10 +228,15 @@ def convert(self, atoms: ase.Atoms, sid=None):
energy = atoms.get_potential_energy(apply_constraint=False)
data.energy = energy
if self.r_forces:
forces = torch.Tensor(atoms.get_forces(apply_constraint=False))
forces = torch.tensor(
atoms.get_forces(apply_constraint=False), dtype=torch.float32
)
data.forces = forces
if self.r_stress:
stress = torch.Tensor(atoms.get_stress(apply_constraint=False, voigt=False))
stress = torch.tensor(
atoms.get_stress(apply_constraint=False, voigt=False),
dtype=torch.float32,
)
data.stress = stress
if self.r_distances and self.r_edges:
data.distances = edge_distances
Expand All @@ -245,13 +250,13 @@ def convert(self, atoms: ase.Atoms, sid=None):
fixed_idx[constraint.index] = 1
data.fixed = fixed_idx
if self.r_pbc:
data.pbc = torch.tensor(atoms.pbc)
data.pbc = torch.tensor(atoms.pbc, dtype=torch.bool)
if self.r_data_keys is not None:
for data_key in self.r_data_keys:
data[data_key] = (
atoms.info[data_key]
if isinstance(atoms.info[data_key], (int, float, str))
else torch.Tensor(atoms.info[data_key])
else torch.tensor(atoms.info[data_key])
)

return data
Expand Down

0 comments on commit 5959f5c

Please sign in to comment.