Skip to content

Commit

Permalink
Merge commit 'fa229d1c4bee16c094be9427334575ec1e79f66c'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Oct 18, 2024
2 parents e48642c + fa229d1 commit f213106
Show file tree
Hide file tree
Showing 31 changed files with 1,132 additions and 129 deletions.
16 changes: 8 additions & 8 deletions .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,9 @@ jobs:
lit -v "${LIT_TEST_DIR}"
- name: Run python tests on CUDA
run: |
SHARED_LIB_DIR="${GITHUB_WORKSPACE}/python/build/$(ls python/build | grep -i lib)/triton/_C"
if [ ! -d "${SHARED_LIB_DIR}" ]; then
echo "Coult not find '${SHARED_LIB_DIR}'" ; exit -1
INSTRUMENTATION_LIB_DIR="${GITHUB_WORKSPACE}/python/build/$(ls python/build | grep -i lib)/triton/instrumentation"
if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then
echo "Coult not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
fi
cd python/test/unit
python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=test_debug.py
Expand All @@ -257,7 +257,7 @@ jobs:
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s language/test_line_info.py
# Run hopper/test_flashattention.py separately to avoid out of gpu memory
python3 -m pytest -s hopper/test_flashattention.py
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${SHARED_LIB_DIR}/libGPUHello.so \
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${INSTRUMENTATION_LIB_DIR}/libGPUInstrumentationTestLib.so \
python3 -m pytest --capture=tee-sys -rfs -vvv instrumentation/test_gpuhello.py
- name: Run interpreter tests
if: ${{ matrix.runner[0] == 'h100-runner-set' }}
Expand Down Expand Up @@ -401,9 +401,9 @@ jobs:
lit -v "${LIT_TEST_DIR}"
- name: Run python tests on HIP
run: |
SHARED_LIB_DIR="${GITHUB_WORKSPACE}/python/triton/_C"
if [ ! -d "${SHARED_LIB_DIR}" ]; then
echo "Coult not find '${SHARED_LIB_DIR}'" ; exit -1
INSTRUMENTATION_LIB_DIR="${GITHUB_WORKSPACE}/python/triton/instrumentation"
if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then
echo "Coult not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
fi
pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py
cd python/test/unit
Expand All @@ -412,7 +412,7 @@ jobs:
--ignore=test_debug.py
# TODO: uncomment
# pytest --capture=tee-sys -rfs test_debug.py
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${SHARED_LIB_DIR}/libGPUHello.so \
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${INSTRUMENTATION_LIB_DIR}/libGPUInstrumentationTestLib.so \
pytest --capture=tee-sys -rfs -vvv instrumentation/test_gpuhello.py
# Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0
Expand Down
16 changes: 8 additions & 8 deletions .github/workflows/integration-tests.yml.in
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,9 @@ jobs:

- name: Run python tests on CUDA
run: |
SHARED_LIB_DIR="${GITHUB_WORKSPACE}/python/build/$(ls python/build | grep -i lib)/triton/_C"
if [ ! -d "${SHARED_LIB_DIR}" ]; then
echo "Coult not find '${SHARED_LIB_DIR}'" ; exit -1
INSTRUMENTATION_LIB_DIR="${GITHUB_WORKSPACE}/python/build/$(ls python/build | grep -i lib)/triton/instrumentation"
if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then
echo "Coult not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
fi
cd python/test/unit
python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=test_debug.py
Expand All @@ -291,7 +291,7 @@ jobs:
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s language/test_line_info.py
# Run hopper/test_flashattention.py separately to avoid out of gpu memory
python3 -m pytest -s hopper/test_flashattention.py
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${SHARED_LIB_DIR}/libGPUHello.so \
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${INSTRUMENTATION_LIB_DIR}/libGPUInstrumentationTestLib.so \
python3 -m pytest --capture=tee-sys -rfs -vvv instrumentation/test_gpuhello.py

- name: Run interpreter tests
Expand Down Expand Up @@ -397,9 +397,9 @@ jobs:

