forked from KaroRonty/StockMarketLongTermForecast
-
Notifications
You must be signed in to change notification settings - Fork 0
/
XGBoostExplainer.R
40 lines (35 loc) · 1.45 KB
/
XGBoostExplainer.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# Explaining XGBoost results ----
library(xgboost)
library(xgboostExplainer)
# Make matrices for training and test data
training_xgb <- xgb.DMatrix(training %>%
select(-dates, -tenyear) %>%
as.matrix(), label = training %>%
pull(tenyear))
test_xgb <- xgb.DMatrix(future %>%
select(-dates, -tenyear) %>%
as.matrix())
# Train XGBoost model using the same hyperparameters as caret
xgb_explain <- xgboost(data = xgb.DMatrix(training %>%
select(-dates, -tenyear) %>%
as.matrix(),
label = training %>% pull(tenyear)),
nround = 150,
max_depth = 1,
eta = 0.3,
gamma = 0,
colsample_bytree = 0.8,
min_child_weight = 1,
subsample = 0.75)
# Make explainer object
explainer <- buildExplainer(xgb_explain,
training_xgb,
base_score = 0.5,
trees = NULL,
type = "regression"
)
# Plot waterfall chart
showWaterfall(xgb_explain, explainer, test_xgb, future %>%
select(-dates, -tenyear) %>%
as.matrix(),
idx = 1)