Skip to content

Commit

Permalink
Add support for batched generation and synthetic long prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
Chillee committed Aug 30, 2024
1 parent bc04265 commit 61c193d
Showing 1 changed file with 61 additions and 27 deletions.
88 changes: 61 additions & 27 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import sys
import time
from pathlib import Path
from typing import Optional, Tuple
from typing import Optional, Tuple, Union

import torch
import torch._dynamo.config
Expand All @@ -24,7 +24,9 @@ def device_sync(device):

torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
# Experimental features to reduce compilation times, will be on by default in future
torch._inductor.config.fx_graph_cache = True
torch._functorch.config.enable_autograd_cache = True

default_device = 'cuda' if torch.cuda.is_available() else 'cpu'

Expand All @@ -50,7 +52,7 @@ def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = Non
return probs

def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
probs = logits_to_probs(logits[0, -1], temperature, top_k)
probs = logits_to_probs(logits[:, -1], temperature, top_k)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs

Expand All @@ -76,7 +78,7 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc
new_tokens.append(next_token.clone())
callback(new_tokens[-1])
new_probs.append(next_prob.clone())
cur_token = next_token.view(1, -1)
cur_token = next_token.clone()

return new_tokens, new_probs

Expand Down Expand Up @@ -139,6 +141,7 @@ def generate(
model: Transformer,
prompt: torch.Tensor,
max_new_tokens: int,
batch_size: int,
*,
interactive: bool,
draft_model: Transformer,
Expand All @@ -152,7 +155,7 @@ def generate(

is_speculative = draft_model is not None
# create an empty tensor of the expected final shape and fill in the current tokens
T = prompt.size(0)
T = prompt.size(-1)
T_new = T + max_new_tokens
if interactive:
max_seq_length = 350
Expand All @@ -162,20 +165,22 @@ def generate(
device, dtype = prompt.device, prompt.dtype
max_seq_length = max_seq_length + speculate_k + 1 if is_speculative else max_seq_length
with torch.device(device):
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
model.setup_caches(max_batch_size=batch_size, max_seq_length=max_seq_length)
if is_speculative and draft_model is not model:
draft_model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
draft_model.setup_caches(max_batch_size=batch_size, max_seq_length=max_seq_length)

# create an empty tensor of the expected final shape and fill in the current tokens
empty = torch.empty(T_new, dtype=dtype, device=device)
empty[:T] = prompt
empty = torch.empty(batch_size, T_new, dtype=dtype, device=device)
# We are just making the same prompt for every batch
prompt = prompt.view(1, -1).repeat(batch_size, 1)
empty[:, :T] = prompt
seq = empty
input_pos = torch.arange(0, T, device=device)

next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs).clone()
next_token = prefill(model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs).clone()
if is_speculative:
prefill(draft_model, prompt.view(1, -1), input_pos, **sampling_kwargs)
seq[T] = next_token
prefill(draft_model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs)
seq[:, T] = next_token.squeeze()

input_pos = torch.tensor([T], device=device, dtype=torch.int)
accept_counts = [0] * (speculate_k + 1)
Expand All @@ -197,8 +202,8 @@ def generate(
input_pos = input_pos + num_added
next_token = next_tokens[-1]
else:
generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs)
seq[T + 1:] = torch.cat(generated_tokens)
generated_tokens, _ = decode_n_tokens(model, next_token.view(batch_size, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs)
seq[:, T + 1:] = torch.cat(generated_tokens, dim=-1)

generate_stats = {
'accept_counts': accept_counts
Expand Down Expand Up @@ -245,6 +250,7 @@ def _load_model(checkpoint_path, device, precision, use_tp):

def _get_model_size(model):
model_size = 0
params = 0
for name, child in model.named_children():
if not isinstance(child, torch.nn.Embedding):
model_size += sum(
Expand All @@ -253,15 +259,22 @@ def _get_model_size(model):
for p in itertools.chain(child.parameters(), child.buffers())
]
)
return model_size
params += sum(
[
p.numel()
for p in itertools.chain(child.parameters(), child.buffers())
]
)
return model_size, params

B_INST, E_INST = "[INST]", "[/INST]"

def main(
prompt: str = "Hello, my name is",
prompt: Union[int, str] = "Hello, my name is",
interactive: bool = False,
num_samples: int = 5,
max_new_tokens: int = 100,
batch_size: int = 1,
top_k: int = 200,
temperature: float = 0.8,
checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"),
Expand Down Expand Up @@ -307,11 +320,15 @@ def main(

tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)

encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
prompt_length = encoded.size(0)
if isinstance(prompt, str):
encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
else:
# generate a fully synthetic prompt
encoded = torch.randint(0, 1024, (prompt,), device=device, dtype=torch.int64)
prompt_length = encoded.size(-1)

torch.manual_seed(1234)
model_size = _get_model_size(model)
model_size, params = _get_model_size(model)
if compile:
if is_speculative and use_tp: # and ("cuda" in device):
torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case
Expand Down Expand Up @@ -371,6 +388,7 @@ def callback(x):
model,
encoded,
max_new_tokens,
batch_size=batch_size,
draft_model=draft_model,
speculate_k=speculate_k,
interactive=interactive,
Expand All @@ -391,21 +409,30 @@ def callback(x):
t = time.perf_counter() - t0

if not interactive:
print(tokenizer.decode(y.tolist()))
# Just displaying the first generation
if batch_size > 1:
print("Only displaying the first generation of the batch")
print(tokenizer.decode(y[0].tolist()))
else:
print()
tokens_generated = y.size(0) - prompt_length
tokens_sec = tokens_generated / t
aggregate_metrics['tokens_per_sec'].append(tokens_sec)
print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec")
print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
tokens_generated = y.size(-1) - prompt_length
generated_tokens_sec = tokens_generated / t
aggregate_metrics['tokens_per_sec'].append(generated_tokens_sec)
print(f"Time for inference {i + 1}: {t:.02f} sec total, {generated_tokens_sec:.02f} tokens/sec")
print(f"Bandwidth achieved: {model_size * generated_tokens_sec / 1e9:.02f} GB/s")
total_tokens_sec = y.numel() / t
print(f"FLOPS achieved: {params * total_tokens_sec * 2 / 1e12:.02f} TF/s")
print()
print("==========")
if is_speculative:
counts_aggregated = [sum(i) for i in zip(*aggregate_metrics['accept_counts'])]
acceptance_probs = [i/sum(counts_aggregated) for i in counts_aggregated]
print(f"Acceptance probs: {acceptance_probs}")
print(f"Mean Accepted: {sum([idx * i for idx, i in enumerate(counts_aggregated)])/sum(counts_aggregated)}")

print(f"Batch Size: {batch_size}")
print(f"Prompt Length: {prompt_length}")
print(f"Generated tokens: {max_new_tokens}")
print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}")
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")

Expand All @@ -414,10 +441,17 @@ def callback(x):
import argparse
parser = argparse.ArgumentParser(description='Your CLI description.')

parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.')
def int_or_str(x):
try:
return int(x)
except:
return x

parser.add_argument('--prompt', type=int_or_str, default="Hello, my name is", help="Input prompt. If it's an integer, will instead generate a synthetic prompt.")
parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode')
parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.')
parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.')
parser.add_argument('--batch_size', type=int, default=1, help='Batch size to benchmark with')
parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.')
parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
Expand All @@ -430,7 +464,7 @@ def callback(x):

args = parser.parse_args()
main(
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.batch_size, args.top_k,
args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.draft_checkpoint_path,
args.speculate_k, args.device
)

0 comments on commit 61c193d

Please sign in to comment.