Skip to content

Commit

Permalink
Fix suggestions and format
Browse files Browse the repository at this point in the history
  • Loading branch information
kchristin22 committed Oct 21, 2024
1 parent ef0f784 commit fc417d9
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
12 changes: 6 additions & 6 deletions include/clad/Differentiator/BuiltinDerivatives.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ __global__ void atomicAdd_kernel(T* destPtr, T* srcPtr, size_t N) {

template <typename T>
void cudaMemcpy_pullback(T* destPtr, T* srcPtr, size_t count,
cudaMemcpyKind kind, T* d_destPtr, T* d_srcPtr,
size_t* d_count, cudaMemcpyKind* d_kind)
cudaMemcpyKind kind, T* d_destPtr, T* d_srcPtr,
size_t* d_count, cudaMemcpyKind* d_kind)
__attribute__((host)) {
T* aux_destPtr;
if (kind == cudaMemcpyDeviceToHost) {
Expand All @@ -111,18 +111,18 @@ void cudaMemcpy_pullback(T* destPtr, T* srcPtr, size_t count,
cudaGetDeviceProperties(&deviceProp, 0);
size_t maxThreads = deviceProp.maxThreadsPerBlock;
size_t maxBlocks = deviceProp.maxGridSize[0];

size_t numThreads = std::min(maxThreads, N);
size_t numBlocks = std::min(maxBlocks, (N + numThreads - 1) / numThreads);
custom_derivatives::atomicAdd_kernel<<<numBlocks, numThreads>>>(
d_srcPtr, aux_destPtr, N);
cudaDeviceSynchronize();
cudaDeviceSynchronize(); // needed in case user uses another stream than the
// default one
cudaFree(aux_destPtr);
} else if (kind == cudaMemcpyHostToDevice) {
// d_kind is device to host, so d_srcPtr is a host pointer
for (size_t i = 0; i < N; ++i) {
for (size_t i = 0; i < N; ++i)
d_srcPtr[i] += aux_destPtr[i];
}
free(aux_destPtr);
}
}
Expand Down
2 changes: 1 addition & 1 deletion lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// If the function is a global kernel, all its parameters reside in the
// global memory of the GPU
else if (m_DiffReq->hasAttr<clang::CUDAGlobalAttr>())
for (auto param : params)
for (auto* param : params)
m_CUDAGlobalArgs.emplace(param);
m_Derivative->setBody(nullptr);

Expand Down

0 comments on commit fc417d9

Please sign in to comment.