Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

error for example "Train the Policy Using BPTT" #6

Open
zhw970623 opened this issue Nov 4, 2024 · 3 comments
Open

error for example "Train the Policy Using BPTT" #6

zhw970623 opened this issue Nov 4, 2024 · 3 comments

Comments

@zhw970623
Copy link

An error occurs when executing the following code in the example train_bptt_state.ipynb

time_start = time.time()
res_dict = bptt.train(
    env,
    train_state,
    num_epochs=100,
    num_steps_per_epoch=env.max_steps_in_episode,
    num_envs=100,
    key=key_bptt,
)
time_train = time.time() - time_start
print(f"Training time: {time_train}")


JaxStackTraceBeforeTransformation Traceback (most recent call last)
File :198, in _run_module_as_main()

File :88, in _run_code()

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/ipykernel_launcher.py:18
16 from ipykernel import kernelapp as app
---> 18 app.launch_new_instance()

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/traitlets/config/application.py:1075, in launch_instance()
1074 app.initialize(argv)
-> 1075 app.start()

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/ipykernel/kernelapp.py:739, in start()
738 try:
--> 739 self.io_loop.start()
740 except KeyboardInterrupt:

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/tornado/platform/asyncio.py:205, in start()
204 def start(self) -> None:
--> 205 self.asyncio_loop.run_forever()

File ~/anaconda3/envs/flightning/lib/python3.11/asyncio/base_events.py:608, in run_forever()
607 while True:
--> 608 self._run_once()
609 if self._stopping:

File ~/anaconda3/envs/flightning/lib/python3.11/asyncio/base_events.py:1936, in _run_once()
1935 else:
-> 1936 handle._run()
1937 handle = None

File ~/anaconda3/envs/flightning/lib/python3.11/asyncio/events.py:84, in _run()
83 try:
---> 84 self._context.run(self._callback, *self._args)
85 except (SystemExit, KeyboardInterrupt):

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/ipykernel/kernelbase.py:545, in dispatch_queue()
544 try:
--> 545 await self.process_one()
546 except Exception:

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/ipykernel/kernelbase.py:534, in process_one()
533 return
--> 534 await dispatch(*args)

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/ipykernel/kernelbase.py:437, in dispatch_shell()
436 if inspect.isawaitable(result):
--> 437 await result
438 except Exception:

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/ipykernel/ipkernel.py:362, in execute_request()
361 self._associate_new_top_level_threads_with(parent_header)
--> 362 await super().execute_request(stream, ident, parent)

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/ipykernel/kernelbase.py:778, in execute_request()
777 if inspect.isawaitable(reply_content):
--> 778 reply_content = await reply_content
780 # Flush output before sending the reply.

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/ipykernel/ipkernel.py:449, in do_execute()
448 if accepts_params["cell_id"]:
--> 449 res = shell.run_cell(
450 code,
451 store_history=store_history,
452 silent=silent,
453 cell_id=cell_id,
454 )
455 else:

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/ipykernel/zmqshell.py:549, in run_cell()
548 self._last_traceback = None
--> 549 return super().run_cell(*args, **kwargs)

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3075, in run_cell()
3074 try:
-> 3075 result = self._run_cell(
3076 raw_cell, store_history, silent, shell_futures, cell_id
3077 )
3078 finally:

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3130, in _run_cell()
3129 try:
-> 3130 result = runner(coro)
3131 except BaseException as e:

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/IPython/core/async_helpers.py:128, in _pseudo_sync_runner()
127 try:
--> 128 coro.send(None)
129 except StopIteration as exc:

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3334, in run_cell_async()
3331 interactivity = "none" if silent else self.ast_node_interactivity
-> 3334 has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
3335 interactivity=interactivity, compiler=compiler, result=result)
3337 self.last_execution_succeeded = not has_raised

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3517, in run_ast_nodes()
3516 asy = compare(code)
-> 3517 if await self.run_code(code, result, async_=asy):
3518 return True

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3577, in run_code()
3576 else:
-> 3577 exec(code_obj, self.user_global_ns, self.user_ns)
3578 finally:
3579 # Reset our crash handler in place

