Skip to content

Commit

Permalink
Update tests to handle small floating point differences
Browse files Browse the repository at this point in the history
  • Loading branch information
kavanase committed Oct 12, 2023
1 parent b0780fa commit 85cba19
Showing 1 changed file with 65 additions and 43 deletions.
108 changes: 65 additions & 43 deletions tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@

import numpy as np
import pytest
from deepdiff import DeepDiff
from monty.serialization import dumpfn, loadfn
from pymatgen.electronic_structure.core import Spin

from pytaser import generator,tas
from pytaser import dasgenerator
from pytaser import generator, tas


def test_gaussian(datapath_gaas):
Expand Down Expand Up @@ -370,8 +368,14 @@ def test_generate_tas(generated_class, light, dark, tas_object, conditions):
)

assert tas_class.tas_total.all() == tas_object.tas_total.all()
assert DeepDiff(tas_class.jdos_diff_if, tas_object.jdos_diff_if) == {}
assert DeepDiff(tas_class.jdos_light_if, tas_object.jdos_light_if) == {}
for if_tuple in tas_class.jdos_diff_if:
np.testing.assert_array_almost_equal(
tas_class.jdos_diff_if[if_tuple], tas_object.jdos_diff_if[if_tuple]
)
np.testing.assert_array_almost_equal(
tas_class.jdos_light_if[if_tuple],
tas_object.jdos_light_if[if_tuple],
)
assert (
tas_class.jdos_light_total.all() == tas_object.jdos_light_total.all()
)
Expand Down Expand Up @@ -449,7 +453,9 @@ def test_from_mpid(mocker, datapath_gaas, generated_class, conditions):
assert gaas2534.vb == generated_class.vb
assert gaas2534.cb == generated_class.cb

#Tests for DASGenerator class

# Tests for DASGenerator class


