Skip to content

Commit

Permalink
Move if-stmt invariant code outside, fix a compiler warning. NFC
Browse files Browse the repository at this point in the history
  • Loading branch information
vgvassilev committed Dec 8, 2023
1 parent 232d17b commit e4c501a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 18 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,7 @@ jobs:
if: ${{ failure() }}
uses: mxschmitt/action-tmate@v3
# When debugging increase to a suitable value!
timeout-minutes: ${{ github.event.pull_request && 100 || 20 }}
timeout-minutes: ${{ github.event.pull_request && 1 || 20 }}
- name: Prepare code coverage report
if: ${{ success() && (matrix.coverage == true) }}
run: |
Expand Down
27 changes: 10 additions & 17 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1304,33 +1304,26 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

StmtDiff ReverseModeVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) {
DeclRefExpr* clonedDRE = nullptr;
DeclRefExpr* clonedDRE = cast<DeclRefExpr>(Clone(DRE));
// Check if referenced Decl was "replaced" with another identifier inside
// the derivative
if (const auto* VD = dyn_cast<VarDecl>(DRE->getDecl())) {
clonedDRE = cast<DeclRefExpr>(Clone(DRE));
if (auto* VD = dyn_cast<VarDecl>(clonedDRE->getDecl())) {
// If current context is different than the context of the original
// declaration (e.g. we are inside lambda), rebuild the DeclRefExpr
// with Sema::BuildDeclRefExpr. This is required in some cases, e.g.
// Sema::BuildDeclRefExpr is responsible for adding captured fields
// to the underlying struct of a lambda.
if (clonedDRE->getDecl()->getDeclContext() != m_Sema.CurContext) {
auto* referencedDecl = cast<VarDecl>(clonedDRE->getDecl());
clonedDRE = cast<DeclRefExpr>(BuildDeclRef(referencedDecl));
}
} else
clonedDRE = cast<DeclRefExpr>(Clone(DRE));
if (VD->getDeclContext() != m_Sema.CurContext)
clonedDRE = cast<DeclRefExpr>(BuildDeclRef(VD));

if (auto* decl = dyn_cast<VarDecl>(clonedDRE->getDecl())) {
if (isVectorValued) {
if (m_VectorOutput.size() <= outputArrayCursor)
return StmtDiff(clonedDRE);

auto it = m_VectorOutput[outputArrayCursor].find(decl);
if (it == std::end(m_VectorOutput[outputArrayCursor])) {
// Is not an independent variable, ignored.
return StmtDiff(clonedDRE);
}
auto it = m_VectorOutput[outputArrayCursor].find(VD);
if (it == std::end(m_VectorOutput[outputArrayCursor]))
return StmtDiff(clonedDRE); // Not an independent variable, ignored.

// Create the (jacobianMatrix[idx] += dfdx) statement.
if (dfdx()) {
auto* add_assign = BuildOp(BO_AddAssign, it->second, dfdx());
Expand All @@ -1339,7 +1332,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
} else {
// Check DeclRefExpr is a reference to an independent variable.
auto it = m_Variables.find(decl);
auto it = m_Variables.find(VD);
if (it == std::end(m_Variables)) {
// Is not an independent variable, ignored.
return StmtDiff(clonedDRE);
Expand All @@ -1348,7 +1341,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (dfdx()) {
// FIXME: not sure if this is generic.
// Don't update derivatives of non-record types.
if (!decl->getType()->isRecordType()) {
if (!VD->getType()->isRecordType()) {
auto* add_assign = BuildOp(BO_AddAssign, it->second, dfdx());
// Add it to the body statements.
addToCurrentBlock(add_assign, direction::reverse);
Expand Down

0 comments on commit e4c501a

Please sign in to comment.