Skip to content

Commit

Permalink
Merge branch 'develop' into add_matmulintegertofloat_contrib_op
Browse files Browse the repository at this point in the history
  • Loading branch information
TedThemistokleous authored Dec 11, 2024
2 parents 6e2a36c + 2e59073 commit 13063df
Show file tree
Hide file tree
Showing 14 changed files with 156 additions and 60 deletions.
2 changes: 1 addition & 1 deletion Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def rocmtestnode(Map conf) {
pre()
sh "docker pull ${DOCKER_IMAGE}:${env.IMAGE_TAG}"
withDockerContainer(image: "${DOCKER_IMAGE}:${env.IMAGE_TAG}", args: "--device=/dev/kfd --device=/dev/dri --group-add video --cap-add SYS_PTRACE -v=/home/jenkins/:/home/jenkins ${docker_args}") {
timeout(time: 3, unit: 'HOURS') {
timeout(time: 4, unit: 'HOURS') {
body(cmake_build)
}
}
Expand Down
2 changes: 1 addition & 1 deletion docs/sphinx/requirements.in
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
rocm-docs-core==1.10.0
rocm-docs-core==1.11.0
sphinx-collapse
2 changes: 1 addition & 1 deletion docs/sphinx/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ requests==2.32.3
# via
# pygithub
# sphinx
rocm-docs-core==1.10.0
rocm-docs-core==1.11.0
# via -r requirements.in
smmap==5.0.1
# via gitdb
Expand Down
20 changes: 11 additions & 9 deletions examples/diffusion/python_stable_diffusion_3/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,9 @@ huggingface-cli login
```

Export the models to onnx.
Currently, optimum does not have the changes required in their latest release. Please follow the steps to build optimum from scratch.
Currently, optimum does not have the changes required in their latest release. Please install from their development branch instead.
```bash
git clone --single-branch --branch main https://github.com/huggingface/optimum.git
cd optimum
make build_dist_install_tools
make build_dist
cd dist
pip install *.whl
cd ../..
python -m pip install optimum[onnxruntime]@git+https://github.com/huggingface/optimum.git
```

Once optimum is built, use the following command to export the models:
Expand All @@ -54,7 +48,7 @@ optimum-cli export onnx --model stabilityai/stable-diffusion-3-medium-diffusers
Run the text-to-image script with the following example prompt and seed (optionally, you can change the batch size / number of images generated for that prompt)

```bash
MIGRAPHX_DISABLE_REDUCE_FUSION=1 python txt2img.py --prompt "a photograph of an astronaut riding a horse" --steps 50 --output astro_horse.jpg
python txt2img.py --prompt "a photograph of an astronaut riding a horse" --steps 50 --output astro_horse.jpg
```
> [!NOTE]
> The first run will compile the models and cache them to make subsequent runs faster. New batch sizes will result in the models re-compiling.*
Expand All @@ -63,3 +57,11 @@ The result should look like this:

![example_output.jpg](./example_output.jpg)

## Lower Memory Usage Pipeline
The entire pipeline is memory intensive, even when quantizing to fp16. The T5XXL encoder can be disabled alongside fp16 quantization to reduce total GPU memory usage to under 16G.

There will be a slight accuracy penalty when disabling T5XXL.
```bash
python txt2img.py --prompt "a photograph of an astronaut riding a horse" --steps 50 --skip-t5 --fp16=all --output astro_horse.jpg
```

61 changes: 43 additions & 18 deletions examples/diffusion/python_stable_diffusion_3/txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,14 @@ def get_args():
help="Perform exhaustive tuning when compiling onnx models",
)

parser.add_argument(
"--skip-t5",
action="store_true",
default=False,
help=
"Skip the third text encoder. Small accuracy penalty but large memory savings."
)

# Runtime
parser.add_argument(
"-s",
Expand Down Expand Up @@ -207,15 +215,22 @@ def allocate_torch_tensors(model):


class StableDiffusionMGX():
def __init__(self, onnx_model_path, compiled_model_path, fp16, batch,
force_compile, exhaustive_tune):
def __init__(self,
onnx_model_path,
compiled_model_path,
fp16,
batch,
force_compile=False,
exhaustive_tune=False,
skip_t5=False):

self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers",
subfolder="scheduler")

self.tokenizer = SD3Tokenizer()
self.device = "cuda"
self.skip_t5 = skip_t5

if fp16 is None:
fp16 = []
Expand Down Expand Up @@ -254,15 +269,6 @@ def __init__(self, onnx_model_path, compiled_model_path, fp16, batch,
force_compile=force_compile,
exhaustive_tune=exhaustive_tune,
offload_copy=False),
"t5xxl":
StableDiffusionMGX.load_mgx_model(
"text_encoder_3", {"input_ids": [1, 77]},
onnx_model_path,
compiled_model_path=compiled_model_path,
use_fp16="clip" in fp16,
force_compile=force_compile,
exhaustive_tune=exhaustive_tune,
offload_copy=False),
"mmdit":
StableDiffusionMGX.load_mgx_model(
"transformer", {
Expand All @@ -283,19 +289,32 @@ def __init__(self, onnx_model_path, compiled_model_path, fp16, batch,
self.tensors = {
"clip-g": allocate_torch_tensors(self.models["clip-g"]),
"clip-l": allocate_torch_tensors(self.models["clip-l"]),
"t5xxl": allocate_torch_tensors(self.models["t5xxl"]),
# "t5xxl": allocate_torch_tensors(self.models["t5xxl"]),
"mmdit": allocate_torch_tensors(self.models["mmdit"]),
"vae": allocate_torch_tensors(self.models["vae"]),
}

self.model_args = {
"clip-g": tensors_to_args(self.tensors['clip-g']),
"clip-l": tensors_to_args(self.tensors['clip-l']),
"t5xxl": tensors_to_args(self.tensors['t5xxl']),
# "t5xxl": tensors_to_args(self.tensors['t5xxl']),
"mmdit": tensors_to_args(self.tensors['mmdit']),
"vae": tensors_to_args(self.tensors['vae']),
}

if not self.skip_t5:
self.models["t5xxl"] = StableDiffusionMGX.load_mgx_model(
"text_encoder_3", {"input_ids": [1, 77]},
onnx_model_path,
compiled_model_path=compiled_model_path,
use_fp16="clip" in fp16,
force_compile=force_compile,
exhaustive_tune=exhaustive_tune,
offload_copy=False)
self.tensors["t5xxl"] = allocate_torch_tensors(
self.models["t5xxl"])
self.model_args["t5xxl"] = tensors_to_args(self.tensors['t5xxl'])

self.events = {
"warmup":
HipEventPair(start=hip.hipEventCreate()[1],
Expand Down Expand Up @@ -468,7 +487,11 @@ def get_embeddings(self, prompt_tokens):
prompt_tokens["l"])
g_out, g_pooled = self.encode_token_weights("clip-g",
prompt_tokens["g"])
t5_out, _ = self.encode_token_weights("t5xxl", prompt_tokens["t5xxl"])
if not self.skip_t5:
t5_out, _ = self.encode_token_weights("t5xxl",
prompt_tokens["t5xxl"])
else:
t5_out = torch.zeros((1, 77, 4096)).cuda()
lg_out = torch.cat([l_out, g_out], dim=-1)
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))

Expand Down Expand Up @@ -539,8 +562,9 @@ def warmup(self, num_runs):
torch.ones((1, 77)).to(torch.int32))
copy_tensor_sync(self.tensors["clip-g"]["input_ids"],
torch.ones((1, 77)).to(torch.int32))
copy_tensor_sync(self.tensors["t5xxl"]["input_ids"],
torch.ones((1, 77)).to(torch.int32))
if not self.skip_t5:
copy_tensor_sync(self.tensors["t5xxl"]["input_ids"],
torch.ones((1, 77)).to(torch.int32))
copy_tensor_sync(
self.tensors["mmdit"]["hidden_states"],
torch.randn((2 * self.batch, 16, 128, 128)).to(torch.float))
Expand All @@ -558,7 +582,8 @@ def warmup(self, num_runs):
for _ in range(num_runs):
run_model_sync(self.models["clip-l"], self.model_args["clip-l"])
run_model_sync(self.models["clip-g"], self.model_args["clip-g"])
run_model_sync(self.models["t5xxl"], self.model_args["t5xxl"])
if not self.skip_t5:
run_model_sync(self.models["t5xxl"], self.model_args["t5xxl"])
run_model_sync(self.models["mmdit"], self.model_args["mmdit"])
run_model_sync(self.models["vae"], self.model_args["vae"])
self.profile_end("warmup")
Expand All @@ -569,7 +594,7 @@ def warmup(self, num_runs):

sd = StableDiffusionMGX(args.onnx_model_path, args.compiled_model_path,
args.fp16, args.batch, args.force_compile,
args.exhaustive_tune)
args.exhaustive_tune, args.skip_t5)
print("Warmup")
sd.warmup(5)
print("Run")
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build
msgpack/[email protected] -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCm/composable_kernel@b7775add2d28251674d81e220cd4a857b90b997a -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCm/rocMLIR@c443cf85a09f289c147d7b01f93c1e51390ff65f -DBUILD_FAT_LIBROCKCOMPILER=On
ROCm/rocMLIR@e61b0f0e516f09144445b3c8eb372f39eb82d53b -DBUILD_FAT_LIBROCKCOMPILER=On
1 change: 1 addition & 0 deletions src/include/migraphx/generic_float.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <limits>
#include <iostream>
#include <tuple>
#include <cstdint>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand Down
6 changes: 3 additions & 3 deletions src/onnx/parse_matmulnbits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ struct parse_matmulnbits : op_parser<parse_matmulnbits>
to_string_range(expected_b_lens) +
". Actual dims: " + to_string_range(args[1]->get_shape().lens()));

std::vector<size_t> expected_scales_lens{n * n_blocks_per_col};
if(args[2]->get_shape().lens() != expected_scales_lens)
const size_t expected_scales_lens = n * n_blocks_per_col;
if(args[2]->get_shape().elements() != expected_scales_lens)
MIGRAPHX_THROW("MatMulNBits: Input scales does not match expected dims: " +
to_string_range(expected_scales_lens) +
to_string(expected_scales_lens) +
". Actual dims: " + to_string_range(args[2]->get_shape().lens()));

if(args.size() > 3)
Expand Down
4 changes: 2 additions & 2 deletions src/targets/gpu/fuse_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ struct find_rocblas_gemm_pointwise : gemm_pointwise
shape s = c_ins->get_shape();
// const-fold input if not standard shape since rocblas can't handle it
// Updated for a case where "standard" shape has out-of-sequence strides
if(not s.standard() or s.normalize_standard() != s)
if(not s.standard())
{
auto c = make_op("contiguous");
auto l = c.compute(c.compute_shape({c_ins->get_shape()}), {c_ins->eval()});
Expand Down Expand Up @@ -903,7 +903,7 @@ struct find_layernorm_pointwise
{
auto matcher() const
{
return precompile_name("pointwise")(match::any_of[match::inputs()](
return precompile_name("pointwise")(match::arg(0)(
precompile_name("gpu::prelayernorm", "gpu::preadd_layernorm").bind("layernorm")));
}

Expand Down
9 changes: 5 additions & 4 deletions src/targets/gpu/gemm_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,11 @@ rocblas_datatype get_type(shape::type_t type)
MIGRAPHX_THROW("ROCBLAS_GEMM: data type not supported!");
}

void blas_shape(const shape& s)
void blas_shape(const shape& in_shape)
{
if(s.lens().size() < 2)
if(in_shape.lens().size() < 2)
return;
auto s = in_shape.normalize_standard();
if(std::none_of(s.strides().end() - 2, s.strides().end(), [](auto i) { return i == 1; }))
MIGRAPHX_THROW("GPU_GEMM: needs to have one matrix stride as 1");
if(std::any_of(s.strides().end() - 2, s.strides().end(), [](auto i) { return i == 0; }))
Expand Down Expand Up @@ -591,7 +592,7 @@ void gemm_compute(context& ctx,
std::transform(args.begin(),
args.end(),
std::back_inserter(input_shapes),
[](const argument& x) { return x.get_shape(); });
[](const argument& x) { return x.get_shape().normalize_standard(); });
auto gemm_item = gemm_impl<float>(output_shape, input_shapes, alpha, beta, compute_fp32);
gemm_item.run(ctx, args, solution_idx);
}
Expand All @@ -608,7 +609,7 @@ void gemm_compute(context& ctx,
std::transform(args.begin(),
args.end(),
std::back_inserter(input_shapes),
[](const argument& x) { return x.get_shape(); });
[](const argument& x) { return x.get_shape().normalize_standard(); });
auto gemm_item = gemm_impl<int32_t>(output_shape, input_shapes, alpha, beta, compute_fp32);
gemm_item.run(ctx, args, solution_idx);
}
Expand Down
43 changes: 26 additions & 17 deletions src/targets/gpu/target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,31 +105,39 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
// whiltelist supported Ops for the FP8 types
// different between fp8e4m3fnuz and OCP types because rocBLAS only has
// support for fp8e4m3fnuz
std::set<std::string> unsupported_fp8fnuz_ops = {};
if(not gpu::rocblas_fp8_available())
std::set<std::string> unsupported_fp8e4m3fnuz_ops = {};
if(not enabled(MIGRAPHX_ENABLE_HIPBLASLT_GEMM{}) and not gpu::rocblas_fp8_available())
{
unsupported_fp8fnuz_ops.insert("dot");
unsupported_fp8fnuz_ops.insert("quant_dot");
unsupported_fp8e4m3fnuz_ops.insert("dot");
unsupported_fp8e4m3fnuz_ops.insert("quant_dot");
}
#if MIGRAPHX_USE_MIOPEN
// MIOpen doesn't have support for fp8 pooling yet.
unsupported_fp8fnuz_ops.insert("pooling");
unsupported_fp8e4m3fnuz_ops.insert("pooling");
#endif
if(not gpu::gfx_has_fp8fnuz_intrinsics())
{
unsupported_fp8fnuz_ops.insert("convolution");
unsupported_fp8fnuz_ops.insert("quant_convolution");
unsupported_fp8e4m3fnuz_ops.insert("convolution");
unsupported_fp8e4m3fnuz_ops.insert("quant_convolution");
}
// add all device kernels
unsupported_fp8fnuz_ops.insert("logsoftmax");
unsupported_fp8fnuz_ops.insert("nonzero");
unsupported_fp8fnuz_ops.insert("prefix_scan_sum");
unsupported_fp8fnuz_ops.insert("scatter_none");
unsupported_fp8fnuz_ops.insert("topk");
unsupported_fp8fnuz_ops.insert("rnn_var_sl_shift_output");
unsupported_fp8fnuz_ops.insert("multinomial");
unsupported_fp8fnuz_ops.insert("argmax");
unsupported_fp8fnuz_ops.insert("argmin");
unsupported_fp8e4m3fnuz_ops.insert("logsoftmax");
unsupported_fp8e4m3fnuz_ops.insert("nonzero");
unsupported_fp8e4m3fnuz_ops.insert("prefix_scan_sum");
unsupported_fp8e4m3fnuz_ops.insert("scatter_none");
unsupported_fp8e4m3fnuz_ops.insert("topk");
unsupported_fp8e4m3fnuz_ops.insert("rnn_var_sl_shift_output");
unsupported_fp8e4m3fnuz_ops.insert("multinomial");
unsupported_fp8e4m3fnuz_ops.insert("argmax");
unsupported_fp8e4m3fnuz_ops.insert("argmin");

std::set<std::string> unsupported_fp8e5m2fnuz_ops = unsupported_fp8e4m3fnuz_ops;
// disable gemm for fp8e5m2fnuz if rocBLAS is being used
if(not enabled(MIGRAPHX_ENABLE_HIPBLASLT_GEMM{}))
{
unsupported_fp8e5m2fnuz_ops.insert("dot");
unsupported_fp8e5m2fnuz_ops.insert("quant_dot");
}

std::set<std::string> unsupported_fp8ocp_ops = {};
// TODO: remove this when the flag is removed
Expand Down Expand Up @@ -194,7 +202,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{},
prefuse_ops{},
dead_code_elimination{},
eliminate_data_type{{migraphx::shape::fp8e4m3fnuz_type, migraphx::shape::fp8e5m2fnuz_type}, shape::float_type, unsupported_fp8fnuz_ops},
eliminate_data_type{{migraphx::shape::fp8e4m3fnuz_type}, shape::float_type, unsupported_fp8e4m3fnuz_ops},
eliminate_data_type{{migraphx::shape::fp8e5m2fnuz_type}, shape::float_type, unsupported_fp8e5m2fnuz_ops},
eliminate_data_type{{migraphx::shape::fp8e4m3fn_type, migraphx::shape::fp8e5m2_type}, shape::float_type, unsupported_fp8ocp_ops},
dead_code_elimination{},
rewrite_reduce{},
Expand Down
3 changes: 1 addition & 2 deletions test/gpu/fuse_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,7 @@ TEST_CASE(layernorm_pointwise)
{
migraphx::program p1 = create_program(false);
run_pass(p1);
migraphx::program p2 = create_fused_program();
EXPECT(p1 == p2);
EXPECT(p1 == create_program(false));
}
}

Expand Down
2 changes: 1 addition & 1 deletion test/onnx/.onnxrt-commit
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1128882bfd2a97c20f8a2a5ddb26cb0d42d9ebba
d27fecd3d3837864a268bc96f00f2b8dce294697
Loading

0 comments on commit 13063df

Please sign in to comment.