diff --git a/src/train.jl b/src/train.jl index 568cab4..b865600 100644 --- a/src/train.jl +++ b/src/train.jl @@ -28,7 +28,7 @@ GMM(x::Vector{T}) where T <: AbstractFloat = GMM(reshape(x, length(x), 1)) # st ## constructors based on data or matrix function GMM(n::Int, x::DataOrMatrix{T}; method::Symbol=:kmeans, kind=:diag, - nInit::Int=50, nIter::Int=10, nFinal::Int=nIter, sparse=0) where T + nInit::Int=50, nIter::Int=10, nFinal::Int=nIter, sparse=0) where T <: AbstractFloat if n < 2 GMM(x, kind=kind) elseif method==:split @@ -40,7 +40,7 @@ function GMM(n::Int, x::DataOrMatrix{T}; method::Symbol=:kmeans, kind=:diag, end end ## a 1-dimensional Gaussian can be initialized with a vector, skip kind= -GMM(n::Int, x::Vector{T}; method::Symbol=:kmeans, nInit::Int=50, nIter::Int=10, nFinal::Int=nIter, sparse=0) where T = GMM(n, reshape(x, length(x), 1); method=method, kind=:diag, nInit=nInit, nIter=nIter, nFinal=nFinal, sparse=sparse) +GMM(n::Int, x::Vector{T}; method::Symbol=:kmeans, nInit::Int=50, nIter::Int=10, nFinal::Int=nIter, sparse=0) where T <: AbstractFloat = GMM(n, reshape(x, length(x), 1); method=method, kind=:diag, nInit=nInit, nIter=nIter, nFinal=nFinal, sparse=sparse) ## we sometimes end up with pathological gmms function sanitycheck!(gmm::GMM) @@ -72,7 +72,7 @@ end ## initialize GMM using Clustering.kmeans (which uses a method similar to kmeans++) -function GMMk(n::Int, x::DataOrMatrix{T}; kind=:diag, nInit::Int=50, nIter::Int=10, sparse=0) where T +function GMMk(n::Int, x::DataOrMatrix{T}; kind=:diag, nInit::Int=50, nIter::Int=10, sparse=0) where T <: AbstractFloat nₓ, d = size(x) hist = [History(@sprintf("Initializing GMM, %d Gaussians %s covariance %d dimensions using %d data points", n, diag, d, nₓ))] @info(last(hist).s)