Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremiedb committed Oct 18, 2022
1 parent a4f7be0 commit e352808
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 155 deletions.
45 changes: 23 additions & 22 deletions experiments/benchmarks_v2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,31 +58,32 @@ num_feat = Int(100)
x_train = rand(nobs, num_feat)
y_train = rand(size(x_train, 1))

@info "xgboost train:"
@time m_xgb = xgboost(x_train, nrounds, label=y_train, param=params_xgb, metrics=metrics, nthread=nthread, silent=1);
@btime xgboost($x_train, $nrounds, label=$y_train, param=$params_xgb, metrics=$metrics, nthread=$nthread, silent=1);
@info "xgboost predict:"
@time pred_xgb = XGBoost.predict(m_xgb, x_train);
@btime XGBoost.predict($m_xgb, $x_train);
# @info "xgboost train:"
# @time m_xgb = xgboost(x_train, nrounds, label=y_train, param=params_xgb, metrics=metrics, nthread=nthread, silent=1);
# @btime xgboost($x_train, $nrounds, label=$y_train, param=$params_xgb, metrics=$metrics, nthread=$nthread, silent=1);
# @info "xgboost predict:"
# @time pred_xgb = XGBoost.predict(m_xgb, x_train);
# @btime XGBoost.predict($m_xgb, $x_train);

@info "evotrees train CPU:"
params_evo.device = "cpu"
@time m_evo = fit_evotree(params_evo; x_train, y_train, x_eval=x_train, y_eval=y_train, metric=metric_evo, print_every_n=50);
@btime fit_evotree($params_evo; x_train=$x_train, y_train=$y_train, x_eval=$x_train, y_eval=$y_train, metric=metric_evo);
@info "evotrees predict CPU:"
@time pred_evo = EvoTrees.predict(m_evo, x_train);
@btime EvoTrees.predict($m_evo, $x_train);
# @info "evotrees train CPU:"
# params_evo.device = "cpu"
# @time m_evo = fit_evotree(params_evo; x_train, y_train, x_eval=x_train, y_eval=y_train, metric=metric_evo, print_every_n=100);
# @btime fit_evotree($params_evo; x_train=$x_train, y_train=$y_train, x_eval=$x_train, y_eval=$y_train, metric=metric_evo);
# @btime fit_evotree($params_evo; x_train=$x_train, y_train=$y_train);
# @info "evotrees predict CPU:"
# @time pred_evo = EvoTrees.predict(m_evo, x_train);
# @btime EvoTrees.predict($m_evo, $x_train);

CUDA.allowscalar(true)
@info "evotrees train GPU:"
params_evo.device = "gpu"
@time m_evo_gpu = fit_evotree(params_evo; x_train, y_train);
@time m_evo = fit_evotree(params_evo; x_train, y_train, x_eval=x_train, y_eval=y_train, metric=metric_evo, print_every_n=50);
@btime fit_evotree($params_evo; x_train=$x_train, y_train=$y_train, x_eval=$x_train, y_eval=$y_train, metric=metric_evo);
@info "evotrees predict GPU:"
@time pred_evo = EvoTrees.predict(m_evo_gpu, x_train);
@btime EvoTrees.predict($m_evo_gpu, $x_train);

