Skip to content

Commit

Permalink
Merge pull request #43 from crstnbr/cbdev
Browse files Browse the repository at this point in the history
#42 and #40
  • Loading branch information
carstenbauer authored Mar 7, 2019
2 parents a35637b + 9034bd9 commit 8096f84
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 37 deletions.
1 change: 0 additions & 1 deletion REQUIRE
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
julia 1.0
EllipsisNotation
Reexport
Lazy
RecursiveArrayTools
26 changes: 13 additions & 13 deletions src/log/binning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ mutable struct Compressor{T}
end


struct LogBinner{N, T}
struct LogBinner{T, N}
# list of Compressors, one per level
compressors::NTuple{N, Compressor{T}}

Expand All @@ -23,21 +23,21 @@ end


# Overload some basic Base functions
Base.eltype(B::LogBinner{N,T}) where {N,T} = T
Base.eltype(B::LogBinner{T,N}) where {T,N} = T
Base.length(B::LogBinner) = B.count[1]
Base.ndims(B::LogBinner{N,T}) where {N,T} = ndims(eltype(B))
Base.ndims(B::LogBinner{T,N}) where {T,N} = ndims(eltype(B))
Base.isempty(B::LogBinner) = length(B) == 0





function _print_header(io::IO, B::LogBinner{N,T}) where {N, T}
print(io, "LogBinner{$(N),$(T)}")
function _print_header(io::IO, B::LogBinner{T,N}) where {T,N}
print(io, "LogBinner{$(T),$(N)}")
nothing
end

function _println_body(io::IO, B::LogBinner{N,T}) where {N, T}
function _println_body(io::IO, B::LogBinner{T,N}) where {T,N}
n = length(B)
println(io)
print(io, "| Count: ", n)
Expand All @@ -49,7 +49,7 @@ function _println_body(io::IO, B::LogBinner{N,T}) where {N, T}
end

# short version (shows up in arrays etc.)
Base.show(io::IO, B::LogBinner{N,T}) where {N, T} = print(io, "LogBinner{$(N),$(T)}()")
Base.show(io::IO, B::LogBinner{T,N}) where {T,N} = print(io, "LogBinner{$(T),$(N)}()")
# verbose version (shows up in the REPL)
Base.show(io::IO, m::MIME"text/plain", B::LogBinner) = (_print_header(io, B); _println_body(io, B))

Expand Down Expand Up @@ -107,8 +107,8 @@ _capacity2nlvls(capacity::Int) = ceil(Int, log(2, capacity + 1))
Capacity of the binner, i.e. how many values can be handled before overflowing.
"""
capacity(B::LogBinner{N, T}) where {N,T} = _nlvls2capacity(N)
nlevels(B::LogBinner{N, T}) where {N,T} = N
capacity(B::LogBinner{T,N}) where {T,N} = _nlvls2capacity(N)
nlevels(B::LogBinner{T,N}) where {T,N} = N



Expand Down Expand Up @@ -171,7 +171,7 @@ function LogBinner(x::T;
el = x
end

B = LogBinner{N, S}(
B = LogBinner{S, N}(
tuple([Compressor{S}(copy(el), false) for i in 1:N]...),
[copy(el) for _ in 1:N],
[copy(el) for _ in 1:N],
Expand Down Expand Up @@ -215,7 +215,7 @@ end
Pushes a new value into the Binning Analysis.
"""
function Base.push!(B::LogBinner{N, T}, value::S) where {N, T, S}
function Base.push!(B::LogBinner{T,N}, value::S) where {N, T, S}
ndims(T) == ndims(S) || throw(DimensionMismatch("Expected $(ndims(T)) dimensions but got $(ndims(S))."))

_push!(B, 1, value)
Expand All @@ -227,7 +227,7 @@ _square(x::Complex) = Complex(real(x)^2, imag(x)^2)
_square(x::AbstractArray) = _square.(x)

# recursion, back-end function
function _push!(B::LogBinner{N, T}, lvl::Int64, value::S) where {N, T <: Number, S}
function _push!(B::LogBinner{T,N}, lvl::Int64, value::S) where {N, T <: Number, S}
C = B.compressors[lvl]

