Skip to content

Commit

Permalink
Better handling of bos/eos.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Dec 4, 2024
1 parent 87ed1e8 commit 58b9fe9
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 10 deletions.
11 changes: 10 additions & 1 deletion yomikomi-pyo3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ struct Tokenize {
report_bpb: bool,
include_bos: bool,
include_eos: bool,
bos_id: Option<u32>,
eos_id: Option<u32>,
}

impl Iterable for Tokenize {
Expand All @@ -227,6 +229,8 @@ impl Iterable for Tokenize {
self.report_bpb,
self.include_bos,
self.include_eos,
self.bos_id,
self.eos_id,
)
.map_err(w)?;
Ok(StreamIter { stream: Box::new(stream) })
Expand Down Expand Up @@ -409,7 +413,8 @@ impl YkIterable {

/// Loads a sentencepiece tokenizer, and use it to tokenize the field passed as an argument of
/// this function.
#[pyo3(signature = (path, *, in_field="text".to_string(), out_field=None, report_bpb=true, include_bos=true, include_eos=false))]
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (path, *, in_field="text".to_string(), out_field=None, report_bpb=true, include_bos=true, include_eos=false, bos_id=None, eos_id=None))]
fn tokenize(
&self,
path: std::path::PathBuf,
Expand All @@ -418,6 +423,8 @@ impl YkIterable {
report_bpb: bool,
include_bos: bool,
include_eos: bool,
bos_id: Option<u32>,
eos_id: Option<u32>,
) -> PyResult<Self> {
let out_field = out_field.unwrap_or_else(|| in_field.clone());
let inner = Tokenize {
Expand All @@ -428,6 +435,8 @@ impl YkIterable {
report_bpb,
include_bos,
include_eos,
bos_id,
eos_id,
};
Ok(Self { inner: Arc::new(inner) })
}
Expand Down
25 changes: 16 additions & 9 deletions yomikomi/src/tokenize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,22 @@ use std::sync::{Arc, Mutex};
use tokenizers::tokenizer::Tokenizer;

enum Processor {
Tokenizers { inner: Box<Tokenizer>, bos_id: Option<u32>, eos_id: Option<u32> },
Tokenizers(Box<Tokenizer>),
SentencePiece(SentencePieceProcessor),
}

impl Processor {
fn bos_id(&self) -> Option<u32> {
match self {
Self::SentencePiece(p) => p.bos_id(),
Self::Tokenizers { inner: _, bos_id, eos_id: _ } => bos_id.as_ref().copied(),
Self::Tokenizers(_) => None,
}
}

fn eos_id(&self) -> Option<u32> {
match self {
Self::SentencePiece(p) => p.eos_id(),
Self::Tokenizers { inner: _, bos_id: _, eos_id } => eos_id.as_ref().copied(),
Self::Tokenizers(_) => None,
}
}

Expand All @@ -28,9 +28,7 @@ impl Processor {
Self::SentencePiece(p) => {
p.encode(str).map_err(E::wrap)?.iter().map(|v| v.id).collect()
}
Self::Tokenizers { inner, bos_id: _, eos_id: _ } => {
inner.encode(str, false)?.get_ids().to_vec()
}
Self::Tokenizers(p) => p.encode(str, false)?.get_ids().to_vec(),
};
Ok(tokens)
}
Expand All @@ -45,9 +43,12 @@ pub struct Tokenize<T> {
tokens_and_chars: Option<Mutex<(usize, usize)>>,
include_bos: bool,
include_eos: bool,
bos_id: Option<u32>,
eos_id: Option<u32>,
}

impl<T> Tokenize<T> {
#[allow(clippy::too_many_arguments)]
pub fn new<P: AsRef<std::path::Path>>(
path: P,
input: T,
Expand All @@ -56,11 +57,13 @@ impl<T> Tokenize<T> {
report_bpb: bool,
include_bos: bool,
include_eos: bool,
bos_id: Option<u32>,
eos_id: Option<u32>,
) -> Result<Self> {
let path = path.as_ref();
let processor = if path.extension().map_or(false, |v| v == "json") {
let inner = Box::new(Tokenizer::from_file(path)?);
Processor::Tokenizers { inner, bos_id: None, eos_id: None }
Processor::Tokenizers(inner)
} else {
Processor::SentencePiece(SentencePieceProcessor::open(path).map_err(E::wrap)?)
};
Expand All @@ -78,6 +81,8 @@ impl<T> Tokenize<T> {
tokens_and_chars,
include_bos,
include_eos,
bos_id,
eos_id,
})
}
}
Expand All @@ -102,7 +107,8 @@ impl<T: Stream> Stream for Tokenize<T> {
let text = String::from_utf8_lossy(values);
let mut all_tokens = Vec::new();
if self.include_bos {
if let Some(bos_id) = self.processor.bos_id() {
let bos_id = self.bos_id.or_else(|| self.processor.bos_id());
if let Some(bos_id) = bos_id {
all_tokens.push(bos_id)
}
}
Expand Down Expand Up @@ -130,7 +136,8 @@ impl<T: Stream> Stream for Tokenize<T> {
}
}
if self.include_eos {
if let Some(eos_id) = self.processor.eos_id() {
let eos_id = self.eos_id.or_else(|| self.processor.eos_id());
if let Some(eos_id) = eos_id {
all_tokens.push(eos_id)
}
}
Expand Down

0 comments on commit 58b9fe9

Please sign in to comment.