Skip to content

Commit

Permalink
Ensure values passed to jax.numpy functions are arrays rather than li…
Browse files Browse the repository at this point in the history
…sts.

Why? This will soon be a requirement in JAX; see jax-ml/jax#7737

PiperOrigin-RevId: 394106413
  • Loading branch information
Jake VanderPlas authored and copybara-github committed Aug 31, 2021
1 parent 9d9e0db commit 1389ceb
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions trax/layers/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,11 @@ def test_eval_equals_predict(inp, model_fn, seq_axis=1, seq_tensor=None,
for indices in indices_list:
start, end = indices
if seq_tensor is None:
new_inp = inp.take(indices=range(start, end), axis=seq_axis)
new_inp = inp.take(indices=np.arange(start, end), axis=seq_axis)
else:
new_inp = list(inp)
new_inp[seq_tensor] = new_inp[seq_tensor].take(
indices=range(start, end), axis=seq_axis)
indices=np.arange(start, end), axis=seq_axis)

output_predict = model_predict(new_inp, rng=rng)
if not isinstance(output_predict, (tuple, list)):
Expand All @@ -123,8 +123,8 @@ def test_eval_equals_predict(inp, model_fn, seq_axis=1, seq_tensor=None,
np.testing.assert_equal(len(output_predict), len(output_eval))
for outp, oute in zip(output_predict, output_eval):
np.testing.assert_array_almost_equal(
oute.take(indices=range(start, end), axis=seq_axis),
outp.take(indices=range(0, end-start), axis=seq_axis),
oute.take(indices=np.arange(start, end), axis=seq_axis),
outp.take(indices=np.arange(0, end-start), axis=seq_axis),
decimal=5,
err_msg='Error on element {} out of {}.{}'.format(indices, length,
message))
Expand Down

0 comments on commit 1389ceb

Please sign in to comment.