diff --git a/requirements.txt b/requirements.txt index 0bd0a93..24c413c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,4 @@ progressbar2~=3.53.1 tokenizers>=0.10.3 transformers>=4.6.1 sudachipy>=0.6.2 -sudachidict_core>=20210802 +sudachidict_core==20210802.* \ No newline at end of file diff --git a/sudachitra/tokenization_bert_sudachipy.py b/sudachitra/tokenization_bert_sudachipy.py index 8173543..fccb0b8 100644 --- a/sudachitra/tokenization_bert_sudachipy.py +++ b/sudachitra/tokenization_bert_sudachipy.py @@ -152,22 +152,6 @@ def __init__( sudachipy_kwargs=None, **kwargs ): - super().__init__( - do_lower_case=do_lower_case, - do_nfkc=do_nfkc, - do_word_tokenize=do_word_tokenize, - do_subword_tokenize=do_subword_tokenize, - word_tokenizer_type=word_tokenizer_type, - subword_tokenizer_type=subword_tokenizer_type, - unk_token=unk_token, - sep_token=sep_token, - pad_token=pad_token, - cls_token=cls_token, - mask_token=mask_token, - word_form_type=word_form_type, - sudachipy_kwargs=sudachipy_kwargs, - **kwargs, - ) if not os.path.isfile(vocab_file): raise ValueError(f"Can't find a vocabulary file at path '{vocab_file}'.") @@ -196,12 +180,30 @@ def __init__( if subword_tokenizer_type == "pos_substitution": self.subword_tokenizer = None elif subword_tokenizer_type == "wordpiece": - self.subword_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token) + self.subword_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token)) elif subword_tokenizer_type == "character": - self.subword_tokenizer = CharacterTokenizer(vocab=self.vocab, unk_token=self.unk_token) + self.subword_tokenizer = CharacterTokenizer(vocab=self.vocab, unk_token=str(unk_token)) else: raise ValueError(f"Invalid subword_tokenizer_type '{subword_tokenizer_type}' is specified.") + super().__init__( + do_lower_case=do_lower_case, + do_nfkc=do_nfkc, + do_word_tokenize=do_word_tokenize, + do_subword_tokenize=do_subword_tokenize, + word_tokenizer_type=word_tokenizer_type, + subword_tokenizer_type=subword_tokenizer_type, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + word_form_type=word_form_type, + sudachipy_kwargs=sudachipy_kwargs, + **kwargs, + ) + + @property def do_lower_case(self): return self.lower_case