Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support new TOKENIZER-TYPE=EXTERNAL_EMBEDDINGS_FROM_DICT #387

Merged
merged 4 commits into from
Dec 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 44 additions & 21 deletions fuse/data/tokenizers/modular_tokenizer/inject_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,35 @@ class InjectorToModularTokenizerLib:
for text following <@TOKENIZER-TYPE=SCALARS_FROM_DICT> is expected to be a key to the sample NDict
for example: "blah.boo.banana" or "data.input.encoder_input"

for text following <@TOKENIZER-TYPE=EXTERNAL_EMBEDDINGS_FROM_DICT> is expected to be a key to the sample NDict
for example: "blah.boo.banana" or "data.input.encoder_input"

example usage:

encoder_input:
<@TOKENIZER-TYPE=AA><MOLECULAR_WEIGHT_IN_SOME_UNIT><@TOKENIZER-TYPE=SCALARS_LITERALS>0.3<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_NANOMOLAR><MASK><@TOKENIZER-TYPE=AA><SEQUENCE_NATURAL_START>ISGGDAIYSSTGRCSLGFNVRSGSTYYFLTAGICTDGATTWWANSARTTVLGTTSGSSFPNNDYGIVRYTNTTIPKDGTVGGQDITSAANATVGMAVTRRGSTTGTISGSVTALNATVNYGGGDVVYGMIRTNVCAEPGDSGGPLYSGTRAIGLTSGGSGNCSSGGTTFFQPVTEALVAYGVSVY<SEQUENCE_NATURAL_END>
labels:
<@TOKENIZER-TYPE=AA><MOLECULAR_WEIGHT_IN_SOME_UNIT><@TOKENIZER-TYPE=SCALARS_LITERALS>0.3<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_NANOMOLAR><@TOKENIZER-TYPE=SCALARS_LITERALS>12.4<@TOKENIZER-TYPE=AA><SEQUENCE_NATURAL_START>ISGGDAIYSSTGRCSLGFNVRSGSTYYFLTAGICTDGATTWWANSARTTVLGTTSGSSFPNNDYGIVRYTNTTIPKDGTVGGQDITSAANATVGMAVTRRGSTTGTISGSVTALNATVNYGGGDVVYGMIRTNVCAEPGDSGGPLYSGTRAIGLTSGGSGNCSSGGTTFFQPVTEALVAYGVSVY<SEQUENCE_NATURAL_END>

for embeddings from dict:
encoder_input:
<@TOKENIZER-TYPE=AA><BIOT5_TASK_ID><1><8><SENTINEL_ID_0><@TOKENIZER-TYPE=AA><MOLECULAR_ENTITY><MOLECULAR_ENTITY_GENERAL_PROTEIN><@TOKENIZER-TYPE=EXTERNAL_EMBEDDINGS_FROM_DICT>{protein1_key}<@TOKENIZER-TYPE=AA@MAX-LEN={max_len_1}><SEQUENCE_NATURAL_START>{protein_seq_1}<SEQUENCE_NATURAL_END><EOS>

"""

@staticmethod
def build_placeholder_meta_tokenization(
*,
sequence: Union[str, list, tuple],
sample_dict: Optional[NDict] = None,
default_sub_tokenizer_name: str = "AA",
bensha6757 marked this conversation as resolved.
Show resolved Hide resolved
) -> Tuple[str, List[str]]:
"""
In order to avoid modifying and rewriting the logic in modular tokenizer, especially regarding padding, limitation of max length of certain sub-parts,
we put placeholders to make sure that the total size is known/fixed and respects the meta instructions to the modular tokenizer

default_sub_tokenizer_name: Specifies the name of the default sub-tokenizer. This tokenizer is used for handling special tokens, such as <SCALAR> and <EMBEDDINGS>.

Returns: a tuple with 2 elements
(
a single string with the full query containing placeholder tokens for FLOAT and VECTOR meta tokenizer parts,
Expand Down Expand Up @@ -88,8 +98,8 @@ def build_placeholder_meta_tokenization(
):
if tokenizer_type.startswith("SCALARS_"):
with_placeholders.append(
"<@TOKENIZER-TYPE=AA>"
) # AA tokenizer selection is arbitrary, we only take the special token <SCALAR> from it
f"<@TOKENIZER-TYPE={default_sub_tokenizer_name}>"
) # tokenizer selection is arbitrary, we only take the special token <SCALAR> from it