# any value propagating through this function is new to lvl. Therefore we
Expand Down Expand Up @@ -259,7 +259,7 @@ function _push!(B::LogBinner{N, T}, lvl::Int64, value::S) where {N, T <: Number,
end

function _push!(
B::LogBinner{N, T},
B::LogBinner{T,N},
lvl::Int64,
value::S
) where {N, T <: AbstractArray, S}
Expand Down
37 changes: 21 additions & 16 deletions src/log/statistics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Calculates the variance of a given level in the Binning Analysis.
function var(B::LogBinner) end

function var(
B::LogBinner{N, T},
B::LogBinner{T,N},
lvl::Integer = _reliable_level(B)
) where {N, T <: Real}

Expand All @@ -32,7 +32,7 @@ function var(
end

function var(
B::LogBinner{N, T},
B::LogBinner{T,N},
lvl::Integer = _reliable_level(B)
) where {N, T <: Complex}

Expand All @@ -45,7 +45,7 @@ function var(
end

function var(
B::LogBinner{N, <: AbstractArray{T, D}},
B::LogBinner{<: AbstractArray{T, D}, N},
lvl::Integer = _reliable_level(B)
) where {N, D, T <: Real}

Expand All @@ -57,7 +57,7 @@ function var(
end

function var(
B::LogBinner{N, <: AbstractArray{T, D}},
B::LogBinner{<: AbstractArray{T, D}, N},
lvl::Integer = _reliable_level(B)
) where {N, D, T <: Complex}

Expand All @@ -74,7 +74,7 @@ end
Calculates the variance for each level of the Binning Analysis.
"""
function all_vars(B::LogBinner{N}) where {N}
function all_vars(B::LogBinner{T,N}) where {T,N}
[var(B, lvl) for lvl in 1:N if B.count[lvl] > 1]
end

Expand All @@ -84,7 +84,7 @@ end
Calculates the variance/N for each level of the Binning Analysis.
"""
function all_varNs(B::LogBinner{N}) where {N}
function all_varNs(B::LogBinner{T,N}) where {T,N}
[varN(B, lvl) for lvl in 1:N if B.count[lvl] > 1]
end

Expand All @@ -109,7 +109,7 @@ end
Calculates the mean for each level of the `LogBinner`.
"""
function all_means(B::LogBinner{N}) where {N}
function all_means(B::LogBinner{T,N}) where {T,N}
[mean(B, lvl) for lvl in 1:N if B.count[lvl] > 1]
end

Expand All @@ -122,19 +122,24 @@ end
Calculates the autocorrelation time tau.
"""
function tau(B::LogBinner, lvl::Integer = _reliable_level(B))
function tau(B::LogBinner{T,N}, lvl::Integer = _reliable_level(B)) where {N , T <: Number}
var_0 = varN(B, 1)
var_l = varN(B, lvl)
0.5 * (var_l / var_0 - 1)
end
function tau(B::LogBinner{T,N}, lvl::Integer = _reliable_level(B)) where {N , T <: AbstractArray}
var_0 = varN(B, 1)
var_l = varN(B, lvl)
@. 0.5 * (var_l / var_0 - 1)
end


"""
all_taus(B::LogBinner)
Calculates the autocorrelation time tau for each level of the `LogBinner`.
"""
function all_taus(B::LogBinner{N}) where {N}
function all_taus(B::LogBinner{T,N}) where {T,N}
[tau(B, lvl) for lvl in 1:N if B.count[lvl] > 1]
end

Expand All @@ -146,7 +151,7 @@ end
# standard error estimate:
# Take the highest lvl with at least 32 bins.
# (Chose 32 based on https://doi.org/10.1119/1.3247985)
function _reliable_level(B::LogBinner{N,T})::Int64 where {N, T}
function _reliable_level(B::LogBinner{T,N})::Int64 where {T,N}
isempty(B) && (return 1) # results in NaN in std_error
i = findlast(x -> x >= 32, B.count)
something(i, 1)
Expand All @@ -159,10 +164,10 @@ Calculates the standard error of the mean.
"""
function std_error(B::LogBinner) end

function std_error(B::LogBinner{N, T}, lvl::Integer=_reliable_level(B)) where {N, T <: Number}
function std_error(B::LogBinner{T,N}, lvl::Integer=_reliable_level(B)) where {N, T <: Number}
sqrt(varN(B, lvl))
end
function std_error(B::LogBinner{N, T}, lvl::Integer=_reliable_level(B)) where {N, T <: AbstractArray}
function std_error(B::LogBinner{T,N}, lvl::Integer=_reliable_level(B)) where {N, T <: AbstractArray}
sqrt.(varN(B, lvl))
end

Expand All @@ -174,8 +179,8 @@ Calculates the standard error for each level of the Binning Analysis.
"""
function all_std_errors(B::LogBinner) end

all_std_errors(B::LogBinner{N, T}) where {N, T <: Number} = sqrt.(all_varNs(B))
all_std_errors(B::LogBinner{N, T}) where {N, T <: AbstractArray} = (x -> sqrt.(x)).(all_varNs(B))
all_std_errors(B::LogBinner{T,N}) where {N, T <: Number} = sqrt.(all_varNs(B))
all_std_errors(B::LogBinner{T,N}) where {N, T <: AbstractArray} = (x -> sqrt.(x)).(all_varNs(B))


"""
Expand All @@ -188,10 +193,10 @@ converged.
function convergence(B::LogBinner) end


function convergence(B::LogBinner{N, T}, lvl::Integer=_reliable_level(B)) where {N, T <: Number}
function convergence(B::LogBinner{T,N}, lvl::Integer=_reliable_level(B)) where {N, T <: Number}
abs((varN(B, lvl+1) - varN(B, lvl)) / varN(B, lvl))
end
function convergence(B::LogBinner{N, T}, lvl::Integer=_reliable_level(B)) where {N, T <: AbstractArray}
function convergence(B::LogBinner{T,N}, lvl::Integer=_reliable_level(B)) where {N, T <: AbstractArray}
mean(abs.((varN(B, lvl+1) .- varN(B, lvl)) ./ varN(B, lvl)))
end

Expand Down
20 changes: 13 additions & 7 deletions test/logbinning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ end
@test isapprox(all_varNs(BA), zero(all_varNs(BA)), atol=1e-6)
@test isapprox(all_taus(BA), [0.0, -0.00065328850247931, -0.0018180968845809553, 0.00019817179932868356, 0.0005312186016332987, -0.009476150581268994, -0.008794711536776634, -0.007346737569564443, -0.014542478848703244, -0.030064159934323653, -0.01599670814224563, -0.007961042363178128, -0.03167873601168558, -0.056188229083248886, -0.008218660661725774, 0.05035147373711113, 0.0756019296606737, 0.2387501479629205, 0.289861051172009])
@test isapprox(all_std_errors(BA), [0.0004083071646092097, 0.0004080403350462593, 0.00040756414657076514, 0.0004083880715587553, 0.0004085240073900606, 0.00040441947616030687, 0.00040470028980091994, 0.0004052963382191276, 0.00040232555162611277, 0.00039584146247874917, 0.00040172249945971054, 0.00040504357104527986, 0.00039516087376314515, 0.0003846815937765092, 0.00040493752222959487, 0.0004283729758565588, 0.00043808977717638777, 0.0004963074437049236, 0.0005131890106236348])

@test isapprox(tau(BA), -0.008218660661725774)
@test isapprox(std_error(BA), 0.00040493752222959487)
end


Expand All @@ -147,6 +150,9 @@ end

# all_std_errors for <:AbstractArray
@test all(isapprox.(all_std_errors(BA), Ref(zeros(3)), atol=1e-2))

@test all(isapprox.(tau(BA), [-0.101203, -0.0831874, -0.0112827], atol=1e-6))
@test all(isapprox.(std_error(BA), [0.000364498, 0.00037268, 0.000403603], atol=1e-6))
end


Expand Down Expand Up @@ -235,12 +241,12 @@ end

@testset "Sum-type heuristic" begin
# numbers
@test typeof(LogBinner(zero(Int64))) == LogBinner{32,Float64}
@test typeof(LogBinner(zero(ComplexF16))) == LogBinner{32,ComplexF64}
@test typeof(LogBinner(zero(Int64))) == LogBinner{Float64, 32}
@test typeof(LogBinner(zero(ComplexF16))) == LogBinner{ComplexF64, 32}

# arrays
@test typeof(LogBinner(zeros(Int64, 2,2))) == LogBinner{32,Matrix{Float64}}
@test typeof(LogBinner(zeros(ComplexF16, 2,2))) == LogBinner{32,Matrix{ComplexF64}}
@test typeof(LogBinner(zeros(Int64, 2,2))) == LogBinner{Matrix{Float64}, 32}
@test typeof(LogBinner(zeros(ComplexF16, 2,2))) == LogBinner{Matrix{ComplexF64}, 32}
end


Expand Down Expand Up @@ -301,9 +307,9 @@ end
close(write_pipe);

# compact
@test readline(read_pipe) == "LogBinner{32,Float64}()"
@test readline(read_pipe) == "LogBinner{Float64,32}()"
# full
@test readline(read_pipe) == "LogBinner{32,Float64}"
@test readline(read_pipe) == "LogBinner{Float64,32}"
@test readline(read_pipe) == "| Count: 0"
@test length(readlines(read_pipe)) == 0
close(read_pipe);
Expand All @@ -315,7 +321,7 @@ end
show(write_pipe, MIME"text/plain"(), B)
redirect_stdout(oldstdout);
close(write_pipe);
@test readline(read_pipe) == "LogBinner{32,Float64}"
@test readline(read_pipe) == "LogBinner{Float64,32}"
@test readline(read_pipe) == "| Count: 1000"
@test readline(read_pipe) == "| Mean: 0.49685"
@test readline(read_pipe) == "| StdError: 0.00733"
Expand Down

0 comments on commit 8096f84

Please sign in to comment.