Skip to content

Commit

Permalink
hf support for ff tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Feb 13, 2024
1 parent 7c5d0db commit 742f9ff
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 7 deletions.
2 changes: 1 addition & 1 deletion aicirt/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ impl SeqCtx {
let res = self.mutinst().post_process(req);
let json = serde_json::to_string(&res)?;
let e = t0.elapsed();
if e.as_micros() > 100 {
if e.as_micros() > 500 {
log::warn!("post_process took {:?}", e);
}
json
Expand Down
4 changes: 2 additions & 2 deletions scripts/hf.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ set -x

RUST_LOG=info,tokenizers=error,rllm=debug,aicirt=info \
PYTHONPATH=py \
python3 scripts/py/run_hf.py \
python3 scripts/py/run_hf_low.py \
--aici-rt ./target/release/aicirt \
--controller gh:microsoft/aici/pyctrl \
--controller-arg controllers/pyctrl/samples/phi.py \
--controller-arg controllers/pyctrl/samples/test.py \
--aici-tokenizer phi \
--model microsoft/phi-2
4 changes: 0 additions & 4 deletions scripts/py/run_hf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
#
# This is outdated
#

import argparse

from typing import cast, Optional, Union, List
Expand Down
180 changes: 180 additions & 0 deletions scripts/py/run_hf_low.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import argparse

from typing import cast, Optional, Union, List
import torch
import pyaici
import pyaici.comms
from pyaici.comms import AiciRunner
from torch import nn

from transformers import (
AutoTokenizer,
PreTrainedModel,
AutoModelForCausalLM,
PreTrainedTokenizer,
)

device = "cuda" if torch.cuda.is_available() else "cpu"


class StopGeneration(Exception):
pass


def check_stop(runner: AiciRunner, seq_id: int):
to_stop = runner.get_seqs_to_stop()
if seq_id in to_stop:
raise StopGeneration


def apply_bias(runner: AiciRunner, seq_id: int, scores: torch.Tensor) -> torch.Tensor:
bias_tensor = runner.recv_logit_bias()
runner.print_logs()
check_stop(runner, seq_id)
ff_tokens, backtrack = runner.mid_status(seq_id)
assert backtrack == 0, "backtrack not implemented"
assert len(ff_tokens) == 0, "ff_tokens not implemented"
bias_tensor = torch.from_numpy(bias_tensor).to(scores.device).to(scores.dtype)
# print(bias_tensor.shape, scores.shape, input_ids.shape)
vocab_size = bias_tensor.shape[1]
# scores should be the size of vocabulary but some models (phi-2) make it slightly bigger
assert scores.shape[1] <= vocab_size + 1000
scores = scores[:, 0:vocab_size]
assert scores.shape == bias_tensor.shape
return bias_tensor + scores


def main(args):
tokenizer = cast(
PreTrainedTokenizer, AutoTokenizer.from_pretrained(args.tokenizer or args.model)
)
model = AutoModelForCausalLM.from_pretrained(
args.model,
device_map="auto",
torch_dtype=torch.bfloat16,
)
model = cast(PreTrainedModel, model)
empty_tokens = cast(
List[int], tokenizer.convert_tokens_to_ids(tokenizer.tokenize(""))
)

runner = pyaici.runner_from_cli(args)

arg = ""
if args.controller_arg:
with open(args.controller_arg) as f:
arg = f.read()
req_id = "r1" # arbitrary string
seq_id = 1 # there can be multiple sequences in a single request
runner.instantiate(req_id, empty_tokens, args.controller, arg)
runner.assign_seq_id(req_id, seq_id)
runner.print_logs()

# we execute first post_pre here, so we get the initial ff_tokens
runner.exec_post_pre()
runner.print_logs()
suspend, num_forks, ff_tokens = runner.pre_status(seq_id)
to_stop = runner.get_seqs_to_stop()
if seq_id in to_stop:
print("AICI decided to stop")
exit(1)
assert not suspend, "forking not implemented"
assert num_forks <= 1, "forking not implemented"

prompt = torch.tensor(
empty_tokens + ff_tokens, dtype=torch.long, device=model.device
).unsqueeze(0)

model_kwargs = {
"attention_mask": None,
"use_cache": True,
}
input_ids = prompt.squeeze(0)
temperature = 0.01

try:
for _ in range(2000):
runner.add_mid(seq_id)
runner.exec_mid()

m_inp = input_ids.unsqueeze(0)
model_kwargs["attention_mask"] = torch.ones(m_inp.shape, dtype=torch.long, device=model.device)
model_inputs = model.prepare_inputs_for_generation(m_inp, **model_kwargs)
outputs = model(
**model_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
scores: torch.Tensor = outputs.logits[:, -1, :] / temperature

bias_tensor = runner.recv_logit_bias()
runner.print_logs()
check_stop(runner, seq_id)
ff_tokens, backtrack = runner.mid_status(seq_id)
assert backtrack == 0, "backtrack not implemented"
assert len(ff_tokens) == 0, "ff_tokens not implemented"
bias_tensor = torch.from_numpy(bias_tensor).to(scores.device).to(scores.dtype)
# print(bias_tensor.shape, scores.shape, input_ids.shape)
vocab_size = bias_tensor.shape[1]
# scores should be the size of vocabulary but some models (phi-2) make it slightly bigger
assert scores.shape[1] <= vocab_size + 1000
scores = scores[:, 0:vocab_size]
assert scores.shape == bias_tensor.shape

probs = nn.functional.softmax(scores, dim=-1)

if backtrack > 0 or len(ff_tokens) > 0:
next_tokens = torch.tensor(ff_tokens, dtype=torch.long, device=model.device)
else:
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)

runner.tokens_generated(seq_id, next_tokens.tolist())
runner.exec_post_pre()
runner.print_logs()
check_stop(runner, seq_id)

if backtrack > 0:
input_ids = input_ids[:-backtrack]
computed_kv_len = input_ids.shape[0]
input_ids = torch.cat([input_ids, next_tokens], dim=0)
model_kwargs = model._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=model.config.is_encoder_decoder,
)
if "past_key_values" in model_kwargs:
m = model_kwargs["past_key_values"]
# len(m) == num_layers, len(m[0]) == 2 (key, value)
# shape of each elt is (batch_size, num_heads, seq_len, head_dim)
if m is not None and backtrack > 0:
m = [(q[0][:, :, 0:computed_kv_len, :], q[1][:, :, 0:computed_kv_len, :]) for q in m]
model_kwargs["past_key_values"] = m

suspend, num_forks, ff_tokens = runner.pre_status(seq_id)
check_stop(runner, seq_id)
assert not suspend, "forking not implemented"
assert num_forks <= 1, "forking not implemented"

if len(ff_tokens) > 0:
input_ids = torch.cat([input_ids, torch.tensor(ff_tokens, dtype=torch.long, device=model.device)], dim=0)

except StopGeneration:
runner.print_logs()
print("AICI decided to stop")


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Demo on using HF Transformers with aicirt"
)
parser.add_argument("--model", type=str, required=True, help="model to use")
parser.add_argument(
"--tokenizer",
type=str,
default="",
help="tokenizer to use; defaults to model name",
)
pyaici.add_cli_args(parser, single=True)
args = parser.parse_args()
main(args)

0 comments on commit 742f9ff

Please sign in to comment.