Skip to content

Commit

Permalink
Fix struct initialization and return stmts (#1142)
Browse files Browse the repository at this point in the history
* Fix struct initialization and struct return stmts

* Set decl replacement as used and referenced in the code
  • Loading branch information
kchristin22 authored Nov 22, 2024
1 parent fe11e45 commit 284c33e
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 7 deletions.
7 changes: 6 additions & 1 deletion lib/Differentiator/ReverseModeForwPassVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,18 @@ ReverseModeForwPassVisitor::BuildParams(DiffParams& diffParams) {
if (newPVD->getIdentifier())
m_Sema.PushOnScopeChains(newPVD, getCurrentScope(),
/*AddToContext=*/false);
else {
IdentifierInfo* newName = CreateUniqueIdentifier("arg");
newPVD->setDeclName(newName);
m_DeclReplacements[PVD] = newPVD;
}

auto* it = std::find(std::begin(diffParams), std::end(diffParams), PVD);
if (it != std::end(diffParams)) {
*it = newPVD;
QualType dType = derivativeFnType->getParamType(dParamTypesIdx);
IdentifierInfo* dII =
CreateUniqueIdentifier("_d_" + PVD->getNameAsString());
CreateUniqueIdentifier("_d_" + newPVD->getNameAsString());
auto* dPVD = utils::BuildParmVarDecl(m_Sema, m_Derivative, dII, dType,
PVD->getStorageClass());
paramDerivatives.push_back(dPVD);
Expand Down
12 changes: 10 additions & 2 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1216,7 +1216,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
const Expr* value = RS->getRetValue();
QualType type = value->getType();
auto* dfdf = m_Pullback;
if (dfdf && (isa<FloatingLiteral>(dfdf) || isa<IntegerLiteral>(dfdf))) {
if (dfdf && (isa<FloatingLiteral>(dfdf) || isa<IntegerLiteral>(dfdf)) &&
type->isScalarType()) {
ExprResult tmp = dfdf;
dfdf = m_Sema
.ImpCastExprToType(tmp.get(), type,
Expand Down Expand Up @@ -1277,6 +1278,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

StmtDiff ReverseModeVisitor::VisitInitListExpr(const InitListExpr* ILE) {
if (!dfdx())
return StmtDiff(Clone(ILE));
QualType ILEType = ILE->getType();
llvm::SmallVector<Expr*, 16> clonedExprs(ILE->getNumInits());
if (isArrayOrPointerType(ILEType)) {
Expand Down Expand Up @@ -4499,6 +4502,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (newPVD->getIdentifier())
m_Sema.PushOnScopeChains(newPVD, getCurrentScope(),
/*AddToContext=*/false);
else {
IdentifierInfo* newName = CreateUniqueIdentifier("arg");
newPVD->setDeclName(newName);
m_DeclReplacements[PVD] = newPVD;
}

auto* it = std::find(std::begin(diffParams), std::end(diffParams), PVD);
if (it != std::end(diffParams)) {
Expand All @@ -4507,7 +4515,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_DiffReq.Mode == DiffMode::experimental_pullback) {
QualType dType = derivativeFnType->getParamType(dParamTypesIdx);
IdentifierInfo* dII =
CreateUniqueIdentifier("_d_" + PVD->getNameAsString());
CreateUniqueIdentifier("_d_" + newPVD->getNameAsString());
auto* dPVD = utils::BuildParmVarDecl(m_Sema, m_Derivative, dII, dType,
PVD->getStorageClass());
paramDerivatives.push_back(dPVD);
Expand Down
4 changes: 3 additions & 1 deletion lib/Differentiator/StmtClone.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,8 @@ bool ReferencesUpdater::VisitDeclRefExpr(DeclRefExpr* DRE) {
auto it = m_DeclReplacements.find(VD);
if (it != std::end(m_DeclReplacements)) {
DRE->setDecl(it->second);
DRE->getDecl()->setReferenced();
DRE->getDecl()->setIsUsed();
QualType NonRefQT = it->second->getType().getNonReferenceType();
if (NonRefQT != DRE->getType())
DRE->setType(NonRefQT);
Expand All @@ -552,7 +554,7 @@ bool ReferencesUpdater::VisitDeclRefExpr(DeclRefExpr* DRE) {
// FIXME: Handle the case when there are overloads found. Update
// it with the best match.
//
// FIXME: This is the right way to go in principe, however there is no
// FIXME: This is the right way to go in principle, however there is no
// properly built decl context.
// m_Sema.MarkDeclRefReferenced(clonedDRE);
if (!R.isSingleResult())
Expand Down
2 changes: 1 addition & 1 deletion test/ForwardMode/STLCustomDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -426,4 +426,4 @@ int main() {
TEST_DIFFERENTIATE(fnArr1, 3); // CHECK-EXEC: {3.00}
TEST_DIFFERENTIATE(fnArr2, 3); // CHECK-EXEC: {108.00}
TEST_DIFFERENTIATE(fnTuple1, 3, 4); // CHECK-EXEC: {2.00}
}
}
2 changes: 1 addition & 1 deletion test/Gradient/STLCustomDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -841,4 +841,4 @@ int main() {
// CHECK-NEXT: {{.*}}value_type _r0 = 0.;
// CHECK-NEXT: {{.*}}push_back_pullback(&_t0, 0{{.*}}, &_d_a, &_r0);
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
58 changes: 57 additions & 1 deletion test/Gradient/UserDefinedTypes.C
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,27 @@ double fn11(double x, double y) {
// CHECK-NEXT: }
// CHECK-NEXT: }

struct MyStruct{
double a;
double b;
};

MyStruct fn12(MyStruct s) {
s = {2 * s.a, 2 * s.b + 2};
return s;
}

// CHECK: void fn12_grad(MyStruct s, MyStruct *_d_s) {
// CHECK-NEXT: MyStruct _t0 = s;
// CHECK-NEXT: clad::ValueAndAdjoint<MyStruct &, MyStruct &> _t1 = _t0.operator_equal_forw({2 * s.a, 2 * s.b + 2}, &(*_d_s), {});
// CHECK-NEXT: {
// CHECK-NEXT: MyStruct _r0 = {};
// CHECK-NEXT: _t0.operator_equal_pullback({2 * s.a, 2 * s.b + 2}, {}, &(*_d_s), &_r0);
// CHECK-NEXT: (*_d_s).a += 2 * _r0.a;
// CHECK-NEXT: (*_d_s).b += 2 * _r0.b;
// CHECK-NEXT: }
// CHECK-NEXT:}

void print(const Tangent& t) {
for (int i = 0; i < 5; ++i) {
printf("%.2f", t.data[i]);
Expand All @@ -391,6 +412,10 @@ void print(const Tangent& t) {
}
}

void print(const MyStruct& s) {
printf("{%.2f, %.2f}\n", s.a, s.b);
}

int main() {
pairdd p(3, 5), d_p;
double i = 3, d_i, d_j;
Expand Down Expand Up @@ -425,6 +450,10 @@ int main() {
TEST_GRADIENT(fn9, /*numOfDerivativeArgs=*/2, t, c1, &d_t, &d_c1); // CHECK-EXEC: {1.00, 1.00, 1.00, 1.00, 1.00, 5.00, 10.00}
TEST_GRADIENT(fn10, /*numOfDerivativeArgs=*/2, 5, 10, &d_i, &d_j); // CHECK-EXEC: {1.00, 0.00}
TEST_GRADIENT(fn11, /*numOfDerivativeArgs=*/2, 3, -14, &d_i, &d_j); // CHECK-EXEC: {1.00, -1.00}
MyStruct s = {1.0, 2.0}, d_s = {1.0, 1.0};
auto fn12_test = clad::gradient(fn12);
fn12_test.execute(s, &d_s);
print(d_s); // CHECK-EXEC: {2.00, 2.00}
}

// CHECK: void sum_pullback(Tangent &t, double _d_y, Tangent *_d_t) {
Expand Down Expand Up @@ -546,4 +575,31 @@ int main() {
// CHECK-NEXT: *_d_x += _d_y;
// CHECK-NEXT: (*_d_t).data[0] += _d_y;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: inline constexpr void operator_equal_pullback(MyStruct &&arg, MyStruct _d_y, MyStruct *_d_this, MyStruct *_d_arg) noexcept {
// CHECK-NEXT: double _t0 = this->a;
// CHECK-NEXT: this->a = arg.a;
// CHECK-NEXT: double _t1 = this->b;
// CHECK-NEXT: this->b = arg.b;
// CHECK-NEXT: {
// CHECK-NEXT: this->b = _t1;
// CHECK-NEXT: double _r_d1 = (*_d_this).b;
// CHECK-NEXT: (*_d_this).b = 0.;
// CHECK-NEXT: (*_d_arg).b += _r_d1;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: this->a = _t0;
// CHECK-NEXT: double _r_d0 = (*_d_this).a;
// CHECK-NEXT: (*_d_this).a = 0.;
// CHECK-NEXT: (*_d_arg).a += _r_d0;
// CHECK-NEXT: }
// CHECK-NEXT:}

// CHECK: inline constexpr clad::ValueAndAdjoint<MyStruct &, MyStruct &> operator_equal_forw(MyStruct &&arg, MyStruct *_d_this, MyStruct &&_d_arg) noexcept {
// CHECK-NEXT: double _t0 = this->a;
// CHECK-NEXT: this->a = arg.a;
// CHECK-NEXT: double _t1 = this->b;
// CHECK-NEXT: this->b = arg.b;
// CHECK-NEXT: return {*this, (*_d_this)};
// CHECK-NEXT:}

0 comments on commit 284c33e

Please sign in to comment.