From 16ee24963d565a2c26f94dd1a2ce48ddeab35984 Mon Sep 17 00:00:00 2001 From: George Petterson Date: Tue, 2 Jul 2024 11:20:09 -0400 Subject: [PATCH] Point to main instead of unify-sd branch --- .../custom_models/llm_cmd_opts.py | 289 ++++++ .../custom_models/stateless_llama.py | 930 ++++++++++++------ 2 files changed, 927 insertions(+), 292 deletions(-) create mode 100644 models/turbine_models/custom_models/llm_cmd_opts.py diff --git a/models/turbine_models/custom_models/llm_cmd_opts.py b/models/turbine_models/custom_models/llm_cmd_opts.py new file mode 100644 index 000000000..2e7248d3c --- /dev/null +++ b/models/turbine_models/custom_models/llm_cmd_opts.py @@ -0,0 +1,289 @@ +import argparse +import os +from pathlib import Path + + +def path_expand(s): + return Path(s).expanduser().resolve() + + +def is_valid_file(arg): + if not os.path.exists(arg): + return None + else: + return arg + + +# Note: this is where command-line options for the scripts in this directory +# are defined along with their defaults. Thus, they should not be referenced +# within modelling or inference code, only at the entry point to the script. + +# We should consider separating out the options that are "model configs" from +# the options that control the compiler, runtime, and script behavior, +# when applicable, as the former would best be kept in a separate +# config or imported from huggingface. + +p = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter +) + +############################################################################## +# SDXL Huggingface Options +############################################################################## + +p.add_argument( + "--hf_auth_token", + type=str, + help="The Hugging Face auth token, if required", + default=None, +) +p.add_argument( + "--hf_model_name", + type=str, + help="HF model name", + default="Trelis/Llama-2-7b-chat-hf-function-calling-v2", +) +p.add_argument( + "--scheduler_id", + type=str, + help="Scheduler ID", + default="Euler", +) + +############################################################################## +# SDXL Inference Options +# These options are used to control runtime parameters for SDXL inference. +############################################################################## + +p.add_argument( + "--prompt", + type=str, + default=" a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + help="Prompt input to stable diffusion.", +) + +p.add_argument( + "--negative_prompt", + type=str, + default="Watermark, blurry, oversaturated, low resolution, pollution", + help="Negative prompt input to stable diffusion.", +) + +p.add_argument( + "--num_inference_steps", type=int, default=30, help="Number of UNet inference steps" +) + +p.add_argument( + "--batch_count", + type=int, + default=1, + help="Number of batches to run for a single prompt", +) + +p.add_argument( + "--guidance_scale", + type=float, + default=7.5, + help="Scale by which to adjust prompt guidance to the unconditional noise prediction output of UNet after each iteration.", +) + +p.add_argument( + "--seed", type=float, default=0, help="Seed for random number/latents generation." +) + +p.add_argument( + "--external_weight_path", + type=str, + default="", + help="Path to external weights file, for jobs with one weights filepath. When importing, this is used to specify where to save the model weights, and at runtime, this is used to specify where to load the model weights from.", +) + +p.add_argument( + "--external_weights_dir", + type=str, + default="", + help="Directory containing external weights for a job that requires more than one weights file. When importing, this is used to specify where to save the model weights, and at runtime, this is used to specify where to load the model weights from. Files will then be saved according to the parameters that make them unique, i.e. ___.", +) + +p.add_argument( + "--vmfb_path", type=str, default="", help="path to vmfb containing compiled module" +) + +p.add_argument( + "--pipeline_vmfb_path", + type=str, + default="", + help="path to vmfb containing compiled meta-module", +) + +p.add_argument( + "--external_weight_file", + type=str, + default=None, + help="Path to external weights, used in benchmark scripts.", +) + +p.add_argument( + "--pipeline_dir", + type=str, + default=None, + help="Directory to save pipeline artifacts", +) + +p.add_argument( + "--compiled_pipeline", + default=False, + action="store_true", + help="Do one-shot inference from tokens to image in a shrink-wrapped pipeline binary.", +) + +############################################################################## +# SDXL Modelling Options +# These options are used to control model defining parameters for SDXL. +# These are MLIR - changing variables! If you change them, you will need +# to import/download and recompile the model. +############################################################################## + +p.add_argument("--batch_size", type=int, default=1, help="Batch size for inference") +p.add_argument( + "--height", type=int, default=1024, help="Height of Stable Diffusion output image." +) +p.add_argument( + "--width", type=int, default=1024, help="Width of Stable Diffusion output image" +) +p.add_argument( + "--precision", + type=str, + default="fp16", + help="Precision of Stable Diffusion weights and graph.", +) +p.add_argument( + "--max_length", type=int, default=64, help="Sequence Length of Stable Diffusion" +) +p.add_argument("--vae_variant", type=str, default="decode", help="encode, decode") +p.add_argument( + "--return_index", + action="store_true", + help="Make scheduled unet compiled module return the step index.", +) + +p.add_argument( + "--vae_decomp_attn", + type=bool, + default=False, + help="Decompose attention for VAE decode only at fx graph level", +) + +############################################################################## +# SDXL script general options. +############################################################################## + +p.add_argument("--compile_to", type=str, default="mlir", help="torch, linalg, vmfb") + +p.add_argument( + "--external_weights", + type=str, + default=None, + choices=["safetensors", "irpa", "gguf", None], + help="Externalizes model weights from the torch dialect IR and its successors", +) + +# See --external_weight_path and --external_weight_dir to specify where to save the model weights. + +p.add_argument( + "--compare_vs_torch", + action="store_true", + help="Runs both turbine vmfb and a torch model to compare results", +) +p.add_argument( + "--decomp_attn", + default=False, + action="store_true", + help="Decompose attention at fx graph level", +) +p.add_argument( + "--exit_on_vmfb", + default=True, + action="store_false", + help="Exit program on vmfb compilation completion. Most scripts will also save .mlir if this is disabled.", +) +p.add_argument( + "--input_mlir", + type=str, + default=None, + help="Path to input mlir file to compile. Comma-separate paths to provide more than one input to pipelines.", +) +p.add_argument( + "--download_mlir", + default=False, + action="store_true", + help="Download missing mlir files from Azure storage.", +) +p.add_argument( + "--container_name", + type=str, + default=None, + help="Azure storage container name to download mlir files from.", +) + + +############################################################################## +# IREE Compiler Options +############################################################################## + +p.add_argument("--device", type=str, default="cpu", help="cpu, cuda, vulkan, rocm") + +p.add_argument( + "--rt_device", + type=str, + default="local-task", + help="local-task, local-sync, vulkan://0, rocm://0, cuda://0, etc.", +) + +# TODO: Bring in detection for target triple +p.add_argument( + "--iree_target_triple", + type=str, + default="", + help="Specify vulkan target triple or rocm/cuda target device.", +) + +p.add_argument("--ireec_flags", type=str, default="", help="extra iree-compile options") + +p.add_argument( + "--attn_flags", + type=str, + default="", + help="extra iree-compile options for models with iree_linalg_ext.attention ops.", +) + +p.add_argument( + "--attn_spec", + type=str, + default=None, + help="extra iree-compile options for models with iree_linalg_ext.attention ops. Set this to 'default' if you are using mfma-capable hardware with ROCM.", +) + +p.add_argument( + "--clip_flags", + type=str, + default="", + help="extra iree-compile options to send for compiling CLIP/prompt_encoder. Only use this for testing bleeding edge flags! Any default options should be added to sd_inference/utils.py", +) + +p.add_argument( + "--vae_flags", + type=str, + default="", + help="extra iree-compile options to send for compiling VAE. Only use this for testing bleeding edge flags! Any default options should be added to sd_inference/utils.py", +) + +p.add_argument( + "--unet_flags", + type=str, + default="", + help="extra iree-compile options to send for compiling unet. Only use this for testing bleeding edge flags! Any default options should be added to sd_inference/utils.py", +) + + +args, unknown = p.parse_known_args() diff --git a/models/turbine_models/custom_models/stateless_llama.py b/models/turbine_models/custom_models/stateless_llama.py index baa4e2348..51fafbce9 100644 --- a/models/turbine_models/custom_models/stateless_llama.py +++ b/models/turbine_models/custom_models/stateless_llama.py @@ -3,6 +3,7 @@ import re import json from turbine_models.turbine_tank import turbine_tank +from pathlib import Path os.environ["TORCH_LOGS"] = "dynamic" from transformers import AutoTokenizer, AutoModelForCausalLM @@ -13,6 +14,8 @@ from turbine_models.custom_models.llm_optimizations.streaming_llm.modify_llama import ( enable_llama_pos_shift_attention, ) +from turbine_models.custom_models.sd_inference.utils import compile_to_vmfb +from turbine_models.model_runner import vmfbRunner from turbine_models.custom_models import remap_gguf import safetensors @@ -30,7 +33,7 @@ "--hf_model_name", type=str, help="HF model name", - default="meta-llama/Llama-2-7b-chat-hf", + default="Trelis/Llama-2-7b-chat-hf-function-calling-v2", ) parser.add_argument("--quantization", type=str, default="unquantized") parser.add_argument("--external_weight_file", type=str, default="") @@ -62,6 +65,11 @@ action="store_true", help="Compile LLM with StreamingLLM optimizations", ) +parser.add_argument( + "--decomp_attn", + action="store_true", + help="Decompose attention ops at fx graph level.", +) def generate_schema(num_layers): @@ -116,14 +124,39 @@ def export_transformer_model( quantization=None, precision=None, device=None, - target_triple=None, + target_triple="x86_64-unknown-linux-gnu", vulkan_max_allocation=None, streaming_llm=False, vmfb_path=None, upload_ir=False, mod=None, tokenizer=None, + decomp_attn=False, + input_mlir=None, + iree_flags=[], ): + safe_name = hf_model_name.replace("-", "_").replace("/", "_") + if streaming_llm: + safe_name += "_streaming" + if not vmfb_path: + vmfb_path = safe_name + "_" + target_triple + + ukernel_supported_arch = {"gfx90a", "gfx940", "gfx1030", "gfx1100"} + if target_triple in ukernel_supported_arch: + iree_flags.extend(["--iree-rocm-enable-ukernels=argmax"]) + if input_mlir is not None: + vmfb_path = compile_to_vmfb( + input_mlir, + device, + target_triple, + ireec_flags=iree_flags, + safe_name=vmfb_path.split(".vmfb")[0], + return_path=True, + const_expr_hoisting=True, + mlir_source="file", + save_mlir=False, + attn_spec="mfma" if "gfx9" in target_triple else "wmma", + ) if tokenizer == None: tokenizer = AutoTokenizer.from_pretrained( hf_model_name, @@ -175,243 +208,261 @@ def export_transformer_model( tensor_mapper = remap_gguf.TensorNameMap(remap_gguf.MODEL_ARCH.LLAMA, HEADS) mapper = tensor_mapper.mapping - class StateUpdateModule(CompiledModule): - if external_weights: - params = export_parameters( - mod, external=True, external_scope="", name_mapper=mapper.get - ) - else: - params = export_parameters(mod) - global_seq_step = export_global(AbstractIndex, mutable=True) - global_k_caches = export_global_tree( - kv_cache_structure, uninitialized=True, mutable=True - ) - global_v_caches = export_global_tree( - kv_cache_structure, uninitialized=True, mutable=True - ) + initial_table = decompositions.current_aot_decompositions() + print("Decomposing torch SDPA") + with decompositions.extend_aot_decompositions( + from_current=True, + add_ops=[ + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten.masked_fill_.Scalar, + torch.ops.aten.copy, + ], + ): + current_table = decompositions.current_aot_decompositions() - def run_initialize(self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)): - init_const = [x.dynamic_dim(1) < MAX_STEP_SEQ] - token, *state = self.initialize(x, constraints=init_const) - self.global_seq_step = IREE.tensor_dim( - state[0], 1 - ) # ? dimension of arbitrarily 0th kv tensor - for i in range(NUM_LAYERS): - slice_of_state = IREE.tensor_reshape( - state[i * 2], 1, self.global_seq_step, HEADS, HIDDEN_DIM - ) - self.global_k_caches["layer_idx"][i] = IREE.tensor_update( - self.global_k_caches["layer_idx"][i], slice_of_state, 0, 0, 0, 0 - ) - for i in range(NUM_LAYERS): - slice_of_state = IREE.tensor_reshape( - state[i * 2 + 1], 1, self.global_seq_step, HEADS, HIDDEN_DIM - ) - self.global_v_caches["layer_idx"][i] = IREE.tensor_update( - self.global_v_caches["layer_idx"][i], slice_of_state, 0, 0, 0, 0 + class StateUpdateModule(CompiledModule): + if external_weights: + params = export_parameters( + mod, external=True, external_scope="", name_mapper=mapper.get ) - return token - - def run_forward(self, x=AbstractTensor(1, 1, dtype=torch.int64)): - state_arg = slice_up_to_step( - self.global_k_caches, - self.global_v_caches, - self.global_seq_step, - HEADS, - HIDDEN_DIM, - NUM_LAYERS, + else: + params = export_parameters(mod) + global_seq_step = export_global(AbstractIndex, mutable=True) + global_k_caches = export_global_tree( + kv_cache_structure, uninitialized=True, mutable=True ) - forw_const = ( - [state_arg[0].dynamic_dim(1) < MAX_STEP_SEQ] - + [ - x.dynamic_dim(1) == (state_arg[0].dynamic_dim(1)) - for x in state_arg[1:] - ] - + [x.dynamic_dim(1) < MAX_STEP_SEQ for x in state_arg[1:]] + global_v_caches = export_global_tree( + kv_cache_structure, uninitialized=True, mutable=True ) - token, *state_update = self.forward(x, *state_arg, constraints=forw_const) - for i in range(NUM_LAYERS): - update = IREE.tensor_reshape( - state_update[i * 2], 1, 1, HEADS, HIDDEN_DIM - ) - self.global_k_caches["layer_idx"][i] = IREE.tensor_update( - self.global_k_caches["layer_idx"][i], - update, - 0, + + def run_initialize( + self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64) + ): + init_const = [x.dynamic_dim(1) < MAX_STEP_SEQ] + token, *state = self.initialize(x, constraints=init_const) + self.global_seq_step = IREE.tensor_dim( + state[0], 1 + ) # ? dimension of arbitrarily 0th kv tensor + for i in range(NUM_LAYERS): + slice_of_state = IREE.tensor_reshape( + state[i * 2], 1, self.global_seq_step, HEADS, HIDDEN_DIM + ) + self.global_k_caches["layer_idx"][i] = IREE.tensor_update( + self.global_k_caches["layer_idx"][i], slice_of_state, 0, 0, 0, 0 + ) + for i in range(NUM_LAYERS): + slice_of_state = IREE.tensor_reshape( + state[i * 2 + 1], 1, self.global_seq_step, HEADS, HIDDEN_DIM + ) + self.global_v_caches["layer_idx"][i] = IREE.tensor_update( + self.global_v_caches["layer_idx"][i], slice_of_state, 0, 0, 0, 0 + ) + return token + + def run_forward(self, x=AbstractTensor(1, 1, dtype=torch.int64)): + state_arg = slice_up_to_step( + self.global_k_caches, + self.global_v_caches, self.global_seq_step, - 0, - 0, + HEADS, + HIDDEN_DIM, + NUM_LAYERS, ) - for i in range(NUM_LAYERS): - update = IREE.tensor_reshape( - state_update[i * 2 + 1], 1, 1, HEADS, HIDDEN_DIM + forw_const = ( + [state_arg[0].dynamic_dim(1) < MAX_STEP_SEQ] + + [ + x.dynamic_dim(1) == (state_arg[0].dynamic_dim(1)) + for x in state_arg[1:] + ] + + [x.dynamic_dim(1) < MAX_STEP_SEQ for x in state_arg[1:]] ) - self.global_v_caches["layer_idx"][i] = IREE.tensor_update( - self.global_v_caches["layer_idx"][i], - update, - 0, - self.global_seq_step, - 0, - 0, + token, *state_update = self.forward( + x, *state_arg, constraints=forw_const ) - self.global_seq_step = self.global_seq_step + 1 - return token - - def get_seq_step(self): - return self.global_seq_step - - @jittable - def initialize(input_ids): - result = mod.forward(input_ids) - state1_flat, _ = pytree.tree_flatten(result.past_key_values) - token1 = torch.argmax(result.logits[:, -1, :], dim=1) - token1 = token1[None, :] - state1_flat = [torch.transpose(x, 1, 2) for x in state1_flat] - return token1, *state1_flat - - @jittable - def forward(token0: torch.Tensor, *state0_flat): - # Unpad the states. - state0_flat = [torch.transpose(x, 1, 2) for x in state0_flat] - state0 = pytree.tree_unflatten(state0_flat, state_schema) - result = mod.forward(token0, past_key_values=state0) - state1_flat, _ = pytree.tree_flatten(result.past_key_values) - state1_flat = [torch.transpose(x[:, :, -1:, :], 1, 2) for x in state1_flat] - token1 = torch.argmax(result.logits[:, -1, :], dim=1) - token1 = token1[None, :] - return token1, *state1_flat - - class StreamingStateUpdateModule(StateUpdateModule): - def run_cached_initialize( - self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64) - ): - state_arg = slice_up_to_step( - self.global_k_caches, - self.global_v_caches, - self.global_seq_step, - HEADS, - HIDDEN_DIM, - NUM_LAYERS, - ) - forw_const = ( - [x.dynamic_dim(1) < MAX_STEP_SEQ] - + [state_arg[0].dynamic_dim(1) < MAX_STEP_SEQ] - + [ - x.dynamic_dim(1) == (state_arg[0].dynamic_dim(1)) - for x in state_arg[1:] + for i in range(NUM_LAYERS): + update = IREE.tensor_reshape( + state_update[i * 2], 1, 1, HEADS, HIDDEN_DIM + ) + self.global_k_caches["layer_idx"][i] = IREE.tensor_update( + self.global_k_caches["layer_idx"][i], + update, + 0, + self.global_seq_step, + 0, + 0, + ) + for i in range(NUM_LAYERS): + update = IREE.tensor_reshape( + state_update[i * 2 + 1], 1, 1, HEADS, HIDDEN_DIM + ) + self.global_v_caches["layer_idx"][i] = IREE.tensor_update( + self.global_v_caches["layer_idx"][i], + update, + 0, + self.global_seq_step, + 0, + 0, + ) + self.global_seq_step = self.global_seq_step + 1 + return token + + def get_seq_step(self): + return self.global_seq_step + + @jittable + def initialize(input_ids): + result = mod.forward(input_ids) + state1_flat, _ = pytree.tree_flatten(result.past_key_values) + token1 = torch.argmax(result.logits[:, -1, :], dim=1) + token1 = token1[None, :] + state1_flat = [torch.transpose(x, 1, 2) for x in state1_flat] + return token1, *state1_flat + + @jittable + def forward(token0: torch.Tensor, *state0_flat): + # Unpad the states. + state0_flat = [torch.transpose(x, 1, 2) for x in state0_flat] + state0 = pytree.tree_unflatten(state0_flat, state_schema) + result = mod.forward(token0, past_key_values=state0) + state1_flat, _ = pytree.tree_flatten(result.past_key_values) + state1_flat = [ + torch.transpose(x[:, :, -1:, :], 1, 2) for x in state1_flat ] - + [x.dynamic_dim(1) < MAX_STEP_SEQ for x in state_arg[1:]] - ) - token, *state = self.cached_initialize( - x, *state_arg, constraints=forw_const - ) - len_of_new_tokens = IREE.tensor_dim( - state[0], 1 - ) # ? dimension of arbitrarily 0th kv tensor - for i in range(NUM_LAYERS): - slice_of_state = IREE.tensor_reshape( - state[i * 2], 1, len_of_new_tokens, HEADS, HIDDEN_DIM - ) - self.global_k_caches["layer_idx"][i] = IREE.tensor_update( - self.global_k_caches["layer_idx"][i], - slice_of_state, - 0, - self.global_seq_step, - 0, - 0, - ) - for i in range(NUM_LAYERS): - slice_of_state = IREE.tensor_reshape( - state[i * 2 + 1], 1, len_of_new_tokens, HEADS, HIDDEN_DIM - ) - self.global_v_caches["layer_idx"][i] = IREE.tensor_update( - self.global_v_caches["layer_idx"][i], - slice_of_state, - 0, + token1 = torch.argmax(result.logits[:, -1, :], dim=1) + token1 = token1[None, :] + return token1, *state1_flat + + class StreamingStateUpdateModule(StateUpdateModule): + def run_cached_initialize( + self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64) + ): + state_arg = slice_up_to_step( + self.global_k_caches, + self.global_v_caches, self.global_seq_step, - 0, - 0, + HEADS, + HIDDEN_DIM, + NUM_LAYERS, ) - self.global_seq_step = self.global_seq_step + len_of_new_tokens - return token - - @jittable - def cached_initialize(input_ids, *state0_flat): - # Unpad the states. - cur_token_len = state0_flat[0].size(1) - state0_flat = [torch.transpose(x, 1, 2) for x in state0_flat] - state0 = pytree.tree_unflatten(state0_flat, state_schema) - result = mod.forward(input_ids, past_key_values=state0) - state1_flat, _ = pytree.tree_flatten(result.past_key_values) - state1_flat = [ - torch.transpose(x[:, :, cur_token_len:, :], 1, 2) for x in state1_flat - ] - token1 = torch.argmax(result.logits[:, -1, :], dim=1) - token1 = token1[None, :] - return token1, *state1_flat - - # Streaming-LLM KVCache evict algorithm: - # slice1 = KVCache[0 : sink] - # slice2 = KVCache[seq_len - window_size : seq_len] - # KVCache = torch.cat([slice1, slice2]) - # TODO: Add move to handle overlap of data. - def evict_kvcache_space(self): - # TODO: Replace hardcoded with global variable. - sink_size = 4 - window_size = 252 - most_recent_window = self.global_seq_step + (-window_size) - for i in range(NUM_LAYERS): - update_window_state = IREE.tensor_slice( - self.global_k_caches["layer_idx"][i], - 0, - (most_recent_window, window_size), - (0, HEADS), - (0, HIDDEN_DIM), - ) # sequence context dim - self.global_k_caches["layer_idx"][i] = IREE.tensor_update( - self.global_k_caches["layer_idx"][i], - update_window_state, - 0, - sink_size, - 0, - 0, + forw_const = ( + [x.dynamic_dim(1) < MAX_STEP_SEQ] + + [state_arg[0].dynamic_dim(1) < MAX_STEP_SEQ] + + [ + x.dynamic_dim(1) == (state_arg[0].dynamic_dim(1)) + for x in state_arg[1:] + ] + + [x.dynamic_dim(1) < MAX_STEP_SEQ for x in state_arg[1:]] ) - for i in range(NUM_LAYERS): - update_window_state = IREE.tensor_slice( - self.global_v_caches["layer_idx"][i], - 0, - (most_recent_window, window_size), - (0, HEADS), - (0, HIDDEN_DIM), - ) # sequence context dim - self.global_v_caches["layer_idx"][i] = IREE.tensor_update( - self.global_v_caches["layer_idx"][i], - update_window_state, - 0, - sink_size, - 0, - 0, + token, *state = self.cached_initialize( + x, *state_arg, constraints=forw_const ) - self.global_seq_step.set(window_size + sink_size) - return self.global_seq_step + len_of_new_tokens = IREE.tensor_dim( + state[0], 1 + ) # ? dimension of arbitrarily 0th kv tensor + for i in range(NUM_LAYERS): + slice_of_state = IREE.tensor_reshape( + state[i * 2], 1, len_of_new_tokens, HEADS, HIDDEN_DIM + ) + self.global_k_caches["layer_idx"][i] = IREE.tensor_update( + self.global_k_caches["layer_idx"][i], + slice_of_state, + 0, + self.global_seq_step, + 0, + 0, + ) + for i in range(NUM_LAYERS): + slice_of_state = IREE.tensor_reshape( + state[i * 2 + 1], 1, len_of_new_tokens, HEADS, HIDDEN_DIM + ) + self.global_v_caches["layer_idx"][i] = IREE.tensor_update( + self.global_v_caches["layer_idx"][i], + slice_of_state, + 0, + self.global_seq_step, + 0, + 0, + ) + self.global_seq_step = self.global_seq_step + len_of_new_tokens + return token - import_to = "INPUT" if compile_to == "linalg" else "IMPORT" - if streaming_llm: - print("Compiling with Streaming LLM") - inst = StreamingStateUpdateModule(context=Context(), import_to=import_to) - else: - inst = StateUpdateModule(context=Context(), import_to=import_to) - # TODO: Integrate with external parameters to actually be able to run - # TODO: Make more generalizable to be able to quantize with all compile_to options - if quantization == "int4" and not compile_to == "linalg": - from shark_turbine.transforms.quantization import mm_group_quant - - mm_group_quant.MMGroupQuantRewriterPass( - CompiledModule.get_mlir_module(inst).operation - ).run() - module_str = str(CompiledModule.get_mlir_module(inst)) - safe_name = hf_model_name.split("/")[-1].strip() - safe_name = re.sub("-", "_", safe_name) + @jittable + def cached_initialize(input_ids, *state0_flat): + # Unpad the states. + cur_token_len = state0_flat[0].size(1) + state0_flat = [torch.transpose(x, 1, 2) for x in state0_flat] + state0 = pytree.tree_unflatten(state0_flat, state_schema) + result = mod.forward(input_ids, past_key_values=state0) + state1_flat, _ = pytree.tree_flatten(result.past_key_values) + state1_flat = [ + torch.transpose(x[:, :, cur_token_len:, :], 1, 2) + for x in state1_flat + ] + token1 = torch.argmax(result.logits[:, -1, :], dim=1) + token1 = token1[None, :] + return token1, *state1_flat + + # Streaming-LLM KVCache evict algorithm: + # slice1 = KVCache[0 : sink] + # slice2 = KVCache[seq_len - window_size : seq_len] + # KVCache = torch.cat([slice1, slice2]) + # TODO: Add move to handle overlap of data. + def evict_kvcache_space(self): + # TODO: Replace hardcoded with global variable. + sink_size = 4 + window_size = 252 + most_recent_window = self.global_seq_step + (-window_size) + for i in range(NUM_LAYERS): + update_window_state = IREE.tensor_slice( + self.global_k_caches["layer_idx"][i], + 0, + (most_recent_window, window_size), + (0, HEADS), + (0, HIDDEN_DIM), + ) # sequence context dim + self.global_k_caches["layer_idx"][i] = IREE.tensor_update( + self.global_k_caches["layer_idx"][i], + update_window_state, + 0, + sink_size, + 0, + 0, + ) + for i in range(NUM_LAYERS): + update_window_state = IREE.tensor_slice( + self.global_v_caches["layer_idx"][i], + 0, + (most_recent_window, window_size), + (0, HEADS), + (0, HIDDEN_DIM), + ) # sequence context dim + self.global_v_caches["layer_idx"][i] = IREE.tensor_update( + self.global_v_caches["layer_idx"][i], + update_window_state, + 0, + sink_size, + 0, + 0, + ) + self.global_seq_step.set(window_size + sink_size) + return self.global_seq_step + + import_to = "INPUT" if compile_to == "linalg" else "IMPORT" + if streaming_llm: + print("Compiling with Streaming LLM") + inst = StreamingStateUpdateModule(context=Context(), import_to=import_to) + else: + inst = StateUpdateModule(context=Context(), import_to=import_to) + # TODO: Integrate with external parameters to actually be able to run + # TODO: Make more generalizable to be able to quantize with all compile_to options + if quantization == "int4" and not compile_to == "linalg": + from shark_turbine.transforms.quantization import mm_group_quant + + mm_group_quant.MMGroupQuantRewriterPass( + CompiledModule.get_mlir_module(inst).operation + ).run() + module_str = str(CompiledModule.get_mlir_module(inst)) if upload_ir: with open(f"{safe_name}.mlir", "w+") as f: f.write(module_str) @@ -423,84 +474,379 @@ def evict_kvcache_space(self): if compile_to != "vmfb": return module_str, tokenizer else: - flags = [ - "--iree-input-type=torch", - "--mlir-print-debuginfo", - "--mlir-print-op-on-diagnostic=false", - "--iree-llvmcpu-target-cpu-features=host", - "--iree-llvmcpu-target-triple=x86_64-linux-gnu", - "--iree-stream-resource-index-bits=64", - "--iree-vm-target-index-bits=64", - ] - if device == "cpu" or device == "llvm-cpu": - flags.append("--iree-llvmcpu-enable-ukernels=all") - device = "llvm-cpu" - elif device == "vulkan": - flags.extend( - [ - "--iree-vulkan-target-triple=" + target_triple, - "--iree-stream-resource-max-allocation-size=" - + vulkan_max_allocation, - ] + blob_name = compile_to_vmfb( + module_str, + device, + target_triple, + ireec_flags=iree_flags, + safe_name=vmfb_path.split(".vmfb")[0], + return_path=True, + const_expr_hoisting=True, + mlir_source="str", + save_mlir=False, + attn_spec="mfma" if "gfx9" in target_triple else "wmma", + ) + if upload_ir: + return blob_name + return blob_name, tokenizer + + +llm_model_map = { + "meta-llama/Llama-2-7b-chat-hf": { + "initializer": export_transformer_model, + "hf_model_name": "meta-llama/Llama-2-7b-chat-hf", + "compile_flags": ["--iree-opt-const-expr-hoisting=False"], + "stop_token": 2, + "max_tokens": 4096, + "system_prompt": """[INST] <>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <>""", + }, + "Trelis/Llama-2-7b-chat-hf-function-calling-v2": { + "initializer": export_transformer_model, + "hf_model_name": "Trelis/Llama-2-7b-chat-hf-function-calling-v2", + "compile_flags": ["--iree-opt-const-expr-hoisting=False"], + "stop_token": 2, + "max_tokens": 4096, + "system_prompt": """[INST] <>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <>""", + }, + "TinyPixel/small-llama2": { + "initializer": export_transformer_model, + "hf_model_name": "TinyPixel/small-llama2", + "compile_flags": ["--iree-opt-const-expr-hoisting=True"], + "stop_token": 2, + "max_tokens": 1024, + "system_prompt": """[INST] <>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <>""", + }, +} + + +class StatelessLlamaPipeline: + def __init__( + self, + hf_model_name: str, + scheduler_id: str, + height: int, + width: int, + precision: str, + max_length: int, + batch_size: int, + num_inference_steps: int, + device: str, + iree_target_triple: str, + ireec_flags: list = [], + attn_spec: str = None, + decomp_attn: bool = False, + pipeline_dir: str | Path = "./shark_vmfbs", + external_weights_dir: str | Path = "./shark_weights", + external_weights: str = "safetensors", + custom_vae: str = None, + vae_decomp_attn: bool = True, + hf_auth_token: str = None, + ): + self.hf_model_name = hf_model_name + self.iree_dtype = "float32" if precision == "fp32" else "float16" + self.torch_dtype = torch.float32 if precision == "fp32" else torch.float16 + self.cpu_scheduling = True + self.scheduler_id = scheduler_id + self.height = height + self.width = width + self.precision = precision + self.max_length = max_length + self.model_max_length = max_length + self.batch_size = batch_size + self.num_inference_steps = num_inference_steps + self.device = device + self.iree_target_triple = iree_target_triple + self.ireec_flags = ireec_flags + self.attn_spec = attn_spec + self.decomp_attn = decomp_attn + self.pipeline_dir = pipeline_dir + self.external_weights_dir = external_weights_dir + self.external_weights = external_weights + self.custom_vae = custom_vae + self.vae_decomp_attn = vae_decomp_attn + + self.first_input = True + self.max_tokens = llm_model_map[self.hf_model_name]["max_tokens"] + self.global_iter = 0 + self.prev_token_len = 0 + self.tokenizer = AutoTokenizer.from_pretrained( + self.hf_model_name, + use_fast=False, + use_auth_token=hf_auth_token, + ) + self.safe_name = "_".join( + [ + self.hf_model_name.replace("/", "_").replace("-", "_"), + self.precision, + ] + ) + self.model = None + self.hf_auth_token=hf_auth_token + + # FILE MANAGEMENT AND PIPELINE SETUP + + def check_prepared( + self, + mlir: str, + vmfb: str, + weight: str, + interactive: bool = False, + quantization: str = None, + ): + ready, vmfb, weight = self.is_prepared(vmfb, weight) + if not ready: + if interactive: + do_continue = input( + f"\nIt seems you are missing some necessary files. Would you like to generate them now? (y/n)" + ) + if do_continue.lower() != "y": + exit() + else: + do_continue = "y" + if do_continue.lower() == "y": + if vmfb is None: + v, w = self.export(input_mlir=mlir, quantization=quantization) + vmfb = v + if weight is None: + weight = w + if weight is None: + _, w = self.export(weights_only=True, quantization=quantization) + weight = w + ready, vmfb, weight = self.is_prepared(vmfb, weight) + if ready: + print("All necessary files found.") + return vmfb, weight + else: + print("There was an error generating the necessary files.") + exit() + else: + print("All necessary files found. Loading pipeline.") + return vmfb, weight + + def is_prepared(self, vmfb, weight): + missing = [] + default_filepath = os.path.join(self.pipeline_dir, self.safe_name + ".vmfb") + + # vmfb + if vmfb is None and os.path.exists(default_filepath): + vmfb = default_filepath + else: + missing.append(vmfb) + + # External weight + if not (weight is not None and os.path.exists(weight)): + if self.external_weights is None: + weight = None + else: + default_name = os.path.join( + self.external_weights_dir, self.safe_name + "." + self.external_weights ) - elif device == "rocm": - flags.extend( - [ - "--iree-rocm-target-chip=" + target_triple, - "--iree-rocm-link-bc=true", - "--iree-vm-bytecode-module-strip-source-map=true", - "--iree-opt-strip-assertions=true", - "--iree-vm-target-truncate-unsupported-floats", - ] + if weight is None and os.path.exists(default_name): + weight = os.path.join(default_name) + else: + missing.append(weight) + if len(missing) > 0: + # print(f"Missing files: " + ", ".join(missing)) + return False, vmfb, weight + else: + return True, vmfb, weight + + # IMPORT / COMPILE PHASE + + def export( + self, + quantization: str = None, + input_mlir: str = None, + weights_only: bool = False, + ): + safe_name = self.hf_model_name.replace("-", "_").replace("/", "_") + # if self.streaming_llm: + safe_name += "_streaming" + + if not os.path.exists(self.pipeline_dir): + os.makedirs(self.pipeline_dir) + if self.external_weights_dir: + if not os.path.exists(self.external_weights_dir): + os.makedirs(external_weights_dir, exist_ok=True) + external_weight_path = os.path.join( + self.external_weights_dir, safe_name + self.external_weights ) - ukernel_supported_arch = {"gfx90a", "gfx940", "gfx1030", "gfx1100"} - if target_triple in ukernel_supported_arch: - flags.extend(["--iree-rocm-enable-ukernels=argmax"]) - elif device == "cuda": - flags.extend( - [ - "--iree-hal-cuda-llvm-target-arch=" + target_triple, - "--iree-vm-bytecode-module-strip-source-map=true", - "--iree-vm-target-truncate-unsupported-floats", - ] + elif self.external_weights is None: + print( + "No external weights type specified using --external_weights, weights for imported .mlir files will not be externalized." ) + external_weight_path = None else: - print("Unknown device kind: ", device) - import iree.compiler as ireec + print( + f"No external weights directory specified using --external_weights_dir, we assume you have your own weights in {self.pipeline_dir}." + ) + external_weights_dir = self.pipeline_dir + external_weight_path = os.path.join( + self.pipeline_dir, safe_name + self.external_weights + ) + if weights_only: + input_mlir = None - flatbuffer_blob = ireec.compile_str( - module_str, - target_backends=[device], - extra_args=flags, + _, vmfb = export_transformer_model( + self.hf_model_name, + hf_auth_token=self.hf_auth_token, + compile_to="vmfb", + external_weights=self.external_weights, + external_weight_file=external_weight_path, + quantization=quantization, + precision=self.precision, + device=self.device, + target_triple=self.iree_target_triple, + vulkan_max_allocation=None, + streaming_llm=True, + vmfb_path=os.path.join(self.pipeline_dir, safe_name + ".vmfb"), + upload_ir=False, + mod=None, + tokenizer=None, + decomp_attn=False, + input_mlir=input_mlir, + iree_flags=self.ireec_flags, ) - if vmfb_path is None: - vmfb_path = f"{safe_name}.vmfb" - with open(vmfb_path, "wb+") as f: - f.write(flatbuffer_blob) - print("saved to ", safe_name + ".vmfb") - if upload_ir: - return blob_name - return module_str, tokenizer + return vmfb, external_weight_path + + # LOAD + def load_pipeline( + self, + vmfb: str, + weight: str, + rt_device: str = "local-task", + compiled_pipeline: bool = False, + ): + self.model = vmfbRunner(rt_device, vmfb, weight) + + # RUN + + def chat(self, prompt): + prompt = self.sanitize_prompt(prompt) + + input_tensor = self.tokenizer(prompt, return_tensors="pt").input_ids + + def format_out(results): + return torch.tensor(results.to_host()[0][0]) + + history = [] + for iter in range(self.max_tokens): + # if self.streaming_llm: + token_slice = max(self.prev_token_len - 1, 0) + input_tensor = input_tensor[:, token_slice:] + # if self.streaming_llm and self.model["get_seq_step"]() > 600: + if self.model["get_seq_step"]() > 600: + print("Evicting cache space!") + self.model["evict_kvcache_space"]() + token_len = input_tensor.shape[-1] + device_inputs = [ + ireert.asdevicearray(self.device, input_tensor) + ] + if self.first_input: # or not self.streaming_llm: + st_time = time.time() + token = self.model["run_initialize"](*device_inputs) + total_time = time.time() - st_time + token_len += 1 + self.first_input = False + else: + st_time = time.time() + token = self.model["run_cached_initialize"](*device_inputs) + total_time = time.time() - st_time + token_len += 1 + + history.append(format_out(token)) + while ( + format_out(token) != llm_model_map[self.hf_model_name]["stop_token"] + and len(history) < self.max_tokens + ): + dec_time = time.time() + if self.model["get_seq_step"]() > 600: + print("Evicting cache space!") + self.model["evict_kvcache_space"]() + token = self.model["run_forward"](token) + history.append(format_out(token)) + total_time = time.time() - dec_time + yield self.tokenizer.decode(history), total_time + + self.prev_token_len = token_len + len(history) + + if format_out(token) == llm_model_map[self.hf_model_name]["stop_token"]: + break + + for i in range(len(history)): + if type(history[i]) != int: + history[i] = int(history[i]) + result_output = self.tokenizer.decode(history) + self.global_iter += 1 + return result_output, total_time if __name__ == "__main__": - args = parser.parse_args() - mod_str, _ = export_transformer_model( + from turbine_models.custom_models.llm_cmd_opts import args + + mlir = args.input_mlir + vmfb = None + weight = None + + flags = [] + if "cpu" in args.device: + flags.extend( + [ + "--iree-global-opt-enable-quantized-matmul-reassociation", + ] + ) + elif args.device == "vulkan": + flags.extend(["--iree-stream-resource-max-allocation-size=4294967296"]) + elif args.device == "rocm": + flags.extend( + [ + "--iree-codegen-llvmgpu-enable-transform-dialect-jit=false", + "--iree-llvmgpu-enable-prefetch=true", + "--iree-opt-outer-dim-concat=true", + "--iree-flow-enable-aggressive-fusion", + ] + ) + if "gfx9" in args.iree_target_triple: + flags.extend( + [ + f"--iree-codegen-transform-dialect-library={get_mfma_spec_path(args.iree_target_triple, get_checkpoints_path())}", + "--iree-codegen-llvmgpu-use-vector-distribution=true", + ] + ) + flags.extend(llm_model_map[args.hf_model_name]["compile_flags"]) + + if not args.pipeline_dir: + args.pipeline_dir = "./shark_vmfbs" + if not args.external_weights_dir and args.external_weights: + args.external_weights_dir = args.pipeline_dir + + sd_pipe = StatelessLlamaPipeline( args.hf_model_name, - args.hf_auth_token, - args.compile_to, - args.external_weights, - args.external_weight_file, - args.quantization, + args.scheduler_id, + args.height, + args.width, args.precision, + args.max_length, + args.batch_size, + args.num_inference_steps, args.device, args.iree_target_triple, - args.vulkan_max_allocation, - args.streaming_llm, - args.vmfb_path, + flags, + args.attn_spec, + args.decomp_attn, + args.pipeline_dir, + args.external_weights_dir, + args.external_weights, + args.vae_decomp_attn, + args.hf_auth_token, + ) + vmfb, weight = sd_pipe.check_prepared(mlir, vmfb, weight, interactive=False, quantization="int4") + sd_pipe.load_pipeline(vmfb, weight, args.rt_device, args.compiled_pipeline) + sd_pipe.generate_images( + args.prompt, + args.negative_prompt, + args.batch_count, + args.guidance_scale, + args.seed, + False, ) - safe_name = args.hf_model_name.split("/")[-1].strip() - safe_name = re.sub("-", "_", safe_name) - with open(f"{safe_name}.mlir", "w+") as f: - f.write(mod_str) - print("Saved to ", safe_name + ".mlir")