Skip to content

Commit

Permalink
[MLIR][LLVM] Remove bitcast pattern from type consistency pass (llvm#…
Browse files Browse the repository at this point in the history
…87755)

This commit removes the no longer required bitcast inserting pattern in
LLVM dialect's type consistency pattern. This was previously required to
enable Mem2Reg and SROA to promote accesses that had different types.
Recent changes to both passes added direct support for this feature to
them, so the pattern has no further use.
  • Loading branch information
Dinistro authored Apr 5, 2024
1 parent 9708d09 commit 5419623
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 56 deletions.
11 changes: 0 additions & 11 deletions mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,6 @@ class SplitStores : public OpRewritePattern<StoreOp> {
PatternRewriter &rewrite) const override;
};

/// Transforms type-inconsistent stores, aka stores where the type hint of
/// the address contradicts the value stored, by inserting a bitcast if
/// possible.
class BitcastStores : public OpRewritePattern<StoreOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(StoreOp store,
PatternRewriter &rewriter) const override;
};

/// Splits GEPs with more than two indices into multiple GEPs with exactly
/// two indices. The created GEPs are then guaranteed to index into only
/// one aggregate at a time.
Expand Down
28 changes: 0 additions & 28 deletions mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,6 @@ static Type isElementTypeInconsistent(Value addr, Type expectedType) {
return elemType;
}

/// Checks that two types are the same or can be bitcast into one another.
static bool areBitcastCompatible(DataLayout &layout, Type lhs, Type rhs) {
return lhs == rhs || (!isa<LLVMStructType, LLVMArrayType>(lhs) &&
!isa<LLVMStructType, LLVMArrayType>(rhs) &&
layout.getTypeSize(lhs) == layout.getTypeSize(rhs));
}

//===----------------------------------------------------------------------===//
// CanonicalizeAlignedGep
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -518,26 +511,6 @@ LogicalResult SplitStores::matchAndRewrite(StoreOp store,
return success();
}

LogicalResult BitcastStores::matchAndRewrite(StoreOp store,
PatternRewriter &rewriter) const {
Type sourceType = store.getValue().getType();
Type typeHint = isElementTypeInconsistent(store.getAddr(), sourceType);
if (!typeHint) {
// Nothing to do, since it is already consistent.
return failure();
}

auto dataLayout = DataLayout::closest(store);
if (!areBitcastCompatible(dataLayout, typeHint, sourceType))
return failure();

auto bitcastOp =
rewriter.create<BitcastOp>(store.getLoc(), typeHint, store.getValue());
rewriter.modifyOpInPlace(store,
[&] { store.getValueMutable().assign(bitcastOp); });
return success();
}

LogicalResult SplitGEP::matchAndRewrite(GEPOp gepOp,
PatternRewriter &rewriter) const {
FailureOr<Type> typeHint = getRequiredConsistentGEPType(gepOp);
Expand Down Expand Up @@ -588,7 +561,6 @@ struct LLVMTypeConsistencyPass
RewritePatternSet rewritePatterns(&getContext());
rewritePatterns.add<CanonicalizeAlignedGep>(&getContext());
rewritePatterns.add<SplitStores>(&getContext(), maxVectorSplitSize);
rewritePatterns.add<BitcastStores>(&getContext());
rewritePatterns.add<SplitGEP>(&getContext());
FrozenRewritePatternSet frozen(std::move(rewritePatterns));

Expand Down
18 changes: 1 addition & 17 deletions mlir/test/Dialect/LLVMIR/type-consistency.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,7 @@ llvm.func @coalesced_store_floats(%arg: i64) {
// CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST32]] : i64
// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32
// CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (f32, f32)>
// CHECK: %[[BIT_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i32 to f32
// CHECK: llvm.store %[[BIT_CAST]], %[[GEP]]
// CHECK: llvm.store %[[TRUNC]], %[[GEP]]
llvm.store %arg, %1 : i64, !llvm.ptr
// CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]]
llvm.return
Expand Down Expand Up @@ -327,21 +326,6 @@ llvm.func @vector_write_split_struct(%arg: vector<2xi64>) {

// -----

// CHECK-LABEL: llvm.func @bitcast_insertion
// CHECK-SAME: %[[ARG:.*]]: i32
llvm.func @bitcast_insertion(%arg: i32) {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x f32
%1 = llvm.alloca %0 x f32 : (i32) -> !llvm.ptr
// CHECK: %[[BIT_CAST:.*]] = llvm.bitcast %[[ARG]] : i32 to f32
// CHECK: llvm.store %[[BIT_CAST]], %[[ALLOCA]]
llvm.store %arg, %1 : i32, !llvm.ptr
// CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]]
llvm.return
}

// -----

// CHECK-LABEL: llvm.func @gep_split
// CHECK-SAME: %[[ARG:.*]]: i64
llvm.func @gep_split(%arg: i64) {
Expand Down

0 comments on commit 5419623

Please sign in to comment.