Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for fix external #28

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
95 changes: 95 additions & 0 deletions examples/fix_external.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
using LAMMPS
using Test
using AtomsCalculators: potential_energy, forces
using AtomsBase
using AtomsBuilder
using ACEpotentials
using ExtXYZ
using AtomsBase: AbstractSystem
using LinearAlgebra: norm
using Unitful

lmp = LMP(["-screen", "none"])


command(lmp, "units lj")
command(lmp, "atom_style atomic")
command(lmp, "atom_modify map array sort 0 0")
command(lmp, "box tilt large")

# Setup box
x_hi = 10.0
y_hi = 10.0
z_hi = 10.0
command(lmp, "boundary p p p")
command(lmp, "region cell block 0 $x_hi 0 $y_hi 0 $z_hi units box")
command(lmp, "create_box 1 cell")

# Use `pair_style zero` to create neighbor list for `julia_lj`
cutoff = 2.5
command(lmp, "pair_style zero $cutoff")
command(lmp, "pair_coeff * *")
command(lmp, "fix julia_lj all external pf/callback 1 1")

if !isfile("Si_dataset.xyz")
download("https://www.dropbox.com/scl/fi/z6lvcpx3djp775zenz032/Si-PRX-2018.xyz?rlkey=ja5e9z99c3ta1ugra5ayq5lcv&st=cs6g7vbu&dl=1",
"Si_dataset.xyz");
end

Si_dataset = ExtXYZ.load("Si_dataset.xyz");

Si_tiny_dataset, _, _ = ACEpotentials.example_dataset("Si_tiny");

deleteat!(Si_dataset, 1);

hyperparams = (elements = [:Si,],
order = 3,
totaldegree = 8,
rcut = 2.5,
Eref = Dict(:Si => -158.54496821))
model = ACEpotentials.ace1_model(;hyperparams...);
solver = ACEfit.QR(lambda=1e-1)
data_keys = (energy_key = "dft_energy", force_key = "dft_force", virial_key = "dft_virial")
acefit!(Si_tiny_dataset, model;
solver=solver, data_keys...);
labelmap = Dict(1 => :Si)

@inline function compute_force(pos, types)
sys_size = size(pos, 2)
particles = [AtomsBase.Atom(ChemicalSpecies(labelmap[types[i]]), pos[:, i].*u"Å") for i in 1:sys_size]
cell = AtomsBuilder.bulk(:Si, cubic=true) * 3
sys = FlexibleSystem(particles, cell)
f = forces(sys, model)
return ustrip.(f)
end

@inline function compute_energy(pos, types)
sys_size = size(pos, 2)
particles = [AtomsBase.Atom(ChemicalSpecies(labelmap[types[i]]), pos[:, i].*u"Å") for i in 1:sys_size]
cell = AtomsBuilder.bulk(:Si, cubic=true) * 3
sys = FlexibleSystem(particles, cell)
energy = potential_energy(sys, model)
return ustrip(energy)
end

# Register external fix
lj = LAMMPS.PairExternal(lmp, "julia_lj", "zero", compute_force, compute_energy, cutoff, true, true)

# Setup atoms
natoms = 54
command(lmp, "labelmap atom 1 Si")
command(lmp, "create_atoms Si random $natoms 1 NULL")
positions = rand(3, natoms) .* 5
command(lmp, "mass 1 28.0855")
LAMMPS.scatter!(lmp, "x", positions)

command(lmp, "run 0")

# extract forces
forces_julia = gather(lmp, "f", Float64)

particles = [AtomsBase.Atom(ChemicalSpecies(labelmap[1]), positions[:, i].*u"Å") for i in 1:54]
cell = AtomsBuilder.bulk(:Si, cubic=true) * 3
sys = FlexibleSystem(particles, cell)
f = forces(sys, model)

49 changes: 47 additions & 2 deletions src/LAMMPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ end

mutable struct LMP
@atomic handle::Ptr{Cvoid}
external_fixes::Dict{String, Any}

function LMP(args::Vector{String}=String[], comm::Union{Nothing, MPI.Comm}=nothing)
if !isempty(args)
Expand All @@ -67,7 +68,7 @@ mutable struct LMP
end
end

this = new(handle)
this = new(handle, Dict{String, Any}())
finalizer(close!, this)
return this
end
Expand All @@ -82,8 +83,10 @@ Shutdown an LMP instance.
function close!(lmp::LMP)
handle = @atomicswap lmp.handle = C_NULL
if handle !== C_NULL
API.lammps_close(handle)
empty!(lmp.external_fixes)
API.lammps_close(handle)
end
return nothing
end

