diff --git a/src/LAMMPS.jl b/src/LAMMPS.jl index a9620d5..dab5a73 100644 --- a/src/LAMMPS.jl +++ b/src/LAMMPS.jl @@ -3,7 +3,7 @@ import MPI include("api.jl") export LMP, command, get_natoms, extract_atom, extract_compute, extract_global, - gather_atoms + gather_atoms, gather, scatter! using Preferences @@ -361,19 +361,134 @@ function extract_variable(lmp::LMP, name::String, group=nothing) end end -function gather_atoms(lmp::LMP, name, T, count) - if T === Int32 - dtype = 0 - elseif T === Float64 - dtype = 1 +@deprecate gather_atoms gather + +""" + gather(lmp::LMP, name::String, + T::Union{Nothing, Type{Int32}, Type{Float64}}=nothing, + count::Union{Nothing, Integer}=nothing; + ids::Union{Nothing, Array{Int32}}=nothing, + ) + +Gather the named per-atom, per-atom fix, per-atom compute, or fix property/atom-based entities from all processes. +By default (when `ids=nothing`), this method collects data from all atoms in consecutive order according to their IDs. +The optional parameter `ids` determines for which subset of atoms the requested data will be gathered. The returned data will then be ordered according to `ids` + +`T` and `count` are not optional for fix or compute entities. + +The returned Array is decoupled from the internal state of the LAMMPS instance. +""" +function gather(lmp::LMP, name::String, + T::Union{Type{Int32}, Type{Float64}, Nothing}=nothing, + count::Union{Nothing, Integer}=nothing; + ids::Union{Nothing, Array{Int32}}=nothing, + ) + + @assert name !== "mass" "masses can not be gathered. Use `extract_atom` instead!" + + _scatter_gather_check_natoms(lmp) + + _T_count = _name_to_T_and_count(name) + + if !isnothing(_T_count) + (_T, _count) = _T_count + + !isnothing(count) && @assert _count == count "the defined count ($count) doesn't match the count ($_count) assotiated with $name" + !isnothing(T) && @assert _T === T "the defined data type ($T) doesn't match the data type ($_T) assotiated with `$name`" else - error("Only Int32 or Float64 allowed as T, got $T") + @assert !isnothing(count) "count couldn't be determinated through name; It's nessecary to specify it explicitly" + @assert !isnothing(T) "type couldn't be determinated through name; It's nessecary to specify it explicitly" + + (_T, _count) = (T, count) end + + dtype = _T === Float64 natoms = get_natoms(lmp) - data = Array{T, 2}(undef, (count, natoms)) - API.lammps_gather_atoms(lmp, name, dtype, count, data) + + if ids === nothing + data = Array{_T, 2}(undef, (_count, natoms)) + API.lammps_gather(lmp, name, dtype, _count, data) + else + @assert all(1 .<= ids .<= natoms) + + ndata = length(ids) + data = Array{_T, 2}(undef, (_count, ndata)) + API.lammps_gather_subset(lmp, name, dtype, _count, ndata, ids, data) + end + check(lmp) return data end +""" + scatter!(lmp::LMP, name::String, data::Array{T}; ids::Union{Nothing, Array{Int32}}=nothing) where T<:Union{Int32, Float64} + +Scatter the named per-atom, per-atom fix, per-atom compute, or fix property/atom-based entity in data to all processes. +By default (when `ids=nothing`), this method scatters data to all atoms in consecutive order according to their IDs. +The optional parameter `ids` determines to which subset of atoms the data will be scattered. +""" +function scatter!(lmp::LMP, name::String, data::Array{T}; ids::Union{Nothing, Array{Int32}}=nothing) where T<:Union{Int32, Float64} + @assert name !== "mass" "masses can not be scattered. Use `extract_atom` instead!" + + _scatter_gather_check_natoms(lmp) + + ndata = isnothing(ids) ? get_natoms(lmp) : length(ids) + (count, r) = divrem(length(data), ndata) + @assert r == 0 "illegal length of data" + + _T_count = _name_to_T_and_count(name) + + if !isnothing(_T_count) + (_T, _count) = _T_count + + @assert _T === T "the array data type ($T) doesn't match the data type ($_T) assotiated with $name" + @assert _count == count "the length of data ($(count*ndata)) doesn't match the length ($(_count*ndata)) assotiated with `$name`" + end + + dtype = (T === Float64) + + if ids === nothing + @assert ndata == get_natoms(lmp) + API.lammps_scatter(lmp, name, dtype, count, data) + else + @assert all(1 .<= ids .<= natoms) + API.lammps_scatter_subset(lmp, name, dtype, count, ndata, ids, data) + end + + check(lmp) + return nothing +end + +_scatter_gather_check_natoms(lmp) = @assert get_natoms(lmp) <= typemax(Int32) "scatter/gather operations only work on systems with less than 2^31 atoms!" + +function _name_to_T_and_count(name::String) + # values taken from: https://docs.lammps.org/Classes_atom.html#_CPPv4N9LAMMPS_NS4Atom7extractEPKc + + name == "mass" && return (Float64, 1) # should be handeled seperately + name == "id" && return (Int32, 1) + name == "type" && return (Int32, 1) + name == "mask" && return (Int32, 1) + name == "image" && return (Int32, 1) + name == "x" && return (Float64, 3) + name == "v" && return (Float64, 3) + name == "f" && return (Float64, 3) + name == "molecule" && return (Int32, 1) + name == "q" && return (Float64, 1) + name == "mu" && return (Float64, 3) + name == "omega" && return (Float64, 3) + name == "angmom" && return (Float64, 3) + name == "torque" && return (Float64, 3) + name == "radius" && return (Float64, 1) + name == "rmass" && return (Float64, 1) + name == "ellipsoid" && return (Int32, 1) + name == "line" && return (Int32, 1) + name == "tri" && return (Int32, 1) + name == "body" && return (Int32, 1) + name == "quat" && return (Float64, 4) + name == "temperature" && return (Float64, 1) + name == "heatflow" && return (Float64, 1) + + return nothing +end + end # module