Skip to content

Commit

Permalink
Avoid 'from jax import config' imports
Browse files Browse the repository at this point in the history
In some environments this appears to import the config module rather than
the config object.
  • Loading branch information
jakevdp committed Apr 11, 2024
1 parent 301c351 commit f090074
Show file tree
Hide file tree
Showing 83 changed files with 162 additions and 224 deletions.
4 changes: 1 addition & 3 deletions benchmarks/api_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@
import jax.numpy as jnp
import numpy as np

from jax import config

config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()


partial = functools.partial
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/shape_poly_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@

import google_benchmark as benchmark

from jax import config
import jax
from jax import core
from jax._src.numpy import lax_numpy
from jax.experimental import export

config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()


@benchmark.register
Expand Down
16 changes: 8 additions & 8 deletions docs/debugging/flags.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ JAX offers flags and context managers that enable catching errors more easily.

If you want to trace where NaNs are occurring in your functions or gradients, you can turn on the NaN-checker by:
* setting the `JAX_DEBUG_NANS=True` environment variable;
* adding `from jax import config` and `config.update("jax_debug_nans", True)` near the top of your main file;
* adding from `jax.config import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;
* adding `jax.config.update("jax_debug_nans", True)` near the top of your main file;
* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;

### Example(s)

```python
from jax import config
config.update("jax_debug_nans", True)
import jax
jax.config.update("jax_debug_nans", True)

def f(x, y):
return x / y
Expand Down Expand Up @@ -47,14 +47,14 @@ jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception!

You can disable JIT-compilation by:
* setting the `JAX_DISABLE_JIT=True` environment variable;
* adding `from jax import config` and `config.update("jax_disable_jit", True)` near the top of your main file;
* adding from `jax.config import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_disable_jit=True`;
* adding `jax.config.update("jax_disable_jit", True)` near the top of your main file;
* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_disable_jit=True`;

### Examples

```python
from jax import config
config.update("jax_disable_jit", True)
import jax
jax.config.update("jax_disable_jit", True)

def f(x):
y = jnp.log(x)
Expand Down
4 changes: 2 additions & 2 deletions docs/debugging/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ Click [here](checkify_guide) to learn more!
**TL;DR** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`.

```python
from jax import config
config.update("jax_debug_nans", True)
import jax
jax.config.update("jax_debug_nans", True)

def f(x, y):
return x / y
Expand Down
18 changes: 9 additions & 9 deletions docs/notebooks/Common_Gotchas_in_JAX.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1946,9 +1946,9 @@
"\n",
"* setting the `JAX_DEBUG_NANS=True` environment variable;\n",
"\n",
"* adding `from jax import config` and `config.update(\"jax_debug_nans\", True)` near the top of your main file;\n",
"* adding `jax.config.update(\"jax_debug_nans\", True)` near the top of your main file;\n",
"\n",
"* adding `from jax import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;\n",
"* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;\n",
"\n",
"This will cause computations to error-out immediately on production of a NaN. Switching this option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an `@jit`. For code under an `@jit`, the output of every `@jit` function is checked and if a nan is present it will re-run the function in de-optimized op-by-op mode, effectively removing one level of `@jit` at a time.\n",
"\n",
Expand Down Expand Up @@ -2141,24 +2141,24 @@
"\n",
" ```python\n",
" # again, this only works on startup!\n",
" from jax import config\n",
" config.update(\"jax_enable_x64\", True)\n",
" import jax\n",
" jax.config.update(\"jax_enable_x64\", True)\n",
" ```\n",
"\n",
"3. You can parse command-line flags with `absl.app.run(main)`\n",
"\n",
" ```python\n",
" from jax import config\n",
" config.config_with_absl()\n",
" import jax\n",
" jax.config.config_with_absl()\n",
" ```\n",
"\n",
"4. If you want JAX to run absl parsing for you, i.e. you don't want to do `absl.app.run(main)`, you can instead use\n",
"\n",
" ```python\n",
" from jax import config\n",
" import jax\n",
" if __name__ == '__main__':\n",
" # calls config.config_with_absl() *and* runs absl parsing\n",
" config.parse_flags_with_absl()\n",
" # calls jax.config.config_with_absl() *and* runs absl parsing\n",
" jax.config.parse_flags_with_absl()\n",
" ```\n",
"\n",
"Note that #2-#4 work for _any_ of JAX's configuration options.\n",
Expand Down
18 changes: 9 additions & 9 deletions docs/notebooks/Common_Gotchas_in_JAX.md
Original file line number Diff line number Diff line change
Expand Up @@ -938,9 +938,9 @@ If you want to trace where NaNs are occurring in your functions or gradients, yo

