Skip to content

Commit

Permalink
Add tokenizers support.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Dec 4, 2024
1 parent fe86fb9 commit 53dd806
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 10 deletions.
1 change: 1 addition & 0 deletions yomikomi/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ sentencepiece = "0.11.2"
serde_json = "1.0.108"
symphonia = { version = "0.5.3", features = ["all-codecs"] }
thiserror = "1.0.50"
tokenizers = "0.21.0"
zstd = "0.13.0"
3 changes: 3 additions & 0 deletions yomikomi/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ pub enum Error {
#[error(transparent)]
Io(#[from] std::io::Error),

#[error(transparent)]
Tokenizers(#[from] tokenizers::tokenizer::Error),

/// Arbitrary errors wrapping.
#[error(transparent)]
Wrapped(Box<dyn std::error::Error + Send + Sync>),
Expand Down
60 changes: 50 additions & 10 deletions yomikomi/src/tokenize.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,43 @@
use crate::{Array, Error, Result, Stream};
use crate::{Array, Error as E, Result, Stream};
use sentencepiece::SentencePieceProcessor;
use std::sync::{Arc, Mutex};
use tokenizers::tokenizer::Tokenizer;

enum Processor {

Check failure on line 6 in yomikomi/src/tokenize.rs

View workflow job for this annotation

GitHub Actions / Clippy

large size difference between variants

Check failure on line 6 in yomikomi/src/tokenize.rs

View workflow job for this annotation

GitHub Actions / Clippy

large size difference between variants
Tokenizers { inner: Tokenizer, bos_id: Option<u32>, eos_id: Option<u32> },
SentencePiece(SentencePieceProcessor),
}

impl Processor {
fn bos_id(&self) -> Option<u32> {
match self {
Self::SentencePiece(p) => p.bos_id(),
Self::Tokenizers { inner: _, bos_id, eos_id: _ } => bos_id.as_ref().copied(),
}
}

fn eos_id(&self) -> Option<u32> {
match self {
Self::SentencePiece(p) => p.eos_id(),
Self::Tokenizers { inner: _, bos_id: _, eos_id } => eos_id.as_ref().copied(),
}
}

fn encode(&self, str: &str) -> Result<Vec<u32>> {
let tokens: Vec<_> = match self {
Self::SentencePiece(p) => {
p.encode(str).map_err(E::wrap)?.iter().map(|v| v.id).collect()
}
Self::Tokenizers { inner, bos_id: _, eos_id: _ } => {
inner.encode(str, false)?.get_ids().to_vec()
}
};
Ok(tokens)
}
}

pub struct Tokenize<T> {
spp: Arc<SentencePieceProcessor>,
processor: Arc<Processor>,
input: T,
in_key: String,
out_key: String,
Expand All @@ -23,14 +57,20 @@ impl<T> Tokenize<T> {
include_bos: bool,
include_eos: bool,
) -> Result<Self> {
let spp = SentencePieceProcessor::open(path).map_err(Error::wrap)?;
let nl_id = match spp.encode("\n").map_err(Error::wrap)?.last() {
let path = path.as_ref();
let processor = if path.extension().map_or(false, |v| v == "json") {
let inner = Tokenizer::from_file(path)?;
Processor::Tokenizers { inner, bos_id: None, eos_id: None }
} else {
Processor::SentencePiece(SentencePieceProcessor::open(path).map_err(E::wrap)?)
};
let nl_id = match processor.encode("\n").map_err(E::wrap)?.last() {
None => crate::bail!("no specific token id for newline"),
Some(p) => p.id,
Some(p) => *p,
};
let tokens_and_chars = if report_bpb { Some(Mutex::new((0, 0))) } else { None };
Ok(Self {
spp: Arc::new(spp),
processor: Arc::new(processor),
input,
in_key,
out_key,
Expand Down Expand Up @@ -62,7 +102,7 @@ impl<T: Stream> Stream for Tokenize<T> {
let text = String::from_utf8_lossy(values);
let mut all_tokens = Vec::new();
if self.include_bos {
if let Some(bos_id) = self.spp.bos_id() {
if let Some(bos_id) = self.processor.bos_id() {
all_tokens.push(bos_id)
}
}
Expand All @@ -72,7 +112,7 @@ impl<T: Stream> Stream for Tokenize<T> {
if idx > 0 {
all_tokens.push(self.nl_id)
}
let tokens = match self.spp.encode(text) {
let tokens = match self.processor.encode(text) {
Ok(tokens) => tokens,
Err(err) => {
eprintln!("tokenizer encode error {err:?}");
Expand All @@ -86,11 +126,11 @@ impl<T: Stream> Stream for Tokenize<T> {
bpb = Some(tokens_and_chars.0 as f64 / tokens_and_chars.1 as f64 / f64::ln(2.))
};
for token in tokens {
all_tokens.push(token.id)
all_tokens.push(token)
}
}
if self.include_eos {
if let Some(eos_id) = self.spp.eos_id() {
if let Some(eos_id) = self.processor.eos_id() {
all_tokens.push(eos_id)
}
}
Expand Down

0 comments on commit 53dd806

Please sign in to comment.