forked from huggingface/candle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding Granite 7b Instruct model example (huggingface#2487)
* Adding Granite 7b Instruct model example * Minor refactoring to make it a little more idiomatic * Clippy fixes. * * Adding a README with some information about supported Granite models * Changing the default prompt to accomodate better the Language modality of the Granite 7b Instruct model --------- Co-authored-by: Laurent <[email protected]>
- Loading branch information
1 parent
c58c5d5
commit 5fc4f17
Showing
4 changed files
with
730 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# candle-granite LLMs from IBM Research | ||
|
||
[Granite](https://www.ibm.com/granite) is a family of Large Language Models built for business, to help drive trust and scalability in AI-driven applications. | ||
|
||
## Running the example | ||
|
||
```bash | ||
$ cargo run --example granite --features metal -r -- --model-type "granite7b-instruct" \ | ||
--prompt "Explain how quantum computing differs from classical computing, focusing on key concepts like qubits, superposition, and entanglement. Describe two potential breakthroughs in the fields of drug discovery and cryptography. Offer a convincing argument for why businesses and governments should invest in quantum computing research now, emphasizing its future benefits and the risks of falling behind" | ||
|
||
Explain how quantum computing differs from classical computing, focusing on key concepts like qubits, superposition, and entanglement. Describe two potential breakthroughs in the fields of drug discovery and cryptography. Offer a convincing argument for why businesses and governments should invest in quantum computing research now, emphasizing its future benefits and the risks of falling behind competitors. | ||
|
||
In recent years, there has been significant interest in quantum computing due to its potential to revolutionize various fields, including drug discovery, cryptography, and optimization problems. Quantum computers, which leverage the principles of quantum mechanics, differ fundamentally from classical computers. Here are some of the key differences: | ||
``` | ||
## Supported Models | ||
There are two different modalities for the Granite family models: Language and Code. | ||
### Granite for language | ||
1. [Granite 7b Instruct](https://huggingface.co/ibm-granite/granite-7b-instruct) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,251 @@ | ||
// An implementation of different Granite models https://www.ibm.com/granite | ||
|
||
#[cfg(feature = "accelerate")] | ||
extern crate accelerate_src; | ||
|
||
#[cfg(feature = "mkl")] | ||
extern crate intel_mkl_src; | ||
|
||
use anyhow::{bail, Error as E, Result}; | ||
use clap::{Parser, ValueEnum}; | ||
|
||
use candle::{DType, Tensor}; | ||
use candle_nn::VarBuilder; | ||
use candle_transformers::generation::{LogitsProcessor, Sampling}; | ||
use hf_hub::{api::sync::Api, Repo, RepoType}; | ||
use std::io::Write; | ||
|
||
use candle_transformers::models::granite as model; | ||
use model::{Granite, GraniteConfig}; | ||
|
||
use std::time::Instant; | ||
|
||
const EOS_TOKEN: &str = "</s>"; | ||
const DEFAULT_PROMPT: &str = "How Fault Tolerant Quantum Computers will help humanity?"; | ||
|
||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] | ||
enum GraniteModel { | ||
Granite7bInstruct, | ||
} | ||
|
||
#[derive(Parser, Debug)] | ||
#[command(author, version, about, long_about = None)] | ||
struct Args { | ||
/// Run on CPU rather than on GPU. | ||
#[arg(long)] | ||
cpu: bool, | ||
|
||
/// The temperature used to generate samples. | ||
#[arg(long, default_value_t = 0.8)] | ||
temperature: f64, | ||
|
||
/// Nucleus sampling probability cutoff. | ||
#[arg(long)] | ||
top_p: Option<f64>, | ||
|
||
/// Only sample among the top K samples. | ||
#[arg(long)] | ||
top_k: Option<usize>, | ||
|
||
/// The seed to use when generating random samples. | ||
#[arg(long, default_value_t = 299792458)] | ||
seed: u64, | ||
|
||
/// The length of the sample to generate (in tokens). | ||
#[arg(short = 'n', long, default_value_t = 10000)] | ||
sample_len: usize, | ||
|
||
/// Disable the key-value cache. | ||
#[arg(long)] | ||
no_kv_cache: bool, | ||
|
||
/// The initial prompt. | ||
#[arg(long)] | ||
prompt: Option<String>, | ||
|
||
/// Use different dtype than f16 | ||
#[arg(long)] | ||
dtype: Option<String>, | ||
|
||
/// Enable tracing (generates a trace-timestamp.json file). | ||
#[arg(long)] | ||
tracing: bool, | ||
|
||
#[arg(long)] | ||
model_id: Option<String>, | ||
|
||
#[arg(long)] | ||
revision: Option<String>, | ||
|
||
#[arg(long, default_value = "granite7b-instruct")] | ||
model_type: GraniteModel, | ||
|
||
#[arg(long)] | ||
use_flash_attn: bool, | ||
|
||
/// Penalty to be applied for repeating tokens, 1. means no penalty. | ||
#[arg(long, default_value_t = 1.1)] | ||
repeat_penalty: f32, | ||
|
||
/// The context size to consider for the repeat penalty. | ||
#[arg(long, default_value_t = 128)] | ||
repeat_last_n: usize, | ||
} | ||
|
||
fn main() -> Result<()> { | ||
use tokenizers::Tokenizer; | ||
use tracing_chrome::ChromeLayerBuilder; | ||
use tracing_subscriber::prelude::*; | ||
|
||
let args = Args::parse(); | ||
let _guard = if args.tracing { | ||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); | ||
tracing_subscriber::registry().with(chrome_layer).init(); | ||
Some(guard) | ||
} else { | ||
None | ||
}; | ||
|
||
let device = candle_examples::device(args.cpu)?; | ||
let dtype = match args.dtype.as_deref() { | ||
Some("f16") => DType::F16, | ||
Some("bf16") => DType::BF16, | ||
Some("f32") => DType::F32, | ||
Some(dtype) => bail!("Unsupported dtype {dtype}"), | ||
None => DType::F16, | ||
}; | ||
let (granite, tokenizer_filename, mut cache, config) = { | ||
let api = Api::new()?; | ||
let model_id = args.model_id.unwrap_or_else(|| match args.model_type { | ||
GraniteModel::Granite7bInstruct => "ibm-granite/granite-7b-instruct".to_string(), | ||
}); | ||
println!("loading the model weights from {model_id}"); | ||
let revision = args.revision.unwrap_or("main".to_string()); | ||
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision)); | ||
|
||
let tokenizer_filename = api.get("tokenizer.json")?; | ||
let config_filename = api.get("config.json")?; | ||
let config: GraniteConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?; | ||
let config = config.into_config(args.use_flash_attn); | ||
|
||
let filenames = match args.model_type { | ||
GraniteModel::Granite7bInstruct => { | ||
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")? | ||
} | ||
}; | ||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?; | ||
|
||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; | ||
( | ||
Granite::load(vb, &config)?, | ||
tokenizer_filename, | ||
cache, | ||
config, | ||
) | ||
}; | ||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; | ||
let eos_token_id = config.eos_token_id.or_else(|| { | ||
tokenizer | ||
.token_to_id(EOS_TOKEN) | ||
.map(model::GraniteEosToks::Single) | ||
}); | ||
|
||
let default_prompt = match args.model_type { | ||
GraniteModel::Granite7bInstruct => DEFAULT_PROMPT, | ||
}; | ||
|
||
let prompt = args.prompt.as_ref().map_or(default_prompt, |p| p.as_str()); | ||
let mut tokens = tokenizer | ||
.encode(prompt, true) | ||
.map_err(E::msg)? | ||
.get_ids() | ||
.to_vec(); | ||
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer); | ||
|
||
println!("Starting the inference loop:"); | ||
print!("{prompt}"); | ||
let mut logits_processor = { | ||
let temperature = args.temperature; | ||
let sampling = if temperature <= 0. { | ||
Sampling::ArgMax | ||
} else { | ||
match (args.top_k, args.top_p) { | ||
(None, None) => Sampling::All { temperature }, | ||
(Some(k), None) => Sampling::TopK { k, temperature }, | ||
(None, Some(p)) => Sampling::TopP { p, temperature }, | ||
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature }, | ||
} | ||
}; | ||
LogitsProcessor::from_sampling(args.seed, sampling) | ||
}; | ||
|
||
let mut start_gen = std::time::Instant::now(); | ||
let mut index_pos = 0; | ||
let mut token_generated = 0; | ||
let use_cache_kv = cache.use_kv_cache; | ||
|
||
(0..args.sample_len) | ||
.inspect(|index| { | ||
if *index == 1 { | ||
start_gen = Instant::now(); | ||
} | ||
}) | ||
.try_for_each(|index| -> Result<()> { | ||
let (context_size, context_index) = if use_cache_kv && index > 0 { | ||
(1, index_pos) | ||
} else { | ||
(tokens.len(), 0) | ||
}; | ||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; | ||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; | ||
let logits = granite | ||
.forward(&input, context_index, &mut cache)? | ||
.squeeze(0)?; | ||
|
||
let logits = if args.repeat_penalty == 1. { | ||
logits | ||
} else { | ||
let start_at = tokens.len().saturating_sub(args.repeat_last_n); | ||
candle_transformers::utils::apply_repeat_penalty( | ||
&logits, | ||
args.repeat_penalty, | ||
&tokens[start_at..], | ||
)? | ||
}; | ||
|
||
index_pos += ctxt.len(); | ||
|
||
let next_token = logits_processor.sample(&logits)?; | ||
token_generated += 1; | ||
tokens.push(next_token); | ||
|
||
if let Some(model::GraniteEosToks::Single(eos_tok_id)) = eos_token_id { | ||
if next_token == eos_tok_id { | ||
return Err(E::msg("EOS token found")); | ||
} | ||
} else if let Some(model::GraniteEosToks::Multiple(ref eos_ids)) = eos_token_id { | ||
if eos_ids.contains(&next_token) { | ||
return Err(E::msg("EOS token found")); | ||
} | ||
} | ||
|
||
if let Some(t) = tokenizer.next_token(next_token)? { | ||
print!("{t}"); | ||
std::io::stdout().flush()?; | ||
} | ||
Ok(()) | ||
}) | ||
.unwrap_or(()); | ||
|
||
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? { | ||
print!("{rest}"); | ||
} | ||
|
||
let dt = start_gen.elapsed(); | ||
println!( | ||
"\n\n{} tokens generated ({} token/s)\n", | ||
token_generated, | ||
(token_generated - 1) as f64 / dt.as_secs_f64(), | ||
); | ||
Ok(()) | ||
} |
Oops, something went wrong.