You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm glad the torch.compile is speeding up very quickly. On A5000 it can speed up 60%, but there's no acceleration at l4. I want to know why is it happen?
Here is my code, you can set --compile when run this code:
importtimeimporttorchimporttransformersfromtransformersimportAutoModelForCausalLM, AutoTokenizer, StaticCachefromtransformersimportset_seedimportosos.environ["TOKENIZERS_PARALLELISM"] ="false"defprint_separater():
print("="*20, "\n")
defget_model_and_tokenizer(model_path, device, dtype):
tokenizer=AutoTokenizer.from_pretrained(model_path)
model=AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=dtype,
device_map=device
)
model.tokenizer=tokenizerreturnmodel, tokenizerdefbenchmark_throughput(model, model_inputs, args):
device=model.deviceset_seed(args.seed)
ifdevice=="cuda":
torch.cuda.synchronize()
t0=time.time()
greedy_output=model.generate(
**model_inputs,
max_new_tokens=args.max_new_tokens,
do_sample=args.do_sample,
top_k=args.top_k,
temperature=args.temperature,
output_scores=True,
return_dict_in_generate=True,
use_cache=True,
).sequencesifdevice=="cuda":
torch.cuda.synchronize()
t1=time.time()
time_elasped=t1-t0num_tokens=greedy_output.numel() -model_inputs['input_ids'].numel()
print("Output:\n"+100*'-')
print(model.tokenizer.decode(greedy_output[0], skip_special_tokens=False))
print("Generated Tokens:", num_tokens)
print("Time Elasped (s):", time_elasped)
throughput=num_tokens/time_elaspedreturnthroughputdefmain(args):
print("torch and transformer version:", torch.__version__, transformers.__version__)
print(torch.__config__.parallel_info())
print(f"device: {args.device}, dtype: {args.dtype}")
print(f"model: {args.model_path}")
print_separater()
model, tokenizer=get_model_and_tokenizer(args.model_path, args.device, args.dtype)
model_inputs=tokenizer(args.prompt, return_tensors='pt').to(args.device)
warm_up_tokens=20set_seed(args.seed)
warm_up_output=model.generate(**model_inputs, max_new_tokens=warm_up_tokens)
throughput=benchmark_throughput(model, model_inputs, args)
print("throughput eager (token/s):", throughput)
ifargs.compile:
t0=time.time()
model._static_cache=StaticCache(
config=model.config,
max_batch_size=1,
max_cache_len=4096,
device=model.device,
dtype=torch.float16,
)
model.model.forward=torch.compile(
model.model.forward,
backend=args.dynamo_backend,
mode=args.dynamo_mode,
dynamic=None,
fullgraph=True,
disable=False
)
t1=time.time()
print("Compile time (s):", t1-t0)
set_seed(args.seed)
warm_up_output_compiled=model.generate(
**model_inputs, max_new_tokens=warm_up_tokens)
print("Warm-up result agree:", torch.equal(warm_up_output, warm_up_output_compiled))
print_separater()
throughput_compiled=benchmark_throughput(model, model_inputs, args)
print_separater()
print("compile speed-up:", throughput_compiled/throughput)
if__name__=='__main__':
importargparseparser=argparse.ArgumentParser(description='Your CLI description.')
parser.add_argument('--device', type=str,
default="cuda")
parser.add_argument('--dtype', default=torch.float16)
parser.add_argument('--model_path', type=str,
default="meta-llama/Meta-Llama-3-8B", help='HF model name or path.')
parser.add_argument('--prompt', type=str,
default="Q: What is the largest animal?\nA:", help='Input prompt.')
parser.add_argument('--max_new_tokens', type=int,
default=256, help='Maximum number of new tokens.')
parser.add_argument('--do_sample', action='store_true',
help='Whether to use sampling. Default is greedy search.')
parser.add_argument('--top_k', type=int,
default=200, help='Top-k for sampling.')
parser.add_argument('--temperature', type=float,
default=0.8, help='Temperature for sampling.')
parser.add_argument('--compile', action='store_true',
help='Whether to compile the model.')
parser.add_argument('--dynamo_backend', type=str,
default="inductor", help='torch._dynamo.list_backends()')
parser.add_argument('--dynamo_mode', type=str,
default="default", help='["default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"]')
parser.add_argument('--seed', type=int, default=42, help='Random seed.')
args=parser.parse_args()
main(args)
The text was updated successfully, but these errors were encountered:
I'm glad the torch.compile is speeding up very quickly. On A5000 it can speed up 60%, but there's no acceleration at l4. I want to know why is it happen?
Here is my code, you can set --compile when run this code:
The text was updated successfully, but these errors were encountered: