From dcacb6b62cc0acc5d751c51b88465fbd58272209 Mon Sep 17 00:00:00 2001 From: Michael Mayer Date: Mon, 11 Mar 2024 23:19:36 +0100 Subject: [PATCH] Speed up calculate_tree_depth() (#35) --- R/min_depth_distribution.R | 36 ++++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/R/min_depth_distribution.R b/R/min_depth_distribution.R index 5beb787..7581eae 100644 --- a/R/min_depth_distribution.R +++ b/R/min_depth_distribution.R @@ -4,13 +4,12 @@ calculate_tree_depth <- function(frame){ stop("The data frame has to contain columns called 'right daughter' and 'left daughter'! It should be a product of the function getTree(..., labelVar = T).") } - frame$depth <- NA - frame$depth[1] <- 0 - for(i in 2:nrow(frame)){ - frame[i, "depth"] <- - frame[frame[, "left daughter"] == as.numeric(rownames(frame[i,])) | - frame[, "right daughter"] == as.numeric(rownames(frame[i,])), "depth"] + 1 - } + # Both child values of leaf nodes are 0, i.e., lower than min(node_id) + frame[["depth"]] <- calculate_tree_depth_( + node_id = seq_len(nrow(frame)), + left_child = frame[["left daughter"]], + right_child = frame[["right daughter"]] + ) return(frame) } @@ -20,16 +19,25 @@ calculate_tree_depth_ranger <- function(frame){ stop("The data frame has to contain columns called 'rightChild' and 'leftChild'! It should be a product of the function ranger::treeInfo().") } - frame$depth <- NA - frame$depth[1] <- 0 - for(i in 2:nrow(frame)){ - frame[i, "depth"] <- - frame[(!is.na(frame[, "leftChild"]) & frame[, "leftChild"] == frame[i, "nodeID"]) | - (!is.na(frame[, "rightChild"]) & frame[, "rightChild"] == frame[i, "nodeID"]), "depth"] + 1 - } + frame[["depth"]] <- calculate_tree_depth_( + node_id = frame[["nodeID"]], + left_child = frame[["leftChild"]], + right_child = frame[["rightChild"]] + ) return(frame) } +# Internal function used to determine the depth of each node +calculate_tree_depth_ <- function(node_id, left_child, right_child) { + n <- length(node_id) + depth <- numeric(n) + for (i in 2:n) { + parent_node <- left_child %in% node_id[i] | right_child %in% node_id[i] + depth[i] <- depth[parent_node] + 1 + } + return(depth) +} + #' Calculate minimal depth distribution of a random forest #' #' Get minimal depth values for all trees in a random forest