Skip to content

Commit

Permalink
[Truncate] Handle casts and emit warnings (#1846)
Browse files Browse the repository at this point in the history
* Handle casts and emit warnings

* only emit warning/error if trunc from type is used
  • Loading branch information
ivanradanov authored Apr 28, 2024
1 parent 0fcd564 commit dc65cf2
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 39 deletions.
81 changes: 42 additions & 39 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5213,31 +5213,37 @@ class TruncateGenerator : public llvm::InstVisitor<TruncateGenerator>,
oldFunc(oldFunc), newFunc(newFunc), mode(truncation.getMode()),
Logic(Logic), ctx(newFunc->getContext()) {}

void checkHandled(llvm::Instruction &inst) {
// TODO
// if (all_of(inst.getOperandList(),
// [&](Use *use) { return use->get()->getType() == fromType; }))
// todo(inst);
}
void todo(llvm::Instruction &I) {
if (all_of(I.operands(),
[&](Use &U) { return U.get()->getType() != fromType; }))
return;

// TODO
void handleTrunc();
void hendleIntToFloat();
void handleFloatToInt();
switch (mode) {
case TruncMemMode:
EmitFailure("FPEscaping", I.getDebugLoc(), &I, "FP value escapes!");
break;
case TruncOpMode:
case TruncOpFullModuleMode:
EmitWarning(
"UnhandledTrunc", I,
"Operation not handled - it will be executed in the original way.",
I);
break;
default:
llvm_unreachable("Unknown trunc mode");
}
}

void visitInstruction(llvm::Instruction &inst) {
void visitInstruction(llvm::Instruction &I) {
using namespace llvm;

// TODO explicitly handle all instructions rather than using the catch all
// below

switch (inst.getOpcode()) {
switch (I.getOpcode()) {
// #include "InstructionDerivatives.inc"
default:
break;
}

checkHandled(inst);
todo(I);
}

Value *truncate(IRBuilder<> &B, Value *v) {
Expand All @@ -5264,21 +5270,6 @@ class TruncateGenerator : public llvm::InstVisitor<TruncateGenerator>,
llvm_unreachable("Unknown trunc mode");
}

void todo(llvm::Instruction &I) {
std::string s;
llvm::raw_string_ostream ss(s);
ss << "cannot handle unknown instruction\n" << I;
if (CustomErrorHandler) {
IRBuilder<> Builder2(getNewFromOriginal(&I));
CustomErrorHandler(ss.str().c_str(), wrap(&I), ErrorType::NoTruncate,
this, nullptr, wrap(&Builder2));
return;
} else {
EmitFailure("NoTruncate", I.getDebugLoc(), &I, ss.str());
return;
}
}

void visitAllocaInst(llvm::AllocaInst &I) { return; }
void visitICmpInst(llvm::ICmpInst &I) { return; }
void visitFCmpInst(llvm::FCmpInst &CI) {
Expand Down Expand Up @@ -5327,10 +5318,28 @@ class TruncateGenerator : public llvm::InstVisitor<TruncateGenerator>,
void visitGetElementPtrInst(llvm::GetElementPtrInst &gep) { return; }
void visitPHINode(llvm::PHINode &phi) { return; }
void visitCastInst(llvm::CastInst &CI) {
// TODO Try to follow fps through trunc/exts
switch (mode) {
case TruncMemMode: {
if (CI.getSrcTy() == getFromType() || CI.getDestTy() == getFromType())
todo(CI);
auto newI = getNewFromOriginal(&CI);
auto newSrc = newI->getOperand(0);
if (CI.getSrcTy() == getFromType()) {
IRBuilder<> B(newI);
if (isa<Constant>(newSrc))
return;
newI->setOperand(0, createFPRTGetCall(B, newSrc));
EmitWarning("FPNoFollow", CI, "Will not follow FP through this cast.",
CI);
} else if (CI.getDestTy() == getFromType()) {
IRBuilder<> B(newI->getNextNode());
EmitWarning("FPNoFollow", CI, "Will not follow FP through this cast.",
CI);
auto nres = createFPRTNewCall(B, newI);
nres->takeName(newI);
nres->copyIRFlags(newI);
newI->replaceUsesWithIf(nres,
[&](Use &U) { return U.getUser() != nres; });
}
return;
}
case TruncOpMode:
Expand Down Expand Up @@ -5585,12 +5594,6 @@ class TruncateGenerator : public llvm::InstVisitor<TruncateGenerator>,
}
return;
}
void visitFPTruncInst(FPTruncInst &I) { return; }
void visitFPExtInst(FPExtInst &I) { return; }
void visitFPToUIInst(FPToUIInst &I) { return; }
void visitFPToSIInst(FPToSIInst &I) { return; }
void visitUIToFPInst(UIToFPInst &I) { return; }
void visitSIToFPInst(SIToFPInst &I) { return; }
};

bool EnzymeLogic::CreateTruncateValue(RequestContext context, Value *v,
Expand Down
53 changes: 53 additions & 0 deletions enzyme/test/Integration/Truncate/warnings.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// RUN: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -c -DTRUNC_MEM -O2 %s -o /dev/null -emit-llvm %newLoadClangEnzyme -include enzyme/fprt/mpfr.h -Xclang -verify -Rpass=enzyme; fi
// RUN: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -c -DTRUNC_MEM -O2 -g %s -o /dev/null -emit-llvm %newLoadClangEnzyme -include enzyme/fprt/mpfr.h -Xclang -verify -Rpass=enzyme; fi
// COM: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -c -DTRUNC_OP -O2 %s -o /dev/null -emit-llvm %newLoadClangEnzyme -include enzyme/fprt/mpfr.h -Xclang -verify -Rpass=enzyme; fi
// COM: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -c -DTRUNC_OP -O2 -g %s -o /dev/null -emit-llvm %newLoadClangEnzyme -include enzyme/fprt/mpfr.h -Xclang -verify -Rpass=enzyme; fi

#include <math.h>

#define FROM 64
#define TO 32

double bithack(double a) {
return *((int64_t *)&a) + 1; // expected-remark {{Will not follow FP through this cast.}}, expected-remark {{Will not follow FP through this cast.}}
}
__attribute__((noinline))
float truncf(double a) {
return (float)a; // expected-remark {{Will not follow FP through this cast.}}
}

double intrinsics(double a, double b) {
return bithack(a) * truncf(b); // expected-remark {{Will not follow FP through this cast.}}
}

typedef double (*fty)(double *, double *, double *, int);

typedef double (*fty2)(double, double);

extern fty __enzyme_truncate_mem_func_2(...);
extern fty2 __enzyme_truncate_mem_func(...);
extern fty __enzyme_truncate_op_func_2(...);
extern fty2 __enzyme_truncate_op_func(...);
extern double __enzyme_truncate_mem_value(...);
extern double __enzyme_expand_mem_value(...);


int main() {
#ifdef TRUNC_MEM
{
double a = 2;
double b = 3;
a = __enzyme_truncate_mem_value(a, FROM, TO);
b = __enzyme_truncate_mem_value(b, FROM, TO);
double trunc = __enzyme_expand_mem_value(__enzyme_truncate_mem_func(intrinsics, FROM, TO)(a, b), FROM, TO);
}
#endif
#ifdef TRUNC_OP
{
double a = 2;
double b = 3;
double trunc = __enzyme_truncate_op_func(intrinsics, FROM, TO)(a, b);
}
#endif

}

0 comments on commit dc65cf2

Please sign in to comment.