diff --git a/.github/README.md b/.github/README.md index 3552076..21f9871 100644 --- a/.github/README.md +++ b/.github/README.md @@ -145,7 +145,7 @@ There are a few challenges we’re actively working to resolve. We encourage con ## Roadmap -- [ ] Add TPU support (https://github.com/ml-gde/jflux/issues/19) +- [x] Add TPU support (https://github.com/ml-gde/jflux/issues/19) (fixed in https://github.com/ml-gde/jflux/pull/25) - [ ] Optimize VRAM usage with gradient checkpointing (https://github.com/ml-gde/jflux/issues/20) - [ ] Explore further optimizations for image generation time - [ ] Improve the handling of bfloat16 tensors with JAX diff --git a/jflux/cli.py b/jflux/cli.py index acc4b20..917dd2e 100644 --- a/jflux/cli.py +++ b/jflux/cli.py @@ -19,22 +19,24 @@ os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" + def get_device_type(): - """Returns the type of JAX device being used. - - Returns: - str: "gpu", "tpu", or "cpu" - """ - try: - device_kind = jax.devices()[0].device_kind - if "gpu" in device_kind.lower(): - return "gpu" - elif "tpu" in device_kind.lower(): - return "tpu" - else: - return "cpu" - except IndexError: - return "cpu" # No devices found, likely using CPU + """Returns the type of JAX device being used. + + Returns: + str: "gpu", "tpu", or "cpu" + """ + try: + device_kind = jax.devices()[0].device_kind + if "gpu" in device_kind.lower(): + return "gpu" + elif "tpu" in device_kind.lower(): + return "tpu" + else: + return "cpu" + except IndexError: + return "cpu" # No devices found, likely using CPU + @dataclass class SamplingOptions: @@ -180,7 +182,10 @@ def main( idx = 0 # init t5 and clip on the gpu (torch models) - t5 = load_t5(device="cuda" if device_type == "gpu" else "cpu", max_length=256 if name == "flux-schnell" else 512) + t5 = load_t5( + device="cuda" if device_type == "gpu" else "cpu", + max_length=256 if name == "flux-schnell" else 512, + ) clip = load_clip(device="cuda" if device_type == "gpu" else "cpu") # init flux and ae on the cpu @@ -223,7 +228,7 @@ def main( # move t5 and clip to cpu t5, clip = t5.cpu(), clip.cpu() if device_type == "gpu": - torch.cuda.empty_cache() + torch.cuda.empty_cache() # load model to device model_state = nnx.state(model) @@ -247,7 +252,9 @@ def main( # move ae decoder to gpu ae_decoder_state = nnx.state(ae.decoder) - ae_decoder_state = jax.device_put(ae_decoder_state, jax.devices(device_type)[0]) + ae_decoder_state = jax.device_put( + ae_decoder_state, jax.devices(device_type)[0] + ) nnx.update(ae.decoder, ae_decoder_state) jax.clear_caches()