From b8e4beb9459b1523c11f5998367598df4306eb29 Mon Sep 17 00:00:00 2001 From: Logan Ward Date: Mon, 25 Dec 2023 13:49:04 -0500 Subject: [PATCH] Scale scaling constants by number of atoms --- jitterbug/model/dscribe/local.py | 2 +- tests/models/test_soap.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/jitterbug/model/dscribe/local.py b/jitterbug/model/dscribe/local.py index 1087a6a..c8b7c35 100644 --- a/jitterbug/model/dscribe/local.py +++ b/jitterbug/model/dscribe/local.py @@ -316,6 +316,6 @@ def train_calculator(self, data: list[Atoms]) -> Calculator: model=model, desc=self.descriptors, desc_scaling=(offset_x, scale_x), - energy_scaling=(offset_y, scale_y), + energy_scaling=(offset_y / len(data[0]), scale_y / len(data[0])), device=self.device ) diff --git a/tests/models/test_soap.py b/tests/models/test_soap.py index b2410d1..0f64be0 100644 --- a/tests/models/test_soap.py +++ b/tests/models/test_soap.py @@ -117,6 +117,10 @@ def test_model(soap, train_set): # Run the fitting calcs = model.train(train_set) + # Make sure the forces are reasonable + eng = calcs[0].get_potential_energy(train_set[0]) + assert np.isclose(eng, train_set[0].get_potential_energy(), atol=1e-2) + # Test the mean hessian function mean_hess = model.mean_hessian(calcs) assert mean_hess.shape == (9, 9), 'Wrong shape'