diff --git a/fuse/data/tokenizers/modular_tokenizer/inject_utils.py b/fuse/data/tokenizers/modular_tokenizer/inject_utils.py index 540e3cff5..c7a9a2dd1 100644 --- a/fuse/data/tokenizers/modular_tokenizer/inject_utils.py +++ b/fuse/data/tokenizers/modular_tokenizer/inject_utils.py @@ -7,6 +7,7 @@ TypedInput, list_to_tokenizer_string, ) +from warnings import warn class InjectorToModularTokenizerLib: @@ -28,17 +29,15 @@ class InjectorToModularTokenizerLib: supported syntax/format: for text following <@TOKENIZER-TYPE=SCALARS_LITERALS> supports the following format: - ',' separated float values and/or tokens - - for example: "2.7,3.99,-12.9" or "" or "2.19,,3.19," + ',' separated float values. For example: "2.7,3.99,-12.9" for text following <@TOKENIZER-TYPE=SCALARS_FROM_DICT> is expected to be a key to the sample NDict for example: "blah.boo.banana" or "data.input.encoder_input" - note: in SCALARS_FROM_DICT you can't describe masked scalars (outputs) you can only describe inputs example usage: encoder_input: - <@TOKENIZER-TYPE=AA><@TOKENIZER-TYPE=SCALARS_LITERALS>0.3<@TOKENIZER-TYPE=AA><@TOKENIZER-TYPE=SCALARS_LITERALS><@TOKENIZER-TYPE=AA>ISGGDAIYSSTGRCSLGFNVRSGSTYYFLTAGICTDGATTWWANSARTTVLGTTSGSSFPNNDYGIVRYTNTTIPKDGTVGGQDITSAANATVGMAVTRRGSTTGTISGSVTALNATVNYGGGDVVYGMIRTNVCAEPGDSGGPLYSGTRAIGLTSGGSGNCSSGGTTFFQPVTEALVAYGVSVY + <@TOKENIZER-TYPE=AA><@TOKENIZER-TYPE=SCALARS_LITERALS>0.3<@TOKENIZER-TYPE=AA><@TOKENIZER-TYPE=AA>ISGGDAIYSSTGRCSLGFNVRSGSTYYFLTAGICTDGATTWWANSARTTVLGTTSGSSFPNNDYGIVRYTNTTIPKDGTVGGQDITSAANATVGMAVTRRGSTTGTISGSVTALNATVNYGGGDVVYGMIRTNVCAEPGDSGGPLYSGTRAIGLTSGGSGNCSSGGTTFFQPVTEALVAYGVSVY labels: <@TOKENIZER-TYPE=AA><@TOKENIZER-TYPE=SCALARS_LITERALS>0.3<@TOKENIZER-TYPE=AA><@TOKENIZER-TYPE=SCALARS_LITERALS>12.4<@TOKENIZER-TYPE=AA>ISGGDAIYSSTGRCSLGFNVRSGSTYYFLTAGICTDGATTWWANSARTTVLGTTSGSSFPNNDYGIVRYTNTTIPKDGTVGGQDITSAANATVGMAVTRRGSTTGTISGSVTALNATVNYGGGDVVYGMIRTNVCAEPGDSGGPLYSGTRAIGLTSGGSGNCSSGGTTFFQPVTEALVAYGVSVY @@ -67,16 +66,15 @@ def build_placeholder_meta_tokenization( ) if len(sequence) > 0: if isinstance(sequence[0], TypedInput): - sequence_str = list_to_tokenizer_string( + sequence = list_to_tokenizer_string( sequence ) # currently supporting it in this simple way. Consider optimizing if it causes a bottleneck. else: raise Exception( f"Expected sequence to be either string or a list of TypedInput elements. Got a list, but the first element is of type {type(sequence[0])}" ) - else: - sequence_str = sequence - hints_and_subseq = re.split("<@TOKENIZER-TYPE=([^>]*)>", sequence_str)[ + + hints_and_subseq = re.split("<@TOKENIZER-TYPE=([^>]*)>", sequence)[ 1: ] # the first element is blank - removing it assert ( @@ -91,19 +89,18 @@ def build_placeholder_meta_tokenization( if tokenizer_type.startswith("SCALARS_"): with_placeholders.append( "<@TOKENIZER-TYPE=AA>" - ) # won't use AA tokens, just an arbitrary one to be able to use a token like + ) # AA tokenizer selection is arbitrary, we only take the special token from it - if ( - tokenizer_type == "SCALARS_LITERALS" - ): # note: masking is only supported in literals (not in "from dict") + if tokenizer_type == "SCALARS_LITERALS": values = subseq.split(",") - # seq = "" * len(values) - seq = "".join( - [ - "" if x == "" else "" - for x in values - ] - ) + # validate that all values can be converted to float + try: + [float(x) for x in values] + except: + raise ValueError( + f'expected a string with "," separated values that can each be converted to float. Got {subseq}' + ) + seq = "" * len(values) elif tokenizer_type == "SCALARS_FROM_DICT": if sample_dict is None: raise Exception( @@ -126,11 +123,13 @@ def build_placeholder_meta_tokenization( return "".join(with_placeholders), hints_and_subseq @staticmethod - def prepare_info_for_model_step( + def build_scalars( *, per_meta_tokenizer_data: List[str], per_meta_encoding_including_placeholders: List[Encoding], + token_ids: List[int], sample_dict: Optional[NDict] = None, + crop_report: str = "warn", ) -> Dict: """ since we: @@ -144,13 +143,18 @@ def prepare_info_for_model_step( per_meta_encoding_including_placeholders: a list of Encoding elements. This is used to extract per tokenizer final tokens num (after all of the padding and cropping logic was already done) sample_dict: a fuse sample_dict - optional. needed only if the meta tokenizer instruction uses a syntax of lookup from the dictionary - + crop_report: one of None (no action), 'warn' - print a warning, 'raise' - raise an exception + will be triggered if cropping happened """ - scalars_indices = [] - scalars_values = [] - scalars_masked_indices = [] - prev_index_end = -1 + assert crop_report in ["warn", "raise", None] + ## both `all_scalars_values` and `all_scalars_valid_mask` will contain torch tensors, which will be concatanated in the end of this function + + # one scalar for every element, `scalar_default_unfound_value` is used for elements that aren't scalars + all_scalars_values = [] + # for each element, whether it's a scalar or not + all_scalars_valid_mask = [] + scalar_default_unfound_value = -1000.0 for tokenizer_name, curr_str_data, curr_placeholder_encoding in zip( per_meta_tokenizer_data[::2], @@ -165,42 +169,30 @@ def prepare_info_for_model_step( f"should match expected length. Found length {len(curr_str_data)} but placeholders length was {len(curr_placeholder_encoding.ids)}" ) - curr_indices = [] - curr_data = [] - - for i, val in enumerate(curr_str_data): - if val != "": - curr_indices.append(i + prev_index_end + 1) - curr_data.append(float(val)) - else: - scalars_masked_indices.append(i + prev_index_end + 1) - - if len(curr_indices) > 0: - curr_indices = torch.tensor(curr_indices, dtype=torch.int64) - curr_data = torch.tensor(curr_data, dtype=torch.float32) - - scalars_indices.append(curr_indices) - scalars_values.append(curr_data) - - assert len(curr_data.shape) == 1 - - prev_index_end += len(curr_str_data) + curr_scalar_values = [float(val) for val in curr_str_data] + curr_scalar_values = torch.tensor( + curr_scalar_values, dtype=torch.float32 + ) + all_scalars_values.append(curr_scalar_values) + all_scalars_valid_mask.append( + torch.full_like( + curr_scalar_values, fill_value=True, dtype=torch.bool + ) + ) elif "SCALARS_FROM_DICT" == tokenizer_name: if sample_dict is None: raise Exception( "SCALARS_FROM_DICT used but the provided sample_dict is None" ) - curr_data = sample_dict[curr_str_data] - assert len(curr_data.shape) == 1 - curr_indices = torch.arange( - prev_index_end + 1, prev_index_end + 1 + curr_data.shape[0] + curr_scalar_values = sample_dict[curr_str_data] + assert len(curr_scalar_values.shape) == 1 + all_scalars_values.append(curr_scalar_values) + all_scalars_valid_mask.append( + torch.full_like( + curr_scalar_values, fill_value=True, dtype=torch.bool + ) ) - scalars_indices.append(curr_indices) - scalars_values.append(curr_data) - - prev_index_end += curr_data.shape[0] - else: raise Exception( "Only supported SCALARS_* tokenizers are SCALARS_LITERALS and SCALARS_FROM_DICT" @@ -209,24 +201,58 @@ def prepare_info_for_model_step( elif tokenizer_name.startswith("VECTORS_"): raise NotImplementedError else: - prev_index_end += len(curr_placeholder_encoding.ids) - - if len(scalars_indices) > 0: - scalars_indices = torch.concat(scalars_indices) - scalars_values = torch.concat(scalars_values) - else: - scalars_indices = None - scalars_values = None - - if len(scalars_masked_indices) > 0: - scalars_masked_indices = torch.tensor( - scalars_masked_indices, dtype=torch.int64 + # prev_index_end += len(curr_placeholder_encoding.ids) + curr_scalar_values = torch.full( + (len(curr_placeholder_encoding.ids),), + fill_value=scalar_default_unfound_value, + ) + all_scalars_values.append(curr_scalar_values) + all_scalars_valid_mask.append( + torch.full_like( + curr_scalar_values, fill_value=False, dtype=torch.bool + ) + ) + + all_scalars_values = torch.concat(all_scalars_values) + all_scalars_valid_mask = torch.concat(all_scalars_valid_mask) + + assert all_scalars_values.shape == all_scalars_valid_mask.shape + + # pad if needed + full_query_len = len(token_ids) + if full_query_len > all_scalars_values.shape[0]: + pad_len = full_query_len - all_scalars_values.shape[0] + all_scalars_values = torch.concat( + [ + all_scalars_values, + torch.full( + (pad_len,), + fill_value=scalar_default_unfound_value, + dtype=all_scalars_values.dtype, + ), + ] ) - else: - scalars_masked_indices = None + all_scalars_valid_mask = torch.concat( + [ + all_scalars_valid_mask, + torch.full( + (pad_len,), fill_value=False, dtype=all_scalars_valid_mask.dtype + ), + ] + ) + elif full_query_len < all_scalars_values.shape[0]: + if crop_report in ["warn", "raise"]: + _msg = f"warning: scalars sequence had to be cropped. The full (including all subtokenizers) length was {all_scalars_values.shape[0]} after cropping it is {full_query_len}" + if crop_report == "warn": + warn(_msg) + elif crop_report == "raise": + raise Exception(_msg) + else: + assert False, "should not get here" + all_scalars_values = all_scalars_values[:full_query_len] + all_scalars_valid_mask = all_scalars_valid_mask[:full_query_len] return { - "scalars_indices": scalars_indices, # 1d - its length is the number of actual scalars (provided) found - "scalars_values": scalars_values, # 1d - values of provided scalars - "scalars_masked_indices": scalars_masked_indices, # 1d - indices of masked scalars + "scalars_values": all_scalars_values, # 1d - its length is the number of actual scalars (provided) found + "scalars_valid_mask": all_scalars_valid_mask, # 1d - values of provided scalars } diff --git a/fuse/data/tokenizers/modular_tokenizer/op.py b/fuse/data/tokenizers/modular_tokenizer/op.py index 9ccf6650a..261dccebd 100644 --- a/fuse/data/tokenizers/modular_tokenizer/op.py +++ b/fuse/data/tokenizers/modular_tokenizer/op.py @@ -372,8 +372,7 @@ class ModularTokenizerOp(ModularTokenizerWithoutInjectOp): supported syntax/format: for text following <@TOKENIZER-TYPE=SCALARS_LITERALS> supports the following format: - ',' separated float values and/or tokens - - for example: "2.7,3.99,-12.9" or "" or "2.19,,3.19," + ',' separated float values for text following <@TOKENIZER-TYPE=SCALARS_FROM_DICT> is expected to be a key to the sample NDict for example: "blah.boo.banana" or "data.input.encoder_input" @@ -437,9 +436,7 @@ def __call__( on_unknown: Optional[str] = "warn", verbose: Optional[int] = 1, validate_ends_with_eos: Optional[bool] = None, - key_out_scalars_indices: Optional[str] = None, - key_out_scalars_values: Optional[str] = None, - key_out_masked_scalars_indices: Optional[str] = None, + key_out_scalars: Optional[str] = None, ) -> NDict: """_summary_ @@ -458,10 +455,10 @@ def __call__( verbose (Optional[int], optional): verbosity level. 0: no notification, 1: warning notification, 2: warning with partial data, 3: warning with full data. Defaults to 1. validate_ends_with_eos (Optional[bool], optional): if not None, overrides self._validate_ends_with_eos - key_out_scalars_inputs_indices:str optional - if provided, will write to sample_dict in this key a 1D torch tensor with indices of all inputs scalar elements. - key_out_scalars_inputs_values:str optional - if provided, will write to sample_dict in this key a 1D torch tensor with indices of all inputs scalar values. + key_out_scalars:str optional + if provided, will write to: + `sample_dict[f'{key_out_scalars}.values]` - a 1D torch tensor with all the scalars values + `sample_dict[f'{key_out_scalars}.valid_mask]` - a 1D torch boolean tensor representing which elements have scalar values Returns: NDict: _description_ @@ -490,39 +487,20 @@ def __call__( + ".per_meta_part_encoding", # using the key_in as base for the name because key_out_* are optional ) - prepared_data = InjectorToModularTokenizerLib.prepare_info_for_model_step( + prepared_data = InjectorToModularTokenizerLib.build_scalars( per_meta_tokenizer_data=per_meta_orig, per_meta_encoding_including_placeholders=sample_dict[ key_in + ".per_meta_part_encoding" ], + token_ids=sample_dict[key_out_tokens_ids], sample_dict=sample_dict, ) - if key_out_scalars_indices is not None: - sample_dict[key_out_scalars_indices] = prepared_data["scalars_indices"] - else: - if prepared_data["scalars_indices"] is not None: - raise Exception( - "non None scalars_indices found but no key_out_scalars_indices found" - ) - - if key_out_scalars_values is not None: - sample_dict[key_out_scalars_values] = prepared_data["scalars_values"] - else: - if prepared_data["scalars_values"] is not None: - raise Exception( - "non None scalars_value found but no key_out_scalars_values found" - ) - - if key_out_masked_scalars_indices is not None: - sample_dict[key_out_masked_scalars_indices] = prepared_data[ - "scalars_masked_indices" + if key_out_scalars is not None: + sample_dict[key_out_scalars + ".values"] = prepared_data["scalars_values"] + sample_dict[key_out_scalars + ".valid_mask"] = prepared_data[ + "scalars_valid_mask" ] - else: - if prepared_data["scalars_masked_indices"] is not None: - raise Exception( - "non None scalars_masked_indices found but no key_out_masked_scalars_indices found" - ) return sample_dict