Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
rename marlinV2 to machete
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasWilkinson committed Jul 31, 2024
1 parent ccf4457 commit 7f8cf90
Show file tree
Hide file tree
Showing 17 changed files with 117 additions and 117 deletions.
32 changes: 16 additions & 16 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -225,34 +225,34 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()

#
# For the CUTLASS-Marlin (marlinv2, mixed precision cutlass kernels) we
# For the CUTLASS-Marlin (machete, mixed precision cutlass kernels) we
# automatically generate sources for various preselected input type pairs
# and schedules.
# Generate sources:
execute_process(
COMMAND ${CMAKE_COMMAND} -E env
PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:$PYTHONPATH
${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/marlinv2/generate.py
RESULT_VARIABLE marlinv2_generation_result
OUTPUT_VARIABLE marlinv2_generation_output
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlinv2_generation.log
ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlinv2_generation.log
${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/machete/generate.py
RESULT_VARIABLE machete_generation_result
OUTPUT_VARIABLE machete_generation_output
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log
ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log
)

if (NOT marlinv2_generation_result EQUAL 0)
file(READ ${CMAKE_CURRENT_BINARY_DIR}/marlinv2_generation.log log)
message(FATAL_ERROR "MarlinV2 generation failed."
" Result: \"${marlinv2_generation_result}\""
if (NOT machete_generation_result EQUAL 0)
file(READ ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log log)
message(FATAL_ERROR "Machete generation failed."
" Result: \"${machete_generation_result}\""
"\nCheck the log for details: "
"${CMAKE_CURRENT_BINARY_DIR}/marlinv2_generation.log")
"${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log")
else()
message(STATUS "MarlinV2 generation completed successfully.")
message(STATUS "Machete generation completed successfully.")
endif()

# Add marlinv2 generated sources
file(GLOB MARLINV2_GEN_SOURCES "csrc/quantization/marlinv2/generated/*.cu")
# Add machete generated sources
file(GLOB MARLINV2_GEN_SOURCES "csrc/quantization/machete/generated/*.cu")
list(APPEND VLLM_EXT_SRC ${MARLINV2_GEN_SOURCES})
message(STATUS "MarlinV2 generated sources: ${MARLINV2_GEN_SOURCES}")
message(STATUS "Machete generated sources: ${MARLINV2_GEN_SOURCES}")

# See comment above for scaled_mm_c3x (same if condition)
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0)
Expand All @@ -265,7 +265,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")

# Add pytorch binding
list(APPEND VLLM_EXT_SRC
csrc/quantization/marlinv2/marlinv2_pytorch.cu)
csrc/quantization/machete/machete_pytorch.cu)
endif()

define_gpu_extension_target(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
DEFAULT_TP_SIZES = [1]


def marlinv2_pack_weights(w_q: torch.tensor,
def machete_pack_weights(w_q: torch.tensor,
wtype: ScalarType) -> torch.tensor:
w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape)
w_q = w_q.t().contiguous().t() # make col major
return ops.marlinv2_prepack_B(w_q, wtype)
return ops.machete_prepack_B(w_q, wtype)


def make_bench_tensors(
Expand Down Expand Up @@ -94,11 +94,11 @@ def bench(atype: torch.dtype,
label: str,
sub_label: str,
benchmark_marlinv1: bool = True,
benchmark_marlinv2_best: bool = True) -> Iterable[TMeasurement]:
benchmark_machete_best: bool = True) -> Iterable[TMeasurement]:
a, weights = make_bench_tensors(atype, wtype, group_size, m, n, k)
sub_label += f", L={len(weights)}"

weights_marlinv2 = [(w_ref, marlinv2_pack_weights(w_q, wtype), w_s, w_zp)
weights_machete = [(w_ref, machete_pack_weights(w_q, wtype), w_s, w_zp)
for w_ref, w_q, w_s, w_zp in weights]

timers = []
Expand Down Expand Up @@ -151,30 +151,30 @@ def marlinv1_permute_scales(w_s: torch.tensor) -> torch.tensor:
size_k=w_ref.shape[0],
is_k_full=True))))

