Skip to content

Commit

Permalink
Initialize adjoints of aggregate types with init lists
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Dec 9, 2024
1 parent eee6faa commit 95300e1
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 61 deletions.
9 changes: 0 additions & 9 deletions include/clad/Differentiator/STLBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -585,15 +585,6 @@ template <typename T, ::std::size_t N, typename U>
void size_pullback(::std::array<T, N>* /*a*/, U /*d_y*/,
::std::array<T, N>* /*d_a*/) noexcept {}
template <typename T, ::std::size_t N>
::clad::ValueAndAdjoint<::std::array<T, N>, ::std::array<T, N>>
constructor_reverse_forw(::clad::ConstructorReverseForwTag<::std::array<T, N>>,
const ::std::array<T, N>& arr,
const ::std::array<T, N>& d_arr) {
::std::array<T, N> a = arr;
::std::array<T, N> d_a = d_arr;
return {a, d_a};
}
template <typename T, ::std::size_t N>
void constructor_pullback(::std::array<T, N>* a, const ::std::array<T, N>& arr,
::std::array<T, N>* d_a, ::std::array<T, N>* d_arr) {
for (size_t i = 0; i < N; ++i)
Expand Down
34 changes: 22 additions & 12 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1309,7 +1309,16 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
clonedExprs[i] = Visit(ILE->getInit(i), member_acess).getExpr();
}
Expr* clonedILE = m_Sema.ActOnInitList(noLoc, clonedExprs, noLoc).get();
return StmtDiff(clonedILE);

const CXXRecordDecl* RD = ILEType->getAsCXXRecordDecl();
Expr* adjointInit = nullptr;
if (RD && RD->isAggregate()) {
llvm::SmallVector<Expr*, 4> adjParams;
for (const FieldDecl* FD : RD->fields())
adjParams.push_back(getZeroInit(FD->getType()));
adjointInit = m_Sema.ActOnInitList(noLoc, adjParams, noLoc).get();
}
return StmtDiff(clonedILE, nullptr, adjointInit);
}

// FIXME: This is a makeshift arrangement to differentiate an InitListExpr
Expand Down Expand Up @@ -2753,6 +2762,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

ConstructorPullbackCallInfo constructorPullbackInfo;

bool isConstructInit =
VD->getInit() && isa<CXXConstructExpr>(VD->getInit()->IgnoreImplicit());

// VDDerivedInit now serves two purposes -- as the initial derivative value
// or the size of the derivative array -- depending on the primal type.
if (promoteToFnScope)
Expand Down Expand Up @@ -2798,7 +2810,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
VDDerivedInit = initDiff.getForwSweepExpr_dx();
}

