Skip to content

Commit

Permalink
fix sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Feb 13, 2024
1 parent 742f9ff commit a91287d
Showing 1 changed file with 26 additions and 23 deletions.
49 changes: 26 additions & 23 deletions scripts/py/run_hf_low.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,6 @@ def check_stop(runner: AiciRunner, seq_id: int):
raise StopGeneration


def apply_bias(runner: AiciRunner, seq_id: int, scores: torch.Tensor) -> torch.Tensor:
bias_tensor = runner.recv_logit_bias()
runner.print_logs()
check_stop(runner, seq_id)
ff_tokens, backtrack = runner.mid_status(seq_id)
assert backtrack == 0, "backtrack not implemented"
assert len(ff_tokens) == 0, "ff_tokens not implemented"
bias_tensor = torch.from_numpy(bias_tensor).to(scores.device).to(scores.dtype)
# print(bias_tensor.shape, scores.shape, input_ids.shape)
vocab_size = bias_tensor.shape[1]
# scores should be the size of vocabulary but some models (phi-2) make it slightly bigger
assert scores.shape[1] <= vocab_size + 1000
scores = scores[:, 0:vocab_size]
assert scores.shape == bias_tensor.shape
return bias_tensor + scores


def main(args):
tokenizer = cast(
PreTrainedTokenizer, AutoTokenizer.from_pretrained(args.tokenizer or args.model)
Expand Down Expand Up @@ -98,34 +81,42 @@ def main(args):
runner.exec_mid()

m_inp = input_ids.unsqueeze(0)
model_kwargs["attention_mask"] = torch.ones(m_inp.shape, dtype=torch.long, device=model.device)
model_kwargs["attention_mask"] = torch.ones(
m_inp.shape, dtype=torch.long, device=model.device
)
model_inputs = model.prepare_inputs_for_generation(m_inp, **model_kwargs)
outputs = model(
**model_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
scores: torch.Tensor = outputs.logits[:, -1, :] / temperature
scores: torch.Tensor = outputs.logits[:, -1, :]

bias_tensor = runner.recv_logit_bias()
runner.print_logs()
check_stop(runner, seq_id)
ff_tokens, backtrack = runner.mid_status(seq_id)
assert backtrack == 0, "backtrack not implemented"
assert len(ff_tokens) == 0, "ff_tokens not implemented"
bias_tensor = torch.from_numpy(bias_tensor).to(scores.device).to(scores.dtype)
bias_tensor = (
torch.from_numpy(bias_tensor).to(scores.device).to(scores.dtype)
)
# print(bias_tensor.shape, scores.shape, input_ids.shape)
vocab_size = bias_tensor.shape[1]
# scores should be the size of vocabulary but some models (phi-2) make it slightly bigger
assert scores.shape[1] <= vocab_size + 1000
scores = scores[:, 0:vocab_size]
assert scores.shape == bias_tensor.shape
scores += bias_tensor
scores /= temperature

probs = nn.functional.softmax(scores, dim=-1)

if backtrack > 0 or len(ff_tokens) > 0:
next_tokens = torch.tensor(ff_tokens, dtype=torch.long, device=model.device)
next_tokens = torch.tensor(
ff_tokens, dtype=torch.long, device=model.device
)
else:
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)

Expand All @@ -148,7 +139,13 @@ def main(args):
# len(m) == num_layers, len(m[0]) == 2 (key, value)
# shape of each elt is (batch_size, num_heads, seq_len, head_dim)
if m is not None and backtrack > 0:
m = [(q[0][:, :, 0:computed_kv_len, :], q[1][:, :, 0:computed_kv_len, :]) for q in m]
m = [
(
q[0][:, :, 0:computed_kv_len, :],
q[1][:, :, 0:computed_kv_len, :],
)
for q in m
]
model_kwargs["past_key_values"] = m

suspend, num_forks, ff_tokens = runner.pre_status(seq_id)
Expand All @@ -157,7 +154,13 @@ def main(args):
assert num_forks <= 1, "forking not implemented"

if len(ff_tokens) > 0:
input_ids = torch.cat([input_ids, torch.tensor(ff_tokens, dtype=torch.long, device=model.device)], dim=0)
input_ids = torch.cat(
[
input_ids,
torch.tensor(ff_tokens, dtype=torch.long, device=model.device),
],
dim=0,
)

except StopGeneration:
runner.print_logs()
Expand Down

0 comments on commit a91287d

Please sign in to comment.