diff --git a/internal/core/src/segcore/tokenizer_c.cpp b/internal/core/src/segcore/tokenizer_c.cpp index 85a3cc39d4f55..a33a6bd9bfd85 100644 --- a/internal/core/src/segcore/tokenizer_c.cpp +++ b/internal/core/src/segcore/tokenizer_c.cpp @@ -10,6 +10,7 @@ // or implied. See the License for the specific language governing permissions and limitations under the License #include "segcore/tokenizer_c.h" +#include #include "common/FieldMeta.h" #include "common/protobuf_utils.h" #include "pb/schema.pb.h" @@ -30,6 +31,17 @@ create_tokenizer(CMap m, CTokenizer* tokenizer) { } } +CStatus +clone_tokenizer(CTokenizer* tokenizer, CTokenizer* rst) { + try { + auto impl = reinterpret_cast(*tokenizer); + *rst = impl->Clone().release(); + return milvus::SuccessCStatus(); + } catch (std::exception& e) { + return milvus::FailureCStatus(&e); + } +} + void free_tokenizer(CTokenizer tokenizer) { auto impl = reinterpret_cast(tokenizer); diff --git a/internal/core/src/segcore/tokenizer_c.h b/internal/core/src/segcore/tokenizer_c.h index 901689c5337ef..3f84da729efaa 100644 --- a/internal/core/src/segcore/tokenizer_c.h +++ b/internal/core/src/segcore/tokenizer_c.h @@ -26,6 +26,9 @@ typedef void* CTokenizer; CStatus create_tokenizer(CMap m, CTokenizer* tokenizer); +CStatus +clone_tokenizer(CTokenizer* tokenizer, CTokenizer* rst); + void free_tokenizer(CTokenizer tokenizer); diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/Cargo.lock b/internal/core/thirdparty/tantivy/tantivy-binding/Cargo.lock index 47872ac8120b8..a72e056522e8d 100644 --- a/internal/core/thirdparty/tantivy/tantivy-binding/Cargo.lock +++ b/internal/core/thirdparty/tantivy/tantivy-binding/Cargo.lock @@ -1021,11 +1021,12 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.115" +version = "1.0.128" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12dc5c46daa8e9fdf4f5e71b6cf9a53f2487da0e86e55808e2d35539666497dd" +checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" dependencies = [ "itoa", + "memchr", "ryu", "serde", ] @@ -1166,6 +1167,7 @@ dependencies = [ "libc", "log", "scopeguard", + "serde_json", "tantivy", "tantivy-jieba", "zstd-sys", diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/Cargo.toml b/internal/core/thirdparty/tantivy/tantivy-binding/Cargo.toml index 3bf9759d470f8..6b26b3ab67e7e 100644 --- a/internal/core/thirdparty/tantivy/tantivy-binding/Cargo.toml +++ b/internal/core/thirdparty/tantivy/tantivy-binding/Cargo.toml @@ -15,6 +15,7 @@ env_logger = "0.11.3" log = "0.4.21" tantivy-jieba = "0.10.0" lazy_static = "1.4.0" +serde_json = "1.0.128" [build-dependencies] cbindgen = "0.26.0" diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/include/tantivy-binding.h b/internal/core/thirdparty/tantivy/tantivy-binding/include/tantivy-binding.h index c443ec7fc7a0e..391cece60bccd 100644 --- a/internal/core/thirdparty/tantivy/tantivy-binding/include/tantivy-binding.h +++ b/internal/core/thirdparty/tantivy/tantivy-binding/include/tantivy-binding.h @@ -159,6 +159,8 @@ const char *tantivy_token_stream_get_token(void *token_stream); void *tantivy_create_tokenizer(void *tokenizer_params); +void *tantivy_clone_tokenizer(void *ptr); + void tantivy_free_tokenizer(void *tokenizer); bool tantivy_index_exist(const char *path); diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/src/lib.rs b/internal/core/thirdparty/tantivy/tantivy-binding/src/lib.rs index fd73108fd4954..f5df4dc10ff15 100644 --- a/internal/core/thirdparty/tantivy/tantivy-binding/src/lib.rs +++ b/internal/core/thirdparty/tantivy/tantivy-binding/src/lib.rs @@ -15,6 +15,7 @@ mod log; mod string_c; mod token_stream_c; mod tokenizer; +mod tokenizer_filter; mod tokenizer_c; mod util; mod util_c; diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/src/tokenizer.rs b/internal/core/thirdparty/tantivy/tantivy-binding/src/tokenizer.rs index 2e0d283947377..f48760c4ac936 100644 --- a/internal/core/thirdparty/tantivy/tantivy-binding/src/tokenizer.rs +++ b/internal/core/thirdparty/tantivy/tantivy-binding/src/tokenizer.rs @@ -1,7 +1,10 @@ use lazy_static::lazy_static; -use log::{info, warn}; +use log::warn; use std::collections::HashMap; -use tantivy::tokenizer::{TextAnalyzer, TokenizerManager}; +use tantivy::tokenizer::*; +use serde_json::{self as json, value}; + +use crate::tokenizer_filter::*; use crate::log::init_log; lazy_static! { @@ -12,32 +15,128 @@ pub(crate) fn default_tokenizer() -> TextAnalyzer { DEFAULT_TOKENIZER_MANAGER.get("default").unwrap() } -fn jieba_tokenizer() -> TextAnalyzer { - tantivy_jieba::JiebaTokenizer {}.into() +struct TantivyBuilder<'a>{ + // builder: TextAnalyzerBuilder + filters:HashMap, + params:&'a json::Map } -pub(crate) fn create_tokenizer(params: &HashMap) -> Option { - init_log(); +impl TantivyBuilder<'_>{ + fn new(params: &json::Map) -> TantivyBuilder{ + TantivyBuilder{ + filters: HashMap::new(), + params:params, + } + } + + fn add_costom_filter(&mut self, name: &String, params: &json::Map){ + match SystemFilter::try_from(params){ + Ok(filter) => {self.filters.insert(name.to_string(), filter);}, + Err(_e) => {}, + }; + } - match params.get("tokenizer") { - Some(tokenizer_name) => match tokenizer_name.as_str() { - "default" => { - Some(default_tokenizer()) + fn add_costom_filters(&mut self, params:&json::Map){ + for (name, value) in params{ + if !value.is_object(){ + continue; } + + self.add_costom_filter(name, value.as_object().unwrap()); + } + } + + fn build(mut self) -> Option{ + let tokenizer=self.params.get("tokenizer"); + if !tokenizer.is_none() && !tokenizer.unwrap().is_string(){ + return None; + } + + let tokenizer_name = { + if !tokenizer.is_none(){ + tokenizer.unwrap().as_str().unwrap() + }else{ + "standard" + } + }; + + match tokenizer_name { + "standard" => { + let mut builder = TextAnalyzer::builder(SimpleTokenizer::default()).dynamic(); + let filters= self.params.get("filter"); + if !filters.is_none() && filters.unwrap().is_array(){ + for filter in filters.unwrap().as_array().unwrap(){ + if filter.is_string(){ + let filter_name = filter.as_str().unwrap(); + let costum = self.filters.remove(filter_name); + if !costum.is_none(){ + builder = costum.unwrap().transform(builder); + continue; + } + // check if filter was system filter + let system = SystemFilter::from(filter_name); + match system { + SystemFilter::Invalid => { + log::warn!("build analyzer failed, filter not found :{}", filter_name); + return None + } + other => { + builder = other.transform(builder); + }, + } + } + } + } + Some(builder.build()) + } "jieba" => { - Some(jieba_tokenizer()) + Some(tantivy_jieba::JiebaTokenizer {}.into()) } s => { warn!("unsupported tokenizer: {}", s); None } - }, - None => { - Some(default_tokenizer()) } } } +pub(crate) fn create_tokenizer(params: &HashMap) -> Option { + init_log(); + + let analyzer_json_value = match params.get("analyzer"){ + Some(value) => { + let json_analyzer = json::from_str::(value); + if json_analyzer.is_err() { + return None; + } + let json_value = json_analyzer.unwrap(); + if !json_value.is_object(){ + return None + } + json_value + } + None => json::Value::Object(json::Map::::new()), + }; + + let analyzer_params= analyzer_json_value.as_object().unwrap(); + let mut builder = TantivyBuilder::new(analyzer_params); + let str_filter=params.get("filter"); + if !str_filter.is_none(){ + let json_filter = json::from_str::(str_filter.unwrap()); + if json_filter.is_err(){ + return None + } + + let filter_params = json_filter.unwrap(); + if !filter_params.is_object(){ + return None + } + + builder.add_costom_filters(filter_params.as_object().unwrap()); + } + builder.build() +} + #[cfg(test)] mod tests { use std::collections::HashMap; @@ -46,8 +145,12 @@ mod tests { #[test] fn test_create_tokenizer() { let mut params : HashMap = HashMap::new(); - params.insert("tokenizer".parse().unwrap(), "jieba".parse().unwrap()); + let analyzer_params = r#" + { + "tokenizer": "jieba" + }"#; + params.insert("analyzer".to_string(), analyzer_params.to_string()); let tokenizer = create_tokenizer(¶ms); assert!(tokenizer.is_some()); } diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/src/tokenizer_c.rs b/internal/core/thirdparty/tantivy/tantivy-binding/src/tokenizer_c.rs index c2caf097fc34c..ef572fcc4f2a6 100644 --- a/internal/core/thirdparty/tantivy/tantivy-binding/src/tokenizer_c.rs +++ b/internal/core/thirdparty/tantivy/tantivy-binding/src/tokenizer_c.rs @@ -20,6 +20,13 @@ pub extern "C" fn tantivy_create_tokenizer(tokenizer_params: *mut c_void) -> *mu } } +#[no_mangle] +pub extern "C" fn tantivy_clone_tokenizer(ptr: *mut c_void) -> *mut c_void { + let analyzer=ptr as *mut TextAnalyzer; + let clone = unsafe {(*analyzer).clone()}; + create_binding(clone) +} + #[no_mangle] pub extern "C" fn tantivy_free_tokenizer(tokenizer: *mut c_void) { free_binding::(tokenizer); diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/src/tokenizer_filter.rs b/internal/core/thirdparty/tantivy/tantivy-binding/src/tokenizer_filter.rs new file mode 100644 index 0000000000000..9d4c27aa15ae7 --- /dev/null +++ b/internal/core/thirdparty/tantivy/tantivy-binding/src/tokenizer_filter.rs @@ -0,0 +1,159 @@ +use tantivy::tokenizer::*; +use serde_json as json; + +pub(crate) enum SystemFilter{ + Invalid, + LowerCase(LowerCaser), + AsciiFolding(AsciiFoldingFilter), + AlphaNumOnly(AlphaNumOnlyFilter), + Length(RemoveLongFilter), + Stop(StopWordFilter), + Decompounder(SplitCompoundWords), + Stemmer(Stemmer) +} + +impl SystemFilter{ + pub(crate) fn transform(self, builder: TextAnalyzerBuilder) -> TextAnalyzerBuilder{ + match self{ + Self::LowerCase(filter) => builder.filter(filter).dynamic(), + Self::AsciiFolding(filter) => builder.filter(filter).dynamic(), + Self::AlphaNumOnly(filter) => builder.filter(filter).dynamic(), + Self::Length(filter) => builder.filter(filter).dynamic(), + Self::Stop(filter) => builder.filter(filter).dynamic(), + Self::Decompounder(filter) => builder.filter(filter).dynamic(), + Self::Stemmer(filter) => builder.filter(filter).dynamic(), + Self::Invalid => builder, + } + } +} + +// create length filter from params +// { +// "type": "length", +// "max": 10, // length +// } +// TODO support min length +fn get_length_filter(params: &json::Map) -> Result{ + let limit_str = params.get("max"); + if limit_str.is_none() || !limit_str.unwrap().is_u64(){ + return Err(()) + } + let limit = limit_str.unwrap().as_u64().unwrap() as usize; + Ok(SystemFilter::Length(RemoveLongFilter::limit(limit))) +} + +fn get_stop_filter(params: &json::Map)-> Result{ + let value = params.get("stop_words"); + if value.is_none() || !value.unwrap().is_array(){ + return Err(()) + } + + let stop_words= value.unwrap().as_array().unwrap(); + let mut str_list = Vec::::new(); + for element in stop_words{ + match element.as_str(){ + Some(word) => str_list.push(word.to_string()), + None => return Err(()) + } + }; + Ok(SystemFilter::Stop(StopWordFilter::remove(str_list))) +} + +fn get_decompounder_filter(params: &json::Map)-> Result{ + let value = params.get("word_list"); + if value.is_none() || !value.unwrap().is_array(){ + return Err(()) + } + + let stop_words= value.unwrap().as_array().unwrap(); + let mut str_list = Vec::::new(); + for element in stop_words{ + match element.as_str(){ + Some(word) => str_list.push(word.to_string()), + None => return Err(()) + } + }; + + match SplitCompoundWords::from_dictionary(str_list){ + Ok(f) => Ok(SystemFilter::Decompounder(f)), + Err(_e) => Err(()) + } +} + +fn get_stemmer_filter(params: &json::Map)-> Result{ + let value = params.get("language"); + if value.is_none() || !value.unwrap().is_string(){ + return Err(()) + } + + match value.unwrap().as_str().unwrap().into_language(){ + Ok(language) => Ok(SystemFilter::Stemmer(Stemmer::new(language))), + Err(_e) => Err(()), + } +} + +trait LanguageParser { + type Error; + fn into_language(self) -> Result; +} + +impl LanguageParser for &str { + type Error = (); + fn into_language(self) -> Result { + match self { + "arabig" => Ok(Language::Arabic), + "danish" => Ok(Language::Danish), + "dutch" => Ok(Language::Dutch), + "english" => Ok(Language::English), + "finnish" => Ok(Language::Finnish), + "french" => Ok(Language::French), + "german" => Ok(Language::German), + "greek" => Ok(Language::Greek), + "hungarian" => Ok(Language::Hungarian), + "italian" => Ok(Language::Italian), + "norwegian" => Ok(Language::Norwegian), + "portuguese" => Ok(Language::Portuguese), + "romanian" => Ok(Language::Romanian), + "russian" => Ok(Language::Russian), + "spanish" => Ok(Language::Spanish), + "swedish" => Ok(Language::Swedish), + "tamil" => Ok(Language::Tamil), + "turkish" => Ok(Language::Turkish), + _ => Err(()), + } + } +} + +impl From<&str> for SystemFilter{ + fn from(value: &str) -> Self { + match value{ + "lowercase" => Self::LowerCase(LowerCaser), + "asciifolding" => Self::AsciiFolding(AsciiFoldingFilter), + "alphanumonly" => Self::AlphaNumOnly(AlphaNumOnlyFilter), + _ => Self::Invalid, + } + } +} + +impl TryFrom<&json::Map> for SystemFilter { + type Error = (); + + fn try_from(params: &json::Map) -> Result { + match params.get(&"type".to_string()){ + Some(value) =>{ + if !value.is_string(){ + return Err(()); + }; + + match value.as_str().unwrap(){ + "length" => get_length_filter(params), + "stop" => get_stop_filter(params), + "decompounder" => get_decompounder_filter(params), + "stemmer" => get_stemmer_filter(params), + _other=>Err(()), + } + } + None => Err(()), + } + } +} diff --git a/internal/core/thirdparty/tantivy/tokenizer.h b/internal/core/thirdparty/tantivy/tokenizer.h index dd753205aa196..6f42eecbfcbe2 100644 --- a/internal/core/thirdparty/tantivy/tokenizer.h +++ b/internal/core/thirdparty/tantivy/tokenizer.h @@ -20,6 +20,9 @@ struct Tokenizer { } } + explicit Tokenizer(void* _ptr) : ptr_(_ptr) { + } + ~Tokenizer() { if (ptr_ != nullptr) { tantivy_free_tokenizer(ptr_); @@ -34,6 +37,12 @@ struct Tokenizer { return std::make_unique(token_stream, shared_text); } + std::unique_ptr + Clone() { + auto newptr = tantivy_clone_tokenizer(ptr_); + return std::make_unique(newptr); + } + // CreateTokenStreamCopyText will copy the text and then create token stream based on the text. std::unique_ptr CreateTokenStreamCopyText(const std::string& text) { diff --git a/internal/util/ctokenizer/c_tokenizer.go b/internal/util/ctokenizer/c_tokenizer.go index 915aa4cfa1938..e9f44aeb23a79 100644 --- a/internal/util/ctokenizer/c_tokenizer.go +++ b/internal/util/ctokenizer/c_tokenizer.go @@ -33,6 +33,15 @@ func (impl *CTokenizer) NewTokenStream(text string) tokenizerapi.TokenStream { return NewCTokenStream(ptr) } +func (impl *CTokenizer) Clone() (tokenizerapi.Tokenizer, error) { + var newptr C.CTokenizer + status := C.clone_tokenizer(&impl.ptr, &newptr) + if err := HandleCStatus(&status, "failed to clone tokenizer"); err != nil { + return nil, err + } + return NewCTokenizer(newptr), nil +} + func (impl *CTokenizer) Destroy() { C.free_tokenizer(impl.ptr) } diff --git a/internal/util/function/bm25_function.go b/internal/util/function/bm25_function.go index 275be8e412f29..225a3fa30893f 100644 --- a/internal/util/function/bm25_function.go +++ b/internal/util/function/bm25_function.go @@ -19,6 +19,7 @@ package function import ( + "encoding/json" "fmt" "sync" @@ -40,6 +41,28 @@ type BM25FunctionRunner struct { concurrency int } +// TODO Use json string instead map[string]string as tokenizer params +func getTokenizerParams(field *schemapb.FieldSchema) (map[string]string, error) { + result := map[string]string{} + for _, param := range field.GetTypeParams() { + if param.Key == "tokenizer_params" { + params := map[string]interface{}{} + err := json.Unmarshal([]byte(param.GetValue()), ¶ms) + if err != nil { + return nil, err + } + for key, param := range params { + bytes, err := json.Marshal(param) + if err != nil { + return nil, err + } + result[key] = string(bytes) + } + } + } + return result, nil +} + func NewBM25FunctionRunner(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema) (*BM25FunctionRunner, error) { if len(schema.GetOutputFieldIds()) != 1 { return nil, fmt.Errorf("bm25 function should only have one output field, but now %d", len(schema.GetOutputFieldIds())) @@ -49,17 +72,26 @@ func NewBM25FunctionRunner(coll *schemapb.CollectionSchema, schema *schemapb.Fun schema: schema, concurrency: 8, } + var params map[string]string for _, field := range coll.GetFields() { if field.GetFieldID() == schema.GetOutputFieldIds()[0] { runner.outputField = field break } + + if field.GetFieldID() == schema.GetInputFieldIds()[0] { + var err error + params, err = getTokenizerParams(field) + if err != nil { + return nil, err + } + } } if runner.outputField == nil { return nil, fmt.Errorf("no output field") } - tokenizer, err := ctokenizer.NewTokenizer(map[string]string{}) + tokenizer, err := ctokenizer.NewTokenizer(params) if err != nil { return nil, err } @@ -69,8 +101,7 @@ func NewBM25FunctionRunner(coll *schemapb.CollectionSchema, schema *schemapb.Fun } func (v *BM25FunctionRunner) run(data []string, dst []map[uint32]float32) error { - // TODO AOIASD Support single Tokenizer concurrency - tokenizer, err := ctokenizer.NewTokenizer(map[string]string{}) + tokenizer, err := v.tokenizer.Clone() if err != nil { return err } diff --git a/internal/util/tokenizerapi/tokenizer.go b/internal/util/tokenizerapi/tokenizer.go index 2b6debbec71f6..6dab31257122c 100644 --- a/internal/util/tokenizerapi/tokenizer.go +++ b/internal/util/tokenizerapi/tokenizer.go @@ -3,5 +3,6 @@ package tokenizerapi //go:generate mockery --name=Tokenizer --with-expecter type Tokenizer interface { NewTokenStream(text string) TokenStream + Clone() (Tokenizer, error) Destroy() }