* setting the `JAX_DEBUG_NANS=True` environment variable;

* adding `from jax import config` and `config.update("jax_debug_nans", True)` near the top of your main file;
* adding `jax.config.update("jax_debug_nans", True)` near the top of your main file;

* adding `from jax import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;
* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;

This will cause computations to error-out immediately on production of a NaN. Switching this option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an `@jit`. For code under an `@jit`, the output of every `@jit` function is checked and if a nan is present it will re-run the function in de-optimized op-by-op mode, effectively removing one level of `@jit` at a time.

Expand Down Expand Up @@ -1087,24 +1087,24 @@ There are a few ways to do this:

```python
# again, this only works on startup!
from jax import config
config.update("jax_enable_x64", True)
import jax
jax.config.update("jax_enable_x64", True)
```

3. You can parse command-line flags with `absl.app.run(main)`

```python
from jax import config
config.config_with_absl()
import jax
jax.config.config_with_absl()
```

4. If you want JAX to run absl parsing for you, i.e. you don't want to do `absl.app.run(main)`, you can instead use

```python
from jax import config
import jax
if __name__ == '__main__':
# calls config.config_with_absl() *and* runs absl parsing
config.parse_flags_with_absl()
# calls jax.config.config_with_absl() *and* runs absl parsing
jax.config.parse_flags_with_absl()
```

Note that #2-#4 work for _any_ of JAX's configuration options.
Expand Down
4 changes: 2 additions & 2 deletions docs/rank_promotion_warning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ One is by using :code:`jax.config` in your code:

.. code-block:: python
from jax import config
config.update("jax_numpy_rank_promotion", "warn")
import jax
jax.config.update("jax_numpy_rank_promotion", "warn")
You can also set the option using the environment variable
:code:`JAX_NUMPY_RANK_PROMOTION`, for example as
Expand Down
4 changes: 2 additions & 2 deletions examples/examples_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import numpy as np

import jax
from jax import lax
from jax import random
import jax.numpy as jnp
Expand All @@ -30,8 +31,7 @@
from examples import kernel_lsq
sys.path.pop()

from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()


def _CheckShapeAgreement(test_case, init_fun, apply_fun, input_shape):
Expand Down
5 changes: 3 additions & 2 deletions examples/gaussian_process_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@

from absl import app
from functools import partial

import jax
from jax import grad
from jax import jit
from jax import vmap
from jax import config
import jax.numpy as jnp
import jax.random as random
import jax.scipy as scipy
Expand Down Expand Up @@ -125,5 +126,5 @@ def train_step(params, momentums, scales, x, y):
mu.flatten() - std * 2, mu.flatten() + std * 2)

if __name__ == "__main__":
config.config_with_absl()
jax.config.config_with_absl()
app.run(main)
4 changes: 2 additions & 2 deletions jax/_src/internal_test_util/lax_test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@
import itertools
from typing import Union, cast

import jax
from jax import lax
from jax._src import dtypes
from jax._src import test_util
from jax._src.util import safe_map, safe_zip

import numpy as np

from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()

map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
Expand Down
4 changes: 1 addition & 3 deletions jax/experimental/array_serialization/serialization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,14 @@
from absl.testing import parameterized
import jax
from jax._src import test_util as jtu
from jax import config
from jax._src import array
from jax.sharding import NamedSharding, GSPMDSharding
from jax.sharding import PartitionSpec as P
from jax.experimental.array_serialization import serialization
import numpy as np
import tensorstore as ts
import unittest

config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()

prev_xla_flags = None

Expand Down
4 changes: 2 additions & 2 deletions jax/experimental/jax2tf/examples/keras_reuse_main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
from absl import flags
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax._src import test_util as jtu
from jax import config

from jax.experimental.jax2tf.examples import keras_reuse_main
from jax.experimental.jax2tf.tests import tf_test_util

config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
FLAGS = flags.FLAGS


Expand Down
4 changes: 2 additions & 2 deletions jax/experimental/jax2tf/tests/back_compat_tf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from typing import Callable, Optional

from absl.testing import absltest
from jax import config
import jax
from jax._src import test_util as jtu
from jax._src.internal_test_util import export_back_compat_test_util as bctu
from jax._src.lib import xla_extension
Expand All @@ -37,7 +37,7 @@
import tensorflow as tf


config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()


def serialize_directory(directory_path):
Expand Down
21 changes: 10 additions & 11 deletions jax/experimental/jax2tf/tests/call_tf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import config
from jax import dlpack
from jax import dtypes
from jax import lax
Expand All @@ -42,7 +41,7 @@
except ImportError:
tf = None

config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()


def _maybe_jit(with_jit: bool, func: Callable) -> Callable:
Expand Down Expand Up @@ -1151,15 +1150,15 @@ def setUp(self):
super().setUp()

def override_serialization_version(self, version_override: int):
version = config.jax_serialization_version
version = jax.config.jax_serialization_version
if version != version_override:
self.addCleanup(partial(config.update,
self.addCleanup(partial(jax.config.update,
"jax_serialization_version",
version_override))
config.update("jax_serialization_version", version_override)
jax.config.update("jax_serialization_version", version_override)
logging.info(
"Using JAX serialization version %s",
config.jax_serialization_version)
jax.config.jax_serialization_version)

def test_alternate(self):
# Alternate sin/cos with sin in TF and cos in JAX
Expand Down Expand Up @@ -1275,7 +1274,7 @@ def fun_tf(x): # x:i32[3]

@_parameterized_jit
def test_shape_poly_static_output_shape(self, with_jit=True):
if config.jax2tf_default_native_serialization:
if jax.config.jax2tf_default_native_serialization:
raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native serialization.")
x = np.array([0.7, 0.8], dtype=np.float32)

Expand All @@ -1289,7 +1288,7 @@ def fun_tf(x):

@_parameterized_jit
def test_shape_poly(self, with_jit=False):
if config.jax2tf_default_native_serialization:
if jax.config.jax2tf_default_native_serialization:
raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native serialization.")
x = np.array([7, 8, 9, 10], dtype=np.float32)
def fun_jax(x):
Expand All @@ -1308,7 +1307,7 @@ def fun_jax(x):

@_parameterized_jit
def test_shape_poly_pytree_result(self, with_jit=True):
if config.jax2tf_default_native_serialization:
if jax.config.jax2tf_default_native_serialization:
raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native serialization.")
x = np.array([7, 8, 9, 10], dtype=np.float32)
def fun_jax(x):
Expand Down Expand Up @@ -1394,7 +1393,7 @@ def fun_jax(x):
if kind == "bad_dim" and with_jit:
# TODO: in jit more the error pops up later, at AddV2
expect_error = "Dimensions must be equal, but are 4 and 9 for .* AddV2"
if kind == "bad_dim" and config.jax2tf_default_native_serialization:
if kind == "bad_dim" and jax.config.jax2tf_default_native_serialization:
# TODO(b/268386622): call_tf with shape polymorphism and native serialization.
expect_error = "Error compiling TensorFlow function"
fun_tf_rt = _maybe_tf_jit(with_jit,
Expand Down Expand Up @@ -1432,7 +1431,7 @@ def test_several_round_trips(self,
f4_function=False, f4_saved_model=False):
if (f2_saved_model and
f4_saved_model and
not config.jax2tf_default_native_serialization):
not jax.config.jax2tf_default_native_serialization):
# TODO: Getting error Found invalid capture Tensor("jax2tf_vjp/jax2tf_arg_0:0", shape=(), dtype=float32) when saving custom gradients
# when saving f4, but only with non-native serialization.
raise unittest.SkipTest("TODO: error invalid capture when saving custom gradients")
Expand Down
3 changes: 1 addition & 2 deletions jax/experimental/jax2tf/tests/control_flow_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@

from jax.experimental.jax2tf.tests import tf_test_util

from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()


class ControlFlowOpsTest(tf_test_util.JaxToTfTestCase):
Expand Down
5 changes: 2 additions & 3 deletions jax/experimental/jax2tf/tests/cross_compilation_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,11 @@

import numpy.random as npr

import jax
from jax import config # Must import before TF
import jax # Must import before TF
from jax.experimental import jax2tf # Defines needed flags
from jax._src import test_util # Defines needed flags

config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()

# Import after parsing flags
from jax.experimental.jax2tf.tests import primitive_harness
Expand Down
3 changes: 1 addition & 2 deletions jax/experimental/jax2tf/tests/savedmodel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
from jax.experimental.jax2tf.tests import tf_test_util
from jax._src import test_util as jtu

from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()


class SavedModelTest(tf_test_util.JaxToTfTestCase):
Expand Down
Loading

0 comments on commit f090074

Please sign in to comment.