Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support new TOKENIZER-TYPE=EXTERNAL_EMBEDDINGS_FROM_DICT #387

Merged
merged 4 commits into from
Dec 15, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 42 additions & 2 deletions fuse/data/tokenizers/modular_tokenizer/inject_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,20 @@ class InjectorToModularTokenizerLib:
for text following <@TOKENIZER-TYPE=SCALARS_FROM_DICT> is expected to be a key to the sample NDict
for example: "blah.boo.banana" or "data.input.encoder_input"

for text following <@TOKENIZER-TYPE=EXTERNAL_EMBEDDINGS_FROM_DICT> is expected to be a key to the sample NDict
for example: "blah.boo.banana" or "data.input.encoder_input"

example usage:

encoder_input:
<@TOKENIZER-TYPE=AA><MOLECULAR_WEIGHT_IN_SOME_UNIT><@TOKENIZER-TYPE=SCALARS_LITERALS>0.3<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_NANOMOLAR><MASK><@TOKENIZER-TYPE=AA><SEQUENCE_NATURAL_START>ISGGDAIYSSTGRCSLGFNVRSGSTYYFLTAGICTDGATTWWANSARTTVLGTTSGSSFPNNDYGIVRYTNTTIPKDGTVGGQDITSAANATVGMAVTRRGSTTGTISGSVTALNATVNYGGGDVVYGMIRTNVCAEPGDSGGPLYSGTRAIGLTSGGSGNCSSGGTTFFQPVTEALVAYGVSVY<SEQUENCE_NATURAL_END>
labels:
<@TOKENIZER-TYPE=AA><MOLECULAR_WEIGHT_IN_SOME_UNIT><@TOKENIZER-TYPE=SCALARS_LITERALS>0.3<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_NANOMOLAR><@TOKENIZER-TYPE=SCALARS_LITERALS>12.4<@TOKENIZER-TYPE=AA><SEQUENCE_NATURAL_START>ISGGDAIYSSTGRCSLGFNVRSGSTYYFLTAGICTDGATTWWANSARTTVLGTTSGSSFPNNDYGIVRYTNTTIPKDGTVGGQDITSAANATVGMAVTRRGSTTGTISGSVTALNATVNYGGGDVVYGMIRTNVCAEPGDSGGPLYSGTRAIGLTSGGSGNCSSGGTTFFQPVTEALVAYGVSVY<SEQUENCE_NATURAL_END>

for embeddings from dict:
encoder_input:
<@TOKENIZER-TYPE=AA><BIOT5_TASK_ID><1><8><SENTINEL_ID_0><@TOKENIZER-TYPE=AA><MOLECULAR_ENTITY><MOLECULAR_ENTITY_GENERAL_PROTEIN><@TOKENIZER-TYPE=EXTERNAL_EMBEDDINGS_FROM_DICT>{protein1_key}<@TOKENIZER-TYPE=AA@MAX-LEN={max_len_1}><SEQUENCE_NATURAL_START>{protein_seq_1}<SEQUENCE_NATURAL_END><EOS>

