diff --git a/scripts/py/run_hf_low.py b/scripts/py/run_hf_low.py index bd539ac5..bdf5bc9e 100644 --- a/scripts/py/run_hf_low.py +++ b/scripts/py/run_hf_low.py @@ -27,23 +27,6 @@ def check_stop(runner: AiciRunner, seq_id: int): raise StopGeneration -def apply_bias(runner: AiciRunner, seq_id: int, scores: torch.Tensor) -> torch.Tensor: - bias_tensor = runner.recv_logit_bias() - runner.print_logs() - check_stop(runner, seq_id) - ff_tokens, backtrack = runner.mid_status(seq_id) - assert backtrack == 0, "backtrack not implemented" - assert len(ff_tokens) == 0, "ff_tokens not implemented" - bias_tensor = torch.from_numpy(bias_tensor).to(scores.device).to(scores.dtype) - # print(bias_tensor.shape, scores.shape, input_ids.shape) - vocab_size = bias_tensor.shape[1] - # scores should be the size of vocabulary but some models (phi-2) make it slightly bigger - assert scores.shape[1] <= vocab_size + 1000 - scores = scores[:, 0:vocab_size] - assert scores.shape == bias_tensor.shape - return bias_tensor + scores - - def main(args): tokenizer = cast( PreTrainedTokenizer, AutoTokenizer.from_pretrained(args.tokenizer or args.model) @@ -98,7 +81,9 @@ def main(args): runner.exec_mid() m_inp = input_ids.unsqueeze(0) - model_kwargs["attention_mask"] = torch.ones(m_inp.shape, dtype=torch.long, device=model.device) + model_kwargs["attention_mask"] = torch.ones( + m_inp.shape, dtype=torch.long, device=model.device + ) model_inputs = model.prepare_inputs_for_generation(m_inp, **model_kwargs) outputs = model( **model_inputs, @@ -106,7 +91,7 @@ def main(args): output_attentions=False, output_hidden_states=False, ) - scores: torch.Tensor = outputs.logits[:, -1, :] / temperature + scores: torch.Tensor = outputs.logits[:, -1, :] bias_tensor = runner.recv_logit_bias() runner.print_logs() @@ -114,18 +99,24 @@ def main(args): ff_tokens, backtrack = runner.mid_status(seq_id) assert backtrack == 0, "backtrack not implemented" assert len(ff_tokens) == 0, "ff_tokens not implemented" - bias_tensor = torch.from_numpy(bias_tensor).to(scores.device).to(scores.dtype) + bias_tensor = ( + torch.from_numpy(bias_tensor).to(scores.device).to(scores.dtype) + ) # print(bias_tensor.shape, scores.shape, input_ids.shape) vocab_size = bias_tensor.shape[1] # scores should be the size of vocabulary but some models (phi-2) make it slightly bigger assert scores.shape[1] <= vocab_size + 1000 scores = scores[:, 0:vocab_size] assert scores.shape == bias_tensor.shape + scores += bias_tensor + scores /= temperature probs = nn.functional.softmax(scores, dim=-1) if backtrack > 0 or len(ff_tokens) > 0: - next_tokens = torch.tensor(ff_tokens, dtype=torch.long, device=model.device) + next_tokens = torch.tensor( + ff_tokens, dtype=torch.long, device=model.device + ) else: next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) @@ -148,7 +139,13 @@ def main(args): # len(m) == num_layers, len(m[0]) == 2 (key, value) # shape of each elt is (batch_size, num_heads, seq_len, head_dim) if m is not None and backtrack > 0: - m = [(q[0][:, :, 0:computed_kv_len, :], q[1][:, :, 0:computed_kv_len, :]) for q in m] + m = [ + ( + q[0][:, :, 0:computed_kv_len, :], + q[1][:, :, 0:computed_kv_len, :], + ) + for q in m + ] model_kwargs["past_key_values"] = m suspend, num_forks, ff_tokens = runner.pre_status(seq_id) @@ -157,7 +154,13 @@ def main(args): assert num_forks <= 1, "forking not implemented" if len(ff_tokens) > 0: - input_ids = torch.cat([input_ids, torch.tensor(ff_tokens, dtype=torch.long, device=model.device)], dim=0) + input_ids = torch.cat( + [ + input_ids, + torch.tensor(ff_tokens, dtype=torch.long, device=model.device), + ], + dim=0, + ) except StopGeneration: runner.print_logs()