-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
advancing to scalars gen2 #374
Changes from 5 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,17 +28,15 @@ class InjectorToModularTokenizerLib: | |
supported syntax/format: | ||
|
||
for text following <@TOKENIZER-TYPE=SCALARS_LITERALS> supports the following format: | ||
',' separated float values and/or <MASK> tokens - | ||
for example: "2.7,3.99,-12.9" or "<MASK><MASK>" or "2.19,<MASK>,3.19,<MASK>" | ||
',' separated float values. For example: "2.7,3.99,-12.9" | ||
|
||
for text following <@TOKENIZER-TYPE=SCALARS_FROM_DICT> is expected to be a key to the sample NDict | ||
for example: "blah.boo.banana" or "data.input.encoder_input" | ||
note: in SCALARS_FROM_DICT you can't describe masked scalars (outputs) you can only describe inputs | ||
|
||
example usage: | ||
|
||
encoder_input: | ||
<@TOKENIZER-TYPE=AA><MOLECULAR_WEIGHT_IN_SOME_UNIT><@TOKENIZER-TYPE=SCALARS_LITERALS>0.3<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_NANOMOLAR><@TOKENIZER-TYPE=SCALARS_LITERALS><MASK><@TOKENIZER-TYPE=AA><SEQUENCE_NATURAL_START>ISGGDAIYSSTGRCSLGFNVRSGSTYYFLTAGICTDGATTWWANSARTTVLGTTSGSSFPNNDYGIVRYTNTTIPKDGTVGGQDITSAANATVGMAVTRRGSTTGTISGSVTALNATVNYGGGDVVYGMIRTNVCAEPGDSGGPLYSGTRAIGLTSGGSGNCSSGGTTFFQPVTEALVAYGVSVY<SEQUENCE_NATURAL_END> | ||
<@TOKENIZER-TYPE=AA><MOLECULAR_WEIGHT_IN_SOME_UNIT><@TOKENIZER-TYPE=SCALARS_LITERALS>0.3<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_NANOMOLAR><MASK><@TOKENIZER-TYPE=AA><SEQUENCE_NATURAL_START>ISGGDAIYSSTGRCSLGFNVRSGSTYYFLTAGICTDGATTWWANSARTTVLGTTSGSSFPNNDYGIVRYTNTTIPKDGTVGGQDITSAANATVGMAVTRRGSTTGTISGSVTALNATVNYGGGDVVYGMIRTNVCAEPGDSGGPLYSGTRAIGLTSGGSGNCSSGGTTFFQPVTEALVAYGVSVY<SEQUENCE_NATURAL_END> | ||
labels: | ||
<@TOKENIZER-TYPE=AA><MOLECULAR_WEIGHT_IN_SOME_UNIT><@TOKENIZER-TYPE=SCALARS_LITERALS>0.3<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_NANOMOLAR><@TOKENIZER-TYPE=SCALARS_LITERALS>12.4<@TOKENIZER-TYPE=AA><SEQUENCE_NATURAL_START>ISGGDAIYSSTGRCSLGFNVRSGSTYYFLTAGICTDGATTWWANSARTTVLGTTSGSSFPNNDYGIVRYTNTTIPKDGTVGGQDITSAANATVGMAVTRRGSTTGTISGSVTALNATVNYGGGDVVYGMIRTNVCAEPGDSGGPLYSGTRAIGLTSGGSGNCSSGGTTFFQPVTEALVAYGVSVY<SEQUENCE_NATURAL_END> | ||
|
||
|
@@ -67,16 +65,15 @@ def build_placeholder_meta_tokenization( | |
) | ||
if len(sequence) > 0: | ||
if isinstance(sequence[0], TypedInput): | ||
sequence_str = list_to_tokenizer_string( | ||
sequence = list_to_tokenizer_string( | ||
sequence | ||
) # currently supporting it in this simple way. Consider optimizing if it causes a bottleneck. | ||
else: | ||
raise Exception( | ||
f"Expected sequence to be either string or a list of TypedInput elements. Got a list, but the first element is of type {type(sequence[0])}" | ||
) | ||
else: | ||
sequence_str = sequence | ||
hints_and_subseq = re.split("<@TOKENIZER-TYPE=([^>]*)>", sequence_str)[ | ||
|
||
hints_and_subseq = re.split("<@TOKENIZER-TYPE=([^>]*)>", sequence)[ | ||
1: | ||
] # the first element is blank - removing it | ||
assert ( | ||
|
@@ -91,19 +88,18 @@ def build_placeholder_meta_tokenization( | |
if tokenizer_type.startswith("SCALARS_"): | ||
with_placeholders.append( | ||
"<@TOKENIZER-TYPE=AA>" | ||
) # won't use AA tokens, just an arbitrary one to be able to use a token like <SCALAR> | ||
) # AA tokenizer selection is arbitrary, we only take the special token <SCALAR> from it | ||
|
||
if ( | ||
tokenizer_type == "SCALARS_LITERALS" | ||
): # note: masking is only supported in literals (not in "from dict") | ||
if tokenizer_type == "SCALARS_LITERALS": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remind me, can we put mask in scalar literals? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I stopped supporting this option, intentionally. |
||
values = subseq.split(",") | ||
# seq = "<SCALAR>" * len(values) | ||
seq = "".join( | ||
[ | ||
"<MASKED_SCALAR>" if x == "<MASK>" else "<SCALAR>" | ||
for x in values | ||
] | ||
) | ||
# validate that all values can be converted to float | ||
try: | ||
[float(x) for x in values] | ||
except: | ||
raise ValueError( | ||
f'expected a string with "," separated values that can each be converted to float. Got {subseq}' | ||
) | ||
seq = "<SCALAR>" * len(values) | ||
elif tokenizer_type == "SCALARS_FROM_DICT": | ||
if sample_dict is None: | ||
raise Exception( | ||
|
@@ -126,10 +122,11 @@ def build_placeholder_meta_tokenization( | |
return "".join(with_placeholders), hints_and_subseq | ||
|
||
@staticmethod | ||
def prepare_info_for_model_step( | ||
def build_scalars( | ||
*, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe build_scalars be a better name here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. np, renamed |
||
per_meta_tokenizer_data: List[str], | ||
per_meta_encoding_including_placeholders: List[Encoding], | ||
token_ids: List[int], | ||
sample_dict: Optional[NDict] = None, | ||
) -> Dict: | ||
""" | ||
|
@@ -147,10 +144,13 @@ def prepare_info_for_model_step( | |
|
||
|
||
""" | ||
scalars_indices = [] | ||
scalars_values = [] | ||
scalars_masked_indices = [] | ||
prev_index_end = -1 | ||
## both `all_scalars_values` and `all_scalars_valid_mask` will contain torch tensors, which will be concatanated in the end of this function | ||
|
||
# one scalar for every element, `scalar_default_unfound_value` is used for elements that aren't scalars | ||
all_scalars_values = [] | ||
# for each element, whether it's a scalar or not | ||
all_scalars_valid_mask = [] | ||
scalar_default_unfound_value = -1000.0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does it mean? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since we now keep the scalars values and mask at the size of the entire sequence, this is the default value for places that don't actually have a scalar value. |
||
|
||
for tokenizer_name, curr_str_data, curr_placeholder_encoding in zip( | ||
per_meta_tokenizer_data[::2], | ||
|
@@ -165,42 +165,30 @@ def prepare_info_for_model_step( | |
f"should match expected length. Found length {len(curr_str_data)} but placeholders length was {len(curr_placeholder_encoding.ids)}" | ||
) | ||
|
||
curr_indices = [] | ||
curr_data = [] | ||
|
||
for i, val in enumerate(curr_str_data): | ||
if val != "<MASK>": | ||
curr_indices.append(i + prev_index_end + 1) | ||
curr_data.append(float(val)) | ||
else: | ||
scalars_masked_indices.append(i + prev_index_end + 1) | ||
|
||
if len(curr_indices) > 0: | ||
curr_indices = torch.tensor(curr_indices, dtype=torch.int64) | ||
curr_data = torch.tensor(curr_data, dtype=torch.float32) | ||
|
||
scalars_indices.append(curr_indices) | ||
scalars_values.append(curr_data) | ||
|
||
assert len(curr_data.shape) == 1 | ||
|
||
prev_index_end += len(curr_str_data) | ||
curr_scalar_values = [float(val) for val in curr_str_data] | ||
curr_scalar_values = torch.tensor( | ||
curr_scalar_values, dtype=torch.float32 | ||
) | ||
all_scalars_values.append(curr_scalar_values) | ||
all_scalars_valid_mask.append( | ||
torch.full_like( | ||
curr_scalar_values, fill_value=True, dtype=torch.bool | ||
) | ||
) | ||
elif "SCALARS_FROM_DICT" == tokenizer_name: | ||
if sample_dict is None: | ||
raise Exception( | ||
"SCALARS_FROM_DICT used but the provided sample_dict is None" | ||
) | ||
curr_data = sample_dict[curr_str_data] | ||
assert len(curr_data.shape) == 1 | ||
curr_indices = torch.arange( | ||
prev_index_end + 1, prev_index_end + 1 + curr_data.shape[0] | ||
curr_scalar_values = sample_dict[curr_str_data] | ||
assert len(curr_scalar_values.shape) == 1 | ||
all_scalars_values.append(curr_scalar_values) | ||
all_scalars_valid_mask.append( | ||
torch.full_like( | ||
curr_scalar_values, fill_value=True, dtype=torch.bool | ||
) | ||
) | ||
|
||
scalars_indices.append(curr_indices) | ||
scalars_values.append(curr_data) | ||
|
||
prev_index_end += curr_data.shape[0] | ||
|
||
else: | ||
raise Exception( | ||
"Only supported SCALARS_* tokenizers are SCALARS_LITERALS and SCALARS_FROM_DICT" | ||
|
@@ -209,24 +197,52 @@ def prepare_info_for_model_step( | |
elif tokenizer_name.startswith("VECTORS_"): | ||
raise NotImplementedError | ||
else: | ||
prev_index_end += len(curr_placeholder_encoding.ids) | ||
|
||
if len(scalars_indices) > 0: | ||
scalars_indices = torch.concat(scalars_indices) | ||
scalars_values = torch.concat(scalars_values) | ||
else: | ||
scalars_indices = None | ||
scalars_values = None | ||
|
||
if len(scalars_masked_indices) > 0: | ||
scalars_masked_indices = torch.tensor( | ||
scalars_masked_indices, dtype=torch.int64 | ||
# prev_index_end += len(curr_placeholder_encoding.ids) | ||
curr_scalar_values = torch.full( | ||
(len(curr_placeholder_encoding.ids),), | ||
fill_value=scalar_default_unfound_value, | ||
) | ||
all_scalars_values.append(curr_scalar_values) | ||
all_scalars_valid_mask.append( | ||
torch.full_like( | ||
curr_scalar_values, fill_value=False, dtype=torch.bool | ||
) | ||
) | ||
|
||
all_scalars_values = torch.concat(all_scalars_values) | ||
all_scalars_valid_mask = torch.concat(all_scalars_valid_mask) | ||
|
||
assert all_scalars_values.shape == all_scalars_valid_mask.shape | ||
|
||
# pad if needed | ||
full_query_len = len(token_ids) | ||
if full_query_len > all_scalars_values.shape[0]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't it happend? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mean "can it happen" ? The main code logic (before it) iterates over each sub part (with specific sub tokenizer) so it does not contain the padding. I can explain more if it isn't clear |
||
pad_len = full_query_len - all_scalars_values.shape[0] | ||
all_scalars_values = torch.concat( | ||
[ | ||
all_scalars_values, | ||
torch.full( | ||
(pad_len,), | ||
fill_value=scalar_default_unfound_value, | ||
dtype=all_scalars_values.dtype, | ||
), | ||
] | ||
) | ||
all_scalars_valid_mask = torch.concat( | ||
[ | ||
all_scalars_valid_mask, | ||
torch.full( | ||
(pad_len,), fill_value=False, dtype=all_scalars_valid_mask.dtype | ||
), | ||
] | ||
) | ||
else: | ||
scalars_masked_indices = None | ||
elif full_query_len > all_scalars_values.shape[0]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mosheraboh see here, related to what we talked about. I'll try to add more unit tests with interesting cases by the end of this week |
||
print('warning: scalars sequence had to be cropped. The full (including all subtokenizers) length was {all_scalars_values.shape[0]} after cropping it is {full_query_len}') | ||
all_scalars_values = all_scalars_values[:full_query_len] | ||
all_scalars_valid_mask = all_scalars_valid_mask[:full_query_len] | ||
|
||
|
||
return { | ||
"scalars_indices": scalars_indices, # 1d - its length is the number of actual scalars (provided) found | ||
"scalars_values": scalars_values, # 1d - values of provided scalars | ||
"scalars_masked_indices": scalars_masked_indices, # 1d - indices of masked scalars | ||
"scalars_values": all_scalars_values, # 1d - its length is the number of actual scalars (provided) found | ||
"scalars_valid_mask": all_scalars_valid_mask, # 1d - values of provided scalars | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The general solution can't assume there is AA subtokenizer.
Maybe we need a default empty sub-tokenizer? Maybe SCALARS can be an empty sub-tokenizer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
interesting point.
SCALARS is currently fully programmatic and does not rely on any dictionary, so I would rather not mix it.
Probably better to have "base" that gets automatically generated and supported , as the modular tokenizer already knows how to handle special tokens
maybe "Base" or "SpecialTokensBase" or something