From ae5cccf28636f67d1a6ae2823b4dcd6a00fd211c Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Wed, 15 Mar 2023 13:08:35 -0400 Subject: [PATCH 01/13] Add support for fix external --- examples/fix_external.jl | 38 ++++++++++++++++++++++++++++++++++++ src/LAMMPS.jl | 9 +++++++-- src/external.jl | 42 ++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 24 +++++++++++++++++++++++ 4 files changed, 111 insertions(+), 2 deletions(-) create mode 100644 examples/fix_external.jl create mode 100644 src/external.jl diff --git a/examples/fix_external.jl b/examples/fix_external.jl new file mode 100644 index 0000000..dd849d0 --- /dev/null +++ b/examples/fix_external.jl @@ -0,0 +1,38 @@ +using LAMMPS + +lmp = LMP() + +command(lmp, "units lj") +command(lmp, "atom_style atomic") +command(lmp, "atom_modify map array sort 0 0") +command(lmp, "box tilt large") + +# Setup box +x_hi = 10.0 +y_hi = 10.0 +z_hi = 10.0 +command(lmp, "boundary p p p") +command(lmp, "region cell block 0 $x_hi 0 $y_hi 0 $z_hi units box") +command(lmp, "create_box 1 cell") + +command(lmp, "fix julia_lj all external pf/callback 1 1") + +# Register external fix +lj = LAMMPS.FixExternal(lmp, "julia_lj") do fix, timestep, nlocal, ids, x, fexternal + @info "julia lj called" timestep nlocal + LAMMPS.energy_global!(fix, 0.0) +end + +# Setup atoms +natoms = 10 +command(lmp, "create_atoms 1 random $natoms 1 NULL") +command(lmp, "mass 1 1.0") + +# (x,y,z), natoms +positions = rand(3, 10) .* 5 +LAMMPS.API.lammps_scatter_atoms(lmp, "x", 1, 3, positions) + +command(lmp, "run 0") + +# extract forces +forces = extract_atom(lmp, "f") \ No newline at end of file diff --git a/src/LAMMPS.jl b/src/LAMMPS.jl index c829086..522bd6f 100644 --- a/src/LAMMPS.jl +++ b/src/LAMMPS.jl @@ -49,6 +49,7 @@ end mutable struct LMP @atomic handle::Ptr{Cvoid} + external_fixes::Dict{String, Any} function LMP(args::Vector{String}=String[], comm::Union{Nothing, MPI.Comm}=nothing) if !isempty(args) @@ -67,7 +68,7 @@ mutable struct LMP end end - this = new(handle) + this = new(handle, Dict{String, Any}()) finalizer(close!, this) return this end @@ -82,8 +83,10 @@ Shutdown an LMP instance. function close!(lmp::LMP) handle = @atomicswap lmp.handle = C_NULL if handle !== C_NULL - API.lammps_close(handle) + empty!(lmp.external_fixes) + API.lammps_close(handle) end + return nothing end function LMP(f::Function, args=String[], comm=nothing) @@ -507,4 +510,6 @@ function _get_T(lmp::LMP, name::String) end +include("external.jl") + end # module diff --git a/src/external.jl b/src/external.jl new file mode 100644 index 0000000..1499c07 --- /dev/null +++ b/src/external.jl @@ -0,0 +1,42 @@ +function fix_external_callback end + +mutable struct FixExternal + lmp::LMP + name::String + callback + + function FixExternal(lmp::LMP, name::String, callback) + if haskey(lmp.external_fixes, name) + error("FixExternal has already been registered with $name") + end + + this = new(lmp, name, callback) + lmp.external_fixes[name] = this # preserves pair globally + + ctx = Base.pointer_from_objref(this) + callback = @cfunction(fix_external_callback, Cvoid, (Ptr{Cvoid}, Int64, Cint, Ptr{Cint}, Ptr{Ptr{Float64}}, Ptr{Ptr{Float64}})) + API.lammps_set_fix_external_callback(lmp, name, callback, ctx) + + return this + end +end + +FixExternal(callback, lmp::LMP, name::String) = FixExternal(lmp::LMP, name::String, callback) + +function fix_external_callback(ctx::Ptr{Cvoid}, timestep::Int64, nlocal::Cint, ids::Ptr{Cint}, x::Ptr{Ptr{Float64}}, fexternal::Ptr{Ptr{Float64}}) + fix = Base.unsafe_pointer_to_objref(ctx)::FixExternal + + @debug "Calling fix_external_callback on" fix timestep ids x fexternal + # TODO wrap as arrays + fix.callback(fix, timestep, nlocal, ids, x, fexternal) + return nothing +end + +function energy_global!(fix::FixExternal, energy) + API.lammps_fix_external_set_energy_global(fix.lmp, fix.name, energy) +end + +# TODO +# virial_global! +# function virial_global!(fix::FixExternal, ) + diff --git a/test/runtests.jl b/test/runtests.jl index 01d1dd9..9a84b5f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,6 +15,16 @@ LMP(["-screen", "none"]) do lmp @test_throws ErrorException command(lmp, "nonsense") end + +function f() + lmp = LMP(["-screen", "none"]) + @test LAMMPS.version(lmp) >= 0 + command(lmp, "clear") + @test_throws ErrorException command(lmp, "nonsense") + LAMMPS.close!(lmp) +end + + @testset "Variables" begin LMP(["-screen", "none"]) do lmp command(lmp, "box tilt large") @@ -101,4 +111,18 @@ end end end +LMP(["-screen", "none"]) do lmp + called = Ref(false) + command(lmp, "boundary p p p") + command(lmp, "region cell block 0 1 0 1 0 1 units box") + command(lmp, "create_box 1 cell") + command(lmp, "fix julia all external pf/callback 1 1") + LAMMPS.FixExternal(lmp, "julia") do fix, timestep, nlocal, ids, x, fexternal + called[] = true + end + command(lmp, "mass 1 1.0") + command(lmp, "run 0") + @test called[] == true +end + @test success(pipeline(`$(MPI.mpiexec()) -n 2 $(Base.julia_cmd()) mpitest.jl`, stderr=stderr, stdout=stdout)) From 4357bad41c8113b03887eaaa41022c78acc63edb Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Wed, 15 Mar 2023 14:27:25 -0400 Subject: [PATCH 02/13] Update test/runtests.jl --- test/runtests.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/runtests.jl b/test/runtests.jl index 9a84b5f..1fe82cf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -118,6 +118,7 @@ LMP(["-screen", "none"]) do lmp command(lmp, "create_box 1 cell") command(lmp, "fix julia all external pf/callback 1 1") LAMMPS.FixExternal(lmp, "julia") do fix, timestep, nlocal, ids, x, fexternal + LAMMPS.energy_global!(fix, 0.0) called[] = true end command(lmp, "mass 1 1.0") From 410e748b8bc4a4449e32dead3c0a1399fb3a77f2 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Wed, 15 Mar 2023 16:08:41 -0400 Subject: [PATCH 03/13] finish example --- examples/fix_external.jl | 33 +++++++++++++++++++--- src/LAMMPS.jl | 31 +++++++++++++++++++++ src/external.jl | 59 ++++++++++++++++++++++++++++++++++++++-- 3 files changed, 117 insertions(+), 6 deletions(-) diff --git a/examples/fix_external.jl b/examples/fix_external.jl index dd849d0..44e0e21 100644 --- a/examples/fix_external.jl +++ b/examples/fix_external.jl @@ -15,14 +15,39 @@ command(lmp, "boundary p p p") command(lmp, "region cell block 0 $x_hi 0 $y_hi 0 $z_hi units box") command(lmp, "create_box 1 cell") +# Use `pair_style zero` to create neighbor list for `julia_lj` +cutoff = 2.5 +command(lmp, "pair_style zero $cutoff") +command(lmp, "pair_coeff * *") command(lmp, "fix julia_lj all external pf/callback 1 1") -# Register external fix -lj = LAMMPS.FixExternal(lmp, "julia_lj") do fix, timestep, nlocal, ids, x, fexternal - @info "julia lj called" timestep nlocal - LAMMPS.energy_global!(fix, 0.0) +const coefficients = Dict( + 1 => Dict( + 1 => [48.0, 24.0, 4.0,4.0] + ) +) + +function compute_force(rsq, itype, jtype) + coeff = coefficients[itype][jtype] + r2inv = 1.0/rsq + r6inv = r2inv^3 + lj1 = coeff[1] + lj2 = coeff[2] + return (r6inv * (lj1*r6inv - lj2))*r2inv +end + +function compute_energy(rsq, itype, jtype) + coeff = coefficients[itype][jtype] + r2inv = 1.0/rsq + r6inv = r2inv^3 + lj3 = coeff[3] + lj4 = coeff[4] + return (r6inv * (lj3*r6inv - lj4)) end +# Register external fix +lj = LAMMPS.PairExternal(lmp, "julia_lj", "zero", compute_force, compute_energy, cutoff) + # Setup atoms natoms = 10 command(lmp, "create_atoms 1 random $natoms 1 NULL") diff --git a/src/LAMMPS.jl b/src/LAMMPS.jl index 522bd6f..1491ea7 100644 --- a/src/LAMMPS.jl +++ b/src/LAMMPS.jl @@ -508,6 +508,37 @@ function _get_T(lmp::LMP, name::String) error("Unkown per atom property $name") end +function pair_neighbor_list(lmp, name, exact, nsub, request) + idx = API.lammps_find_pair_neighlist(lmp, name, exact, nsub, request) + if idx == -1 + error("Could not find neighbor list for pair $(name)") + end + return idx +end + +function fix_neighbor_list(lmp, name, request) + idx = API.lammps_find_fix_neighlist(lmp, name, request) + if idx == -1 + error("Could not find neighbor list for fix $(name)") + end + return idx +end + + +""" + neighbors(lmb::LMP, idx, element) + +Given a neighbor list `idx` and the element therein, +return the atom index, and it's neigbors. +""" +function neighbors(lmp, idx, element) + r_iatom = Ref{Cint}() + r_numneigh = Ref{Cint}() + r_neighbors = Ref{Ptr{Cint}}(0) + + API.lammps_neighlist_element_neighbors(lmp, idx, element - 1, r_iatom, r_numneigh, r_neighbors) + + return r_iatom[], Base.unsafe_wrap(Array, r_neighbors[], r_numneigh[]; own = false) end include("external.jl") diff --git a/src/external.jl b/src/external.jl index 1499c07..57db4fc 100644 --- a/src/external.jl +++ b/src/external.jl @@ -25,9 +25,14 @@ FixExternal(callback, lmp::LMP, name::String) = FixExternal(lmp::LMP, name::Stri function fix_external_callback(ctx::Ptr{Cvoid}, timestep::Int64, nlocal::Cint, ids::Ptr{Cint}, x::Ptr{Ptr{Float64}}, fexternal::Ptr{Ptr{Float64}}) fix = Base.unsafe_pointer_to_objref(ctx)::FixExternal + nlocal = Int(nlocal) + + @debug "Calling fix_external_callback on" fix timestep nlocal + shape = (nlocal, 3) + x = unsafe_wrap(x, shape) + fexternal = unsafe_wrap(fexternal, shape) + ids = unsafe_wrap(ids, (nlocal,)) - @debug "Calling fix_external_callback on" fix timestep ids x fexternal - # TODO wrap as arrays fix.callback(fix, timestep, nlocal, ids, x, fexternal) return nothing end @@ -36,7 +41,57 @@ function energy_global!(fix::FixExternal, energy) API.lammps_fix_external_set_energy_global(fix.lmp, fix.name, energy) end +function neighbor_list(fix::FixExternal, request) + fix_neighbor_list(fix.lmp, fix.name, request) +end + # TODO # virial_global! # function virial_global!(fix::FixExternal, ) +function PairExternal(lmp, name, neigh_name, compute_force, compute_energy, cut_global) + cutsq = cut_global^2 + FixExternal(lmp, name) do fix, timestep, nlocal, ids, x, fexternal + # Full neighbor list + idx = pair_neighbor_list(fix.lmp, neigh_name, 1, 0, 0) + nelements = API.lammps_neighlist_num_elements(fix.lmp, idx) + + type = LAMMPS.extract_atom(lmp, "type") + + # zero-out fexternal (noticed some undef memory) + fexternal .= 0 + + for ii in 1:nelements + iatom, neigh = LAMMPS.neighbors(lmp, idx, ii) + iatom += 1 # 1-based indexing + xtmp, ytmp, ztmp = view(x, :, iatom) # TODO SArray? + itype = type[iatom] + for jj in 1:length(neigh) + jatom = neigh[jj] + 1 + delx = xtmp - x[1, jatom] + dely = ytmp - x[2, jatom] + delz = ztmp = x[3, jatom] + jtype = type[jatom] + + rsq = delx*delx + dely*dely + delz*delz; + + if rsq < cutsq + fpair = compute_force(rsq, itype, jtype) + + fexternal[1, iatom] += delx*fpair + fexternal[2, iatom] += dely*fpair + fexternal[3, iatom] += delz*fpair + if jatom <= nlocal # newton_pair + fexternal[1, jatom] -= delx*fpair + fexternal[2, jatom] -= dely*fpair + fexternal[3, jatom] -= delz*fpair + end + + # todo call compute_energy + # TODO eflag + # TODO evflag + end + end + end + end +end From 362070950705fa96330614648f0ef9c741079762 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Wed, 15 Mar 2023 17:05:57 -0400 Subject: [PATCH 04/13] support newton_pair --- src/LAMMPS.jl | 8 ++++++++ src/external.jl | 19 +++++++++++++++++-- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/src/LAMMPS.jl b/src/LAMMPS.jl index 1491ea7..566f7a7 100644 --- a/src/LAMMPS.jl +++ b/src/LAMMPS.jl @@ -157,6 +157,7 @@ end function extract_global(lmp::LMP, name, dtype=nothing) if dtype === nothing dtype = API.lammps_extract_global_datatype(lmp, name) + dtype == -1 && error("Could not find dataype for global $name") end dtype = API._LMP_DATATYPE_CONST(dtype) type = dtype2type(dtype) @@ -507,6 +508,13 @@ function _get_T(lmp::LMP, name::String) else error("Unkown per atom property $name") end +end + +function extract_setting(lmp, name) + val = API.lammps_extract_setting(lmp, name) + val == -1 && error("Could not find setting $name") + return val +end function pair_neighbor_list(lmp, name, exact, nsub, request) idx = API.lammps_find_pair_neighlist(lmp, name, exact, nsub, request) diff --git a/src/external.jl b/src/external.jl index 57db4fc..85a7afa 100644 --- a/src/external.jl +++ b/src/external.jl @@ -49,6 +49,10 @@ end # virial_global! # function virial_global!(fix::FixExternal, ) +const NEIGHMASK = 0x3FFFFFFF +const SBBITS = 30 +sbmask(atom) = (atom >> SBBITS) & 3 + function PairExternal(lmp, name, neigh_name, compute_force, compute_energy, cut_global) cutsq = cut_global^2 FixExternal(lmp, name) do fix, timestep, nlocal, ids, x, fexternal @@ -56,6 +60,13 @@ function PairExternal(lmp, name, neigh_name, compute_force, compute_energy, cut_ idx = pair_neighbor_list(fix.lmp, neigh_name, 1, 0, 0) nelements = API.lammps_neighlist_num_elements(fix.lmp, idx) + # TODO how to obtain in fix + eflag = false + evflag = false + + # How to get special_lj. + newton_pair = extract_setting(fix.lmp, "newton_pair") == 1 + type = LAMMPS.extract_atom(lmp, "type") # zero-out fexternal (noticed some undef memory) @@ -67,7 +78,11 @@ function PairExternal(lmp, name, neigh_name, compute_force, compute_energy, cut_ xtmp, ytmp, ztmp = view(x, :, iatom) # TODO SArray? itype = type[iatom] for jj in 1:length(neigh) - jatom = neigh[jj] + 1 + jatom = neigh[jj] + factor_lj = 1.0 + jatom &= NEIGHMASK + jatom += 1 # 1-based indexing + delx = xtmp - x[1, jatom] dely = ytmp - x[2, jatom] delz = ztmp = x[3, jatom] @@ -81,7 +96,7 @@ function PairExternal(lmp, name, neigh_name, compute_force, compute_energy, cut_ fexternal[1, iatom] += delx*fpair fexternal[2, iatom] += dely*fpair fexternal[3, iatom] += delz*fpair - if jatom <= nlocal # newton_pair + if jatom <= nlocal || newton_pair fexternal[1, jatom] -= delx*fpair fexternal[2, jatom] -= dely*fpair fexternal[3, jatom] -= delz*fpair From cab4c3f3e2b2cf4edd851ce92a3a5001c5446a0e Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Wed, 15 Mar 2023 20:33:55 -0400 Subject: [PATCH 05/13] Full test and fix silly bug --- examples/fix_external.jl | 6 ++-- src/LAMMPS.jl | 1 + src/api.jl | 4 +++ src/external.jl | 38 +++++++++++---------- test/external_pair.jl | 72 ++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 2 ++ 6 files changed, 104 insertions(+), 19 deletions(-) create mode 100644 test/external_pair.jl diff --git a/examples/fix_external.jl b/examples/fix_external.jl index 44e0e21..27210e5 100644 --- a/examples/fix_external.jl +++ b/examples/fix_external.jl @@ -54,10 +54,12 @@ command(lmp, "create_atoms 1 random $natoms 1 NULL") command(lmp, "mass 1 1.0") # (x,y,z), natoms -positions = rand(3, 10) .* 5 +# positions = rand(3, 10) .* 5 +positions = [4.4955289268519625 3.3999909266656836 4.420245465344918 2.3923580632470216 1.9933183377321746 2.3367019702697096 0.014668174434679937 4.5978923623562356 2.9389893820585025 4.800351333939365; 4.523573662784505 3.1582899538900304 2.5562765646443 3.199496583966941 4.891026316235915 4.689641854106464 2.7591724192198575 0.7491156338926308 1.258994308308421 2.0419941687773937; 2.256261603545908 0.694847945108647 4.058244561946366 3.044596885569421 2.60225212714946 4.0030490608195555 0.9941423774290642 1.8076536961230087 1.9712395260164222 1.2705916409499818] + LAMMPS.API.lammps_scatter_atoms(lmp, "x", 1, 3, positions) command(lmp, "run 0") # extract forces -forces = extract_atom(lmp, "f") \ No newline at end of file +forces = extract_atom(lmp, "f") diff --git a/src/LAMMPS.jl b/src/LAMMPS.jl index 566f7a7..5c7effd 100644 --- a/src/LAMMPS.jl +++ b/src/LAMMPS.jl @@ -202,6 +202,7 @@ function extract_atom(lmp::LMP, name, if dtype === nothing dtype = API.lammps_extract_atom_datatype(lmp, name) + dtype == -1 && error("Could not find dataype for atom $name") dtype = API._LMP_DATATYPE_CONST(dtype) end diff --git a/src/api.jl b/src/api.jl index d40d62c..a099383 100644 --- a/src/api.jl +++ b/src/api.jl @@ -390,6 +390,10 @@ function lammps_flush_buffers(ptr) ccall((:lammps_flush_buffers, liblammps), Cvoid, (Ptr{Cvoid},), ptr) end +function lammps_fix_external_set_energy_peratom(handle, id, eng) + ccall((:lammps_fix_external_set_energy_peratom, liblammps), Cvoid, (Ptr{Cvoid}, Ptr{Cchar}, Ptr{Cdouble}), handle, id, eng) +end + function lammps_free(ptr) ccall((:lammps_free, liblammps), Cvoid, (Ptr{Cvoid},), ptr) end diff --git a/src/external.jl b/src/external.jl index 85a7afa..58bbd5c 100644 --- a/src/external.jl +++ b/src/external.jl @@ -52,6 +52,7 @@ end const NEIGHMASK = 0x3FFFFFFF const SBBITS = 30 sbmask(atom) = (atom >> SBBITS) & 3 +const special_lj = [1.0, 0.0, 0.0 ,0.0] function PairExternal(lmp, name, neigh_name, compute_force, compute_energy, cut_global) cutsq = cut_global^2 @@ -64,49 +65,52 @@ function PairExternal(lmp, name, neigh_name, compute_force, compute_energy, cut_ eflag = false evflag = false - # How to get special_lj. newton_pair = extract_setting(fix.lmp, "newton_pair") == 1 + # special_lj = extract_global(fix.lmp, "special_lj") type = LAMMPS.extract_atom(lmp, "type") # zero-out fexternal (noticed some undef memory) fexternal .= 0 + energies = zeros(nlocal) + for ii in 1:nelements + # local atom index (i.e. in the range [0, nlocal + nghost) iatom, neigh = LAMMPS.neighbors(lmp, idx, ii) iatom += 1 # 1-based indexing xtmp, ytmp, ztmp = view(x, :, iatom) # TODO SArray? itype = type[iatom] for jj in 1:length(neigh) jatom = neigh[jj] - factor_lj = 1.0 + factor_lj = special_lj[sbmask(jatom) + 1] jatom &= NEIGHMASK jatom += 1 # 1-based indexing delx = xtmp - x[1, jatom] dely = ytmp - x[2, jatom] - delz = ztmp = x[3, jatom] + delz = ztmp - x[3, jatom] jtype = type[jatom] rsq = delx*delx + dely*dely + delz*delz; - if rsq < cutsq - fpair = compute_force(rsq, itype, jtype) - - fexternal[1, iatom] += delx*fpair - fexternal[2, iatom] += dely*fpair - fexternal[3, iatom] += delz*fpair - if jatom <= nlocal || newton_pair - fexternal[1, jatom] -= delx*fpair - fexternal[2, jatom] -= dely*fpair - fexternal[3, jatom] -= delz*fpair + fpair = factor_lj * compute_force(rsq, itype, jtype) + + if iatom <= nlocal + fexternal[1, iatom] += delx*fpair + fexternal[2, iatom] += dely*fpair + fexternal[3, iatom] += delz*fpair + if jatom <= nlocal || newton_pair + fexternal[1, jatom] -= delx*fpair + fexternal[2, jatom] -= dely*fpair + fexternal[3, jatom] -= delz*fpair + end + energies[iatom] += compute_energy(rsq, itype, jtype) end - - # todo call compute_energy - # TODO eflag - # TODO evflag end end end + API.lammps_fix_external_set_energy_peratom(fix.lmp, fix.name, energies) + energy_global!(fix, sum(energies)) end end diff --git a/test/external_pair.jl b/test/external_pair.jl new file mode 100644 index 0000000..9fb18ce --- /dev/null +++ b/test/external_pair.jl @@ -0,0 +1,72 @@ +using LAMMPS +using Test + +lmp_native = LMP() +lmp_julia = LMP() + +for lmp in (lmp_native, lmp_julia) + command(lmp, "units lj") + command(lmp, "atom_style atomic") + command(lmp, "atom_modify map array sort 0 0") + command(lmp, "box tilt large") + + # Setup box + command(lmp, "boundary p p p") + command(lmp, "region cell block 0 10.0 0 10.0 0 10.0 units box") + command(lmp, "create_box 1 cell") +end + +cutoff = 2.5 +command(lmp_julia, "pair_style zero $cutoff") +command(lmp_julia, "pair_coeff * *") +command(lmp_julia, "fix julia_lj all external pf/callback 1 1") + +command(lmp_native, "pair_style lj/cut $cutoff") +command(lmp_native, "pair_coeff * * 1 1") + +const coefficients = Dict( + 1 => Dict( + 1 => [48.0, 24.0, 4.0,4.0] + ) +) + +function compute_force(rsq, itype, jtype) + coeff = coefficients[itype][jtype] + r2inv = 1.0/rsq + r6inv = r2inv^3 + lj1 = coeff[1] + lj2 = coeff[2] + return (r6inv * (lj1*r6inv - lj2))*r2inv +end + +function compute_energy(rsq, itype, jtype) + coeff = coefficients[itype][jtype] + r2inv = 1.0/rsq + r6inv = r2inv^3 + lj3 = coeff[3] + lj4 = coeff[4] + return (r6inv * (lj3*r6inv - lj4)) +end + +# Register external fix +lj = LAMMPS.PairExternal(lmp_julia, "julia_lj", "zero", compute_force, compute_energy, cutoff) + +# Setup atoms +natoms = 10 +positions = rand(3, 10) .* 5 +for lmp in (lmp_native, lmp_julia) + command(lmp, "create_atoms 1 random $natoms 1 NULL") + command(lmp, "mass 1 1.0") + + LAMMPS.API.lammps_scatter_atoms(lmp, "x", 1, 3, positions) + + command(lmp, "run 0") +end + +# extract forces +forces_native = extract_atom(lmp_native, "f") +forces_julia = extract_atom(lmp_julia, "f") + +@testset "External Pair" begin + @test forces_native == forces_julia +end diff --git a/test/runtests.jl b/test/runtests.jl index 1fe82cf..d1ab065 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -127,3 +127,5 @@ LMP(["-screen", "none"]) do lmp end @test success(pipeline(`$(MPI.mpiexec()) -n 2 $(Base.julia_cmd()) mpitest.jl`, stderr=stderr, stdout=stdout)) + +include("external_pair.jl") From 2bc86a95b3303046b97b9cdf42225b0af6f21c64 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Wed, 8 May 2024 10:24:14 -0400 Subject: [PATCH 06/13] handle ghost cells --- src/external.jl | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/external.jl b/src/external.jl index 58bbd5c..8fb83d5 100644 --- a/src/external.jl +++ b/src/external.jl @@ -27,13 +27,15 @@ function fix_external_callback(ctx::Ptr{Cvoid}, timestep::Int64, nlocal::Cint, i fix = Base.unsafe_pointer_to_objref(ctx)::FixExternal nlocal = Int(nlocal) + nghost = extract_global(fix.lmp, "nghost") + @debug "Calling fix_external_callback on" fix timestep nlocal - shape = (nlocal, 3) + shape = (nlocal+nghost, 3) x = unsafe_wrap(x, shape) fexternal = unsafe_wrap(fexternal, shape) - ids = unsafe_wrap(ids, (nlocal,)) + ids = unsafe_wrap(ids, (nlocal+nghost,)) - fix.callback(fix, timestep, nlocal, ids, x, fexternal) + fix.callback(fix, timestep, nlocal, nghost, ids, x, fexternal) return nothing end @@ -56,7 +58,7 @@ const special_lj = [1.0, 0.0, 0.0 ,0.0] function PairExternal(lmp, name, neigh_name, compute_force, compute_energy, cut_global) cutsq = cut_global^2 - FixExternal(lmp, name) do fix, timestep, nlocal, ids, x, fexternal + FixExternal(lmp, name) do fix, timestep, nlocal, nghost, ids, x, fexternal # Full neighbor list idx = pair_neighbor_list(fix.lmp, neigh_name, 1, 0, 0) nelements = API.lammps_neighlist_num_elements(fix.lmp, idx) @@ -68,7 +70,7 @@ function PairExternal(lmp, name, neigh_name, compute_force, compute_energy, cut_ newton_pair = extract_setting(fix.lmp, "newton_pair") == 1 # special_lj = extract_global(fix.lmp, "special_lj") - type = LAMMPS.extract_atom(lmp, "type") + type = LAMMPS.extract_atom(lmp, "type", API.LAMMPS_INT, nlocal+nghost) # zero-out fexternal (noticed some undef memory) fexternal .= 0 From 62bec52d8ab1b07c5baa59c38fbd5ce531513c7d Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Wed, 8 May 2024 12:23:41 -0400 Subject: [PATCH 07/13] small amount of performance engineering --- src/LAMMPS.jl | 2 +- src/api.jl | 6 +++--- src/external.jl | 16 ++++++++++------ 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/LAMMPS.jl b/src/LAMMPS.jl index 5c7effd..1eac469 100644 --- a/src/LAMMPS.jl +++ b/src/LAMMPS.jl @@ -547,7 +547,7 @@ function neighbors(lmp, idx, element) API.lammps_neighlist_element_neighbors(lmp, idx, element - 1, r_iatom, r_numneigh, r_neighbors) - return r_iatom[], Base.unsafe_wrap(Array, r_neighbors[], r_numneigh[]; own = false) + return Int(r_iatom[]), Base.unsafe_wrap(Array, r_neighbors[], r_numneigh[]; own = false) end include("external.jl") diff --git a/src/api.jl b/src/api.jl index a099383..4a1b3fa 100644 --- a/src/api.jl +++ b/src/api.jl @@ -390,9 +390,9 @@ function lammps_flush_buffers(ptr) ccall((:lammps_flush_buffers, liblammps), Cvoid, (Ptr{Cvoid},), ptr) end -function lammps_fix_external_set_energy_peratom(handle, id, eng) - ccall((:lammps_fix_external_set_energy_peratom, liblammps), Cvoid, (Ptr{Cvoid}, Ptr{Cchar}, Ptr{Cdouble}), handle, id, eng) -end +# function lammps_fix_external_set_energy_peratom(handle, id, eng) +# ccall((:lammps_fix_external_set_energy_peratom, liblammps), Cvoid, (Ptr{Cvoid}, Ptr{Cchar}, Ptr{Cdouble}), handle, id, eng) +# end function lammps_free(ptr) ccall((:lammps_free, liblammps), Cvoid, (Ptr{Cvoid},), ptr) diff --git a/src/external.jl b/src/external.jl index 8fb83d5..45311a7 100644 --- a/src/external.jl +++ b/src/external.jl @@ -17,6 +17,8 @@ mutable struct FixExternal callback = @cfunction(fix_external_callback, Cvoid, (Ptr{Cvoid}, Int64, Cint, Ptr{Cint}, Ptr{Ptr{Float64}}, Ptr{Ptr{Float64}})) API.lammps_set_fix_external_callback(lmp, name, callback, ctx) + # Ensure function is compiled before timestep 0 + @assert precompile(this.callback, (FixExternal, Int, Int, Int, Vector{Int32}, Matrix{Float64}, Matrix{Float64})) return this end end @@ -27,7 +29,7 @@ function fix_external_callback(ctx::Ptr{Cvoid}, timestep::Int64, nlocal::Cint, i fix = Base.unsafe_pointer_to_objref(ctx)::FixExternal nlocal = Int(nlocal) - nghost = extract_global(fix.lmp, "nghost") + nghost = Int(extract_global(fix.lmp, "nghost")) @debug "Calling fix_external_callback on" fix timestep nlocal shape = (nlocal+nghost, 3) @@ -35,6 +37,7 @@ function fix_external_callback(ctx::Ptr{Cvoid}, timestep::Int64, nlocal::Cint, i fexternal = unsafe_wrap(fexternal, shape) ids = unsafe_wrap(ids, (nlocal+nghost,)) + # necessary dynamic fix.callback(fix, timestep, nlocal, nghost, ids, x, fexternal) return nothing end @@ -56,9 +59,9 @@ const SBBITS = 30 sbmask(atom) = (atom >> SBBITS) & 3 const special_lj = [1.0, 0.0, 0.0 ,0.0] -function PairExternal(lmp, name, neigh_name, compute_force, compute_energy, cut_global) +function PairExternal(lmp, name, neigh_name, compute_force::F, compute_energy::E, cut_global) where {E, F} cutsq = cut_global^2 - FixExternal(lmp, name) do fix, timestep, nlocal, nghost, ids, x, fexternal + function pair(fix::FixExternal, timestep::Int, nlocal::Int, nghost::Int, ids::Vector{Int32}, x::Matrix{Float64}, fexternal::Matrix{Float64}) # Full neighbor list idx = pair_neighbor_list(fix.lmp, neigh_name, 1, 0, 0) nelements = API.lammps_neighlist_num_elements(fix.lmp, idx) @@ -70,21 +73,21 @@ function PairExternal(lmp, name, neigh_name, compute_force, compute_energy, cut_ newton_pair = extract_setting(fix.lmp, "newton_pair") == 1 # special_lj = extract_global(fix.lmp, "special_lj") - type = LAMMPS.extract_atom(lmp, "type", API.LAMMPS_INT, nlocal+nghost) + type = LAMMPS.extract_atom(lmp, "type", API.LAMMPS_INT, nlocal+nghost)::Vector{Int32} # zero-out fexternal (noticed some undef memory) fexternal .= 0 energies = zeros(nlocal) - for ii in 1:nelements + for ii in 1:Int(nelements) # local atom index (i.e. in the range [0, nlocal + nghost) iatom, neigh = LAMMPS.neighbors(lmp, idx, ii) iatom += 1 # 1-based indexing xtmp, ytmp, ztmp = view(x, :, iatom) # TODO SArray? itype = type[iatom] for jj in 1:length(neigh) - jatom = neigh[jj] + jatom = Int(neigh[jj]) factor_lj = special_lj[sbmask(jatom) + 1] jatom &= NEIGHMASK jatom += 1 # 1-based indexing @@ -115,4 +118,5 @@ function PairExternal(lmp, name, neigh_name, compute_force, compute_energy, cut_ API.lammps_fix_external_set_energy_peratom(fix.lmp, fix.name, energies) energy_global!(fix, sum(energies)) end + FixExternal(pair, lmp, name) end From de157f8a1e4207049d77d1ffa6a927b5e8d3d8b5 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Fri, 10 May 2024 09:44:25 -0400 Subject: [PATCH 08/13] reduce overhead further --- examples/fix_external.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/fix_external.jl b/examples/fix_external.jl index 27210e5..8ec1e82 100644 --- a/examples/fix_external.jl +++ b/examples/fix_external.jl @@ -21,13 +21,13 @@ command(lmp, "pair_style zero $cutoff") command(lmp, "pair_coeff * *") command(lmp, "fix julia_lj all external pf/callback 1 1") -const coefficients = Dict( - 1 => Dict( - 1 => [48.0, 24.0, 4.0,4.0] +const coefficients = Base.ImmutableDict( + 1 => Base.ImmutableDict( + 1 => [48.0, 24.0, 4.0, 4.0] ) ) -function compute_force(rsq, itype, jtype) +@inline function compute_force(rsq, itype, jtype) coeff = coefficients[itype][jtype] r2inv = 1.0/rsq r6inv = r2inv^3 @@ -36,7 +36,7 @@ function compute_force(rsq, itype, jtype) return (r6inv * (lj1*r6inv - lj2))*r2inv end -function compute_energy(rsq, itype, jtype) +@inline function compute_energy(rsq, itype, jtype) coeff = coefficients[itype][jtype] r2inv = 1.0/rsq r6inv = r2inv^3 From 13e2294cda72fe3f9bb9dd941d5c24f935d575e1 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Thu, 20 Jun 2024 10:50:41 -0400 Subject: [PATCH 09/13] fix test --- src/external.jl | 4 +++- test/runtests.jl | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/external.jl b/src/external.jl index 45311a7..b124706 100644 --- a/src/external.jl +++ b/src/external.jl @@ -18,7 +18,9 @@ mutable struct FixExternal API.lammps_set_fix_external_callback(lmp, name, callback, ctx) # Ensure function is compiled before timestep 0 - @assert precompile(this.callback, (FixExternal, Int, Int, Int, Vector{Int32}, Matrix{Float64}, Matrix{Float64})) + if !precompile(this.callback, (FixExternal, Int, Int, Int, Vector{Int32}, Matrix{Float64}, Matrix{Float64})) + @warn "Failed to precompile the callback" this.callback + end return this end end diff --git a/test/runtests.jl b/test/runtests.jl index d1ab065..7a61765 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -117,7 +117,7 @@ LMP(["-screen", "none"]) do lmp command(lmp, "region cell block 0 1 0 1 0 1 units box") command(lmp, "create_box 1 cell") command(lmp, "fix julia all external pf/callback 1 1") - LAMMPS.FixExternal(lmp, "julia") do fix, timestep, nlocal, ids, x, fexternal + LAMMPS.FixExternal(lmp, "julia") do fix, timestep, nlocal, nghost, ids, x, fexternal LAMMPS.energy_global!(fix, 0.0) called[] = true end From 5ba4d36ef8d9dc79d1b0d9db4b6e10409ff2c32a Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Sat, 22 Jun 2024 20:16:48 -0400 Subject: [PATCH 10/13] use gather scatter API --- test/external_pair.jl | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/test/external_pair.jl b/test/external_pair.jl index 9fb18ce..950d1cc 100644 --- a/test/external_pair.jl +++ b/test/external_pair.jl @@ -1,8 +1,8 @@ using LAMMPS using Test -lmp_native = LMP() -lmp_julia = LMP() +lmp_native = LMP(["-screen", "none"]) +lmp_julia = LMP(["-screen", "none"]) for lmp in (lmp_native, lmp_julia) command(lmp, "units lj") @@ -24,13 +24,13 @@ command(lmp_julia, "fix julia_lj all external pf/callback 1 1") command(lmp_native, "pair_style lj/cut $cutoff") command(lmp_native, "pair_coeff * * 1 1") -const coefficients = Dict( - 1 => Dict( - 1 => [48.0, 24.0, 4.0,4.0] +const coefficients = Base.ImmutableDict( + 1 => Base.ImmutableDict( + 1 => [48.0, 24.0, 4.0, 4.0] ) ) -function compute_force(rsq, itype, jtype) +@inline function compute_force(rsq, itype, jtype) coeff = coefficients[itype][jtype] r2inv = 1.0/rsq r6inv = r2inv^3 @@ -39,7 +39,7 @@ function compute_force(rsq, itype, jtype) return (r6inv * (lj1*r6inv - lj2))*r2inv end -function compute_energy(rsq, itype, jtype) +@inline function compute_energy(rsq, itype, jtype) coeff = coefficients[itype][jtype] r2inv = 1.0/rsq r6inv = r2inv^3 @@ -58,14 +58,14 @@ for lmp in (lmp_native, lmp_julia) command(lmp, "create_atoms 1 random $natoms 1 NULL") command(lmp, "mass 1 1.0") - LAMMPS.API.lammps_scatter_atoms(lmp, "x", 1, 3, positions) + scatter!(lmp, "x", positions) command(lmp, "run 0") end # extract forces -forces_native = extract_atom(lmp_native, "f") -forces_julia = extract_atom(lmp_julia, "f") +forces_native = gather(lmp_native, "f", Float64) +forces_julia = gather(lmp_julia, "f", Float64) @testset "External Pair" begin @test forces_native == forces_julia From affdabe9a342b17fc1c8749eb890ad3d492528fd Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Sat, 22 Jun 2024 20:28:12 -0400 Subject: [PATCH 11/13] Update examples/fix_external.jl --- examples/fix_external.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/fix_external.jl b/examples/fix_external.jl index 8ec1e82..08f1a3b 100644 --- a/examples/fix_external.jl +++ b/examples/fix_external.jl @@ -54,8 +54,7 @@ command(lmp, "create_atoms 1 random $natoms 1 NULL") command(lmp, "mass 1 1.0") # (x,y,z), natoms -# positions = rand(3, 10) .* 5 -positions = [4.4955289268519625 3.3999909266656836 4.420245465344918 2.3923580632470216 1.9933183377321746 2.3367019702697096 0.014668174434679937 4.5978923623562356 2.9389893820585025 4.800351333939365; 4.523573662784505 3.1582899538900304 2.5562765646443 3.199496583966941 4.891026316235915 4.689641854106464 2.7591724192198575 0.7491156338926308 1.258994308308421 2.0419941687773937; 2.256261603545908 0.694847945108647 4.058244561946366 3.044596885569421 2.60225212714946 4.0030490608195555 0.9941423774290642 1.8076536961230087 1.9712395260164222 1.2705916409499818] +positions = rand(3, 10) .* 5 LAMMPS.API.lammps_scatter_atoms(lmp, "x", 1, 3, positions) From 33b89b317ee4c3317e0890b3cca0fd6e4e73d705 Mon Sep 17 00:00:00 2001 From: Felipe Tome Date: Thu, 7 Nov 2024 15:33:34 -0300 Subject: [PATCH 12/13] Acomodate ML potentials --- src/external.jl | 79 ++++++++++++++++++++++++++----------------- test/external_pair.jl | 2 +- 2 files changed, 49 insertions(+), 32 deletions(-) diff --git a/src/external.jl b/src/external.jl index b124706..505f94f 100644 --- a/src/external.jl +++ b/src/external.jl @@ -52,6 +52,10 @@ function neighbor_list(fix::FixExternal, request) fix_neighbor_list(fix.lmp, fix.name, request) end +function virial_global!(fix::FixExternal, virial) + API.lammps_fix_external_set_virial_global(fix.lmp, fix.name, virial) +end + # TODO # virial_global! # function virial_global!(fix::FixExternal, ) @@ -61,64 +65,77 @@ const SBBITS = 30 sbmask(atom) = (atom >> SBBITS) & 3 const special_lj = [1.0, 0.0, 0.0 ,0.0] -function PairExternal(lmp, name, neigh_name, compute_force::F, compute_energy::E, cut_global) where {E, F} +function virial_fdotr_compute(fexternal::Matrix{Float64}, x::Matrix{Float64}, nall) + #TODO: discuss include_group flag + virial = Array{Float64}(undef, 6) + for i in 1:nall + virial[1] = fexternal[1, i] * x[1, i] + virial[2] = fexternal[2, i] * x[2, i] + virial[3] = fexternal[3, i] * x[3, i] + virial[1] = fexternal[2, i] * x[1, i] + virial[2] = fexternal[3, i] * x[1, i] + virial[3] = fexternal[3, i] * x[2, i] + end + return virial +end + +function PairExternal(lmp, name, neigh_name, compute_force::F, compute_energy::E, cut_global, eflag, vflag) where {E, F} cutsq = cut_global^2 function pair(fix::FixExternal, timestep::Int, nlocal::Int, nghost::Int, ids::Vector{Int32}, x::Matrix{Float64}, fexternal::Matrix{Float64}) # Full neighbor list + idx = pair_neighbor_list(fix.lmp, neigh_name, 1, 0, 0) nelements = API.lammps_neighlist_num_elements(fix.lmp, idx) - - # TODO how to obtain in fix - eflag = false - evflag = false - newton_pair = extract_setting(fix.lmp, "newton_pair") == 1 # special_lj = extract_global(fix.lmp, "special_lj") - type = LAMMPS.extract_atom(lmp, "type", API.LAMMPS_INT, nlocal+nghost)::Vector{Int32} # zero-out fexternal (noticed some undef memory) fexternal .= 0 - + energies = zeros(nlocal) + + #API.lammps_fix_external_set_energy_peratom(fix.lmp, fix.name, energies) + x = gather(lmp, "x", Float64) + for ii in 1:Int(nelements) # local atom index (i.e. in the range [0, nlocal + nghost) - iatom, neigh = LAMMPS.neighbors(lmp, idx, ii) + types = [] + iatom, neigh = LAMMPS.neighbors(lmp, idx, ii) + pt = [] iatom += 1 # 1-based indexing xtmp, ytmp, ztmp = view(x, :, iatom) # TODO SArray? - itype = type[iatom] + append!(types, type[iatom]) + push!(pt, x[:, iatom]) + incut = 1 for jj in 1:length(neigh) jatom = Int(neigh[jj]) - factor_lj = special_lj[sbmask(jatom) + 1] jatom &= NEIGHMASK - jatom += 1 # 1-based indexing - + jatom += 1 # 1-based indexing + jtype = type[jatom] delx = xtmp - x[1, jatom] dely = ytmp - x[2, jatom] delz = ztmp - x[3, jatom] - jtype = type[jatom] - - rsq = delx*delx + dely*dely + delz*delz; + rsq = delx*delx + dely*dely + delz*delz if rsq < cutsq - fpair = factor_lj * compute_force(rsq, itype, jtype) - - if iatom <= nlocal - fexternal[1, iatom] += delx*fpair - fexternal[2, iatom] += dely*fpair - fexternal[3, iatom] += delz*fpair - if jatom <= nlocal || newton_pair - fexternal[1, jatom] -= delx*fpair - fexternal[2, jatom] -= dely*fpair - fexternal[3, jatom] -= delz*fpair - end - energies[iatom] += compute_energy(rsq, itype, jtype) - end + append!(types, jtype) + push!(pt, x[:, jatom]) + incut += 1 end end + fexternal[:, iatom] = compute_force(reshape(pt, (3, incut)), types)[1] + if eflag + energies[iatom] = compute_energy(reshape(pt, (3, incut)), types)[1] + end + end + if eflag + API.lammps_fix_external_set_energy_peratom(fix.lmp, fix.name, energies) + energy_global!(fix, sum(energies)) + end + if vflag + virial = virial_fdotr_compute(fexternal, x, nlocal+nghost) end - API.lammps_fix_external_set_energy_peratom(fix.lmp, fix.name, energies) - energy_global!(fix, sum(energies)) end FixExternal(pair, lmp, name) end diff --git a/test/external_pair.jl b/test/external_pair.jl index 950d1cc..bd608dd 100644 --- a/test/external_pair.jl +++ b/test/external_pair.jl @@ -49,7 +49,7 @@ end end # Register external fix -lj = LAMMPS.PairExternal(lmp_julia, "julia_lj", "zero", compute_force, compute_energy, cutoff) +lj = LAMMPS.PairExternal(lmp_julia, "julia_lj", "zero", compute_force, compute_energy, cutoff, true, true) # Setup atoms natoms = 10 From 36f4139215c79dfdefc6f34470ecaca4f078bfa0 Mon Sep 17 00:00:00 2001 From: Felipe Tome Date: Thu, 7 Nov 2024 15:43:11 -0300 Subject: [PATCH 13/13] Minimal working example with ML potentials --- examples/fix_external.jl | 93 ++++++++++++++++++++++++++-------------- 1 file changed, 62 insertions(+), 31 deletions(-) diff --git a/examples/fix_external.jl b/examples/fix_external.jl index 08f1a3b..9fab73a 100644 --- a/examples/fix_external.jl +++ b/examples/fix_external.jl @@ -1,6 +1,16 @@ using LAMMPS +using Test +using AtomsCalculators: potential_energy, forces +using AtomsBase +using AtomsBuilder +using ACEpotentials +using ExtXYZ +using AtomsBase: AbstractSystem +using LinearAlgebra: norm +using Unitful + +lmp = LMP(["-screen", "none"]) -lmp = LMP() command(lmp, "units lj") command(lmp, "atom_style atomic") @@ -21,44 +31,65 @@ command(lmp, "pair_style zero $cutoff") command(lmp, "pair_coeff * *") command(lmp, "fix julia_lj all external pf/callback 1 1") -const coefficients = Base.ImmutableDict( - 1 => Base.ImmutableDict( - 1 => [48.0, 24.0, 4.0, 4.0] - ) -) - -@inline function compute_force(rsq, itype, jtype) - coeff = coefficients[itype][jtype] - r2inv = 1.0/rsq - r6inv = r2inv^3 - lj1 = coeff[1] - lj2 = coeff[2] - return (r6inv * (lj1*r6inv - lj2))*r2inv +if !isfile("Si_dataset.xyz") + download("https://www.dropbox.com/scl/fi/z6lvcpx3djp775zenz032/Si-PRX-2018.xyz?rlkey=ja5e9z99c3ta1ugra5ayq5lcv&st=cs6g7vbu&dl=1", + "Si_dataset.xyz"); end -@inline function compute_energy(rsq, itype, jtype) - coeff = coefficients[itype][jtype] - r2inv = 1.0/rsq - r6inv = r2inv^3 - lj3 = coeff[3] - lj4 = coeff[4] - return (r6inv * (lj3*r6inv - lj4)) +Si_dataset = ExtXYZ.load("Si_dataset.xyz"); + +Si_tiny_dataset, _, _ = ACEpotentials.example_dataset("Si_tiny"); + +deleteat!(Si_dataset, 1); + +hyperparams = (elements = [:Si,], + order = 3, + totaldegree = 8, + rcut = 2.5, + Eref = Dict(:Si => -158.54496821)) +model = ACEpotentials.ace1_model(;hyperparams...); +solver = ACEfit.QR(lambda=1e-1) +data_keys = (energy_key = "dft_energy", force_key = "dft_force", virial_key = "dft_virial") +acefit!(Si_tiny_dataset, model; + solver=solver, data_keys...); +labelmap = Dict(1 => :Si) + +@inline function compute_force(pos, types) + sys_size = size(pos, 2) + particles = [AtomsBase.Atom(ChemicalSpecies(labelmap[types[i]]), pos[:, i].*u"Å") for i in 1:sys_size] + cell = AtomsBuilder.bulk(:Si, cubic=true) * 3 + sys = FlexibleSystem(particles, cell) + f = forces(sys, model) + return ustrip.(f) +end + +@inline function compute_energy(pos, types) + sys_size = size(pos, 2) + particles = [AtomsBase.Atom(ChemicalSpecies(labelmap[types[i]]), pos[:, i].*u"Å") for i in 1:sys_size] + cell = AtomsBuilder.bulk(:Si, cubic=true) * 3 + sys = FlexibleSystem(particles, cell) + energy = potential_energy(sys, model) + return ustrip(energy) end # Register external fix -lj = LAMMPS.PairExternal(lmp, "julia_lj", "zero", compute_force, compute_energy, cutoff) +lj = LAMMPS.PairExternal(lmp, "julia_lj", "zero", compute_force, compute_energy, cutoff, true, true) # Setup atoms -natoms = 10 -command(lmp, "create_atoms 1 random $natoms 1 NULL") -command(lmp, "mass 1 1.0") - -# (x,y,z), natoms -positions = rand(3, 10) .* 5 - -LAMMPS.API.lammps_scatter_atoms(lmp, "x", 1, 3, positions) +natoms = 54 +command(lmp, "labelmap atom 1 Si") +command(lmp, "create_atoms Si random $natoms 1 NULL") +positions = rand(3, natoms) .* 5 +command(lmp, "mass 1 28.0855") +LAMMPS.scatter!(lmp, "x", positions) command(lmp, "run 0") # extract forces -forces = extract_atom(lmp, "f") +forces_julia = gather(lmp, "f", Float64) + +particles = [AtomsBase.Atom(ChemicalSpecies(labelmap[1]), positions[:, i].*u"Å") for i in 1:54] +cell = AtomsBuilder.bulk(:Si, cubic=true) * 3 +sys = FlexibleSystem(particles, cell) +f = forces(sys, model) +