- name: Run python tests on HIP
run: |
SHARED_LIB_DIR="${GITHUB_WORKSPACE}/python/triton/_C"
if [ ! -d "${SHARED_LIB_DIR}" ]; then
echo "Coult not find '${SHARED_LIB_DIR}'" ; exit -1
INSTRUMENTATION_LIB_DIR="${GITHUB_WORKSPACE}/python/triton/instrumentation"
if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then
echo "Coult not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
fi
pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py
cd python/test/unit
Expand All @@ -408,7 +408,7 @@ jobs:
--ignore=test_debug.py
# TODO: uncomment
# pytest --capture=tee-sys -rfs test_debug.py
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${SHARED_LIB_DIR}/libGPUHello.so \
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${INSTRUMENTATION_LIB_DIR}/libGPUInstrumentationTestLib.so \
pytest --capture=tee-sys -rfs -vvv instrumentation/test_gpuhello.py

# Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0
Expand Down
2 changes: 2 additions & 0 deletions bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
mlir::registerTritonAMDGPUStreamPipeline();
mlir::registerTritonAMDGPUStreamPipelineV2();
mlir::registerTritonAMDGPUCanonicalizePointers();
mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints();
mlir::triton::registerTritonAMDGPULowerInstructionSchedHints();

// TODO: register Triton & TritonGPU passes
registry.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,33 @@ constexpr int patternBenefitPrioritizeOverLLVMConversions = 10;
constexpr int patternBenefitClampOptimizedPattern = 20;
constexpr int patternBenefitConvertLayoutOptimizedPattern = 20;

struct BackendCallbacks {
/**
* A backend-specific callback for appending auxiliary data during
* `LocalStoreOp` conversion.
*
* @param[in] op The reference to the re-written `LocalStoreOp`.
* @param[in] count The number of issued LLVM instructions.
* @param[in] type The input type of issued LLVM instructions.
*/
std::function<void(triton::gpu::LocalStoreOp op, size_t llvmOpCount,
Type llvmOpType)>
localStoreOpConversion = nullptr;
};

void populateElementwiseOpToLLVMPatterns(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo,
PatternBenefit benefit);

void populateMemoryOpToLLVMPattern(LLVMTypeConverter &typeConverter,
const TargetInfoBase &targetInfo,
RewritePatternSet &patterns,
PatternBenefit benefit);
// The given callback is invoked at the end of a successful rewrite. The
// callback receives 1) the current source op, 2) the number of issued LLVM
// instructions and 3) their input types. Each MLIR backend can provide a
// callback and, thus, handle backend-specific behaviors.
void populateMemoryOpToLLVMPattern(
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
RewritePatternSet &patterns, PatternBenefit benefit,
std::optional<BackendCallbacks> backendCallbacks = std::nullopt);

void populateAssertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
Expand Down
10 changes: 5 additions & 5 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -1366,11 +1366,11 @@ SmallVector<Value> loadSharedToDistributed(RankedTensorType dstTy,
Location loc, RewriterBase &rewriter,
const TargetInfoBase &target);

void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy,
Type elemLlvmTy, ArrayRef<Value> srcVals,
Value smemBase, ArrayRef<Value> dstStrides,
Location loc, RewriterBase &rewriter,
const TargetInfoBase &target);
void storeDistributedToShared(
MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy,
ArrayRef<Value> srcVals, Value smemBase, ArrayRef<Value> dstStrides,
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
std::pair<size_t, Type> *const llvmOpCount = nullptr);

