Skip to content

Commit

Permalink
style: pre-commit fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
pre-commit-ci[bot] committed Dec 24, 2023
1 parent a0ca1fd commit 4f3bcbe
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
17 changes: 11 additions & 6 deletions docs/source/examples/notebooks/solvers/idaklu-jax-interface.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,13 @@
"print(data)\n",
"\n",
"# Isolate two variables from the solver\n",
"data = jax_solver.get_vars(f, [\n",
" \"Voltage [V]\",\n",
" \"Current [A]\",\n",
"])(t_eval, inputs)\n",
"data = jax_solver.get_vars(\n",
" f,\n",
" [\n",
" \"Voltage [V]\",\n",
" \"Current [A]\",\n",
" ],\n",
")(t_eval, inputs)\n",
"print(f\"\\nIsolating two variables returns an array of shape {data.shape}\")\n",
"print(data)"
]
Expand Down Expand Up @@ -350,10 +353,10 @@
"t_start = time.time()\n",
"data = jax.vmap(\n",
" jax.grad(\n",
" jax_solver.get_var(f,\"Voltage [V]\"),\n",
" jax_solver.get_var(f, \"Voltage [V]\"),\n",
" argnums=1, # take derivative with respect to `inputs`\n",
" ),\n",
" in_axes=(0, None) # map time over the 0th dimension and do not map inputs\n",
" in_axes=(0, None), # map time over the 0th dimension and do not map inputs\n",
")(t_eval, inputs)\n",
"print(f\"Gradient method ran in {time.time()-t_start:0.3} secs\")\n",
"print(data)"
Expand Down Expand Up @@ -393,11 +396,13 @@
"# Simulate some experimental data using our original parameter settings\n",
"data = sim[\"Voltage [V]\"](t_eval)\n",
"\n",
"\n",
"# Sum-of-squared errors\n",
"def sse(t, inputs):\n",
" modelled = jax_solver.get_var(f, \"Voltage [V]\")(t_eval, inputs)\n",
" return jnp.sum((modelled - data) ** 2)\n",
"\n",
"\n",
"# Provide some predicted model inputs (these could come from a fitting procedure)\n",
"inputs_pred = {\n",
" \"Current function [A]\": 0.150,\n",
Expand Down
4 changes: 1 addition & 3 deletions tests/unit/test_solvers/test_idaklu_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,9 +706,7 @@ def test_grad_scalar_getvar(self, output_variables, idaklu_jax_solver, f, wrappe
idaklu_jax_solver.get_var(f, outvar),
argnums=1,
),
)(
t_eval[k], inputs
) # output should be a dictionary of inputs
)(t_eval[k], inputs) # output should be a dictionary of inputs
print(out)
flat_out, _ = tree_flatten(out)
flat_out = np.array([f for f in flat_out]).flatten()
Expand Down

0 comments on commit 4f3bcbe

Please sign in to comment.