diff --git a/trax/layers/test_utils.py b/trax/layers/test_utils.py index 50988c7c7..505577e79 100644 --- a/trax/layers/test_utils.py +++ b/trax/layers/test_utils.py @@ -109,11 +109,11 @@ def test_eval_equals_predict(inp, model_fn, seq_axis=1, seq_tensor=None, for indices in indices_list: start, end = indices if seq_tensor is None: - new_inp = inp.take(indices=range(start, end), axis=seq_axis) + new_inp = inp.take(indices=np.arange(start, end), axis=seq_axis) else: new_inp = list(inp) new_inp[seq_tensor] = new_inp[seq_tensor].take( - indices=range(start, end), axis=seq_axis) + indices=np.arange(start, end), axis=seq_axis) output_predict = model_predict(new_inp, rng=rng) if not isinstance(output_predict, (tuple, list)): @@ -123,8 +123,8 @@ def test_eval_equals_predict(inp, model_fn, seq_axis=1, seq_tensor=None, np.testing.assert_equal(len(output_predict), len(output_eval)) for outp, oute in zip(output_predict, output_eval): np.testing.assert_array_almost_equal( - oute.take(indices=range(start, end), axis=seq_axis), - outp.take(indices=range(0, end-start), axis=seq_axis), + oute.take(indices=np.arange(start, end), axis=seq_axis), + outp.take(indices=np.arange(0, end-start), axis=seq_axis), decimal=5, err_msg='Error on element {} out of {}.{}'.format(indices, length, message))