From dc65cf2583aabe7ed1f53cd2a25ff8be6d612a95 Mon Sep 17 00:00:00 2001 From: "Ivan R. Ivanov" Date: Sun, 28 Apr 2024 09:35:53 -0700 Subject: [PATCH] [Truncate] Handle casts and emit warnings (#1846) * Handle casts and emit warnings * only emit warning/error if trunc from type is used --- enzyme/Enzyme/EnzymeLogic.cpp | 81 ++++++++++--------- enzyme/test/Integration/Truncate/warnings.cpp | 53 ++++++++++++ 2 files changed, 95 insertions(+), 39 deletions(-) create mode 100644 enzyme/test/Integration/Truncate/warnings.cpp diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index dbf1710e46b9..fbdbc01257fb 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -5213,31 +5213,37 @@ class TruncateGenerator : public llvm::InstVisitor, oldFunc(oldFunc), newFunc(newFunc), mode(truncation.getMode()), Logic(Logic), ctx(newFunc->getContext()) {} - void checkHandled(llvm::Instruction &inst) { - // TODO - // if (all_of(inst.getOperandList(), - // [&](Use *use) { return use->get()->getType() == fromType; })) - // todo(inst); - } + void todo(llvm::Instruction &I) { + if (all_of(I.operands(), + [&](Use &U) { return U.get()->getType() != fromType; })) + return; - // TODO - void handleTrunc(); - void hendleIntToFloat(); - void handleFloatToInt(); + switch (mode) { + case TruncMemMode: + EmitFailure("FPEscaping", I.getDebugLoc(), &I, "FP value escapes!"); + break; + case TruncOpMode: + case TruncOpFullModuleMode: + EmitWarning( + "UnhandledTrunc", I, + "Operation not handled - it will be executed in the original way.", + I); + break; + default: + llvm_unreachable("Unknown trunc mode"); + } + } - void visitInstruction(llvm::Instruction &inst) { + void visitInstruction(llvm::Instruction &I) { using namespace llvm; - // TODO explicitly handle all instructions rather than using the catch all - // below - - switch (inst.getOpcode()) { + switch (I.getOpcode()) { // #include "InstructionDerivatives.inc" default: break; } - checkHandled(inst); + todo(I); } Value *truncate(IRBuilder<> &B, Value *v) { @@ -5264,21 +5270,6 @@ class TruncateGenerator : public llvm::InstVisitor, llvm_unreachable("Unknown trunc mode"); } - void todo(llvm::Instruction &I) { - std::string s; - llvm::raw_string_ostream ss(s); - ss << "cannot handle unknown instruction\n" << I; - if (CustomErrorHandler) { - IRBuilder<> Builder2(getNewFromOriginal(&I)); - CustomErrorHandler(ss.str().c_str(), wrap(&I), ErrorType::NoTruncate, - this, nullptr, wrap(&Builder2)); - return; - } else { - EmitFailure("NoTruncate", I.getDebugLoc(), &I, ss.str()); - return; - } - } - void visitAllocaInst(llvm::AllocaInst &I) { return; } void visitICmpInst(llvm::ICmpInst &I) { return; } void visitFCmpInst(llvm::FCmpInst &CI) { @@ -5327,10 +5318,28 @@ class TruncateGenerator : public llvm::InstVisitor, void visitGetElementPtrInst(llvm::GetElementPtrInst &gep) { return; } void visitPHINode(llvm::PHINode &phi) { return; } void visitCastInst(llvm::CastInst &CI) { + // TODO Try to follow fps through trunc/exts switch (mode) { case TruncMemMode: { - if (CI.getSrcTy() == getFromType() || CI.getDestTy() == getFromType()) - todo(CI); + auto newI = getNewFromOriginal(&CI); + auto newSrc = newI->getOperand(0); + if (CI.getSrcTy() == getFromType()) { + IRBuilder<> B(newI); + if (isa(newSrc)) + return; + newI->setOperand(0, createFPRTGetCall(B, newSrc)); + EmitWarning("FPNoFollow", CI, "Will not follow FP through this cast.", + CI); + } else if (CI.getDestTy() == getFromType()) { + IRBuilder<> B(newI->getNextNode()); + EmitWarning("FPNoFollow", CI, "Will not follow FP through this cast.", + CI); + auto nres = createFPRTNewCall(B, newI); + nres->takeName(newI); + nres->copyIRFlags(newI); + newI->replaceUsesWithIf(nres, + [&](Use &U) { return U.getUser() != nres; }); + } return; } case TruncOpMode: @@ -5585,12 +5594,6 @@ class TruncateGenerator : public llvm::InstVisitor, } return; } - void visitFPTruncInst(FPTruncInst &I) { return; } - void visitFPExtInst(FPExtInst &I) { return; } - void visitFPToUIInst(FPToUIInst &I) { return; } - void visitFPToSIInst(FPToSIInst &I) { return; } - void visitUIToFPInst(UIToFPInst &I) { return; } - void visitSIToFPInst(SIToFPInst &I) { return; } }; bool EnzymeLogic::CreateTruncateValue(RequestContext context, Value *v, diff --git a/enzyme/test/Integration/Truncate/warnings.cpp b/enzyme/test/Integration/Truncate/warnings.cpp new file mode 100644 index 000000000000..4f730a433549 --- /dev/null +++ b/enzyme/test/Integration/Truncate/warnings.cpp @@ -0,0 +1,53 @@ +// RUN: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -c -DTRUNC_MEM -O2 %s -o /dev/null -emit-llvm %newLoadClangEnzyme -include enzyme/fprt/mpfr.h -Xclang -verify -Rpass=enzyme; fi +// RUN: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -c -DTRUNC_MEM -O2 -g %s -o /dev/null -emit-llvm %newLoadClangEnzyme -include enzyme/fprt/mpfr.h -Xclang -verify -Rpass=enzyme; fi +// COM: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -c -DTRUNC_OP -O2 %s -o /dev/null -emit-llvm %newLoadClangEnzyme -include enzyme/fprt/mpfr.h -Xclang -verify -Rpass=enzyme; fi +// COM: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -c -DTRUNC_OP -O2 -g %s -o /dev/null -emit-llvm %newLoadClangEnzyme -include enzyme/fprt/mpfr.h -Xclang -verify -Rpass=enzyme; fi + +#include + +#define FROM 64 +#define TO 32 + +double bithack(double a) { + return *((int64_t *)&a) + 1; // expected-remark {{Will not follow FP through this cast.}}, expected-remark {{Will not follow FP through this cast.}} +} +__attribute__((noinline)) +float truncf(double a) { + return (float)a; // expected-remark {{Will not follow FP through this cast.}} +} + +double intrinsics(double a, double b) { + return bithack(a) * truncf(b); // expected-remark {{Will not follow FP through this cast.}} +} + +typedef double (*fty)(double *, double *, double *, int); + +typedef double (*fty2)(double, double); + +extern fty __enzyme_truncate_mem_func_2(...); +extern fty2 __enzyme_truncate_mem_func(...); +extern fty __enzyme_truncate_op_func_2(...); +extern fty2 __enzyme_truncate_op_func(...); +extern double __enzyme_truncate_mem_value(...); +extern double __enzyme_expand_mem_value(...); + + +int main() { + #ifdef TRUNC_MEM + { + double a = 2; + double b = 3; + a = __enzyme_truncate_mem_value(a, FROM, TO); + b = __enzyme_truncate_mem_value(b, FROM, TO); + double trunc = __enzyme_expand_mem_value(__enzyme_truncate_mem_func(intrinsics, FROM, TO)(a, b), FROM, TO); + } + #endif + #ifdef TRUNC_OP + { + double a = 2; + double b = 3; + double trunc = __enzyme_truncate_op_func(intrinsics, FROM, TO)(a, b); + } + #endif + +}