Skip to content

Commit

Permalink
Merge branch 'main' into amyachev/vstoolchain
Browse files Browse the repository at this point in the history
  • Loading branch information
anmyachev authored Nov 19, 2024
2 parents f5f7101 + a8ca9e5 commit e986f52
Show file tree
Hide file tree
Showing 43 changed files with 626 additions and 787 deletions.
65 changes: 49 additions & 16 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,24 @@ endif()

include(ExternalProject)

set(CMAKE_CXX_STANDARD 17)

set(CMAKE_INCLUDE_CURRENT_DIR ON)

project(triton CXX)
include(CTest)

if(NOT WIN32)
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
endif()


list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")

# Options
if(WIN32)
set(DEFAULT_BUILD_PROTON OFF)
else()
set(DEFAULT_BUILD_PROTON ON)
endif()

# Define the option with the determined default value
option(TRITON_BUILD_PROTON "Build the Triton Proton profiler" ${DEFAULT_BUILD_PROTON})
option(TRITON_BUILD_TUTORIALS "Build C++ Triton tutorials" ON)
option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
option(TRITON_BUILD_PROTON "Build the Triton Proton profiler" ON)
option(TRITON_BUILD_UT "Build C++ Triton Unit Tests" ON)
option(TRITON_BUILD_WITH_CCACHE "Build with ccache (if available)" ON)
set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends")
Expand All @@ -49,10 +50,21 @@ endif()
# used conditionally in this file and by lit tests

