Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tokenizers support. #1

Merged
merged 7 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ykpy-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
target: [x86_64, x86, aarch64, armv7, s390x, ppc64le]
target: [x86_64, x86]
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
Expand Down Expand Up @@ -69,7 +69,7 @@ jobs:
runs-on: macos-latest
strategy:
matrix:
target: [x86_64, aarch64]
target: [x86_64]
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ members = [
resolver = "2"

[workspace.package]
version = "0.2.0"
version = "0.3.0"
edition = "2021"
description = "Dataloader for training large text models."
repository = "https://github.com/kyutai-labs/yomikomi"
Expand Down
2 changes: 1 addition & 1 deletion yomikomi-pyo3/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ crate-type = ["cdylib"]
[dependencies]
numpy = "0.22.0"
pyo3 = "0.22.0"
yomikomi = { path = "../yomikomi", version = "0.2.0" }
yomikomi = { path = "../yomikomi", version = "0.3.0" }
13 changes: 12 additions & 1 deletion yomikomi-pyo3/py_src/yomikomi/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,18 @@ class YkIterable:
""" """
pass

def tokenize(self, path, *, in_field=..., out_field=None, report_bpb=True, include_bos=True, include_eos=False):
def tokenize(
self,
path,
*,
in_field=...,
out_field=None,
report_bpb=True,
include_bos=True,
include_eos=False,
bos_id=None,
eos_id=None
):
"""
Loads a sentencepiece tokenizer, and use it to tokenize the field passed as an argument of
this function.
Expand Down
11 changes: 10 additions & 1 deletion yomikomi-pyo3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ struct Tokenize {
report_bpb: bool,
include_bos: bool,
include_eos: bool,
bos_id: Option<u32>,
eos_id: Option<u32>,
}

impl Iterable for Tokenize {
Expand All @@ -227,6 +229,8 @@ impl Iterable for Tokenize {
self.report_bpb,
self.include_bos,
self.include_eos,
self.bos_id,
self.eos_id,
)
.map_err(w)?;
Ok(StreamIter { stream: Box::new(stream) })
Expand Down Expand Up @@ -409,7 +413,8 @@ impl YkIterable {

/// Loads a sentencepiece tokenizer, and use it to tokenize the field passed as an argument of
/// this function.
#[pyo3(signature = (path, *, in_field="text".to_string(), out_field=None, report_bpb=true, include_bos=true, include_eos=false))]
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (path, *, in_field="text".to_string(), out_field=None, report_bpb=true, include_bos=true, include_eos=false, bos_id=None, eos_id=None))]
fn tokenize(
&self,
path: std::path::PathBuf,
Expand All @@ -418,6 +423,8 @@ impl YkIterable {
report_bpb: bool,
include_bos: bool,
include_eos: bool,
bos_id: Option<u32>,
eos_id: Option<u32>,
) -> PyResult<Self> {
let out_field = out_field.unwrap_or_else(|| in_field.clone());
let inner = Tokenize {
Expand All @@ -428,6 +435,8 @@ impl YkIterable {
report_bpb,
include_bos,
include_eos,
bos_id,
eos_id,
};
Ok(Self { inner: Arc::new(inner) })
}
Expand Down
1 change: 1 addition & 0 deletions yomikomi/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ sentencepiece = "0.11.2"
serde_json = "1.0.108"
symphonia = { version = "0.5.3", features = ["all-codecs"] }
thiserror = "1.0.50"
tokenizers = "0.21.0"
zstd = "0.13.0"
3 changes: 3 additions & 0 deletions yomikomi/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ pub enum Error {
#[error(transparent)]
Io(#[from] std::io::Error),

#[error(transparent)]
Tokenizers(#[from] tokenizers::tokenizer::Error),

/// Arbitrary errors wrapping.
#[error(transparent)]
Wrapped(Box<dyn std::error::Error + Send + Sync>),
Expand Down
2 changes: 1 addition & 1 deletion yomikomi/src/strided_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ impl<'a> StridedIndex<'a> {
}
}

impl<'a> Iterator for StridedIndex<'a> {
impl Iterator for StridedIndex<'_> {
type Item = usize;

fn next(&mut self) -> Option<Self::Item> {
Expand Down
67 changes: 57 additions & 10 deletions yomikomi/src/tokenize.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,54 @@
use crate::{Array, Error, Result, Stream};
use crate::{Array, Error as E, Result, Stream};
use sentencepiece::SentencePieceProcessor;
use std::sync::{Arc, Mutex};
use tokenizers::tokenizer::Tokenizer;

enum Processor {
Tokenizers(Box<Tokenizer>),
SentencePiece(SentencePieceProcessor),
}

impl Processor {
fn bos_id(&self) -> Option<u32> {
match self {
Self::SentencePiece(p) => p.bos_id(),
Self::Tokenizers(_) => None,
}
}

fn eos_id(&self) -> Option<u32> {
match self {
Self::SentencePiece(p) => p.eos_id(),
Self::Tokenizers(_) => None,
}
}

fn encode(&self, str: &str) -> Result<Vec<u32>> {
let tokens: Vec<_> = match self {
Self::SentencePiece(p) => {
p.encode(str).map_err(E::wrap)?.iter().map(|v| v.id).collect()
}
Self::Tokenizers(p) => p.encode(str, false)?.get_ids().to_vec(),
};
Ok(tokens)
}
}

pub struct Tokenize<T> {
spp: Arc<SentencePieceProcessor>,
processor: Arc<Processor>,
input: T,
in_key: String,
out_key: String,
nl_id: u32,
tokens_and_chars: Option<Mutex<(usize, usize)>>,
include_bos: bool,
include_eos: bool,
bos_id: Option<u32>,
eos_id: Option<u32>,
}

impl<T> Tokenize<T> {
#[allow(clippy::too_many_arguments)]
pub fn new<P: AsRef<std::path::Path>>(
path: P,
input: T,
Expand All @@ -22,22 +57,32 @@ impl<T> Tokenize<T> {
report_bpb: bool,
include_bos: bool,
include_eos: bool,
bos_id: Option<u32>,
eos_id: Option<u32>,
) -> Result<Self> {
let spp = SentencePieceProcessor::open(path).map_err(Error::wrap)?;
let nl_id = match spp.encode("\n").map_err(Error::wrap)?.last() {
let path = path.as_ref();
let processor = if path.extension().map_or(false, |v| v == "json") {
let inner = Box::new(Tokenizer::from_file(path)?);
Processor::Tokenizers(inner)
} else {
Processor::SentencePiece(SentencePieceProcessor::open(path).map_err(E::wrap)?)
};
let nl_id = match processor.encode("\n").map_err(E::wrap)?.last() {
None => crate::bail!("no specific token id for newline"),
Some(p) => p.id,
Some(p) => *p,
};
let tokens_and_chars = if report_bpb { Some(Mutex::new((0, 0))) } else { None };
Ok(Self {
spp: Arc::new(spp),
processor: Arc::new(processor),
input,
in_key,
out_key,
nl_id,
tokens_and_chars,
include_bos,
include_eos,
bos_id,
eos_id,
})
}
}
Expand All @@ -62,7 +107,8 @@ impl<T: Stream> Stream for Tokenize<T> {
let text = String::from_utf8_lossy(values);
let mut all_tokens = Vec::new();
if self.include_bos {
if let Some(bos_id) = self.spp.bos_id() {
let bos_id = self.bos_id.or_else(|| self.processor.bos_id());
if let Some(bos_id) = bos_id {
all_tokens.push(bos_id)
}
}
Expand All @@ -72,7 +118,7 @@ impl<T: Stream> Stream for Tokenize<T> {
if idx > 0 {
all_tokens.push(self.nl_id)
}
let tokens = match self.spp.encode(text) {
let tokens = match self.processor.encode(text) {
Ok(tokens) => tokens,
Err(err) => {
eprintln!("tokenizer encode error {err:?}");
Expand All @@ -86,11 +132,12 @@ impl<T: Stream> Stream for Tokenize<T> {
bpb = Some(tokens_and_chars.0 as f64 / tokens_and_chars.1 as f64 / f64::ln(2.))
};
for token in tokens {
all_tokens.push(token.id)
all_tokens.push(token)
}
}
if self.include_eos {
if let Some(eos_id) = self.spp.eos_id() {
let eos_id = self.eos_id.or_else(|| self.processor.eos_id());
if let Some(eos_id) = eos_id {
all_tokens.push(eos_id)
}
}
Expand Down
Loading