Skip to content

Commit

Permalink
Merge pull request #123 from ModelOriented/Hubert_issues
Browse files Browse the repository at this point in the history
Update to forester 1.5.0 with multiclass classification
  • Loading branch information
HubertR21 authored Feb 23, 2024
2 parents 71c68e6 + 745a570 commit 0c05ed8
Show file tree
Hide file tree
Showing 144 changed files with 8,724 additions and 7,193 deletions.
6 changes: 6 additions & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@
^\.Rproj\.user$
^\docs$
^docs$
^misc$
^_pkgdown\.yml$
_pkgdown.yaml
^pkgdown
^catboost_info$
^\catboost_info$
^tests/testthat/catboost_info
^report_binary.pdf
^report_multiclass.pdf
^report_regression.pdf
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,10 @@ R/not_to_run.R
*.tfevents
*.tsv
*.tsv
tests/testthat/checkpoints/*
tests/testthat/catboost_info/*
tests/testthat/testing_data_report.pdf
tests/testthat/lisbon_report.pdf
tests/testthat/iris_report.pdf
tests/testthat/compas_report.pdf
catboost_info
26 changes: 16 additions & 10 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
Package: forester
Type: Package
Title: Quick and Simple Tools for Training and Testing of Tree-based Models
Version: 1.4.2
Version: 1.5.0
Authors@R:
c(person("Hubert", "Ruczyński", role = c("aut", "cre"), email = "[email protected]"),
person("Anna", "Kozak", role = c("aut", "ths"), email = "[email protected]"),
person("Patryk", "Słowakiewicz", role = c("aut"), email = "[email protected]"),
person("Adrianna", "Grudzień", role = c("aut"), email = "[email protected]"),
person("Przemysław", "Biecek", role = c("aut", "ths"), email = "[email protected]"))
Description: Extendable set of tools to facilitate training, tuning and testing
of tree-based models.
Designed to quickly produce a good baseline model, that can be further tuned manually.
The primary value generated by the forester package is to promptly build a baseline for other,
more complex models.
Description: The forester package is an open-source AutoML package implemented in R designed for
training high-quality tree-based models on tabular data. It fully supports regression, binary
classification, and multiclass classification tasks, and provides a limited support for
the survival analysis task. A single line of code allows the use of unprocessed datasets, informs
about potential issues concerning them, and handles feature engineering automatically.
Moreover, hyperparameter tuning is performed by Bayesian optimization, which provides
high-quality outcomes. The results are later served as a ranked list of models. Finally, the
forester package offers a vast training report, including the ranked list, a comparison of
trained models, and explanations for the best one.
License: GPL-3
Depends:
R (>= 3.5.0),
patchwork
Imports:
ranger,
xgboost,
Expand All @@ -34,15 +41,14 @@ Imports:
stats,
tibble,
crayon,
VIM,
patchwork
VIM
Encoding: UTF-8
LazyData: true
RoxygenNote: 7.3.1
URL: https://github.com/ModelOriented/forester
BugReports: https://github.com/ModelOriented/forester/issues
Suggests:
testthat (>= 3.0.0),
testthat (>= 3.5.0),
catboost,
DALEX,
ggradar,
Expand All @@ -56,6 +62,6 @@ Suggests:
survival,
randomForestSRC,
sivs
Config/testthat/edition: 3
Config/testthat/edition: 3.5
VignetteBuilder:
knitr
2 changes: 1 addition & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Generated by roxygen2: do not edit by hand

S3method(plot,binary_clf)
S3method(plot,multiclass)
S3method(plot,regression)
export(basic_info)
export(best_model_predict)
Expand Down Expand Up @@ -35,7 +36,6 @@ export(impute_knn)
export(impute_mice)
export(manage_missing)
export(pre_rm_static_cols)
export(predict_models)
export(predict_models_all)
export(predict_new)
export(prepare_data)
Expand Down
19 changes: 19 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,22 @@
# forester 1.5.0

- Updated `.Rbuildignore` and `.gitignore`.
- Updated package `DESCRIPTION`, and `NAMESPACE`.
- Removed: `choose_best_models.R`, `forester_palette.R`, `format_model_details.R`
- Updated `check_data()`:
- Added more checks of input parameters,
- Modified `check_y_balance()` sub-function for multiclass classification.
- Updated `custom_preprocessing()`:
- Added parameter type,
- Added more checks of input parameters,
- Added more advanced verbose options.
- Added `multiclass classification` task, with all extensional functionalities included (report, explainability, plots, etc), which led to the abundance of changes in majority of functions.
- Added more input checks for `train_test_balance()`.
- Updated documentation.
- Added manual test for multiclass classification task.
- Moved old tests into `misc/old_tests_03_02_2023` folder.
- Created a new, and more in depth tests for the package from scratch for regression, binary classification, and multiclass classification tasks.

# forester 1.4.2

- In the DESCRIPTION updated the description, and RoxygenNote version.
Expand Down
126 changes: 101 additions & 25 deletions R/check_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,39 +8,66 @@
#' Either y, or pair: time, status can be used. By default NULL.
#' @param status A string that indicates a status column name for survival analysis task.
#' Either y, or pair: time, status can be used. By default NULL.
#' @param type A character, one of `binary_clf`/`regression`/`survival`/`auto`/`multiclass` that
#' sets the type of the task. If `auto` (the default option) then
#' the function will figure out `type` based on the number of unique values
#' in the `y` variable, or the presence of time/status columns.
#' @param verbose A logical value, if set to TRUE, provides all information about
#' the process, if FALSE gives none.
#'
#' @return A list with two vectors: lines of the report (str) and the outliers (outliers).
#' @export
#'
#' @examples
#' check_data(iris[1:100, ], 'Species')
#' check_data(lisbon, 'Price')
#' check_data(compas, 'Two_yr_Recidivism')
#' check_data(iris, 'Species')
#' check_data(lymph, 'class')
#' @importFrom stats IQR cor median sd
#' @importFrom utils capture.output
check_data <- function(data, y = NULL, time = NULL, status = NULL, verbose = TRUE) {
check_data <- function(data, y = NULL, time = NULL, status = NULL, type = 'auto', verbose = TRUE) {
options(warn = -1)

if (is.null(y)) {
if (is.null(time) | is.null(status)) {
verbose_cat(crayon::red('\u2716'), 'Lack of target variables. Please specify',
'either y (for classification or regression tasks), or time and',
'status (for survival analysis). \n\n', verbose = verbose)
return(NULL)
stop('Lack of target variables. Please specify either y (for classification
or regression tasks), or time and status (for survival analysis)')
}
} else {
if (!is.null(time) | !is.null(status)) {
verbose_cat(crayon::red('\u2716'), 'Provided too many targets. Please specify',
'either y (for classification or regression tasks), or time and',
'status (for survival analysis). \n\n', verbose = verbose)
return(NULL)
stop('Provided too many targets. Please specify either y (for classification
or regression tasks), or time and status (for survival analysis).')
}
}

if (type == 'auto') {
type <- guess_type(data, y)
if (type == 'regression') {
data[[y]] <- as.numeric(data[[y]])
}
verbose_cat(crayon::green('\u2714'), 'Type guessed as: ', type, '\n\n', verbose = verbose)
} else if (!type %in% c('regression', 'binary_clf', 'survival', 'multiclass')) {
verbose_cat(crayon::red('\u2716'), 'Invalid value. Correct task types are: `binary_clf`, `regression`, `survival`, `multiclass`, and `auto` for automatic task identification \n\n', verbose = verbose)
stop('Invalid value. Correct task types are: `binary_clf`, `regression`, `survival`, `multiclass`, and `auto` for automatic task identification')
} else {
verbose_cat(crayon::green('\u2714'), 'Type provided as: ', type, '\n\n', verbose = verbose)
}

if (type == 'survial') {
if (!status %in% colnames(data) || !time %in% colnames(data)) {
verbose_cat(crayon::red('\u2716'), 'Provided target column name for time or status parameters',
status, time, 'is not present in the datataset. \n\n', verbose = verbose)
stop('Provided target column name for time or status parameter is not present in the datataset.')
}
} else if (!y %in% colnames(data)) {
verbose_cat(crayon::red('\u2716'), 'Provided target column name for y parameter', y,
'is not present in the datataset. \n\n', verbose = verbose)
stop('Provided target column name for y parameter is not present in the datataset.')
}

df <- as.data.frame(data)
str <- capture.output(cat(' -------------------- **CHECK DATA REPORT** -------------------- \n \n'))
verbose_cat(' -------------------- CHECK DATA REPORT -------------------- \n \n', verbose = verbose)
Expand All @@ -53,7 +80,7 @@ check_data <- function(data, y = NULL, time = NULL, status = NULL, verbose = TRU
str <- c(str, check_cor(df, y, time, status, verbose)$str)
rtr <- check_outliers(df, verbose)
str <- c(str, rtr$str)
str <- c(str, check_y_balance(df, y, time, status, verbose))
str <- c(str, check_y_balance(df, y, time, status, type, verbose))
str <- c(str, detect_id_columns(df, verbose))
verbose_cat(' -------------------- CHECK DATA REPORT END -------------------- \n \n', verbose = verbose)
str <- c(str,
Expand Down Expand Up @@ -88,15 +115,13 @@ verbose_cat <- function(..., sep = ' ', verbose = TRUE) {
#' @param data A data source, that is one of the major R formats: data.table, data.frame,
#' matrix, and so on.
#' @param y A string that indicates a target column name.
#' @param max_unique_numeric An integer describing the maximal number of unique
#' values in `y` if `y` is numeric.
#' @param max_unique_not_numeric An integer describing the maximal number of unique
#' values in `y` if `y` is NOT numeric.
#' @param max_unique An integer describing the maximal number of unique
#' values in `y`.
#'
#' @return A string describing the type of ml task: `binary_clf`, `multi_clf`,
#' @return A string describing the type of ml task: `binary_clf`, `multiclass`,
#' `regression`, or `survival`.
#' @export
guess_type <- function(data, y, max_unique_numeric = 5, max_unique_not_numeric = 15) {
guess_type <- function(data, y, max_unique = 15) {

if (is.null(y)) {
type <- 'survival'
Expand All @@ -105,16 +130,17 @@ guess_type <- function(data, y, max_unique_numeric = 5, max_unique_not_numeric =
if (is.numeric(target)) {
if (length(unique(target)) == 2) {
type <- 'binary_clf'
} else if (length(unique(target)) <= max_unique_numeric) {
type <- 'multi_clf'
} else if (length(unique(target)) <= max_unique) {
type <- 'multiclass'
} else {
type <- 'regression'
}
} else if (length(unique(target)) == 2) {
type <- 'binary_clf'
} else if (length(unique(target)) <= max_unique_not_numeric) {
type <- 'multi_clf'
} else if (length(unique(target)) <= max_unique) {
type <- 'multiclass'
} else {
data[[y]] <- as.numeric(data[[y]])
type <- 'regression'
}
}
Expand Down Expand Up @@ -309,11 +335,11 @@ check_dim <- function(df, verbose = TRUE) {
str <- capture.output(cat('**Too big dimensionality with ', cols, ' colums. Forest models wont use so many of them. **\n', sep = ''))
}
if (cols >= rows) {
verbose_cat(crayon::red('\u2716'), ' More features than observations, try reducing dimensionality or add new observations. \n', verbose = verbose)
verbose_cat(crayon::red('\u2716'), 'More features than observations, try reducing dimensionality or add new observations. \n', verbose = verbose)
str <- capture.output(cat('**More features than observations, try reducing dimensionality or add new observations. **\n'))
}
if (cols < rows && cols <= 30) {
verbose_cat(crayon::green('\u2714'), ' No issues with dimensionality. \n', verbose = verbose)
verbose_cat(crayon::green('\u2714'), 'No issues with dimensionality. \n', verbose = verbose)
str <- capture.output(cat('**No issues with dimensionality. **\n'))
}
verbose_cat('\n', verbose = verbose)
Expand Down Expand Up @@ -563,14 +589,28 @@ check_outliers <- function(df, verbose = TRUE) {
#' Either y, or pair: time, status can be used. By default NULL.
#' @param status A string that indicates a status column name for survival analysis task.
#' Either y, or pair: time, status can be used. By default NULL.
#' @param type A character, one of `binary_clf`/`regression`/`survival`/`auto`/`multiclass` that
#' sets the type of the task. If `auto` (the default option) then
#' the function will figure out `type` based on the number of unique values
#' in the `y` variable, or the presence of time/status columns.
#' @param verbose A logical value, if set to TRUE, provides all information about
#' the process, if FALSE gives none.
#'
#' @return A list with every line of the sub-report.
#'
#' @export
check_y_balance <- function(df, y = NULL, time = NULL, status = NULL, verbose = TRUE) {
type <- guess_type(df, y)
check_y_balance <- function(df, y = NULL, time = NULL, status = NULL, type = 'auto', verbose = TRUE) {
# This part should not be provided in string for the report, as it only enhances
# the quality of standalone check_data function.
if (type == 'auto') {
type <- guess_type(df, y)
verbose_cat(crayon::green('\u2714'), 'Type guessed as:', type, '\n\n', verbose = verbose)
} else if (!type %in% c('regression', 'binary_clf', 'survival', 'multiclass')) {
verbose_cat(crayon::red('\u2716'), 'Invalid value. Correct task types are: `binary_clf`, `regression`, `survival`, `multiclass`, and `auto` for automatic task identification \n\n', verbose = verbose)
} else {
verbose_cat(crayon::green('\u2714'), 'Type provided as:', type, '\n\n', verbose = verbose)
}

# Distinction between survival analysis and other tasks.
if (!is.null(y)) {
target <- df[[y]]
Expand Down Expand Up @@ -635,9 +675,45 @@ check_y_balance <- function(df, y = NULL, time = NULL, status = NULL, verbose =
str <- capture.output(cat('**Target data is not evenly distributed with quantile bins:**', perc_bins, '\n'))
}

} else if (type == 'multi_clf') {
verbose_cat(crayon::green('\u2716'), 'Multilabel classification is not supported yet. \n', verbose = verbose)
str <- capture.output(cat('**Multilabel classification is not supported yet. **\n'))
} else if (type == 'multiclass') {

distribution <- table(target) / length(target)
num_class <- length(unique(target))
equal <- 1 / num_class
dominating <- c()
dominating_name <- c()
underrep <- c()
underrep_name <- c()

for (i in 1:length(distribution)) {
if (distribution[i] > 1.5 * equal) {
dominating_name <- c(dominating_name, names(distribution)[i])
dominating <- c(dominating, distribution[i])
} else if (distribution[i] < 0.75 * equal) {
underrep_name <- c(underrep_name, names(distribution)[i])
underrep <- c(underrep, distribution[i])
}
}

if (length(dominating) > 0 || length(underrep) > 0) {
verbose_cat(crayon::red('\u2716'), 'Target data is not evenly distributed. \n', verbose = verbose)
str <- capture.output(cat('**Target data is not evenly distributed. **\n\n'))
if (length(dominating) > 0) {
verbose_cat(crayon::red('\u2716'), 'The dominating classes are:', dominating_name, 'with shares equal to:', dominating,
'where the targeted ratio is', equal, '\n', verbose = verbose)
str <- c(str, capture.output(cat('**The dominating classes are:**', dominating_name, '\n\n**With shares equal to:**', dominating,
'\n\n**Where the targeted ratio is**', equal, '\n\n')))
}
if (length(underrep) > 0) {
verbose_cat(crayon::red('\u2716'), 'The underrepresented classes are:', underrep_name, 'with shares equal to:', underrep,
'where the targeted ratio is', equal, '\n', verbose = verbose)
str <- c(str, capture.output(cat('**The underrepresented classes are:**', underrep_name, '\n\n**With shares equal to:**', underrep,
'\n\n**Where the targeted ratio is**', equal, '\n')))
}
} else {
verbose_cat(crayon::green('\u2714'), 'Target data is evenly distributed. \n', verbose = verbose)
str <- capture.output(cat('**Target data is evenly distributed. **\n'))
}
}
verbose_cat('\n', verbose = verbose)
str <- c(str, capture.output(cat('\n')))
Expand Down
Loading

0 comments on commit 0c05ed8

Please sign in to comment.