Skip to content
This repository has been archived by the owner on Oct 7, 2024. It is now read-only.

Commit

Permalink
Port boot_dqn from optix to optax
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 330726981
Change-Id: Ib7fa620772c45cc4746627b271acc2dbb5455283
  • Loading branch information
mtthss authored and copybara-github committed Sep 9, 2020
1 parent 828f9bf commit f2e6c21
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions bsuite/baselines/jax/boot_dqn/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@

import haiku as hk
from jax import lax
from jax.experimental import optix
import jax.numpy as jnp
import optax

# Internal imports.

Expand Down Expand Up @@ -70,7 +70,7 @@ def network(inputs: jnp.ndarray) -> jnp.ndarray:
x = hk.Flatten()(inputs)
return net(x) + prior_scale * lax.stop_gradient(prior_net(x))

optimizer = optix.adam(learning_rate=1e-3)
optimizer = optax.adam(learning_rate=1e-3)

agent = boot_dqn.BootstrappedDqn(
obs_spec=env.observation_spec(),
Expand Down

0 comments on commit f2e6c21

Please sign in to comment.