inline Value getStructFromSharedMemoryObject(Location loc,
const SharedMemoryObject &smemObj,
Expand Down
36 changes: 25 additions & 11 deletions lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@ using namespace mlir::triton::gpu;
// blocked -> shared.
// Swizzling in shared memory to avoid bank conflict. Normally used for
// A/B operands of dots.
void lowerDistributedToShared(Location loc, Value src, Value dst,
Value adaptorSrc,
const SharedMemoryObject &smemObj,
const LLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
const TargetInfoBase &targetInfo) {
void lowerDistributedToShared(
Location loc, Value src, Value dst, Value adaptorSrc,
const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo,
std::pair<size_t, Type> *const llvmOpCount = nullptr) {
auto srcTy = cast<RankedTensorType>(src.getType());
auto dstTy = cast<MemDescType>(dst.getType());
auto outOrd = mlir::cast<SharedEncodingAttr>(dstTy.getEncoding()).getOrder();
Expand All @@ -33,7 +32,7 @@ void lowerDistributedToShared(Location loc, Value src, Value dst,
auto dstStrides = smemObj.getStrides();
auto inVals = unpackLLElements(loc, adaptorSrc, rewriter);
storeDistributedToShared(dstTy, srcTy, elemTy, inVals, smemBase, dstStrides,
loc, rewriter, targetInfo);
loc, rewriter, targetInfo, llvmOpCount);
}

struct LocalAllocOpConversion
Expand Down Expand Up @@ -185,12 +184,15 @@ struct LocalStoreOpConversion
public:
using ConvertOpToLLVMPattern<
triton::gpu::LocalStoreOp>::ConvertOpToLLVMPattern;
using BackendCallbackType =
decltype(BackendCallbacks::localStoreOpConversion);

LocalStoreOpConversion(const LLVMTypeConverter &converter,
const TargetInfoBase &targetInfo,
BackendCallbackType backendCallback,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<triton::gpu::LocalStoreOp>(converter, benefit),
targetInfo(targetInfo) {}
targetInfo(targetInfo), backendCallback(backendCallback) {}

LogicalResult
matchAndRewrite(triton::gpu::LocalStoreOp op, OpAdaptor adaptor,
Expand All @@ -200,24 +202,36 @@ struct LocalStoreOpConversion
getTypeConverter()->convertType(op.getDst().getType().getElementType());
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
op.getLoc(), adaptor.getDst(), llvmElemTy, rewriter);

std::pair<size_t, Type> llvmOpCount;
lowerDistributedToShared(op.getLoc(), op.getSrc(), op.getDst(),
adaptor.getSrc(), smemObj, getTypeConverter(),
rewriter, targetInfo);
rewriter, targetInfo, &llvmOpCount);

if (backendCallback)
(backendCallback)(op, llvmOpCount.first, llvmOpCount.second);

rewriter.eraseOp(op);
return success();
}

private:
const TargetInfoBase &targetInfo;
BackendCallbackType backendCallback;
};

} // namespace

void mlir::triton::populateMemoryOpToLLVMPattern(
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
RewritePatternSet &patterns, PatternBenefit benefit) {
RewritePatternSet &patterns, PatternBenefit benefit,
std::optional<BackendCallbacks> backendCallbacks) {
patterns.add<LocalAllocOpConversion>(typeConverter, targetInfo, benefit);
patterns.add<LocalDeallocOpConversion>(typeConverter, benefit);
patterns.add<LocalLoadOpConversion>(typeConverter, targetInfo, benefit);
patterns.add<LocalStoreOpConversion>(typeConverter, targetInfo, benefit);

auto backendCall =
backendCallbacks ? backendCallbacks->localStoreOpConversion : nullptr;
patterns.add<LocalStoreOpConversion>(typeConverter, targetInfo, backendCall,
benefit);
}
8 changes: 7 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,8 @@ void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy,
Type elemLlvmTy, ArrayRef<Value> srcVals,
Value smemBase, ArrayRef<Value> dstStrides,
Location loc, RewriterBase &rewriter,
const TargetInfoBase &target) {
const TargetInfoBase &target,
std::pair<size_t, Type> *const llvmOpCount) {
bool success = emitTransferBetweenRegistersAndShared(
srcTy, dstTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemBase,
dstStrides, loc, rewriter, target, [&](VectorType vecTy, Value vecAddr) {
Expand All @@ -418,7 +419,12 @@ void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy,
store(vec, vecAddr)
.setAlignment(vecTy.getNumElements() *
elemLlvmTy.getIntOrFloatBitWidth() / 8);
if (llvmOpCount) {
++(llvmOpCount->first);
llvmOpCount->second = vecTy;
}
});

if (!success)
llvm::report_fatal_error("Failed to emit transfer from register to shared");
}
Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,9 +370,11 @@ class RewriteTensorPointerPass
}

