Skip to content

Commit

Permalink
TST: Use on*_(stop,continue) and logging tools in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
aragilar committed May 26, 2019
1 parent eabc46f commit 94e3670
Show file tree
Hide file tree
Showing 10 changed files with 121 additions and 110 deletions.
5 changes: 5 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[pytest]
log_cli = 1
log_cli_level = INFO
log_cli_format = %(asctime)s %(levelname)s %(message)s
log_cli_date_format = %H:%M:%S
4 changes: 2 additions & 2 deletions scikits/odes/tests/test_dae.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

from numpy.testing import TestCase, run_module_suite
from scipy.integrate import ode as Iode
from scikits.odes import ode,dae
from scikits.odes.sundials.common_defs import DTYPE
from .. import ode, dae
from ..sundials.common_defs import DTYPE

class TestDae(TestCase):
"""
Expand Down
4 changes: 2 additions & 2 deletions scikits/odes/tests/test_dop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from numpy.testing import (
assert_, TestCase, run_module_suite, assert_array_almost_equal,
assert_raises, assert_allclose, assert_array_equal, assert_equal)
from scikits.odes import ode
from scikits.odes.dopri5 import StatusEnumDOP
from .. import ode
from ..dopri5 import StatusEnumDOP


class SimpleOscillator():
Expand Down
12 changes: 9 additions & 3 deletions scikits/odes/tests/test_get_info.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from __future__ import print_function
import numpy as np
import unittest
from scikits.odes import ode
from .. import ode
from ..sundials import log_error_handler

COMMON_ARGS = {
"old_api": False,
"err_handler": log_error_handler
}


xs = np.linspace(1, 10, 10)
Expand All @@ -20,7 +26,7 @@ def rhs(x, y, ydot):

class GetInfoTest(unittest.TestCase):
def setUp(self):
self.ode = ode('cvode', rhs, old_api=False)
self.ode = ode('cvode', rhs, **COMMON_ARGS)
self.solution = self.ode.solve(xs, np.array([1]))

def test_we_integrated_correctly(self):
Expand All @@ -47,7 +53,7 @@ def test_ode_exposes_num_rhs_evals(self):

class GetInfoTestSpils(unittest.TestCase):
def setUp(self):
self.ode = ode('cvode', rhs, linsolver="spgmr", old_api=False)
self.ode = ode('cvode', rhs, linsolver="spgmr", **COMMON_ARGS)
self.solution = self.ode.solve(xs, np.array([1]))

def test_ode_exposes_num_njtimes_evals(self):
Expand Down
5 changes: 3 additions & 2 deletions scikits/odes/tests/test_odeint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
assert_, TestCase, run_module_suite, assert_array_almost_equal,
assert_raises, assert_allclose, assert_array_equal, assert_equal)

from scikits.odes.odeint import odeint
from scikits.odes.sundials.common_defs import DTYPE
from ..odeint import odeint
from ..sundials import log_error_handler
from ..sundials.common_defs import DTYPE

TEST_LAPACK = DTYPE == np.double

Expand Down
52 changes: 23 additions & 29 deletions scikits/odes/tests/test_on_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,15 @@

from numpy.testing import TestCase, run_module_suite

from scikits.odes import ode
from scikits.odes.sundials.cvode import StatusEnum
from scikits.odes.sundials.common_defs import DTYPE
from .. import ode
from ..sundials.cvode import StatusEnum
from ..sundials.common_defs import DTYPE
from ..sundials import log_error_handler, ontstop_stop, onroot_stop

COMMON_ARGS = {
"old_api": False,
"err_handler": log_error_handler
}

#data
g = 9.81 # gravitational constant
Expand Down Expand Up @@ -63,12 +69,6 @@ def onroot_va(t, y, solver):

return 0

def onroot_vb(t, y, solver):
"""
onroot function to stop solver when root is found
"""
return 1

def onroot_vc(t, y, solver):
"""
onroot function to reset the solver back at the start, but keep the current
Expand Down Expand Up @@ -103,12 +103,6 @@ def ontstop_va(t, y, solver):

return 0

def ontstop_vb(t, y, solver):
"""
ontstop function to stop solver when tstop is reached
"""
return 1