if (VDType->isStructureOrClassType()) {
if (isConstructInit) {
m_TrackConstructorPullbackInfo = true;
initDiff = Visit(VD->getInit());
m_TrackConstructorPullbackInfo = false;
Expand Down Expand Up @@ -2870,13 +2882,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
derivedE = BuildOp(UnaryOperatorKind::UO_Deref, derivedE);
}

if (VD->getInit()) {
if (VDType->isStructureOrClassType()) {
if (!initDiff.getExpr())
initDiff = Visit(VD->getInit());
} else
initDiff = Visit(VD->getInit(), derivedE);
}
if (VD->getInit() && !isConstructInit)
initDiff = Visit(VD->getInit(), derivedE);

// If we are differentiating `VarDecl` corresponding to a local variable
// inside a loop, then we need to reset it to 0 at each iteration.
Expand Down Expand Up @@ -4155,7 +4162,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

StmtDiff
ReverseModeVisitor::VisitCXXConstructExpr(const CXXConstructExpr* CE) {

llvm::SmallVector<Expr*, 4> primalArgs;
llvm::SmallVector<Expr*, 4> adjointArgs;
llvm::SmallVector<Expr*, 4> reverseForwAdjointArgs;
Expand Down Expand Up @@ -4214,8 +4220,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

// Try to create a pullback constructor call
llvm::SmallVector<Expr*, 4> pullbackArgs;
QualType recordType =
m_Context.getRecordType(CE->getConstructor()->getParent());
const CXXRecordDecl* RD = CE->getConstructor()->getParent();
QualType recordType = m_Context.getRecordType(RD);
QualType recordPointerType = m_Context.getPointerType(recordType);
// thisE = object being created by this constructor call.
// dThisE = adjoint of the object being created by this constructor call.
Expand Down Expand Up @@ -4274,6 +4280,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (Expr* customReverseForwFnCall = BuildCallToCustomForwPassFn(
CE->getConstructor(), primalArgs, reverseForwAdjointArgs,
/*baseExpr=*/nullptr)) {
if (RD->isAggregate())
diag(DiagnosticsEngine::Note, CE->getConstructor()->getBeginLoc(),
"No need to provide a custom constructor forward sweep for an "
"aggregate type.");
Expr* callRes = StoreAndRef(customReverseForwFnCall);
Expr* val =
utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "value");
Expand Down
7 changes: 7 additions & 0 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,13 @@ namespace clad {
Expr* zero = ConstantFolder::synthesizeLiteral(T, m_Context, /*val=*/0);
return m_Sema.ActOnInitList(noLoc, {zero}, noLoc).get();
}
if (const auto* RD = T->getAsCXXRecordDecl())
if (RD->hasDefinition() && !RD->isUnion() && RD->isAggregate()) {
llvm::SmallVector<Expr*, 4> adjParams;
for (const FieldDecl* FD : RD->fields())
adjParams.push_back(getZeroInit(FD->getType()));
return m_Sema.ActOnInitList(noLoc, adjParams, noLoc).get();
}
return m_Sema.ActOnInitList(noLoc, {}, noLoc).get();
}

Expand Down
33 changes: 16 additions & 17 deletions test/Gradient/STLCustomDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ int main() {
// CHECK-NEXT: clad::tape<std::array<double, 3> > _t2 = {};
// CHECK-NEXT: clad::tape<double> _t3 = {};
// CHECK-NEXT: clad::tape<std::array<double, 3> > _t4 = {};
// CHECK-NEXT: std::array<double, 3> _d_a({});
// CHECK-NEXT: std::array<double, 3> _d_a({{.*}});
// CHECK-NEXT: std::array<double, 3> a;
// CHECK-NEXT: std::array<double, 3> _t0 = a;
// CHECK-NEXT: {{.*}}fill_reverse_forw(&a, x, &_d_a, *_d_x);
Expand Down Expand Up @@ -544,7 +544,7 @@ int main() {
// CHECK-NEXT: }

// CHECK: void fn16_grad(double x, double y, double *_d_x, double *_d_y) {
// CHECK-NEXT: std::array<double, 2> _d_a({});
// CHECK-NEXT: std::array<double, 2> _d_a({{.*}});
// CHECK-NEXT: std::array<double, 2> a;
// CHECK-NEXT: std::array<double, 2> _t0 = a;
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t1 = {{.*}}operator_subscript_reverse_forw(&a, 0, &_d_a, _r0);
Expand All @@ -554,7 +554,7 @@ int main() {
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t4 = {{.*}}operator_subscript_reverse_forw(&a, 1, &_d_a, _r1);
// CHECK-NEXT: double _t5 = _t4.value;
// CHECK-NEXT: _t4.value = y;
// CHECK-NEXT: std::array<double, 3> _d__b({});
// CHECK-NEXT: std::array<double, 3> _d__b({{.*}});
// CHECK-NEXT: std::array<double, 3> _b0;
// CHECK-NEXT: std::array<double, 3> _t6 = _b0;
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t7 = {{.*}}operator_subscript_reverse_forw(&_b0, 0, &_d__b, _r2);
Expand All @@ -568,23 +568,22 @@ int main() {
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t13 = {{.*}}operator_subscript_reverse_forw(&_b0, 2, &_d__b, _r4);
// CHECK-NEXT: double _t14 = _t13.value;
// CHECK-NEXT: _t13.value = x * x;
// CHECK-NEXT: ::clad::ValueAndAdjoint< ::std::array<double, {{3U|3UL}}>, ::std::array<double, {{3U|3UL}}> > _t15 = {{.*}}constructor_reverse_forw(clad::ConstructorReverseForwTag<array<double, 3> >(), _b0, _d__b);
// CHECK-NEXT: std::array<double, 3> _d_b = _t15.adjoint;
// CHECK-NEXT: const std::array<double, 3> b = _t15.value;
// CHECK-NEXT: std::array<double, 2> _t18 = a;
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t19 = {{.*}}back_reverse_forw(&a, &_d_a);
// CHECK-NEXT: std::array<double, 3> _d_b = {{.*}};
// CHECK-NEXT: const std::array<double, 3> b = _b0;
// CHECK-NEXT: std::array<double, 2> _t17 = a;
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t18 = {{.*}}back_reverse_forw(&a, &_d_a);
// CHECK-NEXT: std::array<double, 3> _t19 = b;
// CHECK-NEXT: {{.*}}value_type _t16 = b.front();
// CHECK-NEXT: std::array<double, 3> _t20 = b;
// CHECK-NEXT: {{.*}}value_type _t17 = b.front();
// CHECK-NEXT: {{.*}}value_type _t15 = b.at(2);
// CHECK-NEXT: std::array<double, 3> _t21 = b;
// CHECK-NEXT: {{.*}}value_type _t16 = b.at(2);
// CHECK-NEXT: std::array<double, 3> _t22 = b;
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}}back_pullback(&_t18, 1 * _t16 * _t17, &_d_a);
// CHECK-NEXT: {{.*}}front_pullback(&_t20, _t19.value * 1 * _t16, &_d_b);
// CHECK-NEXT: {{.*}}back_pullback(&_t17, 1 * _t15 * _t16, &_d_a);
// CHECK-NEXT: {{.*}}front_pullback(&_t19, _t18.value * 1 * _t15, &_d_b);
// CHECK-NEXT: {{.*}}size_type _r5 = {{0U|0UL}};
// CHECK-NEXT: {{.*}}at_pullback(&_t21, 2, _t19.value * _t17 * 1, &_d_b, &_r5);
// CHECK-NEXT: {{.*}}at_pullback(&_t20, 2, _t18.value * _t16 * 1, &_d_b, &_r5);
// CHECK-NEXT: {{.*}}size_type _r6 = {{0U|0UL}};
// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t22, 1, 1, &_d_b, &_r6);
// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t21, 1, 1, &_d_b, &_r6);
// CHECK-NEXT: }
// CHECK-NEXT: {{.*}}constructor_pullback(&b, _b0, &_d_b, &_d__b);
// CHECK-NEXT: {
Expand Down Expand Up @@ -629,7 +628,7 @@ int main() {
// CHECK-NEXT: }

// CHECK: void fn17_grad(double x, double y, double *_d_x, double *_d_y) {
// CHECK-NEXT: std::array<double, 50> _d_a({});
// CHECK-NEXT: std::array<double, 50> _d_a({{.*}});
// CHECK-NEXT: std::array<double, 50> a;
// CHECK-NEXT: std::array<double, 50> _t0 = a;
// CHECK-NEXT: {{.*}}fill_reverse_forw(&a, y + x + x, &_d_a, _r0);
Expand All @@ -653,7 +652,7 @@ int main() {
// CHECK-NEXT: }

// CHECK: void fn18_grad(double x, double y, double *_d_x, double *_d_y) {
// CHECK-NEXT: std::array<double, 2> _d_a({});
// CHECK-NEXT: std::array<double, 2> _d_a({{.*}});
// CHECK-NEXT: std::array<double, 2> a;
// CHECK-NEXT: std::array<double, 2> _t0 = a;
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t1 = {{.*}}operator_subscript_reverse_forw(&a, 1, &_d_a, _r0);
Expand Down
27 changes: 24 additions & 3 deletions test/Gradient/UserDefinedTypes.C
Original file line number Diff line number Diff line change
Expand Up @@ -395,10 +395,10 @@ MyStruct fn12(MyStruct s) {

// CHECK: void fn12_grad(MyStruct s, MyStruct *_d_s) {
// CHECK-NEXT: MyStruct _t0 = s;
// CHECK-NEXT: clad::ValueAndAdjoint<MyStruct &, MyStruct &> _t1 = _t0.operator_equal_forw({2 * s.a, 2 * s.b + 2}, &(*_d_s), {});
// CHECK-NEXT: clad::ValueAndAdjoint<MyStruct &, MyStruct &> _t1 = _t0.operator_equal_forw({2 * s.a, 2 * s.b + 2}, &(*_d_s), {0., 0.});
// CHECK-NEXT: {
// CHECK-NEXT: MyStruct _r0 = {};
// CHECK-NEXT: _t0.operator_equal_pullback({2 * s.a, 2 * s.b + 2}, {}, &(*_d_s), &_r0);
// CHECK-NEXT: MyStruct _r0 = {0., 0.};
// CHECK-NEXT: _t0.operator_equal_pullback({2 * s.a, 2 * s.b + 2}, {0., 0.}, &(*_d_s), &_r0);
// CHECK-NEXT: (*_d_s).a += 2 * _r0.a;
// CHECK-NEXT: (*_d_s).b += 2 * _r0.b;
// CHECK-NEXT: }
Expand Down Expand Up @@ -467,6 +467,24 @@ void fn13(double *x, double *y, int size)
// CHECK-NEXT: }
// CHECK-NEXT:}

double fn14(double x, double y) {
MyStruct s = {2 * y, 3 * x + 2};
return s.a * s.b;
}

// CHECK: void fn14_grad(double x, double y, double *_d_x, double *_d_y) {
// CHECK-NEXT: MyStruct _d_s = {0., 0.};
// CHECK-NEXT: MyStruct s = {2 * y, 3 * x + 2};
// CHECK-NEXT: {
// CHECK-NEXT: _d_s.a += 1 * s.b;
// CHECK-NEXT: _d_s.b += s.a * 1;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: *_d_y += 2 * _d_s.a;
// CHECK-NEXT: *_d_x += 3 * _d_s.b;
// CHECK-NEXT: }
// CHECK-NEXT:}

void print(const Tangent& t) {
for (int i = 0; i < 5; ++i) {
printf("%.2f", t.data[i]);
Expand Down Expand Up @@ -535,6 +553,9 @@ int main() {
fn13_test.execute(x, y, 3, d_x, d_y);
printArray(d_x, size); // CHECK-EXEC: {2.00, 2.00, 2.00}
printArray(d_y, size); // CHECK-EXEC: {0.00, 0.00, 0.00}

INIT_GRADIENT(fn14);
TEST_GRADIENT(fn14, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {30.00, 22.00}
}

// CHECK: void sum_pullback(Tangent &t, double _d_y, Tangent *_d_t) {
Expand Down
32 changes: 16 additions & 16 deletions test/Hessian/BuiltinDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -186,16 +186,16 @@ int main() {
// CHECK: void f1_darg0_grad(float x, float *_d_x) {
// CHECK-NEXT: float _d__d_x = 0.F;
// CHECK-NEXT: float _d_x0 = 1;
// CHECK-NEXT: ValueAndPushforward<float, float> _d__t0 = {};
// CHECK-NEXT: ValueAndPushforward<float, float> _d__t0 = {0.F, 0.F};
// CHECK-NEXT: ValueAndPushforward<float, float> _t00 = clad::custom_derivatives{{(::std)?}}::sin_pushforward(x, _d_x0);
// CHECK-NEXT: ValueAndPushforward<float, float> _d__t1 = {};
// CHECK-NEXT: ValueAndPushforward<float, float> _d__t1 = {0.F, 0.F};
// CHECK-NEXT: ValueAndPushforward<float, float> _t10 = clad::custom_derivatives{{(::std)?}}::cos_pushforward(x, _d_x0);
// CHECK-NEXT: {
// CHECK-NEXT: _d__t0.pushforward += 1;
// CHECK-NEXT: _d__t1.pushforward += 1;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}}ValueAndPushforward<float, float> _r3 = {};
// CHECK-NEXT: {{.*}}ValueAndPushforward<float, float> _r3 = {0.F, 0.F};
// CHECK-NEXT: clad::custom_derivatives::class_functions::constructor_pullback(&_t10, {{.*}}cos_pushforward(x, _d_x0), &_d__t1, &_r3);
// CHECK-NEXT: float _r4 = 0.F;
// CHECK-NEXT: float _r5 = 0.F;
Expand All @@ -204,7 +204,7 @@ int main() {
// CHECK-NEXT: _d__d_x += _r5;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}}ValueAndPushforward<float, float> _r0 = {};
// CHECK-NEXT: {{.*}}ValueAndPushforward<float, float> _r0 = {0.F, 0.F};
// CHECK-NEXT: clad::custom_derivatives::class_functions::constructor_pullback(&_t00, {{.*}}sin_pushforward(x, _d_x0), &_d__t0, &_r0);
// CHECK-NEXT: float _r1 = 0.F;
// CHECK-NEXT: float _r2 = 0.F;
Expand All @@ -225,11 +225,11 @@ int main() {
// CHECK: void f2_darg0_grad(float x, float *_d_x) {
// CHECK-NEXT: float _d__d_x = 0.F;
// CHECK-NEXT: float _d_x0 = 1;
// CHECK-NEXT: ValueAndPushforward<float, float> _d__t0 = {};
// CHECK-NEXT: ValueAndPushforward<float, float> _d__t0 = {0.F, 0.F};
// CHECK-NEXT: ValueAndPushforward<float, float> _t00 = clad::custom_derivatives{{(::std)?}}::exp_pushforward(x, _d_x0);
// CHECK-NEXT: _d__t0.pushforward += 1;
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}}ValueAndPushforward<float, float> _r0 = {};
// CHECK-NEXT: {{.*}}ValueAndPushforward<float, float> _r0 = {0.F, 0.F};
// CHECK-NEXT: {{.*}}constructor_pullback(&_t00, {{.*}}exp_pushforward(x, _d_x0), &_d__t0, &_r0);
// CHECK-NEXT: float _r1 = 0.F;
// CHECK-NEXT: float _r2 = 0.F;
Expand All @@ -250,11 +250,11 @@ int main() {
// CHECK: void f3_darg0_grad(float x, float *_d_x) {
// CHECK-NEXT: float _d__d_x = 0.F;
// CHECK-NEXT: float _d_x0 = 1;
// CHECK-NEXT: ValueAndPushforward<float, float> _d__t0 = {};
// CHECK-NEXT: ValueAndPushforward<float, float> _d__t0 = {0.F, 0.F};
// CHECK-NEXT: ValueAndPushforward<float, float> _t00 = clad::custom_derivatives{{(::std)?}}::log_pushforward(x, _d_x0);
// CHECK-NEXT: _d__t0.pushforward += 1;
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}}ValueAndPushforward<float, float> _r0 = {};
// CHECK-NEXT: {{.*}}ValueAndPushforward<float, float> _r0 = {0.F, 0.F};
// CHECK-NEXT: {{.*}}constructor_pullback(&_t00, {{.*}}log_pushforward(x, _d_x0), &_d__t0, &_r0);
// CHECK-NEXT: float _r1 = 0.F;
// CHECK-NEXT: float _r2 = 0.F;
Expand All @@ -275,11 +275,11 @@ int main() {
// CHECK: void f4_darg0_grad(float x, float *_d_x) {
// CHECK-NEXT: float _d__d_x = 0.F;
// CHECK-NEXT: float _d_x0 = 1;
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<float, float> _d__t0 = {};
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<float, float> _d__t0 = {0.F, 0.F};
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<float, float> _t00 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, 4.F, _d_x0, 0.F);
// CHECK-NEXT: _d__t0.pushforward += 1;
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}} _r0 = {};
// CHECK-NEXT: {{.*}} _r0 = {0.F, 0.F};
// CHECK-NEXT: {{.*}}constructor_pullback(&_t00, {{.*}}pow_pushforward(x, 4.F, _d_x0, 0.F), &_d__t0, &_r0);
// CHECK-NEXT: float _r1 = 0.F;
// CHECK-NEXT: float _r2 = 0.F;
Expand All @@ -300,11 +300,11 @@ int main() {
// CHECK: void f5_darg0_grad(float x, float *_d_x) {
// CHECK-NEXT: float _d__d_x = 0.F;
// CHECK-NEXT: float _d_x0 = 1;
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<float, float> _d__t0 = {};
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<float, float> _d__t0 = {0.F, 0.F};
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<float, float> _t00 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(2.F, x, 0.F, _d_x0);
// CHECK-NEXT: _d__t0.pushforward += 1;
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}} _r0 = {};
// CHECK-NEXT: {{.*}} _r0 = {0.F, 0.F};
// CHECK-NEXT: {{.*}}constructor_pullback(&_t00, {{.*}}pow_pushforward(2.F, x, 0.F, _d_x0), &_d__t0, &_r0);
// CHECK-NEXT: float _r1 = 0.F;
// CHECK-NEXT: float _r2 = 0.F;
Expand All @@ -328,11 +328,11 @@ int main() {
// CHECK-NEXT: float _d_x0 = 1;
// CHECK-NEXT: float _d__d_y = 0.F;
// CHECK-NEXT: float _d_y0 = 0;
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<float, float> _d__t0 = {};
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<float, float> _d__t0 = {0.F, 0.F};
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<float, float> _t00 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, y, _d_x0, _d_y0);
// CHECK-NEXT: _d__t0.pushforward += 1;
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}} _r0 = {};
// CHECK-NEXT: {{.*}} _r0 = {0.F, 0.F};
// CHECK-NEXT: {{.*}}constructor_pullback(&_t00, {{.*}}pow_pushforward(x, y, _d_x0, _d_y0), &_d__t0, &_r0);
// CHECK-NEXT: float _r1 = 0.F;
// CHECK-NEXT: float _r2 = 0.F;
Expand All @@ -358,11 +358,11 @@ int main() {
// CHECK-NEXT: float _d_x0 = 0;
// CHECK-NEXT: float _d__d_y = 0.F;
// CHECK-NEXT: float _d_y0 = 1;
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<float, float> _d__t0 = {};
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<float, float> _d__t0 = {0.F, 0.F};
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<float, float> _t00 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, y, _d_x0, _d_y0);
// CHECK-NEXT: _d__t0.pushforward += 1;
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}} _r0 = {};
// CHECK-NEXT: {{.*}} _r0 = {0.F, 0.F};
// CHECK-NEXT: {{.*}}constructor_pullback(&_t00, {{.*}}pow_pushforward(x, y, _d_x0, _d_y0), &_d__t0, &_r0);
// CHECK-NEXT: float _r1 = 0.F;
// CHECK-NEXT: float _r2 = 0.F;
Expand Down
Loading

0 comments on commit 95300e1

Please sign in to comment.