diff --git a/packages/scikits-odes-sundials/src/scikits_odes_sundials/c_cvode.pxd b/packages/scikits-odes-sundials/src/scikits_odes_sundials/c_cvode.pxd index 9863cd3..22d40f0 100644 --- a/packages/scikits-odes-sundials/src/scikits_odes_sundials/c_cvode.pxd +++ b/packages/scikits-odes-sundials/src/scikits_odes_sundials/c_cvode.pxd @@ -120,6 +120,7 @@ cdef extern from "cvode/cvode.h": int CVodeGetNumNonlinSolvConvFails(void *cvode_mem, long int *nncfails) int CVodeGetNonlinSolvStats(void *cvode_mem, long int *nniters, long int *nncfails) + int CVodePrintAllStats(void* cvode_mem, FILE* outfile, SUNOutputFormat fmt) char *CVodeGetReturnFlagName(long int flag) void CVodeFree(void **cvode_mem) diff --git a/packages/scikits-odes-sundials/src/scikits_odes_sundials/c_sundials.pxd b/packages/scikits-odes-sundials/src/scikits_odes_sundials/c_sundials.pxd index 259648f..5061e4a 100644 --- a/packages/scikits-odes-sundials/src/scikits_odes_sundials/c_sundials.pxd +++ b/packages/scikits-odes-sundials/src/scikits_odes_sundials/c_sundials.pxd @@ -4,6 +4,10 @@ cdef extern from "sundials/sundials_types.h": ctypedef float sunrealtype ctypedef unsigned int sunbooleantype ctypedef long sunindextype + + cdef enum SUNOutputFormat: + SUN_OUTPUTFORMAT_TABLE, + SUN_OUTPUTFORMAT_CSV cdef extern from "sundials/sundials_context.h": struct _SUNContext: diff --git a/packages/scikits-odes-sundials/src/scikits_odes_sundials/cvode.pyx b/packages/scikits-odes-sundials/src/scikits_odes_sundials/cvode.pyx index 4382a8f..e1ae46b 100644 --- a/packages/scikits-odes-sundials/src/scikits_odes_sundials/cvode.pyx +++ b/packages/scikits-odes-sundials/src/scikits_odes_sundials/cvode.pyx @@ -10,13 +10,15 @@ include "sundials_config.pxi" import numpy as np cimport numpy as np +from libc cimport stdio + from . import ( CVODESolveFailed, CVODESolveFoundRoot, CVODESolveReachedTSTOP, _get_num_args, ) from .c_sundials cimport ( - sunrealtype, N_Vector, SUNContext_Create, SUNContext_Free, + sunrealtype, N_Vector, SUNContext_Create, SUNContext_Free, SUN_OUTPUTFORMAT_TABLE ) from .c_nvector_serial cimport * from .c_sunmatrix cimport * @@ -2048,6 +2050,9 @@ cdef class CVODE: 'NumRhsEvalsJtimesFD': nfevalsLS}) return info + + def print_stats(self): + CVodePrintAllStats(self._cv_mem, stdio.stdout, SUN_OUTPUTFORMAT_TABLE) def __dealloc__(self): if self._cv_mem is not NULL: CVodeFree(&self._cv_mem) diff --git a/packages/scikits-odes/src/scikits_odes/ode.py b/packages/scikits-odes/src/scikits_odes/ode.py index a6b5337..edca67f 100644 --- a/packages/scikits-odes/src/scikits_odes/ode.py +++ b/packages/scikits-odes/src/scikits_odes/ode.py @@ -306,6 +306,12 @@ def get_info(self): else: return {} + def print_stats(self): + if hasattr(self._integrator, "print_stats"): + self._integrator.print_stats() + else: + print(f"Method `print_stats` is not implemented for integrator {self._integrator}") + #------------------------------------------------------------------------------ # ODE integrators #------------------------------------------------------------------------------