Skip to content

Commit

Permalink
Added unittest for structured_compress_prompt and fixed bugs (#95)
Browse files Browse the repository at this point in the history
- fix bugs and add unittests
- make style
- add nltk init
- fix nltk file exist error
- add unittest for different models

Co-authored-by: Siyun Zhao <[email protected]>
Co-authored-by: Qianhui Wu <[email protected]>
Co-authored-by: Xufang Luo <[email protected]>
Co-authored-by: Yuqing Yang <[email protected]>
  • Loading branch information
5 people authored Feb 28, 2024
1 parent 049a113 commit 9f97ba7
Show file tree
Hide file tree
Showing 3 changed files with 225 additions and 32 deletions.
70 changes: 38 additions & 32 deletions llmlingua/prompt_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,10 @@ def structured_compress_prompt(
context = [" "]
if isinstance(context, str):
context = [context]

context = [
self.tokenizer.decode(self.tokenizer(c, add_special_tokens=False).input_ids)
for c in context
]
context_tokens_length = [self.get_token_length(c) for c in context]
instruction_tokens_length, question_tokens_length = self.get_token_length(
instruction
Expand Down Expand Up @@ -488,7 +491,7 @@ def compress_prompt(
if condition_flag:
prefix = question + "\n\n" + instruction if add_instruction else question
if (
self.get_token_length(prefix) + 2 + iterative_size * 2
self.get_token_length(prefix + "\n\n") + iterative_size * 2
> self.max_position_embeddings
):
tokens = self.tokenizer(prefix, add_special_tokens=False).input_ids
Expand All @@ -502,7 +505,7 @@ def compress_prompt(
+ 2 * iterative_size :
]
)
start = self.get_token_length(prefix) + 2
start = self.get_prefix_length(prefix + "\n\n", context[0])
context = [prefix] + context
else:
start = 0
Expand Down Expand Up @@ -556,6 +559,18 @@ def get_token_length(self, text: str, add_special_tokens: bool = True):
self.tokenizer(text, add_special_tokens=add_special_tokens).input_ids
)

def get_prefix_length(self, prefix: str, text: str):
possible_prefix_token = max(self.get_token_length(prefix, False) - 3, 1)
full_input_ids = self.tokenizer(
prefix + text[:100], add_special_tokens=False
).input_ids
for i in range(possible_prefix_token, len(full_input_ids)):
cur_prefix = self.tokenizer.decode(full_input_ids[:i])
if cur_prefix == prefix:
break
assert self.tokenizer.decode(full_input_ids[i:]) == text[:100]
return i

def get_condition_ppl(
self,
text: str,
Expand Down Expand Up @@ -633,48 +648,53 @@ def get_structured_dynamic_compression_ratio(
seg_info: List[List[tuple]] = None,
):
if start:
context = context[1:]
global_dynamic_rate, global_dynamic_compress, tmp_context = [], [], []
for context_idx, text in enumerate(context):
pure_context = context[1:]
else:
pure_context = context
global_dynamic_rate, global_dynamic_compress, segments = [], [], []
for context_idx, text in enumerate(pure_context):
text_seen = 0
for seg_idx, (seg_len, seg_rate, seg_compress) in enumerate(
seg_info[context_idx]
):
seg_text = text[text_seen : text_seen + seg_len]
if (
seg_idx == len(seg_info[context_idx]) - 1
and context_idx != len(context) - 1
and context_idx != len(pure_context) - 1
):
seg_text += "\n\n"
tmp_context.append(seg_text)
segments.append(seg_text)
if seg_compress:
global_dynamic_rate.append(seg_rate)
else:
global_dynamic_rate.append(1.0)
global_dynamic_compress.append(seg_compress)
text_seen += seg_len
origin_text = "\n\n".join(context)
assert len("".join(tmp_context)) == len(origin_text)
origin_text = "\n\n".join(pure_context)
assert len("".join(segments)) == len(origin_text)
assert len(segments) == len(global_dynamic_rate) == len(global_dynamic_compress)

text_input_ids = self.tokenizer(
"\n\n".join(context), add_special_tokens=False
).input_ids[start:]
assert self.tokenizer.decode(text_input_ids) == origin_text
dynamic_compression_ratio = self.token_segment(
origin_text,
text_input_ids,
iterative_size,
tmp_context,
segments,
global_dynamic_rate,
global_dynamic_compress,
)
return dynamic_compression_ratio

def token_segment(
self,
text: str,
text_input_ids: List[int],
iterative_size: int,
segments: List[str],
global_dynamic_rate: List[float],
global_dynamic_compress: List[bool],
):
assert len(segments) == len(global_dynamic_rate) == len(global_dynamic_compress)
assert text == "".join(segments)
text_input_ids = self.tokenizer(text, add_special_tokens=False).input_ids
decode_window = 3
seg_idx, seg_seen, token_seen_num, last_rate = 0, 0, 0, -1
dynamic_compression_rate, local_compresssion_rate = [], []
Expand Down Expand Up @@ -953,26 +973,12 @@ def sync_sentence(segments, text):
new_segments_info = []
for s in sentences:
tmp = [jj for ii, jj in enumerate(s) if sentence_flags[idx + ii]]
res.append("\n".join(tmp))
res.append("".join(tmp))
if context_segs is not None:
segment_ratio = []
for ii in range(len(s)):
if sentence_flags[idx + ii]:
last_element = (
sen2seg_ratio[idx + ii][-1][0] + 1,
sen2seg_ratio[idx + ii][-1][1],
sen2seg_ratio[idx + ii][-1][2],
)
segment_ratio.extend(
sen2seg_ratio[idx + ii][:-1] + [last_element]
)
segment_ratio = segment_ratio[:-1] + [
(
segment_ratio[-1][0] - 1,
segment_ratio[-1][1],
segment_ratio[-1][2],
)
]
segment_ratio.extend(sen2seg_ratio[idx + ii])
new_segments_info.append(segment_ratio)
idx += len(s)
if context_segs is not None:
Expand Down
Loading

0 comments on commit 9f97ba7

Please sign in to comment.