Skip to content

Commit

Permalink
...
Browse files Browse the repository at this point in the history
  • Loading branch information
yoel shoshan committed Oct 9, 2024
1 parent 0ac3c6d commit 2ae2510
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions fuse/data/tokenizers/modular_tokenizer/inject_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
TypedInput,
list_to_tokenizer_string,
)
from warnings import warn


class InjectorToModularTokenizerLib:
Expand Down Expand Up @@ -128,6 +129,7 @@ def build_scalars(
per_meta_encoding_including_placeholders: List[Encoding],
token_ids: List[int],
sample_dict: Optional[NDict] = None,
crop_report: str = "warn",
) -> Dict:
"""
since we:
Expand All @@ -141,9 +143,11 @@ def build_scalars(
per_meta_encoding_including_placeholders: a list of Encoding elements. This is used to extract per tokenizer final tokens num (after all of the padding and cropping logic was already done)
sample_dict: a fuse sample_dict - optional.
needed only if the meta tokenizer instruction uses a syntax of lookup from the dictionary
crop_report: one of None (no action), 'warn' - print a warning, 'raise' - raise an exception
will be triggered if cropping happened
"""
assert crop_report in ["warn", "raise", None]
## both `all_scalars_values` and `all_scalars_valid_mask` will contain torch tensors, which will be concatanated in the end of this function

# one scalar for every element, `scalar_default_unfound_value` is used for elements that aren't scalars
Expand Down Expand Up @@ -237,9 +241,14 @@ def build_scalars(
]
)
elif full_query_len < all_scalars_values.shape[0]:
print(
"warning: scalars sequence had to be cropped. The full (including all subtokenizers) length was {all_scalars_values.shape[0]} after cropping it is {full_query_len}"
)
if crop_report in ["warn", "raise"]:
_msg = f"warning: scalars sequence had to be cropped. The full (including all subtokenizers) length was {all_scalars_values.shape[0]} after cropping it is {full_query_len}"
if crop_report == "warn":
warn(_msg)
elif crop_report == "raise":
raise Exception(_msg)
else:
assert False, "should not get here"
all_scalars_values = all_scalars_values[:full_query_len]
all_scalars_valid_mask = all_scalars_valid_mask[:full_query_len]

Expand Down

0 comments on commit 2ae2510

Please sign in to comment.