Cell In[9], line 11
9 return train_state.apply_fn(train_state.params, obs)
---> 11 transitions = get_rollouts(env, policy, 10, jax.random.key(3))

Cell In[9], line 4, in get_rollouts()
3 rollout_keys = jax.random.split(key, num_rollouts)
----> 4 transitions = parallel_rollout(env, rollout_keys, policy)
5 return transitions

File ~/reach/rpg_flightning/flightning/envs/env_base.py:135, in rollout()
134 keys_steps = jax.random.split(key, num_steps)
--> 135 _, transitions = jax.lax.scan(step_fn, (state, obs), keys_steps)
136 # concatenate all transitions

File ~/reach/rpg_flightning/flightning/envs/env_base.py:131, in step_fn()
130 else:
--> 131 trans = env._step(env_state, action, key_step)
132 return (trans.state, trans.obs), trans

File ~/reach/rpg_flightning/flightning/envs/wrappers.py:212, in _step()
210 @partial(jax.jit, static_argnums=(0,))
211 def _step(self, state, action, key) -> EnvTransition:
--> 212 transition = self._env._step(state, action, key)
213 obs = normalize(transition.obs, self._obs_min, self._obs_max)

File ~/reach/rpg_flightning/flightning/envs/hovering_state_env.py:163, in _step()
162 f_1, omega_1 = action_1[0], action_1[1:]
--> 163 quadrotor_state = self.quadrotor.step(
164 state.quadrotor_state, f_1, omega_1, dt_1
165 )
167 if self.delay > 0:
168 # 2 step

File ~/reach/rpg_flightning/flightning/objects/quadrotor_obj.py:309, in step()
307 return state_new, state_dot_new
--> 309 return _step(state, f_d, omega_d, dt)

JaxStackTraceBeforeTransformation: TypeError: Custom JVP rule must produce primal and tangent outputs with corresponding shapes and dtypes, but got:
primal key[] with tangent key[], expecting tangent ShapedArray(float0[])

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.


The above exception was the direct cause of the following exception:

TypeError Traceback (most recent call last)
Cell In[11], line 2
1 time_start = time.time()
----> 2 res_dict = bptt.train(
3 env,
4 train_state,
5 num_epochs=100,
6 num_steps_per_epoch=env.max_steps_in_episode,
7 num_envs=100,
8 key=key_bptt,
9 )
10 time_train = time.time() - time_start
11 print(f"Training time: {time_train}")

File ~/reach/rpg_flightning/flightning/algos/bptt.py:155, in train(env, train_state, num_epochs, num_steps_per_epoch, num_envs, key)
152 env_state, obs = env.reset(key_reset, None)
153 runner_state = RunnerState(train_state, env_state, obs, key, epoch_idx=0)
--> 155 return jax.jit(_train)(runner_state)

[... skipping hidden 11 frame]

File ~/reach/rpg_flightning/flightning/algos/bptt.py:143, in train.._train(runner_state)
140 return epoch_state, loss
142 # run epochs
--> 143 runner_state_final, losses = jax.lax.scan(
144 epoch_fn, runner_state, None, num_epochs
145 )
147 return {"runner_state": runner_state_final, "metrics": losses}

[... skipping hidden 9 frame]

File ~/reach/rpg_flightning/flightning/algos/bptt.py:122, in train.._train..epoch_fn(epoch_state, _unused)
120 # compute reward
121 train_state = epoch_state.train_state
--> 122 (loss, epoch_state), grad = loss_fn(
123 train_state.params, epoch_state
124 )
125 # update params
126 train_state = train_state.apply_gradients(grads=grad)

[... skipping hidden 8 frame]

