From 879ba78b985422a791a021686fdf167a3d766bc8 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Wed, 15 Mar 2023 16:08:41 -0400 Subject: [PATCH] finish example --- examples/fix_external.jl | 33 +++++++++++++++++++--- src/LAMMPS.jl | 33 ++++++++++++++++++++++ src/external.jl | 59 ++++++++++++++++++++++++++++++++++++++-- 3 files changed, 119 insertions(+), 6 deletions(-) diff --git a/examples/fix_external.jl b/examples/fix_external.jl index dd849d0..44e0e21 100644 --- a/examples/fix_external.jl +++ b/examples/fix_external.jl @@ -15,14 +15,39 @@ command(lmp, "boundary p p p") command(lmp, "region cell block 0 $x_hi 0 $y_hi 0 $z_hi units box") command(lmp, "create_box 1 cell") +# Use `pair_style zero` to create neighbor list for `julia_lj` +cutoff = 2.5 +command(lmp, "pair_style zero $cutoff") +command(lmp, "pair_coeff * *") command(lmp, "fix julia_lj all external pf/callback 1 1") -# Register external fix -lj = LAMMPS.FixExternal(lmp, "julia_lj") do fix, timestep, nlocal, ids, x, fexternal - @info "julia lj called" timestep nlocal - LAMMPS.energy_global!(fix, 0.0) +const coefficients = Dict( + 1 => Dict( + 1 => [48.0, 24.0, 4.0,4.0] + ) +) + +function compute_force(rsq, itype, jtype) + coeff = coefficients[itype][jtype] + r2inv = 1.0/rsq + r6inv = r2inv^3 + lj1 = coeff[1] + lj2 = coeff[2] + return (r6inv * (lj1*r6inv - lj2))*r2inv +end + +function compute_energy(rsq, itype, jtype) + coeff = coefficients[itype][jtype] + r2inv = 1.0/rsq + r6inv = r2inv^3 + lj3 = coeff[3] + lj4 = coeff[4] + return (r6inv * (lj3*r6inv - lj4)) end +# Register external fix +lj = LAMMPS.PairExternal(lmp, "julia_lj", "zero", compute_force, compute_energy, cutoff) + # Setup atoms natoms = 10 command(lmp, "create_atoms 1 random $natoms 1 NULL") diff --git a/src/LAMMPS.jl b/src/LAMMPS.jl index 90b2fea..1f9e253 100644 --- a/src/LAMMPS.jl +++ b/src/LAMMPS.jl @@ -300,6 +300,39 @@ function gather_atoms(lmp::LMP, name, T, count) return data end +function pair_neighbor_list(lmp, name, exact, nsub, request) + idx = API.lammps_find_pair_neighlist(lmp, name, exact, nsub, request) + if idx == -1 + error("Could not find neighbor list for pair $(name)") + end + return idx +end + +function fix_neighbor_list(lmp, name, request) + idx = API.lammps_find_fix_neighlist(lmp, name, request) + if idx == -1 + error("Could not find neighbor list for fix $(name)") + end + return idx +end + + +""" + neighbors(lmb::LMP, idx, element) + +Given a neighbor list `idx` and the element therein, +return the atom index, and it's neigbors. +""" +function neighbors(lmp, idx, element) + r_iatom = Ref{Cint}() + r_numneigh = Ref{Cint}() + r_neighbors = Ref{Ptr{Cint}}(0) + + API.lammps_neighlist_element_neighbors(lmp, idx, element - 1, r_iatom, r_numneigh, r_neighbors) + + return r_iatom[], Base.unsafe_wrap(Array, r_neighbors[], r_numneigh[]; own = false) +end + include("external.jl") end # module diff --git a/src/external.jl b/src/external.jl index 1499c07..57db4fc 100644 --- a/src/external.jl +++ b/src/external.jl @@ -25,9 +25,14 @@ FixExternal(callback, lmp::LMP, name::String) = FixExternal(lmp::LMP, name::Stri function fix_external_callback(ctx::Ptr{Cvoid}, timestep::Int64, nlocal::Cint, ids::Ptr{Cint}, x::Ptr{Ptr{Float64}}, fexternal::Ptr{Ptr{Float64}}) fix = Base.unsafe_pointer_to_objref(ctx)::FixExternal + nlocal = Int(nlocal) + + @debug "Calling fix_external_callback on" fix timestep nlocal + shape = (nlocal, 3) + x = unsafe_wrap(x, shape) + fexternal = unsafe_wrap(fexternal, shape) + ids = unsafe_wrap(ids, (nlocal,)) - @debug "Calling fix_external_callback on" fix timestep ids x fexternal - # TODO wrap as arrays fix.callback(fix, timestep, nlocal, ids, x, fexternal) return nothing end @@ -36,7 +41,57 @@ function energy_global!(fix::FixExternal, energy) API.lammps_fix_external_set_energy_global(fix.lmp, fix.name, energy) end +function neighbor_list(fix::FixExternal, request) + fix_neighbor_list(fix.lmp, fix.name, request) +end + # TODO # virial_global! # function virial_global!(fix::FixExternal, ) +function PairExternal(lmp, name, neigh_name, compute_force, compute_energy, cut_global) + cutsq = cut_global^2 + FixExternal(lmp, name) do fix, timestep, nlocal, ids, x, fexternal + # Full neighbor list + idx = pair_neighbor_list(fix.lmp, neigh_name, 1, 0, 0) + nelements = API.lammps_neighlist_num_elements(fix.lmp, idx) + + type = LAMMPS.extract_atom(lmp, "type") + + # zero-out fexternal (noticed some undef memory) + fexternal .= 0 + + for ii in 1:nelements + iatom, neigh = LAMMPS.neighbors(lmp, idx, ii) + iatom += 1 # 1-based indexing + xtmp, ytmp, ztmp = view(x, :, iatom) # TODO SArray? + itype = type[iatom] + for jj in 1:length(neigh) + jatom = neigh[jj] + 1 + delx = xtmp - x[1, jatom] + dely = ytmp - x[2, jatom] + delz = ztmp = x[3, jatom] + jtype = type[jatom] + + rsq = delx*delx + dely*dely + delz*delz; + + if rsq < cutsq + fpair = compute_force(rsq, itype, jtype) + + fexternal[1, iatom] += delx*fpair + fexternal[2, iatom] += dely*fpair + fexternal[3, iatom] += delz*fpair + if jatom <= nlocal # newton_pair + fexternal[1, jatom] -= delx*fpair + fexternal[2, jatom] -= dely*fpair + fexternal[3, jatom] -= delz*fpair + end + + # todo call compute_energy + # TODO eflag + # TODO evflag + end + end + end + end +end