Skip to content

Commit

Permalink
add default sub tokenizer and improve logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben Shapira committed Dec 10, 2024
1 parent 1961bcc commit 2aafd32
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 42 deletions.
66 changes: 25 additions & 41 deletions fuse/data/tokenizers/modular_tokenizer/inject_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def build_placeholder_meta_tokenization(
*,
sequence: Union[str, list, tuple],
sample_dict: Optional[NDict] = None,
default_sub_tokenizer_name: str = "AA",
) -> Tuple[str, List[str]]:
"""
In order to avoid modifying and rewriting the logic in modular tokenizer, especially regarding padding, limitation of max length of certain sub-parts,
Expand Down Expand Up @@ -95,8 +96,8 @@ def build_placeholder_meta_tokenization(
):
if tokenizer_type.startswith("SCALARS_"):
with_placeholders.append(
"<@TOKENIZER-TYPE=AA>"
) # AA tokenizer selection is arbitrary, we only take the special token <SCALAR> from it
f"<@TOKENIZER-TYPE={default_sub_tokenizer_name}>"
) # tokenizer selection is arbitrary, we only take the special token <SCALAR> from it

if tokenizer_type == "SCALARS_LITERALS":
values = subseq.split(",")
Expand All @@ -122,8 +123,8 @@ def build_placeholder_meta_tokenization(
with_placeholders.append(seq)
elif tokenizer_type.startswith("EXTERNAL_EMBEDDINGS_"):
with_placeholders.append(
"<@TOKENIZER-TYPE=AA>"
) # AA tokenizer selection is arbitrary, we only take the special token <SCALAR> from it
f"<@TOKENIZER-TYPE={default_sub_tokenizer_name}>"
) # tokenizer selection is arbitrary, we only take the special token <EMBEDDINGS> from it
with_placeholders.append("<EMBEDDINGS>")
elif tokenizer_type.startswith("VECTORS_"):
raise Exception("VECTOR_* are not supported yet")
Expand Down Expand Up @@ -187,59 +188,43 @@ def build_scalars_and_embeddings(
curr_scalar_values = torch.tensor(
curr_scalar_values, dtype=torch.float32
)
all_scalars_values.append(curr_scalar_values)
all_scalars_valid_mask.append(
torch.full_like(
curr_scalar_values, fill_value=True, dtype=torch.bool
)
)
elif "SCALARS_FROM_DICT" == tokenizer_name:
if sample_dict is None:
raise Exception(
"SCALARS_FROM_DICT used but the provided sample_dict is None"
)
curr_scalar_values = sample_dict[curr_str_data]
assert len(curr_scalar_values.shape) == 1
all_scalars_values.append(curr_scalar_values)
all_scalars_valid_mask.append(
torch.full_like(
curr_scalar_values, fill_value=True, dtype=torch.bool
)
)

else:
raise Exception(
"Only supported SCALARS_* tokenizers are SCALARS_LITERALS and SCALARS_FROM_DICT"
)
num_tokens_token_so_far += len(curr_scalar_values)
elif tokenizer_name == "EXTERNAL_EMBEDDINGS_FROM_DICT":
if sample_dict is None:
raise Exception(
"EXTERNAL_EMBEDDINGS_FROM_DICT used but the provided sample_dict is None"
)
embedding_input = sample_dict[curr_str_data]
external_embeddings_info[num_inputs_needing_embeddings] = (
num_tokens_token_so_far,
embedding_input,
)

curr_scalar_values = torch.full(
(1,),
fill_value=scalar_default_unfound_value,
)
all_scalars_values.append(curr_scalar_values)
all_scalars_valid_mask.append(
torch.full_like(
curr_scalar_values, fill_value=False, dtype=torch.bool
curr_scalar_values, fill_value=True, dtype=torch.bool
)
)
num_tokens_token_so_far += 1
num_inputs_needing_embeddings += 1

elif tokenizer_name.startswith("VECTORS_"):
raise NotImplementedError
num_tokens_token_so_far += len(curr_scalar_values)
else:
# prev_index_end += len(curr_placeholder_encoding.ids)
if tokenizer_name == "EXTERNAL_EMBEDDINGS_FROM_DICT":
if sample_dict is None:
raise Exception(
"EXTERNAL_EMBEDDINGS_FROM_DICT used but the provided sample_dict is None"
)
embedding_input = sample_dict[curr_str_data]
external_embeddings_info[num_inputs_needing_embeddings] = (
num_tokens_token_so_far,
embedding_input,
)
num_tokens_token_so_far += 1
num_inputs_needing_embeddings += 1

elif tokenizer_name.startswith("VECTORS_"):
raise NotImplementedError
else:
num_tokens_token_so_far += len(curr_placeholder_encoding.ids)

curr_scalar_values = torch.full(
(len(curr_placeholder_encoding.ids),),
fill_value=scalar_default_unfound_value,
Expand All @@ -250,7 +235,6 @@ def build_scalars_and_embeddings(
curr_scalar_values, fill_value=False, dtype=torch.bool
)
)
num_tokens_token_so_far += len(curr_placeholder_encoding.ids)

all_scalars_values = torch.concat(all_scalars_values)
all_scalars_valid_mask = torch.concat(all_scalars_valid_mask)
Expand Down
7 changes: 6 additions & 1 deletion fuse/data/tokenizers/modular_tokenizer/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,9 @@ def __init__(
verbose=verbose,
**kwargs,
)
self.default_sub_tokenizer_name = next(
iter(self._tokenizer.tokenizers_info.values())
)["name"]

def __call__(
self,
Expand Down Expand Up @@ -481,7 +484,9 @@ def __call__(
with_placeholders_str,
per_meta_orig,
) = InjectorToModularTokenizerLib.build_placeholder_meta_tokenization(
sequence=sample_dict[key_in], sample_dict=sample_dict
sequence=sample_dict[key_in],
sample_dict=sample_dict,
default_sub_tokenizer_name=self.default_sub_tokenizer_name,
)
sample_dict[key_in + ".with_placeholders"] = with_placeholders_str

Expand Down

0 comments on commit 2aafd32

Please sign in to comment.