Skip to content

Commit

Permalink
commented out segaulting test; cf #81
Browse files Browse the repository at this point in the history
  • Loading branch information
cortner committed Dec 7, 2021
1 parent f6caa86 commit 3eda4ca
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 18 deletions.
17 changes: 17 additions & 0 deletions src/ACE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,23 @@ include("testing/testing.jl")
include("ad.jl")


# ---------------- some extra experimental dispatching

evaluate(basis::SymmetricBasis, Xs::AbstractVector) =
evaluate(basis, ACEConfig(Xs))

evaluate_d(basis::SymmetricBasis, Xs::AbstractVector) =
evaluate_d(basis, ACEConfig(Xs))

evaluate(model::LinearACEModel, Xs::AbstractVector) =
evaluate(model, ACEConfig(Xs))

grad_config(model::LinearACEModel, Xs::AbstractVector) =
grad_config(model, ACEConfig(Xs))

grad_params(model::LinearACEModel, Xs::AbstractVector) =
grad_params(model, ACEConfig(Xs))

end # module


Expand Down
11 changes: 11 additions & 0 deletions src/properties.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ struct EuclideanVector{T} <: AbstractProperty
val::SVector{3, T}
end

Base.show(io::IO, φ::EuclideanVector) =
print(io, "e[$(φ.val[1]), $(φ.val[2]), $(φ.val[3])]")

real::EuclideanVector) = EuclideanVector(real.(φ.val))
complex::EuclideanVector) = EuclideanVector(complex.val))
Expand Down Expand Up @@ -254,6 +256,7 @@ end

SphericalVector(val::SVector) = SphericalVector(val, Val(__getL(val)))


# # differentiation - cf #27
# *(φ::SphericalVector, dAA::SVector) = φ.val * dAA'

Expand Down Expand Up @@ -356,6 +359,14 @@ struct SphericalMatrix{L1, L2, LEN1, LEN2, T, LL} <: AbstractProperty
_valL2::Val{L2}
end

function Base.show(io::IO, φ::Union{SphericalVector, SphericalMatrix})
buffer = IOBuffer()
print(buffer, round.(φ.val, digits=4))
str = String(take!(buffer))
lenn = findfirst('[', str)
print(io, "y" * str[lenn:end])
end

SphericalMatrix(val::SMatrix) = SphericalMatrix(val, Val.(__getL1L2(val))...)

# differentiation - cf #27
Expand Down
36 changes: 18 additions & 18 deletions test/test_admodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,28 +74,28 @@ grad_fsmodelp(θ)
ACEbase.Testing.fdtest(fsmodelp, grad_fsmodelp, θ)


##
## THIS TEST CURRENTLY THROWS A SEGFAULT

@info("Check AD for a second partial derivative w.r.t cfg and params")
# @info("Check AD for a second partial derivative w.r.t cfg and params")

fsmodel1 = (model, cfg) -> FS(evaluate(model, cfg))
grad_fsmodel1 = (model, cfg) -> Zygote.gradient(x -> fsmodel1(model, x), cfg)[1]
# fsmodel1 = (model, cfg) -> FS(evaluate(model, cfg))
# grad_fsmodel1 = (model, cfg) -> Zygote.gradient(x -> fsmodel1(model, x), cfg)[1]

y = randn(SVector{3, Float64}, length(cfg))
loss1 = model -> sum(sum(abs2, g.rr - y)
for (g, y) in zip(grad_fsmodel1(model, cfg), y))
# y = randn(SVector{3, Float64}, length(cfg))
# loss1 = model -> sum(sum(abs2, g.rr - y)
# for (g, y) in zip(grad_fsmodel1(model, cfg), y))

# check that loss and gradient evaluate ok
loss1(model)
g = Zygote.gradient(loss1, model)[1]
# # check that loss and gradient evaluate ok
# loss1(model)
# g = Zygote.gradient(loss1, model)[1] # SEGFAULT IN THIS LINE ON J1.7!!!

# wrappers to take derivatives w.r.t. the vector or parameters
F1 = θ -> ( ACE.set_params!(model, mat2svecs(θ));
loss1(model) )
# # wrappers to take derivatives w.r.t. the vector or parameters
# F1 = θ -> ( ACE.set_params!(model, mat2svecs(θ));
# loss1(model) )

dF1 = θ -> ( ACE.set_params!(model, mat2svecs(θ));
Zygote.gradient(loss1, model)[1] |> svecs2vec )
# dF1 = θ -> ( ACE.set_params!(model, mat2svecs(θ));
# Zygote.gradient(loss1, model)[1] |> svecs2vec )

F1(θ)
dF1(θ)
ACEbase.Testing.fdtest(F1, dF1, θ; verbose=true)
# F1(θ)
# dF1(θ)
# ACEbase.Testing.fdtest(F1, dF1, θ; verbose=true)
12 changes: 12 additions & 0 deletions test/test_linearmodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ using ACE: evaluate, evaluate_d, SymmetricBasis, PIBasis,
grad_config, grad_params, O3
using ACEbase.Testing: fdtest


randconfig(B1p, nX) = ACEConfig( rand(PositionState{Float64}, B1p.bases[1], nX) )

##
Expand All @@ -25,6 +26,7 @@ B1p = ACE.Utils.RnYlm_1pbasis(; maxdeg=maxdeg)

# generate a configuration
cfg = randconfig(B1p, 10)
Xs = cfg.Xs

φ = ACE.Invariant()
basis = SymmetricBasis(φ, B1p, O3(), Bsel)
Expand Down Expand Up @@ -80,6 +82,16 @@ end

##

@info("Evaluate LinearACEModel with vector")
for _ = 1:20
cfg = randconfig(B1p, rand(8:15))
print_tf(@test evaluate(standard, cfg) evaluate(standard, cfg.Xs))
print_tf(@test grad_config(standard, cfg) grad_config(standard, cfg.Xs))
print_tf(@test grad_params(standard, cfg) grad_params(standard, cfg.Xs))
end

##

@info("Test a Linear Model with EuclideanVector output")
maxdeg = 6; ord = 3
Bsel = SimpleSparseBasis(ord, maxdeg)
Expand Down
4 changes: 4 additions & 0 deletions test/test_symmbasis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ AA = evaluate(basis.pibasis, cfg)
BB1 = basis.A2Bmap * AA
println(@test isapprox(BB, BB1, rtol=1e-10))

@info("evaluate with vector vs config")
println(@test BB evaluate(basis, Xs))
println(@test evaluate_d(basis, cfg) evaluate_d(basis, Xs))

# check there are no superfluous columns
Iz = findall(iszero, sum(norm, basis.A2Bmap, dims=1)[:])
if !isempty(Iz)
Expand Down

0 comments on commit 3eda4ca

Please sign in to comment.