Skip to content

Commit

Permalink
somehow merged
Browse files Browse the repository at this point in the history
  • Loading branch information
SauravMaheshkar committed Oct 7, 2024
2 parents 7b98927 + e43da22 commit 1375b8f
Show file tree
Hide file tree
Showing 13 changed files with 939 additions and 384 deletions.
8 changes: 1 addition & 7 deletions jflux/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,7 @@
from jax.typing import DTypeLike

from jflux.sampling import denoise, get_noise, get_schedule, prepare, unpack
from jflux.util import (
configs,
load_ae,
load_clip,
load_flow_model,
load_t5,
)
from jflux.util import configs, load_ae, load_clip, load_flow_model, load_t5


@dataclass
Expand Down
42 changes: 3 additions & 39 deletions jflux/math.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,21 @@
import typing

import jax
from chex import Array
from einops import rearrange
from flax import nnx
from jax import numpy as jnp


@typing.no_type_check
def attention(q: Array, k: Array, v: Array, pe: Array) -> Array:
# TODO (ariG23498): Change all usage of attention to use this function
q, k = apply_rope(q, k, pe)

# jax expects this shape
x = rearrange(x, "B H L D -> B L H D") # noqa
x = jax.nn.dot_product_attention(q, k, v)
x = rearrange(x, "B L H D -> B L (H D)") # reshape again
x = nnx.dot_product_attention(q, k, v)
x = rearrange(x, "B H L D -> B L (H D)")

return x


def rope(pos: Array, dim: int, theta: int) -> Array:
"""
Generate Rotary Position Embedding (RoPE) for positional encoding.
Args:
pos (Array): Positional values, typically a sequence of positions in an array format.
dim (int): The embedding dimension, which must be an even number.
theta (int): A scaling parameter for RoPE that controls the frequency range of rotations.
Returns:
Array: Rotary embeddings with cosine and sine components for each position and dimension.
"""

# Embedding dimension must be an even number
assert dim % 2 == 0

# Generate the RoPE embeddings
scale = jnp.arange(0, dim, 2, dtype=jnp.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
out = jnp.einsum("...n,d->...nd", pos, omega)
Expand All @@ -45,26 +25,10 @@ def rope(pos: Array, dim: int, theta: int) -> Array:


def apply_rope(xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]:
"""
Apply RoPE to the input query and key tensors.
Args:
xq (Array): Query tensor.
xk (Array): Key tensor.
freqs_cis (Array): RoPE frequencies.
Returns:
tuple[Array, Array]: Query and key tensors with RoPE applied.
"""
# Reshape and typecast the input tensors
xq_ = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 1, 2)

# Apply RoPE to the input tensors
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]

# Reshape and typecast the output tensors
return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype(
xk.dtype
)
Loading

0 comments on commit 1375b8f

Please sign in to comment.