From 651ef5ce9993803feb964aba50e3bc625dfdf107 Mon Sep 17 00:00:00 2001 From: Dmitry Kabanov Date: Tue, 12 Mar 2024 14:19:30 +0100 Subject: [PATCH] Add `print_stats` method to `ode` and implement for `CVODE` integrator `CVODE` solver has function `CVodePrintAllStats` that prints statistics about integrator such as number of right-hand-side function evaluations and number of nonlinear solves. --- .../src/scikits_odes_sundials/c_cvode.pxd | 1 + .../src/scikits_odes_sundials/c_sundials.pxd | 4 ++++ .../src/scikits_odes_sundials/cvode.pyx | 7 ++++++- packages/scikits-odes/src/scikits_odes/ode.py | 6 ++++++ 4 files changed, 17 insertions(+), 1 deletion(-) diff --git a/packages/scikits-odes-sundials/src/scikits_odes_sundials/c_cvode.pxd b/packages/scikits-odes-sundials/src/scikits_odes_sundials/c_cvode.pxd index 9863cd38..22d40f0f 100644 --- a/packages/scikits-odes-sundials/src/scikits_odes_sundials/c_cvode.pxd +++ b/packages/scikits-odes-sundials/src/scikits_odes_sundials/c_cvode.pxd @@ -120,6 +120,7 @@ cdef extern from "cvode/cvode.h": int CVodeGetNumNonlinSolvConvFails(void *cvode_mem, long int *nncfails) int CVodeGetNonlinSolvStats(void *cvode_mem, long int *nniters, long int *nncfails) + int CVodePrintAllStats(void* cvode_mem, FILE* outfile, SUNOutputFormat fmt) char *CVodeGetReturnFlagName(long int flag) void CVodeFree(void **cvode_mem) diff --git a/packages/scikits-odes-sundials/src/scikits_odes_sundials/c_sundials.pxd b/packages/scikits-odes-sundials/src/scikits_odes_sundials/c_sundials.pxd index 259648f9..5061e4ad 100644 --- a/packages/scikits-odes-sundials/src/scikits_odes_sundials/c_sundials.pxd +++ b/packages/scikits-odes-sundials/src/scikits_odes_sundials/c_sundials.pxd @@ -4,6 +4,10 @@ cdef extern from "sundials/sundials_types.h": ctypedef float sunrealtype ctypedef unsigned int sunbooleantype ctypedef long sunindextype + + cdef enum SUNOutputFormat: + SUN_OUTPUTFORMAT_TABLE, + SUN_OUTPUTFORMAT_CSV cdef extern from "sundials/sundials_context.h": struct _SUNContext: diff --git a/packages/scikits-odes-sundials/src/scikits_odes_sundials/cvode.pyx b/packages/scikits-odes-sundials/src/scikits_odes_sundials/cvode.pyx index 4382a8ff..e1ae46b8 100644 --- a/packages/scikits-odes-sundials/src/scikits_odes_sundials/cvode.pyx +++ b/packages/scikits-odes-sundials/src/scikits_odes_sundials/cvode.pyx @@ -10,13 +10,15 @@ include "sundials_config.pxi" import numpy as np cimport numpy as np +from libc cimport stdio + from . import ( CVODESolveFailed, CVODESolveFoundRoot, CVODESolveReachedTSTOP, _get_num_args, ) from .c_sundials cimport ( - sunrealtype, N_Vector, SUNContext_Create, SUNContext_Free, + sunrealtype, N_Vector, SUNContext_Create, SUNContext_Free, SUN_OUTPUTFORMAT_TABLE ) from .c_nvector_serial cimport * from .c_sunmatrix cimport * @@ -2048,6 +2050,9 @@ cdef class CVODE: 'NumRhsEvalsJtimesFD': nfevalsLS}) return info + + def print_stats(self): + CVodePrintAllStats(self._cv_mem, stdio.stdout, SUN_OUTPUTFORMAT_TABLE) def __dealloc__(self): if self._cv_mem is not NULL: CVodeFree(&self._cv_mem) diff --git a/packages/scikits-odes/src/scikits_odes/ode.py b/packages/scikits-odes/src/scikits_odes/ode.py index a6b53379..edca67f4 100644 --- a/packages/scikits-odes/src/scikits_odes/ode.py +++ b/packages/scikits-odes/src/scikits_odes/ode.py @@ -306,6 +306,12 @@ def get_info(self): else: return {} + def print_stats(self): + if hasattr(self._integrator, "print_stats"): + self._integrator.print_stats() + else: + print(f"Method `print_stats` is not implemented for integrator {self._integrator}") + #------------------------------------------------------------------------------ # ODE integrators #------------------------------------------------------------------------------