From c8beaa5464d61d248e81f2eac38b50ad384c4250 Mon Sep 17 00:00:00 2001 From: Joroks <32484985+Joroks@users.noreply.github.com> Date: Mon, 1 Jul 2024 21:00:32 +0200 Subject: [PATCH] Rework extract functions. (#51) Co-authored-by: Valentin Churavy --- Project.toml | 2 +- examples/lj_forces.jl | 2 +- src/LAMMPS.jl | 551 ++++++++++++++++++++++++++++-------------- test/runtests.jl | 178 +++++++++++--- 4 files changed, 524 insertions(+), 209 deletions(-) diff --git a/Project.toml b/Project.toml index 31ab158..305ad92 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LAMMPS" uuid = "ee2e13b9-eee9-4449-aafa-cfa6a2dbe14d" authors = ["Valentin Churavy "] -version = "0.4.2" +version = "0.5.0" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" diff --git a/examples/lj_forces.jl b/examples/lj_forces.jl index aaf62eb..1374552 100644 --- a/examples/lj_forces.jl +++ b/examples/lj_forces.jl @@ -43,4 +43,4 @@ command(lmp, "run 0") # extract output forces = gather(lmp, "f", Float64) -energies = extract_compute(lmp, "pot_e", LAMMPS.API.LMP_STYLE_GLOBAL, LAMMPS.API.LMP_TYPE_SCALAR) \ No newline at end of file +energies = extract_compute(lmp, "pot_e", STYLE_GLOBAL, TYPE_SCALAR) \ No newline at end of file diff --git a/src/LAMMPS.jl b/src/LAMMPS.jl index ffdc91f..2cb2502 100644 --- a/src/LAMMPS.jl +++ b/src/LAMMPS.jl @@ -3,10 +3,78 @@ import MPI include("api.jl") export LMP, command, get_natoms, extract_atom, extract_compute, extract_global, - gather, scatter!, group_to_atom_ids, get_category_ids + extract_setting, gather, scatter!, group_to_atom_ids, get_category_ids, + extract_variable, + + # _LMP_DATATYPE + LAMMPS_NONE, + LAMMPS_INT, + LAMMPS_INT_2D, + LAMMPS_DOUBLE, + LAMMPS_DOUBLE_2D, + LAMMPS_INT64, + LAMMPS_INT64_2D, + LAMMPS_STRING, + + # _LMP_TYPE + TYPE_SCALAR, + TYPE_VECTOR, + TYPE_ARRAY, + SIZE_VECTOR, + SIZE_ROWS, + SIZE_COLS, + + # _LMP_VARIABLE + VAR_EQUAL, + VAR_ATOM, + VAR_VECTOR, + VAR_STRING, + + # _LMP_STYLE_CONST + STYLE_GLOBAL, + STYLE_ATOM, + STYLE_LOCAL + using Preferences +abstract type TypeEnum{N} end +get_enum(::TypeEnum{N}) where N = N + +struct _LMP_DATATYPE{N} <: TypeEnum{N} end + +const LAMMPS_NONE = _LMP_DATATYPE{API.LAMMPS_NONE}() +const LAMMPS_INT = _LMP_DATATYPE{API.LAMMPS_INT}() +const LAMMPS_INT_2D = _LMP_DATATYPE{API.LAMMPS_INT_2D}() +const LAMMPS_DOUBLE = _LMP_DATATYPE{API.LAMMPS_DOUBLE}() +const LAMMPS_DOUBLE_2D = _LMP_DATATYPE{API.LAMMPS_DOUBLE_2D}() +const LAMMPS_INT64 = _LMP_DATATYPE{API.LAMMPS_INT64}() +const LAMMPS_INT64_2D = _LMP_DATATYPE{API.LAMMPS_INT64_2D}() +const LAMMPS_STRING = _LMP_DATATYPE{API.LAMMPS_STRING}() + +struct _LMP_TYPE{N} <: TypeEnum{N} end + +const TYPE_SCALAR = _LMP_TYPE{API.LMP_TYPE_SCALAR}() +const TYPE_VECTOR = _LMP_TYPE{API.LMP_TYPE_VECTOR}() +const TYPE_ARRAY = _LMP_TYPE{API.LMP_TYPE_ARRAY}() +const SIZE_VECTOR = _LMP_TYPE{API.LMP_SIZE_VECTOR}() +const SIZE_ROWS = _LMP_TYPE{API.LMP_SIZE_ROWS}() +const SIZE_COLS = _LMP_TYPE{API.LMP_SIZE_COLS}() + +struct _LMP_VARIABLE{N} <: TypeEnum{N} end + +const VAR_EQUAL = _LMP_VARIABLE{API.LMP_VAR_EQUAL}() +const VAR_ATOM = _LMP_VARIABLE{API.LMP_VAR_ATOM}() +const VAR_VECTOR = _LMP_VARIABLE{API.LMP_VAR_VECTOR}() +const VAR_STRING = _LMP_VARIABLE{API.LMP_VAR_STRING}() + +# these are not defined as TypeEnum as they don't carry type information +const _LMP_STYLE_CONST = API._LMP_STYLE_CONST + +const STYLE_GLOBAL = API.LMP_STYLE_GLOBAL +const STYLE_ATOM = API.LMP_STYLE_ATOM +const STYLE_LOCAL = API.LMP_STYLE_LOCAL + """ locate() @@ -177,240 +245,363 @@ function get_natoms(lmp::LMP) Int64(API.lammps_get_natoms(lmp)) end -function dtype2type(dtype::API._LMP_DATATYPE_CONST) - if dtype == API.LAMMPS_INT - type = Ptr{Int32} - elseif dtype == API.LAMMPS_INT_2D - type = Ptr{Ptr{Int32}} - elseif dtype == API.LAMMPS_INT64 - type = Ptr{Int64} - elseif dtype == API.LAMMPS_INT64_2D - type = Ptr{Ptr{Int64}} - elseif dtype == API.LAMMPS_DOUBLE - type = Ptr{Float64} - elseif dtype == API.LAMMPS_DOUBLE_2D - type = Ptr{Ptr{Float64}} - elseif dtype == API.LAMMPS_STRING - type = Ptr{Cchar} - else - @assert false "Unknown dtype: $dtype" - end - return type +function _string(ptr::Ptr) + ptr == C_NULL && error("Wrapping NULL-pointer!") + return Base.unsafe_string(ptr) end -""" - extract_global(lmp, name, dtype=nothing) -""" -function extract_global(lmp::LMP, name, dtype=nothing) - if dtype === nothing - dtype = API.lammps_extract_global_datatype(lmp, name) - end - dtype = API._LMP_DATATYPE_CONST(dtype) - type = dtype2type(dtype) +function _extract(ptr::Ptr{<:Real}, shape::Integer; copy=false, own=false) + ptr == C_NULL && error("Wrapping NULL-pointer!") + result = Base.unsafe_wrap(Array, ptr, shape; own=false) - ptr = API.lammps_extract_global(lmp, name) - ptr = reinterpret(type, ptr) + if own && copy + result_copy = Base.copy(result) + API.lammps_free(result) + return result_copy + end - if ptr !== C_NULL - if dtype == API.LAMMPS_STRING - return Base.unsafe_string(ptr) + if own + @static if VERSION >= v"1.11-dev" + finalizer(API.lammps_free, result.ref.mem) + else + finalizer(API.lammps_free, result) end - # TODO: deal with non-scalar data - return Base.unsafe_load(ptr) end + + return copy ? Base.copy(result) : result end -function unsafe_wrap(ptr, shape) - if length(shape) > 1 - # We got a list of ptrs, - # but the first pointer points to the whole data - ptr = Base.unsafe_load(ptr) +function _extract(ptr::Ptr{<:Ptr{T}}, shape::NTuple{2}; copy=false) where T + # The `own` kwarg is not implemented for 2D data, as this is currently not used anywhere - @assert length(shape) == 2 + ptr == C_NULL && error("Wrapping NULL-pointer!") - # Note: Julia like Fortran is column-major - # so the data is transposed from Julia's perspective - shape = reverse(shape) - end + prod(shape) == 0 && return Matrix{T}(undef, shape) # There is no data that can be wrapped + + # TODO: implement a debug mode to do this assert + # pointers = Base.unsafe_wrap(Array, ptr, ndata) + # @assert all(diff(pointers) .== count*sizeof(T)) - # TODO: Who is responsible for freeing this data - array = Base.unsafe_wrap(Array, ptr, shape, own=false) - return array + # If the pointers are evenly spaced, we can simply use the first pointer to wrap our matrix. + first_pointer = unsafe_load(ptr) + result = Base.unsafe_wrap(Array, first_pointer, shape; own=false) + return copy ? Base.copy(result) : result end +function _reinterpret(T::_LMP_DATATYPE, ptr::Ptr) + T === LAMMPS_INT && return Base.reinterpret(Ptr{Int32}, ptr) + T === LAMMPS_INT_2D && return Base.reinterpret(Ptr{Ptr{Int32}}, ptr) + T === LAMMPS_DOUBLE && return Base.reinterpret(Ptr{Float64}, ptr) + T === LAMMPS_DOUBLE_2D && return Base.reinterpret(Ptr{Ptr{Float64}}, ptr) + T === LAMMPS_INT64 && return Base.reinterpret(Ptr{Int64}, ptr) + T === LAMMPS_INT64_2D && return Base.reinterpret(Ptr{Ptr{Int64}}, ptr) + T === LAMMPS_STRING && return Base.reinterpret(Ptr{UInt8}, ptr) +end + +_is_2D_datatype(lmp_dtype::_LMP_DATATYPE) = lmp_dtype in (LAMMPS_INT_2D, LAMMPS_DOUBLE_2D, LAMMPS_INT64_2D) + """ - extract_atom(lmp, name, dtype=nothing, axes1, axes2) -""" -function extract_atom(lmp::LMP, name, - dtype::Union{Nothing, API._LMP_DATATYPE_CONST} = nothing, - axes1=nothing, axes2=nothing) + extract_setting(lmp::LMP, name::String)::Int32 +Query LAMMPS about global settings. - if dtype === nothing - dtype = API.lammps_extract_atom_datatype(lmp, name) - dtype = API._LMP_DATATYPE_CONST(dtype) - end +A full list of settings can be found here: - if axes1 === nothing - if name == "mass" - axes1 = extract_global(lmp, "ntypes") + 1 - else - axes1 = extract_global(lmp, "nlocal") % Int - end +# Examples +```julia + LMP(["-screen", "none"]) do lmp + command(lmp, \""" + region cell block 0 3 0 3 0 3 + create_box 1 cell + lattice sc 1 + create_atoms 1 region cell + \""") + + extract_setting(lmp, "dimension") |> println # 3 + extract_setting(lmp, "nlocal") |> println # 27 end +``` +""" +function extract_setting(lmp::LMP, name::String)::Int32 + return API.lammps_extract_setting(lmp, name) +end - if axes2 === nothing - if dtype in (API.LAMMPS_INT_2D, API.LAMMPS_INT64_2D, API.LAMMPS_DOUBLE_2D) - # TODO: Other fields? - if name in ("x", "v", "f", "angmom", "torque", "csforce", "vforce") - axes2 = 3 - else - axes2 = 2 - end - end - end +""" + extract_global(lmp::LMP, name::String, lmp_type::_LMP_DATATYPE; copy::Bool=false) + +Extract a global property from a LAMMPS instance. + +| valid values for `lmp_type`: | resulting return type: | +| :--------------------------- | :------------------------ | +| `LAMMPS_INT` | `Vector{Int32}` | +| `LAMMPS_INT_2D` | `Matrix{Int32}` | +| `LAMMPS_DOUBLE` | `Vector{Float64}` | +| `LAMMPS_DOUBLE_2D` | `Matrix{Float64}` | +| `LAMMPS_INT64` | `Vector{Int64}` | +| `LAMMPS_INT64_2D` | `Matrix{Int64}` | +| `LAMMPS_STRING` | `String` (allways a copy) | + +Scalar values get returned as a vector with a single element. This way it's possible to +modify the internal state of the LAMMPS instance even if the data is scalar. + +!!! info + Closing the LAMMPS instance or issuing a clear command after calling this method + will result in the returned data becoming invalid. To prevent this, set `copy=true`. + +!!! warning + Modifying the data through `extract_global` may lead to inconsistent internal data and thus may cause failures or crashes or bogus simulations. + In general it is thus usually better to use a LAMMPS input command that sets or changes these parameters. + Those will take care of all side effects and necessary updates of settings derived from such settings. + +A full list of global variables can be found here: +""" +function extract_global(lmp::LMP, name::String, lmp_type::_LMP_DATATYPE; copy::Bool=false) + void_ptr = API.lammps_extract_global(lmp, name) + void_ptr == C_NULL && throw(KeyError("Unknown global variable $name")) + + expect = extract_global_datatype(lmp, name) + receive = get_enum(lmp_type) + expect != receive && error("TypeMismatch: Expected $expect got $receive instead!") + + ptr = _reinterpret(lmp_type, void_ptr) + + lmp_type == LAMMPS_STRING && return _string(ptr) - if axes2 !== nothing - shape = (axes1, axes2) + if name in ("boxlo", "boxhi", "sublo", "subhi", "sublo_lambda", "subhi_lambda", "periodicity") + length = 3 + elseif name in ("special_lj", "special_coul") + length = 4 else - shape = (axes1, ) + length = 1 end - type = dtype2type(dtype) - ptr = API.lammps_extract_atom(lmp, name) - ptr = reinterpret(type, ptr) + return _extract(ptr, length; copy=copy) +end - unsafe_wrap(ptr, shape) +function extract_global_datatype(lmp::LMP, name) + return API._LMP_DATATYPE_CONST(API.lammps_extract_global_datatype(lmp, name)) end -function unsafe_extract_compute(lmp::LMP, name, style, type) - if type == API.LMP_TYPE_SCALAR - if style == API.LMP_STYLE_GLOBAL - dtype = Ptr{Float64} - elseif style == API.LMP_STYLE_LOCAL - dtype = Ptr{Cint} - elseif style == API.LMP_STYLE_ATOM - return nothing - end - extract = true - elseif type == API.LMP_TYPE_VECTOR - dtype = Ptr{Float64} - extract = false - elseif type == API.LMP_TYPE_ARRAY - dtype = Ptr{Ptr{Float64}} - extract = false - elseif type == API.LMP_SIZE_COLS - dtype = Ptr{Cint} - extract = true - elseif type == API.LMP_SIZE_ROWS || - type == API.LMP_SIZE_VECTOR - if style == API.LMP_STYLE_ATOM - return nothing - end - dtype = Ptr{Cint} - extract = true - else - @assert false "Unknown type: $type" - end +""" + extract_atom(lmp::LMP, name::String, lmp_type::_LMP_DATATYPE; copy=false) + +Extract per-atom data from the lammps instance. + +| valid values for `lmp_type`: | resulting return type: | +| :--------------------------- | :--------------------- | +| `LAMMPS_INT` | `Vector{Int32}` | +| `LAMMPS_INT_2D` | `Matrix{Int32}` | +| `LAMMPS_DOUBLE` | `Vector{Float64}` | +| `LAMMPS_DOUBLE_2D` | `Matrix{Float64}` | +| `LAMMPS_INT64` | `Vector{Int64}` | +| `LAMMPS_INT64_2D` | `Matrix{Int64}` | + +the kwarg `copy`, which defaults to true, determies wheter a copy of the underlying data is made. +As the pointer to the underlying data is not persistent, it's highly recommended to only disable this, +if you wish to modify the internal state of the LAMMPS instance. + +!!! info + The returned data may become invalid if a re-neighboring operation + is triggered at any point after calling this method. If this has happened, + trying to read from this data will likely cause julia to crash. + To prevent this, set `copy=true`. + +A table with suported name keywords can be found here: +""" +function extract_atom(lmp::LMP, name::String, lmp_type::_LMP_DATATYPE; copy=false) + void_ptr = API.lammps_extract_atom(lmp, name) + void_ptr == C_NULL && throw(KeyError("Unknown per-atom variable $name")) + + expect = extract_atom_datatype(lmp, name) + receive = get_enum(lmp_type) + expect != receive && error("TypeMismatch: Expected $expect got $receive instead!") - ptr = API.lammps_extract_compute(lmp, name, style, type) - ptr == C_NULL && check(lmp) + ptr = _reinterpret(lmp_type, void_ptr) - if ptr == C_NULL - error("Could not extract_compute $name with $style and $type") + if name == "mass" + length = extract_global(lmp, "ntypes", LAMMPS_INT)[] + ptr += sizeof(eltype(ptr)) # Scarry pointer arithemtic; The first entry in the array is unused + return _extract(ptr, length; copy=copy) end - ptr = reinterpret(dtype, ptr) - if extract - return Base.unsafe_load(ptr) + length = extract_setting(lmp, "nlocal") + + if _is_2D_datatype(lmp_type) + # only Quaternions have 4 entries + # length is a Int32 and lammps_wrap expects a NTuple, so it's + # neccecary to use Int32 for count as well + count = name == "quat" ? Int32(4) : Int32(3) + return _extract(ptr, (count, length); copy=copy) end - return ptr + + return _extract(ptr, length; copy=copy) +end + +function extract_atom_datatype(lmp::LMP, name) + return API._LMP_DATATYPE_CONST(API.lammps_extract_atom_datatype(lmp, name)) end """ - extract_compute(lmp, name, style, type) + extract_compute(lmp::LMP, name::String, style::_LMP_STYLE_CONST, lmp_type::_LMP_TYPE; copy::Bool=true) + +Extract data provided by a compute command identified by the compute-ID. +Computes may provide global, per-atom, or local data, and those may be a scalar, a vector or an array. +Since computes may provide multiple kinds of data, it is required to set style and type flags representing what specific data is desired. + +| valid values for `style`: | +| :------------------------ | +| `STYLE_GLOBAL` | +| `STYLE_ATOM` | +| `STYLE_LOCAL` | + +| valid values for `lmp_type`: | resulting return type: | +| :--------------------------- | :--------------------- | +| `TYPE_SCALAR` | `Vector{Float64}` | +| `TYPE_VECTOR` | `Vector{Float64}` | +| `TYPE_ARRAY` | `Matrix{Float64}` | +| `SIZE_VECTOR` | `Vector{Int32}` | +| `SIZE_COLS` | `Vector{Int32}` | +| `SIZE_ROWS` | `Vector{Int32}` | + +Scalar values get returned as a vector with a single element. This way it's possible to +modify the internal state of the LAMMPS instance even if the data is scalar. + +!!! info + The returned data may become invalid as soon as another LAMMPS command has been issued at any point after calling this method. + If this has happened, trying to read from this data will likely cause julia to crash. + To prevent this, set `copy=true`. + +# Examples + +```julia + LMP(["-screen", "none"]) do lmp + extract_compute(lmp, "thermo_temp", LMP_STYLE_GLOBAL, TYPE_VECTOR, copy=true)[2] = 2 + extract_compute(lmp, "thermo_temp", LMP_STYLE_GLOBAL, TYPE_VECTOR, copy=false)[3] = 3 + + extract_compute(lmp, "thermo_temp", LMP_STYLE_GLOBAL, TYPE_SCALAR) |> println # [0.0] + extract_compute(lmp, "thermo_temp", LMP_STYLE_GLOBAL, TYPE_VECTOR) |> println # [0.0, 0.0, 3.0, 0.0, 0.0, 0.0] + end +``` """ -function extract_compute(lmp::LMP, name, style, type) - ptr_or_value = unsafe_extract_compute(lmp, name, style, type) - if style == API.LMP_TYPE_SCALAR - return ptr_or_value +function extract_compute(lmp::LMP, name::String, style::_LMP_STYLE_CONST, lmp_type::_LMP_TYPE; copy::Bool=false) + API.lammps_has_id(lmp, "compute", name) != 1 && throw(KeyError("Unknown compute $name")) + + void_ptr = API.lammps_extract_compute(lmp, name, style, get_enum(lmp_type)) + void_ptr == C_NULL && error("Compute $name doesn't have data matching $style, $(get_enum(lmp_type))") + + # `lmp_type in (SIZE_COLS, SIZE_ROWS, SIZE_VECTOR)` causes type instability for some reason + if lmp_type == SIZE_COLS || lmp_type == SIZE_ROWS || lmp_type == SIZE_VECTOR + ptr = _reinterpret(LAMMPS_INT, void_ptr) + return _extract(ptr, 1; copy=copy) end - if ptr_or_value === nothing - return nothing + + if lmp_type == TYPE_SCALAR + ptr = _reinterpret(LAMMPS_DOUBLE, void_ptr) + return _extract(ptr, 1; copy=copy) end - ptr = ptr_or_value::Ptr - - if style in (API.LMP_STYLE_GLOBAL, API.LMP_STYLE_LOCAL) - if type == API.LMP_TYPE_VECTOR - nrows = unsafe_extract_compute(lmp, name, style, API.LMP_SIZE_VECTOR) - return unsafe_wrap(ptr, (nrows,)) - elseif type == API.LMP_TYPE_ARRAY - nrows = unsafe_extract_compute(lmp, name, style, API.LMP_SIZE_ROWS) - ncols = unsafe_extract_compute(lmp, name, style, API.LMP_SIZE_COLS) - return unsafe_wrap(ptr, (nrows, ncols)) - end - else style = API.LMP_STYLE_ATOM - nlocal = extract_global(lmp, "nlocal") - if type == API.LMP_TYPE_VECTOR - return unsafe_wrap(ptr, (nlocal,)) - elseif type == API.LMP_TYPE_ARRAY - ncols = unsafe_extract_compute(lmp, name, style, API.LMP_SIZE_COLS) - return unsafe_wrap(ptr, (nlocal, ncols)) - end + + if lmp_type == TYPE_VECTOR + ndata = (style == STYLE_ATOM) ? + extract_setting(lmp, "nlocal") : + extract_compute(lmp, name, style, SIZE_VECTOR)[] + + ptr = _reinterpret(LAMMPS_DOUBLE, void_ptr) + return _extract(ptr, ndata; copy=copy) end - return nothing + + ndata = (style == STYLE_ATOM) ? + extract_setting(lmp, "nlocal") : + extract_compute(lmp, name, style, SIZE_ROWS)[] + + count = extract_compute(lmp, name, style, SIZE_COLS)[] + ptr = _reinterpret(LAMMPS_DOUBLE_2D, void_ptr) + + return _extract(ptr, (count, ndata); copy=copy) end """ - extract_variable(lmp::LMP, name, group) + extract_variable(lmp::LMP, name::String, lmp_variable::LMP_VARIABLE, group::Union{String, Nothing}=nothing; copy::Bool=false) Extracts the data from a LAMMPS variable. When the variable is either an `equal`-style compatible variable, a `vector`-style variable, or an `atom`-style variable, the variable is evaluated and the corresponding value(s) returned. Variables of style `internal` are compatible with `equal`-style variables, if they return a numeric value. For other variable styles, their string value is returned. + +| valid values for `lmp_variable`: | return type | +| :------------------------------- | :--------------- | +| `VAR_ATOM` | `Vector{Float64}` | +| `VAR_EQUAL` | `Float64` | +| `VAR_STRING` | `String` | +| `VAR_VECTOR` | `Vector{Float64}` | + +the kwarg `copy` determies wheter a copy of the underlying data is made. +`copy` is only aplicable for `VAR_VECTOR` and `VAR_ATOM`. For all other variable types, a copy will be made regardless. +The underlying LAMMPS API call for `VAR_ATOM` internally allways creates a copy of the data. As the memory for this gets allocated by LAMMPS instead of julia, +it needs to be dereferenced using `LAMMPS.API.lammps_free` instead of through the garbage collector. +If `copy=false` this gets acieved by registering `LAMMPS.API.lammps_free` as a finalizer for the returned data. +Alternatively, setting `copy=true` will instead create a new copy of the data. The lammps allocated block of memory will then be freed immediately. + +the kwarg `group` determines for which atoms the variable will be extracted. It's only aplicable for +`VAR_ATOM` and will cause an error if used for other variable types. The entires for all atoms not in the group +will be zeroed out. By default, all atoms will be extracted. """ -function extract_variable(lmp::LMP, name::String, group=nothing) - var = API.lammps_extract_variable_datatype(lmp, name) - if var == -1 - throw(KeyError(name)) - end - if group === nothing +function extract_variable(lmp::LMP, name::String, lmp_variable::_LMP_VARIABLE, group::Union{String, Nothing}=nothing; copy::Bool=false) + lmp_variable != VAR_ATOM && !isnothing(group) && throw(ArgumentError("the group parameter is only supported for per atom variables!")) + + if isnothing(group) group = C_NULL end - if var == API.LMP_VAR_EQUAL - ptr = API.lammps_extract_variable(lmp, name, C_NULL) - val = Base.unsafe_load(Base.unsafe_convert(Ptr{Float64}, ptr)) - API.lammps_free(ptr) - return val - elseif var == API.LMP_VAR_ATOM - nlocal = extract_global(lmp, "nlocal") - ptr = API.lammps_extract_variable(lmp, name, group) - if ptr == C_NULL - error("Group $group for variable $name with style atom not available.") - end - # LAMMPS uses malloc, so and we are taking ownership of this buffer - val = copy(Base.unsafe_wrap(Array, Base.unsafe_convert(Ptr{Float64}, ptr), nlocal; own=false)) - API.lammps_free(ptr) - return val - elseif var == API.LMP_VAR_VECTOR - # TODO Fix lammps docs `GET_VECTOR_SIZE` - ptr = API.lammps_extract_variable(lmp, name, "LMP_SIZE_VECTOR") - if ptr == C_NULL - error("$name is a vector style variable but has no size.") + void_ptr = API.lammps_extract_variable(lmp, name, group) + void_ptr == C_NULL && throw(KeyError("Unknown variable $name")) + + expect = extract_variable_datatype(lmp, name) + receive = get_enum(lmp_variable) + if expect != receive + # the documentation instructs us to free the pointers for these styles specifically + if expect in (API.LMP_VAR_ATOM, API.LMP_VAR_EQUAL) + API.lammps_free(void_ptr) end - sz = Base.unsafe_load(Base.unsafe_convert(Ptr{Cint}, ptr)) + + error("TypeMismatch: Expected $expect got $receive instead!") + end + + if lmp_variable == VAR_EQUAL + ptr = _reinterpret(LAMMPS_DOUBLE, void_ptr) + result = unsafe_load(ptr) API.lammps_free(ptr) - ptr = API.lammps_extract_variable(lmp, name, C_NULL) - return Base.unsafe_wrap(Array, Base.unsafe_convert(Ptr{Float64}, ptr), sz, own=false) - elseif var == API.LMP_VAR_STRING - ptr = API.lammps_extract_variable(lmp, name, C_NULL) - return Base.unsafe_string(Base.unsafe_convert(Ptr{Cchar}, ptr)) - else - error("Unkown variable style $var") + return result end + + if lmp_variable == VAR_VECTOR + # Calling lammps_extract_variable directly through the API instead of the higher level wrapper, as + # "LMP_SIZE_VECTOR" is the only group name that won't be ignored for Vector Style Variables. + # This isn't exposed to the high level API as it causes type instability for something that probably won't + # ever be used outside of this implementation + ndata_ptr = _reinterpret(LAMMPS_INT, API.lammps_extract_variable(lmp, name, "LMP_SIZE_VECTOR")) + ndata = unsafe_load(ndata_ptr) + API.lammps_free(ndata_ptr) + + ptr = _reinterpret(LAMMPS_DOUBLE, void_ptr) + return _extract(ptr, ndata; copy=copy) + end + + if lmp_variable == VAR_ATOM + ndata = extract_setting(lmp, "nlocal") + ptr = _reinterpret(LAMMPS_DOUBLE, void_ptr) + # lammps expects us to take ownership of the data + return _extract(ptr, ndata; copy=copy, own=true) + end + + ptr = _reinterpret(LAMMPS_STRING, void_ptr) + return _string(ptr) end +function extract_variable_datatype(lmp::LMP, name) + return API._LMP_VAR_CONST(API.lammps_extract_variable_datatype(lmp, name)) +end + + @deprecate gather_atoms(lmp::LMP, name, T, count) gather(lmp, name, T) diff --git a/test/runtests.jl b/test/runtests.jl index 8cd826b..267cfcc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,46 +15,136 @@ LMP(["-screen", "none"]) do lmp @test_throws ErrorException command(lmp, "nonsense") end +@testset "Extract Setting/Global" begin + LMP(["-screen", "none"]) do lmp + command(lmp, """ + atom_modify map yes + region cell block 0 1 0 2 0 3 + create_box 1 cell + """) + + @test extract_global(lmp, "dt", LAMMPS_DOUBLE)[] isa Float64 + @test extract_global(lmp, "boxhi", LAMMPS_DOUBLE) == [1, 2, 3] + @test extract_global(lmp, "nlocal", LAMMPS_INT)[] == extract_setting(lmp, "nlocal") == 0 + + with_copy1 = extract_global(lmp, "periodicity", LAMMPS_INT, copy=true) + with_copy2 = extract_global(lmp, "periodicity", LAMMPS_INT, copy=true) + + @test pointer(with_copy1) != pointer(with_copy2) + + without_copy1 = extract_global(lmp, "periodicity", LAMMPS_INT, copy=false) + without_copy2 = extract_global(lmp, "periodicity", LAMMPS_INT, copy=false) + + @test pointer(with_copy1) != pointer(with_copy2) + + # verify that no errors were missed + @test LAMMPS.API.lammps_has_error(lmp) == 0 + end +end + +@testset "Extract Atom" begin + LMP(["-screen", "none"]) do lmp + command(lmp, """ + atom_modify map yes + region cell block 0 3 0 3 0 3 + create_box 1 cell + lattice sc 1 + create_atoms 1 region cell + mass 1 1 + """) + + @test extract_atom(lmp, "mass", LAMMPS_DOUBLE) isa Vector{Float64} + @test extract_atom(lmp, "mass", LAMMPS_DOUBLE) == [1] + + x = extract_atom(lmp, "x", LAMMPS_DOUBLE_2D) + @test size(x) == (3, 27) + + @test extract_atom(lmp, "image", LAMMPS_INT) isa Vector{Int32} + + @test_throws ErrorException extract_atom(lmp, "v", LAMMPS_DOUBLE) + + # verify that no errors were missed + @test LAMMPS.API.lammps_has_error(lmp) == 0 + end +end + @testset "Variables" begin LMP(["-screen", "none"]) do lmp - command(lmp, "box tilt large") - command(lmp, "region cell block 0 1.0 0 1.0 0 1.0 units box") - command(lmp, "create_box 1 cell") - command(lmp, "create_atoms 1 random 10 1 NULL") - command(lmp, "compute press all pressure NULL pair"); - command(lmp, "fix press all ave/time 1 1 1 c_press mode vector"); - - command(lmp, "variable var1 equal 1.0") - command(lmp, "variable var2 string \"hello\"") - command(lmp, "variable var3 atom x") - # TODO: x is 3d, how do we access more than the first dims - command(lmp, "variable var4 vector f_press") - - @test LAMMPS.extract_variable(lmp, "var1") == 1.0 - @test LAMMPS.extract_variable(lmp, "var2") == "hello" - x = LAMMPS.extract_atom(lmp, "x") - x_var = LAMMPS.extract_variable(lmp, "var3") + command(lmp, """ + box tilt large + region cell block 0 1.0 0 1.0 0 1.0 units box + create_box 1 cell + create_atoms 1 random 10 1 NULL + compute press all pressure NULL pair + fix press all ave/time 1 1 1 c_press mode vector + + variable var1 equal 1.0 + variable var2 string \"hello\" + variable var3 atom x + # TODO: x is 3d, how do we access more than the first dims + variable var4 vector f_press + group odd id 1 3 5 7 + """) + + @test extract_variable(lmp, "var1", VAR_EQUAL) == 1.0 + @test extract_variable(lmp, "var2", VAR_STRING) == "hello" + x = extract_atom(lmp, "x", LAMMPS_DOUBLE_2D) + x_var = extract_variable(lmp, "var3", VAR_ATOM) @test length(x_var) == 10 @test x_var == x[1, :] - press = LAMMPS.extract_variable(lmp, "var4") + press = extract_variable(lmp, "var4", VAR_VECTOR) @test press isa Vector{Float64} + + x_var_group = extract_variable(lmp, "var3", VAR_ATOM, "odd") + in_group = BitVector((1, 0, 1, 0, 1, 0, 1, 0, 0, 0)) + + @test x_var_group[in_group] == x[1, in_group] + @test all(x_var_group[.!in_group] .== 0) + + @test_throws ErrorException extract_variable(lmp, "var3", VAR_EQUAL) + + # verify that no errors were missed + @test LAMMPS.API.lammps_has_error(lmp) == 0 end + + # check if the memory allocated by LAMMPS persists after closing the instance + lmp = LMP(["-screen", "none"]) + command(lmp, """ + atom_modify map yes + region cell block 0 3 0 3 0 3 + create_box 1 cell + lattice sc 1 + create_atoms 1 region cell + mass 1 1 + + variable var atom id + """) + + var = extract_variable(lmp, "var", VAR_ATOM) + var_copy = copy(var) + LAMMPS.close!(lmp) + + @test var == var_copy + end @testset "gather/scatter" begin LMP(["-screen", "none"]) do lmp # setting up example data - command(lmp, "atom_modify map yes") - command(lmp, "region cell block 0 3 0 3 0 3") - command(lmp, "create_box 1 cell") - command(lmp, "lattice sc 1") - command(lmp, "create_atoms 1 region cell") - command(lmp, "mass 1 1") + command(lmp, """ + atom_modify map yes + region cell block 0 3 0 3 0 3 + create_box 1 cell + lattice sc 1 + create_atoms 1 region cell + mass 1 1 + + compute pos all property/atom x y z + fix pos all ave/atom 10 1 10 c_pos[1] c_pos[2] c_pos[3] - command(lmp, "compute pos all property/atom x y z") - command(lmp, "fix pos all ave/atom 10 1 10 c_pos[1] c_pos[2] c_pos[3]") + run 10 + """) - command(lmp, "run 10") data = zeros(Float64, 3, 27) subset = Int32.([2,5,10, 5]) data_subset = ones(Float64, 3, 4) @@ -98,6 +188,37 @@ end @test gather(lmp, "x", Float64, subset) == gather(lmp, "c_pos", Float64, subset) == gather(lmp, "f_pos", Float64, subset) == data_subset + # verify that no errors were missed + @test LAMMPS.API.lammps_has_error(lmp) == 0 + end +end + +@testset "Extract Compute" begin + LMP(["-screen", "none"]) do lmp + command(lmp, """ + atom_modify map yes + region cell block 0 3 0 3 0 3 + create_box 1 cell + lattice sc 1 + create_atoms 1 region cell + mass 1 1 + + compute pos all property/atom x y z + """) + + @test extract_compute(lmp, "pos", STYLE_ATOM, TYPE_ARRAY) == extract_atom(lmp, "x", LAMMPS_DOUBLE_2D) + + extract_compute(lmp, "thermo_temp", STYLE_GLOBAL, TYPE_VECTOR, copy=true)[2] = 2 + extract_compute(lmp, "thermo_temp", STYLE_GLOBAL, TYPE_VECTOR, copy=false)[3] = 3 + + @test extract_compute(lmp, "thermo_temp", STYLE_GLOBAL, TYPE_SCALAR) == [0.0] + @test extract_compute(lmp, "thermo_temp", STYLE_GLOBAL, TYPE_VECTOR) == [0.0, 0.0, 3.0, 0.0, 0.0, 0.0] + + @test_throws ErrorException extract_compute(lmp, "thermo_temp", STYLE_ATOM, TYPE_SCALAR) + @test_throws ErrorException extract_compute(lmp, "thermo_temp", STYLE_GLOBAL, TYPE_ARRAY) + + # verify that no errors were missed + @test LAMMPS.API.lammps_has_error(lmp) == 0 end end @@ -133,6 +254,9 @@ end @test get_category_ids(lmp, "compute") == ["thermo_temp", "thermo_press", "thermo_pe", "pos"] # some of these computes are there by default it seems @test get_category_ids(lmp, "fix") == ["1"] @test_throws ErrorException get_category_ids(lmp, "nonesense") + + # verify that no errors were missed + @test LAMMPS.API.lammps_has_error(lmp) == 0 end end