-
Notifications
You must be signed in to change notification settings - Fork 221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve matrices translation with SPV_KHR_untyped_pointers #2857
base: main
Are you sure you want to change the base?
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2216,8 +2216,10 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F, | |
auto *AC = static_cast<SPIRVAccessChainBase *>(BV); | ||
auto *Base = transValue(AC->getBase(), F, BB); | ||
SPIRVType *BaseSPVTy = AC->getBaseType(); | ||
if (BaseSPVTy->isTypePointer() && | ||
BaseSPVTy->getPointerElementType()->isTypeCooperativeMatrixKHR()) { | ||
if ((BaseSPVTy->isTypePointer() && | ||
BaseSPVTy->getPointerElementType()->isTypeCooperativeMatrixKHR()) || | ||
(isUntypedAccessChainOpCode(OC) && | ||
BaseSPVTy->isTypeCooperativeMatrixKHR())) { | ||
return mapValue(BV, transSPIRVBuiltinFromInst(AC, BB)); | ||
} | ||
Type *BaseTy = | ||
|
@@ -3426,23 +3428,65 @@ Instruction *SPIRVToLLVM::transBuiltinFromInst(const std::string &FuncName, | |
BasicBlock *BB) { | ||
std::string MangledName; | ||
auto Ops = BI->getOperands(); | ||
Op OC = BI->getOpCode(); | ||
if (isUntypedAccessChainOpCode(OC)) { | ||
auto *AC = static_cast<SPIRVAccessChainBase *>(BI); | ||
if (AC->getBaseType()->isTypeCooperativeMatrixKHR()) | ||
Ops.erase(Ops.begin()); | ||
} | ||
Type *RetTy = | ||
BI->hasType() ? transType(BI->getType()) : Type::getVoidTy(*Context); | ||
transOCLBuiltinFromInstPreproc(BI, RetTy, Ops); | ||
std::vector<Type *> ArgTys = | ||
transTypeVector(SPIRVInstruction::getOperandTypes(Ops), true); | ||
|
||
// Special handling for "truly" untyped pointers to preserve correct | ||
// builtin mangling of atomic operations. | ||
auto Ptr = findFirstPtrType(ArgTys); | ||
if (Ptr < ArgTys.size() && | ||
BI->getValueType(Ops[Ptr]->getId())->isTypeUntypedPointerKHR()) { | ||
if (isAtomicOpCodeUntypedPtrSupported(BI->getOpCode())) { | ||
// Special handling for "truly" untyped pointers to preserve correct | ||
// builtin mangling of atomic and matrix operations. | ||
if (isAtomicOpCodeUntypedPtrSupported(OC)) { | ||
auto *AI = static_cast<SPIRVAtomicInstBase *>(BI); | ||
ArgTys[Ptr] = TypedPointerType::get( | ||
transType(AI->getSemanticType()), | ||
SPIRSPIRVAddrSpaceMap::rmap( | ||
BI->getValueType(Ops[Ptr]->getId())->getPointerStorageClass())); | ||
} else if (OC == spv::OpCooperativeMatrixStoreKHR || | ||
OC == spv::internal::OpJointMatrixStoreINTEL || | ||
OC == spv::internal::OpCooperativeMatrixStoreCheckedINTEL || | ||
OC == spv::internal::OpJointMatrixLoadINTEL || | ||
OC == spv::OpCompositeConstruct || | ||
OC == spv::internal::OpCooperativeMatrixApplyFunctionINTEL) { | ||
// It will work but it'd be strange | ||
auto *Val = transValue(Ops[Ptr], BB->getParent(), BB); | ||
Val = Val->stripPointerCasts(); | ||
if (auto *GEP = dyn_cast<GetElementPtrInst>(Val)) | ||
ArgTys[Ptr] = TypedPointerType::get( | ||
GEP->getSourceElementType(), | ||
SPIRSPIRVAddrSpaceMap::rmap( | ||
BI->getValueType(Ops[Ptr]->getId())->getPointerStorageClass())); | ||
else if (auto *AI = dyn_cast<AllocaInst>(Val)) | ||
ArgTys[Ptr] = TypedPointerType::get( | ||
AI->getAllocatedType(), | ||
SPIRSPIRVAddrSpaceMap::rmap( | ||
BI->getValueType(Ops[Ptr]->getId())->getPointerStorageClass())); | ||
else if (isa<Argument>(Val) && RetTy) { | ||
// Pointer could be a function parameter. Assume that the type of the | ||
// pointer is the same as the return type. | ||
Type *Ty = nullptr; | ||
// it return type is array type, assign its element type to Ty | ||
if (RetTy->isArrayTy()) | ||
Ty = RetTy->getArrayElementType(); | ||
else if (RetTy->isVectorTy()) | ||
Ty = cast<VectorType>(RetTy)->getElementType(); | ||
else | ||
Ty = RetTy; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So for store pointer parameter would be considered as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I don't get this at all. Don't we already query the type from the llvm instructions? Or do you mean specific SPIR-V opcode in mind? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On the line 3437 RetTy is defined like this:
So if the instruction doesn't have a return type (aka There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got you, you're right - it won't be correct, i was just "lucky" that we didn't have test case for that. Will prepare a fix |
||
|
||
ArgTys[Ptr] = TypedPointerType::get( | ||
Ty, | ||
SPIRSPIRVAddrSpaceMap::rmap( | ||
BI->getValueType(Ops[Ptr]->getId())->getPointerStorageClass())); | ||
} | ||
} | ||
} | ||
|
||
|
@@ -3455,8 +3499,7 @@ Instruction *SPIRVToLLVM::transBuiltinFromInst(const std::string &FuncName, | |
if (BM->getDesiredBIsRepresentation() != BIsRepresentation::SPIRVFriendlyIR) | ||
mangleOpenClBuiltin(FuncName, ArgTys, MangledName); | ||
else | ||
MangledName = | ||
getSPIRVFriendlyIRFunctionName(FuncName, BI->getOpCode(), ArgTys, Ops); | ||
MangledName = getSPIRVFriendlyIRFunctionName(FuncName, OC, ArgTys, Ops); | ||
|
||
opaquifyTypedPointers(ArgTys); | ||
|
||
|
@@ -3479,14 +3522,13 @@ Instruction *SPIRVToLLVM::transBuiltinFromInst(const std::string &FuncName, | |
Func->setCallingConv(CallingConv::SPIR_FUNC); | ||
if (isFuncNoUnwind()) | ||
Func->addFnAttr(Attribute::NoUnwind); | ||
auto OC = BI->getOpCode(); | ||
if (isGroupOpCode(OC) || isGroupNonUniformOpcode(OC) || | ||
isIntelSubgroupOpCode(OC) || isSplitBarrierINTELOpCode(OC) || | ||
OC == OpControlBarrier) | ||
Func->addFnAttr(Attribute::Convergent); | ||
} | ||
CallInst *Call; | ||
if (BI->getOpCode() == OpCooperativeMatrixLengthKHR && | ||
if (OC == OpCooperativeMatrixLengthKHR && | ||
Ops[0]->getOpCode() == OpTypeCooperativeMatrixKHR) { | ||
// OpCooperativeMatrixLengthKHR needs special handling as its operand is | ||
// a Type instead of a Value. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, we can stop modifying code related to JointMatrixINTEL instructions that have CooperativeMatrixKHR counterparts aka trivial MatrixLoad/Store, MulAdd etc as they will be depricated. Meanwhile OpCooperativeMatrixStoreCheckedINTEL and other unique for SPV_INTEL_joint_matrix extension instruction must still be handled properly.