Skip to content

Commit

Permalink
advancing to scalars gen2 (#374)
Browse files Browse the repository at this point in the history
  • Loading branch information
YoelShoshan authored Oct 10, 2024
1 parent 1626ae1 commit 957f37f
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 105 deletions.
168 changes: 97 additions & 71 deletions fuse/data/tokenizers/modular_tokenizer/inject_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
TypedInput,
list_to_tokenizer_string,
)
from warnings import warn


class InjectorToModularTokenizerLib:
Expand All @@ -28,17 +29,15 @@ class InjectorToModularTokenizerLib:
supported syntax/format:
for text following <@TOKENIZER-TYPE=SCALARS_LITERALS> supports the following format:
',' separated float values and/or <MASK> tokens -
for example: "2.7,3.99,-12.9" or "<MASK><MASK>" or "2.19,<MASK>,3.19,<MASK>"
',' separated float values. For example: "2.7,3.99,-12.9"
for text following <@TOKENIZER-TYPE=SCALARS_FROM_DICT> is expected to be a key to the sample NDict
for example: "blah.boo.banana" or "data.input.encoder_input"
note: in SCALARS_FROM_DICT you can't describe masked scalars (outputs) you can only describe inputs
example usage:
encoder_input:
<@TOKENIZER-TYPE=AA><MOLECULAR_WEIGHT_IN_SOME_UNIT><@TOKENIZER-TYPE=SCALARS_LITERALS>0.3<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_NANOMOLAR><@TOKENIZER-TYPE=SCALARS_LITERALS><MASK><@TOKENIZER-TYPE=AA><SEQUENCE_NATURAL_START>ISGGDAIYSSTGRCSLGFNVRSGSTYYFLTAGICTDGATTWWANSARTTVLGTTSGSSFPNNDYGIVRYTNTTIPKDGTVGGQDITSAANATVGMAVTRRGSTTGTISGSVTALNATVNYGGGDVVYGMIRTNVCAEPGDSGGPLYSGTRAIGLTSGGSGNCSSGGTTFFQPVTEALVAYGVSVY<SEQUENCE_NATURAL_END>
<@TOKENIZER-TYPE=AA><MOLECULAR_WEIGHT_IN_SOME_UNIT><@TOKENIZER-TYPE=SCALARS_LITERALS>0.3<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_NANOMOLAR><MASK><@TOKENIZER-TYPE=AA><SEQUENCE_NATURAL_START>ISGGDAIYSSTGRCSLGFNVRSGSTYYFLTAGICTDGATTWWANSARTTVLGTTSGSSFPNNDYGIVRYTNTTIPKDGTVGGQDITSAANATVGMAVTRRGSTTGTISGSVTALNATVNYGGGDVVYGMIRTNVCAEPGDSGGPLYSGTRAIGLTSGGSGNCSSGGTTFFQPVTEALVAYGVSVY<SEQUENCE_NATURAL_END>
labels:
<@TOKENIZER-TYPE=AA><MOLECULAR_WEIGHT_IN_SOME_UNIT><@TOKENIZER-TYPE=SCALARS_LITERALS>0.3<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_NANOMOLAR><@TOKENIZER-TYPE=SCALARS_LITERALS>12.4<@TOKENIZER-TYPE=AA><SEQUENCE_NATURAL_START>ISGGDAIYSSTGRCSLGFNVRSGSTYYFLTAGICTDGATTWWANSARTTVLGTTSGSSFPNNDYGIVRYTNTTIPKDGTVGGQDITSAANATVGMAVTRRGSTTGTISGSVTALNATVNYGGGDVVYGMIRTNVCAEPGDSGGPLYSGTRAIGLTSGGSGNCSSGGTTFFQPVTEALVAYGVSVY<SEQUENCE_NATURAL_END>
Expand Down Expand Up @@ -67,16 +66,15 @@ def build_placeholder_meta_tokenization(
)
if len(sequence) > 0:
if isinstance(sequence[0], TypedInput):
sequence_str = list_to_tokenizer_string(
sequence = list_to_tokenizer_string(
sequence
) # currently supporting it in this simple way. Consider optimizing if it causes a bottleneck.
else:
raise Exception(
f"Expected sequence to be either string or a list of TypedInput elements. Got a list, but the first element is of type {type(sequence[0])}"
)
else:
sequence_str = sequence
hints_and_subseq = re.split("<@TOKENIZER-TYPE=([^>]*)>", sequence_str)[

hints_and_subseq = re.split("<@TOKENIZER-TYPE=([^>]*)>", sequence)[
1:
] # the first element is blank - removing it
assert (
Expand All @@ -91,19 +89,18 @@ def build_placeholder_meta_tokenization(
if tokenizer_type.startswith("SCALARS_"):
with_placeholders.append(
"<@TOKENIZER-TYPE=AA>"
) # won't use AA tokens, just an arbitrary one to be able to use a token like <SCALAR>
) # AA tokenizer selection is arbitrary, we only take the special token <SCALAR> from it

if (
tokenizer_type == "SCALARS_LITERALS"
): # note: masking is only supported in literals (not in "from dict")
if tokenizer_type == "SCALARS_LITERALS":
values = subseq.split(",")
# seq = "<SCALAR>" * len(values)
seq = "".join(
[
"<MASKED_SCALAR>" if x == "<MASK>" else "<SCALAR>"
for x in values
]
)
# validate that all values can be converted to float
try:
[float(x) for x in values]
except:
raise ValueError(
f'expected a string with "," separated values that can each be converted to float. Got {subseq}'
)
seq = "<SCALAR>" * len(values)
elif tokenizer_type == "SCALARS_FROM_DICT":
if sample_dict is None:
raise Exception(
Expand All @@ -126,11 +123,13 @@ def build_placeholder_meta_tokenization(
return "".join(with_placeholders), hints_and_subseq

@staticmethod
def prepare_info_for_model_step(
def build_scalars(
*,
per_meta_tokenizer_data: List[str],
per_meta_encoding_including_placeholders: List[Encoding],
token_ids: List[int],
sample_dict: Optional[NDict] = None,
crop_report: str = "warn",
) -> Dict:
"""
since we:
Expand All @@ -144,13 +143,18 @@ def prepare_info_for_model_step(
per_meta_encoding_including_placeholders: a list of Encoding elements. This is used to extract per tokenizer final tokens num (after all of the padding and cropping logic was already done)
sample_dict: a fuse sample_dict - optional.
needed only if the meta tokenizer instruction uses a syntax of lookup from the dictionary
crop_report: one of None (no action), 'warn' - print a warning, 'raise' - raise an exception
will be triggered if cropping happened
"""
scalars_indices = []
scalars_values = []
scalars_masked_indices = []
prev_index_end = -1
assert crop_report in ["warn", "raise", None]
## both `all_scalars_values` and `all_scalars_valid_mask` will contain torch tensors, which will be concatanated in the end of this function

# one scalar for every element, `scalar_default_unfound_value` is used for elements that aren't scalars
all_scalars_values = []
# for each element, whether it's a scalar or not
all_scalars_valid_mask = []
scalar_default_unfound_value = -1000.0

for tokenizer_name, curr_str_data, curr_placeholder_encoding in zip(
per_meta_tokenizer_data[::2],
Expand All @@ -165,42 +169,30 @@ def prepare_info_for_model_step(
f"should match expected length. Found length {len(curr_str_data)} but placeholders length was {len(curr_placeholder_encoding.ids)}"
)

curr_indices = []
curr_data = []

for i, val in enumerate(curr_str_data):
if val != "<MASK>":
curr_indices.append(i + prev_index_end + 1)
curr_data.append(float(val))
else:
scalars_masked_indices.append(i + prev_index_end + 1)

if len(curr_indices) > 0:
curr_indices = torch.tensor(curr_indices, dtype=torch.int64)
curr_data = torch.tensor(curr_data, dtype=torch.float32)

scalars_indices.append(curr_indices)
scalars_values.append(curr_data)

assert len(curr_data.shape) == 1

prev_index_end += len(curr_str_data)
curr_scalar_values = [float(val) for val in curr_str_data]
curr_scalar_values = torch.tensor(
curr_scalar_values, dtype=torch.float32
)
all_scalars_values.append(curr_scalar_values)
all_scalars_valid_mask.append(
torch.full_like(
curr_scalar_values, fill_value=True, dtype=torch.bool
)
)
elif "SCALARS_FROM_DICT" == tokenizer_name:
if sample_dict is None:
raise Exception(
"SCALARS_FROM_DICT used but the provided sample_dict is None"
)
curr_data = sample_dict[curr_str_data]
assert len(curr_data.shape) == 1
curr_indices = torch.arange(
prev_index_end + 1, prev_index_end + 1 + curr_data.shape[0]
curr_scalar_values = sample_dict[curr_str_data]
assert len(curr_scalar_values.shape) == 1
all_scalars_values.append(curr_scalar_values)
all_scalars_valid_mask.append(
torch.full_like(
curr_scalar_values, fill_value=True, dtype=torch.bool
)
)

scalars_indices.append(curr_indices)
scalars_values.append(curr_data)

prev_index_end += curr_data.shape[0]

else:
raise Exception(
"Only supported SCALARS_* tokenizers are SCALARS_LITERALS and SCALARS_FROM_DICT"
Expand All @@ -209,24 +201,58 @@ def prepare_info_for_model_step(
elif tokenizer_name.startswith("VECTORS_"):
raise NotImplementedError
else:
prev_index_end += len(curr_placeholder_encoding.ids)

if len(scalars_indices) > 0:
scalars_indices = torch.concat(scalars_indices)
scalars_values = torch.concat(scalars_values)
else:
scalars_indices = None
scalars_values = None

if len(scalars_masked_indices) > 0:
scalars_masked_indices = torch.tensor(
scalars_masked_indices, dtype=torch.int64
# prev_index_end += len(curr_placeholder_encoding.ids)
curr_scalar_values = torch.full(
(len(curr_placeholder_encoding.ids),),
fill_value=scalar_default_unfound_value,
)
all_scalars_values.append(curr_scalar_values)
all_scalars_valid_mask.append(
torch.full_like(
curr_scalar_values, fill_value=False, dtype=torch.bool
)
)

all_scalars_values = torch.concat(all_scalars_values)
all_scalars_valid_mask = torch.concat(all_scalars_valid_mask)

assert all_scalars_values.shape == all_scalars_valid_mask.shape

# pad if needed
full_query_len = len(token_ids)
if full_query_len > all_scalars_values.shape[0]:
pad_len = full_query_len - all_scalars_values.shape[0]
all_scalars_values = torch.concat(
[
all_scalars_values,
torch.full(
(pad_len,),
fill_value=scalar_default_unfound_value,
dtype=all_scalars_values.dtype,
),
]
)
else:
scalars_masked_indices = None
all_scalars_valid_mask = torch.concat(
[
all_scalars_valid_mask,
torch.full(
(pad_len,), fill_value=False, dtype=all_scalars_valid_mask.dtype
),
]
)
elif full_query_len < all_scalars_values.shape[0]:
if crop_report in ["warn", "raise"]:
_msg = f"warning: scalars sequence had to be cropped. The full (including all subtokenizers) length was {all_scalars_values.shape[0]} after cropping it is {full_query_len}"
if crop_report == "warn":
warn(_msg)
elif crop_report == "raise":
raise Exception(_msg)
else:
assert False, "should not get here"
all_scalars_values = all_scalars_values[:full_query_len]
all_scalars_valid_mask = all_scalars_valid_mask[:full_query_len]

return {
"scalars_indices": scalars_indices, # 1d - its length is the number of actual scalars (provided) found
"scalars_values": scalars_values, # 1d - values of provided scalars
"scalars_masked_indices": scalars_masked_indices, # 1d - indices of masked scalars
"scalars_values": all_scalars_values, # 1d - its length is the number of actual scalars (provided) found
"scalars_valid_mask": all_scalars_valid_mask, # 1d - values of provided scalars
}
46 changes: 12 additions & 34 deletions fuse/data/tokenizers/modular_tokenizer/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,8 +372,7 @@ class ModularTokenizerOp(ModularTokenizerWithoutInjectOp):
supported syntax/format:
for text following <@TOKENIZER-TYPE=SCALARS_LITERALS> supports the following format:
',' separated float values and/or <MASK> tokens -
for example: "2.7,3.99,-12.9" or "<MASK><MASK>" or "2.19,<MASK>,3.19,<MASK>"
',' separated float values
for text following <@TOKENIZER-TYPE=SCALARS_FROM_DICT> is expected to be a key to the sample NDict
for example: "blah.boo.banana" or "data.input.encoder_input"
Expand Down Expand Up @@ -437,9 +436,7 @@ def __call__(
on_unknown: Optional[str] = "warn",
verbose: Optional[int] = 1,
validate_ends_with_eos: Optional[bool] = None,
key_out_scalars_indices: Optional[str] = None,
key_out_scalars_values: Optional[str] = None,
key_out_masked_scalars_indices: Optional[str] = None,
key_out_scalars: Optional[str] = None,
) -> NDict:
"""_summary_
Expand All @@ -458,10 +455,10 @@ def __call__(
verbose (Optional[int], optional): verbosity level. 0: no notification, 1: warning notification, 2: warning with partial data, 3: warning
with full data. Defaults to 1.
validate_ends_with_eos (Optional[bool], optional): if not None, overrides self._validate_ends_with_eos
key_out_scalars_inputs_indices:str optional
if provided, will write to sample_dict in this key a 1D torch tensor with indices of all inputs scalar elements.
key_out_scalars_inputs_values:str optional
if provided, will write to sample_dict in this key a 1D torch tensor with indices of all inputs scalar values.
key_out_scalars:str optional
if provided, will write to:
`sample_dict[f'{key_out_scalars}.values]` - a 1D torch tensor with all the scalars values
`sample_dict[f'{key_out_scalars}.valid_mask]` - a 1D torch boolean tensor representing which elements have scalar values
Returns:
NDict: _description_
Expand Down Expand Up @@ -490,39 +487,20 @@ def __call__(
+ ".per_meta_part_encoding", # using the key_in as base for the name because key_out_* are optional
)

prepared_data = InjectorToModularTokenizerLib.prepare_info_for_model_step(
prepared_data = InjectorToModularTokenizerLib.build_scalars(
per_meta_tokenizer_data=per_meta_orig,
per_meta_encoding_including_placeholders=sample_dict[
key_in + ".per_meta_part_encoding"
],
token_ids=sample_dict[key_out_tokens_ids],
sample_dict=sample_dict,
)

if key_out_scalars_indices is not None:
sample_dict[key_out_scalars_indices] = prepared_data["scalars_indices"]
else:
if prepared_data["scalars_indices"] is not None:
raise Exception(
"non None scalars_indices found but no key_out_scalars_indices found"
)

if key_out_scalars_values is not None:
sample_dict[key_out_scalars_values] = prepared_data["scalars_values"]
else:
if prepared_data["scalars_values"] is not None:
raise Exception(
"non None scalars_value found but no key_out_scalars_values found"
)

if key_out_masked_scalars_indices is not None:
sample_dict[key_out_masked_scalars_indices] = prepared_data[
"scalars_masked_indices"
if key_out_scalars is not None:
sample_dict[key_out_scalars + ".values"] = prepared_data["scalars_values"]
sample_dict[key_out_scalars + ".valid_mask"] = prepared_data[
"scalars_valid_mask"
]
else:
if prepared_data["scalars_masked_indices"] is not None:
raise Exception(
"non None scalars_masked_indices found but no key_out_masked_scalars_indices found"
)

return sample_dict

Expand Down

0 comments on commit 957f37f

Please sign in to comment.