From 5c2c75c1ca57519ffd01e19adff1fb9a6538556d Mon Sep 17 00:00:00 2001 From: Frederic Freyer Date: Sat, 22 Jan 2022 15:23:59 +0100 Subject: [PATCH] Add DummyModel for loading unknown models (#150) * add DummyModel * start implementing tests * reorganize for DummyModel * fixes * fix required methods * update dummy model save * fix names * cleanup model compat * autoconvert to Float64 * fix test errors --- src/FileIO.jl | 33 ++--- src/MonteCarlo.jl | 4 +- src/flavors/DQMC/FileIO.jl | 4 +- src/flavors/MC/MC.jl | 3 +- src/models/DummyModel.jl | 59 +++++++++ src/models/HubbardModel.jl | 16 ++- src/models/Ising/IsingModel.jl | 2 +- test/FileIO.jl | 234 ++++++++++++++++++++------------- test/assets/dummy_in.jld2 | Bin 0 -> 12175 bytes 9 files changed, 233 insertions(+), 122 deletions(-) create mode 100644 src/models/DummyModel.jl create mode 100644 test/assets/dummy_in.jld2 diff --git a/src/FileIO.jl b/src/FileIO.jl index 0890a93c..d3938ab6 100644 --- a/src/FileIO.jl +++ b/src/FileIO.jl @@ -65,10 +65,19 @@ function save( backend.jldopen(filename, mode, compress=compress; kwargs...), filename ) - write(file, "VERSION", 1) - save_mc(file, mc, "MC") - save_rng(file) - close(file.file) + try + write(file, "VERSION", 1) + save_mc(file, mc, "MC") + save_rng(file) + catch e + if overwrite && !isempty(temp_filename) && isfile(temp_filename) + rm(filename) + mv(temp_filename, filename) + end + @error exception = e + finally + close(file.file) + end if overwrite && !isempty(temp_filename) && isfile(temp_filename) rm(temp_filename) @@ -272,21 +281,7 @@ function save_model(file::JLDFile, model, entryname::String) nothing end -""" - _load(data, ::Type{...}) - -Loads `data` where `data` is either a `JLD2.JLDFile`, `JLD2.Group` or a `Dict`. - -The default `_load` will check that `data["VERSION"] == 0` and simply return -`data["data"]`. You may implement `_load(data, ::Type{<: MyType})` to add -specialized loading behavior. -""" -function _load(data, ::Val{:Generic}) where T - data["VERSION"] == 0 || throw(ErrorException( - "Version $(data["VERSION"]) incompatabile with default _load for $T." - )) - data["data"] -end +_load(data, ::Val{:Generic}) = data["data"] # save_lattice(filename, lattice, entryname) diff --git a/src/MonteCarlo.jl b/src/MonteCarlo.jl index b42a37eb..39898778 100644 --- a/src/MonteCarlo.jl +++ b/src/MonteCarlo.jl @@ -53,7 +53,8 @@ Base.keys(fw::FileWrapper) = keys(fw.file) # To allow switching between JLD and JLD2: const UnknownType = Union{JLD.UnsupportedType, JLD2.UnknownType} -const JLDFile = Union{FileWrapper{<: JLD.JldFile}, FileWrapper{<: JLD2.JLDFile}, JLD.JldFile, JLD2.JLDFile} +const _JLDFile = Union{JLD.JldFile, JLD2.JLDFile, Dict, JLD2.Group} +const JLDFile = Union{FileWrapper{<: _JLDFile}, _JLDFile} @@ -103,6 +104,7 @@ export ReplicaExchange, ReplicaPull, connect, disconnect include("models/Ising/IsingModel.jl") include("models/HubbardModel.jl") +include("models/DummyModel.jl") export IsingEnergyMeasurement, IsingMagnetizationMeasurement include("FileIO.jl") diff --git a/src/flavors/DQMC/FileIO.jl b/src/flavors/DQMC/FileIO.jl index 198c6d4f..5c649e2a 100644 --- a/src/flavors/DQMC/FileIO.jl +++ b/src/flavors/DQMC/FileIO.jl @@ -49,12 +49,12 @@ function _load(data, ::Val{:DQMC}) analysis = _load(data["Analysis"], Val(:DQMCAnalysis)) recorder = _load(data["configs"], to_tag(data["configs"])) last_sweep = data["last_sweep"] - model = _load(data["Model"], to_tag(data["Model"])) + model = _load_model(data["Model"], to_tag(data["Model"])) if haskey(data, "field") field = load_field(data["field"], Val(:Field), parameters, model) else conf = data["conf"] - field = choose_field(model)(parameters, model) + field = field_hint(model, to_tag(data["Model"]))(parameters, model) conf!(field, conf) end scheduler = if haskey(data, "Scheduler") diff --git a/src/flavors/MC/MC.jl b/src/flavors/MC/MC.jl index 738ca421..d696b170 100644 --- a/src/flavors/MC/MC.jl +++ b/src/flavors/MC/MC.jl @@ -470,8 +470,7 @@ function _load(data, ::Val{:MC}) conf = data["conf"] configs = _load(data["configs"], to_tag(data["configs"])) last_sweep = data["last_sweep"] - model = _load(data["Model"], to_tag(data["Model"])) - + model = _load_model(data["Model"], to_tag(data["Model"])) measurements = _load(data["Measurements"], Val(:Measurements)) MC( diff --git a/src/models/DummyModel.jl b/src/models/DummyModel.jl new file mode 100644 index 00000000..9fa177b3 --- /dev/null +++ b/src/models/DummyModel.jl @@ -0,0 +1,59 @@ +""" + DummyModel() + +If an unknown or undefined model is loaded a `DummyModel` will be created to +allow the load to succeed and keep the data in check. +""" +struct DummyModel <: Model + data::Dict{String, Any} +end + +Base.show(io::IO, model::DummyModel) = print(io, "DummyModel()") + +function Base.getproperty(obj::DummyModel, field::Symbol) + if hasfield(DummyModel, field) + return getfield(obj, field) + else + return getfield(obj, :data)[string(field)] + end +end + + +choose_field(::DummyModel) = DensityHirschField +nflavors(::DummyModel) = 1 +lattice(m::DummyModel) = m.l + + +function save_model(file::JLDFile, m::DummyModel, entryname::String="Model") + close(file) + error("DummyModel cannot be saved.") +end + +function _load_model(data, ::Val) + tag = to_tag(data) + @warn "Failed to load $tag, creating DummyModel" + dict = _load_to_dict(data) + if haskey(dict, "data") + x = pop!(dict, "data") + push!(dict, x...) + end + DummyModel(dict) +end + +_load_to_dict(file::FileWrapper) = _load_to_dict(file.file) +function _load_to_dict(data) + if parentmodule(typeof(data)) == JLD2.ReconstructedTypes + Dict(map(fieldnames(typeof(data))) do f + string(f) => _load_to_dict(getfield(data, f)) + end) + else + data + end +end +function _load_to_dict(data::Union{JLD.JldFile, JLD2.JLDFile, JLD2.Group}) + output = Dict{String, Any}() + for key in keys(data) + push!(output, key => _load_to_dict(data[key])) + end + output +end \ No newline at end of file diff --git a/src/models/HubbardModel.jl b/src/models/HubbardModel.jl index eca8c1a8..eb569106 100644 --- a/src/models/HubbardModel.jl +++ b/src/models/HubbardModel.jl @@ -27,6 +27,10 @@ struct HubbardModel{LT <: AbstractLattice} <: Model l::LT end +@inline function HubbardModel(t::Real, mu::Real, U::Real, l::AbstractLattice) + HubbardModel(Float64(t), Float64(mu), Float64(U), l) +end + function HubbardModel(; dims = 2, L = 2, l = choose_lattice(HubbardModel, dims, L), U = 1.0, mu = 0.0, t = 1.0 @@ -148,15 +152,21 @@ function save_model(file::JLDFile, m::HubbardModel, entryname::String = "Model") nothing end -function _load(data, ::Val{:HubbardModel}) +# compat +function _load_model(data, ::Val{:HubbardModel}) l = _load(data["l"], to_tag(data["l"])) HubbardModel(data["t"], data["mu"], data["U"], l) end -_load(data, ::Val{:HubbardModelAttractive}) = _load(data, Val(:HubbardModel)) -function _load(data, ::Val{:HubbardModelRepulsive}) +_load_model(data, ::Val{:HubbardModelAttractive}) = _load_model(data, Val(:HubbardModel)) +function _load_model(data, ::Val{:HubbardModelRepulsive}) l = _load(data["l"], to_tag(data["l"])) HubbardModel(data["t"], 0.0, -data["U"], l) end +_load_model(data, ::Val{:AttractiveGHQHubbardModel}) = _load_model(data, Val(:HubbardModelAttractive)) +_load_model(data, ::Val{:RepulsiveGHQHubbardModel}) = _load_model(data, Val(:HubbardModelRepulsive)) +field_hint(m, ::Val) = choose_field(m) +field_hint(m, ::Val{:AttractiveGHQHubbardModel}) = MagneticGHQField +field_hint(m, ::Val{:RepulsiveGHQHubbardModel}) = MagneticGHQField function intE_kernel(mc, model::HubbardModel, G::GreensMatrix, ::Val{1}) # ⟨U (n↑ - 1/2)(n↓ - 1/2)⟩ = ... diff --git a/src/models/Ising/IsingModel.jl b/src/models/Ising/IsingModel.jl index f1cecb69..ae9e22c3 100644 --- a/src/models/Ising/IsingModel.jl +++ b/src/models/Ising/IsingModel.jl @@ -205,7 +205,7 @@ end # # Loads an IsingModel from a given `data` dictionary produced by # `JLD.load(filename)`. -function _load(data, ::Val{:IsingModel}) +function _load_model(data, ::Val{:IsingModel}) if !(data["VERSION"] == 1) throw(ErrorException("Failed to load IsingModel version $(data["VERSION"])")) end diff --git a/test/FileIO.jl b/test/FileIO.jl index cfd55381..bcb8fa06 100644 --- a/test/FileIO.jl +++ b/test/FileIO.jl @@ -181,38 +181,40 @@ end end -function test_mc(mc, x) - # Check if loaded/replayed mc matches original - for f in fieldnames(typeof(mc.p)) - @test getfield(mc.p, f) == getfield(x.p, f) - end - @test mc.conf == x.conf - @test mc.model.L == x.model.L - @test mc.model.dims == x.model.dims - for f in fieldnames(typeof(mc.model.l)) - @test getfield(mc.model.l, f) == getfield(x.model.l, f) - end - # @test mc.model.neighs == x.model.neighs - @test mc.model.energy[] == x.model.energy[] - for (k, v) in mc.thermalization_measurements - for f in fieldnames(typeof(v)) - r = getfield(v, f) == getfield(x.thermalization_measurements[k], f) - r != true && @info "Check failed for $k -> $f" - @test r + + +@time @testset "MC" begin + + function test_mc(mc, x) + # Check if loaded/replayed mc matches original + for f in fieldnames(typeof(mc.p)) + @test getfield(mc.p, f) == getfield(x.p, f) end - end - for (k, v) in mc.measurements - for f in fieldnames(typeof(v)) - r = getfield(v, f) == getfield(x.measurements[k], f) - r != true && @info "Check failed for $k -> $f" - @test r + @test mc.conf == x.conf + @test mc.model.L == x.model.L + @test mc.model.dims == x.model.dims + for f in fieldnames(typeof(mc.model.l)) + @test getfield(mc.model.l, f) == getfield(x.model.l, f) end + # @test mc.model.neighs == x.model.neighs + @test mc.model.energy[] == x.model.energy[] + for (k, v) in mc.thermalization_measurements + for f in fieldnames(typeof(v)) + r = getfield(v, f) == getfield(x.thermalization_measurements[k], f) + r != true && @info "Check failed for $k -> $f" + @test r + end + end + for (k, v) in mc.measurements + for f in fieldnames(typeof(v)) + r = getfield(v, f) == getfield(x.measurements[k], f) + r != true && @info "Check failed for $k -> $f" + @test r + end + end + nothing end - nothing -end - -@time @testset "MC" begin model = IsingModel(dims=2, L=2) mc = MC( model, beta = 0.66, thermalization = 33, sweeps = 123, @@ -279,72 +281,6 @@ end rm("resumable_testfile.jld") end - -function test_dqmc(mc, x) - for f in fieldnames(typeof(mc.parameters)) - @test getfield(mc.parameters, f) == getfield(x.parameters, f) - end - @test mc.field.conf == x.field.conf - @test mc.model.mu == x.model.mu - @test mc.model.t == x.model.t - @test mc.model.U == x.model.U - for f in fieldnames(typeof(mc.model.l)) - @test getfield(mc.model.l, f) == getfield(x.model.l, f) - end - @test MonteCarlo.nflavors(mc.field) == MonteCarlo.nflavors(x.field) - @test mc.scheduler == x.scheduler - for (k, v) in mc.thermalization_measurements - for f in fieldnames(typeof(v)) - r = if getfield(v, f) isa LightObservable - # TODO - # implement == for LightObservable in MonteCarloObservable - getfield(v, f).B == getfield(x.measurements[k], f).B - else - getfield(v, f) == getfield(x.measurements[k], f) - end - r != true && @info "Check failed for $k -> $f" - @test r - end - end - for (k, v) in mc.measurements - for f in fieldnames(typeof(v)) - v isa MonteCarlo.DQMCMeasurement && f == :temp && continue - v isa MonteCarlo.DQMCMeasurement && f == :kernel && continue - r = if getfield(v, f) isa LightObservable - # TODO - # implement == for LightObservable in MonteCarloObservable - # TODO: implement ≈ for LightObservable, LogBinner, etc - r = true - a = getfield(v, f) - b = getfield(x.measurements[k], f) - for i in eachindex(getfield(v, f).B.compressors) - r = r && (a.B.compressors[i].value ≈ b.B.compressors[i].value) - r = r && (a.B.compressors[i].switch ≈ b.B.compressors[i].switch) - end - r = r && (a.B.x_sum ≈ b.B.x_sum) - r = r && (a.B.x2_sum ≈ b.B.x2_sum) - r = r && (a.B.count ≈ b.B.count) - elseif getfield(v, f) isa LogBinner - r = true - a = getfield(v, f) - b = getfield(x.measurements[k], f) - for i in eachindex(a.compressors) - r = r && (a.compressors[i].value ≈ b.compressors[i].value) - r = r && (a.compressors[i].switch ≈ b.compressors[i].switch) - end - r = r && (a.x_sum ≈ b.x_sum) - r = r && (a.x2_sum ≈ b.x2_sum) - r = r && (a.count ≈ b.count) - else - getfield(v, f) == getfield(x.measurements[k], f) - end - r != true && @info "Check failed for $k -> $f" - @test r - end - end - nothing -end - for file in readdir() if endswith(file, "jld") || endswith(file, "jld2") || endswith(file, ".confs") rm(file) @@ -352,7 +288,74 @@ for file in readdir() end + @time @testset "DQMC" begin + + function test_dqmc(mc, x) + for f in fieldnames(typeof(mc.parameters)) + @test getfield(mc.parameters, f) == getfield(x.parameters, f) + end + @test mc.field.conf == x.field.conf + @test mc.model.mu == x.model.mu + @test mc.model.t == x.model.t + @test mc.model.U == x.model.U + for f in fieldnames(typeof(mc.model.l)) + @test getfield(mc.model.l, f) == getfield(x.model.l, f) + end + @test MonteCarlo.nflavors(mc.field) == MonteCarlo.nflavors(x.field) + @test mc.scheduler == x.scheduler + for (k, v) in mc.thermalization_measurements + for f in fieldnames(typeof(v)) + r = if getfield(v, f) isa LightObservable + # TODO + # implement == for LightObservable in MonteCarloObservable + getfield(v, f).B == getfield(x.measurements[k], f).B + else + getfield(v, f) == getfield(x.measurements[k], f) + end + r != true && @info "Check failed for $k -> $f" + @test r + end + end + for (k, v) in mc.measurements + for f in fieldnames(typeof(v)) + v isa MonteCarlo.DQMCMeasurement && f == :temp && continue + v isa MonteCarlo.DQMCMeasurement && f == :kernel && continue + r = if getfield(v, f) isa LightObservable + # TODO + # implement == for LightObservable in MonteCarloObservable + # TODO: implement ≈ for LightObservable, LogBinner, etc + r = true + a = getfield(v, f) + b = getfield(x.measurements[k], f) + for i in eachindex(getfield(v, f).B.compressors) + r = r && (a.B.compressors[i].value ≈ b.B.compressors[i].value) + r = r && (a.B.compressors[i].switch ≈ b.B.compressors[i].switch) + end + r = r && (a.B.x_sum ≈ b.B.x_sum) + r = r && (a.B.x2_sum ≈ b.B.x2_sum) + r = r && (a.B.count ≈ b.B.count) + elseif getfield(v, f) isa LogBinner + r = true + a = getfield(v, f) + b = getfield(x.measurements[k], f) + for i in eachindex(a.compressors) + r = r && (a.compressors[i].value ≈ b.compressors[i].value) + r = r && (a.compressors[i].switch ≈ b.compressors[i].switch) + end + r = r && (a.x_sum ≈ b.x_sum) + r = r && (a.x2_sum ≈ b.x2_sum) + r = r && (a.count ≈ b.count) + else + getfield(v, f) == getfield(x.measurements[k], f) + end + r != true && @info "Check failed for $k -> $f" + @test r + end + end + nothing + end + isfile("testfile.confs") && rm("testfile.confs") model = HubbardModel(4, 2, t = 1.7, U = 2.5) mc = DQMC( @@ -432,4 +435,47 @@ end @test matches rm("resumable_testfile.jld2") isfile("testfile.confs") && rm("testfile.confs") +end + + +function is_file_content_equal(file1, file2) + data1 = open(readavailable, file1, "r") + data2 = open(readavailable, file2, "r") + if length(data1) != length(data2) + return false + end + return all((a == b for (a, b) in zip(data1, data2))) +end + +@testset "DummyModel" begin + cp("assets/dummy_in.jld2", "dummy_in.jld2", force = true) + mc = MonteCarlo.load("dummy_in.jld2") + @test mc.model isa MonteCarlo.DummyModel + @test mc.model.data["x"] == 7 + @test mc.model.data["y"] == "foo" + + MonteCarlo.save("dummy_in.jld2", mc, overwrite=true) + @test isfile("dummy_in.jld2") + @test is_file_content_equal("dummy_in.jld2", "assets/dummy_in.jld2") + + #= + # Generated with + struct TestModel <: MonteCarlo.Model + l::AbstractLattice + U::Float64 + x::Int64 + y::String + end + + Base.rand(::Type{DQMC}, ::TestModel, n::Int64) = Base.rand(4, n) + MonteCarlo.choose_field(::TestModel) = DensityHirschField + MonteCarlo.lattice(m::TestModel) = m.l + MonteCarlo.nflavors(::TestModel) = 1 + MonteCarlo.hopping_matrix(::DQMC, ::TestModel) = ones(4, 4) + + mc = DQMC(TestModel(SquareLattice(2), 1.0, 7, "foo"), beta=1.0, recorder=Discarder()) + MonteCarlo.save("assets/dummy_in.jld2", mc, overwrite=true) + =# + + rm("dummy_in.jld2") end \ No newline at end of file diff --git a/test/assets/dummy_in.jld2 b/test/assets/dummy_in.jld2 new file mode 100644 index 0000000000000000000000000000000000000000..c282fd09ec4b8b3cf7d4e7120b552aae46b0535d GIT binary patch literal 12175 zcmeHNeQX@X6`%X?d42xKC4?pgnp{2_JH-uiu+h#wG z-93nrNFo|y1*G!PkYWUCLE?{=wp2w*%Ren5s9zGar4dToq!Ni5wIC$75(N>c`)21o z-`)C}-q|ez=_Wbv&6}C`e(%kjw{vf1+rVJ=x>42PnfT}Px~9elR9B4;TDGCO?eROf z?P!)6??`keI$69;cIiy4Pjtk4y4HcH!KuVKq`n0!@q%BM*17}T1Vu7R^I@C*ae*>B!ww;Vi45cQg1V= z!Zur}4xcE*)*`Si5@K3BY&Wa6YH$~p)XOVQ4pk*HgNA(OyI<=1a0a44*7m{PmZ~Jq zHqZatS)z#b(BYQrCDb0@ge+!(&Uo*Pdv-OluIUg=4!Y@MlcGG|8$0nWHYWxWz*O%S z?p&Fx%2>0ZS7@vcf8gwSod|7aeLb*LnG)o=Qs-myT!tWA_74(e13^_LLu6ye8Bb`W<&iuhP_6? z9!db^i@~DwTQ*OKciS!@Kroi9mUyjBzEn0B z<4NA0S8YD5x~`Vy_1F*G*DL!h%X?C=CCBh^?#kzMp)8aJ9JgFni}`O!r+Lm*%`{KC zs;fC5#(^1;_PKI)|6wbw>d8Af&p}%w_LHHXu~RGL00_6<7oxjJZ>|zYi;oZSUx4&0 z5Z_MK%veTZm+I{I@a^0_ge~&%S8{|BYOPYUK+OU*3)C!7vp~%PH4D@%P_sbI0yPWN zEbwu+fK0;-_!2=;7AoWxl+_|P7iS^`K=}nqJQl0}&|4esokv=TvHDBLFx^*WVv11A zlfl&VSDxB*X&x#}9r}Lo@dxI?px6c#X1GZALIQG5#;fT*k7@g8%zK3nBlAtH=&)gL zXff+6my)a~?I<#S6c#AkMouZ7e||B2zqD;vxhV4*X7ZH9BFeBC2Q(+G+8J(l-|^~`Rf!X}6s54x=Zw=o5OG49pDBBkg;_lt;wHDX^fPZvPDx>tp_?IZ4%6EPy&&~V$?)Aw zDR)Y7fW?Z>*2dY=IlHt-K9di^`PZy6@sopRB*tN%Ufzq-xLS@N58V2xVh&TiLTgdP zY9TR#_V+dFA0^vYL&rdMplty1Fq{)WdRS(N`xvCnxxqd`|q zHQ7>X(Peg=#ff01MlCbru(o1;lGH(cDhdN)8Iv**`0g){blE?b%IzszJ+z^*OxAlaK~H(}Z(@DJBhSAr#hgSj7%qqrJteEfW=e&oIPPD=%TiVA>bf^fV-9q37dKB|hm{_S%or6RYM&IBny z&shqr7BdmYg9~v4B@0=O6o2#huOF5I{#h!3I+Ab967Z4`@WPA$ni>*BPvS#Q5=>8j zKjK5}^dvg*p>e3!oFDJ5eLErz;1B2(-!te$%K=n@Dui!>1@K*gBcm`|z>=B)r(?-< zmS^%hxAoUIVSL^bx&nd}C@1XPS%}XSg$wcarf?y~3mZ{gyob)*_FS*jt)ky1V)fH- z@wzQ$CpIYOe*6xvf}J?cBX+dU0Mbl3ar#tNvTOGl(0c(iuaG_I;QY0fvi{jLE!Tl`u5L&qGzU(H%O#3pa zLE&%y($U{aT`F`w8d})DxZrMm-9Q)6D$+@uV$!?_AN7|eP&x=_#7*G2+{n4(B`55{ z!vm2&oVzG-uP&LoV91-g?W-hD1y+Puzh7H%=xVQf;7}mw9bV(5pP5Hguw@Z|7KjDv zZ{4~xdDD*VbwVXU29!V!x$4+vB3Xk`#JREaBR^Hs zVMF#NQQ=((eHX#fOz%%Q+L)Qj+19Aejny6|5D(fwk0r7PH0IPXPh5)Iwq@_7Zs70G zt-(?E@syF*-4nhBIC^*C&Vic$4K;w@!iexFchwliwFp7i;r!jXg{ZJGfSyf;+)$=$ z)#dk5+b6uXXwKG5*K0Y2E#B=1X~LPm?@WjH%b7cP*_l9-8|s1Y;Hf;#f}xpj_()FP zxE4@|5`lYCfV3`cH)mkr^16ewzK#fJtwO~7#s3o2v}k2%S1ohk%2kq7&beI(LM?)OB&rK_usoATapyJ^<3rEvH2eZ(3;wdkhV z%2vsYV$(ixcT8zRA_f?-p0qLNF(TtTs5q%{9bU&KKHmMZmxzdwWMSS!AI|+U(P;yH*G?F^d8#S`RMCu=Tpu( z4qn)g36Fa&t72(<@1SYNI0JXhUZz<^-x%rt3tE