Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Apr 23, 2024
1 parent 1c2dc1e commit dc21b9f
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 201 deletions.
10 changes: 7 additions & 3 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1372,10 +1372,13 @@ BaseForwardModeVisitor::DifferentiateVarDecl(const VarDecl* VD) {
BuildVarDecl(VD->getType(), VD->getNameAsString(), initDiff.getExpr(),
VD->isDirectInit(), nullptr, VD->getInitStyle());
// FIXME: Create unique identifier for derivative.
VarDecl* VDDerived = BuildVarDecl(
VarDecl* VDDerived = nullptr;
if (IsDifferentiableType(VD->getType())) {
VDDerived = BuildVarDecl(
VD->getType(), "_d_" + VD->getNameAsString(), initDiff.getExpr_dx(),
VD->isDirectInit(), nullptr, VD->getInitStyle());
m_Variables.emplace(VDClone, BuildDeclRef(VDDerived));
m_Variables.emplace(VDClone, BuildDeclRef(VDDerived));
}
return DeclDiff<VarDecl>(VDClone, VDDerived);
}

Expand Down Expand Up @@ -1437,7 +1440,8 @@ StmtDiff BaseForwardModeVisitor::VisitDeclStmt(const DeclStmt* DS) {
if (VDDiff.getDecl()->getDeclName() != VD->getDeclName())
m_DeclReplacements[VD] = VDDiff.getDecl();
decls.push_back(VDDiff.getDecl());
declsDiff.push_back(VDDiff.getDecl_dx());
if (VDDiff.getDecl_dx())
declsDiff.push_back(VDDiff.getDecl_dx());
} else if (auto* SAD = dyn_cast<StaticAssertDecl>(D)) {
DeclDiff<StaticAssertDecl> SADDiff = DifferentiateStaticAssertDecl(SAD);
if (SADDiff.getDecl())
Expand Down
122 changes: 61 additions & 61 deletions test/FirstDerivative/BasicArithmeticAddSub.C
Original file line number Diff line number Diff line change
Expand Up @@ -7,90 +7,90 @@

extern "C" int printf(const char* fmt, ...);

int a_1(int x) {
int y = 4;
float a_1(float x) {
float y = 4;
return y + y; // == 0
}
// CHECK: int a_1_darg0(int x) {
// CHECK-NEXT: int _d_x = 1;
// CHECK-NEXT: int _d_y = 0;
// CHECK-NEXT: int y = 4;
// CHECK: float a_1_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: float _d_y = 0;
// CHECK-NEXT: float y = 4;
// CHECK-NEXT: return _d_y + _d_y;
// CHECK-NEXT: }

int a_2(int x) {
float a_2(float x) {
return 1 + 1; // == 0
}
// CHECK: int a_2_darg0(int x) {
// CHECK-NEXT: int _d_x = 1;
// CHECK: float a_2_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: return 0 + 0;
// CHECK-NEXT: }

int a_3(int x) {
float a_3(float x) {
return x + x; // == 2
}
// CHECK: int a_3_darg0(int x) {
// CHECK-NEXT: int _d_x = 1;
// CHECK: float a_3_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: return _d_x + _d_x;
// CHECK-NEXT: }

int a_4(int x) {
int y = 4;
float a_4(float x) {
float y = 4;
return x + y + x + 3 + x; // == 3x
}
// CHECK: int a_4_darg0(int x) {
// CHECK-NEXT: int _d_x = 1;
// CHECK-NEXT: int _d_y = 0;
// CHECK-NEXT: int y = 4;
// CHECK: float a_4_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: float _d_y = 0;
// CHECK-NEXT: float y = 4;
// CHECK-NEXT: return _d_x + _d_y + _d_x + 0 + _d_x;
// CHECK-NEXT: }

int s_1(int x) {
int y = 4;
float s_1(float x) {
float y = 4;
return y - y; // == 0
}
// CHECK: int s_1_darg0(int x) {
// CHECK-NEXT: int _d_x = 1;
// CHECK-NEXT: int _d_y = 0;
// CHECK-NEXT: int y = 4;
// CHECK: float s_1_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: float _d_y = 0;
// CHECK-NEXT: float y = 4;
// CHECK-NEXT: return _d_y - _d_y;
// CHECK-NEXT: }

int s_2(int x) {
float s_2(float x) {
return 1 - 1; // == 0
}
// CHECK: int s_2_darg0(int x) {
// CHECK-NEXT: int _d_x = 1;
// CHECK: float s_2_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: return 0 - 0;
// CHECK-NEXT: }

int s_3(int x) {
float s_3(float x) {
return x - x; // == 0
}
// CHECK: int s_3_darg0(int x) {
// CHECK-NEXT: int _d_x = 1;
// CHECK: float s_3_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: return _d_x - _d_x;
// CHECK-NEXT: }

int s_4(int x) {
int y = 4;
float s_4(float x) {
float y = 4;
return x - y - x - 3 - x; // == -1
}
// CHECK: int s_4_darg0(int x) {
// CHECK-NEXT: int _d_x = 1;
// CHECK-NEXT: int _d_y = 0;
// CHECK-NEXT: int y = 4;
// CHECK: float s_4_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: float _d_y = 0;
// CHECK-NEXT: float y = 4;
// CHECK-NEXT: return _d_x - _d_y - _d_x - 0 - _d_x;
// CHECK-NEXT: }

int as_1(int x) {
int y = 4;
float as_1(float x) {
float y = 4;
return x + x - x + y - y + 3 - 3; // == 1
}
// CHECK: int as_1_darg0(int x) {
// CHECK-NEXT: int _d_x = 1;
// CHECK-NEXT: int _d_y = 0;
// CHECK-NEXT: int y = 4;
// CHECK: float as_1_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: float _d_y = 0;
// CHECK-NEXT: float y = 4;
// CHECK-NEXT: return _d_x + _d_x - _d_x + _d_y - _d_y + 0 - 0;
// CHECK-NEXT: }

Expand All @@ -103,45 +103,45 @@ float IntegerLiteralToFloatLiteral(float x, float y) {
// CHECK-NEXT: return _d_x * x + x * _d_x - _d_y;
// CHECK-NEXT: }

int a_1_darg0(int x);
int a_2_darg0(int x);
int a_3_darg0(int x);
int a_4_darg0(int x);
int s_1_darg0(int x);
int s_2_darg0(int x);
int s_3_darg0(int x);
int s_4_darg0(int x);
int as_1_darg0(int x);
float a_1_darg0(float x);
float a_2_darg0(float x);
float a_3_darg0(float x);
float a_4_darg0(float x);
float s_1_darg0(float x);
float s_2_darg0(float x);
float s_3_darg0(float x);
float s_4_darg0(float x);
float as_1_darg0(float x);
float IntegerLiteralToFloatLiteral_darg0(float x, float y);

int main () { // expected-no-diagnostics
int x = 4;
float x = 4;
clad::differentiate(a_1, 0);
printf("Result is = %d\n", a_1_darg0(1)); // CHECK-EXEC: Result is = 0
printf("Result is = %f\n", a_1_darg0(1)); // CHECK-EXEC: Result is = 0

clad::differentiate(a_2, 0);
printf("Result is = %d\n", a_2_darg0(1)); // CHECK-EXEC: Result is = 0
printf("Result is = %f\n", a_2_darg0(1)); // CHECK-EXEC: Result is = 0

clad::differentiate(a_3, 0);
printf("Result is = %d\n", a_3_darg0(1)); // CHECK-EXEC: Result is = 2
printf("Result is = %f\n", a_3_darg0(1)); // CHECK-EXEC: Result is = 2

clad::differentiate(a_4, 0);
printf("Result is = %d\n", a_4_darg0(1)); // CHECK-EXEC: Result is = 3
printf("Result is = %f\n", a_4_darg0(1)); // CHECK-EXEC: Result is = 3

clad::differentiate(s_1, 0);
printf("Result is = %d\n", s_1_darg0(1)); // CHECK-EXEC: Result is = 0
printf("Result is = %f\n", s_1_darg0(1)); // CHECK-EXEC: Result is = 0

clad::differentiate(s_2, 0);
printf("Result is = %d\n", s_2_darg0(1)); // CHECK-EXEC: Result is = 0
printf("Result is = %f\n", s_2_darg0(1)); // CHECK-EXEC: Result is = 0

clad::differentiate(s_3, 0);
printf("Result is = %d\n", s_3_darg0(1)); // CHECK-EXEC: Result is = 0
printf("Result is = %f\n", s_3_darg0(1)); // CHECK-EXEC: Result is = 0

clad::differentiate(s_4, 0);
printf("Result is = %d\n", s_4_darg0(1)); // CHECK-EXEC: Result is = -1
printf("Result is = %f\n", s_4_darg0(1)); // CHECK-EXEC: Result is = -1

clad::differentiate(as_1, 0);
printf("Result is = %d\n", as_1_darg0(1)); // CHECK-EXEC: Result is = 1
printf("Result is = %f\n", as_1_darg0(1)); // CHECK-EXEC: Result is = 1

clad::differentiate(IntegerLiteralToFloatLiteral, 0);
printf("Result is = %f\n", IntegerLiteralToFloatLiteral_darg0(5., 0.)); // CHECK-EXEC: Result is = 10
Expand Down
34 changes: 17 additions & 17 deletions test/FirstDerivative/BasicArithmeticAll.C
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,27 @@

extern "C" int printf(const char* fmt, ...);

float basic_1(int x) {
int y = 4;
int z = 3;
float basic_1(float x) {
float y = 4;
float z = 3;
return (y + x) / (x - z) * ((x * y * z) / 5); // == y * z * (x * x - 2 * x * z - y * z) / (5 * (x - z) * (x - z))
}
// CHECK: float basic_1_darg0(int x) {
// CHECK-NEXT: int _d_x = 1;
// CHECK-NEXT: int _d_y = 0;
// CHECK-NEXT: int y = 4;
// CHECK-NEXT: int _d_z = 0;
// CHECK-NEXT: int z = 3;
// CHECK-NEXT: int _t0 = (y + x);
// CHECK-NEXT: int _t1 = (x - z);
// CHECK-NEXT: int _t2 = x * y;
// CHECK-NEXT: int _t3 = (_t2 * z);
// CHECK-NEXT: int _t4 = _t0 / _t1;
// CHECK-NEXT: int _t5 = (_t3 / 5);
// CHECK: float basic_1_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: float _d_y = 0;
// CHECK-NEXT: float y = 4;
// CHECK-NEXT: float _d_z = 0;
// CHECK-NEXT: float z = 3;
// CHECK-NEXT: float _t0 = (y + x);
// CHECK-NEXT: float _t1 = (x - z);
// CHECK-NEXT: float _t2 = x * y;
// CHECK-NEXT: float _t3 = (_t2 * z);
// CHECK-NEXT: float _t4 = _t0 / _t1;
// CHECK-NEXT: float _t5 = (_t3 / 5);
// CHECK-NEXT: return (((_d_y + _d_x) * _t1 - _t0 * (_d_x - _d_z)) / (_t1 * _t1)) * _t5 + _t4 * ((((_d_x * y + x * _d_y) * z + _t2 * _d_z) * 5 - _t3 * 0) / (5 * 5));
// CHECK-NEXT: }

float basic_1_darg0(int x);
float basic_1_darg0(float x);

double fn1(double i, double j) {
double t = 1;
Expand Down Expand Up @@ -98,7 +98,7 @@ double fn3(double i, double j) {

int main () {
clad::differentiate(basic_1, 0);
printf("Result is = %f\n", basic_1_darg0(1)); // CHECK-EXEC: Result is = -6
printf("Result is = %.2f\n", basic_1_darg0(1)); // CHECK-EXEC: Result is = -10.20
INIT(fn1, "i");
INIT(fn2, "i");
INIT(fn3, "i");
Expand Down
Loading

0 comments on commit dc21b9f

Please sign in to comment.