From 500c9f288214cf817c8c52c7c9a6187cb279c563 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=94=90=E7=92=9C?= <113148619+donjuanplatinum@users.noreply.github.com> Date: Mon, 5 Aug 2024 11:48:09 -0400 Subject: [PATCH] add models support and example for THUDM/glm-4 (#2362) * add models support and example for THUDM/glm-4 * fix the ci report * fmt * fix * Update README.org * Update README.org * fmt * Update README.org * README.md add codegeex4 * README.md add glm4 * Typo. * change expect into ? --------- Co-authored-by: Laurent Mazare --- README.md | 2 + candle-examples/examples/glm4/README.org | 77 +++ candle-examples/examples/glm4/main.rs | 255 ++++++++++ candle-transformers/src/models/glm4.rs | 595 +++++++++++++++++++++++ candle-transformers/src/models/mod.rs | 1 + 5 files changed, 930 insertions(+) create mode 100644 candle-examples/examples/glm4/README.org create mode 100644 candle-examples/examples/glm4/main.rs create mode 100644 candle-transformers/src/models/glm4.rs diff --git a/README.md b/README.md index 54991fae99..543b9ca84c 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,8 @@ We also provide a some command line based examples using state of the art models - [LLaMA v1, v2, and v3](./candle-examples/examples/llama/): general LLM, includes the SOLAR-10.7B variant. - [Falcon](./candle-examples/examples/falcon/): general LLM. +- [Codegeex4](./candle-examples/examples/codegeex4-9b/): Code completion,code interpreter,web search,fuction calling,repository-level +- [GLM4](./candle-examples/examples/glm4/): Open Multilingual Multimodal Chat LMs by THUDM - [Gemma](./candle-examples/examples/gemma/): 2b and 7b general LLMs from Google Deepmind. - [RecurrentGemma](./candle-examples/examples/recurrent-gemma/): 2b and 7b Griffin based models from Google that mix attention with a RNN like state. diff --git a/candle-examples/examples/glm4/README.org b/candle-examples/examples/glm4/README.org new file mode 100644 index 0000000000..364f61e8eb --- /dev/null +++ b/candle-examples/examples/glm4/README.org @@ -0,0 +1,77 @@ +* GLM4 +GLM-4-9B is the open-source version of the latest generation of pre-trained models in the GLM-4 series launched by Zhipu AI. + +- [[https://github.com/THUDM/GLM4][Github]] +- [[https://huggingface.co/THUDM/glm-4-9b][huggingface]] + +** Running with ~cuda~ + +#+begin_src shell + cargo run --example glm4 --release --features cuda +#+end_src + +** Running with ~cpu~ +#+begin_src shell + cargo run --example glm4 --release -- --cpu +#+end_src + +** Output Example +#+begin_src shell +cargo run --example glm4 --release --features cuda -- --sample-len 500 --cache . + Finished release [optimized] target(s) in 0.24s + Running `/root/candle/target/release/examples/glm4 --sample-len 500 --cache .` +avx: true, neon: false, simd128: false, f16c: true +temp: 0.60 repeat-penalty: 1.20 repeat-last-n: 64 +cache path . +retrieved the files in 6.88963ms +loaded the model in 6.113752297s +starting the inference loop +[欢迎使用GLM-4,请输入prompt] +请你告诉我什么是FFT +266 tokens generated (34.50 token/s) +Result: +。Fast Fourier Transform (FFT) 是一种快速计算离散傅里叶变换(DFT)的方法,它广泛应用于信号处理、图像处理和数据分析等领域。 + +具体来说,FFT是一种将时域数据转换为频域数据的算法。在数字信号处理中,我们通常需要知道信号的频率成分,这就需要进行傅立叶变换。传统的傅立叶变换的计算复杂度较高,而 FFT 则大大提高了计算效率,使得大规模的 DFT 换成为可能。 + +以下是使用 Python 中的 numpy 进行 FFT 的简单示例: + +```python +import numpy as np + +# 创建一个时域信号 +t = np.linspace(0, 1, num=100) +f = np.sin(2*np.pi*5*t) + 3*np.cos(2*np.pi*10*t) + +# 对该信号做FFT变换,并计算其幅值谱 +fft_result = np.fft.fftshift(np.abs(np.fft.fft(f))) + +``` + +在这个例子中,我们首先创建了一个时域信号 f。然后我们对这个信号进行了 FFT 换,得到了一个频域结果 fft_result。 +#+end_src + +This example will read prompt from stdin + +* Citation +#+begin_src + @misc{glm2024chatglm, + title={ChatGLM: A Family of Large Language Models from GLM-130B to GLM-4 All Tools}, + author={Team GLM and Aohan Zeng and Bin Xu and Bowen Wang and Chenhui Zhang and Da Yin and Diego Rojas and Guanyu Feng and Hanlin Zhao and Hanyu Lai and Hao Yu and Hongning Wang and Jiadai Sun and Jiajie Zhang and Jiale Cheng and Jiayi Gui and Jie Tang and Jing Zhang and Juanzi Li and Lei Zhao and Lindong Wu and Lucen Zhong and Mingdao Liu and Minlie Huang and Peng Zhang and Qinkai Zheng and Rui Lu and Shuaiqi Duan and Shudan Zhang and Shulin Cao and Shuxun Yang and Weng Lam Tam and Wenyi Zhao and Xiao Liu and Xiao Xia and Xiaohan Zhang and Xiaotao Gu and Xin Lv and Xinghan Liu and Xinyi Liu and Xinyue Yang and Xixuan Song and Xunkai Zhang and Yifan An and Yifan Xu and Yilin Niu and Yuantao Yang and Yueyan Li and Yushi Bai and Yuxiao Dong and Zehan Qi and Zhaoyu Wang and Zhen Yang and Zhengxiao Du and Zhenyu Hou and Zihan Wang}, + year={2024}, + eprint={2406.12793}, + archivePrefix={arXiv}, + primaryClass={id='cs.CL' full_name='Computation and Language' is_active=True alt_name='cmp-lg' in_archive='cs' is_general=False description='Covers natural language processing. Roughly includes material in ACM Subject Class I.2.7. Note that work on artificial languages (programming languages, logics, formal systems) that does not explicitly address natural-language issues broadly construed (natural-language processing, computational linguistics, speech, text retrieval, etc.) is not appropriate for this area.'} +} +#+end_src + +#+begin_src + @misc{wang2023cogvlm, + title={CogVLM: Visual Expert for Pretrained Language Models}, + author={Weihan Wang and Qingsong Lv and Wenmeng Yu and Wenyi Hong and Ji Qi and Yan Wang and Junhui Ji and Zhuoyi Yang and Lei Zhao and Xixuan Song and Jiazheng Xu and Bin Xu and Juanzi Li and Yuxiao Dong and Ming Ding and Jie Tang}, + year={2023}, + eprint={2311.03079}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} +#+end_src diff --git a/candle-examples/examples/glm4/main.rs b/candle-examples/examples/glm4/main.rs new file mode 100644 index 0000000000..55a27f349e --- /dev/null +++ b/candle-examples/examples/glm4/main.rs @@ -0,0 +1,255 @@ +use candle_transformers::models::glm4::*; +use clap::Parser; + +use candle::{DType, Device, Tensor}; +use candle_nn::VarBuilder; +use candle_transformers::generation::LogitsProcessor; +use hf_hub::{Repo, RepoType}; +use tokenizers::Tokenizer; + +struct TextGeneration { + model: Model, + device: Device, + tokenizer: Tokenizer, + logits_processor: LogitsProcessor, + repeat_penalty: f32, + repeat_last_n: usize, + verbose_prompt: bool, + dtype: DType, +} + +impl TextGeneration { + #[allow(clippy::too_many_arguments)] + fn new( + model: Model, + tokenizer: Tokenizer, + seed: u64, + temp: Option, + top_p: Option, + repeat_penalty: f32, + repeat_last_n: usize, + verbose_prompt: bool, + device: &Device, + dtype: DType, + ) -> Self { + let logits_processor = LogitsProcessor::new(seed, temp, top_p); + Self { + model, + tokenizer, + logits_processor, + repeat_penalty, + repeat_last_n, + verbose_prompt, + device: device.clone(), + dtype, + } + } + + fn run(&mut self, sample_len: usize) -> anyhow::Result<()> { + use std::io::BufRead; + use std::io::BufReader; + use std::io::Write; + println!("starting the inference loop"); + println!("[欢迎使用GLM-4,请输入prompt]"); + let stdin = std::io::stdin(); + let reader = BufReader::new(stdin); + for line in reader.lines() { + let line = line.expect("Failed to read line"); + + let tokens = self.tokenizer.encode(line, true).expect("tokens error"); + if tokens.is_empty() { + panic!("Empty prompts are not supported in the chatglm model.") + } + if self.verbose_prompt { + for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) { + let token = token.replace('▁', " ").replace("<0x0A>", "\n"); + println!("{id:7} -> '{token}'"); + } + } + let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") { + Some(token) => *token, + None => panic!("cannot find the endoftext token"), + }; + let mut tokens = tokens.get_ids().to_vec(); + let mut generated_tokens = 0usize; + + std::io::stdout().flush().expect("output flush error"); + let start_gen = std::time::Instant::now(); + + let mut count = 0; + let mut result = vec![]; + for index in 0..sample_len { + count += 1; + let context_size = if index > 0 { 1 } else { tokens.len() }; + let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; + let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; + let logits = self.model.forward(&input)?; + let logits = logits.squeeze(0)?.to_dtype(self.dtype)?; + let logits = if self.repeat_penalty == 1. { + logits + } else { + let start_at = tokens.len().saturating_sub(self.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + self.repeat_penalty, + &tokens[start_at..], + )? + }; + + let next_token = self.logits_processor.sample(&logits)?; + tokens.push(next_token); + generated_tokens += 1; + if next_token == eos_token { + break; + } + let token = self + .tokenizer + .decode(&[next_token], true) + .expect("Token error"); + if self.verbose_prompt { + println!( + "[Count: {}] [Raw Token: {}] [Decode Token: {}]", + count, next_token, token + ); + } + result.push(token); + std::io::stdout().flush()?; + } + let dt = start_gen.elapsed(); + println!( + "\n{generated_tokens} tokens generated ({:.2} token/s)", + generated_tokens as f64 / dt.as_secs_f64(), + ); + println!("Result:"); + for tokens in result { + print!("{tokens}"); + } + self.model.reset_kv_cache(); // clean the cache + } + Ok(()) + } +} +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(name = "cache", short, long, default_value = ".")] + cache_path: String, + + #[arg(long)] + cpu: bool, + + /// Display the token for the specified prompt. + #[arg(long)] + verbose_prompt: bool, + + /// The temperature used to generate samples. + #[arg(long)] + temperature: Option, + + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// The length of the sample to generate (in tokens). + #[arg(long, short = 'n', default_value_t = 8192)] + sample_len: usize, + + #[arg(long)] + model_id: Option, + + #[arg(long)] + revision: Option, + + #[arg(long)] + weight_file: Option, + + #[arg(long)] + tokenizer: Option, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.2)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, +} + +fn main() -> anyhow::Result<()> { + let args = Args::parse(); + println!( + "avx: {}, neon: {}, simd128: {}, f16c: {}", + candle::utils::with_avx(), + candle::utils::with_neon(), + candle::utils::with_simd128(), + candle::utils::with_f16c() + ); + println!( + "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", + args.temperature.unwrap_or(0.6), + args.repeat_penalty, + args.repeat_last_n + ); + + let start = std::time::Instant::now(); + println!("cache path {}", args.cache_path); + let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into())) + .build() + .map_err(anyhow::Error::msg)?; + + let model_id = match args.model_id { + Some(model_id) => model_id.to_string(), + None => "THUDM/glm-4-9b".to_string(), + }; + let revision = match args.revision { + Some(rev) => rev.to_string(), + None => "main".to_string(), + }; + let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision)); + let tokenizer_filename = match args.tokenizer { + Some(file) => std::path::PathBuf::from(file), + None => api + .model("THUDM/codegeex4-all-9b".to_string()) + .get("tokenizer.json") + .map_err(anyhow::Error::msg)?, + }; + let filenames = match args.weight_file { + Some(weight_file) => vec![std::path::PathBuf::from(weight_file)], + None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?, + }; + println!("retrieved the files in {:?}", start.elapsed()); + let tokenizer = Tokenizer::from_file(tokenizer_filename).expect("Tokenizer Error"); + + let start = std::time::Instant::now(); + let config = Config::glm4(); + let device = candle_examples::device(args.cpu)?; + let dtype = if device.is_cuda() { + DType::BF16 + } else { + DType::F32 + }; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; + let model = Model::new(&config, vb)?; + + println!("loaded the model in {:?}", start.elapsed()); + + let mut pipeline = TextGeneration::new( + model, + tokenizer, + args.seed, + args.temperature, + args.top_p, + args.repeat_penalty, + args.repeat_last_n, + args.verbose_prompt, + &device, + dtype, + ); + pipeline.run(args.sample_len)?; + Ok(()) +} diff --git a/candle-transformers/src/models/glm4.rs b/candle-transformers/src/models/glm4.rs new file mode 100644 index 0000000000..3b436eaa6d --- /dev/null +++ b/candle-transformers/src/models/glm4.rs @@ -0,0 +1,595 @@ +use crate::models::with_tracing::{linear_b as linear, Linear}; +use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; +use candle_nn::VarBuilder; + +#[derive(Debug, Clone)] +pub struct Config { + pub num_layers: usize, + pub padded_vocab_size: usize, + pub hidden_size: usize, + pub ffn_hidden_size: usize, + pub kv_channels: usize, + pub num_attention_heads: usize, + pub seq_length: usize, + pub layernorm_epsilon: f64, + pub rmsnorm: bool, + pub apply_residual_connection_post_layernorm: bool, + pub post_layer_norm: bool, + pub add_bias_linear: bool, + pub add_qkv_bias: bool, + pub bias_dropout_fusion: bool, + pub multi_query_attention: bool, + pub multi_query_group_num: usize, + pub apply_query_key_layer_scaling: bool, + pub attention_softmax_in_fp32: bool, + pub fp32_residual_connection: bool, +} + +impl Config { + pub fn glm4() -> Self { + Self { + num_layers: 40, + padded_vocab_size: 151552, + hidden_size: 4096, + ffn_hidden_size: 13696, + kv_channels: 128, + num_attention_heads: 32, + seq_length: 8192, + layernorm_epsilon: 1e-5, + rmsnorm: true, + apply_residual_connection_post_layernorm: false, + post_layer_norm: true, + add_bias_linear: false, + add_qkv_bias: true, + bias_dropout_fusion: true, + multi_query_attention: true, + multi_query_group_num: 2, + apply_query_key_layer_scaling: true, + attention_softmax_in_fp32: true, + fp32_residual_connection: false, + } + } +} + +#[derive(Debug, Clone)] +struct RotaryEmbedding { + cache: Tensor, +} + +impl RotaryEmbedding { + fn new(cfg: &Config, dtype: DType, dev: &Device) -> Result { + let rotary_dim = cfg.kv_channels; + let n_elem = rotary_dim / 2; + let inv_freq: Vec<_> = (0..n_elem) + .step_by(2) + .map(|i| 1f32 / 10_000f64.powf(i as f64 / n_elem as f64) as f32) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; + let t = Tensor::arange(0u32, cfg.seq_length as u32, dev)? + .to_dtype(dtype)? + .reshape((cfg.seq_length, 1))?; + let freqs = t.matmul(&inv_freq)?; + let cache = Tensor::stack(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?; + Ok(Self { cache }) + } + + fn apply(&self, xs: &Tensor, seqlen_offset: usize) -> Result { + let (seqlen, _b, np, _hn) = xs.dims4()?; + let cache = self.cache.narrow(0, seqlen_offset, seqlen)?; + let rot_dim = cache.dim(D::Minus2)? * 2; + let (xs, xs_pass) = ( + xs.narrow(D::Minus1, 0, rot_dim)?, + xs.narrow(D::Minus1, rot_dim, rot_dim)?, + ); + let xshaped = xs.reshape((seqlen, (), np, rot_dim / 2, 2))?; + let cache = cache.reshape((seqlen, (), 1, rot_dim / 2, 2))?; + let (xshaped0, xshaped1) = ( + xshaped.i((.., .., .., .., 0))?, + xshaped.i((.., .., .., .., 1))?, + ); + let (cache0, cache1) = (cache.i((.., .., .., .., 0))?, cache.i((.., .., .., .., 1))?); + let xs_out = Tensor::stack( + &[ + (xshaped0.broadcast_mul(&cache0)? - xshaped1.broadcast_mul(&cache1)?)?, + (xshaped1.broadcast_mul(&cache0)? + xshaped0.broadcast_mul(&cache1)?)?, + ], + D::Minus1, + )?; + let xs_out = xs_out.flatten_from(3)?; + Tensor::cat(&[xs_out, xs_pass], D::Minus1) + } +} + +#[derive(Debug, Clone)] +struct CoreAttention { + coeff: Option, + norm_factor: f64, + dtype: DType, +} + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32, dtype: DType) -> Result { + let shape = mask.shape(); + let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; + let m = mask.where_cond(&on_true.to_dtype(dtype)?, on_false)?; + Ok(m) +} + +impl CoreAttention { + fn new(layer_number: usize, cfg: &Config, dtype: DType) -> Result { + let norm_factor = (cfg.kv_channels as f64).sqrt(); + let (norm_factor, coeff) = if cfg.apply_query_key_layer_scaling { + let coeff = f64::max(1.0, layer_number as f64); + (norm_factor * coeff, Some(coeff)) + } else { + (norm_factor, None) + }; + Ok(Self { + coeff, + norm_factor, + dtype, + }) + } + + fn forward( + &self, + query_layer: &Tensor, + key_layer: &Tensor, + value_layer: &Tensor, + attention_mask: &Option, + ) -> Result { + let output_size = ( + query_layer.dim(1)?, // b + query_layer.dim(2)?, // np + query_layer.dim(0)?, // sq + key_layer.dim(0)?, // sk + ); + let query_layer = + query_layer.reshape((output_size.2, output_size.0 * output_size.1, ()))?; + let key_layer = key_layer.reshape((output_size.3, output_size.0 * output_size.1, ()))?; + let matmul_result = Tensor::matmul( + &query_layer.transpose(0, 1)?.contiguous()?, + &key_layer.transpose(0, 1)?.transpose(1, 2)?.contiguous()?, + )?; + let matmul_result = (matmul_result / self.norm_factor)?.reshape(output_size)?; + let matmul_result = match self.coeff { + None => matmul_result, + Some(coeff) => (matmul_result * coeff)?, + }; + let attention_scores = match attention_mask { + Some(mask) => masked_fill( + &matmul_result, + &mask.broadcast_left((matmul_result.dim(0)?, matmul_result.dim(1)?))?, + f32::NEG_INFINITY, + self.dtype, + )?, + None => matmul_result, + }; + let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?; + + let output_size = ( + value_layer.dim(1)?, + value_layer.dim(2)?, + query_layer.dim(0)?, + value_layer.dim(3)?, + ); + let value_layer = + value_layer.reshape((value_layer.dim(0)?, output_size.0 * output_size.1, ()))?; + let attention_probs = + attention_probs.reshape((output_size.0 * output_size.1, output_size.2, ()))?; + let context_layer = Tensor::matmul( + &attention_probs.contiguous()?, + &value_layer.transpose(0, 1)?.contiguous()?, + )?; + let context_layer = context_layer.reshape(output_size)?; + let context_layer = context_layer.permute((2, 0, 1, 3))?.contiguous()?; + context_layer.flatten_from(D::Minus2) + } +} + +#[derive(Debug, Clone)] +struct SelfAttention { + query_key_value: Linear, + core_attention: CoreAttention, + dense: Linear, + multi_query_attention: bool, + num_attention_heads_per_partition: usize, + num_multi_query_groups_per_partition: usize, + hidden_size_per_attention_head: usize, + kv_cache: Option<(Tensor, Tensor)>, +} + +impl SelfAttention { + fn new(layer_number: usize, cfg: &Config, vb: VarBuilder) -> Result { + let projection_size = cfg.kv_channels * cfg.num_attention_heads; + let hidden_size_per_attention_head = projection_size / cfg.num_attention_heads; + let qkv_hidden_size = if cfg.multi_query_attention { + projection_size + 2 * hidden_size_per_attention_head * cfg.multi_query_group_num + } else { + 3 * projection_size + }; + let query_key_value = linear( + cfg.hidden_size, + qkv_hidden_size, + cfg.add_bias_linear || cfg.add_qkv_bias, + vb.pp("query_key_value"), + )?; + let core_attention = CoreAttention::new(layer_number, cfg, vb.dtype())?; + let dense = linear( + cfg.hidden_size, + cfg.hidden_size, + cfg.add_bias_linear, + vb.pp("dense"), + )?; + Ok(Self { + query_key_value, + core_attention, + dense, + multi_query_attention: cfg.multi_query_attention, + num_attention_heads_per_partition: cfg.num_attention_heads, + num_multi_query_groups_per_partition: cfg.multi_query_group_num, + hidden_size_per_attention_head: cfg.kv_channels, + kv_cache: None, + }) + } + + fn reset_kv_cache(&mut self) { + self.kv_cache = None + } + + fn forward( + &mut self, + xs: &Tensor, + attention_mask: &Option, + rotary_emb: &RotaryEmbedding, + ) -> Result { + let mixed_x_layer = xs.apply(&self.query_key_value)?; + if !self.multi_query_attention { + candle::bail!("only multi_query_attention=true is supported") + } + let hpa = self.hidden_size_per_attention_head; + let query_layer = + mixed_x_layer.narrow(D::Minus1, 0, self.num_attention_heads_per_partition * hpa)?; + let key_layer = mixed_x_layer.narrow( + D::Minus1, + self.num_attention_heads_per_partition * hpa, + self.num_multi_query_groups_per_partition * hpa, + )?; + let value_layer = mixed_x_layer.narrow( + D::Minus1, + self.num_attention_heads_per_partition * hpa + + self.num_multi_query_groups_per_partition * hpa, + self.num_multi_query_groups_per_partition * hpa, + )?; + let query_layer = query_layer.reshape(( + query_layer.dim(0)?, + query_layer.dim(1)?, + self.num_attention_heads_per_partition, + hpa, + ))?; + let key_layer = key_layer.reshape(( + key_layer.dim(0)?, + key_layer.dim(1)?, + self.num_multi_query_groups_per_partition, + hpa, + ))?; + let value_layer = value_layer.reshape(( + value_layer.dim(0)?, + value_layer.dim(1)?, + self.num_multi_query_groups_per_partition, + hpa, + ))?; + + // Rotary embeddings. + let seqlen_offset = match &self.kv_cache { + None => 0, + Some((prev_k, _)) => prev_k.dim(0)?, + }; + let query_layer = rotary_emb.apply(&query_layer, seqlen_offset)?; + let key_layer = rotary_emb.apply(&key_layer, seqlen_offset)?; + + // KV cache. + let (key_layer, value_layer) = match &self.kv_cache { + None => (key_layer, value_layer), + Some((prev_k, prev_v)) => { + let k = Tensor::cat(&[prev_k, &key_layer], 0)?; + let v = Tensor::cat(&[prev_v, &value_layer], 0)?; + (k, v) + } + }; + self.kv_cache = Some((key_layer.clone(), value_layer.clone())); + + // Repeat KV. + let ratio = + self.num_attention_heads_per_partition / self.num_multi_query_groups_per_partition; + let key_layer = { + let (d0, d1, d2, d3) = key_layer.dims4()?; + key_layer + .unsqueeze(D::Minus2)? + .expand((d0, d1, d2, ratio, d3))? + .reshape(( + d0, + d1, + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ))? + }; + let value_layer = { + let (d0, d1, d2, d3) = value_layer.dims4()?; + value_layer + .unsqueeze(D::Minus2)? + .expand((d0, d1, d2, ratio, d3))? + .reshape(( + d0, + d1, + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ))? + }; + + let context_layer = + self.core_attention + .forward(&query_layer, &key_layer, &value_layer, attention_mask)?; + let output = context_layer.apply(&self.dense)?; + Ok(output) + } +} + +#[allow(clippy::upper_case_acronyms)] +#[derive(Debug, Clone)] +struct MLP { + dense_h_to_4h: Linear, + dense_4h_to_h: Linear, +} + +impl MLP { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let dense_h_to_4h = linear( + cfg.hidden_size, + cfg.ffn_hidden_size * 2, + cfg.add_bias_linear, + vb.pp("dense_h_to_4h"), + )?; + let dense_4h_to_h = linear( + cfg.ffn_hidden_size, + cfg.hidden_size, + cfg.add_bias_linear, + vb.pp("dense_4h_to_h"), + )?; + Ok(Self { + dense_4h_to_h, + dense_h_to_4h, + }) + } +} + +impl Module for MLP { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.dense_h_to_4h)? + .apply(&candle_nn::Activation::Swiglu)? + .apply(&self.dense_4h_to_h) + } +} + +#[derive(Debug, Clone)] +struct Block { + input_layernorm: candle_nn::LayerNorm, + self_attention: SelfAttention, + post_attention_layernorm: candle_nn::LayerNorm, + mlp: MLP, + apply_residual_connection_post_layernorm: bool, +} + +impl Block { + fn new(layer_number: usize, cfg: &Config, vb: VarBuilder) -> Result { + let input_layernorm = if cfg.rmsnorm { + candle_nn::rms_norm( + cfg.hidden_size, + cfg.layernorm_epsilon, + vb.pp("input_layernorm"), + )? + .into_inner() + } else { + candle_nn::layer_norm( + cfg.hidden_size, + cfg.layernorm_epsilon, + vb.pp("input_layernorm"), + )? + }; + let post_attention_layernorm = if cfg.rmsnorm { + candle_nn::rms_norm( + cfg.hidden_size, + cfg.layernorm_epsilon, + vb.pp("post_attention_layernorm"), + )? + .into_inner() + } else { + candle_nn::layer_norm( + cfg.hidden_size, + cfg.layernorm_epsilon, + vb.pp("post_attention_layernorm"), + )? + }; + let self_attention = SelfAttention::new(layer_number, cfg, vb.pp("self_attention"))?; + let mlp = MLP::new(cfg, vb.pp("mlp"))?; + Ok(Self { + input_layernorm, + self_attention, + post_attention_layernorm, + mlp, + apply_residual_connection_post_layernorm: cfg.apply_residual_connection_post_layernorm, + }) + } + + fn reset_kv_cache(&mut self) { + self.self_attention.reset_kv_cache() + } + + fn forward( + &mut self, + xs: &Tensor, + attention_mask: &Option, + rotary_emb: &RotaryEmbedding, + ) -> Result { + let layernorm_output = xs.apply(&self.input_layernorm)?; + let attention_output = + self.self_attention + .forward(&layernorm_output, attention_mask, rotary_emb)?; + let residual = if self.apply_residual_connection_post_layernorm { + &layernorm_output + } else { + xs + }; + let layernorm_input = (residual + attention_output)?; + let layernorm_output = layernorm_input.apply(&self.post_attention_layernorm)?; + let mlp_output = layernorm_output.apply(&self.mlp)?; + let residual = if self.apply_residual_connection_post_layernorm { + &layernorm_output + } else { + &layernorm_input + }; + mlp_output + residual + } +} + +#[derive(Debug, Clone)] +struct Transformer { + layers: Vec, + final_layernorm: Option, + rotary_emb: RotaryEmbedding, +} + +impl Transformer { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let vb_l = vb.pp("layers"); + let mut layers = Vec::with_capacity(cfg.num_layers); + for layer_index in 0..cfg.num_layers { + let block = Block::new(layer_index + 1, cfg, vb_l.pp(layer_index))?; + layers.push(block) + } + let final_layernorm = if cfg.post_layer_norm { + let ln = if cfg.rmsnorm { + candle_nn::rms_norm( + cfg.hidden_size, + cfg.layernorm_epsilon, + vb.pp("final_layernorm"), + )? + .into_inner() + } else { + candle_nn::layer_norm( + cfg.hidden_size, + cfg.layernorm_epsilon, + vb.pp("final_layernorm"), + )? + }; + Some(ln) + } else { + None + }; + let rotary_emb = RotaryEmbedding::new(cfg, vb.dtype(), vb.device())?; + Ok(Self { + layers, + final_layernorm, + rotary_emb, + }) + } + + fn reset_kv_cache(&mut self) { + for block in self.layers.iter_mut() { + block.reset_kv_cache() + } + } + + fn forward(&mut self, xs: &Tensor, attention_mask: &Option) -> Result { + let mut xs = xs.clone(); + for block in self.layers.iter_mut() { + xs = block.forward(&xs, attention_mask, &self.rotary_emb)? + } + match self.final_layernorm.as_ref() { + None => Ok(xs), + Some(ln) => xs.apply(ln), + } + } +} + +#[derive(Debug, Clone)] +struct Embedding { + word_embeddings: candle_nn::Embedding, + fp32_residual_connection: bool, +} + +impl Embedding { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let word_embeddings = candle_nn::embedding( + cfg.padded_vocab_size, + cfg.hidden_size, + vb.pp("word_embeddings"), + )?; + Ok(Self { + word_embeddings, + fp32_residual_connection: cfg.fp32_residual_connection, + }) + } +} + +impl Module for Embedding { + fn forward(&self, xs: &Tensor) -> Result { + let xs = self.word_embeddings.forward(xs)?.transpose(0, 1)?; // b,s,h -> s,b,h + if self.fp32_residual_connection { + xs.to_dtype(candle::DType::F32) + } else { + xs.contiguous() + } + } +} + +#[derive(Debug, Clone)] +pub struct Model { + embedding: Embedding, + encoder: Transformer, + output_layer: Linear, +} + +fn get_mask(size: usize, device: &Device) -> Result { + let mask: Vec<_> = (0..size) + .flat_map(|i| (0..size).map(move |j| u8::from(j > i))) + .collect(); + Tensor::from_slice(&mask, (size, size), device) +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let vb = vb.pp("transformer"); + let embedding = Embedding::new(cfg, vb.pp("embedding"))?; + let encoder = Transformer::new(cfg, vb.pp("encoder"))?; + let output_layer = linear( + cfg.hidden_size, + cfg.padded_vocab_size, + false, + vb.pp("output_layer"), + )?; + + Ok(Self { + embedding, + encoder, + output_layer, + }) + } + + pub fn reset_kv_cache(&mut self) { + self.encoder.reset_kv_cache() + } + + pub fn forward(&mut self, xs: &Tensor) -> Result { + let (_b_size, seq_len) = xs.dims2()?; + let input_embeds = xs.apply(&self.embedding)?; + let attention_mask = if seq_len <= 1 { + None + } else { + Some(get_mask(seq_len, xs.device())?) + }; + let xs = self.encoder.forward(&input_embeds, &attention_mask)?; + let lm_logits = xs.i(seq_len - 1)?.apply(&self.output_layer)?; + Ok(lm_logits) + } +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index fa35011994..18c8833d79 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -19,6 +19,7 @@ pub mod eva2; pub mod falcon; pub mod flux; pub mod gemma; +pub mod glm4; pub mod hiera; pub mod jina_bert; pub mod llama;