diff --git a/R/min_depth_distribution.R b/R/min_depth_distribution.R index c82bb3e..41db93e 100644 --- a/R/min_depth_distribution.R +++ b/R/min_depth_distribution.R @@ -4,11 +4,8 @@ calculate_tree_depth <- function(frame){ stop("The data frame has to contain columns called 'right daughter' and 'left daughter'! It should be a product of the function getTree(..., labelVar = T).") } - # Both child values of leaf nodes are 0, i.e., lower than min(node_id) frame[["depth"]] <- calculate_tree_depth_( - node_id = seq_len(nrow(frame)), - left_child = frame[["left daughter"]], - right_child = frame[["right daughter"]] + frame[, c("left daughter", "right daughter")] ) return(frame) } @@ -19,22 +16,30 @@ calculate_tree_depth_ranger <- function(frame){ stop("The data frame has to contain columns called 'rightChild' and 'leftChild'! It should be a product of the function ranger::treeInfo().") } + # Child nodes are zero based, so we increase them by 1 frame[["depth"]] <- calculate_tree_depth_( - node_id = frame[["nodeID"]], - left_child = frame[["leftChild"]], - right_child = frame[["rightChild"]] + frame[, c("leftChild", "rightChild")] + 1 ) return(frame) } -# Internal function used to determine the depth of each node -calculate_tree_depth_ <- function(node_id, left_child, right_child) { - n <- length(node_id) - depth <- numeric(n) - for (i in 2:n) { - parent_node <- left_child %in% node_id[i] | right_child %in% node_id[i] - depth[i] <- depth[parent_node] + 1 +# Internal function used to determine the depth of each node. +# The input is a data.frame with left and right child nodes in 1:nrow(childs). +calculate_tree_depth_ <- function(childs) { + childs <- as.matrix(childs) + n <- nrow(childs) + depth <- rep(NA, times = n) + j <- depth[1L] <- 0 + ix <- 1L # current nodes, initialized with root node index + + # j loops over tree depth + while(anyNA(depth) && j < n) { # The second condition is never used + ix <- as.integer(childs[ix, ]) + ix <- ix[!is.na(ix) & ix >= 1L] # leaf nodes do not have childs + j <- j + 1 + depth[ix] <- j } + return(depth) }