# Customized release build type with assertions: TritonRelBuildWithAsserts
set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
set(CMAKE_C_FLAGS_TRITONBUILDWITHO1 "-O1")
set(CMAKE_CXX_FLAGS_TRITONBUILDWITHO1 "-O1")
if(NOT MSVC)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
set(CMAKE_C_FLAGS_TRITONBUILDWITHO1 "-O1")
set(CMAKE_CXX_FLAGS_TRITONBUILDWITHO1 "-O1")
else()
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "/Zi /Ob0 /Od /RTC1 /bigobj /Zc:preprocessor")
set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "/Zi /Ob0 /Od /RTC1 /bigobj /Zc:preprocessor")
set(CMAKE_EXE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
set(CMAKE_MODULE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
set(CMAKE_SHARED_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
set(CMAKE_STATIC_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
endif()

# Default build type
if(NOT CMAKE_BUILD_TYPE)
Expand All @@ -70,7 +82,15 @@ endif()

# Compiler flags
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17")
if(NOT MSVC)
if(NOT WIN32)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -fPIC")
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -Wno-deprecated")
endif()
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS /wd4244 /wd4624 /wd4715 /wd4530")
endif()


# #########
Expand Down Expand Up @@ -124,7 +144,11 @@ endfunction()


# Disable warnings that show up in external code (gtest;pybind11)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default -fvisibility=hidden")
if(NOT MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default -fvisibility=hidden")
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /WX-")
endif()

include_directories(".")
include_directories(${MLIR_INCLUDE_DIRS})
Expand All @@ -134,7 +158,8 @@ include_directories(${PROJECT_BINARY_DIR}/include) # Tablegen'd files
include_directories(${PROJECT_SOURCE_DIR}/third_party)
include_directories(${PROJECT_BINARY_DIR}/third_party) # Tablegen'd files

# link_directories(${LLVM_LIBRARY_DIR})
link_directories(${LLVM_LIBRARY_DIR})

add_subdirectory(include)
add_subdirectory(lib)

Expand Down Expand Up @@ -163,6 +188,8 @@ if(TRITON_BUILD_PYTHON_MODULE)
# using pip install.
include_directories(${PYTHON_INCLUDE_DIRS})
include_directories(${PYBIND11_INCLUDE_DIR})
message(STATUS "PYTHON_LIB_DIRS ${PYTHON_LIB_DIRS}")
link_directories(${PYTHON_LIB_DIRS})
else()
# Otherwise, we might be building from top CMakeLists.txt directly.
# Try to find Python and pybind11 packages.
Expand Down Expand Up @@ -245,7 +272,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
LLVMAArch64CodeGen
LLVMAArch64AsmParser
)
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64")
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64" OR CMAKE_SYSTEM_PROCESSOR MATCHES "AMD64")
list(APPEND TRITON_LIBRARIES
LLVMX86CodeGen
LLVMX86AsmParser
Expand Down Expand Up @@ -280,6 +307,8 @@ if(TRITON_BUILD_PYTHON_MODULE)
target_link_libraries(triton PUBLIC ${TRITON_LIBRARIES})
if(WIN32)
target_link_libraries(triton PRIVATE ${CMAKE_DL_LIBS})
set_target_properties(triton PROPERTIES SUFFIX ".pyd")
set_target_properties(triton PROPERTIES PREFIX "lib")
else()
target_link_libraries(triton PRIVATE z)
endif()
Expand All @@ -306,6 +335,10 @@ if(NOT TRITON_BUILD_PYTHON_MODULE)
add_subdirectory(third_party/${CODEGEN_BACKEND})
endforeach()
endif()
if(WIN32)
option(CMAKE_USE_WIN32_THREADS_INIT "using WIN32 threads" ON)
option(gtest_disable_pthreads "Disable uses of pthreads in gtest." ON)
endif()

add_subdirectory(third_party/f2reduce)
add_subdirectory(bin)
Expand Down
29 changes: 15 additions & 14 deletions include/triton/Analysis/Allocation.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ namespace mlir {
namespace triton {
class AllocationAnalysis;

/// Callback to allow backends to specify target-specific scratch sizes for
/// some operations.
using AllocationAnalysisScratchSizeFn = std::function<unsigned(Operation *)>;

unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op);

// To convert a tensor from one layout to another, we need to allocate a
// temporary buffer (i.e., scratch buffer) in shared memory. The conversion may
// require multiple iterations, with each iteration involving multiple
Expand Down Expand Up @@ -141,7 +147,8 @@ class Allocation {
explicit Allocation(Operation *operation) : operation(operation) {}

/// Runs allocation analysis on the given top-level operation.
template <typename AllocationAnalysis> void run(FuncAllocMapT &funcAllocMap);
void run(FuncAllocMapT &funcAllocMap,
triton::AllocationAnalysisScratchSizeFn scratchSizeGetter);

/// Returns the operation this analysis was constructed from.
Operation *getOperation() const { return operation; }
Expand Down Expand Up @@ -242,9 +249,6 @@ class Allocation {
size_t sharedMemorySize = 0;
};

template <>
void Allocation::run<triton::AllocationAnalysis>(FuncAllocMapT &funcAllocMap);

/// Static analysis that computes the allocation of shared memory buffers
/// of the entire call graph.
/// The allocation is performed in a post-order walk of the call graph.
Expand All @@ -255,19 +259,19 @@ class ModuleAllocation : public CallGraph<Allocation> {
public:
using FuncOffsetMapT = DenseMap<FunctionOpInterface, Value>;

template <typename AllocationAnalysis = triton::AllocationAnalysis>
static ModuleAllocation get(ModuleOp moduleOp) {
ModuleAllocation res(moduleOp);
res.walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
ModuleAllocation(ModuleOp moduleOp,
triton::AllocationAnalysisScratchSizeFn scratchSizeGetter =
triton::defaultAllocationAnalysisScratchSizeFn)
: CallGraph<Allocation>(moduleOp) {
walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
// Pre-order edge walk callback
[](CallOpInterface callOp, FunctionOpInterface funcOp) {},
// Post-order node walk callback
[&](FunctionOpInterface funcOp) {
auto [iter, inserted] = res.funcMap.try_emplace(funcOp, funcOp);
auto [iter, inserted] = funcMap.try_emplace(funcOp, funcOp);
if (inserted)
iter->second.template run<AllocationAnalysis>(res.funcMap);
iter->second.run(funcMap, scratchSizeGetter);
});
return res;
}

size_t getSharedMemorySize() {
Expand All @@ -292,9 +296,6 @@ class ModuleAllocation : public CallGraph<Allocation> {
}

private:
explicit ModuleAllocation(ModuleOp moduleOp)
: CallGraph<Allocation>(moduleOp) {}

FuncOffsetMapT sharedMemoryValue;
};

Expand Down
7 changes: 4 additions & 3 deletions include/triton/Analysis/AxisInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@ class AxisInfo {
public:
AxisInfo() : AxisInfo({}, {}, {}) {}

AxisInfo(DimVectorT contiguity, DimVectorT divisibility, DimVectorT constancy)
AxisInfo(ArrayRef<int64_t> contiguity, ArrayRef<int64_t> divisibility,
ArrayRef<int64_t> constancy)
: AxisInfo(contiguity, divisibility, constancy, std::nullopt) {}

AxisInfo(DimVectorT contiguity, DimVectorT divisibility, DimVectorT constancy,
std::optional<int64_t> constantValue)
AxisInfo(ArrayRef<int64_t> contiguity, ArrayRef<int64_t> divisibility,
ArrayRef<int64_t> constancy, std::optional<int64_t> constantValue)
: contiguity(contiguity), divisibility(divisibility),
constancy(constancy), constantValue(constantValue) {
assert(divisibility.size() == contiguity.size());
Expand Down
5 changes: 5 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,11 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> {
(ins "int":$opIdx,
"int":$kWidth)>,

InterfaceMethod<"Get the order of reps (tiles of this layout that tile the whole tensor). The fastest-changing axis first",
"SmallVector<unsigned>",
"getRepOrderForOperand",
(ins "int":$opIdx)>,

InterfaceMethod<"Return element sizes per thread for dot operands.", "SmallVector<unsigned>",
"getElemsPerThreadForOperands", (ins "ArrayRef<int64_t>":$tensorShape,
"Type":$eltTy,
Expand Down
Loading

0 comments on commit e986f52

Please sign in to comment.