Skip to content

Commit

Permalink
Merge pull request #37 from Evovest/dev
Browse files Browse the repository at this point in the history
MLJModelInterface integration
  • Loading branch information
jeremiedb authored Feb 16, 2020
2 parents e3c60a7 + 15d8620 commit ad9fe6f
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 596 deletions.
106 changes: 18 additions & 88 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,11 @@ git-tree-sha1 = "23d7324164c89638c18f6d7f90d972fa9c4fa9fb"
uuid = "324d7699-5711-5eae-9e2f-1d82baa6b597"
version = "0.7.7"

[[ColorTypes]]
deps = ["FixedPointNumbers", "Random"]
git-tree-sha1 = "7b62b728a5f3dd6ee3b23910303ccf27e82fad5e"
uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
version = "0.8.1"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "06be57f11a029927e10d050a6c5496a8695a5437"
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "a4839bd26e3e7f4869a4cf6c31f9f93f47aac7c5"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "3.3.0"

[[ComputationalResources]]
deps = ["Test"]
git-tree-sha1 = "89e7e7ed20af73d9f78877d2b8d1194e7b6ff13d"
uuid = "ed09eef8-17a6-5b46-8889-db040fac31e3"
version = "0.3.0"

[[Crayons]]
git-tree-sha1 = "cb7a62895da739fe5bb43f1a26d4292baf4b3dc0"
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
version = "4.0.1"
version = "3.5.0"

[[DataAPI]]
git-tree-sha1 = "674b67f344687a88310213ddfa8a2b3c76cc4252"
Expand Down Expand Up @@ -80,27 +63,16 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[Distributions]]
deps = ["FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns"]
git-tree-sha1 = "e063d0b5d27180b98edacd2b1cb90ecfbc171385"
git-tree-sha1 = "6b19601c0e98de3a8964ed33ad73e130c7165b1d"
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
version = "0.21.12"
version = "0.22.4"

