From d96937ceb380ac353e873574a479196acb45ab9c Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Mon, 9 Dec 2024 22:16:03 +0200 Subject: [PATCH] warn fix --- lib/Differentiator/ReverseModeVisitor.cpp | 23 ++++++++++++---- test/Gradient/UserDefinedTypes.C | 33 ++++++++++++++++++++++- 2 files changed, 50 insertions(+), 6 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index ec4862b3f..beaa21a9e 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -4281,12 +4281,25 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, CE->getConstructor(), primalArgs, reverseForwAdjointArgs, /*baseExpr=*/nullptr)) { if (RD->isAggregate()) { - diag(DiagnosticsEngine::Note, CE->getConstructor()->getBeginLoc(), - "No need to provide a custom constructor forward sweep for an " - "aggregate type."); + // set up printing policy + clang::LangOptions LangOpts; + LangOpts.CPlusPlus = true; + clang::PrintingPolicy Policy(LangOpts); + Policy.Bool = true; + SmallString<128> Name_class; + llvm::raw_svector_ostream OS_class(Name_class); + RD->getNameForDiagnostic(OS_class, Policy, /*qualified=*/true); diag(DiagnosticsEngine::Warning, CE->getBeginLoc(), - "No need to provide a custom constructor forward sweep for an " - "aggregate type."); + "'%0' is an aggregate type and its constructor does not require a " + "user-defined forward sweep function", + {OS_class.str()}); + const FunctionDecl* constr_forw = + cast(customReverseForwFnCall)->getDirectCallee(); + SmallString<128> Name_forw; + llvm::raw_svector_ostream OS_forw(Name_forw); + constr_forw->getNameForDiagnostic(OS_forw, Policy, /*qualified=*/true); + diag(DiagnosticsEngine::Note, constr_forw->getBeginLoc(), + "'%0' is defined here", {OS_forw.str()}); } Expr* callRes = StoreAndRef(customReverseForwFnCall); Expr* val = diff --git a/test/Gradient/UserDefinedTypes.C b/test/Gradient/UserDefinedTypes.C index 6e2580a01..5c320afef 100644 --- a/test/Gradient/UserDefinedTypes.C +++ b/test/Gradient/UserDefinedTypes.C @@ -1,4 +1,4 @@ -// RUN: %cladclang %s -I%S/../../include -oUserDefinedTypes.out 2>&1 | %filecheck %s +// RUN: %cladclang %s -I%S/../../include -oUserDefinedTypes.out -Xclang -verify 2>&1 | %filecheck %s // RUN: ./UserDefinedTypes.out | %filecheck_exec %s // RUN: %cladclang -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oUserDefinedTypes.out // RUN: ./UserDefinedTypes.out | %filecheck_exec %s @@ -485,6 +485,35 @@ double fn14(double x, double y) { // CHECK-NEXT: } // CHECK-NEXT:} +namespace clad { +namespace custom_derivatives { +namespace class_functions { +template<::std::size_t N> +::clad::ValueAndAdjoint<::std::array, ::std::array> // expected-note {{'clad::custom_derivatives::class_functions::constructor_reverse_forw<2UL>' is defined here}} +constructor_reverse_forw(::clad::ConstructorReverseForwTag<::std::array>) { + ::std::array a; + ::std::array d_a; + return {a, d_a}; +} +}}} + +double fn15(double x, double y) { + std::array arr; // expected-warning {{'std::array' is an aggregate type and its constructor does not require a user-defined forward sweep function}} + return arr[0]; +} + +// CHECK:void fn15_grad(double x, double y, double *_d_x, double *_d_y) { +// CHECK-NEXT: ::clad::ValueAndAdjoint< ::std::array, ::std::array > _t0 = {{.*}}constructor_reverse_forw(clad::ConstructorReverseForwTag >()); +// CHECK-NEXT: std::array _d_arr(_t0.adjoint); +// CHECK-NEXT: std::array arr(_t0.value); +// CHECK-NEXT: std::array _t1 = arr; +// CHECK-NEXT: clad::ValueAndAdjoint _t2 = _t1.operator_subscript_forw(0, &_d_arr, 0); +// CHECK-NEXT: { +// CHECK-NEXT: std::array::size_type _r0 = 0UL; +// CHECK-NEXT: _t1.operator_subscript_pullback(0, 1, &_d_arr, &_r0); +// CHECK-NEXT: } +// CHECK-NEXT:} + void print(const Tangent& t) { for (int i = 0; i < 5; ++i) { printf("%.2f", t.data[i]); @@ -556,6 +585,8 @@ int main() { INIT_GRADIENT(fn14); TEST_GRADIENT(fn14, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {30.00, 22.00} + + INIT_GRADIENT(fn15); } // CHECK: void sum_pullback(Tangent &t, double _d_y, Tangent *_d_t) {