# w_train = ones(length(y_train))
# @time m_evo_gpu = fit_evotree(params_evo, x_train, y_train);
# @time m_evo_gpu = fit_evotree(params_evo, x_train, y_train, w_train);
@time m_evo_gpu = fit_evotree(params_evo; x_train, y_train);
@time m_evo_gpu = fit_evotree(params_evo; x_train, y_train);
@time m_evo_gpu = fit_evotree(params_evo; x_train, y_train);
@time m_evo_gpu = fit_evotree(params_evo; x_train, y_train);
@time m_evo = fit_evotree(params_evo; x_train, y_train, x_eval=x_train, y_eval=y_train, metric=metric_evo, print_every_n=100);
# @btime fit_evotree($params_evo; x_train=$x_train, y_train=$y_train, x_eval=$x_train, y_eval=$y_train, metric=metric_evo);
# @info "evotrees predict GPU:"
# @time pred_evo = EvoTrees.predict(m_evo_gpu, x_train);
# @btime EvoTrees.predict($m_evo_gpu, $x_train);
10 changes: 5 additions & 5 deletions experiments/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ params1 = EvoTreeRegressor(T=Float32,

# asus laptopt: for 1.25e6 no eval: 9.650007 seconds (893.53 k allocations: 2.391 GiB, 5.52% gc time)
@time model = fit_evotree(params1; x_train, y_train);
@time model = fit_evotree(params1; x_train, y_train, metric=:mse, x_eval, y_eval, print_every_n=10);
@time model = fit_evotree(params1; x_train, y_train, metric=:mse, x_eval, y_eval, print_every_n=100);
@btime model = fit_evotree(params1; x_train, y_train);
@time pred_train = predict(model, x_train);
@btime pred_train = predict(model, x_train);
Expand Down Expand Up @@ -77,7 +77,7 @@ params1 = EvoTreeGaussian(T=Float32,
# train model
params1 = EvoTreeRegressor(T=Float32,
loss=:linear, metric=:mse,
nrounds=10,
nrounds=100,
lambda=1.0, gamma=0, eta=0.1,
max_depth=6, min_weight=1.0,
rowsample=0.5, colsample=0.5, nbins=64,
Expand All @@ -86,7 +86,7 @@ params1 = EvoTreeRegressor(T=Float32,
# Asus laptop: 10.015568 seconds (13.80 M allocations: 1.844 GiB, 4.00% gc time)
@time model = EvoTrees.fit_evotree(params1; x_train, y_train);
@btime model = EvoTrees.fit_evotree(params1; x_train, y_train);
@time model, cache = EvoTrees.init_evotree_gpu(params1, X_train, Y_train);
@time model, cache = EvoTrees.init_evotree_gpu(params1; x_train, y_train);
@time EvoTrees.grow_evotree!(model, cache);

using MLJBase
Expand Down Expand Up @@ -118,14 +118,14 @@ params1 = EvoTreeRegressor(T=Float32,
# GPU - Gaussian
################################
params1 = EvoTreeGaussian(T=Float32,
loss=:gaussian, metric=:gaussian,
loss=:gaussian,
nrounds=100,
lambda=1.0, gamma=0, eta=0.1,
max_depth=6, min_weight=1.0,
rowsample=0.5, colsample=0.5, nbins=32,
device="gpu")
# Asus laptop: 14.304369 seconds (24.81 M allocations: 2.011 GiB, 1.90% gc time)
@time model = EvoTrees.fit_evotree(params1, X_train, Y_train);
@time model = EvoTrees.fit_evotree(params1; x_train, y_train);
# Auss laptop: 1.888472 seconds (8.40 k allocations: 1.613 GiB, 14.86% gc time)
@time model, cache = EvoTrees.init_evotree(params1, X_train, Y_train);

Expand Down
12 changes: 8 additions & 4 deletions src/find_split.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,10 @@ function hist_gains_cpu!(
if bin == params.nbins
gains[bin] = hL[i]^2 / (hL[i+1] + params.lambda * hL[i+2]) / 2
elseif hL[i+2] > params.min_weight && hR[i+2] > params.min_weight
predL = pred_scalar_cpu!(hL[i:i+2], params, K)
predR = pred_scalar_cpu!(hR[i:i+2], params, K)
if monotone_constraint != 0
predL = pred_scalar_cpu!(view(hL, i:i+2), params, K)
predR = pred_scalar_cpu!(view(hR, i:i+2), params, K)
end
if (monotone_constraint == 0) ||
(monotone_constraint == -1 && predL > predR) ||
(monotone_constraint == 1 && predL < predR)
Expand Down Expand Up @@ -362,8 +364,10 @@ function hist_gains_cpu!(
hL[i+1]^2 / (hL[i+3] + params.lambda * hL[i+4])
) / 2
elseif hL[i+4] > params.min_weight && hR[i+4] > params.min_weight
predL = pred_scalar_cpu!(hL[i:i+4], params, K)
predR = pred_scalar_cpu!(hR[i:i+4], params, K)
if monotone_constraint != 0
predL = pred_scalar_cpu!(view(hL, i:i+4), params, K)
predR = pred_scalar_cpu!(view(hR, i:i+4), params, K)
end
if (monotone_constraint == 0) ||
(monotone_constraint == -1 && predL > predR) ||
(monotone_constraint == 1 && predL < predR)
Expand Down
2 changes: 1 addition & 1 deletion src/fit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ function init_evotree(

# assign monotone contraints in constraints vector
monotone_constraints = zeros(Int32, x_size[2])
hasproperty(params, :monotone_constraint) && for (k, v) in params.monotone_constraints
hasproperty(params, :monotone_constraints) && for (k, v) in params.monotone_constraints
monotone_constraints[k] = v
end

Expand Down
12 changes: 8 additions & 4 deletions src/gpu/find_split_gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,10 @@ function hist_gains_gpu_kernel!(gains::CuDeviceMatrix{T}, hL::CuDeviceArray{T,3}
if i == nbins
gains[i, j] = hL[1, i, j]^2 / (hL[2, i, j] + lambda * hL[3, i, j]) / 2
elseif hL[3, i, j] > min_weight && hR[3, i, j] > min_weight
predL = -hL[1, i, j] / (hL[2, i, j] + lambda * hL[3, i, j])
predR = -hR[1, i, j] / (hR[2, i, j] + lambda * hR[3, i, j])
if monotone_constraint != 0
predL = -hL[1, i, j] / (hL[2, i, j] + lambda * hL[3, i, j])
predR = -hR[1, i, j] / (hR[2, i, j] + lambda * hR[3, i, j])
end
if (monotone_constraint == 0) ||
(monotone_constraint == -1 && predL > predR) ||
(monotone_constraint == 1 && predL < predR)
Expand Down Expand Up @@ -281,8 +283,10 @@ function hist_gains_gpu_kernel_gauss!(gains::CuDeviceMatrix{T}, hL::CuDeviceArra
if i == nbins
gains[i, j] = (hL[1, i, j]^2 / (hL[3, i, j] + lambda * hL[5, i, j]) + hL[2, i, j]^2 / (hL[4, i, j] + lambda * hL[5, i, j])) / 2
elseif hL[5, i, j] > min_weight && hR[5, i, j] > min_weight
predL = -hL[1, i, j] / (hL[3, i, j] + lambda * hL[5, i, j])
predR = -hR[1, i, j] / (hR[3, i, j] + lambda * hR[5, i, j])
if monotone_constraint != 0
predL = -hL[1, i, j] / (hL[3, i, j] + lambda * hL[5, i, j])
predR = -hR[1, i, j] / (hR[3, i, j] + lambda * hR[5, i, j])
end
if (monotone_constraint == 0) ||
(monotone_constraint == -1 && predL > predR) ||
(monotone_constraint == 1 && predL < predR)
Expand Down
2 changes: 1 addition & 1 deletion src/gpu/fit_gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ function init_evotree_gpu(

# assign monotone contraints in constraints vector
monotone_constraints = zeros(Int32, x_size[2])
hasproperty(params, :monotone_constraint) && for (k, v) in params.monotone_constraints
hasproperty(params, :monotone_constraints) && for (k, v) in params.monotone_constraints
monotone_constraints[k] = v
end

Expand Down
6 changes: 3 additions & 3 deletions src/predict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ function pred_leaf_cpu!(
pred[1, n] = -params.eta * ∑[1] / (∑[2] + params.lambda * ∑[3])
end
function pred_scalar_cpu!(
::Vector{T},
::AbstractVector{T},
params::EvoTypes,
K,
) where {L<:GradientRegression,T,S}
Expand All @@ -123,7 +123,7 @@ function pred_leaf_cpu!(
pred[1, n] = -params.eta * ∑[1] / (∑[3] + params.lambda * ∑[5])
pred[2, n] = -params.eta * ∑[2] / (∑[4] + params.lambda * ∑[5])
end
function pred_scalar_cpu!(∑::Vector{T}, params::EvoTypes{L,T,S}, K) where {L<:MLE2P,T,S}
function pred_scalar_cpu!(∑::AbstractVector{T}, params::EvoTypes{L,T,S}, K) where {L<:MLE2P,T,S}
-params.eta * ∑[1] / (∑[3] + params.lambda * ∑[5])
end

Expand Down Expand Up @@ -171,6 +171,6 @@ function pred_leaf_cpu!(
) where {L<:L1Regression,T,S}
pred[1, n] = params.eta * ∑[1] / (∑[3] * (1 + params.lambda))
end
function pred_scalar_cpu!(∑::Vector, params::EvoTypes{L,T,S}, K) where {L<:L1Regression,T,S}
function pred_scalar_cpu!(∑::AbstractVector{T}, params::EvoTypes{L,T,S}, K) where {L<:L1Regression,T,S}
params.eta * ∑[1] / (∑[3] * (1 + params.lambda))
end
Loading

0 comments on commit e352808

Please sign in to comment.