File ~/reach/rpg_flightning/flightning/algos/bptt.py:116, in train.._train..epoch_fn..loss_fn(params, runner_state)
113 return runner_state, trajectory
115 # collect data
--> 116 runner_state, trajectory = rollout(runner_state)
117 loss = -trajectory.reward.sum() / num_envs
118 return loss, runner_state

File ~/reach/rpg_flightning/flightning/algos/bptt.py:110, in train.._train..epoch_fn..loss_fn..rollout(runner_state)
101 runner_state = RunnerState(
102 train_state, env_state, obs, key, epoch_idx
103 )
105 return (
106 runner_state,
107 TrajectoryState(reward=reward),
108 )
--> 110 runner_state, trajectory = jax.lax.scan(
111 step_fn, runner_state, None, num_steps_per_epoch
112 )
113 return runner_state, trajectory

[... skipping hidden 31 frame]

[... skipping similar frames: _jvp_jaxpr at line 685 (2 times), WrappedFun.call_wrapped at line 193 (2 times), eval_jaxpr at line 508 (2 times), jaxpr_as_fun at line 260 (2 times), jvp_jaxpr at line 675 (2 times), trace_to_jaxpr_dynamic at line 2278 (2 times), trace_to_subjaxpr_dynamic at line 2301 (2 times), annotate_function.<locals>.wrapper at line 333 (2 times), _pjit_jvp at line 2045 (1 times), AxisPrimitive.bind at line 2803 (1 times), Primitive.bind_with_trace at line 442 (1 times), JVPTrace.process_primitive at line 302 (1 times)]

[... skipping hidden 25 frame]

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/jax/_src/custom_derivatives.py:351, in _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args)
344 msg = ("Custom JVP rule must produce primal and tangent outputs with "
345 "corresponding shapes and dtypes, but got:\n{}")
346 disagreements = (
347 f" primal {av_p.str_short()} with tangent {av_t.str_short()}, expecting tangent {av_et}"
348 for av_p, av_et, av_t in zip(primal_avals_out, expected_tangent_avals_out, tangent_avals_out)
349 if av_et != av_t)
--> 351 raise TypeError(msg.format('\n'.join(disagreements)))
352 yield primals_out + tangents_out, (out_tree, primal_avals)

TypeError: Custom JVP rule must produce primal and tangent outputs with corresponding shapes and dtypes, but got:
primal key[] with tangent key[], expecting tangent ShapedArray(float0[])

@zhw970623
Copy link
Author

JaxStackTraceBeforeTransformation Traceback (most recent call last)
File :198, in _run_module_as_main()

File :88, in _run_code()

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/ipykernel_launcher.py:18
16 from ipykernel import kernelapp as app
---> 18 app.launch_new_instance()

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/traitlets/config/application.py:1075, in launch_instance()
1074 app.initialize(argv)
-> 1075 app.start()

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/ipykernel/kernelapp.py:739, in start()
738 try:
--> 739 self.io_loop.start()
740 except KeyboardInterrupt:

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/tornado/platform/asyncio.py:205, in start()
204 def start(self) -> None:
--> 205 self.asyncio_loop.run_forever()

File ~/anaconda3/envs/flightning/lib/python3.11/asyncio/base_events.py:608, in run_forever()
607 while True:
...
--> 351 raise TypeError(msg.format('\n'.join(disagreements)))
352 yield primals_out + tangents_out, (out_tree, primal_avals)

TypeError: Custom JVP rule must produce primal and tangent outputs with corresponding shapes and dtypes, but got:
primal uint32[2] with tangent uint32[2], expecting tangent ShapedArray(float0[2])

@patricksharlow
Copy link

Getting the same error

@joheeg
Copy link
Contributor

joheeg commented Nov 17, 2024

This specific error is do to a change in JAX. It is related to jax-ml/jax#24262 . If you are using jax 0.4.34 or later, try using 0.4.33.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants