diff --git a/src/fairchem/core/preprocessing/atoms_to_graphs.py b/src/fairchem/core/preprocessing/atoms_to_graphs.py index fa679a262..27dc664a3 100644 --- a/src/fairchem/core/preprocessing/atoms_to_graphs.py +++ b/src/fairchem/core/preprocessing/atoms_to_graphs.py @@ -185,14 +185,14 @@ def convert(self, atoms: ase.Atoms, sid=None): cell = np.array(atoms.get_cell(complete=True), copy=True) positions = wrap_positions(positions, cell, pbc=pbc, eps=0) - atomic_numbers = torch.Tensor(atoms.get_atomic_numbers()) + atomic_numbers = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.uint8) positions = torch.from_numpy(positions).float() cell = torch.from_numpy(cell).view(1, 3, 3).float() natoms = positions.shape[0] # initialized to torch.zeros(natoms) if tags missing. # https://wiki.fysik.dtu.dk/ase/_modules/ase/atoms.html#Atoms.get_tags - tags = torch.Tensor(atoms.get_tags()) + tags = torch.tensor(atoms.get_tags(), dtype=torch.int) # put the minimum data in torch geometric data object data = Data( @@ -228,10 +228,15 @@ def convert(self, atoms: ase.Atoms, sid=None): energy = atoms.get_potential_energy(apply_constraint=False) data.energy = energy if self.r_forces: - forces = torch.Tensor(atoms.get_forces(apply_constraint=False)) + forces = torch.tensor( + atoms.get_forces(apply_constraint=False), dtype=torch.float32 + ) data.forces = forces if self.r_stress: - stress = torch.Tensor(atoms.get_stress(apply_constraint=False, voigt=False)) + stress = torch.tensor( + atoms.get_stress(apply_constraint=False, voigt=False), + dtype=torch.float32, + ) data.stress = stress if self.r_distances and self.r_edges: data.distances = edge_distances @@ -245,13 +250,13 @@ def convert(self, atoms: ase.Atoms, sid=None): fixed_idx[constraint.index] = 1 data.fixed = fixed_idx if self.r_pbc: - data.pbc = torch.tensor(atoms.pbc) + data.pbc = torch.tensor(atoms.pbc, dtype=torch.bool) if self.r_data_keys is not None: for data_key in self.r_data_keys: data[data_key] = ( atoms.info[data_key] if isinstance(atoms.info[data_key], (int, float, str)) - else torch.Tensor(atoms.info[data_key]) + else torch.tensor(atoms.info[data_key]) ) return data