diff --git a/src/composition/learning_networks/nodes.jl b/src/composition/learning_networks/nodes.jl index 5ede32aa..0733b211 100644 --- a/src/composition/learning_networks/nodes.jl +++ b/src/composition/learning_networks/nodes.jl @@ -27,9 +27,9 @@ See also [`node`](@ref), [`Source`](@ref), [`origins`](@ref), [`sources`](@ref), [`fit!`](@ref). """ -struct Node{T<:Union{Machine, Nothing}} <: AbstractNode +struct Node{T<:Union{Machine, Nothing},Oper} <: AbstractNode - operation # eg, `predict` or a static operation, such as `exp` + operation::Oper # eg, `predict` or a static operation, such as `exp` machine::T # is `nothing` for static operations # nodes called to get args for `operation(model, ...) ` or @@ -43,9 +43,11 @@ struct Node{T<:Union{Machine, Nothing}} <: AbstractNode # order consistent with extended graph, excluding self nodes::Vector{AbstractNode} - function Node(operation, - machine::T, - args::AbstractNode...) where T<:Union{Machine, Nothing} + function Node( + operation::Oper, + machine::T, + args::AbstractNode..., + ) where {T<:Union{Machine, Nothing}, Oper} # check the number of arguments: # if machine === nothing && isempty(args) @@ -70,7 +72,7 @@ struct Node{T<:Union{Machine, Nothing}} <: AbstractNode vcat(nodes_, (nodes(n) for n in machine.args)...) |> unique end - return new{T}(operation, machine, args, origins_, nodes_) + return new{T,Oper}(operation, machine, args, origins_, nodes_) end end