Skip to content

Commit

Permalink
style: ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
SauravMaheshkar committed Oct 20, 2024
1 parent 5892764 commit 6026e35
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .github/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ There are a few challenges we’re actively working to resolve. We encourage con

## Roadmap

- [ ] Add TPU support (https://github.com/ml-gde/jflux/issues/19)
- [x] Add TPU support (https://github.com/ml-gde/jflux/issues/19) (fixed in https://github.com/ml-gde/jflux/pull/25)
- [ ] Optimize VRAM usage with gradient checkpointing (https://github.com/ml-gde/jflux/issues/20)
- [ ] Explore further optimizations for image generation time
- [ ] Improve the handling of bfloat16 tensors with JAX
Expand Down
43 changes: 25 additions & 18 deletions jflux/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,24 @@

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"


def get_device_type():
"""Returns the type of JAX device being used.
Returns:
str: "gpu", "tpu", or "cpu"
"""
try:
device_kind = jax.devices()[0].device_kind
if "gpu" in device_kind.lower():
return "gpu"
elif "tpu" in device_kind.lower():
return "tpu"
else:
return "cpu"
except IndexError:
return "cpu" # No devices found, likely using CPU
"""Returns the type of JAX device being used.
Returns:
str: "gpu", "tpu", or "cpu"
"""
try:
device_kind = jax.devices()[0].device_kind
if "gpu" in device_kind.lower():
return "gpu"
elif "tpu" in device_kind.lower():
return "tpu"
else:
return "cpu"
except IndexError:
return "cpu" # No devices found, likely using CPU


@dataclass
class SamplingOptions:
Expand Down Expand Up @@ -180,7 +182,10 @@ def main(
idx = 0

# init t5 and clip on the gpu (torch models)
t5 = load_t5(device="cuda" if device_type == "gpu" else "cpu", max_length=256 if name == "flux-schnell" else 512)
t5 = load_t5(
device="cuda" if device_type == "gpu" else "cpu",
max_length=256 if name == "flux-schnell" else 512,
)
clip = load_clip(device="cuda" if device_type == "gpu" else "cpu")

# init flux and ae on the cpu
Expand Down Expand Up @@ -223,7 +228,7 @@ def main(
# move t5 and clip to cpu
t5, clip = t5.cpu(), clip.cpu()
if device_type == "gpu":
torch.cuda.empty_cache()
torch.cuda.empty_cache()

# load model to device
model_state = nnx.state(model)
Expand All @@ -247,7 +252,9 @@ def main(

# move ae decoder to gpu
ae_decoder_state = nnx.state(ae.decoder)
ae_decoder_state = jax.device_put(ae_decoder_state, jax.devices(device_type)[0])
ae_decoder_state = jax.device_put(
ae_decoder_state, jax.devices(device_type)[0]
)
nnx.update(ae.decoder, ae_decoder_state)
jax.clear_caches()

Expand Down

0 comments on commit 6026e35

Please sign in to comment.