Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: initial nnx port #3

Merged
merged 34 commits into from
Sep 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
89aaaeb
chore: inital commit for porting
ariG23498 Sep 2, 2024
e5c0c2d
adding torch code from original repository
ariG23498 Sep 2, 2024
14398a3
dot pdt attn in jnp
ariG23498 Sep 3, 2024
1a00d0c
more changes to init
ariG23498 Sep 3, 2024
3dfd926
start porting
ariG23498 Sep 3, 2024
3bcd923
adding everything to uv
ariG23498 Sep 4, 2024
3e09e36
adding ruff and formating
ariG23498 Sep 4, 2024
92e2c5b
fix(modules/autoencoder): correct invocation of GroupNorm and Conv
SauravMaheshkar Sep 8, 2024
d894b3a
fix(modules/layers): correct invocation of nnx modules
SauravMaheshkar Sep 8, 2024
0414a02
fix: jaxify codebase
SauravMaheshkar Sep 13, 2024
38e97b2
chore: allow for non GPU installation
SauravMaheshkar Sep 13, 2024
363f74e
feat(ci): use uv in CI
SauravMaheshkar Sep 13, 2024
4a3948e
fix: tmp mark test as xfail
SauravMaheshkar Sep 13, 2024
29ad2a0
fix: use Sequential for the middle blocks
SauravMaheshkar Sep 13, 2024
b4b3c37
chore: drop devcontainer
SauravMaheshkar Sep 13, 2024
8c6a590
feat: add loop to cli
SauravMaheshkar Sep 13, 2024
393429a
chore: fix just cmd
SauravMaheshkar Sep 13, 2024
76c05b2
style(mypy): disable no-redef
SauravMaheshkar Sep 13, 2024
25c2bf1
fix: return numpy tensors from tokenizer
SauravMaheshkar Sep 13, 2024
52464ef
feat: use Array from chex
SauravMaheshkar Sep 13, 2024
eadbc6d
docs: update docstrings + use chex
SauravMaheshkar Sep 13, 2024
cb9a5a7
fix: nnx modules use __call__
SauravMaheshkar Sep 13, 2024
1bfa089
feat: jaxify prepare fn
SauravMaheshkar Sep 13, 2024
919aace
fix: nnx modules use __call__
SauravMaheshkar Sep 13, 2024
6c97c8d
docs: add docstrings to denoise fn
SauravMaheshkar Sep 13, 2024
df0e52b
feat: to_device >> device_put
SauravMaheshkar Sep 13, 2024
09d1e08
feat: add dtypes and param dtypes
SauravMaheshkar Sep 13, 2024
52776d6
docs: docstrings for Identity module
SauravMaheshkar Sep 13, 2024
8a64274
feat: explicitly specify the dtypes for QKNorm in SelfAttention module
SauravMaheshkar Sep 13, 2024
3d9dc3f
feat: add tests for embedding layer
SauravMaheshkar Sep 15, 2024
b48c3cc
feat: use official flux as optional deps
SauravMaheshkar Sep 16, 2024
7775066
style: enforce isort
SauravMaheshkar Sep 16, 2024
af24690
feat: add tests for layers
SauravMaheshkar Sep 16, 2024
0c59aa1
feat: add tests for modulation and self-attn
SauravMaheshkar Sep 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 0 additions & 65 deletions .devcontainer/devcontainer.json

This file was deleted.

7 changes: 0 additions & 7 deletions .devcontainer/requirements.txt

This file was deleted.

17 changes: 17 additions & 0 deletions .github/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
# FL(U/A)X

JAX Implementation of Black Forest Labs' Flux.1 family of models


## Installation

```shell
$ uv sync
```

## Running

```shell
$ uv jflux
```

## References

* Original Implementation: [black-forest-labs/flux](https://github.com/black-forest-labs/flux)
26 changes: 12 additions & 14 deletions .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,22 @@ jobs:

steps:
- uses: actions/checkout@v4
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v5

- name: Install uv
uses: astral-sh/setup-uv@v2
with:
python-version: ${{ matrix.python-version }}
cache: "pip"
cache-dependency-path: ".devcontainer/requirements.txt"
enable-cache: true
cache-dependency-glob: "uv.lock"

- name: Set up Python ${{ matrix.python-version }}
run: uv python install ${{ matrix.python-version }}

- name: Install dependencies
run: |
python -m venv .venv && export PATH=".venv/bin:$PATH"
python -m pip install uv
python -m uv pip install --upgrade wheel setuptools
python -m uv pip install -r .devcontainer/requirements.txt
run: uv sync --all-extras --dev

- name: Ruff
run: |
python -m venv .venv && export PATH=".venv/bin:$PATH"
python -m ruff check src
uv run ruff check jflux
- name: Test with PyTest
run: |
python -m venv .venv && export PATH=".venv/bin:$PATH"
python -m pytest -v .
uv run pytest -v .
4 changes: 0 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ repos:
rev: 24.8.0
hooks:
- id: black
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.11.2
hooks:
- id: mypy
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.2
hooks:
Expand Down
File renamed without changes.
4 changes: 4 additions & 0 deletions jflux/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from jflux.cli import app

if __name__ == "__main__":
app()
Loading