# marlinv2
# machete
timers.append(
bench_fn(
label, sub_label, "marlinv2_heuristic", lambda: loop_over_weights(
a, weights_marlinv2, lambda a, _, w_q, w_s: ops.marlinv2_gemm(
label, sub_label, "machete_heuristic", lambda: loop_over_weights(
a, weights_machete, lambda a, _, w_q, w_s: ops.machete_gemm(
a, w_q, wtype, b_scales=w_s, b_group_size=group_size))))

if benchmark_marlinv2_best:
print("Finding best schedule for marlinv2")
if benchmark_machete_best:
print("Finding best schedule for machete")
best = None
best_schedule = None
schedules = ops.marlinv2_supported_schedules(wtype)
schedules = ops.machete_supported_schedules(wtype)
for schedule in reversed(schedules):

def run(a, _, w_q, w_s, schedule=schedule):
ops.marlinv2_gemm(a,
ops.machete_gemm(a,
w_q,
wtype,
w_s,
b_group_size=group_size,
schedule=schedule)

res = bench_fn(label, sub_label, "marlinv2_best",
lambda: loop_over_weights(a, weights_marlinv2, run))
res = bench_fn(label, sub_label, "machete_best",
lambda: loop_over_weights(a, weights_machete, run))

print(f" {res.median:5.5} ", schedule)
if not best or res.median < best.median:
Expand Down Expand Up @@ -299,16 +299,16 @@ def to_torch_dtype(dt):

parser = FlexibleArgumentParser(
description="""
Benchmark MarlinV2 GEMM.
Benchmark Machete GEMM.
To run square GEMMs:
python3 ./benchmarks/kernels/benchmark_marlinv2.py --dtype float16 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
To run constant N and K and sweep M:
python3 ./benchmarks/kernels/benchmark_marlinv2.py --dtype float16 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
To run dimensions from a model:
python3 ./benchmarks/kernels/benchmark_marlinv2.py --dtype float16 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
Output:
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,4 @@
plt.ylabel("time (median, s)")
axs_idx += 1
plt.tight_layout()
plt.savefig("graph_marlinv2_bench.pdf")
plt.savefig("graph_machete_bench.pdf")
4 changes: 2 additions & 2 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace,
int64_t size_m, int64_t size_n, int64_t size_k);

namespace marlinv2 {
namespace machete {

std::vector<vllm::ScalarTypeTorchPtr> supported_types();

Expand All @@ -102,7 +102,7 @@ torch::Tensor gemm(torch::Tensor const A, torch::Tensor const B,
torch::Tensor prepack_B(torch::Tensor const B,
vllm::ScalarTypeTorchPtr const& btype);

}; // namespace marlinv2
}; // namespace machete

torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_meta,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
#

DISPATCH_TEMPLATE = """
#include "../marlinv2_mm_launcher.cuh"
#include "../machete_mm_launcher.cuh"
namespace marlinv2 {
namespace machete {
using KernelDispatcher_ = KernelDispatcher<
{{DataTypeTag[type_config.element_a]}}, // ElementA
{{DataTypeTag[type_config.element_b]}}, // ElementB
Expand All @@ -46,7 +46,7 @@
return impl_{{ type_name }}_sch_{{ schedule_name }}(args);
}
{% endfor %}
TORCH_CHECK_NOT_IMPLEMENTED(false, "marlinv2_gemm(..) is not implemented for "
TORCH_CHECK_NOT_IMPLEMENTED(false, "machete_gemm(..) is not implemented for "
"schedule = ", *args.schedule);
}
Expand All @@ -59,13 +59,13 @@
};
}
}; // namespace marlinv2
}; // namespace machete
"""

IMPL_TEMPLATE = """
#include "../marlinv2_mm_launcher.cuh"
#include "../machete_mm_launcher.cuh"
namespace marlinv2 {
namespace machete {
template <typename Config, bool with_C, bool with_scales, bool with_zeropoints>
using Kernel = KernelTemplate<
{{DataTypeTag[type_config.element_a]}}, // ElementA
Expand Down Expand Up @@ -104,20 +104,20 @@
}{% endfor %}
TORCH_CHECK_NOT_IMPLEMENTED(
false, "for the sake of compile times and binary size marlinv2_mm(..) is "
false, "for the sake of compile times and binary size machete_mm(..) is "
" not implemented for with_C=", with_C, ", with_scales=", with_scales,
", with_zeropoints=", with_zeropoints,
" (for {{type_name}}_sch_{{schedule_name}})");
}
{% endfor %}
}; // namespace marlinv2
}; // namespace machete
"""

PREPACK_TEMPLATE = """
#include "../marlinv2_prepack_launcher.cuh"
#include "../machete_prepack_launcher.cuh"
namespace marlinv2 {
namespace machete {
using PrepackBDispatcher_ = PrepackBDispatcher<
{{DataTypeTag[type_config.element_a]}}, // ElementA
{{DataTypeTag[type_config.element_b]}}, // ElementB
Expand All @@ -138,7 +138,7 @@
torch::Tensor PrepackBDispatcher_::dispatch(torch::Tensor B) {
return prepack_impl<PrepackedLayoutB>(B);
}
}; // namespace marlinv2
}; // namespace machete
"""

TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput
Expand Down Expand Up @@ -279,14 +279,14 @@ def create_sources(impl_config: ImplConfig, num_impl_files=2):
terse_type_name = generate_terse_type_signature(impl_config.type_config)

sources.append((
f"marlinv2_mm_{terse_type_name}",
f"machete_mm_{terse_type_name}",
mm_dispatch_template.render(type_name=type_name,
type_config=impl_config.type_config,
schedules=schedules_with_names),
))

sources.append((
f"marlinv2_prepack_{terse_type_name}",
f"machete_prepack_{terse_type_name}",
prepack_dispatch_template.render(
type_name=type_name,
type_config=impl_config.type_config,
Expand All @@ -299,7 +299,7 @@ def create_sources(impl_config: ImplConfig, num_impl_files=2):
file_schedules = schedules_with_names[i:i + schedules_per_file]

sources.append((
f"marlinv2_mm_{terse_type_name}_impl_part{part}",
f"machete_mm_{terse_type_name}_impl_part{part}",
mm_impl_template.render(
type_name=type_name,
type_config=impl_config.type_config,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
#pragma once

#include "cutlass_extensions/vllm_collective_builder.cuh"
#include "marlinv2_mainloop.cuh"
#include "machete_mainloop.cuh"

namespace cutlass::gemm::collective {
using namespace cute;

struct MarlinV2KernelTag {};
struct MacheteKernelTag {};

template <class ElementPairA_, class GmemLayoutA_, int AlignmentA,
class ElementPairB_, class GmemLayoutB_, int AlignmentB,
class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
class StageCountType, class KernelScheduleType>
struct VLLMCollectiveBuilder<
MarlinV2KernelTag, arch::Sm90, arch::OpClassTensorOp, ElementPairA_,
MacheteKernelTag, arch::Sm90, arch::OpClassTensorOp, ElementPairA_,
GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_, AlignmentB,
ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType,
KernelScheduleType,
Expand All @@ -24,7 +24,7 @@ struct VLLMCollectiveBuilder<
KernelTmaWarpSpecializedPingpongMixedInput> ||
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedCooperativeMixedInput>)>> {
using CollectiveOp = marlinv2::MarlinV2CollectiveMma<
using CollectiveOp = machete::MacheteCollectiveMma<
ElementPairA_, GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_,
AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK,
StageCountType, KernelScheduleType>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

/////////////////////////////////////////////////////////////////////////////////////////////////

namespace marlinv2 {
namespace machete {

using namespace cute;
using namespace cutlass;
Expand All @@ -46,7 +46,7 @@ template <class ElementATuple_, class GmemLayoutA, int AlignmentA,
class ElementAccumulator_, class TileShape_MNK,
class ClusterShape_MNK, class StageCountType,
class KernelScheduleType>
struct MarlinV2CollectiveMma {
struct MacheteCollectiveMma {
using Schedule = KernelScheduleType;
static_assert(
cute::is_same_v<Schedule, KernelTmaWarpSpecialized> ||
Expand Down Expand Up @@ -1407,6 +1407,6 @@ struct MarlinV2CollectiveMma {

/////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace marlinv2
} // namespace machete

/////////////////////////////////////////////////////////////////////////////////////////////////
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
#include "cutlass_extensions/gemm/kernel/vllm_tile_schedulers.cuh"
#include "cutlass_extensions/cute_utils.cuh"
#include "cutlass_extensions/vllm_numeric_conversion.cuh"
#include "marlinv2_collective_builder.cuh"
#include "marlinv2_prepacked_layout.cuh"
#include "machete_collective_builder.cuh"
#include "machete_prepacked_layout.cuh"

namespace marlinv2 {
namespace machete {

using namespace cute;

Expand Down Expand Up @@ -123,8 +123,8 @@ struct KernelTemplate {

using CollectiveMainloop =
typename cutlass::gemm::collective::VLLMCollectiveBuilder<
cutlass::gemm::collective::MarlinV2KernelTag, ArchTag,
OperatorClass, BTypeTuple, PrepackedLayoutBB, AlignmentB, ElementA,
cutlass::gemm::collective::MacheteKernelTag, ArchTag, OperatorClass,
BTypeTuple, PrepackedLayoutBB, AlignmentB, ElementA,
LayoutA_Transpose, AlignmentA, ElementAccumulator, TileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
Expand Down Expand Up @@ -216,4 +216,4 @@ struct KernelTemplate {
};
};

}; // namespace marlinv2
}; // namespace machete
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
#include <torch/all.h>
#include <Python.h>

#include "marlinv2_mm_kernel.cuh"
#include "machete_mm_kernel.cuh"
#include "cutlass_extensions/torch_utils.hpp"

namespace marlinv2 {
namespace machete {

struct PytorchArguments {
torch::Tensor const A;
Expand Down Expand Up @@ -85,4 +85,4 @@ struct KernelDispatcher {
static std::vector<std::string> supported_schedules();
};

}; // namespace marlinv2
}; // namespace machete
Loading

0 comments on commit 7f8cf90

Please sign in to comment.