Skip to content

Commit

Permalink
Merge pull request #55 from JuliaAI/dev
Browse files Browse the repository at this point in the history
For a 0.4.1 release
  • Loading branch information
ablaom authored Feb 29, 2024
2 parents 9232329 + 9e3146d commit a819b1f
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 6 deletions.
8 changes: 8 additions & 0 deletions .github/codecov.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
coverage:
status:
project:
default:
threshold: 0.5%
patch:
default:
target: 80%
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJDecisionTreeInterface"
uuid = "c6f25543-311c-4c74-83dc-3ea6d1015661"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.4.0"
version = "0.4.1"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand All @@ -21,7 +21,8 @@ julia = "1.6"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJTestInterface = "72560011-54dd-4dc2-94f3-c5de45b75ecd"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["MLJBase", "MLJTestInterface", "StableRNGs", "Test"]
test = ["MLJBase", "MLJTestInterface", "StableRNGs", "StatisticalMeasures", "Test"]
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# DecisionTree.jl
# MLJDecisionTreeInterface.jl

Repository implementing the MLJ model interface for
[DecisionTree](https://github.com/bensadeghi/DecisionTree.jl) models.
Expand Down
15 changes: 12 additions & 3 deletions src/MLJDecisionTreeInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ const PKG = "MLJDecisionTreeInterface"

struct TreePrinter{T}
tree::T
features::Vector{Symbol}
end
(c::TreePrinter)(depth) = DT.print_tree(c.tree, depth)
(c::TreePrinter)() = DT.print_tree(c.tree, 5)
(c::TreePrinter)(depth) = DT.print_tree(c.tree, depth, feature_names = c.features)
(c::TreePrinter)() = DT.print_tree(c.tree, 5, feature_names = c.features)

Base.show(stream::IO, c::TreePrinter) =
print(stream, "TreePrinter object (call with display depth)")
Expand Down Expand Up @@ -71,7 +72,7 @@ function MMI.fit(
cache = nothing
report = (
classes_seen=classes_seen,
print_tree=TreePrinter(tree),
print_tree=TreePrinter(tree, features),
features=features,
)
return fitresult, cache, report
Expand Down Expand Up @@ -765,6 +766,8 @@ The fields of `fitted_params(mach)` are:
# Report
The fields of `report(mach)` are:
- `features`: the names of the features encountered in training
Expand Down Expand Up @@ -862,6 +865,8 @@ The fields of `fitted_params(mach)` are:
# Report
The fields of `report(mach)` are:
- `features`: the names of the features encountered in training
Expand Down Expand Up @@ -968,6 +973,8 @@ The fields of `fitted_params(mach)` are:
# Report
The fields of `report(mach)` are:
- `features`: the names of the features encountered in training
Expand Down Expand Up @@ -1079,6 +1086,8 @@ The fields of `fitted_params(mach)` are:
# Report
The fields of `report(mach)` are:
- `features`: the names of the features encountered in training
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using StableRNGs
using Random
using Tables
import MLJTestInterface
using StatisticalMeasures

# load code to be tested:
import DecisionTree
Expand Down

0 comments on commit a819b1f

Please sign in to comment.