"""

@staticmethod
Expand Down Expand Up @@ -113,7 +120,11 @@ def build_placeholder_meta_tokenization(
raise Exception(f"tokenizer_type={tokenizer_type} is not supported")

with_placeholders.append(seq)

elif tokenizer_type.startswith("EXTERNAL_EMBEDDINGS_"):
with_placeholders.append(
"<@TOKENIZER-TYPE=AA>"
bensha6757 marked this conversation as resolved.
Show resolved Hide resolved
) # AA tokenizer selection is arbitrary, we only take the special token <SCALAR> from it
with_placeholders.append("<EMBEDDINGS>")
elif tokenizer_type.startswith("VECTORS_"):
raise Exception("VECTOR_* are not supported yet")
else:
Expand All @@ -123,7 +134,7 @@ def build_placeholder_meta_tokenization(
return "".join(with_placeholders), hints_and_subseq

@staticmethod
def build_scalars(
def build_scalars_and_embeddings(
*,
per_meta_tokenizer_data: List[str],
per_meta_encoding_including_placeholders: List[Encoding],
Expand Down Expand Up @@ -155,6 +166,9 @@ def build_scalars(
# for each element, whether it's a scalar or not
all_scalars_valid_mask = []
scalar_default_unfound_value = -1000.0
external_embeddings_info = dict() # a dict mapping location -> embedding input
num_tokens_token_so_far = 0
bensha6757 marked this conversation as resolved.
Show resolved Hide resolved
num_inputs_needing_embeddings = 0

for tokenizer_name, curr_str_data, curr_placeholder_encoding in zip(
per_meta_tokenizer_data[::2],
Expand Down Expand Up @@ -197,6 +211,30 @@ def build_scalars(
raise Exception(
"Only supported SCALARS_* tokenizers are SCALARS_LITERALS and SCALARS_FROM_DICT"
)
num_tokens_token_so_far += len(curr_scalar_values)
elif tokenizer_name == "EXTERNAL_EMBEDDINGS_FROM_DICT":
if sample_dict is None:
raise Exception(
"EXTERNAL_EMBEDDINGS_FROM_DICT used but the provided sample_dict is None"
)
embedding_input = sample_dict[curr_str_data]
external_embeddings_info[num_inputs_needing_embeddings] = (
num_tokens_token_so_far,
embedding_input,
)

curr_scalar_values = torch.full(
(1,),
fill_value=scalar_default_unfound_value,
)
all_scalars_values.append(curr_scalar_values)
all_scalars_valid_mask.append(
torch.full_like(
curr_scalar_values, fill_value=False, dtype=torch.bool
)
)
num_tokens_token_so_far += 1
num_inputs_needing_embeddings += 1

elif tokenizer_name.startswith("VECTORS_"):
raise NotImplementedError
Expand All @@ -212,6 +250,7 @@ def build_scalars(
curr_scalar_values, fill_value=False, dtype=torch.bool
)
)
num_tokens_token_so_far += len(curr_placeholder_encoding.ids)

all_scalars_values = torch.concat(all_scalars_values)
all_scalars_valid_mask = torch.concat(all_scalars_valid_mask)
Expand Down Expand Up @@ -255,4 +294,5 @@ def build_scalars(
return {
"scalars_values": all_scalars_values, # 1d - its length is the number of actual scalars (provided) found
"scalars_valid_mask": all_scalars_valid_mask, # 1d - values of provided scalars
"external_embeddings_info": external_embeddings_info, # dict - number of input needing embeddings -> (location in the query, embeddings input)
}
7 changes: 6 additions & 1 deletion fuse/data/tokenizers/modular_tokenizer/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,7 @@ def __call__(
verbose: Optional[int] = 1,
validate_ends_with_eos: Optional[bool] = None,
key_out_scalars: Optional[str] = None,
key_out_external_embeddings_info: Optional[str] = None,
additional_caller_info_text: Optional[str] = "",
) -> NDict:
"""_summary_
Expand Down Expand Up @@ -500,7 +501,7 @@ def __call__(
+ ".per_meta_part_encoding", # using the key_in as base for the name because key_out_* are optional
)

prepared_data = InjectorToModularTokenizerLib.build_scalars(
prepared_data = InjectorToModularTokenizerLib.build_scalars_and_embeddings(
per_meta_tokenizer_data=per_meta_orig,
per_meta_encoding_including_placeholders=sample_dict[
key_in + ".per_meta_part_encoding"
Expand All @@ -514,6 +515,10 @@ def __call__(
sample_dict[key_out_scalars + ".valid_mask"] = prepared_data[
"scalars_valid_mask"
]
if key_out_external_embeddings_info is not None:
sample_dict[key_out_external_embeddings_info] = prepared_data[
"external_embeddings_info"
]

return sample_dict

Expand Down
Loading