From df3f0e44669d18da9b8676494ad40bad58b0b211 Mon Sep 17 00:00:00 2001 From: pluto-azzaare Date: Mon, 12 Aug 2024 09:53:07 +0000 Subject: [PATCH 1/2] Fixes for learning ICNs with CBLS --- src/icn.jl | 19 ++++++++++--------- src/layer.jl | 7 ++++++- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/src/icn.jl b/src/icn.jl index a0047eb..79ef929 100644 --- a/src/icn.jl +++ b/src/icn.jl @@ -17,11 +17,11 @@ mutable struct ICN weights::BitVector function ICN(; - param = Vector{Symbol}(), - tr_layer = transformation_layer(param), - ar_layer = arithmetic_layer(), - ag_layer = aggregation_layer(), - co_layer = comparison_layer(param), + param=Vector{Symbol}(), + tr_layer=transformation_layer(param), + ar_layer=arithmetic_layer(), + ag_layer=aggregation_layer(), + co_layer=comparison_layer(param), ) w = generate_weights([tr_layer, ar_layer, ag_layer, co_layer]) return new(tr_layer, ar_layer, ag_layer, co_layer, w) @@ -107,7 +107,7 @@ function regularization(icn) return Σop / (Σmax + 1) end -max_icn_length(icn = ICN(; param = [:val])) = length(icn.transformation) +max_icn_length(icn=ICN(; param=[:val])) = length(icn.transformation) """ _compose(icn) @@ -116,7 +116,7 @@ Internal function called by `compose` and `show_composition`. function _compose(icn::ICN) !is_viable(icn) && ( return ( - (x; X = zeros(length(x), max_icn_length()), param = nothing, dom_size = 0) -> typemax(Float64) + (x; X=zeros(length(x), max_icn_length()), param=nothing, dom_size=0) -> typemax(Float64) ), [] ) @@ -133,6 +133,7 @@ function _compose(icn::ICN) if exclu(layer) f_id = as_int(@view weights(icn)[_start:_end]) + # @warn "debug" f_id _end _start weights(icn) (exclu(layer) ? "nbits_exclu(layer)" : "length(layer)") (@view weights(icn)[_start:_end]) s = symbol(layer, f_id + 1) push!(funcs, [functions(layer)[s]]) push!(symbols, [s]) @@ -151,11 +152,11 @@ function _compose(icn::ICN) end end - function composition(x; X = zeros(length(x), length(funcs[1])), dom_size, params...) + function composition(x; X=zeros(length(x), length(funcs[1])), dom_size, params...) tr_in(Tuple(funcs[1]), X, x; params...) X[1:length(x), 1] .= 1:length(x) .|> (i -> funcs[2][1](@view X[i, 1:length(funcs[1])])) - return (y -> funcs[4][1](y; dom_size, nvars = length(x), params...))( + return (y -> funcs[4][1](y; dom_size, nvars=length(x), params...))( funcs[3][1](@view X[:, 1]), ) end diff --git a/src/layer.jl b/src/layer.jl index 53646f0..18be36f 100644 --- a/src/layer.jl +++ b/src/layer.jl @@ -30,7 +30,12 @@ exclu(layer) = layer.exclusive symbol(layer, i) Return the i-th symbols of the operations in a given layer. """ -symbol(layer, i) = collect(keys(functions(layer)))[i] +symbol(layer, i) = begin + if i > length(layer) + @info layer i functions(layer) + end + collect(keys(functions(layer)))[i] +end """ nbits_exclu(layer) From a2c476018858a6748b2a633b9c9bbf63eac78f68 Mon Sep 17 00:00:00 2001 From: Azzaare Date: Mon, 12 Aug 2024 10:03:37 +0000 Subject: [PATCH 2/2] format fix --- src/icn.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/icn.jl b/src/icn.jl index 79ef929..2c21632 100644 --- a/src/icn.jl +++ b/src/icn.jl @@ -17,11 +17,11 @@ mutable struct ICN weights::BitVector function ICN(; - param=Vector{Symbol}(), - tr_layer=transformation_layer(param), - ar_layer=arithmetic_layer(), - ag_layer=aggregation_layer(), - co_layer=comparison_layer(param), + param = Vector{Symbol}(), + tr_layer = transformation_layer(param), + ar_layer = arithmetic_layer(), + ag_layer = aggregation_layer(), + co_layer = comparison_layer(param), ) w = generate_weights([tr_layer, ar_layer, ag_layer, co_layer]) return new(tr_layer, ar_layer, ag_layer, co_layer, w) @@ -107,7 +107,7 @@ function regularization(icn) return Σop / (Σmax + 1) end -max_icn_length(icn=ICN(; param=[:val])) = length(icn.transformation) +max_icn_length(icn = ICN(; param = [:val])) = length(icn.transformation) """ _compose(icn) @@ -116,7 +116,7 @@ Internal function called by `compose` and `show_composition`. function _compose(icn::ICN) !is_viable(icn) && ( return ( - (x; X=zeros(length(x), max_icn_length()), param=nothing, dom_size=0) -> typemax(Float64) + (x; X = zeros(length(x), max_icn_length()), param = nothing, dom_size = 0) -> typemax(Float64) ), [] ) @@ -152,11 +152,11 @@ function _compose(icn::ICN) end end - function composition(x; X=zeros(length(x), length(funcs[1])), dom_size, params...) + function composition(x; X = zeros(length(x), length(funcs[1])), dom_size, params...) tr_in(Tuple(funcs[1]), X, x; params...) X[1:length(x), 1] .= 1:length(x) .|> (i -> funcs[2][1](@view X[i, 1:length(funcs[1])])) - return (y -> funcs[4][1](y; dom_size, nvars=length(x), params...))( + return (y -> funcs[4][1](y; dom_size, nvars = length(x), params...))( funcs[3][1](@view X[:, 1]), ) end