From 5fc4f177273c0c2435f9faeb6c3c1ea3d92bdf4e Mon Sep 17 00:00:00 2001 From: Juan Gomez Date: Sat, 21 Sep 2024 05:52:01 -0400 Subject: [PATCH] Adding Granite 7b Instruct model example (#2487) * Adding Granite 7b Instruct model example * Minor refactoring to make it a little more idiomatic * Clippy fixes. * * Adding a README with some information about supported Granite models * Changing the default prompt to accomodate better the Language modality of the Granite 7b Instruct model --------- Co-authored-by: Laurent --- candle-examples/examples/granite/README.md | 20 + candle-examples/examples/granite/main.rs | 251 +++++++++++ candle-transformers/src/models/granite.rs | 458 +++++++++++++++++++++ candle-transformers/src/models/mod.rs | 1 + 4 files changed, 730 insertions(+) create mode 100644 candle-examples/examples/granite/README.md create mode 100644 candle-examples/examples/granite/main.rs create mode 100644 candle-transformers/src/models/granite.rs diff --git a/candle-examples/examples/granite/README.md b/candle-examples/examples/granite/README.md new file mode 100644 index 0000000000..b16455bab2 --- /dev/null +++ b/candle-examples/examples/granite/README.md @@ -0,0 +1,20 @@ +# candle-granite LLMs from IBM Research + +[Granite](https://www.ibm.com/granite) is a family of Large Language Models built for business, to help drive trust and scalability in AI-driven applications. + +## Running the example + +```bash +$ cargo run --example granite --features metal -r -- --model-type "granite7b-instruct" \ + --prompt "Explain how quantum computing differs from classical computing, focusing on key concepts like qubits, superposition, and entanglement. Describe two potential breakthroughs in the fields of drug discovery and cryptography. Offer a convincing argument for why businesses and governments should invest in quantum computing research now, emphasizing its future benefits and the risks of falling behind" + + Explain how quantum computing differs from classical computing, focusing on key concepts like qubits, superposition, and entanglement. Describe two potential breakthroughs in the fields of drug discovery and cryptography. Offer a convincing argument for why businesses and governments should invest in quantum computing research now, emphasizing its future benefits and the risks of falling behind competitors. + + In recent years, there has been significant interest in quantum computing due to its potential to revolutionize various fields, including drug discovery, cryptography, and optimization problems. Quantum computers, which leverage the principles of quantum mechanics, differ fundamentally from classical computers. Here are some of the key differences: +``` + +## Supported Models +There are two different modalities for the Granite family models: Language and Code. + +### Granite for language +1. [Granite 7b Instruct](https://huggingface.co/ibm-granite/granite-7b-instruct) diff --git a/candle-examples/examples/granite/main.rs b/candle-examples/examples/granite/main.rs new file mode 100644 index 0000000000..5697a99e6d --- /dev/null +++ b/candle-examples/examples/granite/main.rs @@ -0,0 +1,251 @@ +// An implementation of different Granite models https://www.ibm.com/granite + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +use anyhow::{bail, Error as E, Result}; +use clap::{Parser, ValueEnum}; + +use candle::{DType, Tensor}; +use candle_nn::VarBuilder; +use candle_transformers::generation::{LogitsProcessor, Sampling}; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use std::io::Write; + +use candle_transformers::models::granite as model; +use model::{Granite, GraniteConfig}; + +use std::time::Instant; + +const EOS_TOKEN: &str = ""; +const DEFAULT_PROMPT: &str = "How Fault Tolerant Quantum Computers will help humanity?"; + +#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] +enum GraniteModel { + Granite7bInstruct, +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// The temperature used to generate samples. + #[arg(long, default_value_t = 0.8)] + temperature: f64, + + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option, + + /// Only sample among the top K samples. + #[arg(long)] + top_k: Option, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// The length of the sample to generate (in tokens). + #[arg(short = 'n', long, default_value_t = 10000)] + sample_len: usize, + + /// Disable the key-value cache. + #[arg(long)] + no_kv_cache: bool, + + /// The initial prompt. + #[arg(long)] + prompt: Option, + + /// Use different dtype than f16 + #[arg(long)] + dtype: Option, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + #[arg(long)] + model_id: Option, + + #[arg(long)] + revision: Option, + + #[arg(long, default_value = "granite7b-instruct")] + model_type: GraniteModel, + + #[arg(long)] + use_flash_attn: bool, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.1)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 128)] + repeat_last_n: usize, +} + +fn main() -> Result<()> { + use tokenizers::Tokenizer; + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + + let device = candle_examples::device(args.cpu)?; + let dtype = match args.dtype.as_deref() { + Some("f16") => DType::F16, + Some("bf16") => DType::BF16, + Some("f32") => DType::F32, + Some(dtype) => bail!("Unsupported dtype {dtype}"), + None => DType::F16, + }; + let (granite, tokenizer_filename, mut cache, config) = { + let api = Api::new()?; + let model_id = args.model_id.unwrap_or_else(|| match args.model_type { + GraniteModel::Granite7bInstruct => "ibm-granite/granite-7b-instruct".to_string(), + }); + println!("loading the model weights from {model_id}"); + let revision = args.revision.unwrap_or("main".to_string()); + let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision)); + + let tokenizer_filename = api.get("tokenizer.json")?; + let config_filename = api.get("config.json")?; + let config: GraniteConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?; + let config = config.into_config(args.use_flash_attn); + + let filenames = match args.model_type { + GraniteModel::Granite7bInstruct => { + candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")? + } + }; + let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?; + + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; + ( + Granite::load(vb, &config)?, + tokenizer_filename, + cache, + config, + ) + }; + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + let eos_token_id = config.eos_token_id.or_else(|| { + tokenizer + .token_to_id(EOS_TOKEN) + .map(model::GraniteEosToks::Single) + }); + + let default_prompt = match args.model_type { + GraniteModel::Granite7bInstruct => DEFAULT_PROMPT, + }; + + let prompt = args.prompt.as_ref().map_or(default_prompt, |p| p.as_str()); + let mut tokens = tokenizer + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer); + + println!("Starting the inference loop:"); + print!("{prompt}"); + let mut logits_processor = { + let temperature = args.temperature; + let sampling = if temperature <= 0. { + Sampling::ArgMax + } else { + match (args.top_k, args.top_p) { + (None, None) => Sampling::All { temperature }, + (Some(k), None) => Sampling::TopK { k, temperature }, + (None, Some(p)) => Sampling::TopP { p, temperature }, + (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature }, + } + }; + LogitsProcessor::from_sampling(args.seed, sampling) + }; + + let mut start_gen = std::time::Instant::now(); + let mut index_pos = 0; + let mut token_generated = 0; + let use_cache_kv = cache.use_kv_cache; + + (0..args.sample_len) + .inspect(|index| { + if *index == 1 { + start_gen = Instant::now(); + } + }) + .try_for_each(|index| -> Result<()> { + let (context_size, context_index) = if use_cache_kv && index > 0 { + (1, index_pos) + } else { + (tokens.len(), 0) + }; + let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; + let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; + let logits = granite + .forward(&input, context_index, &mut cache)? + .squeeze(0)?; + + let logits = if args.repeat_penalty == 1. { + logits + } else { + let start_at = tokens.len().saturating_sub(args.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + args.repeat_penalty, + &tokens[start_at..], + )? + }; + + index_pos += ctxt.len(); + + let next_token = logits_processor.sample(&logits)?; + token_generated += 1; + tokens.push(next_token); + + if let Some(model::GraniteEosToks::Single(eos_tok_id)) = eos_token_id { + if next_token == eos_tok_id { + return Err(E::msg("EOS token found")); + } + } else if let Some(model::GraniteEosToks::Multiple(ref eos_ids)) = eos_token_id { + if eos_ids.contains(&next_token) { + return Err(E::msg("EOS token found")); + } + } + + if let Some(t) = tokenizer.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + Ok(()) + }) + .unwrap_or(()); + + if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? { + print!("{rest}"); + } + + let dt = start_gen.elapsed(); + println!( + "\n\n{} tokens generated ({} token/s)\n", + token_generated, + (token_generated - 1) as f64 / dt.as_secs_f64(), + ); + Ok(()) +} diff --git a/candle-transformers/src/models/granite.rs b/candle-transformers/src/models/granite.rs new file mode 100644 index 0000000000..6d25c339b2 --- /dev/null +++ b/candle-transformers/src/models/granite.rs @@ -0,0 +1,458 @@ +use super::with_tracing::{linear_no_bias as linear, Linear, RmsNorm}; +use candle::{DType, Device, IndexOp, Result, Tensor, D}; +use candle_nn::{embedding, Embedding, Module, VarBuilder}; +use std::{collections::HashMap, f32::consts::PI}; + +pub const DEFAULT_MAX_SEQ_LEN: usize = 4096; + +#[derive(Debug, Clone, serde::Deserialize, Default)] +pub enum GraniteRopeType { + #[serde(rename = "granite")] + Granite, + #[default] + #[serde(rename = "default")] + Default, +} + +#[derive(Debug, Clone, serde::Deserialize, Default)] +pub struct GraniteRopeConfig { + pub factor: f32, + pub low_freq_factor: f32, + pub high_freq_factor: f32, + pub original_max_position_embeddings: usize, + pub rope_type: GraniteRopeType, +} +#[derive(Debug, Clone, serde::Deserialize)] +#[serde(untagged)] +pub enum GraniteEosToks { + Single(u32), + Multiple(Vec), +} + +#[derive(Debug, Clone, serde::Deserialize)] +pub struct GraniteConfig { + pub hidden_size: usize, + pub intermediate_size: usize, + pub vocab_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub num_key_value_heads: Option, + pub rms_norm_eps: f64, + #[serde(default = "default_rope")] + pub rope_theta: f32, + pub bos_token_id: Option, + pub eos_token_id: Option, + pub rope_scaling: Option, + pub max_position_embeddings: usize, +} + +impl GraniteConfig { + pub fn num_key_value_heads(&self) -> usize { + self.num_key_value_heads.unwrap_or(self.num_attention_heads) + } +} + +fn default_rope() -> f32 { + 10_000.0 +} + +impl GraniteConfig { + pub fn into_config(self, use_flash_attn: bool) -> Config { + Config { + hidden_size: self.hidden_size, + intermediate_size: self.intermediate_size, + vocab_size: self.vocab_size, + num_hidden_layers: self.num_hidden_layers, + num_attention_heads: self.num_attention_heads, + num_key_value_heads: self.num_key_value_heads(), + rms_norm_eps: self.rms_norm_eps, + rope_theta: self.rope_theta, + use_flash_attn, + bos_token_id: self.bos_token_id, + eos_token_id: self.eos_token_id, + rope_scaling: self.rope_scaling, + max_position_embeddings: self.max_position_embeddings, + } + } +} + +#[derive(Debug, Clone)] +pub struct Config { + pub hidden_size: usize, + pub intermediate_size: usize, + pub vocab_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub num_key_value_heads: usize, + pub use_flash_attn: bool, + pub rms_norm_eps: f64, + pub rope_theta: f32, + pub bos_token_id: Option, + pub eos_token_id: Option, + pub rope_scaling: Option, + pub max_position_embeddings: usize, +} + +#[derive(Debug, Clone)] +pub struct Cache { + masks: HashMap, + pub use_kv_cache: bool, + kvs: Vec>, + cos: Tensor, + sin: Tensor, + device: Device, +} + +fn calculate_default_inv_freq(cfg: &Config) -> Vec { + let head_dim = cfg.hidden_size / cfg.num_attention_heads; + (0..head_dim) + .step_by(2) + .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32)) + .collect() +} + +impl Cache { + pub fn new(use_kv_cache: bool, dtype: DType, config: &Config, device: &Device) -> Result { + // precompute freqs_cis + let theta = match &config.rope_scaling { + None + | Some(GraniteRopeConfig { + rope_type: GraniteRopeType::Default, + .. + }) => calculate_default_inv_freq(config), + Some(rope_scaling) => { + let low_freq_wavelen = rope_scaling.original_max_position_embeddings as f32 + / rope_scaling.low_freq_factor; + let high_freq_wavelen = rope_scaling.original_max_position_embeddings as f32 + / rope_scaling.high_freq_factor; + + calculate_default_inv_freq(config) + .into_iter() + .map(|freq| { + let wavelen = 2. * PI / freq; + if wavelen < high_freq_wavelen { + freq + } else if wavelen > low_freq_wavelen { + freq / rope_scaling.factor + } else { + let smooth = (rope_scaling.original_max_position_embeddings as f32 + / wavelen + - rope_scaling.low_freq_factor) + / (rope_scaling.high_freq_factor - rope_scaling.low_freq_factor); + (1. - smooth) * freq / rope_scaling.factor + smooth * freq + } + }) + .collect::>() + } + }; + + let theta = Tensor::new(theta, device)?; + + let idx_theta = Tensor::arange(0, config.max_position_embeddings as u32, device)? + .to_dtype(DType::F32)? + .reshape((config.max_position_embeddings, 1))? + .matmul(&theta.reshape((1, theta.elem_count()))?)?; + let cos = idx_theta.cos()?.to_dtype(dtype)?; + let sin = idx_theta.sin()?.to_dtype(dtype)?; + Ok(Self { + masks: HashMap::new(), + use_kv_cache, + kvs: vec![None; config.num_hidden_layers], + device: device.clone(), + cos, + sin, + }) + } + + fn mask(&mut self, t: usize) -> Result { + if let Some(mask) = self.masks.get(&t) { + Ok(mask.clone()) + } else { + let mask: Vec<_> = (0..t) + .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) + .collect(); + let mask = Tensor::from_slice(&mask, (t, t), &self.device)?; + self.masks.insert(t, mask.clone()); + Ok(mask) + } + } +} + +#[derive(Debug, Clone)] +struct CausalSelfAttention { + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + o_proj: Linear, + num_attention_heads: usize, + num_key_value_heads: usize, + head_dim: usize, + use_flash_attn: bool, + span: tracing::Span, + span_rot: tracing::Span, + max_position_embeddings: usize, +} + +#[cfg(feature = "flash-attn")] +fn flash_attn( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + causal: bool, +) -> Result { + candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal) +} + +#[cfg(not(feature = "flash-attn"))] +fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result { + unimplemented!("compile with '--features flash-attn'") +} + +impl CausalSelfAttention { + fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize, cache: &Cache) -> Result { + let _enter = self.span_rot.enter(); + let (_b_sz, _, seq_len, _hidden_size) = x.dims4()?; + let cos = cache.cos.narrow(0, index_pos, seq_len)?; + let sin = cache.sin.narrow(0, index_pos, seq_len)?; + candle_nn::rotary_emb::rope(x, &cos, &sin) + } + + fn forward( + &self, + x: &Tensor, + index_pos: usize, + block_idx: usize, + cache: &mut Cache, + ) -> Result { + let _enter = self.span.enter(); + let (b_sz, seq_len, hidden_size) = x.dims3()?; + let q = self.q_proj.forward(x)?; + let k = self.k_proj.forward(x)?; + let v = self.v_proj.forward(x)?; + + let q = q + .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))? + .transpose(1, 2)? + .contiguous()?; + let k = k + .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))? + .transpose(1, 2)? + .contiguous()?; + let mut v = v + .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))? + .transpose(1, 2)?; + + let q = self.apply_rotary_emb(&q, index_pos, cache)?; + let mut k = self.apply_rotary_emb(&k, index_pos, cache)?; + + if cache.use_kv_cache { + if let Some((cache_k, cache_v)) = &cache.kvs[block_idx] { + k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?; + v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?; + let k_seq_len = k.dims()[1]; + if k_seq_len > self.max_position_embeddings { + k = k + .narrow( + D::Minus1, + k_seq_len - self.max_position_embeddings, + self.max_position_embeddings, + )? + .contiguous()? + } + let v_seq_len = v.dims()[1]; + if v_seq_len > 2 * self.max_position_embeddings { + v = v + .narrow( + D::Minus1, + v_seq_len - self.max_position_embeddings, + self.max_position_embeddings, + )? + .contiguous()? + } + } + cache.kvs[block_idx] = Some((k.clone(), v.clone())) + } + + let k = self.repeat_kv(k)?; + let v = self.repeat_kv(v)?; + + let y = if self.use_flash_attn { + // flash-attn expects (b_sz, seq_len, nheads, head_dim) + let q = q.transpose(1, 2)?; + let k = k.transpose(1, 2)?; + let v = v.transpose(1, 2)?; + let softmax_scale = 1f32 / (self.head_dim as f32).sqrt(); + flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?.transpose(1, 2)? + } else { + let in_dtype = q.dtype(); + let q = q.to_dtype(DType::F32)?; + let k = k.to_dtype(DType::F32)?; + let v = v.to_dtype(DType::F32)?; + let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; + let att = if seq_len == 1 { + att + } else { + let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?; + masked_fill(&att, &mask, f32::NEG_INFINITY)? + }; + let att = candle_nn::ops::softmax(&att, D::Minus1)?; + // Convert to contiguous as matmul doesn't support strided vs for now. + att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)? + }; + let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?; + let y = self.o_proj.forward(&y)?; + Ok(y) + } + + fn repeat_kv(&self, x: Tensor) -> Result { + crate::utils::repeat_kv(x, self.num_attention_heads / self.num_key_value_heads) + } + + fn load(vb: VarBuilder, cfg: &Config) -> Result { + let span = tracing::span!(tracing::Level::TRACE, "attn"); + let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); + let size_in = cfg.hidden_size; + let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads; + let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads; + let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?; + let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?; + let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?; + let o_proj = linear(size_q, size_in, vb.pp("o_proj"))?; + Ok(Self { + q_proj, + k_proj, + v_proj, + o_proj, + num_attention_heads: cfg.num_attention_heads, + num_key_value_heads: cfg.num_key_value_heads, + head_dim: cfg.hidden_size / cfg.num_attention_heads, + use_flash_attn: cfg.use_flash_attn, + span, + span_rot, + max_position_embeddings: cfg.max_position_embeddings, + }) + } +} + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { + let shape = mask.shape(); + let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; + let m = mask.where_cond(&on_true, on_false)?; + Ok(m) +} + +#[derive(Debug, Clone)] +struct Mlp { + c_fc1: Linear, + c_fc2: Linear, + c_proj: Linear, + span: tracing::Span, +} + +impl Mlp { + fn forward(&self, x: &Tensor) -> Result { + let _enter = self.span.enter(); + let x = (candle_nn::ops::silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?; + self.c_proj.forward(&x) + } + + fn load(vb: VarBuilder, cfg: &Config) -> Result { + let span = tracing::span!(tracing::Level::TRACE, "mlp"); + let h_size = cfg.hidden_size; + let i_size = cfg.intermediate_size; + let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?; + let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"))?; + let c_proj = linear(i_size, h_size, vb.pp("down_proj"))?; + Ok(Self { + c_fc1, + c_fc2, + c_proj, + span, + }) + } +} + +#[derive(Debug, Clone)] +struct Block { + rms_1: RmsNorm, + attn: CausalSelfAttention, + rms_2: RmsNorm, + mlp: Mlp, + span: tracing::Span, +} + +impl Block { + fn forward( + &self, + x: &Tensor, + index_pos: usize, + block_idx: usize, + cache: &mut Cache, + ) -> Result { + let _enter = self.span.enter(); + let residual = x; + let x = self.rms_1.forward(x)?; + let x = (self.attn.forward(&x, index_pos, block_idx, cache)? + residual)?; + let residual = &x; + let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?; + Ok(x) + } + + fn load(vb: VarBuilder, cfg: &Config) -> Result { + let span = tracing::span!(tracing::Level::TRACE, "block"); + let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg)?; + let mlp = Mlp::load(vb.pp("mlp"), cfg)?; + let rms_1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; + let rms_2 = RmsNorm::new( + cfg.hidden_size, + cfg.rms_norm_eps, + vb.pp("post_attention_layernorm"), + )?; + Ok(Self { + rms_1, + attn, + rms_2, + mlp, + span, + }) + } +} + +#[derive(Debug, Clone)] +pub struct Granite { + wte: Embedding, + blocks: Vec, + ln_f: RmsNorm, + lm_head: Linear, +} + +impl Granite { + pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result { + let (_b_sz, seq_len) = x.dims2()?; + let mut x = self.wte.forward(x)?; + for (block_idx, block) in self.blocks.iter().enumerate() { + x = block.forward(&x, index_pos, block_idx, cache)?; + } + let x = self.ln_f.forward(&x)?; + let x = x.i((.., seq_len - 1, ..))?.contiguous()?; + let logits = self.lm_head.forward(&x)?; + logits.to_dtype(DType::F32) + } + + pub fn load(vb: VarBuilder, cfg: &Config) -> Result { + let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?; + let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; + let ln_f = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?; + let blocks: Vec<_> = (0..cfg.num_hidden_layers) + .map(|i| Block::load(vb.pp(format!("model.layers.{i}")), cfg).unwrap()) + .collect(); + + Ok(Self { + wte, + blocks, + ln_f, + lm_head, + }) + } +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 07672bcc33..dae41a807c 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -24,6 +24,7 @@ pub mod flux; pub mod gemma; pub mod gemma2; pub mod glm4; +pub mod granite; pub mod hiera; pub mod jina_bert; pub mod llama;