From 6797990818b6120761fcf5892e939275a9b1d0ad Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Mon, 9 Dec 2024 22:16:03 +0200 Subject: [PATCH] warn fix --- lib/Differentiator/ReverseModeVisitor.cpp | 22 ++++++++++--- test/Gradient/UserDefinedTypes.C | 38 ++++++++++++++++++++++- 2 files changed, 54 insertions(+), 6 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index ec4862b3f..a7d4dc6cb 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -34,7 +34,9 @@ #include #include +#include "llvm/ADT/SmallString.h" #include "llvm/Support/SaveAndRestore.h" +#include #include #include @@ -4281,12 +4283,22 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, CE->getConstructor(), primalArgs, reverseForwAdjointArgs, /*baseExpr=*/nullptr)) { if (RD->isAggregate()) { - diag(DiagnosticsEngine::Note, CE->getConstructor()->getBeginLoc(), - "No need to provide a custom constructor forward sweep for an " - "aggregate type."); + SmallString<128> Name_class; + llvm::raw_svector_ostream OS_class(Name_class); + RD->getNameForDiagnostic(OS_class, m_Context.getPrintingPolicy(), + /*qualified=*/true); diag(DiagnosticsEngine::Warning, CE->getBeginLoc(), - "No need to provide a custom constructor forward sweep for an " - "aggregate type."); + "'%0' is an aggregate type and its constructor does not require a " + "user-defined forward sweep function", + {OS_class.str()}); + const FunctionDecl* constr_forw = + cast(customReverseForwFnCall)->getDirectCallee(); + SmallString<128> Name_forw; + llvm::raw_svector_ostream OS_forw(Name_forw); + constr_forw->getNameForDiagnostic( + OS_forw, m_Context.getPrintingPolicy(), /*qualified=*/true); + diag(DiagnosticsEngine::Note, constr_forw->getBeginLoc(), + "'%0' is defined here", {OS_forw.str()}); } Expr* callRes = StoreAndRef(customReverseForwFnCall); Expr* val = diff --git a/test/Gradient/UserDefinedTypes.C b/test/Gradient/UserDefinedTypes.C index 6e2580a01..33b7fbc3d 100644 --- a/test/Gradient/UserDefinedTypes.C +++ b/test/Gradient/UserDefinedTypes.C @@ -1,4 +1,4 @@ -// RUN: %cladclang %s -I%S/../../include -oUserDefinedTypes.out 2>&1 | %filecheck %s +// RUN: %cladclang %s -I%S/../../include -oUserDefinedTypes.out -Xclang -verify 2>&1 | %filecheck %s // RUN: ./UserDefinedTypes.out | %filecheck_exec %s // RUN: %cladclang -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oUserDefinedTypes.out // RUN: ./UserDefinedTypes.out | %filecheck_exec %s @@ -485,6 +485,40 @@ double fn14(double x, double y) { // CHECK-NEXT: } // CHECK-NEXT:} +template +struct SimpleArray { + T elements[N]; // Aggregate initialization +}; + +namespace clad { +namespace custom_derivatives { +namespace class_functions { +template<::std::size_t N> +::clad::ValueAndAdjoint, SimpleArray> // expected-note {{'clad::custom_derivatives::class_functions::constructor_reverse_forw<2}}{{' is defined here}} +constructor_reverse_forw(::clad::ConstructorReverseForwTag>) { + SimpleArray a; + SimpleArray d_a; + return {a, d_a}; +} +}}} + +double fn15(double x, double y) { + SimpleArray arr; // expected-warning {{'SimpleArray' is an aggregate type and its constructor does not require a user-defined forward sweep function}} + return arr.elements[0]; +} + +// CHECK:void fn15_grad(double x, double y, double *_d_x, double *_d_y) { +// CHECK-NEXT: ::clad::ValueAndAdjoint< ::std::array, ::std::array > _t0 = {{.*}}constructor_reverse_forw(clad::ConstructorReverseForwTag >()); +// CHECK-NEXT: std::array _d_arr(_t0.adjoint); +// CHECK-NEXT: std::array arr(_t0.value); +// CHECK-NEXT: std::array _t1 = arr; +// CHECK-NEXT: clad::ValueAndAdjoint _t2 = _t1.operator_subscript_forw(0, &_d_arr, 0); +// CHECK-NEXT: { +// CHECK-NEXT: size_type _r0 = {{0U|0UL|0ULL}}; +// CHECK-NEXT: _t1.operator_subscript_pullback(0, 1, &_d_arr, &_r0); +// CHECK-NEXT: } +// CHECK-NEXT:} + void print(const Tangent& t) { for (int i = 0; i < 5; ++i) { printf("%.2f", t.data[i]); @@ -556,6 +590,8 @@ int main() { INIT_GRADIENT(fn14); TEST_GRADIENT(fn14, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {30.00, 22.00} + + INIT_GRADIENT(fn15); } // CHECK: void sum_pullback(Tangent &t, double _d_y, Tangent *_d_t) {