Skip to content

Commit

Permalink
Add DummyModel for loading unknown models (#150)
Browse files Browse the repository at this point in the history
* add DummyModel

* start implementing tests

* reorganize for DummyModel

* fixes

* fix required methods

* update dummy model save

* fix names

* cleanup model compat

* autoconvert to Float64

* fix test errors
  • Loading branch information
ffreyer authored Jan 22, 2022
1 parent 592e824 commit 5c2c75c
Show file tree
Hide file tree
Showing 9 changed files with 233 additions and 122 deletions.
33 changes: 14 additions & 19 deletions src/FileIO.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,19 @@ function save(
backend.jldopen(filename, mode, compress=compress; kwargs...), filename
)

write(file, "VERSION", 1)
save_mc(file, mc, "MC")
save_rng(file)
close(file.file)
try
write(file, "VERSION", 1)
save_mc(file, mc, "MC")
save_rng(file)
catch e
if overwrite && !isempty(temp_filename) && isfile(temp_filename)
rm(filename)
mv(temp_filename, filename)
end
@error exception = e
finally
close(file.file)
end

if overwrite && !isempty(temp_filename) && isfile(temp_filename)
rm(temp_filename)
Expand Down Expand Up @@ -272,21 +281,7 @@ function save_model(file::JLDFile, model, entryname::String)
nothing
end

"""
_load(data, ::Type{...})
Loads `data` where `data` is either a `JLD2.JLDFile`, `JLD2.Group` or a `Dict`.
The default `_load` will check that `data["VERSION"] == 0` and simply return
`data["data"]`. You may implement `_load(data, ::Type{<: MyType})` to add
specialized loading behavior.
"""
function _load(data, ::Val{:Generic}) where T
data["VERSION"] == 0 || throw(ErrorException(
"Version $(data["VERSION"]) incompatabile with default _load for $T."
))
data["data"]
end
_load(data, ::Val{:Generic}) = data["data"]


# save_lattice(filename, lattice, entryname)
Expand Down
4 changes: 3 additions & 1 deletion src/MonteCarlo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ Base.keys(fw::FileWrapper) = keys(fw.file)

# To allow switching between JLD and JLD2:
const UnknownType = Union{JLD.UnsupportedType, JLD2.UnknownType}
const JLDFile = Union{FileWrapper{<: JLD.JldFile}, FileWrapper{<: JLD2.JLDFile}, JLD.JldFile, JLD2.JLDFile}
const _JLDFile = Union{JLD.JldFile, JLD2.JLDFile, Dict, JLD2.Group}
const JLDFile = Union{FileWrapper{<: _JLDFile}, _JLDFile}



Expand Down Expand Up @@ -103,6 +104,7 @@ export ReplicaExchange, ReplicaPull, connect, disconnect

include("models/Ising/IsingModel.jl")
include("models/HubbardModel.jl")
include("models/DummyModel.jl")
export IsingEnergyMeasurement, IsingMagnetizationMeasurement

include("FileIO.jl")
Expand Down
4 changes: 2 additions & 2 deletions src/flavors/DQMC/FileIO.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ function _load(data, ::Val{:DQMC})
analysis = _load(data["Analysis"], Val(:DQMCAnalysis))
recorder = _load(data["configs"], to_tag(data["configs"]))
last_sweep = data["last_sweep"]
model = _load(data["Model"], to_tag(data["Model"]))
model = _load_model(data["Model"], to_tag(data["Model"]))
if haskey(data, "field")
field = load_field(data["field"], Val(:Field), parameters, model)
else
conf = data["conf"]
field = choose_field(model)(parameters, model)
field = field_hint(model, to_tag(data["Model"]))(parameters, model)
conf!(field, conf)
end
scheduler = if haskey(data, "Scheduler")
Expand Down
3 changes: 1 addition & 2 deletions src/flavors/MC/MC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -470,8 +470,7 @@ function _load(data, ::Val{:MC})
conf = data["conf"]
configs = _load(data["configs"], to_tag(data["configs"]))
last_sweep = data["last_sweep"]
model = _load(data["Model"], to_tag(data["Model"]))

model = _load_model(data["Model"], to_tag(data["Model"]))
measurements = _load(data["Measurements"], Val(:Measurements))

MC(
Expand Down
59 changes: 59 additions & 0 deletions src/models/DummyModel.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""
DummyModel()
If an unknown or undefined model is loaded a `DummyModel` will be created to
allow the load to succeed and keep the data in check.
"""
struct DummyModel <: Model
data::Dict{String, Any}
end

Base.show(io::IO, model::DummyModel) = print(io, "DummyModel()")

function Base.getproperty(obj::DummyModel, field::Symbol)
if hasfield(DummyModel, field)
return getfield(obj, field)
else
return getfield(obj, :data)[string(field)]
end
end


choose_field(::DummyModel) = DensityHirschField
nflavors(::DummyModel) = 1
lattice(m::DummyModel) = m.l


function save_model(file::JLDFile, m::DummyModel, entryname::String="Model")
close(file)
error("DummyModel cannot be saved.")
end

function _load_model(data, ::Val)
tag = to_tag(data)
@warn "Failed to load $tag, creating DummyModel"
dict = _load_to_dict(data)
if haskey(dict, "data")
x = pop!(dict, "data")
push!(dict, x...)
end
DummyModel(dict)
end

_load_to_dict(file::FileWrapper) = _load_to_dict(file.file)
function _load_to_dict(data)
if parentmodule(typeof(data)) == JLD2.ReconstructedTypes
Dict(map(fieldnames(typeof(data))) do f
string(f) => _load_to_dict(getfield(data, f))
end)
else
data
end
end
function _load_to_dict(data::Union{JLD.JldFile, JLD2.JLDFile, JLD2.Group})
output = Dict{String, Any}()
for key in keys(data)
push!(output, key => _load_to_dict(data[key]))
end
output
end
16 changes: 13 additions & 3 deletions src/models/HubbardModel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ struct HubbardModel{LT <: AbstractLattice} <: Model
l::LT
end

@inline function HubbardModel(t::Real, mu::Real, U::Real, l::AbstractLattice)
HubbardModel(Float64(t), Float64(mu), Float64(U), l)
end

function HubbardModel(;
dims = 2, L = 2, l = choose_lattice(HubbardModel, dims, L),
U = 1.0, mu = 0.0, t = 1.0
Expand Down Expand Up @@ -148,15 +152,21 @@ function save_model(file::JLDFile, m::HubbardModel, entryname::String = "Model")
nothing
end

function _load(data, ::Val{:HubbardModel})
# compat
function _load_model(data, ::Val{:HubbardModel})
l = _load(data["l"], to_tag(data["l"]))
HubbardModel(data["t"], data["mu"], data["U"], l)
end
_load(data, ::Val{:HubbardModelAttractive}) = _load(data, Val(:HubbardModel))
function _load(data, ::Val{:HubbardModelRepulsive})
_load_model(data, ::Val{:HubbardModelAttractive}) = _load_model(data, Val(:HubbardModel))
function _load_model(data, ::Val{:HubbardModelRepulsive})
l = _load(data["l"], to_tag(data["l"]))
HubbardModel(data["t"], 0.0, -data["U"], l)
end
_load_model(data, ::Val{:AttractiveGHQHubbardModel}) = _load_model(data, Val(:HubbardModelAttractive))
_load_model(data, ::Val{:RepulsiveGHQHubbardModel}) = _load_model(data, Val(:HubbardModelRepulsive))
field_hint(m, ::Val) = choose_field(m)
field_hint(m, ::Val{:AttractiveGHQHubbardModel}) = MagneticGHQField
field_hint(m, ::Val{:RepulsiveGHQHubbardModel}) = MagneticGHQField

function intE_kernel(mc, model::HubbardModel, G::GreensMatrix, ::Val{1})
# ⟨U (n↑ - 1/2)(n↓ - 1/2)⟩ = ...
Expand Down
2 changes: 1 addition & 1 deletion src/models/Ising/IsingModel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ end
#
# Loads an IsingModel from a given `data` dictionary produced by
# `JLD.load(filename)`.
function _load(data, ::Val{:IsingModel})
function _load_model(data, ::Val{:IsingModel})
if !(data["VERSION"] == 1)
throw(ErrorException("Failed to load IsingModel version $(data["VERSION"])"))
end
Expand Down
Loading

0 comments on commit 5c2c75c

Please sign in to comment.