Skip to content

Commit

Permalink
Use local var to pop output arg value included in for loop so not to …
Browse files Browse the repository at this point in the history
…corrupt the forw pass
  • Loading branch information
kchristin22 committed Dec 15, 2024
1 parent 11b7d8e commit 31be5d8
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 5 deletions.
2 changes: 1 addition & 1 deletion benchmark/RSBench/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,4 @@ edit:
vim -p $(source) rsbench.cuh

run:
./rsbench -m event -l 102000 > output.txt
./rsbench -m event -l 102000
6 changes: 3 additions & 3 deletions benchmark/RSBench/simulation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ __device__ void body( int i, double * __restrict__ macro_xs, int mat, double E,
__device__ void calculate_macro_xs( double * __restrict__ macro_xs, int mat, double E, Input input, int * __restrict__ num_nucs, int * __restrict__ mats, int max_num_nucs, double * __restrict__ concs, int * __restrict__ n_windows, double * __restrict__ pseudo_K0Rs, Window * __restrict__ windows, Pole * __restrict__ poles, int max_num_windows, int max_num_poles )
{
// zero out macro vector
for( int i = 0; i < 4; i++ )
macro_xs[i] = 0;
// for( int i = 0; i < 4; i++ )
// macro_xs[i] = 0;

// for nuclide in mat
int sz = num_nucs[mat];
Expand Down Expand Up @@ -438,7 +438,7 @@ __device__ RSComplex fast_nuclear_W( RSComplex Z )
RSComplex one = {1, 0};
RSComplex W = c_div(c_mul(i, ( c_sub(one, fast_cexp(c_mul(t1, Z))) )) , c_mul(t2, Z));
RSComplex sum = {0,0};
#pragma unroll
// #pragma unroll
for( int n = 0; n < 10; n++ )
{
RSComplex t3 = {neg_1n[n], 0};
Expand Down
17 changes: 16 additions & 1 deletion lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3407,7 +3407,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
assert(E && "must be provided");
auto Type = getNonConstType(E->getType(), m_Context, m_Sema);

if (isInsideLoop) {
bool isOutVar = false;
if (auto arraySub = dyn_cast<ArraySubscriptExpr>(E))
if (auto declRef = dyn_cast<DeclRefExpr>(
arraySub->getBase()->IgnoreImpCasts()->IgnoreParens()))
if (m_Variables.find(declRef->getDecl()) != m_Variables.end())
isOutVar = true;

if (isInsideLoop && !isOutVar) {
auto CladTape = MakeCladTapeFor(Clone(E), prefix);
Expr* Push = CladTape.Push;
Expr* Pop = CladTape.Pop;
Expand Down Expand Up @@ -3438,6 +3445,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (E->isModifiableLvalue(m_Context) == Expr::MLV_Valid)
Restore = BuildOp(BO_Assign, Clone(E), Ref);

if (isInsideLoop) {
auto CladTape = MakeCladTapeFor(Clone(E), prefix);
Expr* Push = CladTape.Push;
Expr* Pop = CladTape.Pop;
auto* popAssign = BuildOp(BinaryOperatorKind::BO_Assign, Ref, Pop);
return {Push, popAssign};
}

return {Store, Restore};
}

Expand Down

0 comments on commit 31be5d8

Please sign in to comment.