def ontstop_vc(t, y, solver):
"""
ontstop function to reset the solver back at the start, but keep the current
Expand All @@ -132,7 +126,7 @@ def test_cvode_rootfn_noroot(self):
#test calling sequence. End is reached before root is found
tspan = np.arange(0, t_end1 + 1, 1.0, DTYPE)
solver = ode('cvode', rhs_fn, nr_rootfns=1, rootfn=root_fn,
old_api=False)
**COMMON_ARGS)
soln = solver.solve(tspan, y0)
assert soln.flag==StatusEnum.SUCCESS, "ERROR: Error occurred"
assert allclose([soln.values.t[-1], soln.values.y[-1,0], soln.values.y[-1,1]],
Expand All @@ -143,7 +137,7 @@ def test_cvode_rootfn(self):
#test root finding and stopping: End is reached at a root
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
solver = ode('cvode', rhs_fn, nr_rootfns=1, rootfn=root_fn,
old_api=False)
**COMMON_ARGS)
soln = solver.solve(tspan, y0)
assert soln.flag==StatusEnum.ROOT_RETURN, "ERROR: Root not found!"
assert allclose([soln.roots.t[0], soln.roots.y[0,0], soln.roots.y[0,1]],
Expand All @@ -155,7 +149,7 @@ def test_cvode_rootfnacc(self):
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
solver = ode('cvode', rhs_fn, nr_rootfns=1, rootfn=root_fn,
onroot=onroot_va,
old_api=False)
**COMMON_ARGS)
soln = solver.solve(tspan, y0)
assert soln.flag==StatusEnum.SUCCESS, "ERROR: Error occurred"
assert allclose([soln.values.t[-1], soln.values.y[-1,0], soln.values.y[-1,1]],
Expand All @@ -170,8 +164,8 @@ def test_cvode_rootfn_stop(self):
#test root finding and stopping: End is reached at a root with a function
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
solver = ode('cvode', rhs_fn, nr_rootfns=1, rootfn=root_fn,
onroot=onroot_vb,
old_api=False)
onroot=onroot_stop,
**COMMON_ARGS)
soln = solver.solve(tspan, y0)
assert soln.flag==StatusEnum.ROOT_RETURN, "ERROR: Root not found!"
assert allclose([soln.roots.t[-1], soln.roots.y[-1,0], soln.roots.y[-1,1]],
Expand All @@ -183,7 +177,7 @@ def test_cvode_rootfn_test(self):
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
solver = ode('cvode', rhs_fn, nr_rootfns=1, rootfn=root_fn,
onroot=onroot_vc,
old_api=False)
**COMMON_ARGS)
soln = solver.solve(tspan, y0)
assert soln.flag==StatusEnum.ROOT_RETURN, "ERROR: Not sufficient root found"
assert allclose([soln.values.t[-1], soln.values.y[-1,0], soln.values.y[-1,1]],
Expand All @@ -199,7 +193,7 @@ def test_cvode_rootfn_two(self):
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
solver = ode('cvode', rhs_fn, nr_rootfns=2, rootfn=root_fn2,
onroot=onroot_vc,
old_api=False)
**COMMON_ARGS)
soln = solver.solve(tspan, y0)
assert soln.flag==StatusEnum.ROOT_RETURN, "ERROR: Not sufficient root found"
assert allclose([soln.values.t[-1], soln.values.y[-1,0], soln.values.y[-1,1]],
Expand All @@ -215,7 +209,7 @@ def test_cvode_rootfn_end(self):
tspan = np.arange(0, 30 + 1, 1.0, DTYPE)
solver = ode('cvode', rhs_fn, nr_rootfns=1, rootfn=root_fn3,
onroot=onroot_vc,
old_api=False)
**COMMON_ARGS)
soln = solver.solve(tspan, y0)
assert soln.flag==StatusEnum.ROOT_RETURN, "ERROR: Not sufficient root found"
assert allclose([soln.values.t[-1], soln.values.y[-1,0], soln.values.y[-1,1]],
Expand All @@ -232,7 +226,7 @@ def test_cvode_tstopfn_notstop(self):
n = 0
tspan = np.arange(0, t_end1 + 1, 1.0, DTYPE)
solver = ode('cvode', rhs_fn, tstop=T1+1, ontstop=ontstop_va,
old_api=False)
**COMMON_ARGS)

soln = solver.solve(tspan, y0)
assert soln.flag==StatusEnum.SUCCESS, "ERROR: Error occurred"
Expand All @@ -246,7 +240,7 @@ def test_cvode_tstopfn(self):
n = 0
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
solver = ode('cvode', rhs_fn, tstop=T1,
old_api=False)
**COMMON_ARGS)
soln = solver.solve(tspan, y0)
assert soln.flag==StatusEnum.TSTOP_RETURN, "ERROR: Tstop not found!"
assert allclose([soln.tstop.t[0], soln.tstop.y[0,0], soln.tstop.y[0,1]],
Expand All @@ -262,7 +256,7 @@ def test_cvode_tstopfnacc(self):
n = 0
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
solver = ode('cvode', rhs_fn, tstop=T1, ontstop=ontstop_va,
old_api=False)
**COMMON_ARGS)
soln = solver.solve(tspan, y0)
assert soln.flag==StatusEnum.SUCCESS, "ERROR: Error occurred"
assert allclose([soln.values.t[-1], soln.values.y[-1,0], soln.values.y[-1,1]],
Expand All @@ -278,8 +272,8 @@ def test_cvode_tstopfn_stop(self):
global n
n = 0
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
solver = ode('cvode', rhs_fn, tstop=T1, ontstop=ontstop_vb,
old_api=False)
solver = ode('cvode', rhs_fn, tstop=T1, ontstop=ontstop_stop,
**COMMON_ARGS)

soln = solver.solve(tspan, y0)
assert soln.flag==StatusEnum.TSTOP_RETURN, "ERROR: Error occurred"
Expand All @@ -299,7 +293,7 @@ def test_cvode_tstopfn_test(self):
n = 0
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
solver = ode('cvode', rhs_fn, tstop=T1, ontstop=ontstop_vc,
old_api=False)
**COMMON_ARGS)

soln = solver.solve(tspan, y0)
assert soln.flag==StatusEnum.TSTOP_RETURN, "ERROR: Error occurred"
Expand Down
52 changes: 23 additions & 29 deletions scikits/odes/tests/test_on_funcs_ida.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,15 @@

from numpy.testing import TestCase, run_module_suite

from scikits.odes import dae
from scikits.odes.sundials.ida import StatusEnumIDA
from scikits.odes.sundials.common_defs import DTYPE
from .. import dae
from ..sundials import log_error_handler, ontstop_stop, onroot_stop
from ..sundials.ida import StatusEnumIDA
from ..sundials.common_defs import DTYPE

COMMON_ARGS = {
"old_api": False,
"err_handler": log_error_handler
}

#data
g = 9.81 # gravitational constant
Expand Down Expand Up @@ -65,12 +71,6 @@ def onroot_va(t, y, ydot, solver):

return 0

def onroot_vb(t, y, ydot, solver):
"""
onroot function to stop solver when root is found
"""
return 1

def onroot_vc(t, y, ydot, solver):
"""
onroot function to reset the solver back at the start, but keep the current
Expand Down Expand Up @@ -105,12 +105,6 @@ def ontstop_va(t, y, ydot, solver):

