Skip to content

Commit

Permalink
finish example
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy committed Mar 15, 2023
1 parent 45b6d16 commit 879ba78
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 6 deletions.
33 changes: 29 additions & 4 deletions examples/fix_external.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,39 @@ command(lmp, "boundary p p p")
command(lmp, "region cell block 0 $x_hi 0 $y_hi 0 $z_hi units box")
command(lmp, "create_box 1 cell")

# Use `pair_style zero` to create neighbor list for `julia_lj`
cutoff = 2.5
command(lmp, "pair_style zero $cutoff")
command(lmp, "pair_coeff * *")
command(lmp, "fix julia_lj all external pf/callback 1 1")

# Register external fix
lj = LAMMPS.FixExternal(lmp, "julia_lj") do fix, timestep, nlocal, ids, x, fexternal
@info "julia lj called" timestep nlocal
LAMMPS.energy_global!(fix, 0.0)
const coefficients = Dict(
1 => Dict(
1 => [48.0, 24.0, 4.0,4.0]
)
)

function compute_force(rsq, itype, jtype)
coeff = coefficients[itype][jtype]
r2inv = 1.0/rsq
r6inv = r2inv^3
lj1 = coeff[1]
lj2 = coeff[2]
return (r6inv * (lj1*r6inv - lj2))*r2inv
end

function compute_energy(rsq, itype, jtype)
coeff = coefficients[itype][jtype]
r2inv = 1.0/rsq
r6inv = r2inv^3
lj3 = coeff[3]
lj4 = coeff[4]
return (r6inv * (lj3*r6inv - lj4))
end

# Register external fix
lj = LAMMPS.PairExternal(lmp, "julia_lj", "zero", compute_force, compute_energy, cutoff)

# Setup atoms
natoms = 10
command(lmp, "create_atoms 1 random $natoms 1 NULL")
Expand Down
33 changes: 33 additions & 0 deletions src/LAMMPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,39 @@ function gather_atoms(lmp::LMP, name, T, count)
return data
end

function pair_neighbor_list(lmp, name, exact, nsub, request)
idx = API.lammps_find_pair_neighlist(lmp, name, exact, nsub, request)
if idx == -1
error("Could not find neighbor list for pair $(name)")
end
return idx
end

function fix_neighbor_list(lmp, name, request)
idx = API.lammps_find_fix_neighlist(lmp, name, request)
if idx == -1
error("Could not find neighbor list for fix $(name)")
end
return idx
end


"""
neighbors(lmb::LMP, idx, element)
Given a neighbor list `idx` and the element therein,
return the atom index, and it's neigbors.
"""
function neighbors(lmp, idx, element)
r_iatom = Ref{Cint}()
r_numneigh = Ref{Cint}()
r_neighbors = Ref{Ptr{Cint}}(0)

API.lammps_neighlist_element_neighbors(lmp, idx, element - 1, r_iatom, r_numneigh, r_neighbors)

return r_iatom[], Base.unsafe_wrap(Array, r_neighbors[], r_numneigh[]; own = false)
end

include("external.jl")

end # module
59 changes: 57 additions & 2 deletions src/external.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,14 @@ FixExternal(callback, lmp::LMP, name::String) = FixExternal(lmp::LMP, name::Stri

function fix_external_callback(ctx::Ptr{Cvoid}, timestep::Int64, nlocal::Cint, ids::Ptr{Cint}, x::Ptr{Ptr{Float64}}, fexternal::Ptr{Ptr{Float64}})
fix = Base.unsafe_pointer_to_objref(ctx)::FixExternal
nlocal = Int(nlocal)

@debug "Calling fix_external_callback on" fix timestep nlocal
shape = (nlocal, 3)
x = unsafe_wrap(x, shape)
fexternal = unsafe_wrap(fexternal, shape)
ids = unsafe_wrap(ids, (nlocal,))

@debug "Calling fix_external_callback on" fix timestep ids x fexternal
# TODO wrap as arrays
fix.callback(fix, timestep, nlocal, ids, x, fexternal)
return nothing
end
Expand All @@ -36,7 +41,57 @@ function energy_global!(fix::FixExternal, energy)
API.lammps_fix_external_set_energy_global(fix.lmp, fix.name, energy)
end

function neighbor_list(fix::FixExternal, request)
fix_neighbor_list(fix.lmp, fix.name, request)
end

# TODO
# virial_global!
# function virial_global!(fix::FixExternal, )

function PairExternal(lmp, name, neigh_name, compute_force, compute_energy, cut_global)
cutsq = cut_global^2
FixExternal(lmp, name) do fix, timestep, nlocal, ids, x, fexternal
# Full neighbor list
idx = pair_neighbor_list(fix.lmp, neigh_name, 1, 0, 0)
nelements = API.lammps_neighlist_num_elements(fix.lmp, idx)

type = LAMMPS.extract_atom(lmp, "type")

# zero-out fexternal (noticed some undef memory)
fexternal .= 0

for ii in 1:nelements
iatom, neigh = LAMMPS.neighbors(lmp, idx, ii)
iatom += 1 # 1-based indexing
xtmp, ytmp, ztmp = view(x, :, iatom) # TODO SArray?
itype = type[iatom]
for jj in 1:length(neigh)
jatom = neigh[jj] + 1
delx = xtmp - x[1, jatom]
dely = ytmp - x[2, jatom]
delz = ztmp = x[3, jatom]
jtype = type[jatom]

rsq = delx*delx + dely*dely + delz*delz;

if rsq < cutsq
fpair = compute_force(rsq, itype, jtype)

fexternal[1, iatom] += delx*fpair
fexternal[2, iatom] += dely*fpair
fexternal[3, iatom] += delz*fpair
if jatom <= nlocal # newton_pair
fexternal[1, jatom] -= delx*fpair
fexternal[2, jatom] -= dely*fpair
fexternal[3, jatom] -= delz*fpair
end

# todo call compute_energy
# TODO eflag
# TODO evflag
end
end
end
end
end

0 comments on commit 879ba78

Please sign in to comment.