From 835485fd0a8681beb5896a8c4ccbe1fda03aae62 Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Tue, 5 Nov 2024 11:54:01 +0100 Subject: [PATCH] mlir: implement MemorySlot Interfaces for Gradient ops --- enzyme/BUILD | 2 + enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td | 14 ++- enzyme/Enzyme/MLIR/Dialect/Ops.cpp | 115 ++++++++++++++++++++++++ enzyme/Enzyme/MLIR/Dialect/Ops.h | 1 + enzyme/Enzyme/MLIR/enzymemlir-opt.cpp | 2 + enzyme/test/MLIR/Passes/mem2reg.mlir | 14 +++ 6 files changed, 144 insertions(+), 4 deletions(-) create mode 100644 enzyme/test/MLIR/Passes/mem2reg.mlir diff --git a/enzyme/BUILD b/enzyme/BUILD index 490c13a67372..1187dda10135 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -239,6 +239,7 @@ td_library( "@llvm-project//mlir:ControlFlowInterfacesTdFiles", "@llvm-project//mlir:FunctionInterfacesTdFiles", "@llvm-project//mlir:LoopLikeInterfaceTdFiles", + "@llvm-project//mlir:MemorySlotInterfacesTdFiles", "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:SideEffectInterfacesTdFiles", "@llvm-project//mlir:ViewLikeInterfaceTdFiles", @@ -604,6 +605,7 @@ cc_library( "@llvm-project//mlir:LinalgStructuredOpsIncGen", "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:MemorySlotInterfaces", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:NVVMDialect", "@llvm-project//mlir:OpenMPDialect", diff --git a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td index 5d61cec8287b..ca7659143370 100644 --- a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td +++ b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td @@ -24,6 +24,7 @@ include "mlir/IR/AttrTypeBase.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/FunctionInterfaces.td" include "mlir/Interfaces/LoopLikeInterface.td" +include "mlir/Interfaces/MemorySlotInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -121,8 +122,9 @@ def PopOp : Enzyme_Op<"pop"> { let results = (outs AnyType:$output); } -def InitOp : Enzyme_Op<"init"> { - let summary = "Creat enzyme.gradient and enzyme.cache"; +def InitOp : Enzyme_Op<"init", + [DeclareOpInterfaceMethods]> { + let summary = "Create enzyme.gradient and enzyme.cache"; let arguments = (ins ); let results = (outs AnyType); } @@ -147,14 +149,18 @@ def Gradient : Enzyme_Type<"Gradient"> { let assemblyFormat = "`<` $basetype `>`"; } -def SetOp : Enzyme_Op<"set"> { +def SetOp : Enzyme_Op<"set", + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { let summary = "Store the current value of the gradient"; let arguments = (ins Arg:$gradient, AnyType : $value); let results = (outs ); } -def GetOp : Enzyme_Op<"get"> { +def GetOp : Enzyme_Op<"get", + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { let summary = "Load current value of gradient"; let arguments = (ins Arg:$gradient); diff --git a/enzyme/Enzyme/MLIR/Dialect/Ops.cpp b/enzyme/Enzyme/MLIR/Dialect/Ops.cpp index 5c2e4283d300..3e3185427306 100644 --- a/enzyme/Enzyme/MLIR/Dialect/Ops.cpp +++ b/enzyme/Enzyme/MLIR/Dialect/Ops.cpp @@ -8,12 +8,14 @@ #include "Ops.h" #include "Dialect.h" +#include "Interfaces/AutoDiffTypeInterface.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Builders.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/MemorySlotInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -36,6 +38,119 @@ using namespace mlir; using namespace enzyme; using namespace mlir::arith; +//===----------------------------------------------------------------------===// +// InitOp +//===----------------------------------------------------------------------===// + +llvm::SmallVector InitOp::getPromotableSlots() { + auto Ty = this->getType(); + if (isa(Ty)) + return {}; + + if (!getOperation()->getBlock()->isEntryBlock()) + return {}; + + auto gTy = cast(Ty); + MemorySlot slot = {this->getResult(), gTy.getBasetype()}; + + return {slot}; +} + +Value InitOp::getDefaultValue(const MemorySlot &slot, OpBuilder &builder) { + auto gTy = cast(this->getType()); + return cast(gTy.getBasetype()) + .createNullValue(builder, this->getLoc()); +} + +void InitOp::handleBlockArgument(const MemorySlot &slot, BlockArgument argument, + OpBuilder &builder) {} + +std::optional +InitOp::handlePromotionComplete(const MemorySlot &slot, Value defaultValue, + OpBuilder &builder) { + if (defaultValue && defaultValue.use_empty()) + defaultValue.getDefiningOp()->erase(); + this->erase(); + return std::nullopt; +} + +//===----------------------------------------------------------------------===// +// GetOp +//===----------------------------------------------------------------------===// + +bool GetOp::loadsFrom(const MemorySlot &slot) { + return this->getGradient() == slot.ptr; +} + +bool GetOp::storesTo(const MemorySlot &slot) { return false; } + +Value GetOp::getStored(const MemorySlot &slot, OpBuilder &builder, + Value reachingDef, const DataLayout &dataLayout) { + return {}; +} + +bool GetOp::canUsesBeRemoved( + const MemorySlot &slot, + const llvm::SmallPtrSetImpl &blockingUses, + llvm::SmallVectorImpl &newBlockingUses, + const mlir::DataLayout &dataLayout) { + if (blockingUses.size() != 1) + return false; + + Value blockingUse = (*blockingUses.begin())->get(); + return blockingUse == slot.ptr && getGradient() == slot.ptr; +} + +DeletionKind GetOp::removeBlockingUses( + const MemorySlot &slot, + const llvm::SmallPtrSetImpl &blockingUses, OpBuilder &builder, + Value reachingDefinition, const DataLayout &dataLayout) { + this->getResult().replaceAllUsesWith(reachingDefinition); + return DeletionKind::Delete; +} + +llvm::LogicalResult GetOp::ensureOnlySafeAccesses( + const MemorySlot &slot, llvm::SmallVectorImpl &mustBeSafelyUsed, + const DataLayout &dataLayout) { + return success(slot.ptr == getGradient()); +} + +//===----------------------------------------------------------------------===// +// SetOp +//===----------------------------------------------------------------------===// + +bool SetOp::loadsFrom(const MemorySlot &slot) { return false; } + +bool SetOp::storesTo(const MemorySlot &slot) { + return this->getGradient() == slot.ptr; +} + +Value SetOp::getStored(const MemorySlot &slot, OpBuilder &builder, + Value reachingDef, const DataLayout &dataLayout) { + return this->getValue(); +} + +bool SetOp::canUsesBeRemoved( + const MemorySlot &slot, + const llvm::SmallPtrSetImpl &blockingUses, + llvm::SmallVectorImpl &newBlockingUses, + const mlir::DataLayout &dataLayout) { + return true; +} + +DeletionKind SetOp::removeBlockingUses( + const MemorySlot &slot, + const llvm::SmallPtrSetImpl &blockingUses, OpBuilder &builder, + Value reachingDefinition, const DataLayout &dataLayout) { + return DeletionKind::Delete; +} + +llvm::LogicalResult SetOp::ensureOnlySafeAccesses( + const MemorySlot &slot, llvm::SmallVectorImpl &mustBeSafelyUsed, + const DataLayout &dataLayout) { + return success(slot.ptr == getGradient()); +} + //===----------------------------------------------------------------------===// // GetFuncOp //===----------------------------------------------------------------------===// diff --git a/enzyme/Enzyme/MLIR/Dialect/Ops.h b/enzyme/Enzyme/MLIR/Dialect/Ops.h index 11aaa5f4291f..69aa6496b84f 100644 --- a/enzyme/Enzyme/MLIR/Dialect/Ops.h +++ b/enzyme/Enzyme/MLIR/Dialect/Ops.h @@ -14,6 +14,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/MemorySlotInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Bytecode/BytecodeOpInterface.h" diff --git a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp index f608cecd3a2c..0e6bdf7b101e 100644 --- a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp +++ b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp @@ -77,6 +77,7 @@ int main(int argc, char **argv) { // Register the standard passes we want. mlir::registerCSEPass(); + mlir::registerMem2RegPass(); mlir::registerConvertAffineToStandardPass(); mlir::registerSCCPPass(); mlir::registerInlinerPass(); @@ -84,6 +85,7 @@ int main(int argc, char **argv) { mlir::registerSymbolDCEPass(); mlir::registerLoopInvariantCodeMotionPass(); mlir::registerConvertSCFToOpenMPPass(); + mlir::registerSCFToControlFlowPass(); mlir::affine::registerAffinePasses(); mlir::registerReconcileUnrealizedCasts(); diff --git a/enzyme/test/MLIR/Passes/mem2reg.mlir b/enzyme/test/MLIR/Passes/mem2reg.mlir new file mode 100644 index 000000000000..e850f1d8fccf --- /dev/null +++ b/enzyme/test/MLIR/Passes/mem2reg.mlir @@ -0,0 +1,14 @@ +// RUN: %eopt %s -mem2reg | FileCheck %s + +module { + func.func @main(%arg0: f32) -> f32 { + %0 = "enzyme.init"() : () -> !enzyme.Gradient + "enzyme.set"(%0, %arg0) : (!enzyme.Gradient, f32) -> () + %2 = "enzyme.get"(%0) : (!enzyme.Gradient) -> f32 + return %2 : f32 + } +} + +// CHECK: func.func @main(%arg0: f32) -> f32 { +// CHECK-NEXT: return %arg0 : f32 +// CHECK-NEXT: }