// update rewritedInfo
auto opResults = op.getResults();
unsigned oldResIdx = 0, newResIdx = 0;
while (oldResIdx < results.size()) {
if (!triton::isTensorPointerType(results[oldResIdx].getType())) {
opResults[oldResIdx].replaceAllUsesWith(newOp.getResult(newResIdx));
oldResIdx++;
newResIdx++;
} else {
Expand Down
6 changes: 0 additions & 6 deletions python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1909,26 +1909,20 @@ def where(condition, x, y, _builder=None):
def add(x, y, sanitize_overflow: constexpr = True, _builder=None):
x = _unwrap_if_constexpr(x)
y = _unwrap_if_constexpr(y)
x = semantic.to_tensor(x, _builder)
y = semantic.to_tensor(y, _builder)
return semantic.add(x, y, sanitize_overflow, _builder)


@builtin
def sub(x, y, sanitize_overflow: constexpr = True, _builder=None):
x = _unwrap_if_constexpr(x)
y = _unwrap_if_constexpr(y)
x = semantic.to_tensor(x, _builder)
y = semantic.to_tensor(y, _builder)
return semantic.sub(x, y, sanitize_overflow, _builder)


@builtin
def mul(x, y, sanitize_overflow: constexpr = True, _builder=None):
x = _unwrap_if_constexpr(x)
y = _unwrap_if_constexpr(y)
x = semantic.to_tensor(x, _builder)
y = semantic.to_tensor(y, _builder)
return semantic.mul(x, y, sanitize_overflow, _builder)


Expand Down
4 changes: 2 additions & 2 deletions scripts/test-triton.sh
Original file line number Diff line number Diff line change
Expand Up @@ -298,12 +298,12 @@ run_instrumentation_tests() {
return
fi

SHARED_LIB_DIR=$(ls -1d $TRITON_PROJ/python/build/*lib*/triton/_C) || err "Could not find $TRITON_PROJ/python/build/*lib*/triton/_C, build Triton first"
INSTRUMENTATION_LIB_DIR=$(ls -1d $TRITON_PROJ/python/build/*lib*/triton/instrumentation) || err "Could not find $TRITON_PROJ/python/build/*lib*/triton/instrumentation, build Triton first"

cd $TRITON_PROJ/python/test/unit

TRITON_TEST_SUITE=instrumentation \
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${SHARED_LIB_DIR}/libGPUHello.so \
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${INSTRUMENTATION_LIB_DIR}/libGPUInstrumentationTestLib.so \
pytest -vvv --device xpu instrumentation/test_gpuhello.py
}

Expand Down
27 changes: 18 additions & 9 deletions test/Triton/rewrite-tensor-pointer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -141,42 +141,51 @@ tt.func public @rewrite_for(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>) {
// CHECK: tt.num_stages = 3

// -----
tt.func public @rewrite_if(%arg0: !tt.ptr<f16>, %arg1: i1) -> tensor<128x32xf16> {
tt.func public @rewrite_if(%arg0: !tt.ptr<f16>, %arg1: i1, %arg2: tensor<128x32xf32>) -> tensor<128x32xf16> {
%c0_i32 = arith.constant 0 : i32
%c32_i32 = arith.constant 32 : i32
%c1_i64 = arith.constant 1 : i64
%c32_i64 = arith.constant 32 : i64
%c128_i64 = arith.constant 128 : i64
%0 = tt.make_tensor_ptr %arg0, [%c128_i64, %c32_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : !tt.ptr<tensor<128x32xf16>>
%1 = scf.if %arg1 -> (!tt.ptr<tensor<128x32xf16>>) {
%1:2 = scf.if %arg1 -> (tensor<128x32xf16>, !tt.ptr<tensor<128x32xf16>>) {
%2 = tt.advance %0, [%c32_i32, %c0_i32] : !tt.ptr<tensor<128x32xf16>>
scf.yield %2 : !tt.ptr<tensor<128x32xf16>>
%3 = arith.truncf %arg2 : tensor<128x32xf32> to tensor<128x32xf16>
scf.yield %3, %2 : tensor<128x32xf16>, !tt.ptr<tensor<128x32xf16>>
} else {
scf.yield %0 : !tt.ptr<tensor<128x32xf16>>
%cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16>
scf.yield %cst, %0 : tensor<128x32xf16>, !tt.ptr<tensor<128x32xf16>>
}
%4 = tt.load %1 {boundaryCheck = array<i32: 1>, padding = 2 : i32} : !tt.ptr<tensor<128x32xf16>>
tt.return %4 : tensor<128x32xf16>
%4 = tt.load %1#1 {boundaryCheck = array<i32: 1>, padding = 2 : i32} : !tt.ptr<tensor<128x32xf16>>
%5 = arith.addf %1#0, %4 : tensor<128x32xf16>
tt.return %5 : tensor<128x32xf16>
}

// CHECK-LABEL: tt.func public @rewrite_if(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: !tt.ptr<f16>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: i1
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<128x32xf32>
// CHECK-DAG: %[[C0_I32:.*]] = arith.constant 0 : i32
// CHECK-DAG: %[[C32_I32:.*]] = arith.constant 32 : i32
// CHECK-DAG: %[[C1_I64:.*]] = arith.constant 1 : i64
// CHECK-DAG: %[[C32_I64:.*]] = arith.constant 32 : i64
// CHECK-DAG: %[[C128_I64:.*]] = arith.constant 128 : i64
// CHECK: %[[EXTSI0:.*]] = arith.extsi %[[C0_I32]] : i32 to i64
// CHECK: %[[EXTSI1:.*]] = arith.extsi %[[C0_I32]] : i32 to i64
// CHECK: %[[IF:.*]]:2 = scf.if %[[ARG1]] -> (i64, i64) {
// CHECK: %[[IF:.*]]:3 = scf.if %[[ARG1]] -> (tensor<128x32xf16>, i64, i64) {
// CHECK: %[[EXTSI2:.*]] = arith.extsi %[[C32_I32]] : i32 to i64
// CHECK: %[[ADDI0:.*]] = arith.addi %[[EXTSI0]], %[[EXTSI2]] : i64
// CHECK: %[[EXTSI3:.*]] = arith.extsi %[[C0_I32]] : i32 to i64
// CHECK: %[[ADDI1:.*]] = arith.addi %[[EXTSI1]], %[[EXTSI3]] : i64
// CHECK: scf.yield %[[ADDI0]], %[[ADDI1]] : i64, i64
// CHECK: %[[TRUNCF:.*]] = arith.truncf %[[ARG2]] : tensor<128x32xf32> to tensor<128x32xf16>
// CHECK: scf.yield %[[TRUNCF]], %[[ADDI0]], %[[ADDI1]] : tensor<128x32xf16>, i64, i64
// CHECK: } else {
// CHECK: scf.yield %[[EXTSI0]], %[[EXTSI1]] : i64, i64
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : tensor<128x32xf16>
// CHECK: scf.yield %[[CST]], %[[EXTSI0]], %[[EXTSI1]] : tensor<128x32xf16>, i64, i64
// CHECK: }
// CHECK: %{{.*}} = tt.splat %[[IF]]#1 : i64 -> tensor<128xi64>
// CHECK: %{{.*}} = tt.splat %[[IF]]#2 : i64 -> tensor<32xi64>
// CHECK: %{{.*}} = arith.addf %[[IF]]#0, %{{.*}} : tensor<128x32xf16>


// -----
Expand Down
Loading

0 comments on commit f213106

Please sign in to comment.