def test_DAS_from_vasprun(
tio2_das_conditions,
Expand All @@ -461,36 +467,47 @@ def test_DAS_from_vasprun(
AssertionError,
):
np.testing.assert_array_almost_equal(
das_class_vr_only.das_total, das_class_with_waveder.das_total, decimal=1
das_class_vr_only.das_total,
das_class_with_waveder.das_total,
decimal=1,
)

assert das_class_vr_only.jdos_newSys_total.size
np.testing.assert_array_almost_equal(
das_class_vr_only.jdos_newSys_total.any(), das_class_with_waveder.jdos_newSys_total.any()
np.testing.assert_allclose(
das_class_vr_only.jdos_newSys_total.any(),
das_class_with_waveder.jdos_newSys_total.any(),
rtol=1e-4,
)

assert das_class_vr_only.jdos_newSys_if
for key, array in das_class_vr_only.jdos_newSys_if.items():
np.testing.assert_array_almost_equal(
array, das_class_with_waveder.jdos_newSys_if[key]
np.testing.assert_allclose(
array, das_class_with_waveder.jdos_newSys_if[key], rtol=1e-4
)

assert das_class_vr_only.jdos_ref_total.size
np.testing.assert_array_almost_equal(
das_class_vr_only.jdos_ref_total, das_class_with_waveder.jdos_ref_total
np.testing.assert_allclose(
das_class_vr_only.jdos_ref_total,
das_class_with_waveder.jdos_ref_total,
rtol=1e-4,
)

assert das_class_vr_only.jdos_ref_if
for key, array in das_class_vr_only.jdos_ref_if.items():
np.testing.assert_array_almost_equal(
array, das_class_with_waveder.jdos_ref_if[key]
np.testing.assert_allclose(
array, das_class_with_waveder.jdos_ref_if[key], rtol=1e-4
)

assert das_class_vr_only.energy_mesh_ev.size
np.testing.assert_array_almost_equal(
das_class_vr_only.energy_mesh_ev, das_class_with_waveder.energy_mesh_ev
np.testing.assert_allclose(
das_class_vr_only.energy_mesh_ev,
das_class_with_waveder.energy_mesh_ev,
rtol=1e-4,
)
assert (
das_class_vr_only.bandgap_newSys
== das_class_with_waveder.bandgap_newSys
)
assert das_class_vr_only.bandgap_newSys == das_class_with_waveder.bandgap_newSys
assert das_class_vr_only.bandgap_ref == das_class_with_waveder.bandgap_ref
assert das_class_vr_only.temp == tio2_das_conditions[2]
assert das_class_vr_only.temp == das_class_with_waveder.temp
Expand All @@ -499,39 +516,43 @@ def test_DAS_from_vasprun(
assert das_class_vr_only.alpha_ref is None



def test_generate_das(das_class_with_waveder, das_object, tio2_das_conditions):

assert das_class_with_waveder.das_total.size
np.testing.assert_array_almost_equal(
das_class_with_waveder.das_total, das_object.das_total
np.testing.assert_allclose(
das_class_with_waveder.das_total, das_object.das_total, rtol=1e-3
)

assert das_class_with_waveder.jdos_newSys_total.size
np.testing.assert_array_almost_equal(
das_class_with_waveder.jdos_newSys_total, das_object.jdos_newSys_total
np.testing.assert_allclose(
das_class_with_waveder.jdos_newSys_total,
das_object.jdos_newSys_total,
rtol=1e-3,
)

assert das_class_with_waveder.jdos_newSys_if
for key, array in das_class_with_waveder.jdos_newSys_if.items():
np.testing.assert_array_almost_equal(
array, das_object.jdos_newSys_if[key]
np.testing.assert_allclose(
array, das_object.jdos_newSys_if[key], rtol=1e-3
)

assert das_class_with_waveder.jdos_ref_total.size
np.testing.assert_array_almost_equal(
das_class_with_waveder.jdos_ref_total, das_object.jdos_ref_total
np.testing.assert_allclose(
das_class_with_waveder.jdos_ref_total,
das_object.jdos_ref_total,
rtol=1e-3,
)

assert das_class_with_waveder.jdos_ref_if
for key, array in das_class_with_waveder.jdos_ref_if.items():
np.testing.assert_array_almost_equal(
array, das_object.jdos_ref_if[key]
np.testing.assert_allclose(
array, das_object.jdos_ref_if[key], rtol=1e-3
)

assert das_class_with_waveder.energy_mesh_ev.size
np.testing.assert_array_almost_equal(
das_class_with_waveder.energy_mesh_ev, das_object.energy_mesh_ev
np.testing.assert_allclose(
das_class_with_waveder.energy_mesh_ev,
das_object.energy_mesh_ev,
rtol=1e-3,
)

assert das_class_with_waveder.bandgap_newSys == das_object.bandgap_newSys
Expand All @@ -540,22 +561,23 @@ def test_generate_das(das_class_with_waveder, das_object, tio2_das_conditions):
assert das_class_with_waveder.temp == das_object.temp

assert das_class_with_waveder.alpha_newSys.size
np.testing.assert_array_almost_equal(
das_class_with_waveder.alpha_newSys, das_object.alpha_newSys
np.testing.assert_allclose(
das_class_with_waveder.alpha_newSys, das_object.alpha_newSys, rtol=1e-3
)

assert das_class_with_waveder.alpha_ref.size
np.testing.assert_array_almost_equal(
das_class_with_waveder.alpha_ref, das_object.alpha_ref)
np.testing.assert_allclose(
das_class_with_waveder.alpha_ref, das_object.alpha_ref, rtol=1e-3
)

assert das_class_with_waveder.weighted_jdos_newSys_if
for key, array in das_class_with_waveder.weighted_jdos_newSys_if.items():
np.testing.assert_array_almost_equal(
array, das_object.weighted_jdos_newSys_if[key]
np.testing.assert_allclose(
array, das_object.weighted_jdos_newSys_if[key], rtol=1e-3
)

assert das_class_with_waveder.weighted_jdos_ref_if
for key, array in das_class_with_waveder.weighted_jdos_ref_if.items():
np.testing.assert_array_almost_equal(
array, das_object.weighted_jdos_ref_if[key]
np.testing.assert_allclose(
array, das_object.weighted_jdos_ref_if[key], rtol=1e-3
)

0 comments on commit 85cba19

Please sign in to comment.