diff --git a/Project.toml b/Project.toml index d1b325d..9dbfe21 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJDecisionTreeInterface" uuid = "c6f25543-311c-4c74-83dc-3ea6d1015661" authors = ["Anthony D. Blaom "] -version = "0.4.1" +version = "0.4.2" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" diff --git a/src/MLJDecisionTreeInterface.jl b/src/MLJDecisionTreeInterface.jl index ca639bf..cd5357b 100644 --- a/src/MLJDecisionTreeInterface.jl +++ b/src/MLJDecisionTreeInterface.jl @@ -452,7 +452,10 @@ const RandomForestModel = Union{ # # DATA FRONT END -_columnnames(X) = Tables.columnnames(Tables.columns(X)) |> collect +# to get column names based on table access type: +_columnnames(X) = _columnnames(X, Val(Tables.columnaccess(X))) |> collect +_columnnames(X, ::Val{true}) = Tables.columnnames(Tables.columns(X)) +_columnnames(X, ::Val{false}) = Tables.columnnames(first(Tables.rows(X))) # for fit: MMI.reformat(::Classifier, X, y) =