Skip to content

Commit

Permalink
Point to azure links for specs and fix timesteps dim in gpu scheduler.
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Jun 19, 2024
1 parent b1f20f1 commit 618d01f
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def prepare_model_input(self, sample, t, timesteps):
latent_model_input = torch.cat([sample] * 2)
else:
latent_model_input = sample
t = t.expand(sample.shape[0])
t = t.expand(latent_model_input.shape[0])
return latent_model_input.type(self.dtype), t.type(self.dtype)

def step(self, noise_pred, t, sample, guidance_scale, i):
Expand Down
8 changes: 4 additions & 4 deletions models/turbine_models/custom_models/sd_inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
"--iree-codegen-gpu-native-math-precision=true",
"--iree-rocm-waves-per-eu=2",
"--iree-flow-inline-constants-max-byte-length=1",
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics))",
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics, iree-linalg-ext-pad-attention{pad-to-multiple-of=0,128,0,32,0}))",
],
"unet": [
"--iree-flow-enable-aggressive-fusion",
Expand Down Expand Up @@ -275,7 +275,7 @@ def create_safe_name(hf_model_name, model_name_str):


def get_mfma_spec_path(target_chip, save_dir):
url = "https://raw.githubusercontent.com/iree-org/iree/main/build_tools/pkgci/external_test_suite/attention_and_matmul_spec.mlir"
url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/latest/attention_and_matmul_spec_gfx942.mlir"
attn_spec = urlopen(url).read().decode("utf-8")
spec_path = os.path.join(save_dir, "attention_and_matmul_spec_mfma.mlir")
if os.path.exists(spec_path):
Expand All @@ -287,9 +287,9 @@ def get_mfma_spec_path(target_chip, save_dir):

def get_wmma_spec_path(target_chip, save_dir):
if target_chip == "gfx1100":
url = "https://github.com/iree-org/iree/raw/shared/tresleches-united/scripts/attention_gfx1100.spec.mlir"
url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/latest/attention_and_matmul_spec_gfx1100.mlir"
elif target_chip in ["gfx1103", "gfx1150"]:
url = "https://github.com/iree-org/iree/raw/shared/tresleches-united/scripts/attention_gfx1103.spec.mlir"
url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/latest/attention_and_matmul_spec_gfx1150.mlir"
else:
return None
attn_spec = urlopen(url).read().decode("utf-8")
Expand Down

0 comments on commit 618d01f

Please sign in to comment.