From 1c1817e06f4539f5058b63cada5228e1cfcb559e Mon Sep 17 00:00:00 2001 From: aoiasd Date: Thu, 7 Nov 2024 11:39:06 +0800 Subject: [PATCH] Add chinese and english analyzer with refactor jieba tokenizer Signed-off-by: aoiasd --- .../tantivy/tantivy-binding/Cargo.lock | 32 +++---- .../tantivy/tantivy-binding/Cargo.toml | 3 +- .../tantivy-binding/src/jieba_tokenizer.rs | 79 +++++++++++++++ .../tantivy/tantivy-binding/src/lib.rs | 2 + .../tantivy/tantivy-binding/src/stop_words.rs | 5 + .../tantivy/tantivy-binding/src/tokenizer.rs | 95 +++++++++++++------ .../tantivy-binding/src/tokenizer_filter.rs | 55 ++++++++++- .../tantivy/tantivy-binding/src/util.rs | 26 ++++- internal/util/function/bm25_function.go | 1 - 9 files changed, 242 insertions(+), 56 deletions(-) create mode 100644 internal/core/thirdparty/tantivy/tantivy-binding/src/jieba_tokenizer.rs create mode 100644 internal/core/thirdparty/tantivy/tantivy-binding/src/stop_words.rs diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/Cargo.lock b/internal/core/thirdparty/tantivy/tantivy-binding/Cargo.lock index a72e056522e8d..0232ecabe64f6 100644 --- a/internal/core/thirdparty/tantivy/tantivy-binding/Cargo.lock +++ b/internal/core/thirdparty/tantivy/tantivy-binding/Cargo.lock @@ -904,14 +904,14 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.4" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.6", - "regex-syntax 0.8.2", + "regex-automata 0.4.8", + "regex-syntax 0.8.5", ] [[package]] @@ -925,13 +925,13 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.6" +version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" +checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3" dependencies = [ "aho-corasick", "memchr", - "regex-syntax 0.8.2", + "regex-syntax 0.8.5", ] [[package]] @@ -942,9 +942,9 @@ checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" [[package]] name = "regex-syntax" -version = "0.8.2" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "rust-stemmers" @@ -1163,13 +1163,14 @@ dependencies = [ "cbindgen", "env_logger", "futures", + "jieba-rs", "lazy_static", "libc", "log", + "regex", "scopeguard", "serde_json", "tantivy", - "tantivy-jieba", "zstd-sys", ] @@ -1222,17 +1223,6 @@ dependencies = [ "utf8-ranges", ] -[[package]] -name = "tantivy-jieba" -version = "0.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44022293c12a8f878e03439b2f11806d3d394130fe33d4e7781cba91abbac0a4" -dependencies = [ - "jieba-rs", - "lazy_static", - "tantivy-tokenizer-api", -] - [[package]] name = "tantivy-query-grammar" version = "0.21.0" diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/Cargo.toml b/internal/core/thirdparty/tantivy/tantivy-binding/Cargo.toml index 6b26b3ab67e7e..c0c857b168fbf 100644 --- a/internal/core/thirdparty/tantivy/tantivy-binding/Cargo.toml +++ b/internal/core/thirdparty/tantivy/tantivy-binding/Cargo.toml @@ -13,9 +13,10 @@ scopeguard = "1.2" zstd-sys = "=2.0.9" env_logger = "0.11.3" log = "0.4.21" -tantivy-jieba = "0.10.0" lazy_static = "1.4.0" serde_json = "1.0.128" +jieba-rs = "0.6.8" +regex = "1.11.1" [build-dependencies] cbindgen = "0.26.0" diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/src/jieba_tokenizer.rs b/internal/core/thirdparty/tantivy/tantivy-binding/src/jieba_tokenizer.rs new file mode 100644 index 0000000000000..c608dc339525d --- /dev/null +++ b/internal/core/thirdparty/tantivy/tantivy-binding/src/jieba_tokenizer.rs @@ -0,0 +1,79 @@ +use jieba_rs; +use tantivy::tokenizer::{Token, TokenStream, Tokenizer}; + +#[derive(Clone)] +pub enum JiebaMode { + Exact, + Search, +} + +#[derive(Clone)] +pub struct JiebaTokenizer{ + tokenizer: jieba_rs::Jieba, + mode: JiebaMode, + hmm: bool, +} + +pub struct JiebaTokenStream { + tokens: Vec, + index: usize, +} + +impl TokenStream for JiebaTokenStream { + fn advance(&mut self) -> bool { + if self.index < self.tokens.len() { + self.index += 1; + true + } else { + false + } + } + + fn token(&self) -> &Token { + &self.tokens[self.index - 1] + } + + fn token_mut(&mut self) -> &mut Token { + &mut self.tokens[self.index - 1] + } +} + +impl JiebaTokenizer { + pub fn new() -> JiebaTokenizer{ + JiebaTokenizer{tokenizer:jieba_rs::Jieba::new(), mode: JiebaMode::Search, hmm: true} + } + + fn tokenize(&self, text: &str) -> Vec{ + let mut indices = text.char_indices().collect::>(); + indices.push((text.len(), '\0')); + let ori_tokens = match self.mode{ + JiebaMode::Exact => { + self.tokenizer.tokenize(text, jieba_rs::TokenizeMode::Default, self.hmm) + }, + JiebaMode::Search => { + self.tokenizer.tokenize(text, jieba_rs::TokenizeMode::Search, self.hmm) + }, + }; + + let mut tokens = Vec::new(); + for token in ori_tokens { + tokens.push(Token { + offset_from: indices[token.start].0, + offset_to: indices[token.end].0, + position: token.start, + text: String::from(&text[(indices[token.start].0)..(indices[token.end].0)]), + position_length: token.end - token.start, + }); + } + tokens + } +} + +impl Tokenizer for JiebaTokenizer { + type TokenStream<'a> = JiebaTokenStream; + + fn token_stream(&mut self, text: &str) -> JiebaTokenStream { + let tokens = self.tokenize(text); + JiebaTokenStream { tokens, index: 0 } + } +} \ No newline at end of file diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/src/lib.rs b/internal/core/thirdparty/tantivy/tantivy-binding/src/lib.rs index 90bfa80fd11c7..c789e66de84be 100644 --- a/internal/core/thirdparty/tantivy/tantivy-binding/src/lib.rs +++ b/internal/core/thirdparty/tantivy/tantivy-binding/src/lib.rs @@ -21,6 +21,8 @@ mod util; mod error; mod util_c; mod vec_collector; +mod stop_words; +mod jieba_tokenizer; pub fn add(left: usize, right: usize) -> usize { left + right diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/src/stop_words.rs b/internal/core/thirdparty/tantivy/tantivy-binding/src/stop_words.rs new file mode 100644 index 0000000000000..ae78b86f12515 --- /dev/null +++ b/internal/core/thirdparty/tantivy/tantivy-binding/src/stop_words.rs @@ -0,0 +1,5 @@ +pub const ENGLISH: &[&str] = &[ + "a", "an", "and", "are", "as", "at", "be", "but", "by", "for", "if", "in", + "into", "is", "it", "no", "not", "of", "on", "or", "such", "that", "the", + "their", "then", "there", "these", "they", "this", "to", "was", "will", "with", +]; diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/src/tokenizer.rs b/internal/core/thirdparty/tantivy/tantivy-binding/src/tokenizer.rs index d831c9d918c6f..054cebd47d191 100644 --- a/internal/core/thirdparty/tantivy/tantivy-binding/src/tokenizer.rs +++ b/internal/core/thirdparty/tantivy/tantivy-binding/src/tokenizer.rs @@ -1,18 +1,42 @@ use log::warn; use std::collections::HashMap; use tantivy::tokenizer::*; +use tantivy::tokenizer::StopWordFilter; use serde_json as json; +use crate::stop_words; use crate::tokenizer_filter::*; +use crate::jieba_tokenizer::JiebaTokenizer; use crate::error::TantivyError; use crate::util::*; // default build-in analyzer pub(crate) fn standard_analyzer(stop_words: Vec) -> TextAnalyzer { + let builder = standard_builder() + .filter(LowerCaser); + + if stop_words.len() > 0{ + return builder.filter(StopWordFilter::remove(stop_words)).build(); + } + + builder.build() +} + +fn chinese_analzyzer(stop_words: Vec) -> TextAnalyzer{ + let builder = jieba_builder().filter(CnCharOnlyFilter); + if stop_words.len() > 0{ + return builder.filter(StopWordFilter::remove(stop_words)).build(); + } + + builder.build() +} + +fn english_analyzer(stop_words: Vec) -> TextAnalyzer{ let builder = standard_builder() .filter(LowerCaser) - .filter(RemoveLongFilter::limit(40)); + .filter(Stemmer::new(Language::English)) + .filter(StopWordFilter::remove(stop_words::ENGLISH.iter().map(|&word| word.to_owned()))); if stop_words.len() > 0{ return builder.filter(StopWordFilter::remove(stop_words)).build(); @@ -29,10 +53,15 @@ fn whitespace_builder()-> TextAnalyzerBuilder{ TextAnalyzer::builder(WhitespaceTokenizer::default()).dynamic() } +fn jieba_builder() -> TextAnalyzerBuilder{ + TextAnalyzer::builder(JiebaTokenizer::new()).dynamic() +} + fn get_builder_by_name(name:&String) -> Result{ match name.as_str() { "standard" => Ok(standard_builder()), "whitespace" => Ok(whitespace_builder()), + "jieba" => Ok(jieba_builder()), other => { warn!("unsupported tokenizer: {}", other); Err(format!("unsupported tokenizer: {}", other).into()) @@ -92,6 +121,7 @@ impl AnalyzerBuilder<'_>{ } let filters = params.as_array().unwrap(); + for filter in filters{ if filter.is_string(){ let filter_name = filter.as_str().unwrap(); @@ -127,30 +157,34 @@ impl AnalyzerBuilder<'_>{ // build with filter if filter param exist builder=self.build_filter(builder, value)?; }, - "max_token_length" => { - if !value.is_u64(){ - return Err("max token length should be int type".into()); - } - builder = builder.filter_dynamic(RemoveLongFilter::limit(value.as_u64().unwrap() as usize)); - } other => return Err(format!("unknown analyzer option key: {}", other).into()), } } Ok(builder) } + fn get_stop_words_option(&self) -> Result, TantivyError>{ + let value = self.params.get("stop_words"); + match value{ + Some(value)=>{ + let str_list = get_string_list(value, "filter stop_words")?; + Ok(get_stop_words_list(str_list)) + } + None => Ok(vec![]) + } + } + fn build_template(self, type_: &str)-> Result{ match type_{ "standard" => { - let value = self.params.get("stop_words"); - match value{ - Some(value)=>{ - let str_list = get_string_list(value, "filter stop_words")?; - Ok(standard_analyzer(str_list)) - } - None => Ok(standard_analyzer(vec![])) - } + Ok(standard_analyzer(self.get_stop_words_option()?)) }, + "chinese" => { + Ok(chinese_analzyzer(self.get_stop_words_option()?)) + }, + "english" => { + Ok(english_analyzer(self.get_stop_words_option()?)) + } other_ => Err(format!("unknown build-in analyzer type: {}", other_).into()) } } @@ -168,13 +202,7 @@ impl AnalyzerBuilder<'_>{ }; //build custom analyzer - let tokenizer_name = self.get_tokenizer_name()?; - - // jieba analyzer can't add filter. - if tokenizer_name == "jieba"{ - return Ok(tantivy_jieba::JiebaTokenizer{}.into()); - } - + let tokenizer_name = self.get_tokenizer_name()?; let mut builder=get_builder_by_name(&tokenizer_name)?; // build with option @@ -227,28 +255,37 @@ pub(crate) fn create_tokenizer(params: &String) -> Result builder.filter(filter).dynamic(), Self::AsciiFolding(filter) => builder.filter(filter).dynamic(), Self::AlphaNumOnly(filter) => builder.filter(filter).dynamic(), + Self::CnCharOnly(filter) => builder.filter(filter).dynamic(), Self::Length(filter) => builder.filter(filter).dynamic(), Self::Stop(filter) => builder.filter(filter).dynamic(), Self::Decompounder(filter) => builder.filter(filter).dynamic(), @@ -51,7 +54,7 @@ fn get_stop_words_filter(params: &json::Map)-> Result)-> Result{ @@ -125,6 +128,7 @@ impl From<&str> for SystemFilter{ "lowercase" => Self::LowerCase(LowerCaser), "asciifolding" => Self::AsciiFolding(AsciiFoldingFilter), "alphanumonly" => Self::AlphaNumOnly(AlphaNumOnlyFilter), + "cncharonly" => Self::CnCharOnly(CnCharOnlyFilter), _ => Self::Invalid, } } @@ -152,3 +156,52 @@ impl TryFrom<&json::Map> for SystemFilter { } } } + +pub struct CnCharOnlyFilter; + +pub struct CnCharOnlyFilterStream { + regex: regex::Regex, + tail: T, +} + +impl TokenFilter for CnCharOnlyFilter{ + type Tokenizer = CnCharOnlyFilterWrapper; + + fn transform(self, tokenizer: T) -> CnCharOnlyFilterWrapper { + CnCharOnlyFilterWrapper(tokenizer) + } +} + +#[derive(Clone)] +pub struct CnCharOnlyFilterWrapper(T); + +impl Tokenizer for CnCharOnlyFilterWrapper { + type TokenStream<'a> = CnCharOnlyFilterStream>; + + fn token_stream<'a>(&'a mut self, text: &'a str) -> Self::TokenStream<'a> { + CnCharOnlyFilterStream { + regex: regex::Regex::new("\\p{Han}+").unwrap(), + tail: self.0.token_stream(text), + } + } +} + +impl TokenStream for CnCharOnlyFilterStream { + fn advance(&mut self) -> bool { + while self.tail.advance() { + if self.regex.is_match(&self.tail.token().text) { + return true; + } + } + + false + } + + fn token(&self) -> &Token { + self.tail.token() + } + + fn token_mut(&mut self) -> &mut Token { + self.tail.token_mut() + } +} \ No newline at end of file diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/src/util.rs b/internal/core/thirdparty/tantivy/tantivy-binding/src/util.rs index e705b5df072b1..8e33b43214192 100644 --- a/internal/core/thirdparty/tantivy/tantivy-binding/src/util.rs +++ b/internal/core/thirdparty/tantivy/tantivy-binding/src/util.rs @@ -1,10 +1,11 @@ use std::ffi::c_void; use std::ops::Bound; use serde_json as json; -use crate::error::TantivyError; - use tantivy::{directory::MmapDirectory, Index}; +use crate::stop_words; +use crate::error::TantivyError; + pub fn index_exist(path: &str) -> bool { let dir = MmapDirectory::open(path).unwrap(); Index::exists(&dir).unwrap() @@ -45,4 +46,23 @@ pub(crate) fn get_string_list(value: &json::Value, label: &str) -> Result) -> Vec{ + let mut stop_words = Vec::new(); + for str in str_list{ + if str.len()>0 && str.chars().nth(0).unwrap() == '_'{ + match str.as_str(){ + "_english_" =>{ + for word in stop_words::ENGLISH{ + stop_words.push(word.to_string()); + } + continue; + } + _other => {} + } + } + stop_words.push(str); + } + stop_words +} diff --git a/internal/util/function/bm25_function.go b/internal/util/function/bm25_function.go index ebdd055d6c158..b7d04987abba8 100644 --- a/internal/util/function/bm25_function.go +++ b/internal/util/function/bm25_function.go @@ -64,7 +64,6 @@ func NewBM25FunctionRunner(coll *schemapb.CollectionSchema, schema *schemapb.Fun for _, field := range coll.GetFields() { if field.GetFieldID() == schema.GetOutputFieldIds()[0] { runner.outputField = field - break } if field.GetFieldID() == schema.GetInputFieldIds()[0] {