diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index bfdee67dd7f..e1f9d21dc37 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -579,14 +579,18 @@ void ErrorIfRuntimeInactive(llvm::IRBuilder<> &B, llvm::Value *primal, call->setDebugLoc(loc); } -Type *BlasInfo::fpType(LLVMContext &ctx) const { +Type *BlasInfo::fpType(LLVMContext &ctx, bool to_scalar) const { if (floatType == "d" || floatType == "D") { return Type::getDoubleTy(ctx); } else if (floatType == "s" || floatType == "S") { return Type::getFloatTy(ctx); } else if (floatType == "c" || floatType == "C") { + if (to_scalar) + return Type::getFloatTy(ctx); return VectorType::get(Type::getFloatTy(ctx), 2, false); } else if (floatType == "z" || floatType == "Z") { + if (to_scalar) + return Type::getDoubleTy(ctx); return VectorType::get(Type::getDoubleTy(ctx), 2, false); } else { assert(false && "Unreachable"); diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index 62493de9046..02ce4b8b47e 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -678,7 +678,7 @@ struct BlasInfo { std::string function; bool is64; - llvm::Type *fpType(llvm::LLVMContext &ctx) const; + llvm::Type *fpType(llvm::LLVMContext &ctx, bool to_scalar = false) const; llvm::IntegerType *intType(llvm::LLVMContext &ctx) const; }; diff --git a/enzyme/tools/enzyme-tblgen/blasTAUpdater.h b/enzyme/tools/enzyme-tblgen/blasTAUpdater.h index 33f63d16338..2ff92d047b0 100644 --- a/enzyme/tools/enzyme-tblgen/blasTAUpdater.h +++ b/enzyme/tools/enzyme-tblgen/blasTAUpdater.h @@ -16,7 +16,7 @@ inline void emit_BLASTypes(raw_ostream &os) { "\"cublas\" && StringRef(blas.suffix).contains(\"v2\");\n"; os << "TypeTree ttFloat;\n" - << "llvm::Type *floatType = blas.fpType(call.getContext()); \n" + << "llvm::Type *floatType = blas.fpType(call.getContext(), true); \n" << "if (byRefFloat) {\n" << " ttFloat.insert({-1},BaseType::Pointer);\n" << " ttFloat.insert({-1,0},floatType);\n"