diff --git a/src/LAMMPS.jl b/src/LAMMPS.jl index a9620d5..c829086 100644 --- a/src/LAMMPS.jl +++ b/src/LAMMPS.jl @@ -3,7 +3,7 @@ import MPI include("api.jl") export LMP, command, get_natoms, extract_atom, extract_compute, extract_global, - gather_atoms + gather, scatter! using Preferences @@ -361,19 +361,150 @@ function extract_variable(lmp::LMP, name::String, group=nothing) end end -function gather_atoms(lmp::LMP, name, T, count) - if T === Int32 - dtype = 0 - elseif T === Float64 - dtype = 1 +@deprecate gather_atoms(lmp::LMP, name, T, count) gather(lmp, name, T) + + +""" + gather(lmp::LMP, name::String, T::Union{Type{Int32}, Type{Float64}}, ids::Union{Nothing, Array{Int32}}=nothing) + +Gather the named per-atom, per-atom fix, per-atom compute, or fix property/atom-based entities from all processes. +By default (when `ids=nothing`), this method collects data from all atoms in consecutive order according to their IDs. +The optional parameter `ids` determines for which subset of atoms the requested data will be gathered. The returned data will then be ordered according to `ids` + +Compute entities have the prefix `c_`, fix entities use the prefix `f_`, and per-atom entites have no prefix. + +The returned Array is decoupled from the internal state of the LAMMPS instance. + +!!! warning "Type Verification" + Due to how the underlying C-API works, it's not possible to verify the element data-type of fix or compute style data. + Supplying the wrong data-type will not throw an error but will result in nonsensical output + +!!! warning "ids" + The optional parameter `ids` only works, if there is a map defined. For example by doing: + `command(lmp, "atom_modify map yes")` + However, LAMMPS only issues a warning if that's the case, which unfortuately cannot be detected through the underlying API. + Starting form LAMMPS version `17 Apr 2024` this should no longer be an issue, as LAMMPS then throws an error instead of a warning. +""" +function gather(lmp::LMP, name::String, T::Union{Type{Int32}, Type{Float64}}, ids::Union{Nothing, Array{Int32}}=nothing) + name == "mass" && error("scattering/gathering mass is currently not supported! Use `extract_atom()` instead.") + + count = _get_count(lmp, name) + _T = _get_T(lmp, name) + + @assert ismissing(_T) || _T == T "Expected data type $_T got $T instead." + + dtype = (T === Float64) + natoms = get_natoms(lmp) + ndata = isnothing(ids) ? natoms : length(ids) + data = Matrix{T}(undef, (count, ndata)) + + if isnothing(ids) + API.lammps_gather(lmp, name, dtype, count, data) else - error("Only Int32 or Float64 allowed as T, got $T") + @assert all(1 <= id <= natoms for id in ids) + API.lammps_gather_subset(lmp, name, dtype, count, ndata, ids, data) end - natoms = get_natoms(lmp) - data = Array{T, 2}(undef, (count, natoms)) - API.lammps_gather_atoms(lmp, name, dtype, count, data) + check(lmp) return data end +""" + scatter!(lmp::LMP, name::String, data::VecOrMat{T}, ids::Union{Nothing, Array{Int32}}=nothing) where T<:Union{Int32, Float64} + +Scatter the named per-atom, per-atom fix, per-atom compute, or fix property/atom-based entity in data to all processes. +By default (when `ids=nothing`), this method scatters data to all atoms in consecutive order according to their IDs. +The optional parameter `ids` determines to which subset of atoms the data will be scattered. + +Compute entities have the prefix `c_`, fix entities use the prefix `f_`, and per-atom entites have no prefix. + +!!! warning "Type Verification" + Due to how the underlying C-API works, it's not possible to verify the element data-type of fix or compute style data. + Supplying the wrong data-type will not throw an error but will result in nonsensical date being supplied to the LAMMPS instance. + +!!! warning "ids" + The optional parameter `ids` only works, if there is a map defined. For example by doing: + `command(lmp, "atom_modify map yes")` + However, LAMMPS only issues a warning if that's the case, which unfortuately cannot be detected through the underlying API. + Starting form LAMMPS version `17 Apr 2024` this should no longer be an issue, as LAMMPS then throws an error instead of a warning. +""" +function scatter!(lmp::LMP, name::String, data::VecOrMat{T}, ids::Union{Nothing, Array{Int32}}=nothing) where T<:Union{Int32, Float64} + name == "mass" && error("scattering/gathering mass is currently not supported! Use `extract_atom()` instead.") + + count = _get_count(lmp, name) + _T = _get_T(lmp, name) + + @assert ismissing(_T) || _T == T "Expected data type $_T got $T instead." + + dtype = (T === Float64) + natoms = get_natoms(lmp) + ndata = isnothing(ids) ? natoms : length(ids) + + if data isa Vector + @assert count == 1 + @assert ndata == lenght(data) + else + @assert count == size(data,1) + @assert ndata == size(data,2) + end + + if isnothing(ids) + API.lammps_scatter(lmp, name, dtype, count, data) + else + @assert all(1 <= id <= natoms for id in ids) + API.lammps_scatter_subset(lmp, name, dtype, count, ndata, ids, data) + end + + check(lmp) +end + +function _get_count(lmp::LMP, name::String) + # values taken from: https://docs.lammps.org/Classes_atom.html#_CPPv4N9LAMMPS_NS4Atom7extractEPKc + + if startswith(name, r"[f,c]_") + if name[1] == 'c' + API.lammps_has_id(lmp, "compute", name[3:end]) != 1 && error("Unknown per atom compute $name") + + count_ptr = API.lammps_extract_compute(lmp::LMP, name[3:end], API.LMP_STYLE_ATOM, API.LMP_SIZE_COLS) + else + API.lammps_has_id(lmp, "fix", name[3:end]) != 1 && error("Unknown per atom fix $name") + + count_ptr = API.lammps_extract_fix(lmp::LMP, name[3:end], API.LMP_STYLE_ATOM, API.LMP_SIZE_COLS, 0, 0) + end + check(lmp) + + count_ptr = reinterpret(Ptr{Cint}, count_ptr) + count = unsafe_load(count_ptr) + + # a count of 0 indicates that the entity is a vector. In order to perserve type stability we just treat that as a 1xN Matrix. + return count == 0 ? 1 : count + elseif name in ("mass", "id", "type", "mask", "image", "molecule", "q", "radius", "rmass", "ellipsoid", "line", "tri", "body", "temperature", "heatflow") + return 1 + elseif name in ("x", "v", "f", "mu", "omega", "angmom", "torque") + return 3 + elseif name == "quat" + return 4 + else + error("Unknown per atom property $name") + end +end + +function _get_T(lmp::LMP, name::String) + if startswith(name, r"[f,c]_") + return missing # As far as I know, it's not possible to determine the datatype of computes or fixes at runtime + end + + type = API.lammps_extract_atom_datatype(lmp, name) + check(lmp) + + if type in (API.LAMMPS_INT, API.LAMMPS_INT_2D) + return Int32 + elseif type in (API.LAMMPS_DOUBLE, API.LAMMPS_DOUBLE_2D) + return Float64 + else + error("Unkown per atom property $name") + end + +end + end # module diff --git a/test/runtests.jl b/test/runtests.jl index 061cc69..01d1dd9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -41,4 +41,64 @@ end end end +@testset "gather/scatter" begin + LMP(["-screen", "none"]) do lmp + # setting up example data + command(lmp, "atom_modify map yes") + command(lmp, "region cell block 0 3 0 3 0 3") + command(lmp, "create_box 1 cell") + command(lmp, "lattice sc 1") + command(lmp, "create_atoms 1 region cell") + command(lmp, "mass 1 1") + + command(lmp, "compute pos all property/atom x y z") + command(lmp, "fix pos all ave/atom 10 1 10 c_pos[1] c_pos[2] c_pos[3]") + + command(lmp, "run 10") + data = zeros(Float64, 3, 27) + subset = Int32.([2,5,10, 5]) + data_subset = ones(Float64, 3, 4) + + subset_bad1 = Int32.([28]) + subset_bad2 = Int32.([0]) + subset_bad_data = ones(Float64, 3,1) + + @test_throws AssertionError gather(lmp, "x", Int32) + @test_throws AssertionError gather(lmp, "id", Float64) + + @test_throws ErrorException gather(lmp, "nonesense", Float64) + @test_throws ErrorException gather(lmp, "c_nonsense", Float64) + @test_throws ErrorException gather(lmp, "f_nonesense", Float64) + + @test_throws AssertionError gather(lmp, "x", Float64, subset_bad1) + @test_throws AssertionError gather(lmp, "x", Float64, subset_bad2) + + @test_throws ErrorException scatter!(lmp, "nonesense", data) + @test_throws ErrorException scatter!(lmp, "c_nonsense", data) + @test_throws ErrorException scatter!(lmp, "f_nonesense", data) + + @test_throws AssertionError scatter!(lmp, "x", subset_bad_data, subset_bad1) + @test_throws AssertionError scatter!(lmp, "x", subset_bad_data, subset_bad2) + + @test gather(lmp, "x", Float64) == gather(lmp, "c_pos", Float64) == gather(lmp, "f_pos", Float64) + + @test gather(lmp, "x", Float64)[:,subset] == gather(lmp, "x", Float64, subset) + @test gather(lmp, "c_pos", Float64)[:,subset] == gather(lmp, "c_pos", Float64, subset) + @test gather(lmp, "f_pos", Float64)[:,subset] == gather(lmp, "f_pos", Float64, subset) + + scatter!(lmp, "x", data) + scatter!(lmp, "f_pos", data) + scatter!(lmp, "c_pos", data) + + @test gather(lmp, "x", Float64) == gather(lmp, "c_pos", Float64) == gather(lmp, "f_pos", Float64) == data + + scatter!(lmp, "x", data_subset, subset) + scatter!(lmp, "c_pos", data_subset, subset) + scatter!(lmp, "f_pos", data_subset, subset) + + @test gather(lmp, "x", Float64, subset) == gather(lmp, "c_pos", Float64, subset) == gather(lmp, "f_pos", Float64, subset) == data_subset + + end +end + @test success(pipeline(`$(MPI.mpiexec()) -n 2 $(Base.julia_cmd()) mpitest.jl`, stderr=stderr, stdout=stdout))