diff --git a/chgnet/trainer/trainer.py b/chgnet/trainer/trainer.py index 67db99e8..e3637212 100644 --- a/chgnet/trainer/trainer.py +++ b/chgnet/trainer/trainer.py @@ -49,6 +49,7 @@ def __init__( force_loss_ratio: float = 1, stress_loss_ratio: float = 0.1, mag_loss_ratio: float = 0.1, + allow_missing_labels: bool = True, optimizer: str = "Adam", scheduler: str = "CosLR", criterion: str = "MSE", @@ -78,6 +79,9 @@ def __init__( Default = 0.1 mag_loss_ratio (float): magmom loss ratio in loss function Default = 0.1 + allow_missing_labels (bool): whether to allow missing labels in the dataset, + missed target will not contribute to loss and MAEs + Default = True optimizer (str): optimizer to update model. Can be "Adam", "SGD", "AdamW", "RAdam". Default = 'Adam' scheduler (str): learning rate scheduler. Can be "CosLR", "ExponentialLR", @@ -209,6 +213,7 @@ def __init__( force_loss_ratio=force_loss_ratio, stress_loss_ratio=stress_loss_ratio, mag_loss_ratio=mag_loss_ratio, + allow_missing_labels=allow_missing_labels, **kwargs, ) self.epochs = epochs @@ -726,6 +731,7 @@ def __init__( stress_loss_ratio: float = 0.1, mag_loss_ratio: float = 0.1, delta: float = 0.1, + allow_missing_labels: bool = True, ) -> None: """Initialize the combined loss. @@ -745,6 +751,8 @@ def __init__( mag_loss_ratio (float): magmom loss ratio in loss function Default = 0.1 delta (float): delta for torch.nn.HuberLoss. Default = 0.1 + allow_missing_labels (bool): whether to allow missing labels in the dataset, + missed target will not contribute to loss and MAEs """ super().__init__() # Define loss criterion @@ -771,6 +779,7 @@ def __init__( self.mag_loss_ratio = 0 else: self.mag_loss_ratio = mag_loss_ratio + self.allow_missing_labels = allow_missing_labels def forward( self, @@ -791,25 +800,37 @@ def forward( out = {"loss": 0.0} # Energy if "e" in self.target_str: - if self.is_intensive: - out["loss"] += self.energy_loss_ratio * self.criterion( - targets["e"], prediction["e"] - ) - out["e_MAE"] = mae(targets["e"], prediction["e"]) - out["e_MAE_size"] = prediction["e"].shape[0] + if self.allow_missing_labels: + valid_value_indices = ~torch.isnan(targets["e"]) + valid_e_target = targets["e"][valid_value_indices] + valid_atoms_per_graph = prediction["atoms_per_graph"][ + valid_value_indices + ] + valid_e_pred = prediction["e"][valid_value_indices] + if valid_e_pred.shape == torch.Size([]): + valid_e_pred = valid_e_pred.view(1) else: - e_per_atom_target = targets["e"] / prediction["atoms_per_graph"] - e_per_atom_pred = prediction["e"] / prediction["atoms_per_graph"] - out["loss"] += self.energy_loss_ratio * self.criterion( - e_per_atom_target, e_per_atom_pred - ) - out["e_MAE"] = mae(e_per_atom_target, e_per_atom_pred) - out["e_MAE_size"] = prediction["e"].shape[0] + valid_e_target = targets["e"] + valid_atoms_per_graph = prediction["atoms_per_graph"] + valid_e_pred = prediction["e"] + if self.is_intensive: + valid_e_target = valid_e_target / valid_atoms_per_graph + valid_e_pred = valid_e_pred / valid_atoms_per_graph + + out["loss"] += self.energy_loss_ratio * self.criterion( + valid_e_target, valid_e_pred + ) + out["e_MAE"] = mae(valid_e_target, valid_e_pred) + out["e_MAE_size"] = prediction["e"].shape[0] # Force if "f" in self.target_str: forces_pred = torch.cat(prediction["f"], dim=0) forces_target = torch.cat(targets["f"], dim=0) + if self.allow_missing_labels: + valid_value_indices = ~torch.isnan(forces_target) + forces_target = forces_target[valid_value_indices] + forces_pred = forces_pred[valid_value_indices] out["loss"] += self.force_loss_ratio * self.criterion( forces_target, forces_pred ) @@ -820,6 +841,10 @@ def forward( if "s" in self.target_str: stress_pred = torch.cat(prediction["s"], dim=0) stress_target = torch.cat(targets["s"], dim=0) + if self.allow_missing_labels: + valid_value_indices = ~torch.isnan(stress_target) + stress_target = stress_target[valid_value_indices] + stress_pred = stress_pred[valid_value_indices] out["loss"] += self.stress_loss_ratio * self.criterion( stress_target, stress_pred ) @@ -832,7 +857,12 @@ def forward( m_mae_size = 0 for mag_pred, mag_target in zip(prediction["m"], targets["m"], strict=True): # exclude structures without magmom labels - if mag_target is not None: + if self.allow_missing_labels: + if mag_target is not None and not np.isnan(mag_target).any(): + mag_preds.append(mag_pred) + mag_targets.append(mag_target) + m_mae_size += mag_target.shape[0] + else: mag_preds.append(mag_pred) mag_targets.append(mag_target) m_mae_size += mag_target.shape[0] diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 89bf16b4..de1d1497 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -21,7 +21,7 @@ coords = [[0, 0, 0], [0.5, 0.5, 0.5]] NaCl = Structure(lattice, species, coords) structures, energies, forces, stresses, magmoms = [], [], [], [], [] -for _ in range(100): +for _ in range(20): struct = NaCl.copy() struct.perturb(0.1) structures.append(struct) @@ -30,15 +30,22 @@ stresses.append(np.random.random([3, 3])) magmoms.append(np.random.random(2)) +# Create some missing labels +energies[10] = np.nan +forces[4] = (np.nan * np.ones((len(structures[4]), 3))).tolist() +stresses[6] = (np.nan * np.ones((3, 3))).tolist() +magmoms[8] = (np.nan * np.ones((len(structures[8]), 1))).tolist() + data = StructureData( structures=structures, energies=energies, forces=forces, stresses=stresses, magmoms=magmoms, + shuffle=False, ) train_loader, val_loader, _test_loader = get_train_val_test_loader( - data, batch_size=16, train_ratio=0.9, val_ratio=0.05 + data, batch_size=4, train_ratio=0.9, val_ratio=0.05 ) chgnet = CHGNet.load() @@ -55,6 +62,7 @@ def test_trainer(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: wandb_path="test/run", wandb_init_kwargs=dict(anonymous="must"), extra_run_config=extra_run_config, + allow_missing_labels=True, ) trainer.train( train_loader, @@ -66,7 +74,9 @@ def test_trainer(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: for param in chgnet.composition_model.parameters(): assert param.requires_grad is False assert tmp_path.is_dir(), "Training dir was not created" - + for target_str in ["e", "f", "s", "m"]: + assert ~np.isnan(trainer.training_history[target_str]["train"]).any() + assert ~np.isnan(trainer.training_history[target_str]["val"]).any() output_files = [file.name for file in tmp_path.iterdir()] for prefix in ("epoch", "bestE_", "bestF_"): n_matches = sum(file.startswith(prefix) for file in output_files) @@ -147,6 +157,7 @@ def test_wandb_init(mock_wandb): "wandb_path": "test-project/test-run", "wandb_init_kwargs": {"tags": ["test"]}, "extra_run_config": None, + "allow_missing_labels": True, } mock_wandb.init.assert_called_once_with( project="test-project", name="test-run", config=expected_config, tags=["test"]