diff --git a/fuse/data/tokenizers/modular_tokenizer/inject_utils.py b/fuse/data/tokenizers/modular_tokenizer/inject_utils.py index 5a45af284..3e9f0a41d 100644 --- a/fuse/data/tokenizers/modular_tokenizer/inject_utils.py +++ b/fuse/data/tokenizers/modular_tokenizer/inject_utils.py @@ -36,6 +36,9 @@ class InjectorToModularTokenizerLib: for text following <@TOKENIZER-TYPE=SCALARS_FROM_DICT> is expected to be a key to the sample NDict for example: "blah.boo.banana" or "data.input.encoder_input" + for text following <@TOKENIZER-TYPE=EXTERNAL_EMBEDDINGS_FROM_DICT> is expected to be a key to the sample NDict + for example: "blah.boo.banana" or "data.input.encoder_input" + example usage: encoder_input: @@ -43,6 +46,10 @@ class InjectorToModularTokenizerLib: labels: <@TOKENIZER-TYPE=AA><@TOKENIZER-TYPE=SCALARS_LITERALS>0.3<@TOKENIZER-TYPE=AA><@TOKENIZER-TYPE=SCALARS_LITERALS>12.4<@TOKENIZER-TYPE=AA>ISGGDAIYSSTGRCSLGFNVRSGSTYYFLTAGICTDGATTWWANSARTTVLGTTSGSSFPNNDYGIVRYTNTTIPKDGTVGGQDITSAANATVGMAVTRRGSTTGTISGSVTALNATVNYGGGDVVYGMIRTNVCAEPGDSGGPLYSGTRAIGLTSGGSGNCSSGGTTFFQPVTEALVAYGVSVY + for embeddings from dict: + encoder_input: + <@TOKENIZER-TYPE=AA><1><8><@TOKENIZER-TYPE=AA><@TOKENIZER-TYPE=EXTERNAL_EMBEDDINGS_FROM_DICT>{protein1_key}<@TOKENIZER-TYPE=AA@MAX-LEN={max_len_1}>{protein_seq_1} + """ @staticmethod @@ -50,11 +57,14 @@ def build_placeholder_meta_tokenization( *, sequence: Union[str, list, tuple], sample_dict: Optional[NDict] = None, + default_sub_tokenizer_name: str = "AA", ) -> Tuple[str, List[str]]: """ In order to avoid modifying and rewriting the logic in modular tokenizer, especially regarding padding, limitation of max length of certain sub-parts, we put placeholders to make sure that the total size is known/fixed and respects the meta instructions to the modular tokenizer + default_sub_tokenizer_name: Specifies the name of the default sub-tokenizer. This tokenizer is used for handling special tokens, such as and . + Returns: a tuple with 2 elements ( a single string with the full query containing placeholder tokens for FLOAT and VECTOR meta tokenizer parts, @@ -90,8 +100,8 @@ def build_placeholder_meta_tokenization( ): if tokenizer_type.startswith("SCALARS_"): with_placeholders.append( - "<@TOKENIZER-TYPE=AA>" - ) # AA tokenizer selection is arbitrary, we only take the special token from it + f"<@TOKENIZER-TYPE={default_sub_tokenizer_name}>" + ) # tokenizer selection is arbitrary, we only take the special token from it if tokenizer_type == "SCALARS_LITERALS": values = subseq.split(",") @@ -115,7 +125,11 @@ def build_placeholder_meta_tokenization( raise Exception(f"tokenizer_type={tokenizer_type} is not supported") with_placeholders.append(seq) - + elif tokenizer_type.startswith("EXTERNAL_EMBEDDINGS_"): + with_placeholders.append( + f"<@TOKENIZER-TYPE={default_sub_tokenizer_name}>" + ) # tokenizer selection is arbitrary, we only take the special token from it + with_placeholders.append("") elif tokenizer_type.startswith("VECTORS_"): raise Exception("VECTOR_* are not supported yet") else: @@ -125,7 +139,7 @@ def build_placeholder_meta_tokenization( return "".join(with_placeholders), hints_and_subseq @staticmethod - def build_scalars( + def build_scalars_and_embeddings( *, per_meta_tokenizer_data: List[str], per_meta_encoding_including_placeholders: List[Encoding], @@ -157,6 +171,9 @@ def build_scalars( # for each element, whether it's a scalar or not all_scalars_valid_mask = [] scalar_default_unfound_value = -1000.0 + external_embeddings_info = dict() # a dict mapping location -> embedding input + num_tokens_token_so_far = 0 + num_inputs_needing_embeddings = 0 for tokenizer_name, curr_str_data, curr_placeholder_encoding in zip( per_meta_tokenizer_data[::2], @@ -175,12 +192,6 @@ def build_scalars( curr_scalar_values = torch.tensor( curr_scalar_values, dtype=torch.float32 ) - all_scalars_values.append(curr_scalar_values) - all_scalars_valid_mask.append( - torch.full_like( - curr_scalar_values, fill_value=True, dtype=torch.bool - ) - ) elif "SCALARS_FROM_DICT" == tokenizer_name: if sample_dict is None: raise Exception( @@ -188,22 +199,32 @@ def build_scalars( ) curr_scalar_values = sample_dict[curr_str_data] assert len(curr_scalar_values.shape) == 1 - all_scalars_values.append(curr_scalar_values) - all_scalars_valid_mask.append( - torch.full_like( - curr_scalar_values, fill_value=True, dtype=torch.bool - ) - ) - else: raise Exception( "Only supported SCALARS_* tokenizers are SCALARS_LITERALS and SCALARS_FROM_DICT" ) - - elif tokenizer_name.startswith("VECTORS_"): - raise NotImplementedError + all_scalars_values.append(curr_scalar_values) + all_scalars_valid_mask.append( + torch.full_like( + curr_scalar_values, fill_value=True, dtype=torch.bool + ) + ) + num_tokens_token_so_far += len(curr_scalar_values) else: - # prev_index_end += len(curr_placeholder_encoding.ids) + if tokenizer_name == "EXTERNAL_EMBEDDINGS_FROM_DICT": + if sample_dict is None: + raise Exception( + "EXTERNAL_EMBEDDINGS_FROM_DICT used but the provided sample_dict is None" + ) + embedding_input = sample_dict[curr_str_data] + external_embeddings_info[num_inputs_needing_embeddings] = ( + num_tokens_token_so_far, + embedding_input, + ) + num_inputs_needing_embeddings += 1 + elif tokenizer_name.startswith("VECTORS_"): + raise NotImplementedError + curr_scalar_values = torch.full( (len(curr_placeholder_encoding.ids),), fill_value=scalar_default_unfound_value, @@ -214,6 +235,7 @@ def build_scalars( curr_scalar_values, fill_value=False, dtype=torch.bool ) ) + num_tokens_token_so_far += len(curr_placeholder_encoding.ids) all_scalars_values = torch.concat(all_scalars_values) all_scalars_valid_mask = torch.concat(all_scalars_valid_mask) @@ -257,4 +279,5 @@ def build_scalars( return { "scalars_values": all_scalars_values, # 1d - its length is the number of actual scalars (provided) found "scalars_valid_mask": all_scalars_valid_mask, # 1d - values of provided scalars + "external_embeddings_info": external_embeddings_info, # dict - number of input needing embeddings -> (location in the query, embeddings input) } diff --git a/fuse/data/tokenizers/modular_tokenizer/op.py b/fuse/data/tokenizers/modular_tokenizer/op.py index 43861583f..fb622a07f 100644 --- a/fuse/data/tokenizers/modular_tokenizer/op.py +++ b/fuse/data/tokenizers/modular_tokenizer/op.py @@ -433,6 +433,10 @@ def __init__( verbose=verbose, **kwargs, ) + # default_sub_tokenizer_name is used as the tokenizer of special tokens such as scalars and external_embeddings tokens. + self.default_sub_tokenizer_name = next( + iter(self._tokenizer.tokenizers_info.values()) + )["name"] def __call__( self, @@ -447,6 +451,7 @@ def __call__( verbose: Optional[int] = 1, validate_ends_with_eos: Optional[bool] = None, key_out_scalars: Optional[str] = None, + key_out_external_embeddings_info: Optional[str] = None, additional_caller_info_text: Optional[str] = "", ) -> NDict: """_summary_ @@ -480,7 +485,9 @@ def __call__( with_placeholders_str, per_meta_orig, ) = InjectorToModularTokenizerLib.build_placeholder_meta_tokenization( - sequence=sample_dict[key_in], sample_dict=sample_dict + sequence=sample_dict[key_in], + sample_dict=sample_dict, + default_sub_tokenizer_name=self.default_sub_tokenizer_name, ) sample_dict[key_in + ".with_placeholders"] = with_placeholders_str @@ -500,7 +507,7 @@ def __call__( + ".per_meta_part_encoding", # using the key_in as base for the name because key_out_* are optional ) - prepared_data = InjectorToModularTokenizerLib.build_scalars( + prepared_data = InjectorToModularTokenizerLib.build_scalars_and_embeddings( per_meta_tokenizer_data=per_meta_orig, per_meta_encoding_including_placeholders=sample_dict[ key_in + ".per_meta_part_encoding" @@ -514,6 +521,10 @@ def __call__( sample_dict[key_out_scalars + ".valid_mask"] = prepared_data[ "scalars_valid_mask" ] + if key_out_external_embeddings_info is not None: + sample_dict[key_out_external_embeddings_info] = prepared_data[ + "external_embeddings_info" + ] return sample_dict