if tokenizer_type == "SCALARS_LITERALS":
values = subseq.split(",")
Expand All @@ -113,7 +123,11 @@ def build_placeholder_meta_tokenization(
raise Exception(f"tokenizer_type={tokenizer_type} is not supported")

with_placeholders.append(seq)

elif tokenizer_type.startswith("EXTERNAL_EMBEDDINGS_"):
with_placeholders.append(
f"<@TOKENIZER-TYPE={default_sub_tokenizer_name}>"
) # tokenizer selection is arbitrary, we only take the special token <EMBEDDINGS> from it
with_placeholders.append("<EMBEDDINGS>")
elif tokenizer_type.startswith("VECTORS_"):
raise Exception("VECTOR_* are not supported yet")
else:
Expand All @@ -123,7 +137,7 @@ def build_placeholder_meta_tokenization(
return "".join(with_placeholders), hints_and_subseq

@staticmethod
def build_scalars(
def build_scalars_and_embeddings(
*,
per_meta_tokenizer_data: List[str],
per_meta_encoding_including_placeholders: List[Encoding],
Expand Down Expand Up @@ -155,6 +169,9 @@ def build_scalars(
# for each element, whether it's a scalar or not
all_scalars_valid_mask = []
scalar_default_unfound_value = -1000.0
external_embeddings_info = dict() # a dict mapping location -> embedding input
num_tokens_token_so_far = 0
bensha6757 marked this conversation as resolved.
Show resolved Hide resolved
num_inputs_needing_embeddings = 0

for tokenizer_name, curr_str_data, curr_placeholder_encoding in zip(
per_meta_tokenizer_data[::2],
Expand All @@ -173,35 +190,39 @@ def build_scalars(
curr_scalar_values = torch.tensor(
curr_scalar_values, dtype=torch.float32
)
all_scalars_values.append(curr_scalar_values)
all_scalars_valid_mask.append(
torch.full_like(
curr_scalar_values, fill_value=True, dtype=torch.bool
)
)
elif "SCALARS_FROM_DICT" == tokenizer_name:
if sample_dict is None:
raise Exception(
"SCALARS_FROM_DICT used but the provided sample_dict is None"
)
curr_scalar_values = sample_dict[curr_str_data]
assert len(curr_scalar_values.shape) == 1
all_scalars_values.append(curr_scalar_values)
all_scalars_valid_mask.append(
torch.full_like(
curr_scalar_values, fill_value=True, dtype=torch.bool
)
)

else:
raise Exception(
"Only supported SCALARS_* tokenizers are SCALARS_LITERALS and SCALARS_FROM_DICT"
)

elif tokenizer_name.startswith("VECTORS_"):
raise NotImplementedError
all_scalars_values.append(curr_scalar_values)
all_scalars_valid_mask.append(
torch.full_like(
curr_scalar_values, fill_value=True, dtype=torch.bool
)
)
num_tokens_token_so_far += len(curr_scalar_values)
else:
# prev_index_end += len(curr_placeholder_encoding.ids)
if tokenizer_name == "EXTERNAL_EMBEDDINGS_FROM_DICT":
if sample_dict is None:
raise Exception(
"EXTERNAL_EMBEDDINGS_FROM_DICT used but the provided sample_dict is None"
)
embedding_input = sample_dict[curr_str_data]
external_embeddings_info[num_inputs_needing_embeddings] = (
num_tokens_token_so_far,
embedding_input,
)
num_inputs_needing_embeddings += 1
elif tokenizer_name.startswith("VECTORS_"):
raise NotImplementedError

curr_scalar_values = torch.full(
(len(curr_placeholder_encoding.ids),),
fill_value=scalar_default_unfound_value,
Expand All @@ -212,6 +233,7 @@ def build_scalars(
curr_scalar_values, fill_value=False, dtype=torch.bool
)
)
num_tokens_token_so_far += len(curr_placeholder_encoding.ids)

all_scalars_values = torch.concat(all_scalars_values)
all_scalars_valid_mask = torch.concat(all_scalars_valid_mask)
Expand Down Expand Up @@ -255,4 +277,5 @@ def build_scalars(
return {
"scalars_values": all_scalars_values, # 1d - its length is the number of actual scalars (provided) found
"scalars_valid_mask": all_scalars_valid_mask, # 1d - values of provided scalars
"external_embeddings_info": external_embeddings_info, # dict - number of input needing embeddings -> (location in the query, embeddings input)
}
15 changes: 13 additions & 2 deletions fuse/data/tokenizers/modular_tokenizer/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,10 @@ def __init__(
verbose=verbose,
**kwargs,
)
# default_sub_tokenizer_name is used as the tokenizer of special tokens such as scalars and external_embeddings tokens.
self.default_sub_tokenizer_name = next(
bensha6757 marked this conversation as resolved.
Show resolved Hide resolved
iter(self._tokenizer.tokenizers_info.values())
)["name"]

def __call__(
self,
Expand All @@ -447,6 +451,7 @@ def __call__(
verbose: Optional[int] = 1,
validate_ends_with_eos: Optional[bool] = None,
key_out_scalars: Optional[str] = None,
key_out_external_embeddings_info: Optional[str] = None,
additional_caller_info_text: Optional[str] = "",
) -> NDict:
"""_summary_
Expand Down Expand Up @@ -480,7 +485,9 @@ def __call__(
with_placeholders_str,
per_meta_orig,
) = InjectorToModularTokenizerLib.build_placeholder_meta_tokenization(
sequence=sample_dict[key_in], sample_dict=sample_dict
sequence=sample_dict[key_in],
sample_dict=sample_dict,
default_sub_tokenizer_name=self.default_sub_tokenizer_name,
)
sample_dict[key_in + ".with_placeholders"] = with_placeholders_str

Expand All @@ -500,7 +507,7 @@ def __call__(
+ ".per_meta_part_encoding", # using the key_in as base for the name because key_out_* are optional
)

prepared_data = InjectorToModularTokenizerLib.build_scalars(
prepared_data = InjectorToModularTokenizerLib.build_scalars_and_embeddings(
per_meta_tokenizer_data=per_meta_orig,
per_meta_encoding_including_placeholders=sample_dict[
key_in + ".per_meta_part_encoding"
Expand All @@ -514,6 +521,10 @@ def __call__(
sample_dict[key_out_scalars + ".valid_mask"] = prepared_data[
"scalars_valid_mask"
]
if key_out_external_embeddings_info is not None:
sample_dict[key_out_external_embeddings_info] = prepared_data[
"external_embeddings_info"
]

return sample_dict

Expand Down
Loading