From f090074d865a4f98fba32bd88460373438cb4f24 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 11 Apr 2024 13:23:27 -0700 Subject: [PATCH] Avoid 'from jax import config' imports In some environments this appears to import the config module rather than the config object. --- benchmarks/api_benchmark.py | 4 +--- benchmarks/shape_poly_benchmark.py | 4 ++-- docs/debugging/flags.md | 16 +++++++------- docs/debugging/index.md | 4 ++-- docs/notebooks/Common_Gotchas_in_JAX.ipynb | 18 ++++++++-------- docs/notebooks/Common_Gotchas_in_JAX.md | 18 ++++++++-------- docs/rank_promotion_warning.rst | 4 ++-- examples/examples_test.py | 4 ++-- examples/gaussian_process_regression.py | 5 +++-- jax/_src/internal_test_util/lax_test_util.py | 4 ++-- .../array_serialization/serialization_test.py | 4 +--- .../jax2tf/examples/keras_reuse_main_test.py | 4 ++-- .../jax2tf/tests/back_compat_tf_test.py | 4 ++-- jax/experimental/jax2tf/tests/call_tf_test.py | 21 +++++++++---------- .../jax2tf/tests/control_flow_ops_test.py | 3 +-- .../jax2tf/tests/cross_compilation_check.py | 5 ++--- .../jax2tf/tests/savedmodel_test.py | 3 +-- tests/ann_test.py | 4 +--- tests/aot_test.py | 3 +-- tests/api_util_test.py | 4 ++-- tests/array_test.py | 3 +-- tests/batching_test.py | 3 +-- tests/clear_backends_test.py | 3 +-- tests/custom_linear_solve_test.py | 3 +-- tests/custom_object_test.py | 4 ++-- tests/custom_root_test.py | 3 +-- tests/debug_nans_test.py | 17 +++++++-------- tests/debugger_test.py | 3 +-- tests/debugging_primitives_test.py | 3 +-- tests/dynamic_api_test.py | 3 +-- tests/extend_test.py | 3 +-- tests/for_loop_test.py | 3 +-- tests/generated_fun_test.py | 4 ++-- tests/heap_profiler_test.py | 3 +-- tests/host_callback_test.py | 3 +-- tests/image_test.py | 4 +--- tests/infeed_test.py | 3 +-- tests/jet_test.py | 3 +-- tests/key_reuse_test.py | 3 +-- tests/lax_autodiff_test.py | 3 +-- tests/lax_control_flow_test.py | 3 +-- tests/lax_numpy_einsum_test.py | 3 +-- tests/lax_numpy_ufuncs_test.py | 3 +-- tests/lax_numpy_vectorize_test.py | 3 +-- tests/lax_scipy_special_functions_test.py | 3 +-- tests/lax_scipy_spectral_dac_test.py | 4 ++-- tests/lax_scipy_test.py | 3 +-- tests/lax_vmap_op_test.py | 3 +-- tests/lax_vmap_test.py | 3 +-- tests/lobpcg_test.py | 3 +-- tests/logging_test.py | 11 +++++----- tests/metadata_test.py | 3 +-- tests/mock_gpu_test.py | 3 +-- tests/mosaic_test.py | 4 ++-- tests/multi_device_test.py | 3 +-- tests/multibackend_test.py | 3 +-- tests/multiprocess_gpu_test.py | 3 +-- tests/name_stack_test.py | 3 +-- tests/ode_test.py | 3 +-- tests/optimizers_test.py | 3 +-- tests/pgle_test.py | 3 +-- tests/pickle_test.py | 3 +-- tests/polynomial_test.py | 4 ++-- tests/profiler_test.py | 3 +-- tests/scipy_fft_test.py | 5 ++--- tests/scipy_interpolate_test.py | 4 ++-- tests/scipy_ndimage_test.py | 4 ++-- tests/scipy_optimize_test.py | 4 ++-- tests/scipy_signal_test.py | 4 ++-- tests/scipy_spatial_test.py | 3 +-- tests/scipy_stats_test.py | 3 +-- tests/shard_alike_test.py | 3 +-- tests/source_info_test.py | 3 +-- tests/sparse_bcoo_bcsr_test.py | 5 ++--- tests/sparse_test.py | 3 +-- tests/sparsify_test.py | 4 ++-- tests/stack_test.py | 4 ++-- tests/stax_test.py | 4 ++-- tests/third_party/scipy/line_search_test.py | 3 +-- tests/transfer_guard_test.py | 4 +--- tests/util_test.py | 4 ++-- tests/x64_context_test.py | 7 +++---- tests/xmap_test.py | 19 ++++++++--------- 83 files changed, 162 insertions(+), 224 deletions(-) diff --git a/benchmarks/api_benchmark.py b/benchmarks/api_benchmark.py index 30fb04ace8e3..75cd38d10c10 100644 --- a/benchmarks/api_benchmark.py +++ b/benchmarks/api_benchmark.py @@ -33,9 +33,7 @@ import jax.numpy as jnp import numpy as np -from jax import config - -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() partial = functools.partial diff --git a/benchmarks/shape_poly_benchmark.py b/benchmarks/shape_poly_benchmark.py index bd8dd42d1052..b1b6b625ccca 100644 --- a/benchmarks/shape_poly_benchmark.py +++ b/benchmarks/shape_poly_benchmark.py @@ -15,12 +15,12 @@ import google_benchmark as benchmark -from jax import config +import jax from jax import core from jax._src.numpy import lax_numpy from jax.experimental import export -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() @benchmark.register diff --git a/docs/debugging/flags.md b/docs/debugging/flags.md index 8434384c4eb5..90a6cb3bbfbd 100644 --- a/docs/debugging/flags.md +++ b/docs/debugging/flags.md @@ -12,14 +12,14 @@ JAX offers flags and context managers that enable catching errors more easily. If you want to trace where NaNs are occurring in your functions or gradients, you can turn on the NaN-checker by: * setting the `JAX_DEBUG_NANS=True` environment variable; -* adding `from jax import config` and `config.update("jax_debug_nans", True)` near the top of your main file; -* adding from `jax.config import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`; +* adding `jax.config.update("jax_debug_nans", True)` near the top of your main file; +* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`; ### Example(s) ```python -from jax import config -config.update("jax_debug_nans", True) +import jax +jax.config.update("jax_debug_nans", True) def f(x, y): return x / y @@ -47,14 +47,14 @@ jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception! You can disable JIT-compilation by: * setting the `JAX_DISABLE_JIT=True` environment variable; -* adding `from jax import config` and `config.update("jax_disable_jit", True)` near the top of your main file; -* adding from `jax.config import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_disable_jit=True`; +* adding `jax.config.update("jax_disable_jit", True)` near the top of your main file; +* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_disable_jit=True`; ### Examples ```python -from jax import config -config.update("jax_disable_jit", True) +import jax +jax.config.update("jax_disable_jit", True) def f(x): y = jnp.log(x) diff --git a/docs/debugging/index.md b/docs/debugging/index.md index 9a020b360b81..35e0f68950c4 100644 --- a/docs/debugging/index.md +++ b/docs/debugging/index.md @@ -82,8 +82,8 @@ Click [here](checkify_guide) to learn more! **TL;DR** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`. ```python -from jax import config -config.update("jax_debug_nans", True) +import jax +jax.config.update("jax_debug_nans", True) def f(x, y): return x / y diff --git a/docs/notebooks/Common_Gotchas_in_JAX.ipynb b/docs/notebooks/Common_Gotchas_in_JAX.ipynb index 13fdd572b642..d8dffdb8a2f1 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.ipynb +++ b/docs/notebooks/Common_Gotchas_in_JAX.ipynb @@ -1946,9 +1946,9 @@ "\n", "* setting the `JAX_DEBUG_NANS=True` environment variable;\n", "\n", - "* adding `from jax import config` and `config.update(\"jax_debug_nans\", True)` near the top of your main file;\n", + "* adding `jax.config.update(\"jax_debug_nans\", True)` near the top of your main file;\n", "\n", - "* adding `from jax import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;\n", + "* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;\n", "\n", "This will cause computations to error-out immediately on production of a NaN. Switching this option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an `@jit`. For code under an `@jit`, the output of every `@jit` function is checked and if a nan is present it will re-run the function in de-optimized op-by-op mode, effectively removing one level of `@jit` at a time.\n", "\n", @@ -2141,24 +2141,24 @@ "\n", " ```python\n", " # again, this only works on startup!\n", - " from jax import config\n", - " config.update(\"jax_enable_x64\", True)\n", + " import jax\n", + " jax.config.update(\"jax_enable_x64\", True)\n", " ```\n", "\n", "3. You can parse command-line flags with `absl.app.run(main)`\n", "\n", " ```python\n", - " from jax import config\n", - " config.config_with_absl()\n", + " import jax\n", + " jax.config.config_with_absl()\n", " ```\n", "\n", "4. If you want JAX to run absl parsing for you, i.e. you don't want to do `absl.app.run(main)`, you can instead use\n", "\n", " ```python\n", - " from jax import config\n", + " import jax\n", " if __name__ == '__main__':\n", - " # calls config.config_with_absl() *and* runs absl parsing\n", - " config.parse_flags_with_absl()\n", + " # calls jax.config.config_with_absl() *and* runs absl parsing\n", + " jax.config.parse_flags_with_absl()\n", " ```\n", "\n", "Note that #2-#4 work for _any_ of JAX's configuration options.\n", diff --git a/docs/notebooks/Common_Gotchas_in_JAX.md b/docs/notebooks/Common_Gotchas_in_JAX.md index 0e5af8b04ad6..e63d64d94e77 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.md +++ b/docs/notebooks/Common_Gotchas_in_JAX.md @@ -938,9 +938,9 @@ If you want to trace where NaNs are occurring in your functions or gradients, yo * setting the `JAX_DEBUG_NANS=True` environment variable; -* adding `from jax import config` and `config.update("jax_debug_nans", True)` near the top of your main file; +* adding `jax.config.update("jax_debug_nans", True)` near the top of your main file; -* adding `from jax import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`; +* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`; This will cause computations to error-out immediately on production of a NaN. Switching this option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an `@jit`. For code under an `@jit`, the output of every `@jit` function is checked and if a nan is present it will re-run the function in de-optimized op-by-op mode, effectively removing one level of `@jit` at a time. @@ -1087,24 +1087,24 @@ There are a few ways to do this: ```python # again, this only works on startup! - from jax import config - config.update("jax_enable_x64", True) + import jax + jax.config.update("jax_enable_x64", True) ``` 3. You can parse command-line flags with `absl.app.run(main)` ```python - from jax import config - config.config_with_absl() + import jax + jax.config.config_with_absl() ``` 4. If you want JAX to run absl parsing for you, i.e. you don't want to do `absl.app.run(main)`, you can instead use ```python - from jax import config + import jax if __name__ == '__main__': - # calls config.config_with_absl() *and* runs absl parsing - config.parse_flags_with_absl() + # calls jax.config.config_with_absl() *and* runs absl parsing + jax.config.parse_flags_with_absl() ``` Note that #2-#4 work for _any_ of JAX's configuration options. diff --git a/docs/rank_promotion_warning.rst b/docs/rank_promotion_warning.rst index e81509e2a941..5e4e7ec65cbc 100644 --- a/docs/rank_promotion_warning.rst +++ b/docs/rank_promotion_warning.rst @@ -40,8 +40,8 @@ One is by using :code:`jax.config` in your code: .. code-block:: python - from jax import config - config.update("jax_numpy_rank_promotion", "warn") + import jax + jax.config.update("jax_numpy_rank_promotion", "warn") You can also set the option using the environment variable :code:`JAX_NUMPY_RANK_PROMOTION`, for example as diff --git a/examples/examples_test.py b/examples/examples_test.py index b8b4d11e273d..c9cb2991c030 100644 --- a/examples/examples_test.py +++ b/examples/examples_test.py @@ -22,6 +22,7 @@ import numpy as np +import jax from jax import lax from jax import random import jax.numpy as jnp @@ -30,8 +31,7 @@ from examples import kernel_lsq sys.path.pop() -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def _CheckShapeAgreement(test_case, init_fun, apply_fun, input_shape): diff --git a/examples/gaussian_process_regression.py b/examples/gaussian_process_regression.py index c42a024d42aa..75f7398d12ca 100644 --- a/examples/gaussian_process_regression.py +++ b/examples/gaussian_process_regression.py @@ -17,10 +17,11 @@ from absl import app from functools import partial + +import jax from jax import grad from jax import jit from jax import vmap -from jax import config import jax.numpy as jnp import jax.random as random import jax.scipy as scipy @@ -125,5 +126,5 @@ def train_step(params, momentums, scales, x, y): mu.flatten() - std * 2, mu.flatten() + std * 2) if __name__ == "__main__": - config.config_with_absl() + jax.config.config_with_absl() app.run(main) diff --git a/jax/_src/internal_test_util/lax_test_util.py b/jax/_src/internal_test_util/lax_test_util.py index 5cc08eb05231..b57b7d0852a9 100644 --- a/jax/_src/internal_test_util/lax_test_util.py +++ b/jax/_src/internal_test_util/lax_test_util.py @@ -23,6 +23,7 @@ import itertools from typing import Union, cast +import jax from jax import lax from jax._src import dtypes from jax._src import test_util @@ -30,8 +31,7 @@ import numpy as np -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip diff --git a/jax/experimental/array_serialization/serialization_test.py b/jax/experimental/array_serialization/serialization_test.py index b9e3192f9b8f..6cb5347e9686 100644 --- a/jax/experimental/array_serialization/serialization_test.py +++ b/jax/experimental/array_serialization/serialization_test.py @@ -24,16 +24,14 @@ from absl.testing import parameterized import jax from jax._src import test_util as jtu -from jax import config from jax._src import array from jax.sharding import NamedSharding, GSPMDSharding from jax.sharding import PartitionSpec as P from jax.experimental.array_serialization import serialization import numpy as np import tensorstore as ts -import unittest -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() prev_xla_flags = None diff --git a/jax/experimental/jax2tf/examples/keras_reuse_main_test.py b/jax/experimental/jax2tf/examples/keras_reuse_main_test.py index 562369cdb9df..2934842912f0 100644 --- a/jax/experimental/jax2tf/examples/keras_reuse_main_test.py +++ b/jax/experimental/jax2tf/examples/keras_reuse_main_test.py @@ -16,13 +16,13 @@ from absl import flags from absl.testing import absltest from absl.testing import parameterized +import jax from jax._src import test_util as jtu -from jax import config from jax.experimental.jax2tf.examples import keras_reuse_main from jax.experimental.jax2tf.tests import tf_test_util -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() FLAGS = flags.FLAGS diff --git a/jax/experimental/jax2tf/tests/back_compat_tf_test.py b/jax/experimental/jax2tf/tests/back_compat_tf_test.py index 47c5c8360cf5..bd31c19ba120 100644 --- a/jax/experimental/jax2tf/tests/back_compat_tf_test.py +++ b/jax/experimental/jax2tf/tests/back_compat_tf_test.py @@ -27,7 +27,7 @@ from typing import Callable, Optional from absl.testing import absltest -from jax import config +import jax from jax._src import test_util as jtu from jax._src.internal_test_util import export_back_compat_test_util as bctu from jax._src.lib import xla_extension @@ -37,7 +37,7 @@ import tensorflow as tf -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def serialize_directory(directory_path): diff --git a/jax/experimental/jax2tf/tests/call_tf_test.py b/jax/experimental/jax2tf/tests/call_tf_test.py index 57eea5f6a35d..3a0bdffd0f70 100644 --- a/jax/experimental/jax2tf/tests/call_tf_test.py +++ b/jax/experimental/jax2tf/tests/call_tf_test.py @@ -23,7 +23,6 @@ from absl.testing import absltest from absl.testing import parameterized import jax -from jax import config from jax import dlpack from jax import dtypes from jax import lax @@ -42,7 +41,7 @@ except ImportError: tf = None -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def _maybe_jit(with_jit: bool, func: Callable) -> Callable: @@ -1151,15 +1150,15 @@ def setUp(self): super().setUp() def override_serialization_version(self, version_override: int): - version = config.jax_serialization_version + version = jax.config.jax_serialization_version if version != version_override: - self.addCleanup(partial(config.update, + self.addCleanup(partial(jax.config.update, "jax_serialization_version", version_override)) - config.update("jax_serialization_version", version_override) + jax.config.update("jax_serialization_version", version_override) logging.info( "Using JAX serialization version %s", - config.jax_serialization_version) + jax.config.jax_serialization_version) def test_alternate(self): # Alternate sin/cos with sin in TF and cos in JAX @@ -1275,7 +1274,7 @@ def fun_tf(x): # x:i32[3] @_parameterized_jit def test_shape_poly_static_output_shape(self, with_jit=True): - if config.jax2tf_default_native_serialization: + if jax.config.jax2tf_default_native_serialization: raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native serialization.") x = np.array([0.7, 0.8], dtype=np.float32) @@ -1289,7 +1288,7 @@ def fun_tf(x): @_parameterized_jit def test_shape_poly(self, with_jit=False): - if config.jax2tf_default_native_serialization: + if jax.config.jax2tf_default_native_serialization: raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native serialization.") x = np.array([7, 8, 9, 10], dtype=np.float32) def fun_jax(x): @@ -1308,7 +1307,7 @@ def fun_jax(x): @_parameterized_jit def test_shape_poly_pytree_result(self, with_jit=True): - if config.jax2tf_default_native_serialization: + if jax.config.jax2tf_default_native_serialization: raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native serialization.") x = np.array([7, 8, 9, 10], dtype=np.float32) def fun_jax(x): @@ -1394,7 +1393,7 @@ def fun_jax(x): if kind == "bad_dim" and with_jit: # TODO: in jit more the error pops up later, at AddV2 expect_error = "Dimensions must be equal, but are 4 and 9 for .* AddV2" - if kind == "bad_dim" and config.jax2tf_default_native_serialization: + if kind == "bad_dim" and jax.config.jax2tf_default_native_serialization: # TODO(b/268386622): call_tf with shape polymorphism and native serialization. expect_error = "Error compiling TensorFlow function" fun_tf_rt = _maybe_tf_jit(with_jit, @@ -1432,7 +1431,7 @@ def test_several_round_trips(self, f4_function=False, f4_saved_model=False): if (f2_saved_model and f4_saved_model and - not config.jax2tf_default_native_serialization): + not jax.config.jax2tf_default_native_serialization): # TODO: Getting error Found invalid capture Tensor("jax2tf_vjp/jax2tf_arg_0:0", shape=(), dtype=float32) when saving custom gradients # when saving f4, but only with non-native serialization. raise unittest.SkipTest("TODO: error invalid capture when saving custom gradients") diff --git a/jax/experimental/jax2tf/tests/control_flow_ops_test.py b/jax/experimental/jax2tf/tests/control_flow_ops_test.py index 253a5ffc68a7..c66a6d696e89 100644 --- a/jax/experimental/jax2tf/tests/control_flow_ops_test.py +++ b/jax/experimental/jax2tf/tests/control_flow_ops_test.py @@ -23,8 +23,7 @@ from jax.experimental.jax2tf.tests import tf_test_util -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class ControlFlowOpsTest(tf_test_util.JaxToTfTestCase): diff --git a/jax/experimental/jax2tf/tests/cross_compilation_check.py b/jax/experimental/jax2tf/tests/cross_compilation_check.py index 63e8928ee371..0a4bf61f8847 100644 --- a/jax/experimental/jax2tf/tests/cross_compilation_check.py +++ b/jax/experimental/jax2tf/tests/cross_compilation_check.py @@ -39,12 +39,11 @@ import numpy.random as npr -import jax -from jax import config # Must import before TF +import jax # Must import before TF from jax.experimental import jax2tf # Defines needed flags from jax._src import test_util # Defines needed flags -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() # Import after parsing flags from jax.experimental.jax2tf.tests import primitive_harness diff --git a/jax/experimental/jax2tf/tests/savedmodel_test.py b/jax/experimental/jax2tf/tests/savedmodel_test.py index 86ae81f8373a..37e7eb24fd14 100644 --- a/jax/experimental/jax2tf/tests/savedmodel_test.py +++ b/jax/experimental/jax2tf/tests/savedmodel_test.py @@ -25,8 +25,7 @@ from jax.experimental.jax2tf.tests import tf_test_util from jax._src import test_util as jtu -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class SavedModelTest(tf_test_util.JaxToTfTestCase): diff --git a/tests/ann_test.py b/tests/ann_test.py index ab35ce0c5392..1d704c725c61 100644 --- a/tests/ann_test.py +++ b/tests/ann_test.py @@ -23,9 +23,7 @@ from jax import lax from jax._src import test_util as jtu -from jax import config - -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() ignore_jit_of_pmap_warning = partial( jtu.ignore_warning,message=".*jit-of-pmap.*") diff --git a/tests/aot_test.py b/tests/aot_test.py index dacfa620c628..bca0d66ed384 100644 --- a/tests/aot_test.py +++ b/tests/aot_test.py @@ -17,7 +17,6 @@ import unittest from absl.testing import absltest import jax -from jax import config from jax._src import core from jax._src import test_util as jtu from jax._src.lib import xla_client as xc @@ -31,7 +30,7 @@ from jax.sharding import PartitionSpec as P import numpy as np -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() prev_xla_flags = None diff --git a/tests/api_util_test.py b/tests/api_util_test.py index 7b7a479dbf14..f78b5948f4e7 100644 --- a/tests/api_util_test.py +++ b/tests/api_util_test.py @@ -16,12 +16,12 @@ import itertools as it from absl.testing import absltest from absl.testing import parameterized +import jax from jax._src import api_util from jax import numpy as jnp from jax._src import test_util as jtu -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class ApiUtilTest(jtu.JaxTestCase): diff --git a/tests/array_test.py b/tests/array_test.py index 7c8d4c355333..0d8dba0bd6b0 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -40,8 +40,7 @@ from jax._src import array from jax._src import prng -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() prev_xla_flags = None diff --git a/tests/batching_test.py b/tests/batching_test.py index afbe9cf707ba..36e686443ac7 100644 --- a/tests/batching_test.py +++ b/tests/batching_test.py @@ -37,8 +37,7 @@ from jax.interpreters import batching from jax.tree_util import register_pytree_node -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() # These are 'manual' tests for batching (vmap). The more exhaustive, more diff --git a/tests/clear_backends_test.py b/tests/clear_backends_test.py index f8d5271ce402..9ea9cac3a72c 100644 --- a/tests/clear_backends_test.py +++ b/tests/clear_backends_test.py @@ -15,12 +15,11 @@ from absl.testing import absltest import jax -from jax import config from jax._src import api from jax._src import test_util as jtu from jax._src import xla_bridge as xb -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class ClearBackendsTest(jtu.JaxTestCase): diff --git a/tests/custom_linear_solve_test.py b/tests/custom_linear_solve_test.py index 2c3d2a258a56..830526826059 100644 --- a/tests/custom_linear_solve_test.py +++ b/tests/custom_linear_solve_test.py @@ -28,8 +28,7 @@ import jax.numpy as jnp # scan tests use numpy import jax.scipy as jsp -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def high_precision_dot(a, b): diff --git a/tests/custom_object_test.py b/tests/custom_object_test.py index 036f912a3c7a..75ff39630705 100644 --- a/tests/custom_object_test.py +++ b/tests/custom_object_test.py @@ -18,9 +18,9 @@ import numpy as np +import jax import jax.numpy as jnp from jax import jit, lax, make_jaxpr -from jax import config from jax.interpreters import mlir from jax.interpreters import xla @@ -34,7 +34,7 @@ xc = xla_client xb = xla_bridge -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() # TODO(jakevdp): use a setup/teardown method to populate and unpopulate all the # dictionaries associated with the following objects. diff --git a/tests/custom_root_test.py b/tests/custom_root_test.py index 88dee90aad9c..6a7eaab17657 100644 --- a/tests/custom_root_test.py +++ b/tests/custom_root_test.py @@ -25,8 +25,7 @@ import jax.numpy as jnp # scan tests use numpy import jax.scipy as jsp -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def high_precision_dot(a, b): diff --git a/tests/debug_nans_test.py b/tests/debug_nans_test.py index 8dc9818f86eb..e5743944401b 100644 --- a/tests/debug_nans_test.py +++ b/tests/debug_nans_test.py @@ -26,19 +26,18 @@ from jax.experimental import pjit from jax._src.maps import xmap -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class DebugNaNsTest(jtu.JaxTestCase): def setUp(self): super().setUp() - self.cfg = config._read("jax_debug_nans") - config.update("jax_debug_nans", True) + self.cfg = jax.config._read("jax_debug_nans") + jax.config.update("jax_debug_nans", True) def tearDown(self): - config.update("jax_debug_nans", self.cfg) + jax.config.update("jax_debug_nans", self.cfg) super().tearDown() def testSinc(self): @@ -67,7 +66,7 @@ def testJitComputationNaN(self): ans.block_until_ready() def testJitComputationNaNContextManager(self): - config.update("jax_debug_nans", False) + jax.config.update("jax_debug_nans", False) A = jnp.array(0.) f = jax.jit(lambda x: 0. / x) ans = f(A) @@ -210,11 +209,11 @@ class DebugInfsTest(jtu.JaxTestCase): def setUp(self): super().setUp() - self.cfg = config._read("jax_debug_infs") - config.update("jax_debug_infs", True) + self.cfg = jax.config._read("jax_debug_infs") + jax.config.update("jax_debug_infs", True) def tearDown(self): - config.update("jax_debug_infs", self.cfg) + jax.config.update("jax_debug_infs", self.cfg) super().tearDown() def testSingleResultPrimitiveNoInf(self): diff --git a/tests/debugger_test.py b/tests/debugger_test.py index 5e6d4388f130..66488feb85da 100644 --- a/tests/debugger_test.py +++ b/tests/debugger_test.py @@ -21,14 +21,13 @@ from absl.testing import absltest import jax -from jax import config from jax.experimental import pjit from jax._src import debugger from jax._src import test_util as jtu import jax.numpy as jnp import numpy as np -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def make_fake_stdin_stdout(commands: Sequence[str]) -> tuple[IO[str], io.StringIO]: fake_stdin = io.StringIO() diff --git a/tests/debugging_primitives_test.py b/tests/debugging_primitives_test.py index ce0e5e3b27ab..51c91d9aa740 100644 --- a/tests/debugging_primitives_test.py +++ b/tests/debugging_primitives_test.py @@ -19,7 +19,6 @@ from absl.testing import absltest import jax from jax import lax -from jax import config from jax.experimental import pjit from jax.interpreters import pxla from jax._src import ad_checkpoint @@ -35,7 +34,7 @@ except ModuleNotFoundError: rich = None -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() debug_print = debugging.debug_print diff --git a/tests/dynamic_api_test.py b/tests/dynamic_api_test.py index c704d7e10b16..13e9cc5bb415 100644 --- a/tests/dynamic_api_test.py +++ b/tests/dynamic_api_test.py @@ -23,7 +23,6 @@ import jax import jax.numpy as jnp from jax import lax -from jax import config from jax.interpreters import batching import jax._src.lib @@ -31,7 +30,7 @@ from jax._src import core from jax._src import test_util as jtu -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() @jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow") diff --git a/tests/extend_test.py b/tests/extend_test.py index b49c1ac09214..a926861ebece 100644 --- a/tests/extend_test.py +++ b/tests/extend_test.py @@ -24,8 +24,7 @@ from jax._src import prng from jax._src import test_util as jtu -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class ExtendTest(jtu.JaxTestCase): diff --git a/tests/for_loop_test.py b/tests/for_loop_test.py index 641f12ff0e1c..cbbe56a62094 100644 --- a/tests/for_loop_test.py +++ b/tests/for_loop_test.py @@ -24,8 +24,7 @@ from jax._src.lax.control_flow import for_loop import jax.numpy as jnp -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def remat_of_for_loop(nsteps, body, state, **kwargs): return jax.remat(lambda state: for_loop.for_loop(nsteps, body, state, diff --git a/tests/generated_fun_test.py b/tests/generated_fun_test.py index e96f100b46d4..a288e1a5f19a 100644 --- a/tests/generated_fun_test.py +++ b/tests/generated_fun_test.py @@ -22,11 +22,11 @@ import itertools as it import jax.numpy as jnp +import jax from jax import jit, jvp, vjp import jax._src.test_util as jtu -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() npr.seed(0) diff --git a/tests/heap_profiler_test.py b/tests/heap_profiler_test.py index 6d3468e95ac7..240eec1c8fba 100644 --- a/tests/heap_profiler_test.py +++ b/tests/heap_profiler_test.py @@ -17,11 +17,10 @@ import jax import jax._src.xla_bridge as xla_bridge -from jax import config import jax._src.test_util as jtu -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class HeapProfilerTest(unittest.TestCase): diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py index 9c5ab78cbd9e..99d34f30e634 100644 --- a/tests/host_callback_test.py +++ b/tests/host_callback_test.py @@ -30,7 +30,6 @@ import jax from jax import ad_checkpoint -from jax import config from jax import dtypes from jax import lax from jax import numpy as jnp @@ -46,7 +45,7 @@ import numpy as np -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class _TestingOutputStream: diff --git a/tests/image_test.py b/tests/image_test.py index 6204ec91cd5c..f3cd56ed7622 100644 --- a/tests/image_test.py +++ b/tests/image_test.py @@ -24,8 +24,6 @@ from jax import numpy as jnp from jax._src import test_util as jtu -from jax import config - # We use TensorFlow and PIL as reference implementations. try: import tensorflow as tf @@ -37,7 +35,7 @@ except ImportError: PIL_Image = None -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() float_dtypes = jtu.dtypes.all_floating inexact_dtypes = jtu.dtypes.inexact diff --git a/tests/infeed_test.py b/tests/infeed_test.py index 572920fa4d3b..ba47d2417f94 100644 --- a/tests/infeed_test.py +++ b/tests/infeed_test.py @@ -19,7 +19,6 @@ from absl.testing import absltest import jax from jax import lax, numpy as jnp -from jax import config from jax.experimental import host_callback as hcb from jax._src import core from jax._src import xla_bridge @@ -27,7 +26,7 @@ import jax._src.test_util as jtu import numpy as np -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class InfeedTest(jtu.JaxTestCase): diff --git a/tests/jet_test.py b/tests/jet_test.py index c72057246ebc..5661197509ec 100644 --- a/tests/jet_test.py +++ b/tests/jet_test.py @@ -29,8 +29,7 @@ from jax.experimental.jet import jet, fact, zero_series from jax import lax -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def jvp_taylor(fun, primals, series): # Computes the Taylor series the slow way, with nested jvp. diff --git a/tests/key_reuse_test.py b/tests/key_reuse_test.py index 885b08224dcd..d98984be5c96 100644 --- a/tests/key_reuse_test.py +++ b/tests/key_reuse_test.py @@ -29,8 +29,7 @@ Source, Sink, Forward, KeyReuseSignature) from jax.experimental.key_reuse import _core -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() key = jax.eval_shape(jax.random.key, 0) diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py index 630b08cc3694..ab3a183177f6 100644 --- a/tests/lax_autodiff_test.py +++ b/tests/lax_autodiff_test.py @@ -31,8 +31,7 @@ from jax._src.util import NumpyComplexWarning from jax.test_util import check_grads -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() compatible_shapes = [[(3,)], diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 5c737cc4dcd6..0a17a1421556 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -42,8 +42,7 @@ from jax._src.lax.control_flow import for_loop from jax._src.maps import xmap -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() # Some tests are useful for testing both lax.cond and lax.switch. This function diff --git a/tests/lax_numpy_einsum_test.py b/tests/lax_numpy_einsum_test.py index 92259c8f4342..423289f3d001 100644 --- a/tests/lax_numpy_einsum_test.py +++ b/tests/lax_numpy_einsum_test.py @@ -27,8 +27,7 @@ import jax.numpy as jnp import jax._src.test_util as jtu -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class EinsumTest(jtu.JaxTestCase): diff --git a/tests/lax_numpy_ufuncs_test.py b/tests/lax_numpy_ufuncs_test.py index 0f40e9d4d97e..40c9eb3bc4f4 100644 --- a/tests/lax_numpy_ufuncs_test.py +++ b/tests/lax_numpy_ufuncs_test.py @@ -24,8 +24,7 @@ from jax._src import test_util as jtu from jax._src.numpy.ufunc_api import get_if_single_primitive -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def scalar_add(x, y): diff --git a/tests/lax_numpy_vectorize_test.py b/tests/lax_numpy_vectorize_test.py index cb0d9a0dcf64..edc344467d7c 100644 --- a/tests/lax_numpy_vectorize_test.py +++ b/tests/lax_numpy_vectorize_test.py @@ -21,8 +21,7 @@ from jax import numpy as jnp from jax._src import test_util as jtu -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class VectorizeTest(jtu.JaxTestCase): diff --git a/tests/lax_scipy_special_functions_test.py b/tests/lax_scipy_special_functions_test.py index 564ecca86f63..be10f03fb938 100644 --- a/tests/lax_scipy_special_functions_test.py +++ b/tests/lax_scipy_special_functions_test.py @@ -26,8 +26,7 @@ from jax._src import test_util as jtu from jax.scipy import special as lsp_special -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)] diff --git a/tests/lax_scipy_spectral_dac_test.py b/tests/lax_scipy_spectral_dac_test.py index 2d353d5909e2..a09dcac5371c 100644 --- a/tests/lax_scipy_spectral_dac_test.py +++ b/tests/lax_scipy_spectral_dac_test.py @@ -14,6 +14,7 @@ import unittest +import jax from jax import lax from jax import numpy as jnp from jax._src import test_util as jtu @@ -21,8 +22,7 @@ from absl.testing import absltest -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() linear_sizes = [16, 97, 128] diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index e9f2e6bb9776..cf3edbfd397c 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -34,8 +34,7 @@ from jax.scipy import special as lsp_special from jax.scipy import cluster as lsp_cluster -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() scipy_version = jtu.parse_version(scipy.version.version) diff --git a/tests/lax_vmap_op_test.py b/tests/lax_vmap_op_test.py index 5d30281327d6..c7059a29343c 100644 --- a/tests/lax_vmap_op_test.py +++ b/tests/lax_vmap_op_test.py @@ -26,8 +26,7 @@ from jax._src.internal_test_util import lax_test_util from jax._src import util -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() map, unsafe_map = util.safe_map, map zip, unsafe_zip = util.safe_zip, zip diff --git a/tests/lax_vmap_test.py b/tests/lax_vmap_test.py index 0d22d801d085..37d51c04f8de 100644 --- a/tests/lax_vmap_test.py +++ b/tests/lax_vmap_test.py @@ -35,8 +35,7 @@ from jax._src.lib import xla_client from jax._src.util import safe_map, safe_zip -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip diff --git a/tests/lobpcg_test.py b/tests/lobpcg_test.py index 3a4a2196c20f..1953114cb2b3 100644 --- a/tests/lobpcg_test.py +++ b/tests/lobpcg_test.py @@ -30,7 +30,6 @@ import scipy.sparse as sps import jax -from jax import config from jax._src import test_util as jtu from jax.experimental.sparse import linalg, bcoo import jax.numpy as jnp @@ -433,5 +432,5 @@ def testCallableMatricesF64(self, matrix_name): if __name__ == '__main__': - config.parse_flags_with_absl() + jax.config.parse_flags_with_absl() absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/logging_test.py b/tests/logging_test.py index 05bb31015c1a..6b02432ce2b1 100644 --- a/tests/logging_test.py +++ b/tests/logging_test.py @@ -22,7 +22,6 @@ import unittest import jax -from jax import config import jax._src.test_util as jtu from jax._src import xla_bridge @@ -33,7 +32,7 @@ # parsing to work correctly with bazel (otherwise we could avoid importing # absltest/absl logging altogether). from absl.testing import absltest -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() @contextlib.contextmanager @@ -96,27 +95,27 @@ def test_debug_logging(self): self.assertEmpty(log_output.getvalue()) # Turn on all debug logging. - config.update("jax_debug_log_modules", "jax") + jax.config.update("jax_debug_log_modules", "jax") with capture_jax_logs() as log_output: jax.jit(lambda x: x + 1)(1) self.assertIn("Finished tracing + transforming", log_output.getvalue()) self.assertIn("Compiling ", log_output.getvalue()) # Turn off all debug logging. - config.update("jax_debug_log_modules", None) + jax.config.update("jax_debug_log_modules", None) with capture_jax_logs() as log_output: jax.jit(lambda x: x + 1)(1) self.assertEmpty(log_output.getvalue()) # Turn on one module. - config.update("jax_debug_log_modules", "jax._src.dispatch") + jax.config.update("jax_debug_log_modules", "jax._src.dispatch") with capture_jax_logs() as log_output: jax.jit(lambda x: x + 1)(1) self.assertIn("Finished tracing + transforming", log_output.getvalue()) self.assertNotIn("Compiling ", log_output.getvalue()) # Turn everything off again. - config.update("jax_debug_log_modules", None) + jax.config.update("jax_debug_log_modules", None) with capture_jax_logs() as log_output: jax.jit(lambda x: x + 1)(1) self.assertEmpty(log_output.getvalue()) diff --git a/tests/metadata_test.py b/tests/metadata_test.py index 3511595d6c5a..e01ba538b342 100644 --- a/tests/metadata_test.py +++ b/tests/metadata_test.py @@ -23,8 +23,7 @@ from jax._src.lib.mlir import ir from jax import numpy as jnp -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def module_to_string(module: ir.Module) -> str: diff --git a/tests/mock_gpu_test.py b/tests/mock_gpu_test.py index b955f0398e0c..ba735775beab 100644 --- a/tests/mock_gpu_test.py +++ b/tests/mock_gpu_test.py @@ -17,14 +17,13 @@ from absl.testing import absltest import jax -from jax import config from jax._src import test_util as jtu import jax.numpy as jnp from jax.sharding import NamedSharding from jax.sharding import PartitionSpec as P import numpy as np -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class MockGPUTest(jtu.JaxTestCase): diff --git a/tests/mosaic_test.py b/tests/mosaic_test.py index 518766c1e7d1..03c8f1ce36f0 100644 --- a/tests/mosaic_test.py +++ b/tests/mosaic_test.py @@ -14,9 +14,9 @@ from absl.testing import absltest from jax._src import test_util as jtu -from jax import config +import jax -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class ImportTest(jtu.JaxTestCase): diff --git a/tests/multi_device_test.py b/tests/multi_device_test.py index 0060df9deada..85386566885b 100644 --- a/tests/multi_device_test.py +++ b/tests/multi_device_test.py @@ -26,8 +26,7 @@ from jax._src import test_util as jtu from jax._src import xla_bridge -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() prev_xla_flags = None diff --git a/tests/multibackend_test.py b/tests/multibackend_test.py index 40cbb6630411..f498d788b08c 100644 --- a/tests/multibackend_test.py +++ b/tests/multibackend_test.py @@ -25,8 +25,7 @@ from jax._src import test_util as jtu from jax import numpy as jnp -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() npr.seed(0) diff --git a/tests/multiprocess_gpu_test.py b/tests/multiprocess_gpu_test.py index bbe79ecff969..76ed03890614 100644 --- a/tests/multiprocess_gpu_test.py +++ b/tests/multiprocess_gpu_test.py @@ -26,7 +26,6 @@ import numpy as np import jax -from jax import config from jax._src import core from jax._src import distributed from jax._src import maps @@ -40,7 +39,7 @@ except ImportError: portpicker = None -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() @unittest.skipIf(not portpicker, "Test requires portpicker") class DistributedTest(jtu.JaxTestCase): diff --git a/tests/name_stack_test.py b/tests/name_stack_test.py index e6ac29e7088b..5f6dc95b9a5b 100644 --- a/tests/name_stack_test.py +++ b/tests/name_stack_test.py @@ -20,12 +20,11 @@ from jax import lax from jax._src.pjit import pjit from jax._src import linear_util as lu -from jax import config from jax._src import test_util as jtu from jax._src.lib import xla_client from jax._src import ad_checkpoint -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def _get_hlo(f): def wrapped(*args, **kwargs): diff --git a/tests/ode_test.py b/tests/ode_test.py index 2d2bcc971434..834745e1cf1c 100644 --- a/tests/ode_test.py +++ b/tests/ode_test.py @@ -24,8 +24,7 @@ import scipy.integrate as osp_integrate -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class ODETest(jtu.JaxTestCase): diff --git a/tests/optimizers_test.py b/tests/optimizers_test.py index 3fb3101c4142..b7710d9b94c2 100644 --- a/tests/optimizers_test.py +++ b/tests/optimizers_test.py @@ -26,8 +26,7 @@ from jax import lax from jax.example_libraries import optimizers -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class OptimizerTests(jtu.JaxTestCase): diff --git a/tests/pgle_test.py b/tests/pgle_test.py index 188b56c8cf78..3dbf0232fbcf 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -21,7 +21,6 @@ from absl.testing import absltest import jax -from jax import config from jax._src import test_util as jtu from jax.sharding import NamedSharding from jax.experimental import profiler as exp_profiler @@ -29,7 +28,7 @@ from jax.sharding import PartitionSpec as P import numpy as np -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() @jtu.pytest_mark_if_available('multiaccelerator') diff --git a/tests/pickle_test.py b/tests/pickle_test.py index 8fa6613cf895..1dede34d2bf1 100644 --- a/tests/pickle_test.py +++ b/tests/pickle_test.py @@ -26,14 +26,13 @@ import jax from jax import numpy as jnp -from jax import config from jax.interpreters import pxla from jax._src import test_util as jtu from jax._src.lib import xla_client as xc import numpy as np -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def _get_device_by_id(device_id: int) -> xc.Device: diff --git a/tests/polynomial_test.py b/tests/polynomial_test.py index ccba4c2ef11f..3eeaec482719 100644 --- a/tests/polynomial_test.py +++ b/tests/polynomial_test.py @@ -19,12 +19,12 @@ from absl.testing import absltest +import jax from jax._src import dtypes from jax import numpy as jnp from jax._src import test_util as jtu -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() all_dtypes = jtu.dtypes.floating + jtu.dtypes.integer + jtu.dtypes.complex diff --git a/tests/profiler_test.py b/tests/profiler_test.py index c232c3afd699..b67b078aec02 100644 --- a/tests/profiler_test.py +++ b/tests/profiler_test.py @@ -26,7 +26,6 @@ import jax import jax.numpy as jnp import jax.profiler -from jax import config import jax._src.test_util as jtu from jax._src import profiler @@ -50,7 +49,7 @@ except ImportError: pass -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class ProfilerTest(unittest.TestCase): diff --git a/tests/scipy_fft_test.py b/tests/scipy_fft_test.py index 77ee057d29b4..17c1e9c2d1d0 100644 --- a/tests/scipy_fft_test.py +++ b/tests/scipy_fft_test.py @@ -15,13 +15,12 @@ from absl.testing import absltest +import jax from jax._src import test_util as jtu import jax.scipy.fft as jsp_fft import scipy.fft as osp_fft -from jax import config - -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() float_dtypes = jtu.dtypes.floating real_dtypes = float_dtypes + jtu.dtypes.integer + jtu.dtypes.boolean diff --git a/tests/scipy_interpolate_test.py b/tests/scipy_interpolate_test.py index ee905b7f0112..1fead634ab7b 100644 --- a/tests/scipy_interpolate_test.py +++ b/tests/scipy_interpolate_test.py @@ -18,13 +18,13 @@ from functools import reduce import numpy as np +import jax from jax._src import test_util as jtu import scipy.interpolate as sp_interp import jax.scipy.interpolate as jsp_interp -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class LaxBackedScipyInterpolateTests(jtu.JaxTestCase): diff --git a/tests/scipy_ndimage_test.py b/tests/scipy_ndimage_test.py index 7ce0df8736cd..b206c77d0351 100644 --- a/tests/scipy_ndimage_test.py +++ b/tests/scipy_ndimage_test.py @@ -21,13 +21,13 @@ from absl.testing import absltest import scipy.ndimage as osp_ndimage +import jax from jax import grad from jax._src import test_util as jtu from jax import dtypes from jax.scipy import ndimage as lsp_ndimage -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() float_dtypes = jtu.dtypes.floating diff --git a/tests/scipy_optimize_test.py b/tests/scipy_optimize_test.py index e07455e06f81..70a00e14c468 100644 --- a/tests/scipy_optimize_test.py +++ b/tests/scipy_optimize_test.py @@ -17,13 +17,13 @@ import scipy import scipy.optimize +import jax from jax import numpy as jnp from jax._src import test_util as jtu from jax import jit -from jax import config import jax.scipy.optimize -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def rosenbrock(np): diff --git a/tests/scipy_signal_test.py b/tests/scipy_signal_test.py index 70a367a04e74..11923257a9dd 100644 --- a/tests/scipy_signal_test.py +++ b/tests/scipy_signal_test.py @@ -21,14 +21,14 @@ import numpy as np import scipy.signal as osp_signal +import jax from jax import lax import jax.numpy as jnp from jax._src import dtypes from jax._src import test_util as jtu import jax.scipy.signal as jsp_signal -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() onedim_shapes = [(1,), (2,), (5,), (10,)] twodim_shapes = [(1, 1), (2, 2), (2, 3), (3, 4), (4, 4)] diff --git a/tests/scipy_spatial_test.py b/tests/scipy_spatial_test.py index f51ad49adc22..5acbdc0ddb6b 100644 --- a/tests/scipy_spatial_test.py +++ b/tests/scipy_spatial_test.py @@ -25,9 +25,8 @@ import jax.numpy as jnp import numpy as onp -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() scipy_version = jtu.parse_version(scipy.version.version) diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index 786d4ae039f7..1ab0bb9e5ed1 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -27,8 +27,7 @@ from jax.scipy import stats as lsp_stats from jax.scipy.special import expit -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() scipy_version = jtu.parse_version(scipy.version.version) diff --git a/tests/shard_alike_test.py b/tests/shard_alike_test.py index 291ee5360864..8b7f11e319e0 100644 --- a/tests/shard_alike_test.py +++ b/tests/shard_alike_test.py @@ -25,8 +25,7 @@ from jax.experimental.shard_map import shard_map from jax._src.lib import xla_extension_version -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() prev_xla_flags = None diff --git a/tests/source_info_test.py b/tests/source_info_test.py index aaa3abf552d8..0f876de1c20f 100644 --- a/tests/source_info_test.py +++ b/tests/source_info_test.py @@ -19,11 +19,10 @@ import jax from jax import lax -from jax import config from jax._src import source_info_util from jax._src import test_util as jtu -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class SourceInfoTest(jtu.JaxTestCase): diff --git a/tests/sparse_bcoo_bcsr_test.py b/tests/sparse_bcoo_bcsr_test.py index 441bee4efdb3..ba0ad5cb02c5 100644 --- a/tests/sparse_bcoo_bcsr_test.py +++ b/tests/sparse_bcoo_bcsr_test.py @@ -22,7 +22,6 @@ from absl.testing import absltest import jax -from jax import config from jax import jit from jax import lax from jax import vmap @@ -40,7 +39,7 @@ from jax.util import split_list import numpy as np -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() COMPATIBLE_SHAPE_PAIRS = [ [(), ()], @@ -151,7 +150,7 @@ def _is_required_cuda_version_satisfied(cuda_version): class BCOOTest(sptu.SparseTestCase): def gpu_matmul_warning_context(self, msg): - if config.jax_bcoo_cusparse_lowering: + if jax.config.jax_bcoo_cusparse_lowering: return self.assertWarnsRegex(sparse.CuSparseEfficiencyWarning, msg) return contextlib.nullcontext() diff --git a/tests/sparse_test.py b/tests/sparse_test.py index 2522befa9f67..49438f411ff5 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -22,7 +22,6 @@ import jax import jax.random -from jax import config from jax import dtypes from jax.experimental import sparse from jax.experimental.sparse import coo as sparse_coo @@ -43,7 +42,7 @@ import numpy as np import scipy.sparse -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() all_dtypes = jtu.dtypes.integer + jtu.dtypes.floating + jtu.dtypes.complex diff --git a/tests/sparsify_test.py b/tests/sparsify_test.py index 998ce1c4067d..46086511d8b5 100644 --- a/tests/sparsify_test.py +++ b/tests/sparsify_test.py @@ -22,7 +22,7 @@ import numpy as np import jax -from jax import config, jit, lax +from jax import jit, lax import jax.numpy as jnp import jax._src.test_util as jtu from jax.experimental.sparse import BCOO, BCSR, sparsify, todense, SparseTracer @@ -31,7 +31,7 @@ from jax.experimental.sparse.util import CuSparseEfficiencyWarning from jax.experimental.sparse import test_util as sptu -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def rand_sparse(rng, nse=0.5, post=lambda x: x, rand_method=jtu.rand_default): def _rand_sparse(shape, dtype, nse=nse): diff --git a/tests/stack_test.py b/tests/stack_test.py index acefc0630018..655a42571b01 100644 --- a/tests/stack_test.py +++ b/tests/stack_test.py @@ -17,13 +17,13 @@ from absl.testing import absltest +import jax import jax.numpy as jnp from jax._src.lax.stack import Stack from jax._src import test_util as jtu -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class StackTest(jtu.JaxTestCase): diff --git a/tests/stax_test.py b/tests/stax_test.py index 351a0fdb3d71..6850f36a02ea 100644 --- a/tests/stax_test.py +++ b/tests/stax_test.py @@ -18,13 +18,13 @@ import numpy as np +import jax from jax._src import test_util as jtu from jax import random from jax.example_libraries import stax from jax import dtypes -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def random_inputs(rng, input_shape): diff --git a/tests/third_party/scipy/line_search_test.py b/tests/third_party/scipy/line_search_test.py index 5e7d9a943352..9b2480053d33 100644 --- a/tests/third_party/scipy/line_search_test.py +++ b/tests/third_party/scipy/line_search_test.py @@ -3,13 +3,12 @@ import jax from jax import grad -from jax import config import jax.numpy as jnp import jax._src.test_util as jtu from jax._src.scipy.optimize.line_search import line_search -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class TestLineSearch(jtu.JaxTestCase): diff --git a/tests/transfer_guard_test.py b/tests/transfer_guard_test.py index fa08c52b6aff..b6d9058db385 100644 --- a/tests/transfer_guard_test.py +++ b/tests/transfer_guard_test.py @@ -25,9 +25,7 @@ import jax._src.test_util as jtu import jax.numpy as jnp -from jax import config - -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() def _host_to_device_funcs(): diff --git a/tests/util_test.py b/tests/util_test.py index e06df8b3fa70..5f07d2f50880 100644 --- a/tests/util_test.py +++ b/tests/util_test.py @@ -16,13 +16,13 @@ from absl.testing import absltest +import jax from jax._src import linear_util as lu from jax._src import test_util as jtu from jax._src import util -from jax import config from jax._src.util import weakref_lru_cache -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() try: from jax._src.lib import utils as jaxlib_utils diff --git a/tests/x64_context_test.py b/tests/x64_context_test.py index 75919de8f2f7..58cf4a2baae3 100644 --- a/tests/x64_context_test.py +++ b/tests/x64_context_test.py @@ -24,12 +24,11 @@ import jax from jax import lax from jax import random -from jax import config from jax.experimental import enable_x64, disable_x64 import jax.numpy as jnp import jax._src.test_util as jtu -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() class X64ContextTests(jtu.JaxTestCase): @@ -49,12 +48,12 @@ def test_make_array(self, jit): ) def test_correctly_capture_default(self, jit, enable_or_disable): # The fact we defined a jitted function with a block with a different value - # of `config.enable_x64` has no impact on the output. + # of `jax.config.enable_x64` has no impact on the output. with enable_or_disable(): func = jit(lambda: jnp.array(np.float64(0))) func() - expected_dtype = "float64" if config._read("jax_enable_x64") else "float32" + expected_dtype = "float64" if jax.config._read("jax_enable_x64") else "float32" self.assertEqual(func().dtype, expected_dtype) with enable_x64(): diff --git a/tests/xmap_test.py b/tests/xmap_test.py index 91b63488af16..0d11bb878d55 100644 --- a/tests/xmap_test.py +++ b/tests/xmap_test.py @@ -53,8 +53,7 @@ from jax._src.sharding_impls import NamedSharding from jax._src.util import unzip2 -from jax import config -config.parse_flags_with_absl() +jax.config.parse_flags_with_absl() # TODO(mattjj): de-duplicate setUpModule and tearDownModule with pmap_test.py @@ -248,10 +247,10 @@ class SPMDTestMixin: def setUp(self): super().setUp() self.spmd_lowering = maps.SPMD_LOWERING.value - config.update('experimental_xmap_spmd_lowering', True) + jax.config.update('experimental_xmap_spmd_lowering', True) def tearDown(self): - config.update('experimental_xmap_spmd_lowering', self.spmd_lowering) + jax.config.update('experimental_xmap_spmd_lowering', self.spmd_lowering) class ManualSPMDTestMixin: @@ -261,12 +260,12 @@ def setUp(self): super().setUp() self.spmd_lowering = maps.SPMD_LOWERING.value self.spmd_manual_lowering = maps.SPMD_LOWERING_MANUAL.value - config.update('experimental_xmap_spmd_lowering', True) - config.update('experimental_xmap_spmd_lowering_manual', True) + jax.config.update('experimental_xmap_spmd_lowering', True) + jax.config.update('experimental_xmap_spmd_lowering_manual', True) def tearDown(self): - config.update('experimental_xmap_spmd_lowering', self.spmd_lowering) - config.update('experimental_xmap_spmd_lowering_manual', self.spmd_manual_lowering) + jax.config.update('experimental_xmap_spmd_lowering', self.spmd_lowering) + jax.config.update('experimental_xmap_spmd_lowering_manual', self.spmd_manual_lowering) @jtu.pytest_mark_if_available('multiaccelerator') @@ -845,13 +844,13 @@ def testFixedSharding(self): # TODO(apaszke): Add support for extracting XLA computations generated by # xmap and make this less of a smoke test. try: - config.update("experimental_xmap_ensure_fixed_sharding", True) + jax.config.update("experimental_xmap_ensure_fixed_sharding", True) f = xmap(lambda x: jnp.sin(2 * jnp.sum(jnp.cos(x) + 4, 'i')), in_axes=['i'], out_axes={}, axis_resources={'i': 'x'}) x = jnp.arange(20, dtype=jnp.float32) f(x) finally: - config.update("experimental_xmap_ensure_fixed_sharding", False) + jax.config.update("experimental_xmap_ensure_fixed_sharding", False) @jtu.with_mesh([('x', 2)]) def testConstantsInLowering(self):