From dfdce2b6022ee5328c1a0d7305bdd3e3e4d3bc75 Mon Sep 17 00:00:00 2001 From: Czxck001 <10724409+Czxck001@users.noreply.github.com> Date: Mon, 5 Aug 2024 10:26:15 -0700 Subject: [PATCH] Add the MMDiT model of Stable Diffusion 3 (#2397) * add mmdit of stable diffusion 3 lint add comments * correct a misplaced comment * fix cargo fmt * fix clippy error * use bail! instead of assert! * use get_on_dim in splitting qkv --- .../src/models/mmdit/blocks.rs | 294 ++++++++++++++++++ .../src/models/mmdit/embedding.rs | 197 ++++++++++++ candle-transformers/src/models/mmdit/mod.rs | 4 + candle-transformers/src/models/mmdit/model.rs | 173 +++++++++++ .../src/models/mmdit/projections.rs | 94 ++++++ candle-transformers/src/models/mod.rs | 1 + 6 files changed, 763 insertions(+) create mode 100644 candle-transformers/src/models/mmdit/blocks.rs create mode 100644 candle-transformers/src/models/mmdit/embedding.rs create mode 100644 candle-transformers/src/models/mmdit/mod.rs create mode 100644 candle-transformers/src/models/mmdit/model.rs create mode 100644 candle-transformers/src/models/mmdit/projections.rs diff --git a/candle-transformers/src/models/mmdit/blocks.rs b/candle-transformers/src/models/mmdit/blocks.rs new file mode 100644 index 0000000000..e2b924a013 --- /dev/null +++ b/candle-transformers/src/models/mmdit/blocks.rs @@ -0,0 +1,294 @@ +use candle::{Module, Result, Tensor, D}; +use candle_nn as nn; + +use super::projections::{AttnProjections, Mlp, Qkv, QkvOnlyAttnProjections}; + +pub struct ModulateIntermediates { + gate_msa: Tensor, + shift_mlp: Tensor, + scale_mlp: Tensor, + gate_mlp: Tensor, +} + +pub struct DiTBlock { + norm1: LayerNormNoAffine, + attn: AttnProjections, + norm2: LayerNormNoAffine, + mlp: Mlp, + ada_ln_modulation: nn::Sequential, +} + +pub struct LayerNormNoAffine { + eps: f64, +} + +impl LayerNormNoAffine { + pub fn new(eps: f64) -> Self { + Self { eps } + } +} + +impl Module for LayerNormNoAffine { + fn forward(&self, x: &Tensor) -> Result { + nn::LayerNorm::new_no_bias(Tensor::ones_like(x)?, self.eps).forward(x) + } +} + +impl DiTBlock { + pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result { + // {'hidden_size': 1536, 'num_heads': 24} + let norm1 = LayerNormNoAffine::new(1e-6); + let attn = AttnProjections::new(hidden_size, num_heads, vb.pp("attn"))?; + let norm2 = LayerNormNoAffine::new(1e-6); + let mlp_ratio = 4; + let mlp = Mlp::new(hidden_size, hidden_size * mlp_ratio, vb.pp("mlp"))?; + let n_mods = 6; + let ada_ln_modulation = nn::seq().add(nn::Activation::Silu).add(nn::linear( + hidden_size, + n_mods * hidden_size, + vb.pp("adaLN_modulation.1"), + )?); + + Ok(Self { + norm1, + attn, + norm2, + mlp, + ada_ln_modulation, + }) + } + + pub fn pre_attention(&self, x: &Tensor, c: &Tensor) -> Result<(Qkv, ModulateIntermediates)> { + let modulation = self.ada_ln_modulation.forward(c)?; + let chunks = modulation.chunk(6, D::Minus1)?; + let (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp) = ( + chunks[0].clone(), + chunks[1].clone(), + chunks[2].clone(), + chunks[3].clone(), + chunks[4].clone(), + chunks[5].clone(), + ); + + let norm_x = self.norm1.forward(x)?; + let modulated_x = modulate(&norm_x, &shift_msa, &scale_msa)?; + let qkv = self.attn.pre_attention(&modulated_x)?; + + Ok(( + qkv, + ModulateIntermediates { + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + }, + )) + } + + pub fn post_attention( + &self, + attn: &Tensor, + x: &Tensor, + mod_interm: &ModulateIntermediates, + ) -> Result { + let attn_out = self.attn.post_attention(attn)?; + let x = x.add(&attn_out.broadcast_mul(&mod_interm.gate_msa.unsqueeze(1)?)?)?; + + let norm_x = self.norm2.forward(&x)?; + let modulated_x = modulate(&norm_x, &mod_interm.shift_mlp, &mod_interm.scale_mlp)?; + let mlp_out = self.mlp.forward(&modulated_x)?; + let x = x.add(&mlp_out.broadcast_mul(&mod_interm.gate_mlp.unsqueeze(1)?)?)?; + + Ok(x) + } +} + +pub struct QkvOnlyDiTBlock { + norm1: LayerNormNoAffine, + attn: QkvOnlyAttnProjections, + ada_ln_modulation: nn::Sequential, +} + +impl QkvOnlyDiTBlock { + pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result { + let norm1 = LayerNormNoAffine::new(1e-6); + let attn = QkvOnlyAttnProjections::new(hidden_size, num_heads, vb.pp("attn"))?; + let n_mods = 2; + let ada_ln_modulation = nn::seq().add(nn::Activation::Silu).add(nn::linear( + hidden_size, + n_mods * hidden_size, + vb.pp("adaLN_modulation.1"), + )?); + + Ok(Self { + norm1, + attn, + ada_ln_modulation, + }) + } + + pub fn pre_attention(&self, x: &Tensor, c: &Tensor) -> Result { + let modulation = self.ada_ln_modulation.forward(c)?; + let chunks = modulation.chunk(2, D::Minus1)?; + let (shift_msa, scale_msa) = (chunks[0].clone(), chunks[1].clone()); + + let norm_x = self.norm1.forward(x)?; + let modulated_x = modulate(&norm_x, &shift_msa, &scale_msa)?; + self.attn.pre_attention(&modulated_x) + } +} + +pub struct FinalLayer { + norm_final: LayerNormNoAffine, + linear: nn::Linear, + ada_ln_modulation: nn::Sequential, +} + +impl FinalLayer { + pub fn new( + hidden_size: usize, + patch_size: usize, + out_channels: usize, + vb: nn::VarBuilder, + ) -> Result { + let norm_final = LayerNormNoAffine::new(1e-6); + let linear = nn::linear( + hidden_size, + patch_size * patch_size * out_channels, + vb.pp("linear"), + )?; + let ada_ln_modulation = nn::seq().add(nn::Activation::Silu).add(nn::linear( + hidden_size, + 2 * hidden_size, + vb.pp("adaLN_modulation.1"), + )?); + + Ok(Self { + norm_final, + linear, + ada_ln_modulation, + }) + } + + pub fn forward(&self, x: &Tensor, c: &Tensor) -> Result { + let modulation = self.ada_ln_modulation.forward(c)?; + let chunks = modulation.chunk(2, D::Minus1)?; + let (shift, scale) = (chunks[0].clone(), chunks[1].clone()); + + let norm_x = self.norm_final.forward(x)?; + let modulated_x = modulate(&norm_x, &shift, &scale)?; + let output = self.linear.forward(&modulated_x)?; + + Ok(output) + } +} + +fn modulate(x: &Tensor, shift: &Tensor, scale: &Tensor) -> Result { + let shift = shift.unsqueeze(1)?; + let scale = scale.unsqueeze(1)?; + let scale_plus_one = scale.add(&Tensor::ones_like(&scale)?)?; + shift.broadcast_add(&x.broadcast_mul(&scale_plus_one)?) +} + +pub struct JointBlock { + x_block: DiTBlock, + context_block: DiTBlock, + num_heads: usize, +} + +impl JointBlock { + pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result { + let x_block = DiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?; + let context_block = DiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?; + + Ok(Self { + x_block, + context_block, + num_heads, + }) + } + + pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> { + let (context_qkv, context_interm) = self.context_block.pre_attention(context, c)?; + let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?; + let (context_attn, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads)?; + let context_out = + self.context_block + .post_attention(&context_attn, context, &context_interm)?; + let x_out = self.x_block.post_attention(&x_attn, x, &x_interm)?; + Ok((context_out, x_out)) + } +} + +pub struct ContextQkvOnlyJointBlock { + x_block: DiTBlock, + context_block: QkvOnlyDiTBlock, + num_heads: usize, +} + +impl ContextQkvOnlyJointBlock { + pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result { + let x_block = DiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?; + let context_block = QkvOnlyDiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?; + Ok(Self { + x_block, + context_block, + num_heads, + }) + } + + pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result { + let context_qkv = self.context_block.pre_attention(context, c)?; + let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?; + + let (_, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads)?; + + let x_out = self.x_block.post_attention(&x_attn, x, &x_interm)?; + Ok(x_out) + } +} + +// A QKV-attention that is compatible with the interface of candle_flash_attn::flash_attn +// Flash attention regards q, k, v dimensions as (batch_size, seqlen, nheads, headdim) +fn flash_compatible_attention( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, +) -> Result { + let q_dims_for_matmul = q.transpose(1, 2)?.dims().to_vec(); + let rank = q_dims_for_matmul.len(); + let q = q.transpose(1, 2)?.flatten_to(rank - 3)?; + let k = k.transpose(1, 2)?.flatten_to(rank - 3)?; + let v = v.transpose(1, 2)?.flatten_to(rank - 3)?; + let attn_weights = (q.matmul(&k.t()?)? * softmax_scale as f64)?; + let attn_scores = candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(&v)?; + attn_scores.reshape(q_dims_for_matmul)?.transpose(1, 2) +} + +fn joint_attn(context_qkv: &Qkv, x_qkv: &Qkv, num_heads: usize) -> Result<(Tensor, Tensor)> { + let qkv = Qkv { + q: Tensor::cat(&[&context_qkv.q, &x_qkv.q], 1)?, + k: Tensor::cat(&[&context_qkv.k, &x_qkv.k], 1)?, + v: Tensor::cat(&[&context_qkv.v, &x_qkv.v], 1)?, + }; + + let (batch_size, seqlen, _) = qkv.q.dims3()?; + let qkv = Qkv { + q: qkv.q.reshape((batch_size, seqlen, num_heads, ()))?, + k: qkv.k.reshape((batch_size, seqlen, num_heads, ()))?, + v: qkv.v, + }; + + let headdim = qkv.q.dim(D::Minus1)?; + let softmax_scale = 1.0 / (headdim as f64).sqrt(); + // let attn: Tensor = candle_flash_attn::flash_attn(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32, false)?; + let attn = flash_compatible_attention(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32)?; + + let attn = attn.reshape((batch_size, seqlen, ()))?; + let context_qkv_seqlen = context_qkv.q.dim(1)?; + let context_attn = attn.narrow(1, 0, context_qkv_seqlen)?; + let x_attn = attn.narrow(1, context_qkv_seqlen, seqlen - context_qkv_seqlen)?; + + Ok((context_attn, x_attn)) +} diff --git a/candle-transformers/src/models/mmdit/embedding.rs b/candle-transformers/src/models/mmdit/embedding.rs new file mode 100644 index 0000000000..6e200b18bd --- /dev/null +++ b/candle-transformers/src/models/mmdit/embedding.rs @@ -0,0 +1,197 @@ +use candle::{bail, DType, Module, Result, Tensor}; +use candle_nn as nn; + +pub struct PatchEmbedder { + proj: nn::Conv2d, +} + +impl PatchEmbedder { + pub fn new( + patch_size: usize, + in_channels: usize, + embed_dim: usize, + vb: nn::VarBuilder, + ) -> Result { + let proj = nn::conv2d( + in_channels, + embed_dim, + patch_size, + nn::Conv2dConfig { + stride: patch_size, + ..Default::default() + }, + vb.pp("proj"), + )?; + + Ok(Self { proj }) + } +} + +impl Module for PatchEmbedder { + fn forward(&self, x: &Tensor) -> Result { + let x = self.proj.forward(x)?; + + // flatten spatial dim and transpose to channels last + let (b, c, h, w) = x.dims4()?; + x.reshape((b, c, h * w))?.transpose(1, 2) + } +} + +pub struct Unpatchifier { + patch_size: usize, + out_channels: usize, +} + +impl Unpatchifier { + pub fn new(patch_size: usize, out_channels: usize) -> Result { + Ok(Self { + patch_size, + out_channels, + }) + } + + pub fn unpatchify(&self, x: &Tensor, h: usize, w: usize) -> Result { + let h = (h + 1) / self.patch_size; + let w = (w + 1) / self.patch_size; + + let x = x.reshape(( + x.dim(0)?, + h, + w, + self.patch_size, + self.patch_size, + self.out_channels, + ))?; + let x = x.permute((0, 5, 1, 3, 2, 4))?; // "nhwpqc->nchpwq" + x.reshape(( + x.dim(0)?, + self.out_channels, + self.patch_size * h, + self.patch_size * w, + )) + } +} + +pub struct PositionEmbedder { + pos_embed: Tensor, + patch_size: usize, + pos_embed_max_size: usize, +} + +impl PositionEmbedder { + pub fn new( + hidden_size: usize, + patch_size: usize, + pos_embed_max_size: usize, + vb: nn::VarBuilder, + ) -> Result { + let pos_embed = vb.get( + (1, pos_embed_max_size * pos_embed_max_size, hidden_size), + "pos_embed", + )?; + Ok(Self { + pos_embed, + patch_size, + pos_embed_max_size, + }) + } + pub fn get_cropped_pos_embed(&self, h: usize, w: usize) -> Result { + let h = (h + 1) / self.patch_size; + let w = (w + 1) / self.patch_size; + + if h > self.pos_embed_max_size || w > self.pos_embed_max_size { + bail!("Input size is too large for the position embedding") + } + + let top = (self.pos_embed_max_size - h) / 2; + let left = (self.pos_embed_max_size - w) / 2; + + let pos_embed = + self.pos_embed + .reshape((1, self.pos_embed_max_size, self.pos_embed_max_size, ()))?; + let pos_embed = pos_embed.narrow(1, top, h)?.narrow(2, left, w)?; + pos_embed.reshape((1, h * w, ())) + } +} + +pub struct TimestepEmbedder { + mlp: nn::Sequential, + frequency_embedding_size: usize, +} + +impl TimestepEmbedder { + pub fn new( + hidden_size: usize, + frequency_embedding_size: usize, + vb: nn::VarBuilder, + ) -> Result { + let mlp = nn::seq() + .add(nn::linear( + frequency_embedding_size, + hidden_size, + vb.pp("mlp.0"), + )?) + .add(nn::Activation::Silu) + .add(nn::linear(hidden_size, hidden_size, vb.pp("mlp.2"))?); + + Ok(Self { + mlp, + frequency_embedding_size, + }) + } + + fn timestep_embedding(t: &Tensor, dim: usize, max_period: f64) -> Result { + if dim % 2 != 0 { + bail!("Embedding dimension must be even") + } + + if t.dtype() != DType::F32 && t.dtype() != DType::F64 { + bail!("Input tensor must be floating point") + } + + let half = dim / 2; + let freqs = Tensor::arange(0f32, half as f32, t.device())? + .to_dtype(candle::DType::F32)? + .mul(&Tensor::full( + (-f64::ln(max_period) / half as f64) as f32, + half, + t.device(), + )?)? + .exp()?; + + let args = t + .unsqueeze(1)? + .to_dtype(candle::DType::F32)? + .matmul(&freqs.unsqueeze(0)?)?; + let embedding = Tensor::cat(&[args.cos()?, args.sin()?], 1)?; + embedding.to_dtype(candle::DType::F16) + } +} + +impl Module for TimestepEmbedder { + fn forward(&self, t: &Tensor) -> Result { + let t_freq = Self::timestep_embedding(t, self.frequency_embedding_size, 10000.0)?; + self.mlp.forward(&t_freq) + } +} + +pub struct VectorEmbedder { + mlp: nn::Sequential, +} + +impl VectorEmbedder { + pub fn new(input_dim: usize, hidden_size: usize, vb: nn::VarBuilder) -> Result { + let mlp = nn::seq() + .add(nn::linear(input_dim, hidden_size, vb.pp("mlp.0"))?) + .add(nn::Activation::Silu) + .add(nn::linear(hidden_size, hidden_size, vb.pp("mlp.2"))?); + + Ok(Self { mlp }) + } +} + +impl Module for VectorEmbedder { + fn forward(&self, x: &Tensor) -> Result { + self.mlp.forward(x) + } +} diff --git a/candle-transformers/src/models/mmdit/mod.rs b/candle-transformers/src/models/mmdit/mod.rs new file mode 100644 index 0000000000..9c4db6e085 --- /dev/null +++ b/candle-transformers/src/models/mmdit/mod.rs @@ -0,0 +1,4 @@ +pub mod blocks; +pub mod embedding; +pub mod model; +pub mod projections; diff --git a/candle-transformers/src/models/mmdit/model.rs b/candle-transformers/src/models/mmdit/model.rs new file mode 100644 index 0000000000..1523836c7f --- /dev/null +++ b/candle-transformers/src/models/mmdit/model.rs @@ -0,0 +1,173 @@ +// Implement the MMDiT model originally introduced for Stable Diffusion 3 (https://arxiv.org/abs/2403.03206). +// This follows the implementation of the MMDiT model in the ComfyUI repository. +// https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py#L1 +use candle::{Module, Result, Tensor, D}; +use candle_nn as nn; + +use super::blocks::{ContextQkvOnlyJointBlock, FinalLayer, JointBlock}; +use super::embedding::{ + PatchEmbedder, PositionEmbedder, TimestepEmbedder, Unpatchifier, VectorEmbedder, +}; + +#[derive(Debug, Clone)] +pub struct Config { + pub patch_size: usize, + pub in_channels: usize, + pub out_channels: usize, + pub depth: usize, + pub head_size: usize, + pub adm_in_channels: usize, + pub pos_embed_max_size: usize, + pub context_embed_size: usize, + pub frequency_embedding_size: usize, +} + +impl Config { + pub fn sd3() -> Self { + Self { + patch_size: 2, + in_channels: 16, + out_channels: 16, + depth: 24, + head_size: 64, + adm_in_channels: 2048, + pos_embed_max_size: 192, + context_embed_size: 4096, + frequency_embedding_size: 256, + } + } +} + +pub struct MMDiT { + core: MMDiTCore, + patch_embedder: PatchEmbedder, + pos_embedder: PositionEmbedder, + timestep_embedder: TimestepEmbedder, + vector_embedder: VectorEmbedder, + context_embedder: nn::Linear, + unpatchifier: Unpatchifier, +} + +impl MMDiT { + pub fn new(cfg: &Config, vb: nn::VarBuilder) -> Result { + let hidden_size = cfg.head_size * cfg.depth; + let core = MMDiTCore::new( + cfg.depth, + hidden_size, + cfg.depth, + cfg.patch_size, + cfg.out_channels, + vb.clone(), + )?; + let patch_embedder = PatchEmbedder::new( + cfg.patch_size, + cfg.in_channels, + hidden_size, + vb.pp("x_embedder"), + )?; + let pos_embedder = PositionEmbedder::new( + hidden_size, + cfg.patch_size, + cfg.pos_embed_max_size, + vb.clone(), + )?; + let timestep_embedder = TimestepEmbedder::new( + hidden_size, + cfg.frequency_embedding_size, + vb.pp("t_embedder"), + )?; + let vector_embedder = + VectorEmbedder::new(cfg.adm_in_channels, hidden_size, vb.pp("y_embedder"))?; + let context_embedder = nn::linear( + cfg.context_embed_size, + hidden_size, + vb.pp("context_embedder"), + )?; + let unpatchifier = Unpatchifier::new(cfg.patch_size, cfg.out_channels)?; + + Ok(Self { + core, + patch_embedder, + pos_embedder, + timestep_embedder, + vector_embedder, + context_embedder, + unpatchifier, + }) + } + + pub fn forward(&self, x: &Tensor, t: &Tensor, y: &Tensor, context: &Tensor) -> Result { + // Following the convention of the ComfyUI implementation. + // https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py#L919 + // + // Forward pass of DiT. + // x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + // t: (N,) tensor of diffusion timesteps + // y: (N,) tensor of class labels + let h = x.dim(D::Minus2)?; + let w = x.dim(D::Minus1)?; + let cropped_pos_embed = self.pos_embedder.get_cropped_pos_embed(h, w)?; + let x = self + .patch_embedder + .forward(x)? + .broadcast_add(&cropped_pos_embed)?; + let c = self.timestep_embedder.forward(t)?; + let y = self.vector_embedder.forward(y)?; + let c = (c + y)?; + let context = self.context_embedder.forward(context)?; + + let x = self.core.forward(&context, &x, &c)?; + let x = self.unpatchifier.unpatchify(&x, h, w)?; + x.narrow(2, 0, h)?.narrow(3, 0, w) + } +} + +pub struct MMDiTCore { + joint_blocks: Vec, + context_qkv_only_joint_block: ContextQkvOnlyJointBlock, + final_layer: FinalLayer, +} + +impl MMDiTCore { + pub fn new( + depth: usize, + hidden_size: usize, + num_heads: usize, + patch_size: usize, + out_channels: usize, + vb: nn::VarBuilder, + ) -> Result { + let mut joint_blocks = Vec::with_capacity(depth - 1); + for i in 0..depth - 1 { + joint_blocks.push(JointBlock::new( + hidden_size, + num_heads, + vb.pp(format!("joint_blocks.{}", i)), + )?); + } + + Ok(Self { + joint_blocks, + context_qkv_only_joint_block: ContextQkvOnlyJointBlock::new( + hidden_size, + num_heads, + vb.pp(format!("joint_blocks.{}", depth - 1)), + )?, + final_layer: FinalLayer::new( + hidden_size, + patch_size, + out_channels, + vb.pp("final_layer"), + )?, + }) + } + + pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result { + let (mut context, mut x) = (context.clone(), x.clone()); + for joint_block in &self.joint_blocks { + (context, x) = joint_block.forward(&context, &x, c)?; + } + let x = self.context_qkv_only_joint_block.forward(&context, &x, c)?; + self.final_layer.forward(&x, c) + } +} diff --git a/candle-transformers/src/models/mmdit/projections.rs b/candle-transformers/src/models/mmdit/projections.rs new file mode 100644 index 0000000000..1077398f5c --- /dev/null +++ b/candle-transformers/src/models/mmdit/projections.rs @@ -0,0 +1,94 @@ +use candle::{Module, Result, Tensor}; +use candle_nn as nn; + +pub struct Qkv { + pub q: Tensor, + pub k: Tensor, + pub v: Tensor, +} + +pub struct Mlp { + fc1: nn::Linear, + act: nn::Activation, + fc2: nn::Linear, +} + +impl Mlp { + pub fn new( + in_features: usize, + hidden_features: usize, + vb: candle_nn::VarBuilder, + ) -> Result { + let fc1 = nn::linear(in_features, hidden_features, vb.pp("fc1"))?; + let act = nn::Activation::GeluPytorchTanh; + let fc2 = nn::linear(hidden_features, in_features, vb.pp("fc2"))?; + + Ok(Self { fc1, act, fc2 }) + } +} + +impl Module for Mlp { + fn forward(&self, x: &Tensor) -> Result { + let x = self.fc1.forward(x)?; + let x = self.act.forward(&x)?; + self.fc2.forward(&x) + } +} + +pub struct QkvOnlyAttnProjections { + qkv: nn::Linear, + head_dim: usize, +} + +impl QkvOnlyAttnProjections { + pub fn new(dim: usize, num_heads: usize, vb: nn::VarBuilder) -> Result { + // {'dim': 1536, 'num_heads': 24} + let head_dim = dim / num_heads; + let qkv = nn::linear(dim, dim * 3, vb.pp("qkv"))?; + Ok(Self { qkv, head_dim }) + } + + pub fn pre_attention(&self, x: &Tensor) -> Result { + let qkv = self.qkv.forward(x)?; + split_qkv(&qkv, self.head_dim) + } +} + +pub struct AttnProjections { + head_dim: usize, + qkv: nn::Linear, + proj: nn::Linear, +} + +impl AttnProjections { + pub fn new(dim: usize, num_heads: usize, vb: nn::VarBuilder) -> Result { + let head_dim = dim / num_heads; + let qkv = nn::linear(dim, dim * 3, vb.pp("qkv"))?; + let proj = nn::linear(dim, dim, vb.pp("proj"))?; + Ok(Self { + head_dim, + qkv, + proj, + }) + } + + pub fn pre_attention(&self, x: &Tensor) -> Result { + let qkv = self.qkv.forward(x)?; + split_qkv(&qkv, self.head_dim) + } + + pub fn post_attention(&self, x: &Tensor) -> Result { + self.proj.forward(x) + } +} + +fn split_qkv(qkv: &Tensor, head_dim: usize) -> Result { + let (batch_size, seq_len, _) = qkv.dims3()?; + let qkv = qkv.reshape((batch_size, seq_len, 3, (), head_dim))?; + let q = qkv.get_on_dim(2, 0)?; + let q = q.reshape((batch_size, seq_len, ()))?; + let k = qkv.get_on_dim(2, 1)?; + let k = k.reshape((batch_size, seq_len, ()))?; + let v = qkv.get_on_dim(2, 2)?; + Ok(Qkv { q, k, v }) +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 18c8833d79..c0de550bf7 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -32,6 +32,7 @@ pub mod metavoice; pub mod mistral; pub mod mixformer; pub mod mixtral; +pub mod mmdit; pub mod mobilenetv4; pub mod mobileone; pub mod moondream;