Skip to content

Commit

Permalink
mlir: implement MemorySlot Interfaces for Gradient ops
Browse files Browse the repository at this point in the history
  • Loading branch information
Pangoraw committed Nov 5, 2024
1 parent 7b27adb commit 835485f
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 4 deletions.
2 changes: 2 additions & 0 deletions enzyme/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ td_library(
"@llvm-project//mlir:ControlFlowInterfacesTdFiles",
"@llvm-project//mlir:FunctionInterfacesTdFiles",
"@llvm-project//mlir:LoopLikeInterfaceTdFiles",
"@llvm-project//mlir:MemorySlotInterfacesTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:SideEffectInterfacesTdFiles",
"@llvm-project//mlir:ViewLikeInterfaceTdFiles",
Expand Down Expand Up @@ -604,6 +605,7 @@ cc_library(
"@llvm-project//mlir:LinalgStructuredOpsIncGen",
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:MemorySlotInterfaces",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:NVVMDialect",
"@llvm-project//mlir:OpenMPDialect",
Expand Down
14 changes: 10 additions & 4 deletions enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ include "mlir/IR/AttrTypeBase.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/FunctionInterfaces.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/MemorySlotInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"


Expand Down Expand Up @@ -121,8 +122,9 @@ def PopOp : Enzyme_Op<"pop"> {
let results = (outs AnyType:$output);
}

def InitOp : Enzyme_Op<"init"> {
let summary = "Creat enzyme.gradient and enzyme.cache";
def InitOp : Enzyme_Op<"init",
[DeclareOpInterfaceMethods<PromotableAllocationOpInterface>]> {
let summary = "Create enzyme.gradient and enzyme.cache";
let arguments = (ins );
let results = (outs AnyType);
}
Expand All @@ -147,14 +149,18 @@ def Gradient : Enzyme_Type<"Gradient"> {
let assemblyFormat = "`<` $basetype `>`";
}

def SetOp : Enzyme_Op<"set"> {
def SetOp : Enzyme_Op<"set",
[DeclareOpInterfaceMethods<PromotableMemOpInterface>,
DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>]> {
let summary = "Store the current value of the gradient";
let arguments = (ins Arg<AnyType, "the reference to store to",
[MemWrite]>:$gradient, AnyType : $value);
let results = (outs );
}

def GetOp : Enzyme_Op<"get"> {
def GetOp : Enzyme_Op<"get",
[DeclareOpInterfaceMethods<PromotableMemOpInterface>,
DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>]> {
let summary = "Load current value of gradient";
let arguments = (ins Arg<AnyType, "the reference to load from",
[MemRead]>:$gradient);
Expand Down
115 changes: 115 additions & 0 deletions enzyme/Enzyme/MLIR/Dialect/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@

#include "Ops.h"
#include "Dialect.h"
#include "Interfaces/AutoDiffTypeInterface.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Builders.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
Expand All @@ -36,6 +38,119 @@ using namespace mlir;
using namespace enzyme;
using namespace mlir::arith;

//===----------------------------------------------------------------------===//
// InitOp
//===----------------------------------------------------------------------===//

llvm::SmallVector<MemorySlot> InitOp::getPromotableSlots() {
auto Ty = this->getType();
if (isa<CacheType>(Ty))
return {};

if (!getOperation()->getBlock()->isEntryBlock())
return {};

auto gTy = cast<GradientType>(Ty);
MemorySlot slot = {this->getResult(), gTy.getBasetype()};

return {slot};
}

Value InitOp::getDefaultValue(const MemorySlot &slot, OpBuilder &builder) {
auto gTy = cast<GradientType>(this->getType());
return cast<AutoDiffTypeInterface>(gTy.getBasetype())
.createNullValue(builder, this->getLoc());
}

void InitOp::handleBlockArgument(const MemorySlot &slot, BlockArgument argument,
OpBuilder &builder) {}

std::optional<mlir::PromotableAllocationOpInterface>
InitOp::handlePromotionComplete(const MemorySlot &slot, Value defaultValue,
OpBuilder &builder) {
if (defaultValue && defaultValue.use_empty())
defaultValue.getDefiningOp()->erase();
this->erase();
return std::nullopt;
}

//===----------------------------------------------------------------------===//
// GetOp
//===----------------------------------------------------------------------===//

bool GetOp::loadsFrom(const MemorySlot &slot) {
return this->getGradient() == slot.ptr;
}

bool GetOp::storesTo(const MemorySlot &slot) { return false; }

Value GetOp::getStored(const MemorySlot &slot, OpBuilder &builder,
Value reachingDef, const DataLayout &dataLayout) {
return {};
}

bool GetOp::canUsesBeRemoved(
const MemorySlot &slot,
const llvm::SmallPtrSetImpl<OpOperand *> &blockingUses,
llvm::SmallVectorImpl<OpOperand *> &newBlockingUses,
const mlir::DataLayout &dataLayout) {
if (blockingUses.size() != 1)
return false;

Value blockingUse = (*blockingUses.begin())->get();
return blockingUse == slot.ptr && getGradient() == slot.ptr;
}

DeletionKind GetOp::removeBlockingUses(
const MemorySlot &slot,
const llvm::SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder,
Value reachingDefinition, const DataLayout &dataLayout) {
this->getResult().replaceAllUsesWith(reachingDefinition);
return DeletionKind::Delete;
}

llvm::LogicalResult GetOp::ensureOnlySafeAccesses(
const MemorySlot &slot, llvm::SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
const DataLayout &dataLayout) {
return success(slot.ptr == getGradient());
}

//===----------------------------------------------------------------------===//
// SetOp
//===----------------------------------------------------------------------===//

bool SetOp::loadsFrom(const MemorySlot &slot) { return false; }

bool SetOp::storesTo(const MemorySlot &slot) {
return this->getGradient() == slot.ptr;
}

Value SetOp::getStored(const MemorySlot &slot, OpBuilder &builder,
Value reachingDef, const DataLayout &dataLayout) {
return this->getValue();
}

bool SetOp::canUsesBeRemoved(
const MemorySlot &slot,
const llvm::SmallPtrSetImpl<OpOperand *> &blockingUses,
llvm::SmallVectorImpl<OpOperand *> &newBlockingUses,
const mlir::DataLayout &dataLayout) {
return true;
}

DeletionKind SetOp::removeBlockingUses(
const MemorySlot &slot,
const llvm::SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder,
Value reachingDefinition, const DataLayout &dataLayout) {
return DeletionKind::Delete;
}

llvm::LogicalResult SetOp::ensureOnlySafeAccesses(
const MemorySlot &slot, llvm::SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
const DataLayout &dataLayout) {
return success(slot.ptr == getGradient());
}

//===----------------------------------------------------------------------===//
// GetFuncOp
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions enzyme/Enzyme/MLIR/Dialect/Ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

#include "mlir/Bytecode/BytecodeOpInterface.h"
Expand Down
2 changes: 2 additions & 0 deletions enzyme/Enzyme/MLIR/enzymemlir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,15 @@ int main(int argc, char **argv) {

// Register the standard passes we want.
mlir::registerCSEPass();
mlir::registerMem2RegPass();
mlir::registerConvertAffineToStandardPass();
mlir::registerSCCPPass();
mlir::registerInlinerPass();
mlir::registerCanonicalizerPass();
mlir::registerSymbolDCEPass();
mlir::registerLoopInvariantCodeMotionPass();
mlir::registerConvertSCFToOpenMPPass();
mlir::registerSCFToControlFlowPass();
mlir::affine::registerAffinePasses();
mlir::registerReconcileUnrealizedCasts();

Expand Down
14 changes: 14 additions & 0 deletions enzyme/test/MLIR/Passes/mem2reg.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: %eopt %s -mem2reg | FileCheck %s

module {
func.func @main(%arg0: f32) -> f32 {
%0 = "enzyme.init"() : () -> !enzyme.Gradient<f32>
"enzyme.set"(%0, %arg0) : (!enzyme.Gradient<f32>, f32) -> ()
%2 = "enzyme.get"(%0) : (!enzyme.Gradient<f32>) -> f32
return %2 : f32
}
}

// CHECK: func.func @main(%arg0: f32) -> f32 {
// CHECK-NEXT: return %arg0 : f32
// CHECK-NEXT: }

0 comments on commit 835485f

Please sign in to comment.