Skip to content

Commit

Permalink
fixup! Add CUDA example for fit_external
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy committed Apr 3, 2023
1 parent f96f4ef commit 4cd91c7
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions examples/fix_external_cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,21 +77,24 @@ LAMMPS.FixExternal(lmp, "julia_lj_cuda") do fix, timestep, nlocal, ids, x, fexte
cu_type = adapt(CuArray, type)

neighbors_array = VectorOfArrays{Int64, 2}()
ilist = Vector{Int64}()

# Copy neighbor_list to Julia datastructure
for ii in 1:nelements
# local atom index (i.e. in the range [0, nlocal + nghost)
_, neigh = LAMMPS.neighbors(lmp, idx, ii)
iatom, neigh = LAMMPS.neighbors(lmp, idx, ii)
push!(neighbors_array, reshape(neigh, (1, length(neigh))))
push!(ilist, iatom)
end
@show neighbors_array

neighbors_array = adapt(CuArray, neighbors_array)
ilist = adapt(CuArray, ilist)

function kernel(potential, x, fexternal, energies, neighbors, cutsq, nlocal, type, special_lj)
iatom = threadIdx().x
neighs = neighbors[iatom]
function kernel(potential, x, fexternal, energies, ilist, neighbors, cutsq, nlocal, type, special_lj)
ii = threadIdx().x
iatom = ilist[ii]
neighs = neighbors[ii]

iatom += 1 # 1-based indexing
xtmp = x[1, iatom]
ytmp = x[2, iatom]
ztmp = x[3, iatom]
Expand Down Expand Up @@ -129,7 +132,8 @@ LAMMPS.FixExternal(lmp, "julia_lj_cuda") do fix, timestep, nlocal, ids, x, fexte
return nothing
end

@cuda threads=nlocal kernel(potential, cu_x, cu_fexternal, energies, neighbors_array,
@cuda threads=nlocal kernel(potential, cu_x, cu_fexternal, energies,
ilist, neighbors_array,
cutsq, nlocal, cu_type, special_lj)

copyto!(fexternal, cu_fexternal) # TODO async
Expand Down

0 comments on commit 4cd91c7

Please sign in to comment.