forked from huggingface/candle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Sketch the stable-lm model. (huggingface#1045)
- Loading branch information
1 parent
4631c48
commit b0442ef
Showing
2 changed files
with
365 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,364 @@ | ||
#![allow(unused)] | ||
use crate::models::with_tracing::{linear_no_bias, Linear}; | ||
use candle::{DType, Device, Module, Result, Tensor, D}; | ||
use candle_nn::{Activation, LayerNorm, VarBuilder}; | ||
use std::sync::Arc; | ||
|
||
// https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/configuration_stablelm_epoch.py | ||
#[derive(Debug, Clone, PartialEq)] | ||
pub struct Config { | ||
pub(crate) vocab_size: usize, | ||
pub(crate) intermediate_size: usize, | ||
pub(crate) hidden_size: usize, | ||
pub(crate) num_hidden_layers: usize, | ||
pub(crate) num_attention_heads: usize, | ||
pub(crate) num_key_value_heads: usize, | ||
pub(crate) hidden_act: Activation, | ||
pub(crate) rope_pct: f64, | ||
pub(crate) rope_theta: f64, | ||
pub(crate) max_position_embeddings: usize, | ||
pub(crate) norm_eps: f64, | ||
pub(crate) use_cache: bool, | ||
} | ||
|
||
impl Config { | ||
pub fn stablelm_3b_4e1t() -> Self { | ||
Self { | ||
vocab_size: 50304, | ||
intermediate_size: 6912, | ||
hidden_size: 2560, | ||
num_hidden_layers: 32, | ||
num_attention_heads: 32, | ||
num_key_value_heads: 32, | ||
hidden_act: Activation::Silu, | ||
rope_pct: 0.25, | ||
rope_theta: 10_000., | ||
max_position_embeddings: 4096, | ||
norm_eps: 1e-5, | ||
use_cache: true, | ||
} | ||
} | ||
|
||
fn head_dim(&self) -> usize { | ||
self.hidden_size / self.num_attention_heads | ||
} | ||
|
||
fn rotary_ndims(&self) -> usize { | ||
(self.head_dim() as f64 * self.rope_pct) as usize | ||
} | ||
|
||
fn num_kv_groups(&self) -> usize { | ||
self.num_attention_heads / self.num_key_value_heads | ||
} | ||
} | ||
|
||
#[derive(Debug)] | ||
struct RotaryEmbedding { | ||
sin: Tensor, | ||
cos: Tensor, | ||
} | ||
|
||
fn rotate_half(xs: &Tensor) -> Result<Tensor> { | ||
let xs = xs.chunk(2, D::Minus1)?; | ||
Tensor::cat(&[&xs[1].neg()?, &xs[0]], D::Minus1) | ||
} | ||
|
||
impl RotaryEmbedding { | ||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> { | ||
let dim = cfg.rotary_ndims(); | ||
let max_seq_len = cfg.max_position_embeddings; | ||
let inv_freq: Vec<_> = (0..dim) | ||
.step_by(2) | ||
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32) | ||
.collect(); | ||
let inv_freq_len = inv_freq.len(); | ||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; | ||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)? | ||
.to_dtype(dtype)? | ||
.reshape((max_seq_len, 1))?; | ||
let freqs = t.matmul(&inv_freq)?; | ||
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?; | ||
Ok(Self { | ||
sin: freqs.sin()?, | ||
cos: freqs.cos()?, | ||
}) | ||
} | ||
|
||
fn apply_rotary_emb_qkv( | ||
&self, | ||
q: &Tensor, | ||
k: &Tensor, | ||
seqlen_offset: usize, | ||
) -> Result<(Tensor, Tensor)> { | ||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; | ||
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?; | ||
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?; | ||
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) | ||
let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) | ||
let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?; | ||
let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?; | ||
Ok((q_embed, k_embed)) | ||
} | ||
} | ||
|
||
#[derive(Debug)] | ||
#[allow(clippy::upper_case_acronyms)] | ||
struct MLP { | ||
gate_proj: Linear, | ||
up_proj: Linear, | ||
down_proj: Linear, | ||
act_fn: Activation, | ||
} | ||
|
||
impl MLP { | ||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { | ||
let hidden_sz = cfg.hidden_size; | ||
let intermediate_sz = cfg.intermediate_size; | ||
let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?; | ||
let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?; | ||
let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?; | ||
Ok(Self { | ||
gate_proj, | ||
up_proj, | ||
down_proj, | ||
act_fn: cfg.hidden_act, | ||
}) | ||
} | ||
} | ||
|
||
impl Module for MLP { | ||
fn forward(&self, xs: &Tensor) -> Result<Tensor> { | ||
let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?; | ||
let rhs = xs.apply(&self.up_proj)?; | ||
(lhs * rhs)?.apply(&self.down_proj) | ||
} | ||
} | ||
|
||
#[derive(Debug)] | ||
struct Attention { | ||
q_proj: Linear, | ||
k_proj: Linear, | ||
v_proj: Linear, | ||
o_proj: Linear, | ||
num_heads: usize, | ||
num_kv_heads: usize, | ||
num_kv_groups: usize, | ||
head_dim: usize, | ||
hidden_size: usize, | ||
rotary_emb: Arc<RotaryEmbedding>, | ||
kv_cache: Option<(Tensor, Tensor)>, | ||
use_cache: bool, | ||
} | ||
|
||
impl Attention { | ||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> { | ||
let hidden_sz = cfg.hidden_size; | ||
let head_dim = cfg.head_dim(); | ||
let num_heads = cfg.num_attention_heads; | ||
let num_kv_heads = cfg.num_key_value_heads; | ||
let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?; | ||
let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?; | ||
let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?; | ||
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?; | ||
Ok(Self { | ||
q_proj, | ||
k_proj, | ||
v_proj, | ||
o_proj, | ||
num_heads, | ||
num_kv_heads, | ||
num_kv_groups: cfg.num_kv_groups(), | ||
head_dim, | ||
hidden_size: hidden_sz, | ||
rotary_emb, | ||
kv_cache: None, | ||
use_cache: cfg.use_cache, | ||
}) | ||
} | ||
|
||
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> { | ||
let n_rep = self.num_kv_groups; | ||
if n_rep == 1 { | ||
Ok(xs) | ||
} else { | ||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?; | ||
xs.unsqueeze(2)? | ||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))? | ||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim)) | ||
} | ||
} | ||
|
||
fn forward( | ||
&mut self, | ||
xs: &Tensor, | ||
attention_mask: Option<&Tensor>, | ||
seqlen_offset: usize, | ||
) -> Result<Tensor> { | ||
let (b_sz, q_len, _) = xs.dims3()?; | ||
|
||
let query_states = self.q_proj.forward(xs)?; | ||
let key_states = self.k_proj.forward(xs)?; | ||
let value_states = self.v_proj.forward(xs)?; | ||
|
||
let query_states = query_states | ||
.reshape((b_sz, q_len, self.num_heads, self.head_dim))? | ||
.transpose(1, 2)?; | ||
let key_states = key_states | ||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? | ||
.transpose(1, 2)?; | ||
let value_states = value_states | ||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? | ||
.transpose(1, 2)?; | ||
|
||
let (query_states, key_states) = | ||
self.rotary_emb | ||
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?; | ||
|
||
let (key_states, value_states) = match &self.kv_cache { | ||
None => (key_states, value_states), | ||
Some((prev_k, prev_v)) => { | ||
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?; | ||
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?; | ||
(key_states, value_states) | ||
} | ||
}; | ||
if self.use_cache { | ||
self.kv_cache = Some((key_states.clone(), value_states.clone())); | ||
} | ||
|
||
let key_states = self.repeat_kv(key_states)?; | ||
let value_states = self.repeat_kv(value_states)?; | ||
|
||
let attn_output = { | ||
let scale = 1f64 / f64::sqrt(self.head_dim as f64); | ||
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?; | ||
|
||
let attn_weights = match attention_mask { | ||
None => attn_weights, | ||
Some(mask) => attn_weights.broadcast_add(mask)?, | ||
}; | ||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; | ||
attn_weights.matmul(&value_states)? | ||
}; | ||
attn_output | ||
.transpose(1, 2)? | ||
.reshape((b_sz, q_len, self.hidden_size))? | ||
.apply(&self.o_proj) | ||
} | ||
} | ||
|
||
#[derive(Debug)] | ||
struct DecoderLayer { | ||
self_attn: Attention, | ||
mlp: MLP, | ||
input_layernorm: LayerNorm, | ||
post_attention_layernorm: LayerNorm, | ||
} | ||
|
||
impl DecoderLayer { | ||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> { | ||
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?; | ||
let mlp = MLP::new(cfg, vb.pp("mlp"))?; | ||
let input_layernorm = | ||
candle_nn::layer_norm(cfg.hidden_size, cfg.norm_eps, vb.pp("input_layernorm"))?; | ||
let post_attention_layernorm = candle_nn::layer_norm( | ||
cfg.hidden_size, | ||
cfg.norm_eps, | ||
vb.pp("post_attention_layernorm"), | ||
)?; | ||
Ok(Self { | ||
self_attn, | ||
mlp, | ||
input_layernorm, | ||
post_attention_layernorm, | ||
}) | ||
} | ||
|
||
fn forward( | ||
&mut self, | ||
xs: &Tensor, | ||
attention_mask: Option<&Tensor>, | ||
seqlen_offset: usize, | ||
) -> Result<Tensor> { | ||
let residual = xs; | ||
let xs = self.input_layernorm.forward(xs)?; | ||
let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?; | ||
let xs = (xs + residual)?; | ||
let residual = &xs; | ||
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?; | ||
residual + xs | ||
} | ||
} | ||
|
||
#[derive(Debug)] | ||
pub struct Model { | ||
embed_tokens: candle_nn::Embedding, | ||
layers: Vec<DecoderLayer>, | ||
norm: LayerNorm, | ||
lm_head: Linear, | ||
device: Device, | ||
dtype: DType, | ||
} | ||
|
||
impl Model { | ||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { | ||
let vb_m = vb.pp("model"); | ||
let embed_tokens = | ||
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?; | ||
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?); | ||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers); | ||
let vb_l = vb_m.pp("layers"); | ||
for layer_idx in 0..cfg.num_hidden_layers { | ||
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?; | ||
layers.push(layer) | ||
} | ||
let norm = candle_nn::layer_norm(cfg.hidden_size, cfg.norm_eps, vb_m.pp("norm"))?; | ||
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; | ||
Ok(Self { | ||
embed_tokens, | ||
layers, | ||
norm, | ||
lm_head, | ||
device: vb.device().clone(), | ||
dtype: vb.dtype(), | ||
}) | ||
} | ||
|
||
fn prepare_decoder_attention_mask( | ||
&self, | ||
b_size: usize, | ||
tgt_len: usize, | ||
seqlen_offset: usize, | ||
) -> Result<Tensor> { | ||
// Sliding window mask? | ||
let mask: Vec<_> = (0..tgt_len) | ||
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. })) | ||
.collect(); | ||
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?; | ||
let mask = if seqlen_offset > 0 { | ||
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?; | ||
Tensor::cat(&[&mask0, &mask], D::Minus1)? | ||
} else { | ||
mask | ||
}; | ||
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))? | ||
.to_dtype(self.dtype) | ||
} | ||
|
||
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> { | ||
let (b_size, seq_len) = input_ids.dims2()?; | ||
let attention_mask = if seq_len <= 1 { | ||
None | ||
} else { | ||
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?; | ||
Some(mask) | ||
}; | ||
let mut xs = self.embed_tokens.forward(input_ids)?; | ||
for layer in self.layers.iter_mut() { | ||
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)? | ||
} | ||
xs.narrow(1, seq_len - 1, 1)? | ||
.apply(&self.norm)? | ||
.apply(&self.lm_head) | ||
} | ||
} |