[[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays"]
git-tree-sha1 = "fec413d4fc547992eb62a5c544cedb6d7853c1f5"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "0.8.4"

[[FixedPointNumbers]]
git-tree-sha1 = "4aaea64dd0c30ad79037084f8ca2b94348e65eaa"
uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
version = "0.7.1"

[[Formatting]]
deps = ["Printf"]
git-tree-sha1 = "a0c901c29c0e7c763342751c0a94211d56c0de5c"
uuid = "59287772-0a20-5a39-b81b-1366585eb4c0"
version = "0.4.1"

[[Future]]
deps = ["Random"]
uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820"
Expand All @@ -109,12 +81,6 @@ uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820"
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"

[[InvertedIndices]]
deps = ["Test"]
git-tree-sha1 = "15732c475062348b0165684ffe28e85ea8396afc"
uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f"
version = "1.0.0"

[[IteratorInterfaceExtensions]]
git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856"
uuid = "82899510-4779-5014-852e-03e436cf321d"
Expand All @@ -126,12 +92,6 @@ git-tree-sha1 = "b34d7cef7b337321e97d22242c3c2b91f476748e"
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
version = "0.21.0"

[[LearnBase]]
deps = ["LinearAlgebra", "SparseArrays", "StatsBase", "Test"]
git-tree-sha1 = "c4b5da6d68517f46f70ed5157b28336b56cd2ff3"
uuid = "7f8f8fb0-2700-5f03-b4bd-41f8cfc144b6"
version = "0.2.2"

[[LibGit2]]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"

Expand All @@ -145,17 +105,11 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
[[Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

[[LossFunctions]]
deps = ["InteractiveUtils", "LearnBase", "Markdown", "Random", "RecipesBase", "SparseArrays", "Statistics", "StatsBase", "Test"]
git-tree-sha1 = "08d87fec43e7d335811dfae5b55dbfc5690e915b"
uuid = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
version = "0.5.1"

[[MLJBase]]
deps = ["CategoricalArrays", "ComputationalResources", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LinearAlgebra", "LossFunctions", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "ScientificTypes", "Statistics", "StatsBase", "Tables"]
git-tree-sha1 = "450fa34dcb0005d0799ffcf9cca5f40aa6d83059"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
version = "0.10.1"
[[MLJModelInterface]]
deps = ["ScientificTypes"]
git-tree-sha1 = "269deeabed43d68656c80fa57a83fb53ad202728"
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
version = "0.1.5"

[[Markdown]]
deps = ["Base64"]
Expand Down Expand Up @@ -194,12 +148,6 @@ git-tree-sha1 = "5f303510529486bb02ac4d70da8295da38302194"
uuid = "90014a1f-27ba-587c-ab20-58faa44d9150"
version = "0.9.11"

[[Parameters]]
deps = ["OrderedCollections"]
git-tree-sha1 = "b62b2558efb1eef1fa44e4be5ff58a515c287e38"
uuid = "d96e819e-fc66-5662-9728-84c9c7592b0a"
version = "0.12.0"

[[Parsers]]
deps = ["Dates", "Test"]
git-tree-sha1 = "d112c19ccca00924d5d3a38b11ae2b4b268dda39"
Expand All @@ -210,22 +158,10 @@ version = "0.3.11"
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"

[[PrettyTables]]
deps = ["Crayons", "Formatting", "Parameters", "Reexport", "Tables"]
git-tree-sha1 = "2268242f037e0290e87d55c02060320c1d0d6b03"
uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
version = "0.6.0"

[[Printf]]
deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"

[[ProgressMeter]]
deps = ["Distributed", "Printf"]
git-tree-sha1 = "ea1f4fa0ff5e8b771bf130d87af5b7ef400760bd"
uuid = "92933f4c-e287-5a05-a399-4b506db050ca"
version = "1.2.0"

[[QuadGK]]
deps = ["DataStructures", "LinearAlgebra"]
git-tree-sha1 = "dc84e810393cfc6294248c9032a9cdacc14a3db4"
Expand All @@ -240,11 +176,6 @@ uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
deps = ["Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[[RecipesBase]]
git-tree-sha1 = "b4ed4a7f988ea2340017916f7c9e5d7560b52cae"
uuid = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
version = "0.8.0"

[[Reexport]]
deps = ["Pkg"]
git-tree-sha1 = "7b1d07f411bc8ddb7977ec7f377b97b158514fe0"
Expand All @@ -261,10 +192,9 @@ version = "0.6.0"
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"

[[ScientificTypes]]
deps = ["CategoricalArrays", "ColorTypes", "PrettyTables", "Tables"]
git-tree-sha1 = "20fa7448b38ea42eb40da1d66c83cf67d626964a"
git-tree-sha1 = "9c232034bbee8c53173cdce83787bf8968b09d31"
uuid = "321657f4-b219-11e9-178b-2701a2544e81"
version = "0.5.1"
version = "0.7.1"

[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
Expand All @@ -288,9 +218,9 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[[SpecialFunctions]]
deps = ["OpenSpecFun_jll"]
git-tree-sha1 = "268052ee908b2c086cc0011f528694f02f3e2408"
git-tree-sha1 = "e19b98acb182567bcb7b75bb5d9eedf3a3b5ec6c"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "0.9.0"
version = "0.10.0"

[[StaticArrays]]
deps = ["LinearAlgebra", "Random", "Statistics"]
Expand All @@ -304,15 +234,15 @@ uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[[StatsBase]]
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"]
git-tree-sha1 = "c53e809e63fe5cf5de13632090bc3520649c9950"
git-tree-sha1 = "be5c7d45daa449d12868f4466dbf5882242cf2d9"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.32.0"
version = "0.32.1"

[[StatsFuns]]
deps = ["Rmath", "SpecialFunctions"]
git-tree-sha1 = "79982835d2ff3970685cb704500909c94189bde9"
git-tree-sha1 = "f290ddd5fdedeadd10e961eb3f4d3340f09d030a"
uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
version = "0.9.3"
version = "0.9.4"

[[SuiteSparse]]
deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
Expand Down
12 changes: 9 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
name = "EvoTrees"
uuid = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5"
authors = ["jeremiedb <[email protected]>"]
version = "0.4.2"
version = "0.4.3"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SortingAlgorithms = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Expand All @@ -18,9 +18,15 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[compat]
CategoricalArrays = "0.7"
Distributions = "0.21, 0.22"
MLJBase = "0.10"
SortingAlgorithms = "0.3"
StaticArrays = "0.12"
StatsBase = "0.32"
Tables = "0.2"
julia = "1"

[extras]
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "MLJBase"]
4 changes: 1 addition & 3 deletions experiments/MLJ.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,6 @@ tree.model.nrounds += 10
pred = predict(tree, selectrows(X,train))
pred_mean = predict_mean(tree, selectrows(X,train))
pred_mode = predict_mode(tree, selectrows(X,train))
pred_median = predict_median(tree, selectrows(X,train))

##################################################
### Gaussian - Larger data
Expand Down Expand Up @@ -231,8 +230,7 @@ tree.model.nrounds += 10
pred = predict(tree, selectrows(X,train))
pred_mean = predict_mean(tree, selectrows(X,train))
pred_mode = predict_mode(tree, selectrows(X,train))
pred_median = predict_median(tree, selectrows(X,train))
mean(abs.(pred_train - selectrows(Y,train)))
mean(abs.(pred_mean - selectrows(Y,train)))

q_20 = quantile.(pred, 0.20)
q_20 = quantile.(pred, 0.80)
7 changes: 3 additions & 4 deletions src/EvoTrees.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
module EvoTrees

export init_evotree, grow_evotree!, grow_tree, predict, fit_evotree,
export init_evotree, grow_evotree!, grow_tree, fit_evotree, predict,
EvoTreeRegressor, EvoTreeCount, EvoTreeClassifier, EvoTreeGaussian,
EvoTreeRModels, importance

using Statistics
using Base.Threads: @threads
using StatsBase: sample, quantile
import StatsBase: predict
using Random: seed!
using StaticArrays
using Distributions
using CategoricalArrays
import MLJBase
# import MLJ
import MLJModelInterface: predict
import MLJModelInterface

include("models.jl")
include("structs.jl")
Expand Down
Loading

0 comments on commit ad9fe6f

Please sign in to comment.