return 0

def ontstop_vb(t, y, ydot, solver):
"""
ontstop function to stop solver when tstop is reached
"""
return 1

def ontstop_vc(t, y, ydot, solver):
"""
ontstop function to reset the solver back at the start, but keep the current
Expand All @@ -134,7 +128,7 @@ def test_ida_rootfn_noroot(self):
#test calling sequence. End is reached before root is found
tspan = np.arange(0, t_end1 + 1, 1.0, DTYPE)
solver = dae('ida', rhs_fn, nr_rootfns=1, rootfn=root_fn,
old_api=False)
**COMMON_ARGS)
soln = solver.solve(tspan, y0, yp0)
assert soln.flag==StatusEnumIDA.SUCCESS, "ERROR: Error occurred"
assert allclose([soln.values.t[-1], soln.values.y[-1,0], soln.values.y[-1,1]],
Expand All @@ -145,7 +139,7 @@ def test_ida_rootfn(self):
#test root finding and stopping: End is reached at a root
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
solver = dae('ida', rhs_fn, nr_rootfns=1, rootfn=root_fn,
old_api=False)
**COMMON_ARGS)
soln = solver.solve(tspan, y0, yp0)
assert soln.flag==StatusEnumIDA.ROOT_RETURN, "ERROR: Root not found!"
assert allclose([soln.roots.t[0], soln.roots.y[0,0], soln.roots.y[0,1]],
Expand All @@ -157,7 +151,7 @@ def test_ida_rootfnacc(self):
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
solver = dae('ida', rhs_fn, nr_rootfns=1, rootfn=root_fn,
onroot=onroot_va,
old_api=False)
**COMMON_ARGS)
soln = solver.solve(tspan, y0, yp0)
assert soln.flag==StatusEnumIDA.SUCCESS, "ERROR: Error occurred"
assert allclose([soln.values.t[-1], soln.values.y[-1,0], soln.values.y[-1,1]],
Expand All @@ -172,8 +166,8 @@ def test_ida_rootfn_stop(self):
#test root finding and stopping: End is reached at a root with a function
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
solver = dae('ida', rhs_fn, nr_rootfns=1, rootfn=root_fn,
onroot=onroot_vb,
old_api=False)
onroot=onroot_stop,
**COMMON_ARGS)
soln = solver.solve(tspan, y0, yp0)
assert soln.flag==StatusEnumIDA.ROOT_RETURN, "ERROR: Root not found!"
assert allclose([soln.roots.t[-1], soln.roots.y[-1,0], soln.roots.y[-1,1]],
Expand All @@ -185,7 +179,7 @@ def test_ida_rootfn_test(self):
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
solver = dae('ida', rhs_fn, nr_rootfns=1, rootfn=root_fn,
onroot=onroot_vc,
old_api=False)
**COMMON_ARGS)
soln = solver.solve(tspan, y0, yp0)
assert soln.flag==StatusEnumIDA.ROOT_RETURN, "ERROR: Not sufficient root found"
assert allclose([soln.values.t[-1], soln.values.y[-1,0], soln.values.y[-1,1]],
Expand All @@ -201,7 +195,7 @@ def test_ida_rootfn_two(self):
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
solver = dae('ida', rhs_fn, nr_rootfns=2, rootfn=root_fn2,
onroot=onroot_vc,
old_api=False)
**COMMON_ARGS)
soln = solver.solve(tspan, y0, yp0)
assert soln.flag==StatusEnumIDA.ROOT_RETURN, "ERROR: Not sufficient root found"
assert allclose([soln.values.t[-1], soln.values.y[-1,0], soln.values.y[-1,1]],
Expand All @@ -217,7 +211,7 @@ def test_ida_rootfn_end(self):
tspan = np.arange(0, 30 + 1, 1.0, DTYPE)
solver = dae('ida', rhs_fn, nr_rootfns=1, rootfn=root_fn3,
onroot=onroot_vc,
old_api=False)
**COMMON_ARGS)
soln = solver.solve(tspan, y0, yp0)
assert soln.flag==StatusEnumIDA.ROOT_RETURN, "ERROR: Not sufficient root found"
assert allclose([soln.values.t[-1], soln.values.y[-1,0], soln.values.y[-1,1]],
Expand All @@ -234,7 +228,7 @@ def test_ida_tstopfn_notstop(self):
n = 0
tspan = np.arange(0, t_end1 + 1, 1.0, DTYPE)
solver = dae('ida', rhs_fn, tstop=T1+1, ontstop=ontstop_va,
old_api=False)
**COMMON_ARGS)
soln = solver.solve(tspan, y0, yp0)
assert soln.flag==StatusEnumIDA.SUCCESS, "ERROR: Error occurred"
assert allclose([soln.values.t[-1], soln.values.y[-1,0], soln.values.y[-1,1]],
Expand All @@ -247,7 +241,7 @@ def test_ida_tstopfn(self):
n = 0
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
solver = dae('ida', rhs_fn, tstop=T1,
old_api=False)
**COMMON_ARGS)
soln = solver.solve(tspan, y0, yp0)
assert soln.flag==StatusEnumIDA.TSTOP_RETURN, "ERROR: Tstop not found!"
assert allclose([soln.tstop.t[0], soln.tstop.y[0,0], soln.tstop.y[0,1]],
Expand All @@ -263,7 +257,7 @@ def test_ida_tstopfnacc(self):
n = 0
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
solver = dae('ida', rhs_fn, tstop=T1, ontstop=ontstop_va,
old_api=False)
**COMMON_ARGS)
soln = solver.solve(tspan, y0, yp0)
assert soln.flag==StatusEnumIDA.SUCCESS, "ERROR: Error occurred"
assert allclose([soln.values.t[-1], soln.values.y[-1,0], soln.values.y[-1,1]],
Expand All @@ -279,8 +273,8 @@ def test_ida_tstopfn_stop(self):
global n
n = 0
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
solver = dae('ida', rhs_fn, tstop=T1, ontstop=ontstop_vb,
old_api=False)
solver = dae('ida', rhs_fn, tstop=T1, ontstop=ontstop_stop,
**COMMON_ARGS)

soln = solver.solve(tspan, y0, yp0)
assert soln.flag==StatusEnumIDA.TSTOP_RETURN, "ERROR: Error occurred"
Expand All @@ -300,7 +294,7 @@ def test_ida_tstopfn_test(self):
n = 0
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
solver = dae('ida', rhs_fn, tstop=T1, ontstop=ontstop_vc,
old_api=False)
**COMMON_ARGS)

soln = solver.solve(tspan, y0, yp0)
assert soln.flag==StatusEnumIDA.TSTOP_RETURN, "ERROR: Error occurred"
Expand Down
Loading

0 comments on commit 94e3670

Please sign in to comment.