Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix assert error for type of keep_step #310

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

mstoelzle
Copy link

When I am running a normal integration such as

import diffrax
ode_term = ODETerm(ode_fn)
sol = diffrax.diffeqsolve(
    ode_term,
    diffrax.Euler(),
    0.0,  # initial time
    1.0,  # final time
    1e-4,  # time step
    x_init_bt.astype(jnp.float64)[0, :],  # initial state
    max_steps=20000,
)

I will get an error similar to

  File "/home/mstolzle/sources/learning-representations-from-first-principle-dynamics/src/tasks/fp_dynamics.py", line 251, in forward_fn
    sol = ode_solve_fn(
          ^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/lrfpd/lib/python3.11/site-packages/equinox/_jit.py", line 107, in __call__
    return self._call(False, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/lrfpd/lib/python3.11/site-packages/equinox/_jit.py", line 103, in _call
    out = self._cached(dynamic, static)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/lrfpd/lib/python3.11/site-packages/diffrax/integrate.py", line 824, in diffeqsolve
    final_state, aux_stats = adjoint.loop(
                             ^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/lrfpd/lib/python3.11/site-packages/diffrax/adjoint.py", line 286, in loop
    final_state = self._loop(
                  ^^^^^^^^^^^
  File "/opt/miniconda3/envs/lrfpd/lib/python3.11/site-packages/diffrax/integrate.py", line 424, in loop
    filter_state = eqx.filter_eval_shape(body_fun, init_state)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/lrfpd/lib/python3.11/site-packages/diffrax/integrate.py", line 252, in body_fun
    assert jnp.result_type(keep_step) is jnp.dtype(bool)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

When I print jnp.result_type(keep_step), I get bool instead of jnp.dtype(bool).

I would like to stress that this issue only appears for certain ode_fn. I haven't quite figured out yet which change/property of the ode_fn causes this error to occur.

Still, this backwards-compatible change should work for any case.

@patrick-kidger
Copy link
Owner

What version of JAX and what version of NumPy are you using?

@mstoelzle
Copy link
Author

What version of JAX and what version of NumPy are you using?

I am using python 3.11, numpy 1.23.5, jax 0.4.14, equinox 0.10.11, diffrax 0.4.1

@patrick-kidger
Copy link
Owner

Hmm. I'm not able to easily reproduce this with those versions. It should always be the case that jnp.result_type returns a numpy dtype.
You say this only arises for certain ode_fn. Can you provide a MWE?

@VincentStimper
Copy link

Hi @mstoelzle and @patrick-kidger,

I had the same error. assert jnp.result_type(keep_step) == jnp.dtype(bool) passed while assert jnp.result_type(keep_step) is jnp.dtype(bool) threw an error. It only occurred when I loaded my model via pickle, so I figured that through the pickle.load some code was loaded that caused the two types not to be identical objects.

However, when I initialized the model beforehand as I did during training, loading the checkpoint and subsequently using diffrax worked without any errors. I guess this might be because in this case no additional code needed to be loaded through the pickle.load call.

I hope this helps to resolve the issues you are having, @mstoelzle.

Best regards,
Vincent

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants