Skip to content

Commit

Permalink
Do not create tapes for whole arrays.
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Nov 26, 2023
1 parent 46e7134 commit cabf700
Showing 1 changed file with 3 additions and 18 deletions.
21 changes: 3 additions & 18 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
ReverseModeVisitor::CladTapeResult
ReverseModeVisitor::MakeCladTapeFor(Expr* E, llvm::StringRef prefix) {
assert(E && "must be provided");
if (auto* IE = dyn_cast<ImplicitCastExpr>(E)) {
E = IE->getSubExpr()->IgnoreImplicit();
}
QualType EQt = E->getType();
if (isa<ArrayType>(EQt))
EQt = GetCladArrayOfType(utils::GetValueType(EQt));
E = E->IgnoreImplicit();
QualType TapeType =
GetCladTapeOfType(getNonConstType(EQt, m_Context, m_Sema));
GetCladTapeOfType(getNonConstType(E->getType(), m_Context, m_Sema));
LookupResult& Push = GetCladTapePush();
LookupResult& Pop = GetCladTapePop();
Expr* TapeRef =
Expand All @@ -93,17 +88,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Expr* PopExpr =
m_Sema.ActOnCallExpr(getCurrentScope(), PopDRE, noLoc, TapeRef, noLoc)
.get();
Expr* exprToPush = E;
if (const auto* AT = dyn_cast<ArrayType>(E->getType())) {
Expr* init = getArraySizeExpr(AT, m_Context, *this);
llvm::SmallVector<Expr*, 2> pushArgs{E, init};
SourceLocation loc = E->getExprLoc();
TypeSourceInfo* TSI = m_Context.getTrivialTypeSourceInfo(EQt, loc);
exprToPush =
m_Sema.BuildCXXTypeConstructExpr(TSI, loc, pushArgs, loc, false)
.get();
}
Expr* CallArgs[] = {TapeRef, exprToPush};
Expr* CallArgs[] = {TapeRef, E};
Expr* PushExpr =
m_Sema.ActOnCallExpr(getCurrentScope(), PushDRE, noLoc, CallArgs, noLoc)
.get();
Expand Down

0 comments on commit cabf700

Please sign in to comment.