From 05b846dfe8277306ba8235cea10dd91f05e7bd5e Mon Sep 17 00:00:00 2001 From: Joroks Date: Thu, 20 Jun 2024 14:26:41 +0200 Subject: [PATCH 1/7] support gather/scatter --- src/LAMMPS.jl | 65 +++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 55 insertions(+), 10 deletions(-) diff --git a/src/LAMMPS.jl b/src/LAMMPS.jl index a9620d5..fad1e99 100644 --- a/src/LAMMPS.jl +++ b/src/LAMMPS.jl @@ -3,7 +3,7 @@ import MPI include("api.jl") export LMP, command, get_natoms, extract_atom, extract_compute, extract_global, - gather_atoms + gather_atoms, gather, scatter using Preferences @@ -361,19 +361,64 @@ function extract_variable(lmp::LMP, name::String, group=nothing) end end -function gather_atoms(lmp::LMP, name, T, count) - if T === Int32 - dtype = 0 - elseif T === Float64 - dtype = 1 +""" + gather_atoms(lmp::LMP, name::String, T::Union{Type{Int32}, Type{Float64}}, count::Integer, ids::Union{Nothing, Array{Int32}}=nothing) +""" +function gather_atoms(lmp::LMP, name::String, T::Union{Type{Int32}, Type{Float64}}, count::Integer, ids::Union{Nothing, Array{Int32}}=nothing) + dtype = (T === Float64) + + if ids === nothing + natoms = get_natoms(lmp) + data = Array{T, 2}(undef, (count, natoms)) + API.lammps_gather_atoms(lmp, name, dtype, count, data) + else + ndata = length(ids) + data = Array{T, 2}(undef, (count, ndata)) + API.lammps_gather_atoms_subset(lmp, name, dtype, count, ndata, ids, data) + end + + check(lmp) + return data +end + +""" + gather(lmp::LMP, name::String, T::Union{Type{Int32}, Type{Float64}}, count::Integer, ids::Union{Nothing, Array{Int32}}=nothing) +""" +function gather(lmp::LMP, name::String, T::Union{Type{Int32}, Type{Float64}}, count::Integer, ids::Union{Nothing, Array{Int32}}=nothing) + dtype = (T === Float64) + + if ids === nothing + natoms = get_natoms(lmp) + data = Array{T, 2}(undef, (count, natoms)) + API.lammps_gather(lmp, name, dtype, count, data) else - error("Only Int32 or Float64 allowed as T, got $T") + ndata = length(ids) + data = Array{T, 2}(undef, (count, ndata)) + API.lammps_gather_subset(lmp, name, dtype, count, ndata, ids, data) end - natoms = get_natoms(lmp) - data = Array{T, 2}(undef, (count, natoms)) - API.lammps_gather_atoms(lmp, name, dtype, count, data) + check(lmp) return data end +""" + scatter(lmp::LMP, name::String, data::Matrix{T}, ids::Union{Nothing, Array{Int32}}=nothing) where T<:Union{Int32, Float64} +""" +function scatter(lmp::LMP, name::String, data::Matrix{T}, ids::Union{Nothing, Array{Int32}}=nothing) where T<:Union{Int32, Float64} + dtype = (T === Float64) + + if ids === nothing + (count, natoms) = size(data) + @assert natoms == get_natoms(lmp) + API.lammps_scatter(lmp, name, dtype, count, data) + else + (count, ndata) = size(data) + @assert ndata == length(ids) + API.lammps_scatter_subset(lmp, name, dtype, count, ndata, ids, data) + end + + check(lmp) + return nothing +end + end # module From f16f368bc0c4c2f7ba01e1eaa10b7a8d50ca1d15 Mon Sep 17 00:00:00 2001 From: Joroks Date: Thu, 20 Jun 2024 14:29:57 +0200 Subject: [PATCH 2/7] naming convention --- src/LAMMPS.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/LAMMPS.jl b/src/LAMMPS.jl index fad1e99..67bb1a3 100644 --- a/src/LAMMPS.jl +++ b/src/LAMMPS.jl @@ -3,7 +3,7 @@ import MPI include("api.jl") export LMP, command, get_natoms, extract_atom, extract_compute, extract_global, - gather_atoms, gather, scatter + gather_atoms, gather, scatter! using Preferences @@ -404,7 +404,7 @@ end """ scatter(lmp::LMP, name::String, data::Matrix{T}, ids::Union{Nothing, Array{Int32}}=nothing) where T<:Union{Int32, Float64} """ -function scatter(lmp::LMP, name::String, data::Matrix{T}, ids::Union{Nothing, Array{Int32}}=nothing) where T<:Union{Int32, Float64} +function scatter!(lmp::LMP, name::String, data::Matrix{T}, ids::Union{Nothing, Array{Int32}}=nothing) where T<:Union{Int32, Float64} dtype = (T === Float64) if ids === nothing From ab7ed1b93c0a5c6801d69d50df8e1389cb284d5f Mon Sep 17 00:00:00 2001 From: Joroks Date: Thu, 20 Jun 2024 20:35:00 +0200 Subject: [PATCH 3/7] add utility functions --- src/LAMMPS.jl | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/src/LAMMPS.jl b/src/LAMMPS.jl index 67bb1a3..bdbb902 100644 --- a/src/LAMMPS.jl +++ b/src/LAMMPS.jl @@ -420,5 +420,36 @@ function scatter!(lmp::LMP, name::String, data::Matrix{T}, ids::Union{Nothing, A check(lmp) return nothing end +_scatter_gather_check_natoms(lmp) = @assert get_natoms(lmp) <= typemax(Int32) "scatter/gather operations only work on systems with less than 2^31 atoms!" + +function _name_to_T_and_count(name::String) + # values taken from: https://docs.lammps.org/Classes_atom.html#_CPPv4N9LAMMPS_NS4Atom7extractEPKc + + name == "mass" && return (Float64, 1) # should be handeled seperately + name == "id" && return (Int32, 1) + name == "type" && return (Int32, 1) + name == "mask" && return (Int32, 1) + name == "image" && return (Int32, 1) + name == "x" && return (Float64, 3) + name == "v" && return (Float64, 3) + name == "f" && return (Float64, 3) + name == "molecule" && return (Int32, 1) + name == "q" && return (Float64, 1) + name == "mu" && return (Float64, 3) + name == "omega" && return (Float64, 3) + name == "angmom" && return (Float64, 3) + name == "torque" && return (Float64, 3) + name == "radius" && return (Float64, 1) + name == "rmass" && return (Float64, 1) + name == "ellipsoid" && return (Int32, 1) + name == "line" && return (Int32, 1) + name == "tri" && return (Int32, 1) + name == "body" && return (Int32, 1) + name == "quat" && return (Float64, 4) + name == "temperature" && return (Float64, 1) + name == "heatflow" && return (Float64, 1) + + return nothing +end end # module From f0d1d059fdf1dce85b8a5e52ba53bce7bb889f5b Mon Sep 17 00:00:00 2001 From: Joroks Date: Thu, 20 Jun 2024 20:37:03 +0200 Subject: [PATCH 4/7] change gather/skatter API --- src/LAMMPS.jl | 80 +++++++++++++++++++++++++++++++-------------------- 1 file changed, 49 insertions(+), 31 deletions(-) diff --git a/src/LAMMPS.jl b/src/LAMMPS.jl index bdbb902..76d46e8 100644 --- a/src/LAMMPS.jl +++ b/src/LAMMPS.jl @@ -361,40 +361,42 @@ function extract_variable(lmp::LMP, name::String, group=nothing) end end -""" - gather_atoms(lmp::LMP, name::String, T::Union{Type{Int32}, Type{Float64}}, count::Integer, ids::Union{Nothing, Array{Int32}}=nothing) -""" -function gather_atoms(lmp::LMP, name::String, T::Union{Type{Int32}, Type{Float64}}, count::Integer, ids::Union{Nothing, Array{Int32}}=nothing) - dtype = (T === Float64) +function gather(lmp::LMP, name::String; + T::Union{Type{Int32}, Type{Float64}, Nothing}=nothing, + count::Union{Nothing, Integer}=nothing, + ids::Union{Nothing, Array{Int32}}=nothing, + ) - if ids === nothing - natoms = get_natoms(lmp) - data = Array{T, 2}(undef, (count, natoms)) - API.lammps_gather_atoms(lmp, name, dtype, count, data) + @assert name !== "mass" "masses can not be gathered. Use `extract_atom` instead!" + + _scatter_gather_check_natoms(lmp) + + _T_count = _name_to_T_and_count(name) + + if !isnothing(_T_count) + (_T, _count) = _T_count + + !isnothing(count) && @assert _count == count "the defined count ($count) doesn't match the count ($_count) assotiated with $name" + !isnothing(T) && @assert _T === T "the defined data type ($T) doesn't match the data type ($_T) assotiated with `$name`" else - ndata = length(ids) - data = Array{T, 2}(undef, (count, ndata)) - API.lammps_gather_atoms_subset(lmp, name, dtype, count, ndata, ids, data) + @assert !isnothing(count) "count couldn't be determinated through name; It's nessecary to specify it explicitly" + @assert !isnothing(T) "type couldn't be determinated through name; It's nessecary to specify it explicitly" + + (_T, _count) = (T, count) end - - check(lmp) - return data -end -""" - gather(lmp::LMP, name::String, T::Union{Type{Int32}, Type{Float64}}, count::Integer, ids::Union{Nothing, Array{Int32}}=nothing) -""" -function gather(lmp::LMP, name::String, T::Union{Type{Int32}, Type{Float64}}, count::Integer, ids::Union{Nothing, Array{Int32}}=nothing) - dtype = (T === Float64) + dtype = _T === Float64 + natoms = get_natoms(lmp) if ids === nothing - natoms = get_natoms(lmp) - data = Array{T, 2}(undef, (count, natoms)) - API.lammps_gather(lmp, name, dtype, count, data) + data = Array{_T, 2}(undef, (_count, natoms)) + API.lammps_gather(lmp, name, dtype, _count, data) else + @assert all(1 .<= ids .<= natoms) + ndata = length(ids) - data = Array{T, 2}(undef, (count, ndata)) - API.lammps_gather_subset(lmp, name, dtype, count, ndata, ids, data) + data = Array{_T, 2}(undef, (_count, ndata)) + API.lammps_gather_subset(lmp, name, dtype, _count, ndata, ids, data) end check(lmp) @@ -404,22 +406,38 @@ end """ scatter(lmp::LMP, name::String, data::Matrix{T}, ids::Union{Nothing, Array{Int32}}=nothing) where T<:Union{Int32, Float64} """ -function scatter!(lmp::LMP, name::String, data::Matrix{T}, ids::Union{Nothing, Array{Int32}}=nothing) where T<:Union{Int32, Float64} +function scatter!(lmp::LMP, name::String, data::Array{T}; ids::Union{Nothing, Array{Int32}}=nothing) where T<:Union{Int32, Float64} + @assert name !== "mass" "masses can not be scattered. Use `extract_atom` instead!" + + _scatter_gather_check_natoms(lmp) + + ndata = isnothing(ids) ? get_natoms(lmp) : length(ids) + (count, r) = divrem(length(data), ndata) + @assert r == 0 "illegal length of data" + + _T_count = _name_to_T_and_count(name) + + if !isnothing(_T_count) + (_T, _count) = _T_count + + @assert _T === T "the array data type ($T) doesn't match the data type ($_T) assotiated with $name" + @assert _count == count "the length of data ($(count*ndata)) doesn't match the length ($(_count*ndata)) assotiated with `$name`" + end + dtype = (T === Float64) if ids === nothing - (count, natoms) = size(data) - @assert natoms == get_natoms(lmp) + @assert ndata == get_natoms(lmp) API.lammps_scatter(lmp, name, dtype, count, data) else - (count, ndata) = size(data) - @assert ndata == length(ids) + @assert all(1 .<= ids .<= natoms) API.lammps_scatter_subset(lmp, name, dtype, count, ndata, ids, data) end check(lmp) return nothing end + _scatter_gather_check_natoms(lmp) = @assert get_natoms(lmp) <= typemax(Int32) "scatter/gather operations only work on systems with less than 2^31 atoms!" function _name_to_T_and_count(name::String) From c8f7209202c02ba296a9384dc38be19f1466fc74 Mon Sep 17 00:00:00 2001 From: Joroks Date: Thu, 20 Jun 2024 20:46:56 +0200 Subject: [PATCH 5/7] change kwargs --- src/LAMMPS.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/LAMMPS.jl b/src/LAMMPS.jl index 76d46e8..b28773b 100644 --- a/src/LAMMPS.jl +++ b/src/LAMMPS.jl @@ -361,9 +361,9 @@ function extract_variable(lmp::LMP, name::String, group=nothing) end end -function gather(lmp::LMP, name::String; +function gather(lmp::LMP, name::String, T::Union{Type{Int32}, Type{Float64}, Nothing}=nothing, - count::Union{Nothing, Integer}=nothing, + count::Union{Nothing, Integer}=nothing; ids::Union{Nothing, Array{Int32}}=nothing, ) From 81b86ed9b7b7e6460c1f2c6b615b12f9c9c3bf71 Mon Sep 17 00:00:00 2001 From: Joroks Date: Thu, 20 Jun 2024 20:48:15 +0200 Subject: [PATCH 6/7] docstrings --- src/LAMMPS.jl | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/src/LAMMPS.jl b/src/LAMMPS.jl index b28773b..1a7c1fa 100644 --- a/src/LAMMPS.jl +++ b/src/LAMMPS.jl @@ -361,6 +361,21 @@ function extract_variable(lmp::LMP, name::String, group=nothing) end end +""" + gather(lmp::LMP, name::String, + T::Union{Nothing, Type{Int32}, Type{Float64}}=nothing, + count::Union{Nothing, Integer}=nothing; + ids::Union{Nothing, Array{Int32}}=nothing, + ) + +Gather the named per-atom, per-atom fix, per-atom compute, or fix property/atom-based entities from all processes. +By default (when `ids=nothing`), this method collects data from all atoms in consecutive order according to their IDs. +The optional parameter `ids` determines for which subset of atoms the requested data will be gathered. The returned data will then be ordered according to `ids` + +`T` and `count` are not optional for fix or compute entities. + +The returned Array is decoupled from the internal state of the LAMMPS instance. +""" function gather(lmp::LMP, name::String, T::Union{Type{Int32}, Type{Float64}, Nothing}=nothing, count::Union{Nothing, Integer}=nothing; @@ -404,7 +419,11 @@ function gather(lmp::LMP, name::String, end """ - scatter(lmp::LMP, name::String, data::Matrix{T}, ids::Union{Nothing, Array{Int32}}=nothing) where T<:Union{Int32, Float64} + scatter!(lmp::LMP, name::String, data::Array{T}; ids::Union{Nothing, Array{Int32}}=nothing) where T<:Union{Int32, Float64} + +Scatter the named per-atom, per-atom fix, per-atom compute, or fix property/atom-based entity in data to all processes. +By default (when `ids=nothing`), this method scatters data to all atoms in consecutive order according to their IDs. +The optional parameter `ids` determines to which subset of atoms the data will be scattered. """ function scatter!(lmp::LMP, name::String, data::Array{T}; ids::Union{Nothing, Array{Int32}}=nothing) where T<:Union{Int32, Float64} @assert name !== "mass" "masses can not be scattered. Use `extract_atom` instead!" From 47c03294ec486d152845d0fcdb0318fc385815de Mon Sep 17 00:00:00 2001 From: Joroks Date: Thu, 20 Jun 2024 20:51:54 +0200 Subject: [PATCH 7/7] deprecate gather_atoms --- src/LAMMPS.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/LAMMPS.jl b/src/LAMMPS.jl index 1a7c1fa..dab5a73 100644 --- a/src/LAMMPS.jl +++ b/src/LAMMPS.jl @@ -361,6 +361,8 @@ function extract_variable(lmp::LMP, name::String, group=nothing) end end +@deprecate gather_atoms gather + """ gather(lmp::LMP, name::String, T::Union{Nothing, Type{Int32}, Type{Float64}}=nothing,