function LMP(f::Function, args=String[], comm=nothing)
Expand Down Expand Up @@ -154,6 +157,7 @@ end
function extract_global(lmp::LMP, name, dtype=nothing)
if dtype === nothing
dtype = API.lammps_extract_global_datatype(lmp, name)
dtype == -1 && error("Could not find dataype for global $name")
end
dtype = API._LMP_DATATYPE_CONST(dtype)
type = dtype2type(dtype)
Expand Down Expand Up @@ -198,6 +202,7 @@ function extract_atom(lmp::LMP, name,

if dtype === nothing
dtype = API.lammps_extract_atom_datatype(lmp, name)
dtype == -1 && error("Could not find dataype for atom $name")
dtype = API._LMP_DATATYPE_CONST(dtype)
end

Expand Down Expand Up @@ -504,7 +509,47 @@ function _get_T(lmp::LMP, name::String)
else
error("Unkown per atom property $name")
end
end

function extract_setting(lmp, name)
val = API.lammps_extract_setting(lmp, name)
val == -1 && error("Could not find setting $name")
return val
end

function pair_neighbor_list(lmp, name, exact, nsub, request)
idx = API.lammps_find_pair_neighlist(lmp, name, exact, nsub, request)
if idx == -1
error("Could not find neighbor list for pair $(name)")
end
return idx
end

function fix_neighbor_list(lmp, name, request)
idx = API.lammps_find_fix_neighlist(lmp, name, request)
if idx == -1
error("Could not find neighbor list for fix $(name)")
end
return idx
end


"""
neighbors(lmb::LMP, idx, element)

Given a neighbor list `idx` and the element therein,
return the atom index, and it's neigbors.
"""
function neighbors(lmp, idx, element)
r_iatom = Ref{Cint}()
r_numneigh = Ref{Cint}()
r_neighbors = Ref{Ptr{Cint}}(0)

API.lammps_neighlist_element_neighbors(lmp, idx, element - 1, r_iatom, r_numneigh, r_neighbors)

return Int(r_iatom[]), Base.unsafe_wrap(Array, r_neighbors[], r_numneigh[]; own = false)
end

include("external.jl")

end # module
4 changes: 4 additions & 0 deletions src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,10 @@ function lammps_flush_buffers(ptr)
ccall((:lammps_flush_buffers, liblammps), Cvoid, (Ptr{Cvoid},), ptr)
end

# function lammps_fix_external_set_energy_peratom(handle, id, eng)
# ccall((:lammps_fix_external_set_energy_peratom, liblammps), Cvoid, (Ptr{Cvoid}, Ptr{Cchar}, Ptr{Cdouble}), handle, id, eng)
# end

function lammps_free(ptr)
ccall((:lammps_free, liblammps), Cvoid, (Ptr{Cvoid},), ptr)
end
Expand Down
141 changes: 141 additions & 0 deletions src/external.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
function fix_external_callback end

mutable struct FixExternal
lmp::LMP
name::String
callback

function FixExternal(lmp::LMP, name::String, callback)
if haskey(lmp.external_fixes, name)
error("FixExternal has already been registered with $name")
end

this = new(lmp, name, callback)
lmp.external_fixes[name] = this # preserves pair globally

ctx = Base.pointer_from_objref(this)
callback = @cfunction(fix_external_callback, Cvoid, (Ptr{Cvoid}, Int64, Cint, Ptr{Cint}, Ptr{Ptr{Float64}}, Ptr{Ptr{Float64}}))
API.lammps_set_fix_external_callback(lmp, name, callback, ctx)

# Ensure function is compiled before timestep 0
if !precompile(this.callback, (FixExternal, Int, Int, Int, Vector{Int32}, Matrix{Float64}, Matrix{Float64}))
@warn "Failed to precompile the callback" this.callback
end
return this
end
end

FixExternal(callback, lmp::LMP, name::String) = FixExternal(lmp::LMP, name::String, callback)

function fix_external_callback(ctx::Ptr{Cvoid}, timestep::Int64, nlocal::Cint, ids::Ptr{Cint}, x::Ptr{Ptr{Float64}}, fexternal::Ptr{Ptr{Float64}})
fix = Base.unsafe_pointer_to_objref(ctx)::FixExternal
nlocal = Int(nlocal)

nghost = Int(extract_global(fix.lmp, "nghost"))

@debug "Calling fix_external_callback on" fix timestep nlocal
shape = (nlocal+nghost, 3)
x = unsafe_wrap(x, shape)
fexternal = unsafe_wrap(fexternal, shape)
ids = unsafe_wrap(ids, (nlocal+nghost,))

# necessary dynamic
fix.callback(fix, timestep, nlocal, nghost, ids, x, fexternal)
return nothing
end

function energy_global!(fix::FixExternal, energy)
API.lammps_fix_external_set_energy_global(fix.lmp, fix.name, energy)
end

function neighbor_list(fix::FixExternal, request)
fix_neighbor_list(fix.lmp, fix.name, request)
end

function virial_global!(fix::FixExternal, virial)
API.lammps_fix_external_set_virial_global(fix.lmp, fix.name, virial)
end

# TODO
# virial_global!
# function virial_global!(fix::FixExternal, )

const NEIGHMASK = 0x3FFFFFFF
const SBBITS = 30
sbmask(atom) = (atom >> SBBITS) & 3
const special_lj = [1.0, 0.0, 0.0 ,0.0]

function virial_fdotr_compute(fexternal::Matrix{Float64}, x::Matrix{Float64}, nall)
#TODO: discuss include_group flag
virial = Array{Float64}(undef, 6)
for i in 1:nall
virial[1] = fexternal[1, i] * x[1, i]
virial[2] = fexternal[2, i] * x[2, i]
virial[3] = fexternal[3, i] * x[3, i]
virial[1] = fexternal[2, i] * x[1, i]
virial[2] = fexternal[3, i] * x[1, i]
virial[3] = fexternal[3, i] * x[2, i]
end
return virial
end

function PairExternal(lmp, name, neigh_name, compute_force::F, compute_energy::E, cut_global, eflag, vflag) where {E, F}
cutsq = cut_global^2
function pair(fix::FixExternal, timestep::Int, nlocal::Int, nghost::Int, ids::Vector{Int32}, x::Matrix{Float64}, fexternal::Matrix{Float64})
# Full neighbor list

idx = pair_neighbor_list(fix.lmp, neigh_name, 1, 0, 0)
nelements = API.lammps_neighlist_num_elements(fix.lmp, idx)
newton_pair = extract_setting(fix.lmp, "newton_pair") == 1
# special_lj = extract_global(fix.lmp, "special_lj")
type = LAMMPS.extract_atom(lmp, "type", API.LAMMPS_INT, nlocal+nghost)::Vector{Int32}

# zero-out fexternal (noticed some undef memory)
fexternal .= 0
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary? At least it allows me to use += later.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes and no. The expectation for the design of fix external is that it is used to couple some other software that can compute forces to LAMMPS, feed the atom positions and (global) indexes to it and let it compute the forces. Those are then copied to the fexternal array. Typically the program would have its own input file(s) and read the topology and initial geometry from that and then the data from fix external is used to update the positions (hence the atom-IDs since the order in the local arrays can change all the time) and then compute the forces.

So if you collect your forces in a buffer of your own, you only need to copy them. If you want to collect them directly into the array passed from LAMMPS you need to zero it out first.

Another item to take care of are forces between pairs of atoms that straddle subdomain boundaries (or periodic domain boundaries in case you run in serial). You cannot return forces on "ghost" atoms so you either have to use a full neighbor list (have each pair listed twice) or use "newton off" so that pairs across domain boundaries are listed twice and then store forces only with local atoms (hence the passing of nlocal). The alternative would be to implement some communication.

At this point, it is probably a good idea to read through sections 4.1 to 4.6 here: https://docs.lammps.org/Developer.html
Fix external is called at the "post_force" step.


energies = zeros(nlocal)


#API.lammps_fix_external_set_energy_peratom(fix.lmp, fix.name, energies)
x = gather(lmp, "x", Float64)

for ii in 1:Int(nelements)
# local atom index (i.e. in the range [0, nlocal + nghost)
types = []
iatom, neigh = LAMMPS.neighbors(lmp, idx, ii)
pt = []
iatom += 1 # 1-based indexing
xtmp, ytmp, ztmp = view(x, :, iatom) # TODO SArray?
append!(types, type[iatom])
push!(pt, x[:, iatom])
incut = 1
for jj in 1:length(neigh)
jatom = Int(neigh[jj])
jatom &= NEIGHMASK
jatom += 1 # 1-based indexing
jtype = type[jatom]
delx = xtmp - x[1, jatom]
dely = ytmp - x[2, jatom]
delz = ztmp - x[3, jatom]
rsq = delx*delx + dely*dely + delz*delz
if rsq < cutsq
append!(types, jtype)
push!(pt, x[:, jatom])
incut += 1
end
end
fexternal[:, iatom] = compute_force(reshape(pt, (3, incut)), types)[1]
if eflag
energies[iatom] = compute_energy(reshape(pt, (3, incut)), types)[1]
end
end
if eflag
API.lammps_fix_external_set_energy_peratom(fix.lmp, fix.name, energies)
energy_global!(fix, sum(energies))
end
if vflag
virial = virial_fdotr_compute(fexternal, x, nlocal+nghost)
end
end
FixExternal(pair, lmp, name)
end
Loading