From f93b19164e9966e3ea913a3a4033d753f16a532f Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 10 Dec 2024 18:55:52 +0100 Subject: [PATCH] Import some internal changes. --- rust/moshi-core/src/lib.rs | 1 + rust/moshi-core/src/lm.rs | 12 ++ rust/moshi-core/src/quantized_transformer.rs | 18 +- rust/moshi-core/src/transformer.rs | 116 ++++++++--- rust/moshi-core/src/tts_streaming.rs | 197 +++++++++++++++++++ 5 files changed, 310 insertions(+), 34 deletions(-) create mode 100644 rust/moshi-core/src/tts_streaming.rs diff --git a/rust/moshi-core/src/lib.rs b/rust/moshi-core/src/lib.rs index 9ec41ea..c58a6c2 100644 --- a/rust/moshi-core/src/lib.rs +++ b/rust/moshi-core/src/lib.rs @@ -17,6 +17,7 @@ pub mod seanet; pub mod streaming; pub mod transformer; pub mod tts; +pub mod tts_streaming; pub mod wav; #[derive(Debug, Copy, Clone, PartialEq, Eq)] diff --git a/rust/moshi-core/src/lm.rs b/rust/moshi-core/src/lm.rs index 5fd286d..5ee1e20 100644 --- a/rust/moshi-core/src/lm.rs +++ b/rust/moshi-core/src/lm.rs @@ -515,6 +515,18 @@ impl LmModel { } } + pub fn forward_ca( + &mut self, + text_ids: Option, + audio_ids: Vec>, + ca_src: &Tensor, + ) -> candle::Result<(Tensor, Tensor)> { + match self { + Self::Lm(m) => m.forward_ca(text_ids, audio_ids, ca_src), + Self::QuantizedLm(m) => m.forward_ca(text_ids, audio_ids, ca_src), + } + } + pub fn depformer_sample( &mut self, step_idx: usize, diff --git a/rust/moshi-core/src/quantized_transformer.rs b/rust/moshi-core/src/quantized_transformer.rs index b35f1e5..b2768bd 100644 --- a/rust/moshi-core/src/quantized_transformer.rs +++ b/rust/moshi-core/src/quantized_transformer.rs @@ -3,7 +3,7 @@ // LICENSE file in the root directory of this source tree. use crate::streaming::{StreamTensor, StreamingModule}; -use crate::transformer::{get_mask, CrossAttention, PositionalEmbedding, RotaryEmbedding}; +use crate::transformer::{get_mask, PositionalEmbedding, RotaryEmbedding}; use candle::{DType, IndexOp, Module, Result, Tensor, D}; use candle_transformers::quantized_nn::{layer_norm, linear_b, Linear}; @@ -169,7 +169,7 @@ pub struct StreamingMultiheadCrossAttention { } impl StreamingMultiheadCrossAttention { - pub fn new(_ca: CrossAttention, _cfg: &Config, _vb: VarBuilder) -> Result { + pub fn new(_cfg: &Config, _vb: VarBuilder) -> Result { candle::bail!("cross-attn is not supported at the moment") } @@ -347,14 +347,12 @@ impl StreamingTransformerLayer { } }; let self_attn = StreamingMultiheadAttention::new(rope, cfg, vb.pp("self_attn"))?; - let cross_attn = match cfg.cross_attention { - Some(ca) => { - let norm_cross = layer_norm(cfg.d_model, 1e-5, vb.pp("norm_cross"))?; - let cross_attn = - StreamingMultiheadCrossAttention::new(ca, cfg, vb.pp("cross_attention"))?; - Some((norm_cross, cross_attn)) - } - None => None, + let cross_attn = if cfg.cross_attention.is_some() { + let norm_cross = layer_norm(cfg.d_model, 1e-5, vb.pp("norm_cross"))?; + let cross_attn = StreamingMultiheadCrossAttention::new(cfg, vb.pp("cross_attention"))?; + Some((norm_cross, cross_attn)) + } else { + None }; Ok(Self { self_attn, diff --git a/rust/moshi-core/src/transformer.rs b/rust/moshi-core/src/transformer.rs index e4c4b2b..6df4246 100644 --- a/rust/moshi-core/src/transformer.rs +++ b/rust/moshi-core/src/transformer.rs @@ -22,10 +22,86 @@ pub enum PositionalEmbedding { None, } -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -pub enum CrossAttention { +#[derive(Debug, Copy, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +pub enum CrossAttentionGating { + // Configure Type of gating used at the output of vision cross-attention layers + Normal, + ConstantGatedTanh, + ConstantGatedSigmoid, + ConditionalGatedTanh, + ConditionalGatedSigmoid, +} + +#[derive(Debug, Clone)] +pub enum XaGate { + // Corresponding module Normal, - Gated, + ConstantGated { + alpha: Tensor, + }, + ConditionalGated { + in_proj: Linear, + out_proj: Linear, + activation: candle_nn::init::NonLinearity, + }, +} + +impl XaGate { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + match cfg.cross_attention { + // no cross attention - shouldn't occur here + None => candle::bail!("Invalid cross-attention config specified."), + // no gating + Some(CrossAttentionGating::Normal) => Ok(Self::Normal), + // constant (per-layer parameter) passed through tanh or sigmoid + Some(CrossAttentionGating::ConstantGatedTanh) => { + let alpha = vb.get((1, 1, 1), "gate.alpha")?.tanh()?; + Ok(Self::ConstantGated { alpha }) + } + Some(CrossAttentionGating::ConstantGatedSigmoid) => { + let alpha = candle_nn::ops::sigmoid(&(vb.get((1, 1, 1), "gate.alpha")? - 4.0)?)?; + Ok(Self::ConstantGated { alpha }) + } + // input conditional (small MLP) + Some(CrossAttentionGating::ConditionalGatedTanh) + | Some(CrossAttentionGating::ConditionalGatedSigmoid) => { + let dim = cfg.d_model; + let hidden_dims = (0.125 * dim as f32).floor() as usize; + let in_proj = linear(dim, hidden_dims, false, vb.pp("gate.alpha.0"))?; + let out_proj = linear(hidden_dims, dim, false, vb.pp("gate.alpha.2"))?; + let activation = match cfg.cross_attention { + Some(CrossAttentionGating::ConditionalGatedTanh) => { + candle_nn::init::NonLinearity::Tanh + } + Some(CrossAttentionGating::ConditionalGatedSigmoid) => { + candle_nn::init::NonLinearity::Sigmoid + } + _ => candle::bail!("Invalid cross-attention config specified."), + }; + Ok(Self::ConditionalGated { in_proj, out_proj, activation }) + } + } + } +} + +impl Module for XaGate { + fn forward(&self, xs: &Tensor) -> Result { + match self { + Self::Normal => Ok(xs.clone()), + Self::ConstantGated { alpha } => xs.broadcast_mul(alpha), + Self::ConditionalGated { in_proj, out_proj, activation } => { + let alpha = xs.apply(in_proj)?.relu()?.apply(out_proj)?; + let alpha = match activation { + candle_nn::init::NonLinearity::Tanh => alpha.tanh(), + candle_nn::init::NonLinearity::Sigmoid => { + candle_nn::ops::sigmoid(&(alpha - 4.0)?) + } + _ => candle::bail!("Invalid non-linearity specified in cross-attention gating"), + }; + xs * alpha? + } + } + } } #[derive(Debug, Clone)] @@ -40,7 +116,7 @@ pub struct Config { pub layer_scale: Option, pub positional_embedding: PositionalEmbedding, pub use_conv_block: bool, - pub cross_attention: Option, + pub cross_attention: Option, pub conv_kernel_size: usize, pub use_conv_bias: bool, pub gating: Option, @@ -248,12 +324,12 @@ pub struct StreamingMultiheadCrossAttention { kv_repeat: usize, num_heads: usize, neg_inf: Tensor, - tanh_gate_alpha: Option, + gate: XaGate, span: tracing::Span, } impl StreamingMultiheadCrossAttention { - pub fn new(ca: CrossAttention, cfg: &Config, vb: VarBuilder) -> Result { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { let embed_dim = cfg.d_model; let num_kv = cfg.num_heads / cfg.kv_repeat; let kv_dim = num_kv * (embed_dim / cfg.num_heads); @@ -276,10 +352,7 @@ impl StreamingMultiheadCrossAttention { let in_proj_v = Linear::new(in_proj_weight_v, in_proj_bias_v); let out_proj = linear(embed_dim, embed_dim, cfg.bias_attn, vb.pp("out_proj"))?; let neg_inf = Tensor::new(f32::NEG_INFINITY, vb.device())?.to_dtype(vb.dtype())?; - let tanh_gate_alpha = match ca { - CrossAttention::Gated => Some(vb.get((1, 1, 1), "tanh_gate.alpha")?.tanh()?), - CrossAttention::Normal => None, - }; + let gate = XaGate::new(cfg, vb)?; Ok(Self { in_proj_q, in_proj_k, @@ -288,7 +361,7 @@ impl StreamingMultiheadCrossAttention { kv_repeat: cfg.kv_repeat, num_heads: cfg.num_heads, neg_inf, - tanh_gate_alpha, + gate, span: tracing::span!(tracing::Level::TRACE, "mhca"), }) } @@ -331,11 +404,8 @@ impl StreamingMultiheadCrossAttention { let xs = xs .transpose(1, 2)? // b,t,h,d .reshape((b, t, hd))? - .apply(&self.out_proj)?; - let xs = match self.tanh_gate_alpha.as_ref() { - None => xs, - Some(alpha) => xs.broadcast_mul(alpha)?, - }; + .apply(&self.out_proj)? + .apply(&self.gate)?; Ok(xs) } } @@ -517,14 +587,12 @@ impl StreamingTransformerLayer { } }; let self_attn = StreamingMultiheadAttention::new(rope, cfg, vb.pp("self_attn"))?; - let cross_attn = match cfg.cross_attention { - Some(ca) => { - let norm_cross = LayerNorm::new_no_bias(cfg.d_model, 1e-5, vb.pp("norm_cross"))?; - let cross_attn = - StreamingMultiheadCrossAttention::new(ca, cfg, vb.pp("cross_attention"))?; - Some((norm_cross, cross_attn)) - } - None => None, + let cross_attn = if cfg.cross_attention.is_some() { + let norm_cross = LayerNorm::new_no_bias(cfg.d_model, 1e-5, vb.pp("norm_cross"))?; + let cross_attn = StreamingMultiheadCrossAttention::new(cfg, vb.pp("cross_attention"))?; + Some((norm_cross, cross_attn)) + } else { + None }; Ok(Self { self_attn, diff --git a/rust/moshi-core/src/tts_streaming.rs b/rust/moshi-core/src/tts_streaming.rs new file mode 100644 index 0000000..d57db77 --- /dev/null +++ b/rust/moshi-core/src/tts_streaming.rs @@ -0,0 +1,197 @@ +// Copyright (c) Kyutai, all rights reserved. +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +use candle::{IndexOp, Tensor}; +use candle_transformers::generation::LogitsProcessor; + +pub const UNGENERATED: u32 = u32::MAX; + +#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)] +pub struct Config { + pub generated_audio_codebooks: usize, + pub audio_vocab_size: usize, + pub acoustic_delay: usize, + pub text_pad_token: u32, + pub text_bos_token: u32, + pub text_eos_token: u32, + pub text_eop_token: u32, + pub text_start_token: u32, + pub text_audio_delay_in_tokens: usize, +} + +impl Config { + pub fn v0_1() -> Self { + Self { + generated_audio_codebooks: 16, + audio_vocab_size: 2049, + acoustic_delay: 2, + text_eop_token: 0, + text_bos_token: 1, + text_eos_token: 2, + text_pad_token: 3, + text_start_token: 32000, + text_audio_delay_in_tokens: 25, + } + } + + pub fn audio_pad_token(&self) -> u32 { + self.audio_vocab_size as u32 - 1 + } + + pub fn total_audio_codebooks(&self) -> usize { + self.generated_audio_codebooks + } +} + +pub struct State { + model: crate::lm::LmModel, + ca_src: Option, + audio_tokens: Vec>, + text_tokens: Vec, + audio_lp: LogitsProcessor, + text_lp: LogitsProcessor, + step_idx: usize, + config: Config, +} + +impl State { + pub fn new( + model: crate::lm::LmModel, + ca_src: Option, + max_step_idx: usize, + audio_lp: LogitsProcessor, + text_lp: LogitsProcessor, + config: Config, + ) -> Self { + let audio_tokens: Vec> = vec![ + vec![UNGENERATED; config.total_audio_codebooks()]; + max_step_idx + config.acoustic_delay + ]; + let text_tokens = vec![UNGENERATED; max_step_idx + config.acoustic_delay]; + Self { model, ca_src, audio_tokens, text_tokens, audio_lp, text_lp, step_idx: 0, config } + } + + pub fn step_idx(&self) -> usize { + self.step_idx + } + + fn audio_pad_token(&self) -> u32 { + self.config.audio_pad_token() + } + + pub fn config(&self) -> &Config { + &self.config + } + + // The acoustic tokens are written with a delay, so this can create "gaps" of UNGENERATED + // tokens in the case where we call `step_audio_prompt` *after* `step`. + pub fn step( + &mut self, + text_token: u32, + possible_text_token: Option, + ) -> candle::Result { + let mut codes = Vec::with_capacity(self.config.total_audio_codebooks()); + let dev = self.model.device(); + for codebook in 0..self.config.total_audio_codebooks() { + let t = if codebook == 0 || codebook == 8 { + if self.step_idx == 0 { + self.audio_pad_token() + } else { + self.audio_tokens[self.step_idx - 1][codebook] + } + } else if self.step_idx <= self.config.acoustic_delay { + self.audio_pad_token() + } else { + self.audio_tokens[self.step_idx - self.config.acoustic_delay - 1][codebook] + }; + if t == UNGENERATED { + candle::bail!("internal error, ungenerated {}", self.step_idx) + } + let t = Tensor::new(&[t], dev)?.unsqueeze(0)?; + codes.push(Some(t)) + } + let text_token = Some(Tensor::from_vec(vec![text_token], (1, 1), dev)?); + let (text_logits, ys) = match self.ca_src.as_ref() { + None => self.model.forward(text_token, codes)?, + Some(ca_src) => self.model.forward_ca(text_token, codes, ca_src)?, + }; + let text_logits = text_logits.i((0, 0))?; + let text_token = self.text_lp.sample_f(&text_logits, |prs| { + let mut sum_p = 0.; + for (idx, pr) in prs.iter_mut().enumerate() { + let idx = idx as u32; + if idx != self.config.text_pad_token && idx != self.config.text_eop_token { + sum_p += *pr; + *pr = 0. + } + } + if let Some(tt) = possible_text_token { + prs[tt as usize] = sum_p + } + })?; + self.text_tokens[self.step_idx] = text_token; + let last_audio_tokens = self.model.depformer_sample( + self.step_idx, + &ys, + Some(text_token), + &mut self.audio_lp, + )?; + let audio_pad_token = self.audio_pad_token(); + for c_idx in 0..self.config.generated_audio_codebooks { + let delay = if c_idx == 0 || c_idx == 8 { 0 } else { self.config.acoustic_delay }; + let pos = &mut self.audio_tokens[self.step_idx.saturating_sub(delay)][c_idx]; + match last_audio_tokens.as_ref() { + Some(lat) => { + if *pos == UNGENERATED { + *pos = lat[c_idx] + } + } + None => { + if *pos == UNGENERATED { + *pos = audio_pad_token + } + } + } + } + self.step_idx += 1; + if self.step_idx >= self.audio_tokens.len() { + candle::bail!("max step-idx reached") + } + Ok(text_token) + } + + /// If include_all is set, all the time steps are returned. Otherwise only the timesteps that + /// have been generated are handled. + pub fn audio_tokens(&self, include_all: bool) -> &[Vec] { + if include_all { + &self.audio_tokens + } else { + let max_idx = usize::min(self.step_idx, self.audio_tokens.len()); + &self.audio_tokens[..max_idx] + } + } + + pub fn text_tokens(&self, include_all: bool) -> &[u32] { + if include_all { + &self.text_tokens + } else { + let max_idx = usize::min(self.step_idx, self.text_tokens.len()); + &self.text_tokens[..max_idx] + } + } + + pub fn last_audio_tokens(&self) -> Option> { + if self.step_idx <= self.config.acoustic_delay { + None + } else { + // step_idx is in advance by 1 + there is a 2 token delay on audio tokens. + let audio_tokens = &self.audio_tokens[self.step_idx - self.config.acoustic_delay - 1]; + if audio_tokens.iter().any(|v| *v as usize >= self.config.audio_vocab_size - 1) { + None + } else { + Some(audio_tokens.clone()) + } + } + } +}