From c3903edeacff8db0622d7c6505cce8f9a2ee4d18 Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Wed, 16 Oct 2024 14:09:34 -0400 Subject: [PATCH 01/48] Cleaning --- R/get_weights_from_mm.R | 6 +- R/match.data.R | 13 ++- R/match.qoi.R | 13 +-- R/matchit2cem.R | 106 +++++++++++++---------- R/matchit2full.R | 24 ++++-- R/matchit2genetic.R | 26 ++++-- R/matchit2optimal.R | 21 +++-- R/matchit2quick.R | 14 +++- R/matchit2subclass.R | 9 +- R/plot.matchit.R | 23 ++--- R/rbind.matchdata.R | 4 +- R/summary.matchit.R | 180 ++++++++++++++++++++++------------------ 12 files changed, 261 insertions(+), 178 deletions(-) diff --git a/R/get_weights_from_mm.R b/R/get_weights_from_mm.R index a7713d0e..2cf75edc 100644 --- a/R/get_weights_from_mm.R +++ b/R/get_weights_from_mm.R @@ -1,6 +1,8 @@ get_weights_from_mm <- function(match.matrix, treat) { - if (!is.integer(match.matrix)) match.matrix <- charmm2nummm(match.matrix, treat) + if (!is.integer(match.matrix)) { + match.matrix <- charmm2nummm(match.matrix, treat) + } weights <- weights_matrixC(match.matrix, treat) @@ -12,4 +14,4 @@ get_weights_from_mm <- function(match.matrix, treat) { .err("No control units were matched") setNames(weights, names(treat)) -} +} \ No newline at end of file diff --git a/R/match.data.R b/R/match.data.R index ce538411..a4de88c1 100644 --- a/R/match.data.R +++ b/R/match.data.R @@ -181,10 +181,12 @@ match.data <- function(object, group = "all", distance = "distance", weights = " if (is.null(data)) { env <- environment(object$formula) data <- try(eval(object$call$data, envir = env), silent = TRUE) - if (length(data) == 0 || inherits(data, "try-error") || length(dim(data)) != 2 || nrow(data) != length(object[["treat"]])) { + if (length(data) == 0 || inherits(data, "try-error") || + length(dim(data)) != 2 || nrow(data) != length(object[["treat"]])) { env <- parent.frame() data <- try(eval(object$call$data, envir = env), silent = TRUE) - if (length(data) == 0 || inherits(data, "try-error") || length(dim(data)) != 2 || nrow(data) != length(object[["treat"]])) { + if (length(data) == 0 || inherits(data, "try-error") || + length(dim(data)) != 2 || nrow(data) != length(object[["treat"]])) { data <- object[["model"]][["data"]] if (length(data) == 0 || nrow(data) != length(object[["treat"]])) { .err("a valid dataset could not be found. Please supply an argument to `data` containing the original dataset used in the matching") @@ -194,9 +196,12 @@ match.data <- function(object, group = "all", distance = "distance", weights = " } if (!is.data.frame(data)) { - if (is.matrix(data)) data <- as.data.frame.matrix(data) - else .err("`data` must be a data frame") + if (!is.matrix(data)) { + .err("`data` must be a data frame") + } + data <- as.data.frame.matrix(data) } + if (nrow(data) != length(object$treat)) { .err("`data` must have as many rows as there were units in the original call to `matchit()`") } diff --git a/R/match.qoi.R b/R/match.qoi.R index c4b4f01e..da05afca 100644 --- a/R/match.qoi.R +++ b/R/match.qoi.R @@ -1,5 +1,6 @@ ## Functions to calculate summary stats -bal1var <- function(xx, tt, ww = NULL, s.weights, subclass = NULL, mm = NULL, s.d.denom = "treated", standardize = FALSE, +bal1var <- function(xx, tt, ww = NULL, s.weights, subclass = NULL, mm = NULL, + s.d.denom = "treated", standardize = FALSE, compute.pair.dist = TRUE) { un <- is.null(ww) @@ -67,7 +68,8 @@ bal1var <- function(xx, tt, ww = NULL, s.weights, subclass = NULL, mm = NULL, s. xsum } -bal1var.subclass <- function(xx, tt, s.weights, subclass, s.d.denom = "treated", standardize = FALSE, which.subclass = NULL) { +bal1var.subclass <- function(xx, tt, s.weights, subclass, s.d.denom = "treated", + standardize = FALSE, which.subclass = NULL) { #Within-subclass balance statistics bin.var <- all(xx == 0 | xx == 1) in.sub <- !is.na(subclass) & subclass == which.subclass @@ -157,10 +159,12 @@ pair.dist <- function(xx, tt, subclass = NULL, mm = NULL, std = NULL, fast = TRU } else { mpdiff <- pairdistsubC(as.numeric(xx), as.integer(tt), - as.integer(subclass), nlevels(subclass)) + as.integer(subclass)) } } - else return(NA_real_) + else { + return(NA_real_) + } if (!is.null(std) && abs(mpdiff) > 1e-8) { mpdiff <- mpdiff/std @@ -243,5 +247,4 @@ qqsum <- function(x, t, w = NULL, standardize = FALSE) { } c(meandiff = mean(ediff), maxdiff = max(ediff)) - } \ No newline at end of file diff --git a/R/matchit2cem.R b/R/matchit2cem.R index 1f396b62..1f8c9faa 100644 --- a/R/matchit2cem.R +++ b/R/matchit2cem.R @@ -259,7 +259,7 @@ #' k2k = TRUE, k2k.method = "mahalanobis") #' m.out2 #' summary(m.out2, un = FALSE) -#' + NULL matchit2cem <- function(treat, covs, estimand = "ATT", s.weights = NULL, verbose = FALSE, ...) { @@ -314,25 +314,27 @@ cem_matchit <- function(treat, X, cutpoints = "sturges", grouping = list(), k2k #k2k now works with single covariates (previously it was ignored). k2k uses original variables, not coarsened versions if (k2k) { - if (length(unique(treat)) > 2) { + if (length(unique(treat)) > 2L) { .err("`k2k` cannot be `TRUE` with a multi-category treatment") } + if (!is.null(k2k.method)) { - k2k.method <- tolower(k2k.method) - k2k.method <- match_arg(k2k.method, c(matchit_distances(), "maximum", "manhattan", "canberra", "binary", "minkowski")) + k2k.method <- tolower(k2k.method) + k2k.method <- match_arg(k2k.method, c(matchit_distances(), "maximum", "manhattan", "canberra", "binary", "minkowski")) - X.match <- transform_covariates(data = X, s.weights = s.weights, treat = treat, - method = if (k2k.method %in% matchit_distances()) k2k.method else "euclidean") + X.match <- transform_covariates(data = X, s.weights = s.weights, treat = treat, + method = if (k2k.method %in% matchit_distances()) k2k.method else "euclidean") } } for (i in names(X)) { if (is.ordered(X[[i]])) X[[i]] <- as.numeric(X[[i]]) } + is.numeric.cov <- setNames(vapply(X, is.numeric, logical(1L)), names(X)) #Process grouping - if (length(grouping) > 0) { + if (length(grouping) > 0L) { if (!is.list(grouping) || is.null(names(grouping))) { .err("`grouping` must be a named list of grouping values with an element for each variable whose values are to be binned") } @@ -353,8 +355,10 @@ cem_matchit <- function(treat, X, cutpoints = "sturges", grouping = list(), k2k !is.list(g) || !all(vapply(g, function(gg) is.atomic(gg) && is.vector(gg), logical(1L))) }, logical(1L))] + nbg <- length(bag.groupings) - if (nbg > 0) { + + if (nbg > 0L) { .err(paste0("Each entry in the list supplied to `groupings` must be a list with entries containing values of the corresponding variable.", "\nIncorrectly specified variable%s:\n\t"), paste(bag.groupings, collapse = ", "), @@ -384,18 +388,23 @@ cem_matchit <- function(treat, X, cutpoints = "sturges", grouping = list(), k2k if (is.null(names(cutpoints))) { .err("`cutpoints` must be a named list of binning values with an element for each numeric variable") } + bad.names <- setdiff(names(cutpoints), names(X)) + nb <- length(bad.names) - if (nb > 0) { + + if (nb > 0L) { .wrn(sprintf("the variable%%s %s named in `cutpoints` %%r not in the variables supplied to `matchit()` and will be ignored", word_list(bad.names, quotes = 2, and.or = "and")), n = nb) cutpoints[bad.names] <- NULL } - if (length(grouping) > 0) { + if (length(grouping) > 0L) { grouping.cutpoint.names <- intersect(names(grouping), names(cutpoints)) + ngc <- length(grouping.cutpoint.names) - if (ngc > 0) { + + if (ngc > 0L) { .wrn(sprintf("the variable%%s %s %%r named in both `grouping` and `cutpoints`; %s entr%%y%%s in `cutpoints` will be ignored", word_list(grouping.cutpoint.names, quotes = 2, and.or = "and"), ngettext(ngc, "its", "their")), n = ngc) @@ -404,43 +413,44 @@ cem_matchit <- function(treat, X, cutpoints = "sturges", grouping = list(), k2k } non.numeric.in.cutpoints <- intersect(names(X)[!is.numeric.cov], names(cutpoints)) + nnnic <- length(non.numeric.in.cutpoints) - if (nnnic > 0) { + + if (nnnic > 0L) { .wrn(sprintf("the variable%%s %s named in `cutpoints` %%r not numeric and %s cutpoints will not be applied. Use `grouping` for non-numeric variables", word_list(non.numeric.in.cutpoints, quotes = 2, and.or = "and"), ngettext(nnnic, "its", "their")), n = nnnic) } bad.cuts <- setNames(rep(FALSE, length(cutpoints)), names(cutpoints)) + for (i in names(cutpoints)) { - if (length(cutpoints[[i]]) == 0) { + if (length(cutpoints[[i]]) == 0L) { cutpoints[[i]] <- "sturges" } - else if (length(cutpoints[[i]]) == 1) { - if (is.na(cutpoints[[i]])) is.numeric.cov[i] <- FALSE #Will not be binned - else if (is.character(cutpoints[[i]])) { - bad.cuts[i] <- !(startsWith(cutpoints[[i]], "q") && can_str2num(substring(cutpoints[[i]], 2))) && - is.na(pmatch(cutpoints[[i]], c("sturges", "fd", "scott"))) - } - else if (is.numeric(cutpoints[[i]])) { - if (!is.finite(cutpoints[[i]]) || cutpoints[[i]] < 0) { - bad.cuts[i] <- TRUE - } - else if (cutpoints[[i]] == 0) { - is.numeric.cov[i] <- FALSE #Will not be binned - } - else if (cutpoints[[i]] == 1) { - X[[i]] <- NULL #Removing from X, still in X.match - is.numeric.cov <- is.numeric.cov[names(is.numeric.cov) != i] - } - } - else { - bad.cuts[i] <- TRUE - } - } - else { + else if (length(cutpoints[[i]]) > 1L) { bad.cuts[i] <- !is.numeric(cutpoints[[i]]) } + else if (is.na(cutpoints[[i]])) { + is.numeric.cov[i] <- FALSE #Will not be binned + } + else if (is.character(cutpoints[[i]])) { + bad.cuts[i] <- !(startsWith(cutpoints[[i]], "q") && can_str2num(substring(cutpoints[[i]], 2))) && + is.na(pmatch(cutpoints[[i]], c("sturges", "fd", "scott"))) + } + else if (!is.numeric(cutpoints[[i]])) { + bad.cuts[i] <- TRUE + } + else if (!is.finite(cutpoints[[i]]) || cutpoints[[i]] < 0) { + bad.cuts[i] <- TRUE + } + else if (cutpoints[[i]] == 0) { + is.numeric.cov[i] <- FALSE #Will not be binned + } + else if (cutpoints[[i]] == 1) { + X[[i]] <- NULL #Removing from X, still in X.match + is.numeric.cov <- is.numeric.cov[names(is.numeric.cov) != i] + } } if (any(bad.cuts)) { @@ -449,14 +459,16 @@ cem_matchit <- function(treat, X, cutpoints = "sturges", grouping = list(), k2k "\n\t- a single number corresponding to the number of bins", "\n\t- a numeric vector containing the cut points separating bins", "\nIncorrectly specified variable%s:\n\t"), - paste(names(cutpoints)[bad.cuts], collapse = ", "), + paste(names(cutpoints)[bad.cuts], collapse = ", "), tidy = FALSE, n = sum(bad.cuts)) } #Create bins for numeric variables for (i in names(X)[is.numeric.cov]) { - if (is.null(cutpoints) || !i %in% names(cutpoints)) bins <- "sturges" - else bins <- cutpoints[[i]] + bins <- { + if (!is.null(cutpoints) && i %in% names(cutpoints)) cutpoints[[i]] + else "sturges" + } if (is.character(bins)) { if (startsWith(bins, "q") || can_str2num(substring(bins, 2))) { @@ -474,7 +486,7 @@ cem_matchit <- function(treat, X, cutpoints = "sturges", grouping = list(), k2k } } - if (length(bins) == 1) { + if (length(bins) == 1L) { #cutpoints is number of bins, unlike in cem breaks <- seq(min(X[[i]]), max(X[[i]]), length = bins + 1) breaks[c(1, bins + 1)] <- c(-Inf, Inf) @@ -486,15 +498,16 @@ cem_matchit <- function(treat, X, cutpoints = "sturges", grouping = list(), k2k X[[i]] <- findInterval(X[[i]], breaks) } - if (length(X) == 0) { - subclass <- setNames(rep(1, length(treat)), names(treat)) + if (length(X) == 0L) { + subclass <- setNames(rep(1L, length(treat)), names(treat)) } else { #Exact match xx <- exactify(X, names(treat)) + cc <- do.call("intersect", unname(split(xx, treat))) - if (length(cc) == 0) { + if (length(cc) == 0L) { .err("no units were matched. Try coarsening the variables further or decrease the number of variables to match on") } @@ -525,7 +538,7 @@ cem_matchit <- function(treat, X, cutpoints = "sturges", grouping = list(), k2k dist.mat <- eucdist_internal(X.match[in.sub,,drop = FALSE], treat[in.sub] == s) } else { - dist.mat <- dist_to_matrixC(dist(X.match[in.sub,,drop = FALSE], method = k2k.method, p = mpower)) + dist.mat <- as.matrix(dist(X.match[in.sub,,drop = FALSE], method = k2k.method, p = mpower)) #Put smaller group on rows d.rows <- which(rownames(dist.mat) %in% names(treat[in.sub])[treat[in.sub] == s]) @@ -544,8 +557,9 @@ cem_matchit <- function(treat, X, cutpoints = "sturges", grouping = list(), k2k } #If any unmatched units remain, give them NA subclass - if (any(dim(dist.mat) > 0)) is.na(subclass)[unlist(dimnames(dist.mat))] <- TRUE - + if (any(dim(dist.mat) > 0)) { + is.na(subclass)[unlist(dimnames(dist.mat))] <- TRUE + } } } diff --git a/R/matchit2full.R b/R/matchit2full.R index 8da98e25..7fd625bd 100644 --- a/R/matchit2full.R +++ b/R/matchit2full.R @@ -272,8 +272,12 @@ matchit2full <- function(treat, formula, data, distance, discarded, if (!is.null(exact)) { ex <- factor(exactify(model.frame(exact, data = data), sep = ", ", include_vars = TRUE)[!discarded]) + cc <- intersect(as.integer(ex)[treat_==1], as.integer(ex)[treat_==0]) - if (length(cc) == 0) .err("No matches were found") + + if (length(cc) == 0L) { + .err("No matches were found") + } } else { ex <- factor(rep("_", length(treat_)), levels = "_") @@ -324,6 +328,7 @@ matchit2full <- function(treat, formula, data, distance, discarded, cov.cals <- setdiff(names(caliper), "") calcovs <- get.covs.matrix(reformulate(cov.cals, intercept = FALSE), data = data) } + for (i in seq_along(caliper)) { if (names(caliper)[i] != "") { mo_cal <- optmatch::match_on(setNames(calcovs[!discarded, names(caliper)[i]], names(treat_)), z = treat_) @@ -337,6 +342,7 @@ matchit2full <- function(treat, formula, data, distance, discarded, mo <- mo + optmatch::caliper(mo_cal, caliper[i]) } + rm(mo_cal) } @@ -356,8 +362,11 @@ matchit2full <- function(treat, formula, data, distance, discarded, } else mo_ <- mo - if (any(dim(mo_) == 0) || !any(is.finite(mo_))) next - else if (all(dim(mo_) == 1) && all(is.finite(mo_))) { + if (any(dim(mo_) == 0) || !any(is.finite(mo_))) { + next + } + + if (all(dim(mo_) == 1) && all(is.finite(mo_))) { pair[ex == e] <- paste(1, e, sep = "|") next } @@ -374,8 +383,13 @@ matchit2full <- function(treat, formula, data, distance, discarded, pair[names(p[[e]])[!is.na(p[[e]])]] <- paste(as.character(p[[e]][!is.na(p[[e]])]), e, sep = "|") } - if (all(is.na(pair))) .err("No matches were found") - if (length(p) == 1) p <- p[[1]] + if (all(is.na(pair))) { + .err("No matches were found") + } + + if (length(p) == 1L) { + p <- p[[1]] + } psclass <- factor(pair) levels(psclass) <- seq_len(nlevels(psclass)) diff --git a/R/matchit2genetic.R b/R/matchit2genetic.R index 9c5c3927..cd0fcec9 100644 --- a/R/matchit2genetic.R +++ b/R/matchit2genetic.R @@ -273,7 +273,7 @@ matchit2genetic <- function(treat, data, distance, discarded, if (!replace) { if (sum(!discarded & treat != focal) < sum(!discarded & treat == focal)) { .wrn(sprintf("fewer %s units than %s units; not all %s units will get a match", - tc[2], tc[1], tc[1])) + tc[2], tc[1], tc[1])) } else if (sum(!discarded & treat != focal) < sum(!discarded & treat == focal)*ratio) { .err(sprintf("not enough %s units for %s matches for each %s unit", @@ -314,7 +314,7 @@ matchit2genetic <- function(treat, data, distance, discarded, X <- cbind(covs_to_balance, distance) } - if (ncol(covs_to_balance) == 0) { + if (ncol(covs_to_balance) == 0L) { .err("covariates must be specified in the input formula to use genetic matching") } @@ -330,7 +330,9 @@ matchit2genetic <- function(treat, data, distance, discarded, exact.log <- c(rep(FALSE, ncol(X) - 1), TRUE) } - else exact.log <- ex <- NULL + else { + exact.log <- ex <- NULL + } #Process caliper; cal will be supplied to GenMatch() and Match() if (!is.null(caliper)) { @@ -343,7 +345,9 @@ matchit2genetic <- function(treat, data, distance, discarded, #Expand exact.log for newly added covariates if (!is.null(exact.log)) exact.log <- c(exact.log, rep(FALSE, ncol(calcovs))) } - else cov.cals <- NULL + else { + cov.cals <- NULL + } #Matching::Match multiplies calipers by pop SD, so we need to divide by pop SD to unstandardize pop.sd <- function(x) sqrt(sum((x-mean(x))^2)/length(x)) @@ -376,8 +380,9 @@ matchit2genetic <- function(treat, data, distance, discarded, cal[ncol(covs_to_balance) + 1] <- dist.cal } } - else dist.cal <- NULL - + else { + dist.cal <- NULL + } } else { cal <- dist.cal <- cov.cals <- NULL @@ -399,6 +404,7 @@ matchit2genetic <- function(treat, data, distance, discarded, t(combn(which(antiexactcovs[,i] == u), 2)) })) })), -1) + if (!is.null(A[["restrict"]])) A[["restrict"]] <- rbind(A[["restrict"]], antiexact_restrict) else A[["restrict"]] <- antiexact_restrict } @@ -441,9 +447,11 @@ matchit2genetic <- function(treat, data, distance, discarded, replace = replace, estimand = "ATT", ties = FALSE, weights = s.weights, CommonSupport = FALSE, distance.tolerance = A[["distance.tolerance"]], Weight = 3, - Weight.matrix = if (use.genetic) g.out - else if (is.null(s.weights)) generalized_inverse(cor(X)) - else generalized_inverse(cov.wt(X, s.weights, cor = TRUE)$cor), + Weight.matrix = { + if (use.genetic) g.out + else if (is.null(s.weights)) generalized_inverse(cor(X)) + else generalized_inverse(cov.wt(X, s.weights, cor = TRUE)$cor) + }, restrict = A[["restrict"]], version = "fast") }, from = "Matching", dont_warn_if = "replace==FALSE, but there are more (weighted) treated obs than control obs.") diff --git a/R/matchit2optimal.R b/R/matchit2optimal.R index 89c63b7a..7ac0ed18 100644 --- a/R/matchit2optimal.R +++ b/R/matchit2optimal.R @@ -300,7 +300,10 @@ matchit2optimal <- function(treat, formula, data, distance, discarded, sep = ", ", include_vars = TRUE)[!discarded]) cc <- intersect(as.integer(ex)[treat_==1], as.integer(ex)[treat_==0]) - if (length(cc) == 0) .err("No matches were found") + + if (length(cc) == 0L) { + .err("No matches were found") + } e_ratios <- vapply(levels(ex), function(e) sum(treat_[ex == e] == 0)/sum(treat_[ex == e] == 1), numeric(1L)) @@ -387,8 +390,11 @@ matchit2optimal <- function(treat, formula, data, distance, discarded, } else mo_ <- mo - if (any(dim(mo_) == 0) || !any(is.finite(mo_))) next - else if (all(dim(mo_) == 1) && all(is.finite(mo_))) { + if (any(dim(mo_) == 0) || !any(is.finite(mo_))) { + next + } + + if (all(dim(mo_) == 1) && all(is.finite(mo_))) { pair[ex == e] <- paste(1, e, sep = "|") next } @@ -425,8 +431,13 @@ matchit2optimal <- function(treat, formula, data, distance, discarded, pair[names(p[[e]])[!is.na(p[[e]])]] <- paste(as.character(p[[e]][!is.na(p[[e]])]), e, sep = "|") } - if (all(is.na(pair))) .err("No matches were found") - if (length(p) == 1) p <- p[[1]] + if (all(is.na(pair))) { + .err("No matches were found") + } + + if (length(p) == 1L) { + p <- p[[1]] + } psclass <- factor(pair) levels(psclass) <- seq_len(nlevels(psclass)) diff --git a/R/matchit2quick.R b/R/matchit2quick.R index cd310975..4cfef36d 100644 --- a/R/matchit2quick.R +++ b/R/matchit2quick.R @@ -176,9 +176,12 @@ matchit2quick <- function(treat, formula, data, distance, discarded, if (!is.null(exact)) { ex <- factor(exactify(model.frame(exact, data = data), sep = ", ", include_vars = TRUE)[!discarded]) + cc <- intersect(as.integer(ex)[treat_==1], as.integer(ex)[treat_==0]) - if (length(cc) == 0) .err("no matches were found") + if (length(cc) == 0L) { + .err("no matches were found") + } } else { ex <- factor(rep("_", length(treat_)), levels = "_") @@ -234,8 +237,13 @@ matchit2quick <- function(treat, formula, data, distance, discarded, pair[which(ex == e)[!is.na(p[[e]])]] <- paste(as.character(p[[e]][!is.na(p[[e]])]), e, sep = "|") } - if (all(is.na(pair))) .err("no matches were found") - if (length(p) == 1) p <- p[[1]] + if (all(is.na(pair))) { + .err("no matches were found") + } + + if (length(p) == 1L) { + p <- p[[1]] + } psclass <- factor(pair) levels(psclass) <- seq_len(nlevels(psclass)) diff --git a/R/matchit2subclass.R b/R/matchit2subclass.R index aca5a792..b8a9c57a 100644 --- a/R/matchit2subclass.R +++ b/R/matchit2subclass.R @@ -212,7 +212,7 @@ matchit2subclass <- function(treat, distance, discarded, psclass[!discarded] <- as.integer(findInterval(distance[!discarded], q, all.inside = TRUE)) if (length(unique(na.omit(psclass))) != subclass){ - .wrn("Due to discreteness in the distance measure, fewer subclasses were generated than were requested") + .wrn("due to discreteness in the distance measure, fewer subclasses were generated than were requested") } if (min.n == 0) { @@ -220,11 +220,14 @@ matchit2subclass <- function(treat, distance, discarded, is.na(psclass)[!discarded & !psclass %in% intersect(psclass[!discarded & treat == 1], psclass[!discarded & treat == 0])] <- TRUE } - else if (any(table(treat, psclass) < min.n)) { + else { ## If any subclasses don't have members of a treatment group, fill them ## by "scooting" units from nearby subclasses until each subclass has a unit ## from each treatment group - psclass[!discarded] <- subclass_scoot(psclass[!discarded], treat[!discarded], distance[!discarded], min.n) + psclass[!discarded] <- subclass_scoot(psclass[!discarded], + treat[!discarded], + distance[!discarded], + min.n) } psclass <- setNames(factor(psclass, nmax = length(q)), names(treat)) diff --git a/R/plot.matchit.R b/R/plot.matchit.R index 464c8deb..2b839f8c 100644 --- a/R/plot.matchit.R +++ b/R/plot.matchit.R @@ -152,13 +152,13 @@ plot.matchit <- function(x, type = "qq", interactive = TRUE, which.xs = NULL, da if (is.null(x$distance)) { .err("`type = \"jitter\"` cannot be used if a distance measure is not estimated or supplied. No plots generated") } - jitter.pscore(x, interactive = interactive,...) + jitter_pscore(x, interactive = interactive,...) } else if (type == "histogram") { if (is.null(x$distance)) { .err("`type = \"hist\"` cannot be used if a distance measure is not estimated or supplied. No plots generated") } - hist.pscore(x,...) + hist_pscore(x,...) } invisible(x) } @@ -229,13 +229,13 @@ plot.matchit.subclass <- function(x, type = "qq", interactive = TRUE, which.xs = if (is.null(x$distance)) { .err("`type = \"jitter\"` cannot be used when no distance variable was estimated or supplied") } - jitter.pscore(x, interactive = interactive, ...) + jitter_pscore(x, interactive = interactive, ...) } else if (type == "histogram") { if (is.null(x$distance)) { .err("`type = \"histogram\"` cannot be used when no distance variable was estimated or supplied") } - hist.pscore(x,...) + hist_pscore(x,...) } invisible(x) } @@ -244,7 +244,7 @@ plot.matchit.subclass <- function(x, type = "qq", interactive = TRUE, which.xs = matchit.covplot <- function(object, type = "qq", interactive = TRUE, which.xs = NULL, data = NULL, ...) { if (is.null(which.xs)) { - if (length(object$X) == 0) { + if (length(object$X) == 0L) { .wrn("No covariates to plot") return(invisible(NULL)) } @@ -283,7 +283,7 @@ matchit.covplot <- function(object, type = "qq", interactive = TRUE, which.xs = if (is.character(which.xs)) { if (!all(which.xs %in% names(data))) { - .err("All variables in `which.xs` must be in the supplied `matchit` object or in `data`") + .err("all variables in `which.xs` must be in the supplied `matchit` object or in `data`") } X <- data[which.xs] } @@ -400,7 +400,7 @@ matchit.covplot.subclass <- function(object, type = "qq", which.subclass = NULL, interactive = TRUE, which.xs = NULL, data = NULL, ...) { if (is.null(which.xs)) { - if (length(object$X) == 0) { + if (length(object$X) == 0L) { .wrn("No covariates to plot") return(invisible(NULL)) } @@ -439,7 +439,7 @@ matchit.covplot.subclass <- function(object, type = "qq", which.subclass = NULL, if (is.character(which.xs)) { if (!all(which.xs %in% names(data))) { - .err("All variables in `which.xs` must be in the supplied `matchit` object or in `data`") + .err("all variables in `which.xs` must be in the supplied `matchit` object or in `data`") } X <- data[which.xs] } @@ -448,7 +448,7 @@ matchit.covplot.subclass <- function(object, type = "qq", which.subclass = NULL, X <- model.frame(which.xs, data, na.action = "na.pass") if (anyNA(X)) { - .err("Missing values are not allowed in the covariates named in `which.xs`") + .err("missing values are not allowed in the covariates named in `which.xs`") } } else { @@ -466,6 +466,7 @@ matchit.covplot.subclass <- function(object, type = "qq", which.subclass = NULL, if (!is.atomic(which.subclass)) { .err("The argument to `subclass` must be NULL or the indices of the subclasses for which to display covariate distributions") } + if (!all(which.subclass %in% object$subclass[!is.na(object$subclass)])) { .err("The argument supplied to `subclass` is not the index of any subclass in the matchit object") } @@ -813,7 +814,7 @@ densityplot_match <- function(x, t, w, sw, ...) { } } -hist.pscore <- function(x, xlab = "Propensity Score", freq = FALSE, ...){ +hist_pscore <- function(x, xlab = "Propensity Score", freq = FALSE, ...) { .pardefault <- par(no.readonly = TRUE) on.exit(par(.pardefault)) @@ -870,7 +871,7 @@ hist.pscore <- function(x, xlab = "Propensity Score", freq = FALSE, ...){ } -jitter.pscore <- function(x, interactive, pch = 1, ...){ +jitter_pscore <- function(x, interactive, pch = 1, ...) { .pardefault <- par(no.readonly = TRUE) on.exit(par(.pardefault)) diff --git a/R/rbind.matchdata.R b/R/rbind.matchdata.R index 9089155a..2fde0137 100644 --- a/R/rbind.matchdata.R +++ b/R/rbind.matchdata.R @@ -84,8 +84,8 @@ rbind.matchdata <- function(..., deparse.level = 1) { allargs$deparse.level <- deparse.level type <- intersect(c("matchdata", "getmatches"), unlist(lapply(md_list, class))) - if (length(type) == 0) .err("A `matchdata` or `getmatches` object must be supplied") - if (length(type) == 2) .err("Supplied objects must be all `matchdata` objects or all `getmatches` objects") + if (length(type) == 0L) .err("A `matchdata` or `getmatches` object must be supplied") + if (length(type) == 2L) .err("Supplied objects must be all `matchdata` objects or all `getmatches` objects") attrs <- c("distance", "weights", "subclass", "id") attr_list <- setNames(vector("list", length(attrs)), attrs) diff --git a/R/summary.matchit.R b/R/summary.matchit.R index 590a4a8b..81926909 100644 --- a/R/summary.matchit.R +++ b/R/summary.matchit.R @@ -272,8 +272,10 @@ summary.matchit <- function(object, if (!no_x && interactions) { n.int <- kk*(kk+1)/2 - if (un) sum.all.int <- matrix(NA_real_, nrow = n.int, ncol = length(aa.all[[1]]), dimnames = list(NULL, names(aa.all[[1]]))) - if (matched) sum.matched.int <- matrix(NA_real_, nrow = n.int, ncol = length(aa.matched[[1]]), dimnames = list(NULL, names(aa.matched[[1]]))) + if (un) sum.all.int <- matrix(NA_real_, nrow = n.int, ncol = length(aa.all[[1]]), + dimnames = list(NULL, names(aa.all[[1]]))) + if (matched) sum.matched.int <- matrix(NA_real_, nrow = n.int, ncol = length(aa.matched[[1]]), + dimnames = list(NULL, names(aa.matched[[1]]))) to.remove <- rep(FALSE, n.int) int.names <- character(n.int) @@ -312,6 +314,7 @@ summary.matchit <- function(object, rownames(sum.all.int) <- int.names sum.all <- rbind(sum.all, sum.all.int[!to.remove,,drop = FALSE]) } + if (matched) { rownames(sum.matched.int) <- int.names sum.matched <- rbind(sum.matched, sum.matched.int[!to.remove,,drop = FALSE]) @@ -411,13 +414,13 @@ summary.matchit.subclass <- function(object, chk::chk_flag(un) chk::chk_flag(improvement) - if (standardize) { - s.d.denom <- switch(object$estimand, - "ATT" = "treated", - "ATC" = "control", - "ATE" = "pooled") + s.d.denom <- { + if (standardize) switch(object$estimand, + "ATT" = "treated", + "ATC" = "control", + "ATE" = "pooled") + else NULL } - else s.d.denom <- NULL if (isTRUE(which.subclass)) which.subclass <- subclasses else if (isFALSE(which.subclass)) which.subclass <- NULL @@ -458,13 +461,15 @@ summary.matchit.subclass <- function(object, if (interactions) { n.int <- kk*(kk+1)/2 - if (un) sum.all.int <- matrix(NA_real_, nrow = n.int, ncol = length(aa.all[[1]]), dimnames = list(NULL, names(aa.all[[1]]))) - if (matched) sum.matched.int <- matrix(NA_real_, nrow = n.int, ncol = length(aa.matched[[1]]), dimnames = list(NULL, names(aa.matched[[1]]))) + if (un) sum.all.int <- matrix(NA_real_, nrow = n.int, ncol = length(aa.all[[1]]), + dimnames = list(NULL, names(aa.all[[1]]))) + if (matched) sum.matched.int <- matrix(NA_real_, nrow = n.int, ncol = length(aa.matched[[1]]), + dimnames = list(NULL, names(aa.matched[[1]]))) to.remove <- rep(FALSE, n.int) int.names <- character(n.int) k <- 1 - for (i in 1:kk) { + for (i in seq_len(kk)) { for (j in i:kk) { x2 <- X[,i] * X[,j] if (all(abs(x2) < sqrt(.Machine$double.eps)) || @@ -535,7 +540,9 @@ summary.matchit.subclass <- function(object, #bal1var.subclass only returns unmatched stats, which is all we need within #subclasses. Otherwise, identical to matched stats. aa <- setNames(lapply(seq_len(kk), function(i) { - bal1var.subclass(X[,i], tt = treat, s.weights = s.weights, subclass = subclass, s.d.denom = s.d.denom, standardize = standardize, which.subclass = s) + bal1var.subclass(X[,i], tt = treat, s.weights = s.weights, + subclass = subclass, s.d.denom = s.d.denom, + standardize = standardize, which.subclass = s) }), colnames(X)) sum.sub <- matrix(NA_real_, nrow = kk, ncol = ncol(aa[[1]]), dimnames = list(nam, colnames(aa[[1]]))) @@ -545,7 +552,8 @@ summary.matchit.subclass <- function(object, sum.sub[i,] <- aa[[i]] } if (interactions) { - sum.sub.int <- matrix(NA_real_, nrow = kk*(kk+1)/2, ncol = length(aa[[1]]), dimnames = list(NULL, names(aa[[1]]))) + sum.sub.int <- matrix(NA_real_, nrow = kk*(kk+1)/2, ncol = length(aa[[1]]), + dimnames = list(NULL, names(aa[[1]]))) to.remove <- rep(FALSE, nrow(sum.sub.int)) int.names <- character(nrow(sum.sub.int)) k <- 1 @@ -553,7 +561,9 @@ summary.matchit.subclass <- function(object, for (j in i:kk) { if (!to.remove[k]) { #to.remove defined above x2 <- X[,i] * X[,j] - jqoi <- bal1var.subclass(x2, tt = treat, s.weights = s.weights, subclass = subclass, s.d.denom = s.d.denom, standardize = standardize, which.subclass = s) + jqoi <- bal1var.subclass(x2, tt = treat, s.weights = s.weights, + subclass = subclass, s.d.denom = s.d.denom, + standardize = standardize, which.subclass = s) sum.sub.int[k,] <- jqoi if (i == j) { int.names[k] <- paste0(nam[i], "\u00B2") @@ -599,9 +609,11 @@ summary.matchit.subclass <- function(object, #' @exportS3Method print summary.matchit #' @rdname summary.matchit print.summary.matchit <- function(x, digits = max(3, getOption("digits") - 3), - ...){ + ...) { - if (!is.null(x$call)) cat("\nCall:", deparse(x$call), sep = "\n") + if (!is.null(x$call)) { + cat("\nCall:", deparse(x$call), sep = "\n") + } if (!is.null(x$sum.all)) { cat("\nSummary of Balance for All Data:\n") @@ -641,7 +653,9 @@ print.summary.matchit <- function(x, digits = max(3, getOption("digits") - 3), #' @exportS3Method print summary.matchit.subclass print.summary.matchit.subclass <- function(x, digits = max(3, getOption("digits") - 3), ...){ - if (!is.null(x$call)) cat("\nCall:", deparse(x$call), sep = "\n") + if (!is.null(x$call)) { + cat("\nCall:", deparse(x$call), sep = "\n") + } if (!is.null(x$sum.all)) { cat("\nSummary of Balance for All Data:\n") @@ -700,94 +714,94 @@ print.summary.matchit.subclass <- function(x, digits = max(3, getOption("digits" else get.covs.matrix(data = object$X) } - if (!is.null(addlvariables)) { + if (is.null(addlvariables)) { + return(X) + } - #Attempt to extrct data from matchit object; same as match.data() - data.fram.matchit <- FALSE - if (is.null(data)) { - env <- environment(object$formula) + #Attempt to extrct data from matchit object; same as match.data() + data.fram.matchit <- FALSE + if (is.null(data)) { + env <- environment(object$formula) + data <- try(eval(object$call$data, envir = env), silent = TRUE) + if (length(data) == 0 || inherits(data, "try-error") || length(dim(data)) != 2 || nrow(data) != length(object[["treat"]])) { + env <- parent.frame() data <- try(eval(object$call$data, envir = env), silent = TRUE) if (length(data) == 0 || inherits(data, "try-error") || length(dim(data)) != 2 || nrow(data) != length(object[["treat"]])) { - env <- parent.frame() - data <- try(eval(object$call$data, envir = env), silent = TRUE) - if (length(data) == 0 || inherits(data, "try-error") || length(dim(data)) != 2 || nrow(data) != length(object[["treat"]])) { - data <- object[["model"]][["data"]] - if (length(data) == 0 || nrow(data) != length(object[["treat"]])) { - data <- NULL - } - else data.fram.matchit <- TRUE + data <- object[["model"]][["data"]] + if (length(data) == 0 || nrow(data) != length(object[["treat"]])) { + data <- NULL } else data.fram.matchit <- TRUE } else data.fram.matchit <- TRUE } + else data.fram.matchit <- TRUE + } - if (is.character(addlvariables)) { - if (!is.null(data) && is.data.frame(data)) { - if (all(addlvariables %in% names(data))) { - addlvariables <- data[addlvariables] - } - else { - .err("All variables in `addlvariables` must be in `data`") - } - } - else { - .err("If `addlvariables` is specified as a string, a data frame argument must be supplied to `data`") - } + if (is.character(addlvariables)) { + if (is.null(data) || !is.data.frame(data)) { + .err("If `addlvariables` is specified as a string, a data frame argument must be supplied to `data`") } - else if (inherits(addlvariables, "formula")) { - vars.in.formula <- all.vars(addlvariables) - if (!is.null(data) && is.data.frame(data)) { - data <- data.frame(data[names(data) %in% vars.in.formula], - object$X[names(object$X) %in% setdiff(vars.in.formula, names(data))]) - } - else data <- object$X - # addlvariables <- get.covs.matrix(addlvariables, data = data) - } - else if (!is.matrix(addlvariables) && !is.data.frame(addlvariables)) { - .err("The argument to `addlvariables` must be in one of the accepted forms. See `?summary.matchit` for details") + if (!all(addlvariables %in% names(data))) { + .err("All variables in `addlvariables` must be in `data`") } - - if (af <- inherits(addlvariables, "formula")) { - addvariables_f <- addlvariables - addlvariables <- model.frame(addvariables_f, data = data, na.action = "na.pass") + addlvariables <- data[addlvariables] + } + else if (inherits(addlvariables, "formula")) { + vars.in.formula <- all.vars(addlvariables) + if (!is.null(data) && is.data.frame(data)) { + data <- data.frame(data[names(data) %in% vars.in.formula], + object$X[names(object$X) %in% setdiff(vars.in.formula, names(data))]) } + else data <- object$X - if (nrow(addlvariables) != length(object$treat)) { - if (is.null(data) || data.fram.matchit) { - .err("Variables specified in `addlvariables` must have the same number of units as are present in the original call to `matchit()`") - } - else { - .err("`data` must have the same number of units as are present in the original call to `matchit()`") - } - } + # addlvariables <- get.covs.matrix(addlvariables, data = data) + } + else if (!is.matrix(addlvariables) && !is.data.frame(addlvariables)) { + .err("The argument to `addlvariables` must be in one of the accepted forms. See `?summary.matchit` for details") + } - k <- ncol(addlvariables) - for (i in seq_len(k)) { - if (anyNA(addlvariables[[i]]) || (is.numeric(addlvariables[[i]]) && any(!is.finite(addlvariables[[i]])))) { - covariates.with.missingness <- names(addlvariables)[i:k][vapply(i:k, function(j) anyNA(addlvariables[[j]]) || - (is.numeric(addlvariables[[j]]) && - any(!is.finite(addlvariables[[j]]))), - logical(1L))] - .err(paste0("Missing and non-finite values are not allowed in `addlvariables`. Variables with missingness or non-finite values:\n\t", - paste(covariates.with.missingness, collapse = ", ")), tidy = FALSE) - } - if (is.character(addlvariables[[i]])) addlvariables[[i]] <- factor(addlvariables[[i]]) - } - if (af) { - addlvariables <- get.covs.matrix(addvariables_f, data = data) + if (af <- inherits(addlvariables, "formula")) { + addvariables_f <- addlvariables + addlvariables <- model.frame(addvariables_f, data = data, na.action = "na.pass") + } + + if (nrow(addlvariables) != length(object$treat)) { + if (is.null(data) || data.fram.matchit) { + .err("variables specified in `addlvariables` must have the same number of units as are present in the original call to `matchit()`") } else { - addlvariables <- get.covs.matrix(data = addlvariables) + .err("`data` must have the same number of units as are present in the original call to `matchit()`") + } + } + + k <- ncol(addlvariables) + for (i in seq_len(k)) { + if (anyNA(addlvariables[[i]]) || (is.numeric(addlvariables[[i]]) && any(!is.finite(addlvariables[[i]])))) { + covariates.with.missingness <- names(addlvariables)[i:k][vapply(i:k, function(j) anyNA(addlvariables[[j]]) || + (is.numeric(addlvariables[[j]]) && + any(!is.finite(addlvariables[[j]]))), + logical(1L))] + .err(paste0("Missing and non-finite values are not allowed in `addlvariables`. Variables with missingness or non-finite values:\n\t", + paste(covariates.with.missingness, collapse = ", ")), tidy = FALSE) } + if (is.character(addlvariables[[i]])) { + addlvariables[[i]] <- factor(addlvariables[[i]]) + } + } - # addl_assign <- get_assign(addlvariables) - X <- cbind(X, addlvariables[, setdiff(colnames(addlvariables), colnames(X)), drop = FALSE]) + if (af) { + addlvariables <- get.covs.matrix(addvariables_f, data = data) } + else { + addlvariables <- get.covs.matrix(data = addlvariables) + } + + # addl_assign <- get_assign(addlvariables) + cbind(X, addlvariables[, setdiff(colnames(addlvariables), colnames(X)), drop = FALSE]) - X } \ No newline at end of file From 099df172071692b656f39a20f5fdd26b61ff7801 Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Wed, 16 Oct 2024 14:10:28 -0400 Subject: [PATCH 02/48] Improvements to subclass_scoot (now in Rcpp) and exactify --- R/aux_functions.R | 203 ++++++++++++++++------------------------ src/subclass_scootC.cpp | 149 +++++++++++++++++++++++++++++ 2 files changed, 232 insertions(+), 120 deletions(-) create mode 100644 src/subclass_scootC.cpp diff --git a/R/aux_functions.R b/R/aux_functions.R index 54ffa38c..92d9e84b 100644 --- a/R/aux_functions.R +++ b/R/aux_functions.R @@ -4,87 +4,28 @@ #from other subclasses. From WeightIt. subclass_scoot <- function(sub, treat, x, min.n = 1) { #Reassigns subclasses so there are no empty subclasses - #for each treatment group. Copied from WeightIt with - #slight modifications. + #for each treatment group. + subtab <- table(treat, sub) - treat <- as.character(treat) - unique.treat <- unique(treat, nmax = 2) - - names(x) <- seq_along(x) - names(sub) <- seq_along(sub) - original.order <- names(x) - - nsub <- length(unique(sub)) + if (all(subtab >= min.n)) { + return(sub) + } - #Turn subs into a contiguous sequence - sub <- setNames(setNames(seq_len(nsub), sort(unique(sub)))[as.character(sub)], - original.order) + nsub <- ncol(subtab) - if (any(table(treat) < nsub * min.n)) { + if (any(rowSums(subtab) < nsub * min.n)) { .err(sprintf("not enough units to fit %s treated and control %s in each subclass", min.n, ngettext(min.n, "unit", "units"))) } - for (t in unique.treat) { - if (length(x[treat == t]) == nsub) { - sub[treat == t] <- seq_len(nsub) - } - } - - sub_tab <- table(treat, sub) - - if (any(sub_tab < min.n)) { - - soft_thresh <- function(x, minus = 1) { - x <- x - minus - x[x < 0] <- 0 - x - } - - for (t in unique.treat) { - for (n in seq_len(min.n)) { - while (any(sub_tab[t,] == 0)) { - first_0 <- which(sub_tab[t,] == 0)[1] - - if (first_0 == nsub || - (first_0 != 1 && - sum(soft_thresh(sub_tab[t, seq(1, first_0 - 1)]) / abs(first_0 - seq(1, first_0 - 1))) >= - sum(soft_thresh(sub_tab[t, seq(first_0 + 1, nsub)]) / abs(first_0 - seq(first_0 + 1, nsub))))) { - #If there are more and closer nonzero subs to the left... - first_non0_to_left <- max(seq(1, first_0 - 1)[sub_tab[t, seq(1, first_0 - 1)] > 0]) - - name_to_move <- names(sub)[which(x == max(x[treat == t & sub == first_non0_to_left]) & treat == t & sub == first_non0_to_left)[1]] - - sub[name_to_move] <- first_0 - sub_tab[t, first_0] <- 1L - sub_tab[t, first_non0_to_left] <- sub_tab[t, first_non0_to_left] - 1L - - } - else { - #If there are more and closer nonzero subs to the right... - first_non0_to_right <- min(seq(first_0 + 1, nsub)[sub_tab[t, seq(first_0 + 1, nsub)] > 0]) - - name_to_move <- names(sub)[which(x == min(x[treat == t & sub == first_non0_to_right]) & treat == t & sub == first_non0_to_right)[1]] - - sub[name_to_move] <- first_0 - sub_tab[t, first_0] <- 1L - sub_tab[t, first_non0_to_right] <- sub_tab[t, first_non0_to_right] - 1L - } - } - - sub_tab[t,] <- sub_tab[t,] - 1 - } - } - - #Unsort - sub <- sub[names(sub)] - } - - sub + subclass_scootC(as.integer(sub), as.integer(treat), + as.numeric(x), as.integer(min.n)) } #Create info component of matchit object -create_info <- function(method, fn1, link, discard, replace, ratio, mahalanobis, transform, subclass, antiexact, distance_is_matrix) { +create_info <- function(method, fn1, link, discard, replace, ratio, + mahalanobis, transform, subclass, antiexact, + distance_is_matrix) { info <- list(method = method, distance = if (is.null(fn1)) NULL else sub("distance2", "", fn1, fixed = TRUE), link = if (is.null(link)) NULL else link, @@ -104,24 +45,33 @@ create_info <- function(method, fn1, link, discard, replace, ratio, mahalanobis, info.to.method <- function(info) { out.list <- setNames(vector("list", 3), c("kto1", "type", "replace")) - out.list[["kto1"]] <- if (!is.null(info$ratio)) paste0(if (!is.null(info$max.controls)) "variable ratio ", round(info$ratio, 2), ":1") else NULL - out.list[["type"]] <- if (is.null(info$method)) "none (no matching)" else - switch(info$method, - "exact" = "exact matching", - "cem" = "coarsened exact matching", - "nearest" = "nearest neighbor matching", - "optimal" = "optimal pair matching", - "full" = "optimal full matching", - "quick" = "generalized full matching", - "genetic" = "genetic matching", - "subclass" = paste0("subclassification (", info$subclass, " subclasses)"), - "cardinality" = "cardinality matching", - if (is.null(attr(info$method, "method"))) "an unspecified matching method" - else attr(info$method, "method")) - out.list[["replace"]] <- if (!is.null(info$replace) && info$method %in% c("nearest", "genetic")) { - if (info$replace) "with replacement" + + out.list[["kto1"]] <- { + if (!is.null(info$ratio)) paste0(if (!is.null(info$max.controls)) "variable ratio ", round(info$ratio, 2), ":1") + else NULL + } + + out.list[["type"]] <- { + if (is.null(info$method)) "none (no matching)" + else switch(info$method, + "exact" = "exact matching", + "cem" = "coarsened exact matching", + "nearest" = "nearest neighbor matching", + "optimal" = "optimal pair matching", + "full" = "optimal full matching", + "quick" = "generalized full matching", + "genetic" = "genetic matching", + "subclass" = sprintf("subclassification (%s subclasses)", info$subclass), + "cardinality" = "cardinality matching", + if (is.null(attr(info$method, "method"))) "an unspecified matching method" + else attr(info$method, "method")) + } + + out.list[["replace"]] <- { + if (is.null(info$replace) || !info$method %in% c("nearest", "genetic")) NULL + else if (info$replace) "with replacement" else "without replacement" - } else NULL + } firstup(do.call("paste", c(unname(out.list), list(sep = " ")))) } @@ -270,8 +220,8 @@ match_arg <- function(arg, choices, several.ok = FALSE) { i <- pmatch(arg, choices, nomatch = 0L, duplicates.ok = TRUE) if (all(i == 0L)) .err(sprintf("the argument to `%s` should be %s%s.", - arg.name, ngettext(length(choices), "", if (several.ok) "at least one of " else "one of "), - word_list(choices, and.or = "or", quotes = 2))) + arg.name, ngettext(length(choices), "", if (several.ok) "at least one of " else "one of "), + word_list(choices, and.or = "or", quotes = 2))) i <- i[i > 0L] choices[i] @@ -318,34 +268,40 @@ binarize <- function(variable, zero = NULL, one = NULL) { } #Make interaction vector out of matrix of covs; similar to interaction() -exactify <- function(X, nam = NULL, sep = "|", include_vars = FALSE) { +exactify <- function(X, nam = NULL, sep = "|", include_vars = FALSE, justify = "right") { if (is.null(nam)) nam <- rownames(X) if (is.matrix(X)) X <- setNames(lapply(seq_len(ncol(X)), function(i) X[,i]), colnames(X)) if (!is.list(X)) stop("X must be a matrix, data frame, or list.") - if (include_vars) { - for (i in seq_along(X)) { - if (is.character(X[[i]]) || is.factor(X[[i]])) { - X[[i]] <- sprintf('%s = "%s"', names(X)[i], X[[i]]) - } - else { - X[[i]] <- sprintf('%s = %s', names(X)[i], X[[i]]) - } + for (i in seq_along(X)) { + unique_x <- { + if (is.factor(X[[i]])) levels(X[[i]]) + else sort(unique(X[[i]])) } - } - else { - for (i in seq_along(X)) { - if (is.factor(X[[i]])) { - X[[i]] <- format(levels(X[[i]]), justify = "right")[X[[i]]] - } - else { - X[[i]] <- format(X[[i]], justify = "right") + + lev <- { + if (include_vars) { + sprintf("%s = %s", + names(X)[i], + add_quotes(unique_x, is.character(X[[i]]) || is.factor(X[[i]]))) } + else if (is.null(justify)) unique_x + else format(unique_x, justify = justify) } + + X[[i]] <- factor(X[[i]], levels = unique_x, labels = lev) } + all_levels <- do.call("paste", c(rev(expand.grid(rev(lapply(X, levels)), + KEEP.OUT.ATTRS = FALSE, stringsAsFactors = FALSE)), + sep = sep)) + out <- do.call("paste", c(X, sep = sep)) + + out <- factor(out, levels = all_levels[all_levels %in% out]) + if (!is.null(nam)) names(out) <- nam + out } @@ -353,7 +309,7 @@ exactify <- function(X, nam = NULL, sep = "|", include_vars = FALSE) { can_str2num <- function(x) { nas <- is.na(x) suppressWarnings(x_num <- as.numeric(as.character(x[!nas]))) - !anyNA(x_num) + !anyNA(x_num) } #Cleanly coerces a character vector to numeric; best to use after can_str2num() @@ -373,8 +329,8 @@ firstup <- function(x) { #Capitalize first letter of each word capwords <- function(s, strict = FALSE) { cap <- function(s) paste0(toupper(substring(s, 1, 1)), - {s <- substring(s, 2); if(strict) tolower(s) else s}, - collapse = " ") + {s <- substring(s, 2); if(strict) tolower(s) else s}, + collapse = " ") sapply(strsplit(s, split = " "), cap, USE.NAMES = !is.null(names(s))) } @@ -465,7 +421,7 @@ get.covs.matrix <- function(formula = NULL, data = NULL) { contrasts.arg = lapply(Filter(is.factor, mf), contrasts, contrasts = FALSE)) assign <- attr(X, "assign")[-1] - X <- X[,-1,drop=FALSE] + X <- X[,-1, drop = FALSE] attr(X, "assign") <- assign X @@ -481,13 +437,15 @@ get_assign <- function(mat) { #Convert match.matrix (mm) using numerical indices to using char rownames nummm2charmm <- function(nummm, treat) { #Assumes nummm has rownames - charmm <- array(NA_character_, dim = dim(nummm), dimnames = dimnames(nummm)) + charmm <- array(NA_character_, dim = dim(nummm), + dimnames = dimnames(nummm)) charmm[] <- names(treat)[nummm] charmm } charmm2nummm <- function(charmm, treat) { - nummm <- array(NA_integer_, dim = dim(charmm)) + nummm <- array(NA_integer_, dim = dim(charmm), + dimnames = dimnames(charmm)) n_index <- setNames(seq_along(treat), names(treat)) nummm[] <- n_index[charmm] nummm @@ -496,14 +454,16 @@ charmm2nummm <- function(charmm, treat) { #Get subclass from match.matrix. Only to be used if replace = FALSE. See subclass2mmC.cpp for reverse. mm2subclass <- function(mm, treat) { lab <- names(treat) - ind1 <- which(treat == 1) + mmlab <- rownames(mm) + no.match <- is.na(mm) subclass <- setNames(rep(NA_character_, length(treat)), lab) - no.match <- is.na(mm) - subclass[ind1[!no.match[,1]]] <- ind1[!no.match[,1]] - subclass[mm[!no.match]] <- ind1[row(mm)[!no.match]] - subclass <- setNames(factor(subclass, nmax = length(ind1)), lab) + subclass[mmlab[!no.match[,1]]] <- mmlab[!no.match[,1]] + + subclass[mm[!no.match]] <- mmlab[row(mm)[!no.match]] + + subclass <- setNames(factor(subclass, nmax = length(mmlab)), lab) levels(subclass) <- seq_len(nlevels(subclass)) subclass @@ -646,7 +606,8 @@ nn <- function(treat, weights, discarded = NULL, s.weights = NULL) { if (is.null(discarded)) discarded <- rep(FALSE, length(treat)) if (is.null(s.weights)) s.weights <- rep(1, length(treat)) weights <- weights * s.weights - n <- matrix(0, ncol=2, nrow=6, dimnames = list(c("All (ESS)", "All", "Matched (ESS)","Matched", "Unmatched","Discarded"), + n <- matrix(0, ncol=2, nrow=6, dimnames = list(c("All (ESS)", "All", "Matched (ESS)", + "Matched", "Unmatched","Discarded"), c("Control", "Treated"))) # Control Treated @@ -671,10 +632,12 @@ qn <- function(treat, subclass, discarded = NULL) { qn <- cbind(qn, table(treat[is.na(subclass) & !discarded])) colnames(qn)[ncol(qn)] <- "Unmatched" } + if (any(discarded)) { qn <- cbind(qn, table(treat[discarded])) colnames(qn)[ncol(qn)] <- "Discarded" } + qn <- rbind(qn, colSums(qn)) rownames(qn)[nrow(qn)] <- "Total" @@ -805,7 +768,7 @@ pkg_caller_call <- function(start = 1) { matchit_try <- function(expr, from = NULL, dont_warn_if = NULL) { tryCatch({ withCallingHandlers({ - expr + expr }, warning = function(w) { if (is.null(dont_warn_if) || !grepl(dont_warn_if, conditionMessage(w), fixed = TRUE)) { diff --git a/src/subclass_scootC.cpp b/src/subclass_scootC.cpp new file mode 100644 index 00000000..9ea0ee98 --- /dev/null +++ b/src/subclass_scootC.cpp @@ -0,0 +1,149 @@ +#include +#include "internal.h" +using namespace Rcpp; + + +// [[Rcpp::export]] +IntegerVector subclass_scootC(const IntegerVector& subclass_, + const IntegerVector& treat_, + const NumericVector& x_, + const int& min_n) { + + if (min_n == 0) { + return subclass_; + } + + int m, i, s, s2; + int best_i, nt; + double best_x, score; + + LogicalVector na_sub = is_na(subclass_); + + IntegerVector subclass = subclass_[!na_sub]; + IntegerVector treat = treat_[!na_sub]; + NumericVector x = x_[!na_sub]; + + int n = subclass.size(); + + IntegerVector unique_sub = unique(subclass); + std::sort(unique_sub.begin(), unique_sub.end()); + + subclass = match(subclass, unique_sub) - 1; + + int nsub = unique_sub.size(); + + IntegerVector subtab(nsub); + IntegerVector indt; + bool left = false; + + IntegerVector ut = unique(treat); + + for (int t : ut) { + indt = which(treat == t); + nt = indt.size(); + + //Tabulate + subtab = rep(0, nsub); + for (int i : indt) { + subtab[subclass[i]]++; + } + + for (m = 0; m < min_n; m++) { + while (min(subtab) <= 0) { + for (s = 0; s < nsub; s++) { + if (subtab[s] == 0) { + break; + } + } + + //Find which way to look for new member + if (s == nsub - 1) { + left = true; + } + else if (s == 0) { + left = false; + } + else { + score = 0.; + + for (s2 = 0; s2 < nsub; s2++) { + if (subtab[s2] <= 1) { + continue; + } + + if (s2 == s) { + continue; + } + + score += static_cast(subtab[s2] - 1) / static_cast(s2 - s); + } + + left = (score <= 0); + } + + //Find which subclass to take from (s2) + if (left) { + for (s2 = s - 1; s2 >= 0; s2--) { + if (subtab[s2] > 0) { + break; + } + } + } + else { + for (s2 = s + 1; s2 < nsub; s2++) { + if (subtab[s2] > 0) { + break; + } + } + } + + //Find unit with closest x in that subclass to take + for (i = 0; i < nt; i++) { + if (subclass[indt[i]] == s2) { + best_i = i; + best_x = x[indt[i]]; + break; + } + } + + for (i = best_i + 1; i < nt; i++) { + if (subclass[indt[i]] != s2) { + continue; + } + + if (left) { + if (x[indt[i]] < best_x) { + continue; + } + } + else { + if (x[indt[i]] > best_x) { + continue; + } + } + + best_i = i; + best_x = x[indt[i]]; + } + + subclass[indt[best_i]] = s; + subtab[s]++; + subtab[s2]--; + } + + for (s = 0; s < nsub; s++) { + subtab[s]--; + } + } + } + + for (i = 0; i < n; i++) { + subclass[i] = unique_sub[subclass[i]]; + } + + IntegerVector sub_out = rep(NA_INTEGER, subclass_.size()); + + sub_out[!na_sub] = subclass; + + return sub_out; +} \ No newline at end of file From 1adff4bb22ec6b078ae8d835d7cb4bc88b6aa49a Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Wed, 16 Oct 2024 14:12:17 -0400 Subject: [PATCH 03/48] dist_to_matrix removed (as.matrix now faster); Rcpp for computing n1xn0 distance matrix --- R/dist_functions.R | 18 ++++++++++-------- src/dist_to_matrix.cpp | 30 ------------------------------ src/eucdistC.cpp | 34 ++++++++++++++++++++++++++++++++++ 3 files changed, 44 insertions(+), 38 deletions(-) delete mode 100644 src/dist_to_matrix.cpp create mode 100644 src/eucdistC.cpp diff --git a/R/dist_functions.R b/R/dist_functions.R index 98ab444f..b8174089 100644 --- a/R/dist_functions.R +++ b/R/dist_functions.R @@ -328,26 +328,28 @@ transform_covariates <- function(formula = NULL, data = NULL, method = "mahalano #Internal function for fast(ish) Euclidean distance eucdist_internal <- function(X, treat = NULL) { - if (is.null(dim(X))) X <- as.matrix(X) - if (is.null(treat)) { - if (ncol(X) == 1) { + if (NCOL(X) == 1L) { d <- abs(outer(drop(X), drop(X), "-")) } else { - d <- dist_to_matrixC(dist(X)) + d <- as.matrix(dist(X)) } + dimnames(d) <- list(rownames(X), rownames(X)) } else { treat_l <- as.logical(treat) - if (ncol(X) == 1) { - d <- abs(outer(X[treat_l,], X[!treat_l,], "-")) + + if (NCOL(X) == 1L) { + d <- abs(outer(X[treat_l], X[!treat_l], "-")) } else { - d <- dist(X) - d <- dist_to_matrixC(d)[treat_l, !treat_l, drop = FALSE] + # d <- dist(X) + # d <- as.matrix(d)[treat_l, !treat_l, drop = FALSE] + d <- eucdistC_N1xN0(X, as.integer(treat)) } + dimnames(d) <- list(rownames(X)[treat_l], rownames(X)[!treat_l]) } diff --git a/src/dist_to_matrix.cpp b/src/dist_to_matrix.cpp deleted file mode 100644 index 4884023f..00000000 --- a/src/dist_to_matrix.cpp +++ /dev/null @@ -1,30 +0,0 @@ -#include -using namespace Rcpp; -//Faster alternative to stats:::as.matrix.dist(). - -// [[Rcpp::export]] -NumericMatrix dist_to_matrixC(const NumericVector& d) { - int n = d.attr("Size"); - - NumericMatrix m(n,n); - - double dk; - int i, j; - int k = 0; - - for (i = 0; i < n; i++) { - for (j = i + 1; j < n; j++) { - dk = d[k]; - m(i, j) = dk; - m(j, i) = dk; - k++; - } - } - - if (d.hasAttribute("Labels")) { - CharacterVector lab = d.attr("Labels"); - rownames(m) = lab; - colnames(m) = lab; - } - return m; -} diff --git a/src/eucdistC.cpp b/src/eucdistC.cpp new file mode 100644 index 00000000..509ca20a --- /dev/null +++ b/src/eucdistC.cpp @@ -0,0 +1,34 @@ +#include +#include "internal.h" +using namespace Rcpp; +// [[Rcpp::plugins(cpp11)]] + +// [[Rcpp::export]] +NumericVector eucdistC_N1xN0(const NumericMatrix& x, + const IntegerVector& t) { + + IntegerVector ind0 = which(t == 0); + IntegerVector ind1 = which(t == 1); + int p = x.ncol(); + int i; + double d, di; + + NumericVector dist(ind1.size() * ind0.size()); + + int k = 0; + for (double i0 : ind0) { + for (double i1 : ind1) { + d = 0; + for (i = 0; i < p; i++) { + di = x(i0, i) - x(i1, i); + d += di * di; + } + dist[k] = sqrt(d); + k++; + } + } + + dist.attr("dim") = Dimension(ind1.size(), ind0.size()); + + return dist; +} \ No newline at end of file From 1995398a70a47b48688ec96f5b33272f04731327 Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Wed, 16 Oct 2024 14:12:56 -0400 Subject: [PATCH 04/48] Cleaning --- R/matchit2cardinality.R | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/R/matchit2cardinality.R b/R/matchit2cardinality.R index 12c65330..f3de1e6b 100644 --- a/R/matchit2cardinality.R +++ b/R/matchit2cardinality.R @@ -355,6 +355,8 @@ matchit2cardinality <- function(treat, data, discarded, formula, #Apply std.tols if (any(std.tols)) { + std.tols <- which(std.tols) + sds <- { if (estimand == "ATE") { pooled_sd(X[, std.tols, drop = FALSE], t = treat, @@ -366,10 +368,9 @@ matchit2cardinality <- function(treat, data, discarded, formula, } } - zero.sds <- sds < 1e-10 - - X[,std.tols][,!zero.sds] <- scale(X[, std.tols, drop = FALSE][,!zero.sds, drop = FALSE], - center = FALSE, scale = sds[!zero.sds]) + for (i in which(sds >= 1e-10)) { + X[,std.tols[i]] <- X[,std.tols[i]] / sds[i] + } } opt.out <- setNames(vector("list", nlevels(ex)), levels(ex)) @@ -457,9 +458,11 @@ cardinality_matchit <- function(treat, X, estimand = "ATT", tols = .05, s.weight highs = "highs")) #Select match type - if (estimand == "ATE") match_type <- "profile_ate" - else if (!is.finite(ratio)) match_type <- "profile_att" - else match_type <- "cardinality" + match_type <- { + if (estimand == "ATE") "profile_ate" + else if (!is.finite(ratio)) "profile_att" + else "cardinality" + } #Set objective and constraints if (match_type == "profile_ate") { From da05977db4a374eab4284c5edc9a874922fb970f Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Wed, 16 Oct 2024 14:14:40 -0400 Subject: [PATCH 05/48] Improvements to NN matching algorithms including better support for m.order = "closest" with PS in large samples --- src/internal.cpp | 392 +++++++++++++++++++++++- src/internal.h | 66 ++++ src/nn_matchC.cpp | 554 +++++++++++++++++++++------------- src/nn_matchC_closest.cpp | 150 +++++---- src/nn_matchC_vec.cpp | 538 ++++++++++----------------------- src/nn_matchC_vec_closest.cpp | 310 +++++++++++++++++++ 6 files changed, 1352 insertions(+), 658 deletions(-) create mode 100644 src/nn_matchC_vec_closest.cpp diff --git a/src/internal.cpp b/src/internal.cpp index 3c98d77b..63a717b9 100644 --- a/src/internal.cpp +++ b/src/internal.cpp @@ -1,6 +1,9 @@ #include +#include using namespace Rcpp; +// [[Rcpp::plugins(cpp11)]] + // Rcpp internal functions //C implementation of tabulate(). Faster than base::tabulate(), but real @@ -8,17 +11,20 @@ using namespace Rcpp; // [[Rcpp::interfaces(cpp)]] IntegerVector tabulateC_(const IntegerVector& bins, - const Nullable& nbins = R_NilValue) { + const Nullable& nbins = R_NilValue) { int max_bin; + if (nbins.isNotNull()) max_bin = as(nbins); else max_bin = max(na_omit(bins)); IntegerVector counts(max_bin); int n = bins.size(); for (int i = 0; i < n; i++) { - if (bins[i] > 0 && bins[i] <= max_bin) + if (bins[i] > 0 && bins[i] <= max_bin) { counts[bins[i] - 1]++; + } } + return counts; } @@ -28,4 +34,386 @@ IntegerVector tabulateC_(const IntegerVector& bins, IntegerVector which(const LogicalVector& x) { IntegerVector ind = Range(0, x.size() - 1); return ind[x]; +} + +// [[Rcpp::interfaces(cpp)]] +bool antiexact_okay(const int& aenc, + const int& ii, + const int& i, + const IntegerMatrix& antiexact_covs) { + if (aenc == 0) { + return true; + } + + for (int j = 0; j < aenc; j++) { + if (antiexact_covs(ii, j) == antiexact_covs(i, j)) { + return false; + } + } + + return true; +} + +// [[Rcpp::interfaces(cpp)]] +bool caliper_covs_okay(const int& ncc, + const int& ii, + const int& i, + const NumericMatrix& caliper_covs_mat, + const NumericVector& caliper_covs) { + if (ncc == 0) { + return true; + } + + for (int j = 0; j < ncc; j++) { + if (caliper_covs[j] >= 0) { + if (std::abs(caliper_covs_mat(ii, j) - caliper_covs_mat(i, j)) > caliper_covs[j]) { + return false; + } + } + else { + if (std::abs(caliper_covs_mat(ii, j) - caliper_covs_mat(i, j)) <= -caliper_covs[j]) { + return false; + } + } + } + + return true; +} + +// [[Rcpp::interfaces(cpp)]] +bool mm_okay(const int& r, + const int& i, + const IntegerVector& mm_ordi) { + + if (r > 1) { + for (int j : mm_ordi) { + if (i == j) { + return false; + } + } + } + + return true; +} + +// [[Rcpp::interfaces(cpp)]] +bool exact_okay(const bool& use_exact, + const int& ii, + const int& i, + const IntegerVector& exact) { + + if (!use_exact) { + return true; + } + + return exact[ii] == exact[i]; +} + +// [[Rcpp::interfaces(cpp)]] +int find_both(const int& t_id, + const IntegerVector& ind_d_ord, + const IntegerVector& match_d_ord, + const IntegerVector& treat, + const NumericVector& distance, + const LogicalVector& eligible, + const int& gi, + const int& r, + const IntegerVector& mm_ordi, + const int& ncc, + const NumericMatrix& caliper_covs_mat, + const NumericVector& caliper_covs, + const double& caliper_dist, + const bool& use_exact, + const IntegerVector& exact, + const int& aenc, + const IntegerMatrix& antiexact_covs, + const int& first_control, + const int& last_control) { + + int ii = match_d_ord[t_id]; + + int min_ii = first_control; + int max_ii = last_control; + + int iil = ii; + int iir = ii; + int il = -1; + int ir = -1; + + bool l_found = (iil <= min_ii); + bool r_found = (iir >= max_ii); + + double di = distance[t_id]; + double distl, distr; + + while (!l_found || !r_found) { + if (!l_found) { + if (iil == min_ii) { + l_found = true; + il = -1; + } + else { + iil--; + il = ind_d_ord[iil]; + + //Left + if (eligible[il]) { + if (treat[il] == gi) { + if (mm_okay(r, il, mm_ordi)) { + + distl = std::abs(di - distance[il]); + + if (r_found && ir >= 0 && distl > distr) { + return ir; + } + + if (distl > caliper_dist) { + il = -1; + l_found = true; + } + else { + if (exact_okay(use_exact, t_id, il, exact)) { + if (antiexact_okay(aenc, t_id, il, antiexact_covs)) { + if (caliper_covs_okay(ncc, t_id, il, caliper_covs_mat, caliper_covs)) { + l_found = true; + } + } + } + } + } + } + } + } + } + + if (!r_found) { + if (iir == max_ii) { + r_found = true; + ir = -1; + } + else { + iir++; + ir = ind_d_ord[iir]; + + //Right + if (eligible[ir]) { + if (treat[ir] == gi) { + if (mm_okay(r, ir, mm_ordi)) { + + distr = std::abs(di - distance[ir]); + + if (l_found && il >= 0 && distl <= distr) { + return il; + } + + if (distr > caliper_dist) { + ir = -1; + r_found = true; + } + else { + if (exact_okay(use_exact, t_id, ir, exact)) { + if (antiexact_okay(aenc, t_id, ir, antiexact_covs)) { + if (caliper_covs_okay(ncc, t_id, ir, caliper_covs_mat, caliper_covs)) { + r_found = true; + } + } + } + } + } + } + } + } + } + } + + if (il < 0) { + return ir; + } + + if (ir < 0) { + return il; + } + + if (distl <= distr) { + return il; + } + else { + return ir; + } +} + +// [[Rcpp::interfaces(cpp)]] +int find_lr(const int& prev_match, + const int& t_id, + const IntegerVector& ind_d_ord, + const IntegerVector& match_d_ord, + const IntegerVector& treat, + const NumericVector& distance, + const LogicalVector& eligible, + const int& gi, + const int& ncc, + const NumericMatrix& caliper_covs_mat, + const NumericVector& caliper_covs, + const double& caliper_dist, + const bool& use_exact, + const IntegerVector& exact, + const int& aenc, + const IntegerMatrix& antiexact_covs, + const int& first_control, + const int& last_control) { + + int ik, iik; + double dist; + + int prev_pos; + int ii = match_d_ord[t_id]; + + int z; + if (prev_match < 0) { + if (prev_match == -1) { + z = -1; + } + else { + z = 1; + } + + prev_pos = ii + z; + } + else { + prev_pos = match_d_ord[prev_match]; + if (prev_pos < ii) { + z = -1; + } + else { + z = 1; + } + } + + int min_ii = 0; + int max_ii = ind_d_ord.size() - 1; + + if (z == -1) { + min_ii = first_control; + } + else { + max_ii = last_control; + } + + for (iik = prev_pos; iik >= min_ii && iik <= max_ii; iik = iik + z) { + ik = ind_d_ord[iik]; + + if (!eligible[ik]) { + continue; + } + + if (treat[ik] != gi) { + continue; + } + + dist = std::abs(distance[t_id] - distance[ik]); + + if (dist > caliper_dist) { + return -1; + } + + if (!exact_okay(use_exact, t_id, ik, exact)) { + continue; + } + + if (!antiexact_okay(aenc, t_id, ik, antiexact_covs)) { + continue; + } + + if (!caliper_covs_okay(ncc, t_id, ik, caliper_covs_mat, caliper_covs)) { + continue; + } + + return ik; + } + + return -1; +} + +// [[Rcpp::interfaces(cpp)]] +IntegerVector swap_pos(IntegerVector x, + const int& a, + const int& b) { + int xa = x[a]; + + x[a] = x[b]; + x[b] = xa; + + return x; +} + +// [[Rcpp::interfaces(cpp)]] +double max_finite(const NumericVector& x) { + double m = NA_REAL; + + int n = x.size(); + int i; + bool found = false; + + //Find first finite value + for (i = 0; i < n; i++) { + if (std::isfinite(x[i])) { + m = x[i]; + found = true; + break; + } + } + + //If none found, return NA + if (!found) { + return m; + } + + //Find largest finite value + for (int j = i + 1; j < n; j++) { + if (!std::isfinite(x[i])) { + continue; + } + + if (x[j] > m) { + m = x[j]; + } + } + + return m; +} + +// [[Rcpp::interfaces(cpp)]] +double min_finite(const NumericVector& x) { + double m = NA_REAL; + + int n = x.size(); + int i; + bool found = false; + + //Find first finite value + for (i = 0; i < n; i++) { + if (std::isfinite(x[i])) { + m = x[i]; + found = true; + break; + } + } + + //If none found, return NA + if (!found) { + return m; + } + + //Find smallest finite value + for (int j = i + 1; j < n; j++) { + if (!std::isfinite(x[i])) { + continue; + } + + if (x[j] < m) { + m = x[j]; + } + } + + return m; } \ No newline at end of file diff --git a/src/internal.h b/src/internal.h index 7ee17ce9..2cf86a16 100644 --- a/src/internal.h +++ b/src/internal.h @@ -9,5 +9,71 @@ IntegerVector tabulateC_(const IntegerVector& bins, IntegerVector which(const LogicalVector& x); +int find_both(const int& t_id, + const IntegerVector& ind_d_ord, + const IntegerVector& match_d_ord, + const IntegerVector& treat, + const NumericVector& distance, + const LogicalVector& eligible, + const int& gi, + const int& r, + const IntegerVector& mm_ordi, + const int& ncc, + const NumericMatrix& caliper_covs_mat, + const NumericVector& caliper_covs, + const double& caliper_dist, + const bool& use_exact, + const IntegerVector& exact, + const int& aenc, + const IntegerMatrix& antiexact_covs, + const int& first_control, + const int& last_control); + +int find_lr(const int& prev_match, + const int& t_id, + const IntegerVector& ind_d_ord, + const IntegerVector& match_d_ord, + const IntegerVector& treat, + const NumericVector& distance, + const LogicalVector& eligible, + const int& gi, + const int& ncc, + const NumericMatrix& caliper_covs_mat, + const NumericVector& caliper_covs, + const double& caliper_dist, + const bool& use_exact, + const IntegerVector& exact, + const int& aenc, + const IntegerMatrix& antiexact_covs, + const int& first_control, + const int& last_control); + +bool antiexact_okay(const int& aenc, + const int& ii, + const int& i, + const IntegerMatrix& antiexact_covs); + +bool caliper_covs_okay(const int& ncc, + const int& ii, + const int& i, + const NumericMatrix& caliper_covs_mat, + const NumericVector& caliper_covs); + +bool mm_okay(const int& r, + const int& i, + const IntegerVector& mm_ordi); + +bool exact_okay(const bool& use_exact, + const int& ii, + const int& i, + const IntegerVector& exact); + +IntegerVector swap_pos(IntegerVector x, + const int& a, + const int& b); + +double max_finite(const NumericVector& x); + +double min_finite(const NumericVector& x); #endif \ No newline at end of file diff --git a/src/nn_matchC.cpp b/src/nn_matchC.cpp index ccb0dac3..5141ebf0 100644 --- a/src/nn_matchC.cpp +++ b/src/nn_matchC.cpp @@ -1,16 +1,19 @@ // [[Rcpp::depends(RcppProgress)]] #include #include +#include "internal.h" +#include using namespace Rcpp; // [[Rcpp::plugins(cpp11)]] // [[Rcpp::export]] IntegerMatrix nn_matchC(const IntegerVector& treat_, - const IntegerVector& ord_, + const IntegerVector& ord, const IntegerVector& ratio, const LogicalVector& discarded, const int& reuse_max, + const int& focal_, const Nullable& distance_ = R_NilValue, const Nullable& distance_mat_ = R_NilValue, const Nullable& exact_ = R_NilValue, @@ -20,295 +23,436 @@ IntegerMatrix nn_matchC(const IntegerVector& treat_, const Nullable& mah_covs_ = R_NilValue, const Nullable& antiexact_covs_ = R_NilValue, const Nullable& unit_id_ = R_NilValue, - const bool& disl_prog = false) - { + const bool& disl_prog = false) { + + IntegerVector unique_treat = unique(treat_); + std::sort(unique_treat.begin(), unique_treat.end()); + int g = unique_treat.size(); + IntegerVector treat = match(treat_, unique_treat) - 1; + int focal; + for (focal = 0; focal < g; focal++) { + if (unique_treat[focal] == focal_) { + break; + } + } - // Initialize + int n = treat.size(); + IntegerVector ind = Range(0, n - 1); - NumericVector distance, caliper_covs; - double caliper_dist; - NumericMatrix distance_mat, caliper_covs_mat, mah_covs, mah_covs_c; - IntegerMatrix antiexact_covs; - IntegerVector exact, exact_c, antiexact_col, unit_id, units_with_id_of_chosen_unit; - int id_of_chosen_unit; + int i, gi; + IntegerVector indt(n); + IntegerVector indt_begin(g), indt_end(g); + IntegerVector indt_tmp; + IntegerVector nt(g); + IntegerVector ind_match = rep(NA_INTEGER, n); - bool use_dist_mat = false; - bool use_exact = false; - bool use_caliper_dist = false; - bool use_caliper_covs = false; - bool use_mah_covs = false; - bool use_antiexact = false; - bool use_reuse_max = false; - - // Info about original treat - int n_ = treat_.size(); - IntegerVector ind_ = Range(0, n_ - 1); - IntegerVector ind1_ = ind_[treat_ == 1]; - IntegerVector ind0_ = ind_[treat_ == 0]; - int n1_ = ind1_.size(); - int n0_ = n_ - n1_; - CharacterVector lab = treat_.names(); + IntegerVector times_matched = rep(0, n); + LogicalVector eligible = !discarded; - // Output matrix with sample indices of C units - int max_rat = max(ratio); - IntegerMatrix mm(n1_, max_rat); - mm.fill(NA_INTEGER); - // rownames(mm) = lab[ind1_]; + IntegerVector g_c = Range(0, g - 1); + g_c = g_c[g_c != focal]; - // Store who has been matched - IntegerVector matched = rep(0, n_); - matched[discarded] = n1_; //discarded are unmatchable + for (gi = 0; gi < g; gi++) { + nt[gi] = sum(treat == gi); + } - // After discarding + int nf = nt[focal]; - IntegerVector ind = ind_[!discarded]; - IntegerVector treat = treat_[!discarded]; - IntegerVector ind0 = ind[treat == 0]; - int n0 = ind0.size(); + int max_nc = max(as(nt[g_c])); - int t, t_ind, min_ind, c_chosen, num_eligible, cal_len, t_rat, n_anti, antiexact_t; - double dt, cal_var_t; + indt_begin[0] = 0; + indt_end[0] = nt[0]; - NumericVector cal_var, cal_diff, ps_diff, diff, dist_t, mah_covs_t, mah_covs_row, - match_distance(n0); + for (gi = 0; gi < g; gi++) { + if (gi > 0) { + indt_begin[gi] = indt_end[gi - 1]; + indt_end[gi] = indt_begin[gi] + nt[gi]; + } - IntegerVector c_eligible(n0), indices(n0); - LogicalVector finite_match_distance(n0); + indt_tmp = ind[treat == gi]; - if (distance_.isNotNull()) { - distance = distance_; + for (i = 0; i < nt[gi]; i++) { + indt[indt_begin[gi] + i] = indt_tmp[i]; + ind_match[indt_tmp[i]] = i; + } + } + + IntegerVector ind_focal = indt[Range(indt_begin[focal], indt_end[focal] - 1)]; + + IntegerVector times_matched_allowed = rep(reuse_max, n); + times_matched_allowed[ind_focal] = ratio; + + IntegerVector n_eligible(unique_treat.size()); + for (i = 0; i < n; i++) { + if (eligible[i]) { + n_eligible[treat[i]]++; + } } + + int max_ratio = max(ratio); + + // Output matrix with sample indices of control units + IntegerMatrix mm(nf, max_ratio); + mm.fill(NA_INTEGER); + CharacterVector lab = treat_.names(); + + int min_ind, t_rat; + + //exact + bool use_exact = false; + IntegerVector exact; if (exact_.isNotNull()) { - exact = exact_; + exact = as(exact_); use_exact = true; } + + //caliper_dist + bool use_caliper_dist = false; + double caliper_dist; if (caliper_dist_.isNotNull()) { caliper_dist = as(caliper_dist_); use_caliper_dist = true; - ps_diff = NumericVector(n_); } + + //caliper_covs + NumericVector caliper_covs; + NumericMatrix caliper_covs_mat; + int ncc = 0; if (caliper_covs_.isNotNull()) { - caliper_covs = caliper_covs_; - use_caliper_covs = true; - cal_len = caliper_covs.size(); - cal_diff = NumericVector(n0); - } - if (caliper_covs_mat_.isNotNull()) { + caliper_covs = as(caliper_covs_); caliper_covs_mat = as(caliper_covs_mat_); + ncc = caliper_covs_mat.ncol(); } + + //dsit_mat and mah_covs + bool use_dist_mat = false; + bool use_mah_covs = false; + NumericMatrix distance_mat, mah_covs; if (mah_covs_.isNotNull()) { mah_covs = as(mah_covs_); - NumericVector mah_covs_row(mah_covs.ncol()); use_mah_covs = true; - } else { - if (distance_mat_.isNotNull()) { - distance_mat = as(distance_mat_); + } + else if (distance_mat_.isNotNull()) { + distance_mat = as(distance_mat_); + use_dist_mat = true; + } - // IntegerVector ind0_ = ind_[treat_ == 0]; - NumericVector dist_t(n0_); - use_dist_mat = true; - } - ps_diff = NumericVector(n_); + //distance + NumericVector distance; + if (distance_.isNotNull()) { + distance = distance_; } + + //anitexact + IntegerMatrix antiexact_covs; + int aenc = 0; if (antiexact_covs_.isNotNull()) { antiexact_covs = as(antiexact_covs_); - n_anti = antiexact_covs.ncol(); - use_antiexact = true; - } - if (reuse_max < n1_) { - use_reuse_max = true; + aenc = antiexact_covs.ncol(); } + + //reuse_max + bool use_reuse_max = (reuse_max < nf); + + //unit_id + IntegerVector unit_id; + IntegerVector matched_unit_ids; + bool use_unit_id = false; if (unit_id_.isNotNull()) { unit_id = as(unit_id_); - } - else { - unit_id = ind_; + use_unit_id = true; + use_reuse_max = true; + matched_unit_ids = rep(NA_INTEGER, max_nc); } - bool ps_diff_assigned; + IntegerVector c_eligible(max_nc); + NumericVector match_distance(max_nc); + IntegerVector matches_i(1 + max_ratio * (g - 1)); + int k_total; //progress bar int prog_length; - if (!use_reuse_max) prog_length = n1_ + 1; - else prog_length = max_rat*n1_ + 1; + if (use_reuse_max) prog_length = sum(ratio) + 1; + else prog_length = nf + 1; Progress p(prog_length, disl_prog); //Counters - int rat, i, x, j, j_, a, k; - k = -1; + int r, t_id_t_i, t_id_i, c_id_i, c, k; + double ps_diff, dist; + IntegerVector ck_, top_r_matches; + bool ps_diff_calculated; - //Matching - for (rat = 0; rat < max_rat; ++rat) { - for (i = 0; i < n1_; ++i) { - - k++; - if (k % 500 == 0) Rcpp::checkUserInterrupt(); - - p.increment(); - - if (all(as(matched[ind0]) >= reuse_max).is_true()){ - break; - } - - t = ord_[i] - 1; // index among treated - t_ind = ind1_[t]; // index among sample - - // Skip discarded units (only discarded treated units have matched > 0) - if (matched[t_ind] > 0) { - continue; - } - - //Check if unit has enough matches - t_rat = ratio[t]; - - if (t_rat < rat + 1) { - continue; - } - - c_eligible = ind0; // index among sample + int counter = -1; - //Make ineligible control units that have been matched too many times - c_eligible = c_eligible[as(matched[c_eligible]) < reuse_max]; - - //Prevent control units being matched to same treated unit again - if (rat > 0) { - c_eligible = as(c_eligible[!in(c_eligible, na_omit(mm.row(t)))]); - } + //Matching + if (use_reuse_max) { + for (r = 1; r <= max_ratio; r++) { + for (i = 0; i < nf && max(as(n_eligible[g_c])) > 0; i++) { - if (use_exact) { - exact_c = exact[c_eligible]; - c_eligible = c_eligible[exact_c == exact[t_ind]]; - } + counter++; + if (counter % 500 == 0) Rcpp::checkUserInterrupt(); - if (c_eligible.size() == 0) { - continue; - } + t_id_t_i = ord[i] - 1; // index among treated + t_id_i = ind_focal[t_id_t_i]; // index among sample - if (use_antiexact) { - for (a = 0; (a < n_anti) && (c_eligible.size() > 0); ++a) { - antiexact_col = antiexact_covs(_, a); - antiexact_t = antiexact_col[t_ind]; - antiexact_col = antiexact_col[c_eligible]; - c_eligible = c_eligible[antiexact_col != antiexact_t]; - } - if (c_eligible.size() == 0) { + if (r > times_matched_allowed[t_id_i]) { continue; } - } - ps_diff_assigned = false; + p.increment(); - if (use_caliper_dist) { - if (use_dist_mat) { - dist_t = distance_mat.row(t); - diff = dist_t[match(c_eligible, ind0_) - 1]; - } else { - dt = distance[t_ind]; - diff = Rcpp::abs(as(distance[c_eligible]) - dt); + if (!eligible[t_id_i]) { + continue; } - ps_diff[c_eligible] = diff; - ps_diff_assigned = true; - - c_eligible = c_eligible[diff <= caliper_dist]; + k_total = 0; + + for (int gi : g_c) { + k = 0; + + if (n_eligible[gi] > 0) { + for (c = indt_begin[gi]; c < indt_end[gi]; c++) { + c_id_i = indt[c]; + + if (!eligible[c_id_i]) { + continue; + } + + //Prevent control units being matched to same treated unit again + if (!mm_okay(r, c_id_i, mm.row(t_id_t_i))) { + continue; + } + + if (!exact_okay(use_exact, t_id_i, c_id_i, exact)) { + continue; + } + + if (!antiexact_okay(aenc, t_id_i, c_id_i, antiexact_covs)) { + continue; + } + + ps_diff_calculated = false; + + if (use_caliper_dist) { + if (use_dist_mat) { + ps_diff = distance_mat(t_id_t_i, ind_match[c_id_i]); + } + else { + ps_diff = std::abs(distance[c_id_i] - distance[t_id_i]); + } + + if (ps_diff > caliper_dist) { + continue; + } + + ps_diff_calculated = true; + } + + if (!caliper_covs_okay(ncc, t_id_i, c_id_i, caliper_covs_mat, caliper_covs)) { + continue; + } + + //Compute distances among eligible + if (use_mah_covs) { + dist = sqrt(sum(pow(mah_covs.row(t_id_i) - mah_covs.row(c_id_i), 2.0))); + } + else if (ps_diff_calculated) { + dist = ps_diff; + } + else if (use_dist_mat) { + dist = distance_mat(t_id_t_i, ind_match[c_id_i]); + } + else { + dist = std::abs(distance[c_id_i] - distance[t_id_i]); + } + + if (!std::isfinite(dist)) { + continue; + } + + c_eligible[k] = c_id_i; + match_distance[k] = dist; + k++; + } + } + + //If no matches... + if (k == 0) { + //If round 1, focal has no possible matches + if (r == 1) { + k_total = 0; + break; + } + continue; + } + + //Find minimum distance and assign + min_ind = 0; + for (c = 1; c < k; c++) { + if (match_distance[c] < match_distance[min_ind]) { + min_ind = c; + } + } + + matches_i[k_total] = c_eligible[min_ind]; + k_total++; + } - if (c_eligible.size() == 0) { + if (k_total == 0) { + eligible[t_id_i] = false; + n_eligible[focal]--; continue; } - } - if (use_caliper_covs) { - for (x = 0; (x < cal_len) && (c_eligible.size() > 0); ++x) { - cal_var = caliper_covs_mat( _ , x ); - - cal_var_t = cal_var[t_ind]; + //Assign to match matrix + for (c = 0; c < k_total; c++) { + mm(t_id_t_i, sum(!is_na(mm(t_id_t_i, _)))) = matches_i[c]; + } - diff = Rcpp::abs(as(cal_var[c_eligible]) - cal_var_t); + matches_i[k_total] = t_id_i; - cal_diff = diff; + ck_ = matches_i[Range(0, k_total)]; - c_eligible = c_eligible[cal_diff <= caliper_covs[x]]; + if (use_unit_id) { + ck_ = ind[match(unit_id, as(unit_id[ck_])) > 0]; } - if (c_eligible.size() == 0) { - continue; + for (int ck : ck_) { + if (!eligible[ck]) { + continue; + } + + times_matched[ck]++; + if (times_matched[ck] >= times_matched_allowed[ck]) { + eligible[ck] = false; + n_eligible[treat[ck]]--; + } } } + } + } + else { + for (i = 0; i < nf; i++) { - //Compute distances among eligible - num_eligible = c_eligible.size(); + counter++; + if (counter % 500 == 0) Rcpp::checkUserInterrupt(); - //If replace and few eligible controls, assign all and move on - if (!use_reuse_max && (num_eligible <= t_rat)) { - for (j = 0; j < num_eligible; ++j) { - mm( t , j ) = c_eligible[j]; - } + t_id_t_i = ord[i] - 1; // index among treated + t_id_i = ind_focal[t_id_t_i]; // index among sample + + p.increment(); + + if (!eligible[t_id_i]) { continue; } - if (use_mah_covs) { - - match_distance = rep(0.0, num_eligible); - mah_covs_t = mah_covs.row(t_ind); + t_rat = ratio[t_id_t_i]; + + k_total = 0; + + for (int gi : g_c) { + k = 0; + + if (n_eligible[gi] > 0) { + for (c = indt_begin[gi]; c < indt_end[gi]; c++) { + c_id_i = indt[c]; + + if (!eligible[c_id_i]) { + continue; + } + + if (!exact_okay(use_exact, t_id_i, c_id_i, exact)) { + continue; + } + + if (!antiexact_okay(aenc, t_id_i, c_id_i, antiexact_covs)) { + continue; + } + + ps_diff_calculated = false; + + if (use_caliper_dist) { + if (use_dist_mat) { + ps_diff = distance_mat(t_id_t_i, ind_match[c_id_i]); + } + else { + ps_diff = std::abs(distance[c_id_i] - distance[t_id_i]); + } + + if (ps_diff > caliper_dist) { + continue; + } + + ps_diff_calculated = true; + } + + if (!caliper_covs_okay(ncc, t_id_i, c_id_i, caliper_covs_mat, caliper_covs)) { + continue; + } + + //Compute distances among eligible + if (use_mah_covs) { + dist = sqrt(sum(pow(mah_covs.row(t_id_i) - mah_covs.row(c_id_i), 2.0))); + } + else if (ps_diff_calculated) { + dist = ps_diff; + } + else if (use_dist_mat) { + dist = distance_mat(t_id_t_i, ind_match[c_id_i]); + } + else { + dist = std::abs(distance[c_id_i] - distance[t_id_i]); + } + + if (!std::isfinite(dist)) { + continue; + } + + c_eligible[k] = c_id_i; + match_distance[k] = dist; + k++; + } + } - for (j = 0; j < num_eligible; j++) { - j_ = c_eligible[j]; - mah_covs_row = mah_covs.row(j_); - match_distance[j] = sqrt(sum(pow(mah_covs_t - mah_covs_row, 2.0))); + //If no matches... + if (k == 0) { + k_total = 0; + break; } - } else if (ps_diff_assigned) { - match_distance = ps_diff[c_eligible]; //c_eligible might have shrunk since previous assignment - } else if (use_dist_mat) { - dist_t = distance_mat.row(t); - match_distance = dist_t[match(c_eligible, ind0_) - 1]; - } else { - dt = distance[t_ind]; - match_distance = Rcpp::abs(as(distance[c_eligible]) - dt); - } + //If replace and few eligible controls, assign all and move on - //Remove infinite distances - finite_match_distance = is_finite(match_distance); - c_eligible = c_eligible[finite_match_distance]; - num_eligible = c_eligible.size(); - if (num_eligible == 0) { - continue; - } - match_distance = match_distance[finite_match_distance]; + if (k < t_rat) { + t_rat = k; + } - if (!use_reuse_max) { - //When matching w/ replacement, get t_rat closest control units - indices = Range(0, num_eligible - 1); + //Sort distances and assign + top_r_matches = Range(0, k - 1); - std::partial_sort(indices.begin(), indices.begin() + t_rat, indices.end(), - [&match_distance](int k, int j) {return match_distance[k] < match_distance[j];}); + std::partial_sort(top_r_matches.begin(), top_r_matches.begin() + t_rat, top_r_matches.end(), + [&match_distance](int a, int b) {return match_distance[a] < match_distance[b];}); - for (j = 0; j < t_rat; ++j) { - min_ind = indices[j]; - mm( t , j ) = c_eligible[min_ind]; + for (c = 0; c < t_rat; c++) { + matches_i[k_total] = c_eligible[top_r_matches[c]]; + k_total++; } } - else { - min_ind = which_min(match_distance); - c_chosen = c_eligible[min_ind]; - mm( t , rat ) = c_chosen; + if (k_total == 0) { + continue; + } - id_of_chosen_unit = unit_id[c_chosen]; - units_with_id_of_chosen_unit = ind_[unit_id == id_of_chosen_unit]; - for (j = 0; j < units_with_id_of_chosen_unit.size(); ++j) { - matched[units_with_id_of_chosen_unit[j]] = matched[units_with_id_of_chosen_unit[j]] + 1; - } + //Assign to match matrix + for (c = 0; c < k_total; c++) { + mm(t_id_t_i, sum(!is_na(mm(t_id_t_i, _)))) = matches_i[c]; } } - - if (!use_reuse_max) break; } p.update(prog_length); mm = mm + 1; // + 1 because C indexing starts at 0 but mm is sent to R - rownames(mm) = lab[ind1_]; + rownames(mm) = lab[ind_focal]; return mm; } diff --git a/src/nn_matchC_closest.cpp b/src/nn_matchC_closest.cpp index 22ce7282..d437bb04 100644 --- a/src/nn_matchC_closest.cpp +++ b/src/nn_matchC_closest.cpp @@ -22,38 +22,39 @@ IntegerMatrix nn_matchC_closest(const NumericMatrix& distance_mat, { int r = distance_mat.nrow(); - int c = distance_mat.ncol(); + + int n = treat.size(); IntegerMatrix mm(r, max(ratio)); mm.fill(NA_INTEGER); CharacterVector lab = treat.names(); - IntegerVector matched_t = rep(0, r); - IntegerVector matched_c = rep(0, c); - // IntegerVector ind = seq(0, treat.size() - 1); - IntegerVector ind0 = which(treat == 0); - IntegerVector ind1 = which(treat == 1); + IntegerVector ind = Range(0, n - 1); + IntegerVector ind0 = ind[treat == 0]; + IntegerVector ind1 = ind[treat == 1]; //caliper_dist - bool use_caliper_dist = false; double caliper_dist; if (caliper_dist_.isNotNull()) { caliper_dist = as(caliper_dist_); - use_caliper_dist = true; + } + else { + caliper_dist = max_finite(distance_mat) + 1; } //caliper_covs NumericVector caliper_covs; NumericMatrix caliper_covs_mat; - bool use_caliper_covs = false; - double n_cal_covs; + int ncc; if (caliper_covs_.isNotNull()) { caliper_covs = as(caliper_covs_); caliper_covs_mat = as(caliper_covs_mat_); - n_cal_covs = caliper_covs_mat.ncol(); - use_caliper_covs = true; + ncc = caliper_covs_mat.ncol(); + } + else { + ncc = 0; } //exact @@ -66,12 +67,13 @@ IntegerMatrix nn_matchC_closest(const NumericMatrix& distance_mat, //antiexact IntegerMatrix antiexact_covs; - bool use_antiexact = false; - int n_anti; + int aenc; if (antiexact_covs_.isNotNull()) { antiexact_covs = as(antiexact_covs_); - n_anti = antiexact_covs.ncol(); - use_antiexact = true; + aenc = antiexact_covs.ncol(); + } + else { + aenc = 0; } //unit_id @@ -82,6 +84,15 @@ IntegerMatrix nn_matchC_closest(const NumericMatrix& distance_mat, use_unit_id = true; } + IntegerVector times_matched = rep(0, n); + LogicalVector eligible = rep(true, n); + eligible[discarded] = false; + IntegerVector times_matched_allowed = rep(reuse_max, n); + times_matched_allowed[ind1] = ratio; + + int n_eligible0 = sum(as(eligible[treat == 0])); + int n_eligible1 = sum(as(eligible[treat == 1])); + //progress bar int prog_length; prog_length = sum(ratio) + 1; @@ -93,17 +104,22 @@ IntegerMatrix nn_matchC_closest(const NumericMatrix& distance_mat, IntegerVector d_ord = o(distance_mat); d_ord = d_ord - 1; //Because R uses 1-indexing - int rj, cj, dj, i, ind0i, ind1i; - bool okay; + int rj, cj, c_id_i, t_id_i; + + for (int dj : d_ord) { - for (int j = 0; j < d_ord.size(); j++) { + if (n_eligible1 <= 0) { + break; + } - dj = d_ord[j]; + if (n_eligible0 <= 0) { + break; + } - // If distance is greater tha distance caliper, stop the whole thing because + // If distance is greater than distance caliper, stop the whole thing because // no remaining distance will be smaller - if (use_caliper_dist) { - if (distance_mat[dj] > caliper_dist) break; + if (distance_mat[dj] > caliper_dist) { + break; } // Get row and column index of potential pair @@ -111,81 +127,63 @@ IntegerMatrix nn_matchC_closest(const NumericMatrix& distance_mat, cj = dj / r; // Get sample indices of members of potential pair - ind1i = ind1[rj]; - ind0i = ind0[cj]; + t_id_i = ind1[rj]; + c_id_i = ind0[cj]; // If either member is discarded, move on - if (discarded[ind1i]) continue; - if (discarded[ind0i]) continue; + if (!eligible[t_id_i]) { + continue; + } - // If either member has been matched enough times, move on - if (matched_t[rj] >= ratio[rj]) continue; - if (matched_c[cj] >= reuse_max) continue; + if (!eligible[c_id_i]) { + continue; + } // Exact matching criterion - if (use_exact) { - if (exact[ind1i] != exact[ind0i]) { - continue; - } + if (!exact_okay(use_exact, t_id_i, c_id_i, exact)) { + continue; } - // Covariate caliper criterion - if (use_caliper_covs) { - i = 0; - okay = true; - while (okay && (i < n_cal_covs)) { - if (std::abs(caliper_covs_mat(ind1i, i) - caliper_covs_mat(ind0i, i)) > caliper_covs[i]) { - okay = false; - } - i++; - } - if (!okay) continue; + // Antiexact criterion + if (!antiexact_okay(aenc, t_id_i, c_id_i, antiexact_covs)) { + continue; } - // Antiexact criterion - if (use_antiexact) { - i = 0; - okay = true; - while (okay && (i < n_anti)) { - if (antiexact_covs(ind1i, i) == antiexact_covs(ind0i, i)) { - okay = false; - } - i++; - } - if (!okay) continue; + // Covariate caliper criterion + if (!caliper_covs_okay(ncc, t_id_i, c_id_i, caliper_covs_mat, caliper_covs)) { + continue; } // If all criteria above are satisfied, potential pair becomes a pair! + mm(rj, sum(!is_na(mm(rj, _)))) = c_id_i; // If unit_id used, increase match count of all units with that ID if (use_unit_id) { - ck_ = which(as(unit_id[ind1]) == unit_id[ind1i]); - - for (i = 0; i < ck_.size(); i++) { - matched_t[ck_[i]]++; - } - - ck_ = which(as(unit_id[ind0]) == unit_id[ind0i]); - - for (i = 0; i < ck_.size(); i++) { - matched_c[ck_[i]]++; - } + ck_ = ind[unit_id == unit_id[t_id_i] | unit_id == unit_id[c_id_i]]; } else { - matched_t[rj]++; - matched_c[cj]++; + ck_ = {t_id_i, c_id_i}; } - mm(rj, matched_t[rj] - 1) = ind0i; + for (int ck : ck_) { - p.increment(); + if (!eligible[ck]) { + continue; + } - if (matched_t[rj] >= ratio[rj]) { - if (all(matched_t >= ratio).is_true()) break; - } - if (matched_c[cj] >= reuse_max) { - if (all(matched_c >= reuse_max).is_true()) break; + times_matched[ck]++; + if (times_matched[ck] >= times_matched_allowed[ck]) { + eligible[ck] = false; + if (treat[ck] == 1) { + n_eligible1--; + } + else { + n_eligible0--; + } + } } + + p.increment(); } p.update(prog_length); diff --git a/src/nn_matchC_vec.cpp b/src/nn_matchC_vec.cpp index 43c04096..cc8ef424 100644 --- a/src/nn_matchC_vec.cpp +++ b/src/nn_matchC_vec.cpp @@ -4,471 +4,259 @@ #include "internal.h" using namespace Rcpp; -// Version of nn_matchC that works when `distance` is a vector. -// Doesn't accept `distance_mat_` or `mah_covs_`. - -bool check_in(int x, - IntegerVector table) { - int t = table.size(); - if (t < 1) return false; +// [[Rcpp::plugins(cpp11)]] - for (int j = 0; j < t; j++) { - if (x == table[j]) return true; - } - return false; -} - -int find_right(int ii, - int last_control, - IntegerVector treat, - LogicalVector can_be_matched, - int r, - IntegerVector row, - IntegerVector d_ord, - NumericVector distance, - bool use_caliper_dist, - double caliper_dist, - bool use_caliper_covs, - NumericVector caliper_covs, - NumericMatrix caliper_covs_mat, - bool use_exact, - IntegerVector exact, - bool use_antiexact, - IntegerMatrix antiexact_covs) { - - int k = ii + 1; - bool found = false, okay; - int i; - - int n_anti; - if (use_antiexact) n_anti = antiexact_covs.ncol(); - - int n_cal_covs; - if (use_caliper_covs) n_cal_covs = caliper_covs_mat.ncol(); - - while (!found && k <= last_control) { - if (treat[k] == 1) { - k++; //if unit is treated, move right - continue; - } +// [[Rcpp::export]] +IntegerMatrix nn_matchC_vec(const IntegerVector& treat_, + const IntegerVector& ord, + const IntegerVector& ratio, + const LogicalVector& discarded, + const int& reuse_max, + const int& focal_, + const NumericVector& distance, + const Nullable& exact_ = R_NilValue, + const Nullable& caliper_dist_ = R_NilValue, + const Nullable& caliper_covs_ = R_NilValue, + const Nullable& caliper_covs_mat_ = R_NilValue, + const Nullable& antiexact_covs_ = R_NilValue, + const Nullable& unit_id_ = R_NilValue, + const bool& disl_prog = false) { - if (!can_be_matched[k]) { - k++; //if unit is matched, move right - continue; + IntegerVector unique_treat = unique(treat_); + std::sort(unique_treat.begin(), unique_treat.end()); + int g = unique_treat.size(); + IntegerVector treat = match(treat_, unique_treat) - 1; + int focal; + for (focal = 0; focal < g; focal++) { + if (unique_treat[focal] == focal_) { + break; } + } - //if unit has already been matched to unit i, skip - if (r > 0) { - if (check_in(d_ord[k], row)) { - k++; - continue; - } - } + int n = treat.size(); + IntegerVector ind = Range(0, n - 1); - if (use_caliper_dist) { - if (std::abs(distance[ii] - distance[k]) > caliper_dist) { - //if closest is outside caliper, break; none can be found - break; - } - } + int i, gi; + IntegerVector indt(n); + IntegerVector indt_begin(g), indt_end(g); + IntegerVector indt_tmp; + IntegerVector nt(g); + IntegerVector ind_match = rep(NA_INTEGER, n); - if (use_exact) { - if (exact[ii] != exact[k]) { - k++; //if not exact match, move right - continue; - } - } + IntegerVector times_matched = rep(0, n); + LogicalVector eligible = !discarded; - if (use_antiexact) { - i = 0; - okay = true; - while (okay && (i < n_anti)) { - if (antiexact_covs(ii, i) == antiexact_covs(k, i)) { - okay = false; - } - i++; - } - if (!okay) { - k++; //if any antiexact checks failed, move right - continue; - } - } + IntegerVector g_c = Range(0, g - 1); + g_c = g_c[g_c != focal]; - if (use_caliper_covs) { - i = 0; - okay = true; - while (okay && (i < n_cal_covs)) { - if (std::abs(caliper_covs_mat(ii, i) - caliper_covs_mat(k, i)) > caliper_covs[i]) { - okay = false; - } - i++; - } - if (!okay) { - k++; //if any cov caliper checks failed, move right - continue; - } - } - - found = true; + for (gi = 0; gi < g; gi++) { + nt[gi] = sum(treat == gi); } - if (!found) k = -1; - - return k; -} - -int find_left(int ii, - int first_control, - IntegerVector treat, - LogicalVector can_be_matched, - int r, - IntegerVector row, - IntegerVector d_ord, - NumericVector distance, - bool use_caliper_dist, - double caliper_dist, - bool use_caliper_covs, - NumericVector caliper_covs, - NumericMatrix caliper_covs_mat, - bool use_exact, - IntegerVector exact, - bool use_antiexact, - IntegerMatrix antiexact_covs) { - - int k = ii - 1; - bool found = false, okay; - int i; - - int n_anti; - if (use_antiexact) n_anti = antiexact_covs.ncol(); - - int n_cal_covs; - if (use_caliper_covs) n_cal_covs = caliper_covs_mat.ncol(); - - while (!found && k >= first_control) { - if (treat[k] == 1) { - k--; //if unit is treated, move left - continue; - } - - if (!can_be_matched[k]) { - k--; //if unit is matched, move left - continue; - } - - //if unit has already been matched to unit i, skip - if (r > 0) { - if (check_in(d_ord[k], row)) { - k--; - continue; - } - } + int nf = nt[focal]; - if (use_caliper_dist) { - if (std::abs(distance[ii] - distance[k]) > caliper_dist) { - //if closest is outside caliper, break - break; - } - } + indt_begin[0] = 0; + indt_end[0] = nt[0]; - if (use_exact) { - if (exact[ii] != exact[k]) { - k--; //if not exact match, move left - continue; - } + for (gi = 0; gi < g; gi++) { + if (gi > 0) { + indt_begin[gi] = indt_end[gi - 1]; + indt_end[gi] = indt_begin[gi] + nt[gi]; } - if (use_antiexact) { - i = 0; - okay = true; - while (okay && (i < n_anti)) { - if (antiexact_covs(ii, i) == antiexact_covs(k, i)) { - okay = false; - } - i++; - } - if (!okay) { - k--; //if any antiexact checks failed, move left - continue; - } - } + indt_tmp = ind[treat == gi]; - if (use_caliper_covs) { - i = 0; - okay = true; - while (okay && (i < n_cal_covs)) { - if (std::abs(caliper_covs_mat(ii, i) - caliper_covs_mat(k, i)) > caliper_covs[i]) { - okay = false; - } - i++; - } - if (!okay) { - k--; //if any cov caliper checks failed, move left - continue; - } + for (i = 0; i < nt[gi]; i++) { + indt[indt_begin[gi] + i] = indt_tmp[i]; + ind_match[indt_tmp[i]] = i; } - - found = true; } - if (!found) k = -1; + IntegerVector ind_focal = indt[Range(indt_begin[focal], indt_end[focal] - 1)]; - return k; -} + IntegerVector times_matched_allowed = rep(reuse_max, n); + times_matched_allowed[ind_focal] = ratio; -// [[Rcpp::plugins(cpp11)]] - -// [[Rcpp::export]] -IntegerMatrix nn_matchC_vec(const IntegerVector& treat_, - const IntegerVector& ord_, - const IntegerVector& ratio_, - const LogicalVector& discarded_, - const int& reuse_max, - const NumericVector& distance_, - const Nullable& exact_ = R_NilValue, - const Nullable& caliper_dist_ = R_NilValue, - const Nullable& caliper_covs_ = R_NilValue, - const Nullable& caliper_covs_mat_ = R_NilValue, - const Nullable& antiexact_covs_ = R_NilValue, - const Nullable& unit_id_ = R_NilValue, - const bool& disl_prog = false) { + IntegerVector n_eligible(unique_treat.size()); + for (i = 0; i < n; i++) { + if (eligible[i]) { + n_eligible[treat[i]]++; + } + } - int n = treat_.size(); + int max_ratio = max(ratio); - CharacterVector lab_ = treat_.names(); + // Output matrix with sample indices of control units + IntegerMatrix mm(nf, max_ratio); + mm.fill(NA_INTEGER); + CharacterVector lab = treat_.names(); //Use base::order() because faster than Rcpp implementation of order() Function o("order"); - IntegerVector d_ord = o(distance_, Named("decreasing") = false); - d_ord = d_ord - 1; - - IntegerVector treat = treat_[d_ord]; - NumericVector distance = distance_[d_ord]; - CharacterVector lab = lab_[d_ord]; - LogicalVector discarded = discarded_[d_ord]; + IntegerVector ind_d_ord = o(distance, Named("decreasing") = false); + ind_d_ord = ind_d_ord - 1; //location of each unit after sorting - IntegerVector ratio_tmp(n); - ratio_tmp[treat_ == 1] = ratio_; - IntegerVector ratio = ratio_tmp[d_ord]; + IntegerVector match_d_ord = match(ind, ind_d_ord) - 1; - IntegerVector ord = ord_ - 1; - - int max_ratio = max(ratio); - - IntegerVector ind = Range(0, n - 1); - IntegerVector ind0 = ind[treat == 0]; - IntegerVector ind1 = ind[treat == 1]; - int n0 = ind0.size(); - int n1 = ind1.size(); - ind1.names() = lab[ind1]; - - IntegerVector t(n); - IntegerVector t0(n0), t1(n1); - int i; - - //ind: 1 2 3 4 5 6 7 8 - //ind1: 2 3 5 7 - - IntegerMatrix mm(n1, max_ratio); - mm.fill(NA_INTEGER); - CharacterVector lab1 = lab[ind1]; - CharacterVector mm_nm = lab_[treat_ == 1]; + IntegerVector first_control = rep(0, g); + IntegerVector last_control = rep(n - 1, g); + //exact + bool use_exact = false; + IntegerVector exact; + if (exact_.isNotNull()) { + exact = as(exact_); + use_exact = true; + } //caliper_dist double caliper_dist; - bool use_caliper_dist = false; if (caliper_dist_.isNotNull()) { caliper_dist = as(caliper_dist_); - use_caliper_dist = true; + } + else { + caliper_dist = max_finite(distance) - min_finite(distance) + 1; } //caliper_covs NumericVector caliper_covs; NumericMatrix caliper_covs_mat; - bool use_caliper_covs = false; + int ncc = 0; if (caliper_covs_.isNotNull()) { caliper_covs = as(caliper_covs_); caliper_covs_mat = as(caliper_covs_mat_); - NumericVector tmp_cc(caliper_covs_mat.nrow()); - for (int i = 0; i < caliper_covs_mat.ncol(); i++) { - tmp_cc = caliper_covs_mat(_, i); - tmp_cc = tmp_cc[d_ord]; - caliper_covs_mat(_, i) = tmp_cc; - } - use_caliper_covs = true; - } - - //exact - bool use_exact = false; - IntegerVector exact; - if (exact_.isNotNull()) { - exact = as(exact_)[d_ord]; - use_exact = true; + ncc = caliper_covs_mat.ncol(); } //antiexact IntegerMatrix antiexact_covs; - bool use_antiexact = false; + int aenc = 0; if (antiexact_covs_.isNotNull()) { antiexact_covs = as(antiexact_covs_); - NumericVector tmp_ae(antiexact_covs.nrow()); - for (int i = 0; i < antiexact_covs.ncol(); i++) { - tmp_ae = antiexact_covs(_, i); - tmp_ae = tmp_ae[d_ord]; - antiexact_covs(_, i) = tmp_ae; - } - use_antiexact = true; + aenc = antiexact_covs.ncol(); } //unit_id IntegerVector unit_id; bool use_unit_id = false; if (unit_id_.isNotNull()) { - unit_id = as(unit_id_)[d_ord]; + unit_id = as(unit_id_); use_unit_id = true; } - IntegerVector times_matched = rep(0, n); - LogicalVector can_be_matched = !as(discarded); - - IntegerVector ind_cbm = ind[can_be_matched]; - int first_control = ind_cbm[0]; - int last_control = ind_cbm[ind_cbm.size() - 1]; + IntegerVector matches_i(g); + int k_total; //progress bar - int prog_length; - prog_length = max_ratio*n1 + 1; + int prog_length = sum(ratio) + 1; Progress p(prog_length, disl_prog); - int ii, k, j, ck, r, row_to_fill; - int k_left, k_right; - double dti; - String labi; - IntegerVector mm_row, ck_; - bool done = false; - - for (r = 0; r < max_ratio; r++) { - for (i = 0; i < n1; i++) { - p.increment(); - - row_to_fill = ord[i]; + int r, t_id_t_i, t_id_i, c_id_i, c; + IntegerVector ck_; + // bool check = true; - labi = mm_nm[row_to_fill]; + int counter = -1; - ii = ind1[labi]; //ii'th unit overall + for (r = 1; r <= max_ratio; r++) { + for (i = 0; i < nf && max(as(n_eligible[g_c])) > 0; i++) { + // i: generic looping index + // t_id_t_i; index of treated unit to match among treated units + // t_id_i: index of treated unit to match among all units + counter++; + if (counter % 500 == 0) Rcpp::checkUserInterrupt(); - if (!can_be_matched[ii]) continue; + t_id_t_i = ord[i] - 1; + t_id_i = ind_focal[t_id_t_i]; - mm_row = na_omit(mm(row_to_fill, _)); + if (r > times_matched_allowed[t_id_i]) { + continue; + } - //find control unit to left and right - k_left = find_left(ii, first_control, - treat, - can_be_matched, - r, mm_row, - d_ord, - distance, use_caliper_dist, caliper_dist, - use_caliper_covs, caliper_covs, caliper_covs_mat, - use_exact, exact, - use_antiexact, antiexact_covs); + p.increment(); - k_right = find_right(ii, last_control, - treat, - can_be_matched, - r, mm_row, - d_ord, - distance, use_caliper_dist, caliper_dist, - use_caliper_covs, caliper_covs, caliper_covs_mat, - use_exact, exact, - use_antiexact, antiexact_covs); - - if ((k_left >= 0) && (k_right >= 0)) { - dti = distance[ii]; - if (std::abs(distance[k_left] - dti) <= std::abs(distance[k_right] - dti)) { - k = k_left; - } - else { - k = k_right; - } - } - else if (k_left >= 0) { - k = k_left; - } - else if (k_right >= 0) { - k = k_right; - } - else { - can_be_matched[ii] = false; + if (!eligible[t_id_i]) { continue; } - mm( row_to_fill, r ) = d_ord[k]; + k_total = 0; - if (use_unit_id) { - ck_ = ind[unit_id == unit_id[ii]]; + for (int gi : g_c) { + while (!eligible[ind_d_ord[first_control[gi]]] || treat[ind_d_ord[first_control[gi]]] != gi) { + first_control[gi]++; + } + while (!eligible[ind_d_ord[last_control[gi]]] || treat[ind_d_ord[last_control[gi]]] != gi) { + last_control[gi]--; + } - for (j = 0; j < ck_.size(); j++) { - ck = ck_[j]; - times_matched[ck]++; - if (times_matched[ck] >= ratio[ck]) { - can_be_matched[ck] = false; + c_id_i = find_both(t_id_i, + ind_d_ord, + match_d_ord, + treat, + distance, + eligible, + gi, + r, + mm.row(t_id_t_i), + ncc, + caliper_covs_mat, + caliper_covs, + caliper_dist, + use_exact, + exact, + aenc, + antiexact_covs, + first_control[gi], + last_control[gi]); + + if (c_id_i < 0) { + if (r == 1) { + k_total = 0; + break; } + continue; } - ck_ = ind[unit_id == unit_id[k]]; + matches_i[k_total] = c_id_i; + k_total++; + } - for (j = 0; j < ck_.size(); j++) { - ck = ck_[j]; - times_matched[ck]++; - if (times_matched[ck] >= reuse_max) { - can_be_matched[ck] = false; - } - } + if (k_total == 0) { + eligible[t_id_i] = false; + n_eligible[focal]--; + continue; } - else { - times_matched[ii]++; - if (times_matched[ii] >= ratio[ii]) { - can_be_matched[ii] = false; - } - times_matched[k]++; - if (times_matched[k] >= reuse_max) { - can_be_matched[k] = false; - } + for (c = 0; c < k_total; c++) { + mm(t_id_t_i, sum(!is_na(mm(t_id_t_i, _)))) = matches_i[c]; } - if (!can_be_matched[ii]) { - if (any(as(can_be_matched[treat == 1])).is_false()) { - done = true; - break; - } + matches_i[k_total] = t_id_i; + + ck_ = matches_i[Range(0, k_total)]; + + if (use_unit_id) { + ck_ = ind[match(unit_id, as(unit_id[ck_])) > 0]; } - if (!can_be_matched[k]) { - if (any(as(can_be_matched[treat == 0])).is_false()) { - done = true; - break; + for (int ck : ck_) { + if (!eligible[ck]) { + continue; } - } - ind_cbm = ind[can_be_matched & (treat == 0)]; - if (!can_be_matched[first_control]) { - first_control = ind_cbm[0]; - } - if (!can_be_matched[last_control]) { - last_control = ind_cbm[ind_cbm.size() - 1]; + times_matched[ck]++; + if (times_matched[ck] >= times_matched_allowed[ck]) { + eligible[ck] = false; + n_eligible[treat[ck]]--; + } } } - - if (done) break; } p.update(prog_length); mm = mm + 1; - rownames(mm) = mm_nm; + rownames(mm) = lab[ind_focal]; return mm; -} +} \ No newline at end of file diff --git a/src/nn_matchC_vec_closest.cpp b/src/nn_matchC_vec_closest.cpp new file mode 100644 index 00000000..32d930ba --- /dev/null +++ b/src/nn_matchC_vec_closest.cpp @@ -0,0 +1,310 @@ +// [[Rcpp::depends(RcppProgress)]] +#include +#include +#include "internal.h" +#include +using namespace Rcpp; + +// [[Rcpp::plugins(cpp11)]] + +// [[Rcpp::export]] +IntegerMatrix nn_matchC_vec_closest(const IntegerVector& treat, + const IntegerVector& ratio, + const LogicalVector& discarded, + const int& reuse_max, + const NumericVector& distance, + const Nullable& exact_ = R_NilValue, + const Nullable& caliper_dist_ = R_NilValue, + const Nullable& caliper_covs_ = R_NilValue, + const Nullable& caliper_covs_mat_ = R_NilValue, + const Nullable& antiexact_covs_ = R_NilValue, + const Nullable& unit_id_ = R_NilValue, + const bool& disl_prog = false) { + + int n = treat.size(); + + //Use base::order() because faster than Rcpp implementation of order() + Function o("order"); + + IntegerVector ind_d_ord = o(distance, Named("decreasing") = false); + ind_d_ord = ind_d_ord - 1; //location of each unit after sorting + + IntegerVector ind = Range(0, n - 1); + IntegerVector ind1 = ind[treat == 1]; + + int n1 = ind1.size(); + + int i, j; + + IntegerVector match_d_ord = match(ind, ind_d_ord) - 1; + + int max_ratio = max(ratio); + + //ind: 1 2 3 4 5 6 7 8 + //ind1: 2 3 5 7 + + IntegerMatrix mm(n1, max_ratio); + mm.fill(NA_INTEGER); + CharacterVector lab = treat.names(); + + //exact + bool use_exact = false; + IntegerVector exact; + if (exact_.isNotNull()) { + exact = as(exact_); + use_exact = true; + } + + //caliper_dist + double caliper_dist; + if (caliper_dist_.isNotNull()) { + caliper_dist = as(caliper_dist_); + } + else { + caliper_dist = max_finite(distance) - min_finite(distance) + 1; + } + + //caliper_covs + NumericVector caliper_covs; + NumericMatrix caliper_covs_mat; + int ncc = 0; + if (caliper_covs_.isNotNull()) { + caliper_covs = as(caliper_covs_); + caliper_covs_mat = as(caliper_covs_mat_); + ncc = caliper_covs_mat.ncol(); + } + + //antiexact + IntegerMatrix antiexact_covs; + int aenc = 0; + if (antiexact_covs_.isNotNull()) { + antiexact_covs = as(antiexact_covs_); + aenc = antiexact_covs.ncol(); + } + + //unit_id + IntegerVector unit_id; + bool use_unit_id = false; + if (unit_id_.isNotNull()) { + unit_id = as(unit_id_); + use_unit_id = true; + } + + IntegerVector ind1_match = rep(NA_INTEGER, n); + for (i = 0; i < n1; i++) { + ind1_match[ind1[i]] = i; + } + + //storing closeness + IntegerVector t_id(2 * n1); + IntegerVector c_id(2 * n1); + NumericVector dist = rep(R_PosInf, 2 * n1); + for (i = 0; i < n1; i++) { + t_id[i] = ind1[i]; + t_id[i + n1] = ind1[i]; + c_id[i] = -1; + c_id[i + n1] = -2; + } + + IntegerVector times_matched = rep(0, n); + LogicalVector eligible = rep(true, n); + eligible[discarded] = false; + IntegerVector times_matched_allowed = rep(reuse_max, n); + times_matched_allowed[ind1] = ratio; + + IntegerVector times_skipped = rep(0, n); + + //progress bar + int prog_length = sum(ratio) + 1; + Progress p(prog_length, disl_prog); + p.increment(); + + IntegerVector ck_; + + int t_id_i, c_id_i, t_id_t_i, k; + + int first_control = 0; + int last_control = n - 1; + + while (!eligible[ind_d_ord[first_control]] || treat[ind_d_ord[first_control]] == 1) { + first_control++; + } + while (!eligible[ind_d_ord[last_control]] || treat[ind_d_ord[last_control]] == 1) { + last_control--; + } + + //Find left and right matches for each treated unit + + for (i = 0; i < (2 * n1); i++) { + t_id_i = t_id[i]; + + if (!eligible[t_id_i]) { + continue; + } + + c_id_i = c_id[i]; + + k = find_lr(c_id_i, + t_id_i, + ind_d_ord, + match_d_ord, + treat, + distance, + eligible, + 0, + ncc, + caliper_covs_mat, + caliper_covs, + caliper_dist, + use_exact, + exact, + aenc, + antiexact_covs, + first_control, + last_control); + + if (k < 0) { + times_skipped[t_id_i]++; + if (times_skipped[t_id_i] == 2) { + eligible[t_id_i] = false; + } + continue; + } + + c_id[i] = k; + dist[i] = std::abs(distance[t_id_i] - distance[k]); + } + + int n_eligible0 = 0; + int n_eligible1 = 0; + for (i = 0; i < n; i++) { + if (eligible[i]) { + if (treat[i] == 0) { + n_eligible0++; + } + else { + n_eligible1++; + } + } + } + + //Order the list + IntegerVector heap_ord = o(dist); + heap_ord = heap_ord - 1; + + //Go down the list; update as needed + int hi; + int counter = -1; + + for (i = 0; (i < 2 * n1) && n_eligible1 > 0 && n_eligible0 > 0; i++) { + counter++; + if (counter % 500 == 0) Rcpp::checkUserInterrupt(); + + hi = heap_ord[i]; + + if (dist[hi] >= caliper_dist) { + break; + } + + t_id_i = t_id[hi]; + + if (!eligible[t_id_i]) { + continue; + } + + c_id_i = c_id[hi]; + + if (!eligible[c_id_i]) { + // If control isn't eligible, find new control and try again + + while (!eligible[ind_d_ord[first_control]] || treat[ind_d_ord[first_control]] == 1) { + first_control++; + } + while (!eligible[ind_d_ord[last_control]] || treat[ind_d_ord[last_control]] == 1) { + last_control--; + } + + k = find_lr(c_id_i, + t_id_i, + ind_d_ord, + match_d_ord, + treat, + distance, + eligible, + 0, + ncc, + caliper_covs_mat, + caliper_covs, + caliper_dist, + use_exact, + exact, + aenc, + antiexact_covs, + first_control, + last_control); + + //If no new control found, mark treated unit as ineligible and continue + if (k < 0) { + times_skipped[t_id_i]++; + if (times_skipped[t_id_i] == 2) { + eligible[t_id_i] = false; + n_eligible1--; + } + continue; + } + + c_id[hi] = k; + dist[hi] = std::abs(distance[t_id_i] - distance[k]); + + //Find new position of pair in heap + for (j = i; j < (2 * n1) - 1; j++) { + if (dist[heap_ord[j]] < dist[heap_ord[j + 1]]) { + break; + } + + swap_pos(heap_ord, j, j + 1); + } + + i--; + } + else { + t_id_t_i = ind1_match[t_id_i]; + + mm(t_id_t_i, sum(!is_na(mm(t_id_t_i, _)))) = c_id_i; + + if (use_unit_id) { + ck_ = ind[unit_id == unit_id[t_id_i] | unit_id == unit_id[c_id_i]]; + } + else { + ck_ = {c_id_i, t_id_i}; + } + + for (int ck : ck_) { + + if (!eligible[ck]) { + continue; + } + + times_matched[ck]++; + if (times_matched[ck] >= times_matched_allowed[ck]) { + eligible[ck] = false; + if (treat[ck] == 1) { + n_eligible1--; + } + else { + n_eligible0--; + } + } + } + + p.increment(); + } + } + + p.update(prog_length); + + mm = mm + 1; + rownames(mm) = lab[ind1]; + + return mm; +} \ No newline at end of file From 37be6078be4a59bdf8f373acd5e03703befcb7a4 Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Wed, 16 Oct 2024 14:15:30 -0400 Subject: [PATCH 06/48] Improvements --- src/pairdistC.cpp | 40 ++++++++---------- src/subclass2mm.cpp | 83 ++++++++++++++++++++++++++----------- src/weights_matrixC.cpp | 92 ++++++++++++++++++++++++++++------------- 3 files changed, 140 insertions(+), 75 deletions(-) diff --git a/src/pairdistC.cpp b/src/pairdistC.cpp index 500b65a0..a4c158a2 100644 --- a/src/pairdistC.cpp +++ b/src/pairdistC.cpp @@ -3,39 +3,33 @@ using namespace Rcpp; // [[Rcpp::export]] -double pairdistsubC(const NumericVector& x_, - const IntegerVector& t_, - const IntegerVector& s_, - const int& num_sub) { +double pairdistsubC(const NumericVector& x, + const IntegerVector& t, + const IntegerVector& s) { double dist = 0; - LogicalVector not_na_sub = !is_na(s_); - NumericVector x = x_[not_na_sub]; - IntegerVector t = t_[not_na_sub]; - IntegerVector s = s_[not_na_sub]; - int n = t.size(); - LogicalVector in_s_i(n); - NumericVector x_t0(n); - IntegerVector t_ind_s(n), c_ind_s(n); - + int i, j; int k = 0; - int i, i1, n1_s; - for (i = 1; i <= num_sub; ++i) { - in_s_i = (s == i); - t_ind_s = which((t == 1) & in_s_i); - c_ind_s = which((t == 0) & in_s_i); + for (i = 0; i < n - 1; i++) { + if (!std::isfinite(s[i])) { + continue; + } - n1_s = t_ind_s.size(); + for (j = i + 1; j < n; j++) { + if (s[i] != s[j]) { + continue; + } - x_t0 = x[c_ind_s]; + if (t[i] == t[j]) { + continue; + } - for (i1 = 0; i1 < n1_s; ++i1) { - dist += sum(Rcpp::abs(x[t_ind_s[i1]] - x_t0)); + dist += std::abs(x[i] - x[j]); + k++; } - k += n1_s * c_ind_s.size(); } dist /= k; diff --git a/src/subclass2mm.cpp b/src/subclass2mm.cpp index bdc5622a..f4274384 100644 --- a/src/subclass2mm.cpp +++ b/src/subclass2mm.cpp @@ -6,42 +6,77 @@ using namespace Rcpp; //focal is the treatment level (0/1) that corresponds to the rownames. // [[Rcpp::export]] -IntegerMatrix subclass2mmC(const IntegerVector& subclass, +IntegerMatrix subclass2mmC(const IntegerVector& subclass_, const IntegerVector& treat, const int& focal) { - IntegerVector tab = tabulateC_(subclass); - int mm_col = max(tab) - 1; + LogicalVector na_sub = is_na(subclass_); + IntegerVector unique_sub = unique(as(subclass_[!na_sub])); + IntegerVector subclass = match(subclass_, unique_sub) - 1; - IntegerVector ind = Range(0, treat.size() - 1); - IntegerVector ind1 = ind[treat == focal]; - int n1 = ind1.size(); + int nsub = unique_sub.size(); + + int n = treat.size(); + IntegerVector ind = Range(0, n - 1); + IntegerVector ind_focal = ind[treat == focal]; + int n1 = ind_focal.size(); + + IntegerVector subtab = rep(-1, nsub); + + int i; + for (i = 0; i < n; i++) { + if (na_sub[i]) { + continue; + } + + subtab[subclass[i]]++; + } + + int mm_col = max(subtab); IntegerMatrix mm(n1, mm_col); mm.fill(NA_INTEGER); - rownames(mm) = as(treat.names())[ind1]; - - IntegerVector in_sub = ind[!is_na(subclass)]; - IntegerVector ind_in_sub = ind[in_sub]; - IntegerVector ind0_in_sub = ind_in_sub[as(treat[in_sub]) != focal]; - IntegerVector sub0_in_sub = subclass[ind0_in_sub]; + CharacterVector lab = treat.names(); - int i, t, s, nmc, mci; - IntegerVector mc(mm_col); + IntegerVector ss = rep(NA_INTEGER, n1); + int s, si; for (i = 0; i < n1; i++) { - t = ind1[i]; - s = subclass[t]; - - if (s != NA_INTEGER) { - mc = ind0_in_sub[sub0_in_sub == s]; - nmc = mc.size(); - for (mci = 0; mci < nmc; mci++) { - mm(i, mci) = mc[mci] + 1; + if (na_sub[ind_focal[i]]) { + continue; + } + + ss[i] = subclass[ind_focal[i]]; + } + + for (i = 0; i < n; i++) { + if (treat[i] == focal) { + continue; + } + + if (na_sub[i]) { + continue; + } + + si = subclass[i]; + + for (s = 0; s < n1; s++) { + if (!std::isfinite(ss[s])) { + continue; } + + if (si != ss[s]) { + continue; + } + + mm(s, sum(!is_na(mm(s, _)))) = i; + break; } } - return mm; -} + mm = mm + 1; + rownames(mm) = lab[ind_focal]; + + return mm; +} \ No newline at end of file diff --git a/src/weights_matrixC.cpp b/src/weights_matrixC.cpp index dbbaf85f..190ca1df 100644 --- a/src/weights_matrixC.cpp +++ b/src/weights_matrixC.cpp @@ -1,53 +1,89 @@ #include +#include "internal.h" +#include using namespace Rcpp; -// Computes matching weights from match.matrix +// [[Rcpp::plugins(cpp11)]] +// Computes matching weights from match.matrix // [[Rcpp::export]] NumericVector weights_matrixC(const IntegerMatrix& mm, - const IntegerVector& treat) { + const IntegerVector& treat_) { + + CharacterVector lab = treat_.names(); + IntegerVector unique_treat = unique(treat_); + std::sort(unique_treat.begin(), unique_treat.end()); + int g = unique_treat.size(); + IntegerVector treat = match(treat_, unique_treat) - 1; + int n = treat.size(); - IntegerVector ind = Range(0, n - 1); - IntegerVector ind0 = ind[treat == 0]; - IntegerVector ind1 = ind[treat == 1]; + int gi; NumericVector weights = rep(0., n); - // weights.fill(0); + weights.names() = lab; + + IntegerVector row_ind = match(as(rownames(mm)), lab) - 1; + NumericVector matches_g = rep(0.0, g); int nr = mm.nrow(); - int nc = mm.ncol(); - int r, c, row_not_na, which_c, t_ind; - double weights_c, add_w; - IntegerVector row_r(nc); + int r, rn, i; + IntegerVector row_r(mm.ncol()); for (r = 0; r < nr; r++) { - row_r = na_omit(mm(r, _)); - row_not_na = row_r.size(); - if (row_not_na == 0) { + + row_r = na_omit(mm.row(r)); + + rn = row_r.size(); + + if (rn == 0) { continue; } - add_w = 1.0/static_cast(row_not_na); - for (c = 0; c < row_not_na; c++) { - which_c = row_r[c] - 1; - weights_c = weights[which_c]; - weights[which_c] = weights_c + add_w; + for (gi = 0; gi < g; gi++) { + matches_g[gi] = 0.0; + } + + for (i = 0; i < rn; i++) { + matches_g[treat[row_r[i] - 1]] += 1.0; } - t_ind = ind1[r]; - weights[t_ind] = 1; + for (i = 0; i < rn; i++) { + if (matches_g[treat[row_r[i] - 1]] == 0.0) { + continue; + } + + weights[row_r[i] - 1] += 1.0/matches_g[treat[row_r[i] - 1]]; + } + + weights[row_ind[r]] += 1.0; } - NumericVector c_weights = weights[ind0]; - double sum_c_w = sum(c_weights); - double sum_matched_c = sum(c_weights > 0); - int n0 = ind0.size(); + //Scale control weights to sum to number of matched controls + NumericVector weights_gi; + IntegerVector indg; + double sum_w; + double sum_matched; + + for (gi = 0; gi < g; gi++) { + indg = which(treat == gi); + + weights_gi = weights[indg]; + + sum_w = sum(weights_gi); + + if (sum_w == 0) { + continue; + } + + sum_matched = sum(weights_gi > 0); + + if (sum_matched == sum_w) { + continue; + } - if (sum_c_w > 0) { - for (int i = 0; i < n0; i++ ) { - which_c = ind0[i]; - weights[which_c] = c_weights[i] * sum_matched_c / sum_c_w; + for (int i : indg) { + weights[i] *= sum_matched / sum_w; } } From 4d2bb99b13844fe33047f5b9decc3b2ceee26bbd Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Wed, 16 Oct 2024 14:16:46 -0400 Subject: [PATCH 07/48] Improvements to NN matching using Rcpp; added caliper splitting feature, not implemented yet though. --- R/matchit2nearest.R | 189 +++++++++++++++++++++++++++++++------------- src/get_splitsC.cpp | 30 +++++++ 2 files changed, 164 insertions(+), 55 deletions(-) create mode 100644 src/get_splitsC.cpp diff --git a/R/matchit2nearest.R b/R/matchit2nearest.R index f8878cf1..77a158fa 100644 --- a/R/matchit2nearest.R +++ b/R/matchit2nearest.R @@ -279,11 +279,11 @@ NULL matchit2nearest <- function(treat, data, distance, discarded, - ratio = 1, s.weights = NULL, replace = FALSE, m.order = NULL, - caliper = NULL, mahvars = NULL, exact = NULL, - formula = NULL, estimand = "ATT", verbose = FALSE, - is.full.mahalanobis, - antiexact = NULL, unit.id = NULL, ...){ + ratio = 1, s.weights = NULL, replace = FALSE, m.order = NULL, + caliper = NULL, mahvars = NULL, exact = NULL, + formula = NULL, estimand = "ATT", verbose = FALSE, + is.full.mahalanobis, + antiexact = NULL, unit.id = NULL, ...){ if (verbose) { rlang::check_installed("RcppProgress") @@ -336,14 +336,28 @@ matchit2nearest <- function(treat, data, distance, discarded, distance_mat <- distance distance <- NULL - if (focal == 0) distance_mat <- t(distance_mat) + if (focal == 0) { + distance_mat <- t(distance_mat) + } } #Process caliper + ex.caliper.list <- list() if (!is.null(caliper)) { if (any(names(caliper) != "")) { caliper.covs <- caliper[names(caliper) != ""] caliper.covs.mat <- get.covs.matrix(reformulate(names(caliper.covs)), data = data) + + ex.caliper.list <- setNames(lapply(names(caliper.covs), function(i) { + splits <- get_splitsC(as.numeric(caliper.covs.mat[,i]), + as.numeric(caliper.covs[i])) + + if (length(splits) == 0) return(NULL) + + cut(caliper.covs.mat[,i], + breaks = splits, + include.lowest = TRUE) + }), names(caliper.covs)) } else { caliper.covs.mat <- caliper.covs <- NULL @@ -351,16 +365,44 @@ matchit2nearest <- function(treat, data, distance, discarded, if (any(names(caliper) == "")) { caliper.dist <- caliper[names(caliper) == ""] + + if (!is.null(distance)) { + splits <- get_splitsC(as.numeric(distance), + as.numeric(caliper.dist)) + + ex.caliper.list <- c(ex.caliper.list, + list(distance = cut(distance, + breaks = splits, + include.lowest = TRUE))) + } } else { caliper.dist <- NULL } + + if (length(ex.caliper.list) > 0L) { + ex.caliper.list <- ex.caliper.list[lengths(ex.caliper.list) > 0L] + + for (i in seq_along(ex.caliper.list)) { + levels(ex.caliper.list[[i]]) <- paste(names(ex.caliper.list)[i], "\u2208", levels(ex.caliper.list[[i]])) + } + } } else { caliper.dist <- caliper.covs <- NULL caliper.covs.mat <- NULL } + #### + ex.caliper.list <- list() + #### + + ex.caliper <- { + if (length(ex.caliper.list) == 0) NULL + else as.factor(exactify(ex.caliper.list, nam = lab, sep = ", ", + justify = NULL)) + } + #Process antiexact if (!is.null(antiexact)) { antiexactcovs <- model.frame(antiexact, data) @@ -383,22 +425,49 @@ matchit2nearest <- function(treat, data, distance, discarded, unit.id <- factor(exactify(model.frame(unit.id, data = data), nam = lab, sep = ", ", include_vars = TRUE)) num_ctrl_unit.ids <- length(unique(unit.id[treat == 0])) + num_trt_unit.ids <- length(unique(unit.id[treat == 1])) - #If each control unit is a unit.id, unit.ids are meaningless - if (num_ctrl_unit.ids == n0) unit.id <- NULL + #If each unit is a unit.id, unit.ids are meaningless + if (num_ctrl_unit.ids == n0 && num_trt_unit.ids == n1) { + # unit.id <- NULL + } } else { unit.id <- NULL } + #Process exact if (!is.null(exact)) { - ex <- factor(exactify(model.frame(exact, data = data), nam = lab, sep = ", ", include_vars = TRUE)) + ex <- as.factor(exactify(model.frame(exact, data = data), nam = lab, sep = ", ", include_vars = TRUE)) + } + else { + ex <- NULL + } - cc <- intersect(as.integer(ex)[treat==1], as.integer(ex)[treat==0]) - if (length(cc) == 0) .err("No matches were found") + if (!is.null(ex) || !is.null(ex.caliper)) { + ex0 <- { + if (!is.null(ex) && !is.null(ex.caliper)) { + as.factor(exactify(list(ex, ex.caliper), nam = lab, sep = ", ", + justify = NULL)) + } + else if (!is.null(ex)) ex + else ex.caliper + } + + cc <- intersect(as.integer(ex0)[treat==1], as.integer(ex0)[treat==0]) - if (reuse.max < n1) { + if (length(cc) == 0L) { + .err("No matches were found") + } + + cc <- sort(cc) + } + else { + ex0 <- NULL + } + if (reuse.max < n1) { + if (!is.null(ex)) { e_ratios <- vapply(levels(ex), function(e) { if (is.null(unit.id)) sum(treat[ex == e] == 0)*(reuse.max/sum(treat[ex == e] == 1)) else length(unique(unit.id[treat == 0 & ex == e]))*(reuse.max/sum(treat[ex == e] == 1)) @@ -406,23 +475,18 @@ matchit2nearest <- function(treat, data, distance, discarded, if (any(e_ratios < 1)) { .wrn(sprintf("fewer %s units than %s units in some `exact` strata; not all %s units will get a match", - tc[2], tc[1], tc[1])) + tc[2], tc[1], tc[1])) } if (ratio > 1 && any(e_ratios < ratio)) { if (is.null(max.controls) || ratio == max.controls) .wrn(sprintf("not all %s units will get %s matches", - tc[1], ratio)) + tc[1], ratio)) else .wrn(sprintf("not enough %s units for an average of %s matches per %s unit in all `exact` strata", - tc[2], ratio, tc[1])) + tc[2], ratio, tc[1])) } } - } - else { - ex <- NULL - - if (reuse.max < n1) { - + else { e_ratios <- { if (is.null(unit.id)) as.numeric(reuse.max)*n0/n1 else as.numeric(reuse.max)*num_ctrl_unit.ids/n1 @@ -430,15 +494,15 @@ matchit2nearest <- function(treat, data, distance, discarded, if (e_ratios < 1) { .wrn(sprintf("fewer %s %s than %s units; not all %s units will get a match", - tc[2], if (is.null(unit.id)) "units" else "unit IDs", tc[1], tc[1])) + tc[2], if (is.null(unit.id)) "units" else "unit IDs", tc[1], tc[1])) } else if (e_ratios < ratio) { if (is.null(max.controls) || ratio == max.controls) .wrn(sprintf("not all %s units will get %s matches", - tc[1], ratio)) + tc[1], ratio)) else .wrn(sprintf("not enough %s %s for an average of %s matches per %s unit", - tc[2], if (is.null(unit.id)) "units" else "unit IDs", ratio, tc[1])) + tc[2], if (is.null(unit.id)) "units" else "unit IDs", ratio, tc[1])) } } } @@ -447,8 +511,10 @@ matchit2nearest <- function(treat, data, distance, discarded, #Each treated unit get its own value of ratio if (!is.null(max.controls)) { if (is.null(distance)) { - if (is.full.mahalanobis) .err(sprintf("`distance` cannot be \"%s\" for variable ratio matching", - transform)) + if (is.full.mahalanobis) { + .err(sprintf("`distance` cannot be \"%s\" for variable ratio matching", + transform)) + } .err("`distance` cannot be supplied as a matrix for variable ratio matching") } @@ -486,7 +552,7 @@ matchit2nearest <- function(treat, data, distance, discarded, else "largest" } - if (is.null(ex) || !is.null(unit.id)) { + if (is.null(ex0) || !is.null(unit.id) || (is.null(mahcovs) && is.null(distance_mat))) { if (m.order == "closest") { if (!is.null(mahcovs)) { distance_mat <- eucdist_internal(mahcovs, treat) @@ -501,9 +567,6 @@ matchit2nearest <- function(treat, data, distance, discarded, caliper.dist <- NULL } } - else if (is.null(distance_mat)) { - distance_mat <- eucdist_internal(distance, treat) - } ord <- NULL } @@ -516,41 +579,48 @@ matchit2nearest <- function(treat, data, distance, discarded, } mm <- nn_matchC_dispatch(treat, ord, ratio, discarded, reuse.max, distance, distance_mat, - ex, caliper.dist, caliper.covs, caliper.covs.mat, mahcovs, + ex0, caliper.dist, caliper.covs, caliper.covs.mat, mahcovs, antiexactcovs, unit.id, m.order, verbose) - } else { distance_ <- caliper.covs.mat_ <- mahcovs_ <- antiexactcovs_ <- distance_mat_ <- NULL - mm_list <- lapply(levels(ex)[cc], function(e) { + + mm_list <- lapply(levels(ex0)[cc], function(e) { if (verbose) { cat(sprintf("Matching subgroup %s/%s: %s...\n", - match(e, levels(ex)[cc]), length(cc), e)) + match(e, levels(ex0)[cc]), length(cc), e)) } - .e <- which(ex == e) - .e1 <- which(ex[treat==1] == e) + .e <- which(ex0 == e) + .e1 <- which(ex0[treat==1] == e) treat_ <- treat[.e] + discarded_ <- discarded[.e] if (!is.null(distance)) distance_ <- distance[.e] if (!is.null(caliper.covs.mat)) caliper.covs.mat_ <- caliper.covs.mat[.e,,drop = FALSE] if (!is.null(mahcovs)) mahcovs_ <- mahcovs[.e,,drop = FALSE] if (!is.null(antiexactcovs)) antiexactcovs_ <- antiexactcovs[.e,,drop = FALSE] if (!is.null(distance_mat)) { - .e0 <- which(ex[treat==0] == e) + .e0 <- which(ex0[treat==0] == e) distance_mat_ <- distance_mat[.e1, .e0, drop = FALSE] } ratio_ <- ratio[.e1] n1_ <- sum(treat_ == 1) - if (m.order == "closest") { if (!is.null(mahcovs)) { distance_mat_ <- eucdist_internal(mahcovs_, treat_) - } - else if (is.null(distance_mat)) { - distance_mat_ <- eucdist_internal(distance_, treat_) + + # If caliper on PS (distance), treat it as a covariate + if (!is.null(distance) && !is.null(caliper.dist)) { + caliper.covs.mat <- { + if (is.null(caliper.covs)) as.matrix(distance_) + else cbind(caliper.covs.mat_, distance_) + } + caliper.covs <- c(caliper.covs, caliper.dist) + caliper.dist <- NULL + } } ord <- NULL @@ -564,8 +634,8 @@ matchit2nearest <- function(treat, data, distance, discarded, } mm_ <- nn_matchC_dispatch(treat_, ord_, ratio_, discarded_, reuse.max, distance_, distance_mat_, - NULL, caliper.dist, caliper.covs, caliper.covs.mat_, mahcovs_, - antiexactcovs_, NULL, m.order, verbose) + NULL, caliper.dist, caliper.covs, caliper.covs.mat_, mahcovs_, + antiexactcovs_, NULL, m.order, verbose) #Ensure matched indices correspond to indices in full sample, not subgroup mm_[] <- seq_along(treat)[.e][mm_] @@ -608,18 +678,27 @@ matchit2nearest <- function(treat, data, distance, discarded, nn_matchC_dispatch <- function(treat, ord, ratio, discarded, reuse.max, distance, distance_mat, ex, caliper.dist, caliper.covs, caliper.covs.mat, mahcovs, antiexactcovs, unit.id, m.order, verbose) { if (m.order == "closest") { - nn_matchC_closest(distance_mat, treat, ratio, discarded, reuse.max, - ex, caliper.dist, caliper.covs, caliper.covs.mat, - antiexactcovs, unit.id, verbose) - } - else if (is.null(distance_mat) && is.null(mahcovs)) { - nn_matchC_vec(treat, ord, ratio, discarded, reuse.max, distance, - ex, caliper.dist, caliper.covs, caliper.covs.mat, - antiexactcovs, unit.id, verbose) + if (is.null(distance_mat) && is.null(mahcovs)) { + nn_matchC_vec_closest(treat, ratio, discarded, reuse.max, distance, + ex, caliper.dist, caliper.covs, caliper.covs.mat, + antiexactcovs, unit.id, verbose) + } + else { + nn_matchC_closest(distance_mat, treat, ratio, discarded, reuse.max, + ex, caliper.dist, caliper.covs, caliper.covs.mat, + antiexactcovs, unit.id, verbose) + } } else { - nn_matchC(treat, ord, ratio, discarded, reuse.max, distance, distance_mat, - ex, caliper.dist, caliper.covs, caliper.covs.mat, mahcovs, - antiexactcovs, unit.id, verbose) + if (is.null(distance_mat) && is.null(mahcovs)) { + nn_matchC_vec(treat, ord, ratio, discarded, reuse.max, 1L, distance, + ex, caliper.dist, caliper.covs, caliper.covs.mat, + antiexactcovs, unit.id, verbose) + } + else { + nn_matchC(treat, ord, ratio, discarded, reuse.max, 1L, distance, distance_mat, + ex, caliper.dist, caliper.covs, caliper.covs.mat, mahcovs, + antiexactcovs, unit.id, verbose) + } } -} +} \ No newline at end of file diff --git a/src/get_splitsC.cpp b/src/get_splitsC.cpp new file mode 100644 index 00000000..5ecef79c --- /dev/null +++ b/src/get_splitsC.cpp @@ -0,0 +1,30 @@ +#include +using namespace Rcpp; + +// [[Rcpp::export]] +NumericVector get_splitsC(const NumericVector& x, + const double& caliper) { + + NumericVector splits; + + NumericVector x_ = unique(x); + NumericVector x_sorted = x_.sort(); + + int n = x_sorted.size(); + + if (n <= 1) { + return splits; + } + + splits = x_sorted[0]; + + for (int i = 1; i < x_sorted.length(); i++) { + if (x_sorted[i] - x_sorted[i - 1] <= caliper) continue; + + splits.push_back((x_sorted[i] + x_sorted[i - 1]) / 2); + } + + splits.push_back(x_sorted[n - 1]); + + return splits; +} \ No newline at end of file From 729b30fbf7fecae559fac4ec82f2ad9d9a125c80 Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Thu, 24 Oct 2024 18:41:24 -0400 Subject: [PATCH 08/48] Updates and improvements --- src/internal.cpp | 115 +++++++++++++++++++++++++++------------- src/internal.h | 31 ++++++----- src/pairdistC.cpp | 30 ++++++----- src/subclass_scootC.cpp | 13 +++-- src/tabulateC.cpp | 6 ++- src/weights_matrixC.cpp | 38 ++++++------- 6 files changed, 147 insertions(+), 86 deletions(-) diff --git a/src/internal.cpp b/src/internal.cpp index 63a717b9..1b399464 100644 --- a/src/internal.cpp +++ b/src/internal.cpp @@ -11,10 +11,10 @@ using namespace Rcpp; // [[Rcpp::interfaces(cpp)]] IntegerVector tabulateC_(const IntegerVector& bins, - const Nullable& nbins = R_NilValue) { + const int& nbins = 0) { int max_bin; - if (nbins.isNotNull()) max_bin = as(nbins); + if (nbins > 0) max_bin = nbins; else max_bin = max(na_omit(bins)); IntegerVector counts(max_bin); @@ -38,15 +38,15 @@ IntegerVector which(const LogicalVector& x) { // [[Rcpp::interfaces(cpp)]] bool antiexact_okay(const int& aenc, - const int& ii, const int& i, + const int& j, const IntegerMatrix& antiexact_covs) { if (aenc == 0) { return true; } - for (int j = 0; j < aenc; j++) { - if (antiexact_covs(ii, j) == antiexact_covs(i, j)) { + for (int k = 0; k < aenc; k++) { + if (antiexact_covs(i, k) == antiexact_covs(j, k)) { return false; } } @@ -56,22 +56,22 @@ bool antiexact_okay(const int& aenc, // [[Rcpp::interfaces(cpp)]] bool caliper_covs_okay(const int& ncc, - const int& ii, const int& i, + const int& j, const NumericMatrix& caliper_covs_mat, const NumericVector& caliper_covs) { if (ncc == 0) { return true; } - for (int j = 0; j < ncc; j++) { - if (caliper_covs[j] >= 0) { - if (std::abs(caliper_covs_mat(ii, j) - caliper_covs_mat(i, j)) > caliper_covs[j]) { + for (int k = 0; k < ncc; k++) { + if (caliper_covs[k] >= 0) { + if (std::abs(caliper_covs_mat(i, k) - caliper_covs_mat(j, k)) > caliper_covs[k]) { return false; } } else { - if (std::abs(caliper_covs_mat(ii, j) - caliper_covs_mat(i, j)) <= -caliper_covs[j]) { + if (std::abs(caliper_covs_mat(i, k) - caliper_covs_mat(j, k)) <= -caliper_covs[k]) { return false; } } @@ -83,10 +83,10 @@ bool caliper_covs_okay(const int& ncc, // [[Rcpp::interfaces(cpp)]] bool mm_okay(const int& r, const int& i, - const IntegerVector& mm_ordi) { + const IntegerVector& mm_rowi) { if (r > 1) { - for (int j : mm_ordi) { + for (int j : na_omit(mm_rowi)) { if (i == j) { return false; } @@ -98,15 +98,15 @@ bool mm_okay(const int& r, // [[Rcpp::interfaces(cpp)]] bool exact_okay(const bool& use_exact, - const int& ii, const int& i, + const int& j, const IntegerVector& exact) { if (!use_exact) { return true; } - return exact[ii] == exact[i]; + return exact[i] == exact[j]; } // [[Rcpp::interfaces(cpp)]] @@ -127,13 +127,13 @@ int find_both(const int& t_id, const IntegerVector& exact, const int& aenc, const IntegerMatrix& antiexact_covs, - const int& first_control, - const int& last_control) { + const IntegerVector& first_control, + const IntegerVector& last_control) { int ii = match_d_ord[t_id]; - int min_ii = first_control; - int max_ii = last_control; + int min_ii = first_control[gi]; + int max_ii = last_control[gi]; int iil = ii; int iir = ii; @@ -259,8 +259,8 @@ int find_lr(const int& prev_match, const IntegerVector& exact, const int& aenc, const IntegerMatrix& antiexact_covs, - const int& first_control, - const int& last_control) { + const IntegerVector& first_control, + const IntegerVector& last_control) { int ik, iik; double dist; @@ -277,7 +277,7 @@ int find_lr(const int& prev_match, z = 1; } - prev_pos = ii + z; + prev_pos = ii; } else { prev_pos = match_d_ord[prev_match]; @@ -293,13 +293,22 @@ int find_lr(const int& prev_match, int max_ii = ind_d_ord.size() - 1; if (z == -1) { - min_ii = first_control; + min_ii = first_control[gi]; } else { - max_ii = last_control; + max_ii = last_control[gi]; } - for (iik = prev_pos; iik >= min_ii && iik <= max_ii; iik = iik + z) { + int start = prev_pos + z; + if (start > last_control[gi]) { + start = last_control[gi]; + } + else if (start < first_control[gi]) { + start = first_control[gi]; + } + + for (iik = start; iik >= min_ii && iik <= max_ii; iik += z) { + ik = ind_d_ord[iik]; if (!eligible[ik]) { @@ -335,23 +344,21 @@ int find_lr(const int& prev_match, } // [[Rcpp::interfaces(cpp)]] -IntegerVector swap_pos(IntegerVector x, - const int& a, - const int& b) { +void swap_pos(IntegerVector x, + const int& a, + const int& b) { int xa = x[a]; x[a] = x[b]; x[b] = xa; - - return x; } // [[Rcpp::interfaces(cpp)]] double max_finite(const NumericVector& x) { double m = NA_REAL; - int n = x.size(); - int i; + R_xlen_t n = x.size(); + R_xlen_t i; bool found = false; //Find first finite value @@ -369,8 +376,8 @@ double max_finite(const NumericVector& x) { } //Find largest finite value - for (int j = i + 1; j < n; j++) { - if (!std::isfinite(x[i])) { + for (R_xlen_t j = i + 1; j < n; j++) { + if (!std::isfinite(x[j])) { continue; } @@ -386,8 +393,8 @@ double max_finite(const NumericVector& x) { double min_finite(const NumericVector& x) { double m = NA_REAL; - int n = x.size(); - int i; + R_xlen_t n = x.size(); + R_xlen_t i; bool found = false; //Find first finite value @@ -405,8 +412,8 @@ double min_finite(const NumericVector& x) { } //Find smallest finite value - for (int j = i + 1; j < n; j++) { - if (!std::isfinite(x[i])) { + for (R_xlen_t j = i + 1; j < n; j++) { + if (!std::isfinite(x[j])) { continue; } @@ -416,4 +423,38 @@ double min_finite(const NumericVector& x) { } return m; +} + +// [[Rcpp::interfaces(cpp)]] +void update_first_and_last_control(IntegerVector first_control, + IntegerVector last_control, + const IntegerVector& ind_d_ord, + const LogicalVector& eligible, + const IntegerVector& treat, + const int& gi) { + R_xlen_t c; + + // Update first_control + if (!eligible[ind_d_ord[first_control[gi]]]) { + for (c = first_control[gi] + 1; c <= last_control[gi]; c++) { + if (eligible[ind_d_ord[c]]) { + if (treat[ind_d_ord[c]] == gi) { + first_control[gi] = c; + break; + } + } + } + } + + // Update last_control + if (!eligible[ind_d_ord[last_control[gi]]]) { + for (c = last_control[gi] - 1; c >= first_control[gi]; c--) { + if (eligible[ind_d_ord[c]]) { + if (treat[ind_d_ord[c]] == gi) { + last_control[gi] = c; + break; + } + } + } + } } \ No newline at end of file diff --git a/src/internal.h b/src/internal.h index 2cf86a16..1a42e007 100644 --- a/src/internal.h +++ b/src/internal.h @@ -5,7 +5,7 @@ using namespace Rcpp; IntegerVector tabulateC_(const IntegerVector& bins, - const Nullable& nbins = R_NilValue); + const int& nbins = 0); IntegerVector which(const LogicalVector& x); @@ -26,8 +26,8 @@ int find_both(const int& t_id, const IntegerVector& exact, const int& aenc, const IntegerMatrix& antiexact_covs, - const int& first_control, - const int& last_control); + const IntegerVector& first_control, + const IntegerVector& last_control); int find_lr(const int& prev_match, const int& t_id, @@ -45,35 +45,42 @@ int find_lr(const int& prev_match, const IntegerVector& exact, const int& aenc, const IntegerMatrix& antiexact_covs, - const int& first_control, - const int& last_control); + const IntegerVector& first_control, + const IntegerVector& last_control); bool antiexact_okay(const int& aenc, - const int& ii, const int& i, + const int& j, const IntegerMatrix& antiexact_covs); bool caliper_covs_okay(const int& ncc, - const int& ii, const int& i, + const int& j, const NumericMatrix& caliper_covs_mat, const NumericVector& caliper_covs); bool mm_okay(const int& r, const int& i, - const IntegerVector& mm_ordi); + const IntegerVector& mm_rowi); bool exact_okay(const bool& use_exact, - const int& ii, const int& i, + const int& j, const IntegerVector& exact); -IntegerVector swap_pos(IntegerVector x, - const int& a, - const int& b); +void swap_pos(IntegerVector x, + const int& a, + const int& b); double max_finite(const NumericVector& x); double min_finite(const NumericVector& x); +void update_first_and_last_control(IntegerVector first_control, + IntegerVector last_control, + const IntegerVector& ind_d_ord, + const LogicalVector& eligible, + const IntegerVector& treat, + const int& gi); + #endif \ No newline at end of file diff --git a/src/pairdistC.cpp b/src/pairdistC.cpp index a4c158a2..8bc02aaf 100644 --- a/src/pairdistC.cpp +++ b/src/pairdistC.cpp @@ -9,30 +9,36 @@ double pairdistsubC(const NumericVector& x, double dist = 0; - int n = t.size(); - int i, j; + R_xlen_t i, j; + int s_i, ord_i, ord_j; int k = 0; - for (i = 0; i < n - 1; i++) { - if (!std::isfinite(s[i])) { - continue; - } + Function o("order"); + IntegerVector ord = o(s); + ord = ord - 1; + + R_xlen_t n = sum(!is_na(s)); + + + for (i = 0; i < n; i++) { + ord_i = ord[i]; + s_i = s[ord_i]; for (j = i + 1; j < n; j++) { - if (s[i] != s[j]) { - continue; + ord_j = ord[j]; + + if (s[ord_j] != s_i) { + break; } - if (t[i] == t[j]) { + if (t[ord_j] == t[ord_i]) { continue; } - dist += std::abs(x[i] - x[j]); k++; + dist += (std::abs(x[ord_j] - x[ord_i]) - dist) / k; } } - dist /= k; - return dist; } \ No newline at end of file diff --git a/src/subclass_scootC.cpp b/src/subclass_scootC.cpp index 9ea0ee98..91cb7430 100644 --- a/src/subclass_scootC.cpp +++ b/src/subclass_scootC.cpp @@ -2,6 +2,7 @@ #include "internal.h" using namespace Rcpp; +// [[Rcpp::plugins(cpp11)]] // [[Rcpp::export]] IntegerVector subclass_scootC(const IntegerVector& subclass_, @@ -14,8 +15,9 @@ IntegerVector subclass_scootC(const IntegerVector& subclass_, } int m, i, s, s2; - int best_i, nt; + int best_i; double best_x, score; + R_xlen_t nt; LogicalVector na_sub = is_na(subclass_); @@ -23,7 +25,7 @@ IntegerVector subclass_scootC(const IntegerVector& subclass_, IntegerVector treat = treat_[!na_sub]; NumericVector x = x_[!na_sub]; - int n = subclass.size(); + R_xlen_t n = subclass.size(); IntegerVector unique_sub = unique(subclass); std::sort(unique_sub.begin(), unique_sub.end()); @@ -45,7 +47,7 @@ IntegerVector subclass_scootC(const IntegerVector& subclass_, //Tabulate subtab = rep(0, nsub); for (int i : indt) { - subtab[subclass[i]]++; + subtab[subclass[i]]++; } for (m = 0; m < min_n; m++) { @@ -107,7 +109,7 @@ IntegerVector subclass_scootC(const IntegerVector& subclass_, } for (i = best_i + 1; i < nt; i++) { - if (subclass[indt[i]] != s2) { + if (subclass[indt[i]] != s2) { continue; } @@ -141,7 +143,8 @@ IntegerVector subclass_scootC(const IntegerVector& subclass_, subclass[i] = unique_sub[subclass[i]]; } - IntegerVector sub_out = rep(NA_INTEGER, subclass_.size()); + IntegerVector sub_out(subclass_.size()); + sub_out.fill(NA_INTEGER); sub_out[!na_sub] = subclass; diff --git a/src/tabulateC.cpp b/src/tabulateC.cpp index e1d06c50..5ec5419f 100644 --- a/src/tabulateC.cpp +++ b/src/tabulateC.cpp @@ -5,5 +5,9 @@ using namespace Rcpp; // [[Rcpp::export]] IntegerVector tabulateC(const IntegerVector& bins, const Nullable& nbins = R_NilValue) { - return tabulateC_(bins, nbins); + + int nbins_ = 0; + if (nbins.isNotNull()) nbins_ = as(nbins); + + return tabulateC_(bins, nbins_); } \ No newline at end of file diff --git a/src/weights_matrixC.cpp b/src/weights_matrixC.cpp index 190ca1df..f6b06fc5 100644 --- a/src/weights_matrixC.cpp +++ b/src/weights_matrixC.cpp @@ -8,7 +8,8 @@ using namespace Rcpp; // Computes matching weights from match.matrix // [[Rcpp::export]] NumericVector weights_matrixC(const IntegerMatrix& mm, - const IntegerVector& treat_) { + const IntegerVector& treat_, + const Nullable& focal = R_NilValue) { CharacterVector lab = treat_.names(); IntegerVector unique_treat = unique(treat_); @@ -16,44 +17,43 @@ NumericVector weights_matrixC(const IntegerMatrix& mm, int g = unique_treat.size(); IntegerVector treat = match(treat_, unique_treat) - 1; - int n = treat.size(); + R_xlen_t n = treat.size(); int gi; - NumericVector weights = rep(0., n); + NumericVector weights(n); + weights.fill(0.0); weights.names() = lab; - IntegerVector row_ind = match(as(rownames(mm)), lab) - 1; - NumericVector matches_g = rep(0.0, g); + IntegerVector row_ind; + if (focal.isNotNull()) { + row_ind = which(treat == as(focal)); + } + else { + row_ind = match(as(rownames(mm)), lab) - 1; + } - int nr = mm.nrow(); + NumericVector matches_g = rep(0.0, g); - int r, rn, i; IntegerVector row_r(mm.ncol()); - for (r = 0; r < nr; r++) { + for (int r : which(!is_na(mm(_, 0)))) { row_r = na_omit(mm.row(r)); - rn = row_r.size(); - - if (rn == 0) { - continue; - } - for (gi = 0; gi < g; gi++) { matches_g[gi] = 0.0; } - for (i = 0; i < rn; i++) { - matches_g[treat[row_r[i] - 1]] += 1.0; + for (int i : row_r - 1) { + matches_g[treat[i]] += 1.0; } - for (i = 0; i < rn; i++) { - if (matches_g[treat[row_r[i] - 1]] == 0.0) { + for (int i : row_r - 1) { + if (matches_g[treat[i]] == 0.0) { continue; } - weights[row_r[i] - 1] += 1.0/matches_g[treat[row_r[i] - 1]]; + weights[i] += 1.0/matches_g[treat[i]]; } weights[row_ind[r]] += 1.0; From eb6a7a5e871676f04b4c2f2f90a12e979ea6c33c Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Thu, 24 Oct 2024 18:42:08 -0400 Subject: [PATCH 09/48] Added fast mm2subclassC --- src/subclass2mm.cpp | 54 ++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 49 insertions(+), 5 deletions(-) diff --git a/src/subclass2mm.cpp b/src/subclass2mm.cpp index f4274384..ce574cbc 100644 --- a/src/subclass2mm.cpp +++ b/src/subclass2mm.cpp @@ -16,14 +16,14 @@ IntegerMatrix subclass2mmC(const IntegerVector& subclass_, int nsub = unique_sub.size(); - int n = treat.size(); + R_xlen_t n = treat.size(); IntegerVector ind = Range(0, n - 1); IntegerVector ind_focal = ind[treat == focal]; - int n1 = ind_focal.size(); + R_xlen_t n1 = ind_focal.size(); IntegerVector subtab = rep(-1, nsub); - int i; + R_xlen_t i; for (i = 0; i < n; i++) { if (na_sub[i]) { continue; @@ -38,7 +38,8 @@ IntegerMatrix subclass2mmC(const IntegerVector& subclass_, mm.fill(NA_INTEGER); CharacterVector lab = treat.names(); - IntegerVector ss = rep(NA_INTEGER, n1); + IntegerVector ss(n1); + ss.fill(NA_INTEGER); int s, si; for (i = 0; i < n1; i++) { @@ -46,7 +47,7 @@ IntegerMatrix subclass2mmC(const IntegerVector& subclass_, continue; } - ss[i] = subclass[ind_focal[i]]; + ss[i] = subclass[ind_focal[i]]; } for (i = 0; i < n; i++) { @@ -79,4 +80,47 @@ IntegerMatrix subclass2mmC(const IntegerVector& subclass_, rownames(mm) = lab[ind_focal]; return mm; +} + +// [[Rcpp::export]] +IntegerVector mm2subclassC(const IntegerMatrix& mm, + const IntegerVector& treat, + const Nullable& focal = R_NilValue) { + + CharacterVector lab = treat.names(); + + IntegerVector subclass(treat.size()); + subclass.fill(NA_INTEGER); + subclass.names() = lab; + + IntegerVector ind1; + if (focal.isNotNull()) { + ind1 = which(treat == as(focal)); + } + else { + ind1 = match(as(rownames(mm)), lab) - 1; + } + + int r = mm.nrow(); + int ki = 0; + + for (int i : which(!is_na(mm))) { + if (i / r == 0) { + //If first entry in row, increment ki and assign subclass of treated + ki++; + subclass[ind1[i % r]] = ki; + } + + subclass[mm[i] - 1] = ki; + } + + CharacterVector levs(ki); + for (int j = 0; j < ki; j++){ + levs[j] = std::to_string(j + 1); + } + + subclass.attr("class") = "factor"; + subclass.attr("levels") = levs; + + return subclass; } \ No newline at end of file From 0c44e4d810a847527b643a71bd33d4cf2db88534 Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Thu, 24 Oct 2024 18:44:10 -0400 Subject: [PATCH 10/48] Code cleaning and small improvements --- R/add_s.weights.R | 105 +++++++++------- R/discard.R | 4 +- R/dist_functions.R | 86 +++++++------- R/distance2_methods.R | 155 +++++++++++++----------- R/input_processing.R | 250 +++++++++++++++++++++++++-------------- R/match.data.R | 30 ++--- R/match.qoi.R | 33 ++---- R/matchit.R | 131 ++++++++++++-------- R/matchit2cardinality.R | 45 ++++--- R/matchit2exact.R | 8 +- R/matchit2full.R | 18 +-- R/matchit2optimal.R | 20 ++-- R/matchit2quick.R | 12 +- R/matchit2subclass.R | 18 +-- R/plot.matchit.R | 122 +++++++++++-------- R/plot.summary.matchit.R | 10 +- R/rbind.matchdata.R | 43 +++++-- R/summary.matchit.R | 70 ++++++----- 18 files changed, 675 insertions(+), 485 deletions(-) diff --git a/R/add_s.weights.R b/R/add_s.weights.R index 7754a255..366356fc 100644 --- a/R/add_s.weights.R +++ b/R/add_s.weights.R @@ -61,63 +61,80 @@ add_s.weights <- function(m, chk::chk_is(m, "matchit") - if (!is.null(s.weights)) { - if (!is.numeric(s.weights)) { - if (is.null(data)) { - if (!is.null(m$model)) { - env <- attributes(terms(m$model))$.Environment - } else { - env <- parent.frame() - } - data <- eval(m$call$data, envir = env) - if (length(data) == 0) { - .err("a dataset could not be found. Please supply an argument to `data` containing the original dataset used in the matching") - } + if (is_null(s.weights)) { + return(m) + } + + if (!is.numeric(s.weights)) { + if (is_null(data)) { + if (is_not_null(m$model)) { + env <- attributes(terms(m$model))$.Environment + } else { + env <- parent.frame() } - else { - if (!is.data.frame(data)) { - if (is.matrix(data)) data <- as.data.frame.matrix(data) - else .err("`data` must be a data frame") - } - if (nrow(data) != length(m$treat)) { - .err("`data` must have as many rows as there were units in the original call to `matchit()`") + data <- eval(m$call$data, envir = env) + if (is_null(data)) { + .err("a dataset could not be found. Please supply an argument to `data` containing the original dataset used in the matching") + } + } + else { + if (!is.data.frame(data)) { + if (!is.matrix(data)) { + .err("`data` must be a data frame") } + data <- as.data.frame.matrix(data) } - if (is.character(s.weights)) { - if (is.null(data) || !is.data.frame(data)) { - .err("if `s.weights` is specified a string, a data frame containing the named variable must be supplied to `data`") - } - if (!all(s.weights %in% names(data))) { - .err("the name supplied to `s.weights` must be a variable in `data`") - } - s.weights.form <- reformulate(s.weights) - s.weights <- model.frame(s.weights.form, data, na.action = "na.pass") - if (ncol(s.weights) != 1) .err("`s.weights` can only contain one named variable") - s.weights <- s.weights[[1]] + if (nrow(data) != length(m$treat)) { + .err("`data` must have as many rows as there were units in the original call to `matchit()`") + } + } + + if (is.character(s.weights)) { + if (is_null(data) || !is.data.frame(data)) { + .err("if `s.weights` is specified a string, a data frame containing the named variable must be supplied to `data`") } - else if (rlang::is_formula(s.weights)) { - s.weights.form <- update(terms(s.weights, data = data), NULL ~ .) - s.weights <- model.frame(s.weights.form, data, na.action = "na.pass") - if (ncol(s.weights) != 1) .err("`s.weights` can only contain one named variable") - s.weights <- s.weights[[1]] + + if (!all(s.weights %in% names(data))) { + .err("the name supplied to `s.weights` must be a variable in `data`") + } + + s.weights.form <- reformulate(s.weights) + s.weights <- model.frame(s.weights.form, data, na.action = "na.pass") + + if (ncol(s.weights) != 1L) { + .err("`s.weights` can only contain one named variable") } - else { - .err("`s.weights` must be supplied as a numeric vector, string, or one-sided formula") + + s.weights <- s.weights[[1L]] + } + else if (rlang::is_formula(s.weights)) { + s.weights.form <- update(terms(s.weights, data = data), NULL ~ .) + s.weights <- model.frame(s.weights.form, data, na.action = "na.pass") + + if (ncol(s.weights) != 1L) { + .err("`s.weights` can only contain one named variable") } + + s.weights <- s.weights[[1L]] + } + else { + .err("`s.weights` must be supplied as a numeric vector, string, or one-sided formula") } + } - chk::chk_not_any_na(s.weights) - if (length(s.weights) != length(m$treat)) .err("`s.weights` must be the same length as the treatment vector") + chk::chk_not_any_na(s.weights) + if (length(s.weights) != length(m$treat)) { + .err("`s.weights` must be the same length as the treatment vector") + } - names(s.weights) <- names(m$treat) + names(s.weights) <- names(m$treat) - attr(s.weights, "in_ps") <- isTRUE(all.equal(s.weights, m$s.weights)) + attr(s.weights, "in_ps") <- isTRUE(all.equal(s.weights, m$s.weights)) - m$s.weights <- s.weights + m$s.weights <- s.weights - m$nn <- nn(m$treat, m$weights, m$discarded, s.weights) - } + m$nn <- nn(m$treat, m$weights, m$discarded, s.weights) m } diff --git a/R/discard.R b/R/discard.R index 764673e3..74f2bf9a 100644 --- a/R/discard.R +++ b/R/discard.R @@ -2,7 +2,7 @@ discard <- function(treat, pscore = NULL, option = NULL) { n.obs <- length(treat) - if (length(option) == 0){ + if (is_null(option)){ # keep all units return(setNames(rep(FALSE, n.obs), names(treat))) } @@ -23,7 +23,7 @@ discard <- function(treat, pscore = NULL, option = NULL) { return(setNames(rep(FALSE, n.obs), names(treat))) } - if (is.null(pscore)) { + if (is_null(pscore)) { .err('`discard` must be a logical vector or "none" in the absence of a propensity score') } diff --git a/R/dist_functions.R b/R/dist_functions.R index b8174089..c879e22b 100644 --- a/R/dist_functions.R +++ b/R/dist_functions.R @@ -104,15 +104,14 @@ #' #' @author Noah Greifer #' @seealso [`distance`], [matchit()], [dist()] (which is used -#' internally to compute Euclidean distances) +#' internally to compute some Euclidean distances) #' #' \pkgfun{optmatch}{match_on}, which provides similar functionality but with fewer #' options and a focus on efficient storage of the output. #' #' @references #' -#' Rosenbaum, P. R. (2010). *Design of observational studies*. -#' Springer. +#' Rosenbaum, P. R. (2010). *Design of observational studies*. Springer. #' #' Rosenbaum, P. R., & Rubin, D. B. (1985). Constructing a Control Group Using #' Multivariate Matched Sampling Methods That Incorporate the Propensity Score. @@ -221,6 +220,7 @@ transform_covariates <- function(formula = NULL, data = NULL, method = "mahalano #If allvariables have no variance, use Euclidean to avoid errors #If some have no variance, removes those to avoid messing up distances no_variance <- which(apply(X, 2, function(x) abs(max(x) - min(x)) < sqrt(.Machine$double.eps))) + if (length(no_variance) == ncol(X)) { method <- "euclidean" X <- X[, 1, drop = FALSE] @@ -231,22 +231,22 @@ transform_covariates <- function(formula = NULL, data = NULL, method = "mahalano method <- match_arg(method, matchit_distances()) - if (is.null(discarded)) discarded <- rep(FALSE, nrow(X)) + if (is_null(discarded)) discarded <- rep(FALSE, nrow(X)) if (method == "mahalanobis") { # X <- sweep(X, 2, colMeans(X)) - if (is.null(var)) { + if (is_null(var)) { X <- scale(X) + #NOTE: optmatch and Rubin (1980) use pooled within-group covariance matrix - if (!is.null(treat)) { - var <- pooled_cov(X[!discarded,, drop = FALSE], treat[!discarded], s.weights[!discarded]) - } - else if (is.null(s.weights)) { - var <- cov(X[!discarded,, drop = FALSE]) - } - else { - var <- cov.wt(X[!discarded,, drop = FALSE], s.weights[!discarded])$cov + var <- { + if (is_not_null(treat)) + pooled_cov(X[!discarded,, drop = FALSE], treat[!discarded], s.weights[!discarded]) + else if (is_null(s.weights)) + cov(X[!discarded,, drop = FALSE]) + else + cov.wt(X[!discarded,, drop = FALSE], s.weights[!discarded])$cov } } else if (!is.cov_like(var)) { @@ -269,12 +269,16 @@ transform_covariates <- function(formula = NULL, data = NULL, method = "mahalano #Rosenbaum (2010, ch8) X_r <- matrix(0, nrow = sum(!discarded), ncol = ncol(X), dimnames = list(rownames(X)[!discarded], colnames(X))) + for (i in seq_len(ncol(X_r))) X_r[,i] <- rank(X[!discarded, i]) - if (is.null(s.weights)) var_r <- cov(X_r) - else var_r <- cov.wt(X_r, s.weights[!discarded])$cov + var_r <- { + if (is_null(s.weights)) cov(X_r) + else cov.wt(X_r, s.weights[!discarded])$cov + } + + multiplier <- sd(seq_len(sum(!discarded))) / sqrt(diag(var_r)) - multiplier <- sd(seq_len(sum(!discarded)))/sqrt(diag(var_r)) var_r <- var_r * outer(multiplier, multiplier, "*") inv_var <- NULL @@ -294,12 +298,9 @@ transform_covariates <- function(formula = NULL, data = NULL, method = "mahalano X <- mahalanobize(X_r, inv_var) } - else if (method == "euclidean") { - #Do nothing - } else if (method == "scaled_euclidean") { - if (is.null(var)) { - if (!is.null(treat)) { + if (is_null(var)) { + if (is_not_null(treat)) { sds <- pooled_sd(X[!discarded,, drop = FALSE], treat[!discarded], s.weights[!discarded]) } else { @@ -320,6 +321,9 @@ transform_covariates <- function(formula = NULL, data = NULL, method = "mahalano X[,i] <- X[,i]/sds[i] } } + else if (method == "euclidean") { + #Do nothing + } attr(X, "treat") <- treat X @@ -328,12 +332,10 @@ transform_covariates <- function(formula = NULL, data = NULL, method = "mahalano #Internal function for fast(ish) Euclidean distance eucdist_internal <- function(X, treat = NULL) { - if (is.null(treat)) { - if (NCOL(X) == 1L) { - d <- abs(outer(drop(X), drop(X), "-")) - } - else { - d <- as.matrix(dist(X)) + if (is_null(treat)) { + d <- { + if (NCOL(X) == 1L) abs(outer(drop(X), drop(X), "-")) + else as.matrix(dist(X)) } dimnames(d) <- list(rownames(X), rownames(X)) @@ -341,13 +343,9 @@ eucdist_internal <- function(X, treat = NULL) { else { treat_l <- as.logical(treat) - if (NCOL(X) == 1L) { - d <- abs(outer(X[treat_l], X[!treat_l], "-")) - } - else { - # d <- dist(X) - # d <- as.matrix(d)[treat_l, !treat_l, drop = FALSE] - d <- eucdistC_N1xN0(X, as.integer(treat)) + d <- { + if (NCOL(X) == 1L) abs(outer(X[treat_l], X[!treat_l], "-")) + else eucdistC_N1xN0(X, as.integer(treat)) } dimnames(d) <- list(rownames(X)[treat_l], rownames(X)[!treat_l]) @@ -360,8 +358,8 @@ eucdist_internal <- function(X, treat = NULL) { #to ensure same result as when non-factor binary variable supplied (see optmatch:::contr.match_on) get.covs.matrix.for.dist <- function(formula = NULL, data = NULL) { - if (is.null(formula)) { - if (is.null(colnames(data))) colnames(data) <- paste0("X", seq_len(ncol(data))) + if (is_null(formula)) { + if (is_null(colnames(data))) colnames(data) <- paste0("X", seq_len(ncol(data))) fnames <- colnames(data) fnames[!startsWith(fnames, "`")] <- add_quotes(fnames[!startsWith(fnames, "`")], "`") data <- as.data.frame(data) @@ -405,18 +403,26 @@ get.covs.matrix.for.dist <- function(formula = NULL, data = NULL) { treat <- attr(X, "treat") - if (is.data.frame(X)) X <- as.matrix(X) - else if (is.numeric(X) && is.null(dim(X))) { + if (is.data.frame(X)) { + X <- as.matrix(X) + } + else if (is.numeric(X) && is_null(dim(X))) { X <- matrix(X, nrow = length(X), dimnames = list(names(X), NULL)) } - if (anyNA(X)) .err("missing values are not allowed in the covariates") - if (any(!is.finite(X))) .err("Non-finite values are not allowed in the covariates") + if (anyNA(X)) { + .err("missing values are not allowed in the covariates") + } + + if (any(!is.finite(X))) { + .err("Non-finite values are not allowed in the covariates") + } if (!is.numeric(X) || length(dim(X)) != 2) { stop("bad X") } + attr(X, "checked") <- TRUE attr(X, "treat") <- treat X diff --git a/R/distance2_methods.R b/R/distance2_methods.R index fd58d9d0..232d2572 100644 --- a/R/distance2_methods.R +++ b/R/distance2_methods.R @@ -96,8 +96,11 @@ #' generalized boosted modeling as in *twang*; here, the number of trees is #' chosen based on cross-validation or out-of-bag error, rather than based on #' optimizing balance. \pkg{twang} should not be cited when using this method -#' to estimate propensity scores. } -#' \item{`"lasso"`, `"ridge"`, `"elasticnet"`}{ The propensity +#' to estimate propensity scores. Note that because there is a random component to choosing the tuning +#' parameter, results will vary across runs unless a [seed][set.seed] is +#' set.} +#' \item{`"lasso"`, `"ridge"`, `"elasticnet"`}{ +#' The propensity #' scores are estimated using a lasso, ridge, or elastic net model, #' respectively. The `formula` supplied to `matchit()` is processed #' with [model.matrix()] and passed to \pkgfun{glmnet}{cv.glmnet}, and @@ -128,7 +131,8 @@ #' directly to \pkgfun{randomForest}{randomForest}, and #' \pkgfun{randomForest}{predict.randomForest} is used to compute the propensity #' scores. The `link` argument is ignored, and predicted probabilities are -#' always returned as the distance measure.} +#' always returned as the distance measure. Note that because there is a random component, results will vary across runs unless a [seed][set.seed] is +#' set. } #' \item{`"nnet"`}{ The #' propensity scores are estimated using a single-hidden-layer neural network. #' The `formula` supplied to `matchit()` is passed directly to @@ -156,35 +160,35 @@ #' the linear predictor instead of the predicted probabilities. When #' `s.weights` is supplied to `matchit()`, it will not be passed to #' `bart2` because the `weights` argument in `bart2` does not -#' correspond to sampling weights. } +#' correspond to sampling weights. Note that because there is a random component to choosing the tuning +#' parameter, results will vary across runs unless the `seed` argument is supplied to `distance.options`. Note that setting a seed using [set.seed()] is not sufficient to guarantee reproducibility unless single-threading is used. See \pkgfun{dbarts}{bart2} for details.} #' } #' #' ## Methods for computing distances from covariates #' -#' The following methods involve computing a distance matrix from the covariates themselves -#' without estimating a propensity score. Calipers on the distance measure and -#' common support restrictions cannot be used, and the `distance` -#' component of the output object will be empty because no propensity scores -#' are estimated. The `link` and `distance.options` arguments are -#' ignored with these methods. See the individual matching methods pages for -#' whether these distances are allowed and how they are used. Each of these -#' distance measures can also be calculated outside `matchit()` using its -#' [corresponding function][euclidean_dist]. +#' The following methods involve computing a distance matrix from the covariates +#' themselves without estimating a propensity score. Calipers on the distance +#' measure and common support restrictions cannot be used, and the `distance` +#' component of the output object will be empty because no propensity scores are +#' estimated. The `link` and `distance.options` arguments are ignored with these +#' methods. See the individual matching methods pages for whether these +#' distances are allowed and how they are used. Each of these distance measures +#' can also be calculated outside `matchit()` using its [corresponding +#' function][euclidean_dist]. #' #' \describe{ #' \item{`"euclidean"`}{ The Euclidean distance is the raw -#' distance between units, computed as \deqn{d_{ij} = \sqrt{(x_i - x_j)(x_i - -#' x_j)'}} It is sensitive to the scale of the covariates, so covariates with +#' distance between units, computed as \deqn{d_{ij} = \sqrt{(x_i - x_j)(x_i - x_j)'}} It is sensitive to the scale of the covariates, so covariates with #' larger scales will take higher priority. } -#' \item{`"scaled_euclidean"`}{ The scaled Euclidean distance is the +#' \item{`"scaled_euclidean"`}{ +#' The scaled Euclidean distance is the #' Euclidean distance computed on the scaled (i.e., standardized) covariates. #' This ensures the covariates are on the same scale. The covariates are #' standardized using the pooled within-group standard deviations, computed by #' treatment group-mean centering each covariate before computing the standard -#' deviation in the full sample. } -#' \item{`"mahalanobis"`}{ The -#' Mahalanobis distance is computed as \deqn{d_{ij} = \sqrt{(x_i - -#' x_j)\Sigma^{-1}(x_i - x_j)'}} where \eqn{\Sigma} is the pooled within-group +#' deviation in the full sample. +#' } +#' \item{`"mahalanobis"`}{ The Mahalanobis distance is computed as \deqn{d_{ij} = \sqrt{(x_i - x_j)\Sigma^{-1}(x_i - x_j)'}} where \eqn{\Sigma} is the pooled within-group #' covariance matrix of the covariates, computed by treatment group-mean #' centering each covariate before computing the covariance in the full sample. #' This ensures the variables are on the same scale and accounts for the @@ -197,37 +201,36 @@ #' Mahalanobis distance but is not affinely invariant. } #' } #' -#' To perform Mahalanobis distance matching *and* estimate propensity -#' scores to be used for a purpose other than matching, the `mahvars` -#' argument should be used along with a different specification to -#' `distance`. See the individual matching method pages for details on how -#' to use `mahvars`. +#' To perform Mahalanobis distance matching *and* estimate propensity scores to +#' be used for a purpose other than matching, the `mahvars` argument should be +#' used along with a different specification to `distance`. See the individual +#' matching method pages for details on how to use `mahvars`. #' #' ## Distances supplied as a numeric vector or matrix #' -#' `distance` can also be supplied as a numeric vector whose values will be taken to -#' function like propensity scores; their pairwise difference will define the -#' distance between units. This might be useful for supplying propensity scores -#' computed outside `matchit()` or resupplying `matchit()` with -#' propensity scores estimated previously without having to recompute them. +#' `distance` can also be supplied as a numeric vector whose values will be +#' taken to function like propensity scores; their pairwise difference will +#' define the distance between units. This might be useful for supplying +#' propensity scores computed outside `matchit()` or resupplying `matchit()` +#' with propensity scores estimated previously without having to recompute them. #' #' `distance` can also be supplied as a matrix whose values represent the #' pairwise distances between units. The matrix should either be a square, with #' a row and column for each unit (e.g., as the output of a call to -#' `as.matrix(`[`dist`]`(.))`), or have as many rows as there are treated -#' units and as many columns as there are control units (e.g., as the output of -#' a call to [mahalanobis_dist()] or \pkgfun{optmatch}{match_on}). Distance values -#' of `Inf` will disallow the corresponding units to be matched. When -#' `distance` is a supplied as a numeric vector or matrix, `link` and -#' `distance.options` are ignored. +#' `as.matrix(`[`dist`]`(.))`), or have as many rows as there are treated units +#' and as many columns as there are control units (e.g., as the output of a call +#' to [mahalanobis_dist()] or \pkgfun{optmatch}{match_on}). Distance values of +#' `Inf` will disallow the corresponding units to be matched. When `distance` is +#' a supplied as a numeric vector or matrix, `link` and `distance.options` are +#' ignored. #' -#' @note -#' In versions of *MatchIt* prior to 4.0.0, `distance` was -#' specified in a slightly different way. When specifying arguments using the -#' old syntax, they will automatically be converted to the corresponding method -#' in the new syntax but a warning will be thrown. `distance = "logit"`, -#' the old default, will still work in the new syntax, though `distance = "glm", link = "logit"` is preferred (note that these are the default -#' settings and don't need to be made explicit). +#' @note In versions of *MatchIt* prior to 4.0.0, `distance` was specified in a +#' slightly different way. When specifying arguments using the old syntax, they +#' will automatically be converted to the corresponding method in the new syntax +#' but a warning will be thrown. `distance = "logit"`, the old default, will +#' still work in the new syntax, though `distance = "glm", link = "logit"` is +#' preferred (note that these are the default settings and don't need to be made +#' explicit). #' #' @examples #' data("lalonde") @@ -293,7 +296,7 @@ NULL #distance2glm----------------- distance2glm <- function(formula, data = NULL, link = "logit", ...) { - linear <- !is.null(link) && startsWith(as.character(link), "linear") + linear <- is_not_null(link) && startsWith(as.character(link), "linear") if (linear) link <- sub("linear.", "", as.character(link), fixed = TRUE) A <- list(...) @@ -311,7 +314,7 @@ distance2glm <- function(formula, data = NULL, link = "logit", ...) { distance2gam <- function(formula, data = NULL, link = "logit", ...) { rlang::check_installed("mgcv") - linear <- !is.null(link) && startsWith(as.character(link), "linear") + linear <- is_not_null(link) && startsWith(as.character(link), "linear") if (linear) link <- sub("linear.", "", as.character(link), fixed = TRUE) A <- list(...) @@ -357,25 +360,30 @@ distance2nnet <- function(formula, data = NULL, link = NULL, ...) { distance2cbps <- function(formula, data = NULL, link = NULL, ...) { rlang::check_installed("CBPS") - linear <- !is.null(link) && startsWith(as.character(link), "linear") + linear <- is_not_null(link) && startsWith(as.character(link), "linear") A <- list(...) A[["standardized"]] <- FALSE - if (is.null(A[["ATT"]])) { - if (is.null(A[["estimand"]])) A[["ATT"]] <- 1 + + if (is_null(A[["ATT"]])) { + if (is_null(A[["estimand"]])) { + A[["ATT"]] <- 1 + } else { estimand <- toupper(A[["estimand"]]) estimand <- match_arg(estimand, c("ATT", "ATC", "ATE")) A[["ATT"]] <- switch(estimand, "ATT" = 1, "ATC" = 2, 0) } } - if (is.null(A[["method"]])) { + + if (is_null(A[["method"]])) { A[["method"]] <- if (isFALSE(A[["over"]])) "exact" else "over" } + A[c("estimand", "over")] <- NULL - if (!is.null(A[["weights"]])) { + if (is_not_null(A[["weights"]])) { A[["sample.weights"]] <- A[["weights"]] A[["weights"]] <- NULL } @@ -394,7 +402,7 @@ distance2cbps <- function(formula, data = NULL, link = NULL, ...) { distance2bart <- function(formula, data = NULL, link = NULL, ...) { rlang::check_installed("dbarts") - linear <- !is.null(link) && startsWith(as.character(link), "linear") + linear <- is_not_null(link) && startsWith(as.character(link), "linear") A <- list(...) A[!names(A) %in% c(names(formals(dbarts::bart2)), names(formals(dbarts::dbartsControl)))] <- NULL @@ -412,7 +420,7 @@ distance2bart <- function(formula, data = NULL, link = NULL, ...) { # distance2bart <- function(formula, data, link = NULL, ...) { # rlang::check_installed("BART") # -# if (!is.null(link) && startsWith(as.character(link), "linear")) { +# if (is_not_null(link) && startsWith(as.character(link), "linear")) { # linear <- TRUE # link <- sub("linear.", "", as.character(link), fixed = TRUE) # } @@ -421,7 +429,7 @@ distance2bart <- function(formula, data = NULL, link = NULL, ...) { # #Keep link probit because default in matchit is logit but probit is much faster with BART # link <- "probit" # -# # if (is.null(link)) link <- "probit" +# # if (is_null(link)) link <- "probit" # # else if (!link %in% c("probit", "logit")) { # # stop("'link' must be \"probit\" or \"logit\" with distance = \"bart\".", call. = FALSE) # # } @@ -436,7 +444,7 @@ distance2bart <- function(formula, data = NULL, link = NULL, ...) { # # A <- list(...) # -# if (!is.null(A[["mc.cores"]]) && A[["mc.cores"]][1] > 1) fun <- BART::mc.gbart +# if (is_not_null(A[["mc.cores"]]) && A[["mc.cores"]][1] > 1) fun <- BART::mc.gbart # else fun <- BART::gbart # # res <- do.call(fun, c(list(X, @@ -459,26 +467,28 @@ distance2randomforest <- function(formula, data = NULL, link = NULL, ...) { newdata[[treatvar]] <- factor(newdata[[treatvar]], levels = c("0", "1")) res <- randomForest::randomForest(formula, data = newdata, ...) - list(model = res, distance = predict(res, type = "prob")[,"1"]) + list(model = res, distance = predict(res, type = "prob")[,"1"]) } #distance2glmnet-------------- distance2elasticnet <- function(formula, data = NULL, link = NULL, ...) { rlang::check_installed("glmnet") - linear <- !is.null(link) && startsWith(as.character(link), "linear") + linear <- is_not_null(link) && startsWith(as.character(link), "linear") if (linear) link <- sub("linear.", "", as.character(link), fixed = TRUE) A <- list(...) s <- A[["s"]] A[!names(A) %in% c(names(formals(glmnet::glmnet)), names(formals(glmnet::cv.glmnet)))] <- NULL - if (is.null(link)) link <- "logit" - if (link == "logit") A$family <- "binomial" - else if (link == "log") A$family <- "poisson" - else A$family <- binomial(link = link) + if (is_null(link)) link <- "logit" + + A$family <- switch(link, + "logit" = "binomial", + "log" = "poisson", + binomial(link = link)) - if (is.null(A[["alpha"]])) A[["alpha"]] <- .5 + if (is_null(A[["alpha"]])) A[["alpha"]] <- .5 mf <- model.frame(formula, data = data) @@ -487,10 +497,10 @@ distance2elasticnet <- function(formula, data = NULL, link = NULL, ...) { res <- do.call(glmnet::cv.glmnet, A) - if (is.null(s)) s <- "lambda.1se" + if (is_null(s)) s <- "lambda.1se" pred <- drop(predict(res, newx = A$x, s = s, - type = if (linear) "link" else "response")) + type = if (linear) "link" else "response")) list(model = res, distance = pred) } @@ -511,7 +521,7 @@ distance2ridge <- function(formula, data = NULL, link = NULL, ...) { distance2gbm <- function(formula, data = NULL, link = NULL, ...) { rlang::check_installed("gbm") - linear <- !is.null(link) && startsWith(as.character(link), "linear") + linear <- is_not_null(link) && startsWith(as.character(link), "linear") A <- list(...) @@ -522,17 +532,18 @@ distance2gbm <- function(formula, data = NULL, link = NULL, ...) { A$data <- data A$distribution <- "bernoulli" - if (is.null(A[["n.trees"]])) A[["n.trees"]] <- 1e4 - if (is.null(A[["interaction.depth"]])) A[["interaction.depth"]] <- 3 - if (is.null(A[["shrinkage"]])) A[["shrinkage"]] <- .01 - if (is.null(A[["bag.fraction"]])) A[["bag.fraction"]] <- 1 - if (is.null(A[["cv.folds"]])) A[["cv.folds"]] <- 5 - if (is.null(A[["keep.data"]])) A[["keep.data"]] <- FALSE + if (is_null(A[["n.trees"]])) A[["n.trees"]] <- 1e4 + if (is_null(A[["interaction.depth"]])) A[["interaction.depth"]] <- 3 + if (is_null(A[["shrinkage"]])) A[["shrinkage"]] <- .01 + if (is_null(A[["bag.fraction"]])) A[["bag.fraction"]] <- 1 + if (is_null(A[["cv.folds"]])) A[["cv.folds"]] <- 5 + if (is_null(A[["keep.data"]])) A[["keep.data"]] <- FALSE if (A[["cv.folds"]] <= 1 && A[["bag.fraction"]] == 1) { .err('either `bag.fraction` must be less than 1 or `cv.folds` must be greater than 1 when using `distance = "gbm"`') } - if (is.null(method)) { + + if (is_null(method)) { if (A[["bag.fraction"]] < 1) method <- "OOB" else method <- "cv" } @@ -547,5 +558,5 @@ distance2gbm <- function(formula, data = NULL, link = NULL, ...) { pred <- drop(predict(res, newdata = data, n.trees = best.tree, type = if (linear) "link" else "response")) - list(model = res, distance = pred) + list(model = res, distance = pred) } diff --git a/R/input_processing.R b/R/input_processing.R index 305229d1..61460e78 100644 --- a/R/input_processing.R +++ b/R/input_processing.R @@ -4,12 +4,13 @@ check.inputs <- function(mcall, method, distance, exact, mahvars, antiexact, ratio, m.order, estimand, ..., min.controls = NULL, max.controls = NULL) { - null.method <- is.null(method) + null.method <- is_null(method) if (null.method) { method <- "NULL" } else { - method <- match_arg(method, c("exact", "cem", "nearest", "optimal", "full", "genetic", "subclass", "cardinality", + method <- match_arg(method, c("exact", "cem", "nearest", "optimal", "full", + "genetic", "subclass", "cardinality", "quick")) } @@ -17,21 +18,21 @@ check.inputs <- function(mcall, method, distance, exact, mahvars, antiexact, error.inputs <- character(0) if (null.method) { for (i in c("exact", "mahvars", "antiexact", "caliper", "std.caliper", "replace", "ratio", "min.controls", "max.controls", "m.order")) { - if (i %in% names(mcall) && !is.null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { + if (i %in% names(mcall) && is_not_null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { ignored.inputs <- c(ignored.inputs, i) } } } else if (method == "exact") { for (i in c("distance", "exact", "mahvars", "antiexact", "caliper", "std.caliper", "discard", "reestimate", "replace", "ratio", "min.controls", "max.controls", "m.order")) { - if (i %in% names(mcall) && !is.null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { + if (i %in% names(mcall) && is_not_null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { ignored.inputs <- c(ignored.inputs, i) } } } else if (method == "cem") { for (i in c("distance", "exact", "mahvars", "antiexact", "caliper", "std.caliper", "discard", "reestimate", "replace", "ratio", "min.controls", "max.controls", "m.order")) { - if (i %in% names(mcall) && !is.null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { + if (i %in% names(mcall) && is_not_null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { ignored.inputs <- c(ignored.inputs, i) } } @@ -39,7 +40,7 @@ check.inputs <- function(mcall, method, distance, exact, mahvars, antiexact, else if (method == "nearest") { if (is.character(distance) && distance %in% matchit_distances()) { for (e in c("mahvars", "reestimate")) { - if (e %in% names(mcall) && !is.null(e_ <- get0(e, inherits = FALSE)) && !identical(e_, formals(matchit)[[e]])) { + if (e %in% names(mcall) && is_not_null(e_ <- get0(e, inherits = FALSE)) && !identical(e_, formals(matchit)[[e]])) { error.inputs <- c(error.inputs, e) } } @@ -48,14 +49,14 @@ check.inputs <- function(mcall, method, distance, exact, mahvars, antiexact, else if (method == "optimal") { if (is.character(distance) && distance %in% matchit_distances()) { for (e in c("mahvars", "reestimate")) { - if (e %in% names(mcall) && !is.null(e_ <- get0(e, inherits = FALSE)) && !identical(e_, formals(matchit)[[e]])) { + if (e %in% names(mcall) && is_not_null(e_ <- get0(e, inherits = FALSE)) && !identical(e_, formals(matchit)[[e]])) { error.inputs <- c(error.inputs, e) } } } for (i in c("replace", "caliper", "std.caliper", "m.order")) { - if (i %in% names(mcall) && !is.null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { + if (i %in% names(mcall) && is_not_null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { ignored.inputs <- c(ignored.inputs, i) } } @@ -64,14 +65,14 @@ check.inputs <- function(mcall, method, distance, exact, mahvars, antiexact, else if (method == "full") { if (is.character(distance) && distance %in% matchit_distances()) { for (e in c("mahvars", "reestimate")) { - if (e %in% names(mcall) && !is.null(e_ <- get0(e, inherits = FALSE)) && !identical(e_, formals(matchit)[[e]])) { + if (e %in% names(mcall) && is_not_null(e_ <- get0(e, inherits = FALSE)) && !identical(e_, formals(matchit)[[e]])) { error.inputs <- c(error.inputs, e) } } } for (i in c("replace", "ratio", "m.order")) { - if (i %in% names(mcall) && !is.null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { + if (i %in% names(mcall) && is_not_null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { ignored.inputs <- c(ignored.inputs, i) } } @@ -79,27 +80,27 @@ check.inputs <- function(mcall, method, distance, exact, mahvars, antiexact, else if (method == "genetic") { if (is.character(distance) && distance %in% matchit_distances()) { for (e in c("mahvars", "reestimate")) { - if (e %in% names(mcall) && !is.null(e_ <- get0(e, inherits = FALSE)) && !identical(e_, formals(matchit)[[e]])) { + if (e %in% names(mcall) && is_not_null(e_ <- get0(e, inherits = FALSE)) && !identical(e_, formals(matchit)[[e]])) { error.inputs <- c(error.inputs, e) } } } for (i in c("min.controls", "max.controls")) { - if (i %in% names(mcall) && !is.null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { + if (i %in% names(mcall) && is_not_null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { ignored.inputs <- c(ignored.inputs, i) } } } else if (method == "cardinality") { for (i in c("distance", "antiexact", "caliper", "std.caliper", "reestimate", "replace", "min.controls", "m.order")) { - if (i %in% names(mcall) && !is.null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { + if (i %in% names(mcall) && is_not_null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { ignored.inputs <- c(ignored.inputs, i) } } } else if (method == "subclass") { for (i in c("exact", "mahvars", "antiexact", "caliper", "std.caliper", "replace", "ratio", "min.controls", "max.controls", "m.order")) { - if (i %in% names(mcall) && !is.null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { + if (i %in% names(mcall) && is_not_null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { ignored.inputs <- c(ignored.inputs, i) } } @@ -107,47 +108,67 @@ check.inputs <- function(mcall, method, distance, exact, mahvars, antiexact, else if (method == "quick") { if (is.character(distance) && distance %in% matchit_distances()) { for (e in c("mahvars", "reestimate")) { - if (e %in% names(mcall) && !is.null(e_ <- get0(e, inherits = FALSE)) && !identical(e_, formals(matchit)[[e]])) { + if (e %in% names(mcall) && is_not_null(e_ <- get0(e, inherits = FALSE)) && !identical(e_, formals(matchit)[[e]])) { error.inputs <- c(error.inputs, e) } } } for (i in c("replace", "ratio", "min.controls", "max.controls", "m.order", "antiexact")) { - if (i %in% names(mcall) && !is.null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { + if (i %in% names(mcall) && is_not_null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { ignored.inputs <- c(ignored.inputs, i) } } } - if (length(ignored.inputs) > 0) .wrn(sprintf("the %s %s not used with `method = %s` and will be ignored", - ngettext(length(ignored.inputs), "argument", "arguments"), - word_list(ignored.inputs, quotes = 1, is.are = TRUE), - add_quotes(method, quotes = !null.method))) - if (length(error.inputs) > 0) .err(sprintf("the %s %s not used with `method = %s` and `distance = \"%s\"`", - ngettext(length(error.inputs), "argument", "arguments"), - word_list(error.inputs, quotes = 1, is.are = TRUE), - add_quotes(method, quotes = !null.method), - distance)) + if (is_not_null(ignored.inputs)) { + .wrn(sprintf("the %s %s not used with `method = %s` and will be ignored", + ngettext(length(ignored.inputs), "argument", "arguments"), + word_list(ignored.inputs, quotes = 1, is.are = TRUE), + add_quotes(method, quotes = !null.method))) + } + + if (is_not_null(error.inputs)) { + .err(sprintf("the %s %s not used with `method = %s` and `distance = \"%s\"`", + ngettext(length(error.inputs), "argument", "arguments"), + word_list(error.inputs, quotes = 1, is.are = TRUE), + add_quotes(method, quotes = !null.method), + distance)) + } + ignored.inputs } #Check treatment for type, binary, missing, num. rows check_treat <- function(treat = NULL, X = NULL) { - if (is.null(treat)) { - if (is.null(X) || is.null(attr(X, "treat"))) return(NULL) + if (is_null(treat)) { + if (is_null(X) || is_null(attr(X, "treat"))) { + return(NULL) + } + treat <- attr(X, "treat") } - if (isTRUE(attr(treat, "checked"))) return(treat) - if (!is.atomic(treat) || !is.null(dim(treat))) { + if (isTRUE(attr(treat, "checked"))) { + return(treat) + } + + if (!is.atomic(treat) || is_not_null(dim(treat))) { .err("the treatment must be a vector") } - if (anyNA(treat)) .err("missing values are not allowed in the treatment") - if (length(unique(treat)) != 2) .err("the treatment must be a binary variable") - if (!is.null(X) && length(treat) != nrow(X)) .err("the treatment and covariates must have the same number of units") + if (anyNA(treat)) { + .err("missing values are not allowed in the treatment") + } + + if (!has_n_unique(treat, 2L)) { + .err("the treatment must be a binary variable") + } + + if (is_not_null(X) && length(treat) != nrow(X)) { + .err("the treatment and covariates must have the same number of units") + } treat <- binarize(treat) #make 0/1 attr(treat, "checked") <- TRUE @@ -156,15 +177,16 @@ check_treat <- function(treat = NULL, X = NULL) { #Function to process distance and give warnings about new syntax process.distance <- function(distance, method = NULL, treat) { - if (is.null(distance)) { - if (!is.null(method) && !method %in% c("cem", "exact", "cardinality")) { + if (is_null(distance)) { + if (is_not_null(method) && !method %in% c("cem", "exact", "cardinality")) { .err(sprintf("`distance` cannot be `NULL` with `method = \"%s\"`", method)) } + return(distance) } - if (is.character(distance) && length(distance) == 1) { + if (is.character(distance) && length(distance) == 1L) { allowable.distances <- c( #Propensity score methods "glm", "cbps", "gam", "nnet", "rpart", "bart", @@ -177,14 +199,14 @@ process.distance <- function(distance, method = NULL, treat) { "linear.cauchit", "log", "probit")) { link <- tolower(distance) .wrn(sprintf("`distance = \"%s\"` will be deprecated; please use `distance = \"glm\", link = \"%s\"` in the future", - distance, link)) + distance, link)) distance <- "glm" attr(distance, "link") <- link } else if (tolower(distance) %in% tolower(c("GAMcloglog", "GAMlog", "GAMlogit", "GAMprobit"))) { link <- tolower(substr(distance, 4, nchar(distance))) .wrn(sprintf("`distance = \"%s\"` will be deprecated; please use `distance = \"gam\", link = \"%s\"` in the future", - distance, link)) + distance, link)) distance <- "gam" attr(distance, "link") <- link } @@ -198,34 +220,45 @@ process.distance <- function(distance, method = NULL, treat) { else if (!tolower(distance) %in% allowable.distances) { .err("the argument supplied to `distance` is not an allowable value. See `help(\"distance\")` for allowable options") } - else if (!is.null(method) && method == "subclass" && tolower(distance) %in% matchit_distances()) { + else if (is_not_null(method) && method == "subclass" && tolower(distance) %in% matchit_distances()) { .err(sprintf("`distance` cannot be %s with `method = \"subclass\"`", add_quotes(distance))) } else { distance <- tolower(distance) } - } - else if (!is.numeric(distance) || (!is.null(dim(distance)) && length(dim(distance)) != 2)) { + else if (!is.numeric(distance) || (is_not_null(dim(distance)) && length(dim(distance)) != 2)) { .err("`distance` must be a string with the name of the distance measure to be used or a numeric vector or matrix containing distance measures") } - else if (is.matrix(distance) && (is.null(method) || !method %in% c("nearest", "optimal", "full"))) { + else if (is.matrix(distance) && (is_null(method) || !method %in% c("nearest", "optimal", "full"))) { .err(sprintf("`distance` cannot be supplied as a matrix with `method = %s`", - add_quotes(method, quotes = !is.null(method)))) + add_quotes(method, quotes = is_not_null(method)))) } if (is.numeric(distance)) { if (is.matrix(distance)) { dim.distance <- dim(distance) + if (all(dim.distance == length(treat))) { - if (!is.null(rownames(distance))) distance <- distance[names(treat),, drop = FALSE] - if (!is.null(colnames(distance))) distance <- distance[,names(treat), drop = FALSE] + if (is_not_null(rownames(distance))) { + distance <- distance[names(treat),, drop = FALSE] + } + + if (is_not_null(colnames(distance))) { + distance <- distance[,names(treat), drop = FALSE] + } + distance <- distance[treat == 1, treat == 0, drop = FALSE] } else if (all(dim.distance == c(sum(treat==1), sum(treat==0)))) { - if (!is.null(rownames(distance))) distance <- distance[names(treat)[treat == 1],, drop = FALSE] - if (!is.null(colnames(distance))) distance <- distance[,names(treat)[treat == 0], drop = FALSE] + if (is_not_null(rownames(distance))) { + distance <- distance[names(treat)[treat == 1],, drop = FALSE] + } + + if (is_not_null(colnames(distance))) { + distance <- distance[,names(treat)[treat == 0], drop = FALSE] + } } else { .err("when supplied as a matrix, `distance` must have dimensions NxN or N1xN0. See `help(\"distance\")` for details") @@ -246,10 +279,13 @@ process.distance <- function(distance, method = NULL, treat) { #Function to check ratio is acceptable process.ratio <- function(ratio, method = NULL, ..., min.controls = NULL, max.controls = NULL) { #Should be run after process.inputs() and ignored inputs set to NULL - ratio.null <- length(ratio) == 0 + ratio.null <- is_null(ratio) ratio.na <- !ratio.null && anyNA(ratio) - if (is.null(method)) return(1) + if (is_null(method)) { + return(1) + } + if (method %in% c("nearest", "optimal")) { if (ratio.null) ratio <- 1 else if (ratio.na) .err("`ratio` cannot be `NA`") @@ -257,7 +293,7 @@ process.ratio <- function(ratio, method = NULL, ..., min.controls = NULL, max.co .err("`ratio` must be a single number greater than or equal to 1") } - if (is.null(max.controls)) { + if (is_null(max.controls)) { if (!chk::vld_whole_number(ratio)) { .err("`ratio` must be a whole number when `max.controls` is not specified") } @@ -267,28 +303,46 @@ process.ratio <- function(ratio, method = NULL, ..., min.controls = NULL, max.co .err("`max.controls` must be a single positive number") } else { - if (ratio <= 1) .err("`ratio` must be greater than 1 for variable ratio matching") + if (ratio <= 1) { + .err("`ratio` must be greater than 1 for variable ratio matching") + } max.controls <- ceiling(max.controls) - if (max.controls <= ratio) .err("`max.controls` must be greater than `ratio` for variable ratio matching") - if (is.null(min.controls)) min.controls <- 1 - else if (anyNA(max.controls) || !is.atomic(max.controls) || !is.numeric(max.controls) || length(max.controls) > 1) { + if (max.controls <= ratio) { + .err("`max.controls` must be greater than `ratio` for variable ratio matching") + } + + if (is_null(min.controls)) { + min.controls <- 1 + } + else if (!anyNA(max.controls) && is.atomic(max.controls) && is.numeric(max.controls) && length(max.controls) == 1L) { + min.controls <- floor(min.controls) + } + else { .err("`max.controls` must be a single positive number") } - else min.controls <- floor(min.controls) - if (min.controls < 1) .err("`min.controls` cannot be less than 1 for variable ratio matching") - if (min.controls >= ratio) .err("`min.controls` must be less than `ratio` for variable ratio matching") + if (min.controls < 1) { + .err("`min.controls` cannot be less than 1 for variable ratio matching") + } + + if (min.controls >= ratio) { + .err("`min.controls` must be less than `ratio` for variable ratio matching") + } } } else if (method == "full") { - if (is.null(max.controls)) max.controls <- Inf + if (is_null(max.controls)) { + max.controls <- Inf + } else if ((anyNA(max.controls) || !is.atomic(max.controls) || !is.numeric(max.controls) || length(max.controls) > 1)) { .err("`max.controls` must be a single positive number") } - if (is.null(min.controls)) min.controls <- 0 + if (is_null(min.controls)) { + min.controls <- 0 + } else if ((anyNA(min.controls) || !is.atomic(min.controls) || !is.numeric(min.controls) || length(min.controls) > 1)) { .err("`min.controls` must be a single positive number") } @@ -296,18 +350,25 @@ process.ratio <- function(ratio, method = NULL, ..., min.controls = NULL, max.co ratio <- 1 #Just to get min.controls and max.controls out } else if (method == "genetic") { - if (ratio.null) ratio <- 1 - else if (ratio.na) .err("`ratio` cannot be `NA`") + if (ratio.null) { + ratio <- 1 + } + else if (ratio.na) { + .err("`ratio` cannot be `NA`") + } else if (!is.atomic(ratio) || !is.numeric(ratio) || length(ratio) > 1 || ratio < 1 || !chk::vld_whole_number(ratio)) { .err("`ratio` must be a single whole number greater than or equal to 1") } + ratio <- round(ratio) min.controls <- max.controls <- NULL } else if (method == "cardinality") { - if (ratio.null) ratio <- 1 + if (ratio.null) { + ratio <- 1 + } else if (!ratio.na && (!is.atomic(ratio) || !is.numeric(ratio) || length(ratio) > 1 || ratio < 0)) { .err("`ratio` must be a single positive number or `NA`") } @@ -318,7 +379,7 @@ process.ratio <- function(ratio, method = NULL, ..., min.controls = NULL, max.co min.controls <- max.controls <- NULL } - if (!is.null(ratio)) { + if (is_not_null(ratio)) { attr(ratio, "min.controls") <- min.controls attr(ratio, "max.controls") <- max.controls } @@ -339,16 +400,20 @@ process.caliper <- function(caliper = NULL, method = NULL, data = NULL, covs = N #If std, export standardized versions #Check need for caliper - if (length(caliper) == 0 || is.null(method) || !method %in% c("nearest", "genetic", "full", "quick")) return(NULL) + if (is_null(caliper) || is_null(method) || !method %in% c("nearest", "genetic", "full", "quick")) { + return(NULL) + } #Check if form of caliper is okay - if (!is.atomic(caliper) || !is.numeric(caliper)) .err("`caliper` must be a numeric vector") + if (!is.atomic(caliper) || !is.numeric(caliper)) { + .err("`caliper` must be a numeric vector") + } #Check caliper names - if (length(caliper) == 1 && (is.null(names(caliper)) || identical(names(caliper), ""))) { + if (length(caliper) == 1L && (is_null(names(caliper)) || identical(names(caliper), ""))) { names(caliper) <- "" } - else if (is.null(names(caliper))) { + else if (is_null(names(caliper))) { .err("`caliper` must be a named vector with names corresponding to the variables for which a caliper is to be applied") } else if (anyNA(names(caliper))) { @@ -358,7 +423,7 @@ process.caliper <- function(caliper = NULL, method = NULL, data = NULL, covs = N .err("no more than one entry in `caliper` can have no name") } - if (any(names(caliper) == "") && is.null(distance)) { + if (any(names(caliper) == "") && is_null(distance)) { .err("all entries in `caliper` must be named when `distance` does not correspond to a propensity score") } @@ -368,7 +433,8 @@ process.caliper <- function(caliper = NULL, method = NULL, data = NULL, covs = N cal.in.mahcovs <- setNames(names(caliper) %in% names(mahcovs), names(caliper)) if (any(names(caliper) != "" & !cal.in.covs & !cal.in.data)) { .err(paste0("All variables named in `caliper` must be in `data`. Variables not in `data`:\n\t", - paste0(names(caliper)[names(caliper) != "" & !cal.in.data & !cal.in.covs & !cal.in.mahcovs], collapse = ", ")), tidy = FALSE) + paste0(names(caliper)[names(caliper) != "" & !cal.in.data & !cal.in.covs & !cal.in.mahcovs], collapse = ", ")), + tidy = FALSE) } #Check std.caliper @@ -376,31 +442,22 @@ process.caliper <- function(caliper = NULL, method = NULL, data = NULL, covs = N if (length(std.caliper) == 1) { std.caliper <- setNames(rep.int(std.caliper, length(caliper)), names(caliper)) } - else if (length(std.caliper) != length(caliper)) { + else if (length(std.caliper) == length(caliper)) { + names(std.caliper) <- names(caliper) + } + else { .err("`std.caliper` must be the same length as `caliper`") } - else names(std.caliper) <- names(caliper) #Remove trivial calipers caliper <- caliper[is.finite(caliper)] - num.unique <- vapply(names(caliper), function(x) { - if (x == "") var <- distance - else if (cal.in.data[x]) var <- data[[x]] - else if (cal.in.covs[x]) var <- covs[[x]] - else var <- mahcovs[[x]] - - length(unique(var)) - }, integer(1L)) - - caliper <- caliper[num.unique > 1] - - if (length(caliper) == 0) return(NULL) + if (is_null(caliper)) { + return(NULL) + } #Ensure no calipers on categorical variables cat.vars <- vapply(names(caliper), function(x) { - if (num.unique[names(num.unique) == x] == 2) return(TRUE) - if (x == "") var <- distance else if (cal.in.data[x]) var <- data[[x]] else if (cal.in.covs[x]) var <- covs[[x]] @@ -410,7 +467,7 @@ process.caliper <- function(caliper = NULL, method = NULL, data = NULL, covs = N }, logical(1L)) if (any(cat.vars)) { - .err(paste0("Calipers cannot be used with binary, factor, or character variables. Offending variables:\n\t", + .err(paste0("Calipers cannot be used with factor or character variables. Offending variables:\n\t", paste0(ifelse(names(caliper) == "", "", names(caliper))[cat.vars], collapse = ", ")), tidy = FALSE) } @@ -420,9 +477,10 @@ process.caliper <- function(caliper = NULL, method = NULL, data = NULL, covs = N chk::chk_not_any_na(std.caliper) if (any(std.caliper)) { - if ("" %in% names(std.caliper) && isTRUE(std.caliper[names(std.caliper) == ""]) && is.matrix(distance)) { + if (any(names(std.caliper) == "") && isTRUE(std.caliper[names(std.caliper) == ""]) && is.matrix(distance)) { .err("when `distance` is supplied as a matrix and a caliper for it is specified, `std.caliper` must be `FALSE` for the distance measure") } + caliper[std.caliper] <- caliper[std.caliper] * vapply(names(caliper)[std.caliper], function(x) { if (x == "") sd(distance[!discarded]) else if (cal.in.data[x]) sd(data[[x]][!discarded]) @@ -442,13 +500,18 @@ process.caliper <- function(caliper = NULL, method = NULL, data = NULL, covs = N #Function to process replace argument process.replace <- function(replace, method = NULL, ..., reuse.max = NULL) { - if (is.null(method)) return(FALSE) + if (is_null(method)) { + return(FALSE) + } + + if (is_null(replace)) { + replace <- FALSE + } - if (is.null(replace)) replace <- FALSE chk::chk_flag(replace) if (method %in% c("nearest")) { - if (is.null(reuse.max)) { + if (is_null(reuse.max)) { if (replace) reuse.max <- .Machine$integer.max else reuse.max <- 1L } @@ -474,17 +537,21 @@ process.replace <- function(replace, method = NULL, ..., reuse.max = NULL) { process.variable.input <- function(x, data = NULL) { n <- deparse1(substitute(x)) - if (is.null(x)) return(NULL) + if (is_null(x)) { + return(NULL) + } if (is.character(x)) { - if (is.null(data) || !is.data.frame(data)) { + if (is_null(data) || !is.data.frame(data)) { .err(sprintf("if `%s` is specified as strings, a data frame containing the named variables must be supplied to `data`", n)) } + if (!all(x %in% names(data))) { .err(sprintf("All names supplied to `%s` must be variables in `data`. Variables not in `data`:\n\t%s", n, paste(add_quotes(setdiff(x, names(data))), collapse = ", ")), tidy = FALSE) } + x <- reformulate(x) } else if (rlang::is_formula(x)) { @@ -495,6 +562,7 @@ process.variable.input <- function(x, data = NULL) { } x_covs <- model.frame(x, data, na.action = "na.pass") + if (anyNA(x_covs)) { .err(sprintf("missing values are not allowed in the covariates named in `%s`", n)) } diff --git a/R/match.data.R b/R/match.data.R index a4de88c1..9f90198e 100644 --- a/R/match.data.R +++ b/R/match.data.R @@ -178,17 +178,17 @@ match.data <- function(object, group = "all", distance = "distance", weights = " chk::chk_is(object, "matchit") - if (is.null(data)) { + if (is_null(data)) { env <- environment(object$formula) data <- try(eval(object$call$data, envir = env), silent = TRUE) - if (length(data) == 0 || inherits(data, "try-error") || + if (null_or_error(data) || length(dim(data)) != 2 || nrow(data) != length(object[["treat"]])) { env <- parent.frame() data <- try(eval(object$call$data, envir = env), silent = TRUE) - if (length(data) == 0 || inherits(data, "try-error") || + if (null_or_error(data) || length(dim(data)) != 2 || nrow(data) != length(object[["treat"]])) { data <- object[["model"]][["data"]] - if (length(data) == 0 || nrow(data) != length(object[["treat"]])) { + if (is_null(data) || nrow(data) != length(object[["treat"]])) { .err("a valid dataset could not be found. Please supply an argument to `data` containing the original dataset used in the matching") } } @@ -206,7 +206,7 @@ match.data <- function(object, group = "all", distance = "distance", weights = " .err("`data` must have as many rows as there were units in the original call to `matchit()`") } - if (!is.null(object$distance)) { + if (is_not_null(object$distance)) { chk::chk_not_null(distance) chk::chk_string(distance) if (distance %in% names(data)) { @@ -216,7 +216,7 @@ match.data <- function(object, group = "all", distance = "distance", weights = " data[[distance]] <- object$distance } - if (!is.null(object$weights)) { + if (is_not_null(object$weights)) { chk::chk_not_null(weights) chk::chk_string(weights) if (weights %in% names(data)) { @@ -225,12 +225,12 @@ match.data <- function(object, group = "all", distance = "distance", weights = " } data[[weights]] <- object$weights - if (!is.null(object$s.weights) && include.s.weights) { + if (is_not_null(object$s.weights) && include.s.weights) { data[[weights]] <- data[[weights]] * object$s.weights } } - if (!is.null(object$subclass)) { + if (is_not_null(object$subclass)) { chk::chk_not_null(subclass) chk::chk_string(subclass) if (subclass %in% names(data)) { @@ -242,7 +242,7 @@ match.data <- function(object, group = "all", distance = "distance", weights = " treat <- object$treat - if (drop.unmatched && !is.null(object$weights)) { + if (drop.unmatched && is_not_null(object$weights)) { data <- data[object$weights > 0,,drop = FALSE] treat <- treat[object$weights > 0] } @@ -251,9 +251,9 @@ match.data <- function(object, group = "all", distance = "distance", weights = " if (group == "treated") data <- data[treat == 1,,drop = FALSE] else if (group == "control") data <- data[treat == 0,,drop = FALSE] - if (!is.null(object$distance)) attr(data, "distance") <- distance - if (!is.null(object$weights)) attr(data, "weights") <- weights - if (!is.null(object$subclass)) attr(data, "subclass") <- subclass + if (is_not_null(object$distance)) attr(data, "distance") <- distance + if (is_not_null(object$weights)) attr(data, "weights") <- weights + if (is_not_null(object$subclass)) attr(data, "subclass") <- subclass class(data) <- c("matchdata", class(data)) @@ -267,7 +267,7 @@ get_matches <- function(object, distance = "distance", weights = "weights", subc chk::chk_is(object, "matchit") - if (is.null(object$match.matrix)) { + if (is_null(object$match.matrix)) { .err("a match.matrix component must be present in the matchit object, which does not occur with all types of matching. Use `match.data()` instead") } @@ -304,7 +304,7 @@ get_matches <- function(object, distance = "distance", weights = "weights", subc matched[[subclass]] <- c(as.vector(col(tmm)[!is.na(tmm)]), seq_len(nrow(mm))) matched[[weights]] <- c(1/num.matches[matched[[subclass]][seq_len(sum(!is.na(mm)))]], rep(1, nrow(mm))) - if (!is.null(object$s.weights) && include.s.weights) { + if (is_not_null(object$s.weights) && include.s.weights) { matched[[weights]] <- matched[[weights]] * object$s.weights[matched[[id]]] } @@ -315,7 +315,7 @@ get_matches <- function(object, distance = "distance", weights = "weights", subc out[[subclass]] <- factor(out[[subclass]], labels = seq_len(nrow(mm))) - if (!is.null(object$distance)) attr(out, "distance") <- distance + if (is_not_null(object$distance)) attr(out, "distance") <- distance attr(out, "weights") <- weights attr(out, "subclass") <- subclass attr(out, "id") <- id diff --git a/R/match.qoi.R b/R/match.qoi.R index da05afca..bb683a3d 100644 --- a/R/match.qoi.R +++ b/R/match.qoi.R @@ -3,7 +3,7 @@ bal1var <- function(xx, tt, ww = NULL, s.weights, subclass = NULL, mm = NULL, s.d.denom = "treated", standardize = FALSE, compute.pair.dist = TRUE) { - un <- is.null(ww) + un <- is_null(ww) bin.var <- all(xx == 0 | xx == 1) xsum <- rep(NA_real_, 7) @@ -133,9 +133,13 @@ bal1var.subclass <- function(xx, tt, s.weights, subclass, s.d.denom = "treated", } #Compute within-pair/subclass distances -pair.dist <- function(xx, tt, subclass = NULL, mm = NULL, std = NULL, fast = TRUE) { +pair.dist <- function(xx, tt, subclass = NULL, mm = NULL, std = NULL) { - if (!is.null(mm)) { + if (is_not_null(subclass)) { + mpdiff <- pairdistsubC(as.numeric(xx), as.integer(tt), + as.integer(subclass)) + } + else if (is_not_null(mm)) { names(xx) <- names(tt) xx_t <- xx[rownames(mm)] xx_c <- matrix(0, nrow = nrow(mm), ncol = ncol(mm)) @@ -143,30 +147,11 @@ pair.dist <- function(xx, tt, subclass = NULL, mm = NULL, std = NULL, fast = TRU mpdiff <- mean(abs(xx_t - xx_c), na.rm = TRUE) } - else if (!is.null(subclass)) { - if (!fast) { - dists <- unlist(lapply(levels(subclass), function(s) { - t1 <- which(!is.na(subclass) & subclass == s & tt == 1) - t0 <- which(!is.na(subclass) & subclass == s & tt == 0) - if (length(t1) == 1 || length(t0) == 1) { - xx[t1] - xx[t0] - } - else { - outer(xx[t1], xx[t0], "-") - } - })) - mpdiff <- mean(abs(dists)) - } - else { - mpdiff <- pairdistsubC(as.numeric(xx), as.integer(tt), - as.integer(subclass)) - } - } else { return(NA_real_) } - if (!is.null(std) && abs(mpdiff) > 1e-8) { + if (is_not_null(std) && abs(mpdiff) > 1e-8) { mpdiff <- mpdiff/std } @@ -179,7 +164,7 @@ qqsum <- function(x, t, w = NULL, standardize = FALSE) { n.obs <- length(x) - if (is.null(w)) w <- rep(1, n.obs) + if (is_null(w)) w <- rep(1, n.obs) if (all(x == 0 | x == 1)) { t1 <- t == t[1] diff --git a/R/matchit.R b/R/matchit.R index 49c2b66e..b161937f 100644 --- a/R/matchit.R +++ b/R/matchit.R @@ -409,14 +409,14 @@ matchit <- function(formula, mcall <- match.call() ## Process method - .chk_null_or(method, chk::chk_string) - if (length(method) == 1 && is.character(method)) { + chk::chk_null_or(method, vld = chk::vld_string) + if (length(method) == 1L && is.character(method)) { method <- tolower(method) method <- match_arg(method, c("exact", "cem", "nearest", "optimal", "full", "genetic", "subclass", "cardinality", "quick")) fn2 <- paste0("matchit2", method) } - else if (is.null(method)) { + else if (is_null(method)) { fn2 <- "matchit2null" } else { @@ -424,7 +424,9 @@ matchit <- function(formula, } #Process formula and data inputs - .chk_formula(formula, sides = 2) + if (is_null(formula) || !rlang::is_formula(formula, lhs = TRUE)) { + .err("`formula` must be a formula relating treatment to covariates") + } treat.form <- update(terms(formula, data = data), . ~ 0) treat.mf <- model.frame(treat.form, data = data, na.action = "na.pass") @@ -432,7 +434,9 @@ matchit <- function(formula, #Check and binarize treat treat <- check_treat(treat) - if (length(treat) == 0) .err("the treatment cannot be `NULL`") + if (is_null(treat)) { + .err("the treatment cannot be `NULL`") + } names(treat) <- rownames(treat.mf) @@ -444,7 +448,7 @@ matchit <- function(formula, reestimate = reestimate, s.weights = s.weights, replace = replace, ratio = ratio, m.order = m.order, estimand = estimand) - if (length(ignored.inputs) > 0) { + if (is_not_null(ignored.inputs)) { for (i in ignored.inputs) assign(i, NULL) } @@ -455,40 +459,51 @@ matchit <- function(formula, ratio <- process.ratio(ratio, method, ...) #Process s.weights - if (!is.null(s.weights)) { + if (is_not_null(s.weights)) { if (is.character(s.weights)) { - if (is.null(data) || !is.data.frame(data)) { + if (is_null(data) || !is.data.frame(data)) { .err("if `s.weights` is specified a string, a data frame containing the named variable must be supplied to `data`") } + if (!all(s.weights %in% names(data))) { .err("the name supplied to `s.weights` must be a variable in `data`") } + s.weights.form <- reformulate(s.weights) s.weights <- model.frame(s.weights.form, data, na.action = "na.pass") - if (ncol(s.weights) != 1) .err("`s.weights` can only contain one named variable") - s.weights <- s.weights[[1]] + + if (ncol(s.weights) != 1L) { + .err("`s.weights` can only contain one named variable") + } + + s.weights <- s.weights[[1L]] } else if (inherits(s.weights, "formula")) { s.weights.form <- update(terms(s.weights, data = data), NULL ~ .) s.weights <- model.frame(s.weights.form, data, na.action = "na.pass") - if (ncol(s.weights) != 1) .err("`s.weights` can only contain one named variable") - s.weights <- s.weights[[1]] + + if (ncol(s.weights) != 1L) { + .err("`s.weights` can only contain one named variable") + } + + s.weights <- s.weights[[1L]] } else if (!is.numeric(s.weights)) { .err("`s.weights` must be supplied as a numeric vector, string, or one-sided formula") } chk::chk_not_any_na(s.weights) - if (length(s.weights) != n.obs) .err("`s.weights` must be the same length as the treatment vector") + if (length(s.weights) != n.obs) { + .err("`s.weights` must be the same length as the treatment vector") + } names(s.weights) <- names(treat) - } #Process distance function is.full.mahalanobis <- FALSE fn1 <- NULL - if (is.null(method) || !method %in% c("exact", "cem", "cardinality")) { + if (is_null(method) || !method %in% c("exact", "cem", "cardinality")) { distance <- process.distance(distance, method, treat) if (is.numeric(distance)) { @@ -507,7 +522,7 @@ matchit <- function(formula, } #Process covs - if (!is.null(fn1) && fn1 == "distance2gam") { + if (is_not_null(fn1) && fn1 == "distance2gam") { rlang::check_installed("mgcv") env <- environment(formula) covs.formula <- mgcv::interpret.gam(formula)$fake.formula @@ -529,7 +544,10 @@ matchit <- function(formula, .err(paste0("Missing and non-finite values are not allowed in the covariates. Covariates with missingness or non-finite values:\n\t", paste(covariates.with.missingness, collapse = ", ")), tidy = FALSE) } - if (is.character(covs[[i]])) covs[[i]] <- factor(covs[[i]]) + + if (is.character(covs[[i]])) { + covs[[i]] <- factor(covs[[i]]) + } } #Process exact, mahvars, and antiexact @@ -543,7 +561,7 @@ matchit <- function(formula, antiexact <- attr(antiexactcovs, "terms") #Estimate distance, discard from common support, optionally re-estimate distance - if (is.null(fn1) || is.full.mahalanobis) { + if (is_null(fn1) || is.full.mahalanobis) { #No distance measure dist.model <- distance <- link <- NULL } @@ -555,21 +573,23 @@ matchit <- function(formula, cat("Estimating propensity scores... \n") } - if (!is.null(s.weights)) { + if (is_not_null(s.weights)) { attr(s.weights, "in_ps") <- !distance %in% c("bart") } #Estimate distance - if (is.null(distance.options$formula)) distance.options$formula <- formula - if (is.null(distance.options$data)) distance.options$data <- data - if (is.null(distance.options$verbose)) distance.options$verbose <- verbose - if (is.null(distance.options$estimand)) distance.options$estimand <- estimand - if (is.null(distance.options$weights) && !fn1 %in% c("distance2bart")) { + if (is_null(distance.options$formula)) distance.options$formula <- formula + if (is_null(distance.options$data)) distance.options$data <- data + if (is_null(distance.options$verbose)) distance.options$verbose <- verbose + if (is_null(distance.options$estimand)) distance.options$estimand <- estimand + if (is_null(distance.options$weights) && !fn1 %in% c("distance2bart")) { distance.options$weights <- s.weights } - if (!is.null(attr(distance, "link"))) distance.options$link <- attr(distance, "link") - else distance.options$link <- link + distance.options$link <- { + if (is_not_null(attr(distance, "link"))) attr(distance, "link") + else link + } dist.out <- do.call(fn1, distance.options, quote = TRUE) @@ -585,7 +605,7 @@ matchit <- function(formula, } #Process discard - if (is.null(fn1) || is.full.mahalanobis || fn1 == "distance2user") { + if (is_null(fn1) || is.full.mahalanobis || fn1 == "distance2user") { discarded <- discard(treat, distance, discard) } else { @@ -597,7 +617,7 @@ matchit <- function(formula, if (length(distance.options[[i]]) == n.obs) { distance.options[[i]] <- distance.options[[i]][!discarded] } - else if (length(dim(distance.options[[i]])) == 2 && nrow(distance.options[[i]]) == n.obs) { + else if (length(dim(distance.options[[i]])) == 2L && nrow(distance.options[[i]]) == n.obs) { distance.options[[i]] <- distance.options[[i]][!discarded,,drop = FALSE] } } @@ -610,12 +630,16 @@ matchit <- function(formula, #Process caliper calcovs <- NULL - if (!is.null(caliper)) { + if (is_not_null(caliper)) { caliper <- process.caliper(caliper, method, data, covs, mahcovs, distance, discarded, std.caliper) - if (!is.null(attr(caliper, "cal.formula"))) { + if (is_not_null(attr(caliper, "cal.formula"))) { calcovs <- model.frame(attr(caliper, "cal.formula"), data, na.action = "na.pass") - if (anyNA(calcovs)) .err("missing values are not allowed in the covariates named in `caliper`") + + if (anyNA(calcovs)) { + .err("missing values are not allowed in the covariates named in `caliper`") + } + attr(caliper, "cal.formula") <- NULL } } @@ -630,17 +654,23 @@ matchit <- function(formula, quote = TRUE) info <- create_info(method, fn1, link, discard, replace, ratio, - mahalanobis = is.full.mahalanobis || !is.null(mahvars), + mahalanobis = is.full.mahalanobis || is_not_null(mahvars), transform = attr(is.full.mahalanobis, "transform"), subclass = match.out$subclass, antiexact = colnames(antiexactcovs), - distance_is_matrix = !is.null(distance) && is.matrix(distance)) + distance_is_matrix = is_not_null(distance) && is.matrix(distance)) #Create X.list for X output, removing duplicate variables X.list <- list(covs, exactcovs, mahcovs, calcovs, antiexactcovs) all.covs <- lapply(X.list, names) - for (i in seq_along(X.list)[-1]) if (!is.null(X.list[[i]])) X.list[[i]][names(X.list[[i]]) %in% unlist(all.covs[1:(i-1)])] <- NULL - X.list[vapply(X.list, is.null, logical(1L))] <- NULL + + for (i in seq_along(X.list)[-1]) { + if (is_not_null(X.list[[i]])) { + X.list[[i]][names(X.list[[i]]) %in% unlist(all.covs[1:(i-1)])] <- NULL + } + } + + X.list[vapply(X.list, is_null, logical(1L))] <- NULL ## putting all the results together out <- list( @@ -653,7 +683,7 @@ matchit <- function(formula, estimand = estimand, formula = formula, treat = treat, - distance = if (!is.null(distance) && !is.matrix(distance)) setNames(distance, names(treat)), + distance = if (is_not_null(distance) && !is.matrix(distance)) setNames(distance, names(treat)), discarded = discarded, s.weights = s.weights, exact = exact, @@ -664,38 +694,40 @@ matchit <- function(formula, obj = if (include.obj) match.out[["obj"]] ) - out[vapply(out, is.null, logical(1L))] <- NULL + out[vapply(out, is_null, logical(1L))] <- NULL class(out) <- class(match.out) out + #:) } #' @export #' @rdname matchit print.matchit <- function(x, ...) { info <- x[["info"]] - cal <- !is.null(x[["caliper"]]) + cal <- is_not_null(x[["caliper"]]) dis <- c("both", "control", "treat")[pmatch(info$discard, c("both", "control", "treat"), 0L)] - disl <- length(dis) > 0 - nm <- is.null(x[["method"]]) + disl <- is_not_null(dis) + nm <- is_null(x[["method"]]) + cat("A matchit object") cat(paste0("\n - method: ", info.to.method(info))) - if (!is.null(info$distance) || info$mahalanobis) { + if (is_not_null(info$distance) || info$mahalanobis) { cat("\n - distance: ") if (info$mahalanobis) { - if (is.null(info$transform)) #mahvars used + if (is_null(info$transform)) #mahvars used cat("Mahalanobis") else { cat(capwords(gsub("_", " ", info$transform, fixed = TRUE))) } } - if (!is.null(info$distance) && !info$distance %in% matchit_distances()) { + if (is_not_null(info$distance) && !info$distance %in% matchit_distances()) { if (info$mahalanobis) cat(" [matching]\n ") if (info$distance_is_matrix) cat("User-defined (matrix)") else if (info$distance != "user") cat("Propensity score") - else if (!is.null(attr(info$distance, "custom"))) cat(attr(info$distance, "custom")) + else if (is_not_null(attr(info$distance, "custom"))) cat(attr(info$distance, "custom")) else cat("User-defined") if (cal || disl) { @@ -707,7 +739,7 @@ print.matchit <- function(x, ...) { if (info$distance != "user") { cat("\n - estimated with ") cat(info.to.distance(info)) - if (!is.null(x[["s.weights"]])) { + if (is_not_null(x[["s.weights"]])) { if (isTRUE(attr(x[["s.weights"]], "in_ps"))) cat("\n - sampling weights included in estimation") else cat("\n - sampling weights not included in estimation") @@ -715,11 +747,13 @@ print.matchit <- function(x, ...) { } } } + if (cal) { cat(paste0("\n - caliper: ", paste(vapply(seq_along(x[["caliper"]]), function(z) paste0(if (names(x[["caliper"]])[z] == "") "" else names(x[["caliper"]])[z], " (", format(round(x[["caliper"]][z], 3)), ")"), character(1L)), collapse = ", "))) } + if (disl) { cat("\n - common support: ") if (dis == "both") cat("units from both groups") @@ -727,10 +761,11 @@ print.matchit <- function(x, ...) { else if (dis == "control") cat("control units") cat(" dropped") } + cat(paste0("\n - number of obs.: ", length(x[["treat"]]), " (original)", if (!all(x[["weights"]] == 1)) paste0(", ", sum(x[["weights"]] != 0), " (matched)"))) - if (!is.null(x[["s.weights"]])) cat("\n - sampling weights: present") - if (!is.null(x[["estimand"]])) cat(paste0("\n - target estimand: ", x[["estimand"]])) - if (!is.null(x[["X"]])) cat(paste0("\n - covariates: ", ifelse(length(names(x[["X"]])) > 40, "too many to name", paste(names(x[["X"]]), collapse = ", ")))) + if (is_not_null(x[["s.weights"]])) cat("\n - sampling weights: present") + if (is_not_null(x[["estimand"]])) cat(paste0("\n - target estimand: ", x[["estimand"]])) + if (is_not_null(x[["X"]])) cat(paste0("\n - covariates: ", if (length(names(x[["X"]])) > 40) "too many to name" else paste(names(x[["X"]]), collapse = ", "))) cat("\n") invisible(x) } diff --git a/R/matchit2cardinality.R b/R/matchit2cardinality.R index f3de1e6b..429a0f31 100644 --- a/R/matchit2cardinality.R +++ b/R/matchit2cardinality.R @@ -281,8 +281,10 @@ matchit2cardinality <- function(treat, data, discarded, formula, estimand <- toupper(estimand) estimand <- match_arg(estimand, c("ATT", "ATC", "ATE")) - if (!is.null(focal)) { - if (!focal %in% tvals) .err("`focal` must be a value of the treatment") + if (is_not_null(focal)) { + if (!focal %in% tvals) { + .err("`focal` must be a value of the treatment") + } } else if (estimand == "ATC") { focal <- min(tvals) @@ -297,11 +299,14 @@ matchit2cardinality <- function(treat, data, discarded, formula, X <- get.covs.matrix(formula, data = data) - if (!is.null(exact)) { + if (is_not_null(exact)) { ex <- factor(exactify(model.frame(exact, data = data), nam = lab, sep = ", ", include_vars = TRUE)) - cc <- do.call("intersect", lapply(tvals, function(t) as.integer(ex)[treat == t])) - if (length(cc) == 0) .err("no matches were found") + cc <- Reduce("intersect", lapply(tvals, function(t) as.integer(ex)[treat==t])) + + if (is_null(cc)) { + .err("no matches were found") + } } else { ex <- gl(1, length(treat)) @@ -309,11 +314,13 @@ matchit2cardinality <- function(treat, data, discarded, formula, } #Process mahvars - if (!is.null(mahvars)) { + if (is_not_null(mahvars)) { if (!is.finite(ratio) || !chk::vld_whole_number(ratio)) { .err("`mahvars` can only be used with `method = \"cardinality\"` when `ratio` is a whole number") } + rlang::check_installed("optmatch") + mahcovs <- transform_covariates(mahvars, data = data, method = "mahalanobis", s.weights = s.weights, treat = treat, discarded = discarded) @@ -332,7 +339,7 @@ matchit2cardinality <- function(treat, data, discarded, formula, assign <- get_assign(X) chk::chk_numeric(tols) - if (length(tols) == 1) { + if (length(tols) == 1L) { tols <- rep(tols, ncol(X)) } else if (length(tols) == max(assign)) { @@ -343,7 +350,7 @@ matchit2cardinality <- function(treat, data, discarded, formula, } chk::chk_logical(std.tols) - if (length(std.tols) == 1) { + if (length(std.tols) == 1L) { std.tols <- rep(std.tols, ncol(X)) } else if (length(std.tols) == max(assign)) { @@ -396,7 +403,7 @@ matchit2cardinality <- function(treat, data, discarded, formula, weights[in.exact] <- out[["weights"]] opt.out[[e]] <- out[["opt.out"]] - if (!is.null(mahvars)) { + if (is_not_null(mahvars)) { mo <- eucdist_internal(mahcovs[in.exact[out[["weights"]] > 0],, drop = FALSE], treat_in.exact[out[["weights"]] > 0]) @@ -408,7 +415,7 @@ matchit2cardinality <- function(treat, data, discarded, formula, } } - if (!is.null(pair)) { + if (is_not_null(pair)) { psclass <- factor(pair) levels(psclass) <- seq_len(nlevels(psclass)) names(psclass) <- names(treat) @@ -440,20 +447,27 @@ cardinality_matchit <- function(treat, X, estimand = "ATT", tols = .05, s.weight time = 2*60, verbose = FALSE) { n <- length(treat) - if (is.null(tvals)) tvals <- if (is.factor(treat)) levels(treat) else sort(unique(treat)) + if (is_null(tvals)) tvals <- if (is.factor(treat)) levels(treat) else sort(unique(treat)) nt <- length(tvals) #Check inputs - if (is.null(s.weights)) s.weights <- rep(1, n) - else for (i in tvals) s.weights[treat == i] <- s.weights[treat == i]/mean(s.weights[treat == i]) + if (is_null(s.weights)) { + s.weights <- rep(1, n) + } + else { + for (i in tvals) { + s.weights[treat == i] <- s.weights[treat == i]/mean(s.weights[treat == i]) + } + } - if (is.null(focal)) focal <- tvals[length(tvals)] + if (is_null(focal)) focal <- tvals[length(tvals)] chk::chk_number(time) chk::chk_gt(time, 0) chk::chk_string(solver) solver <- match_arg(solver, c("highs", "glpk", "symphony", "gurobi")) + rlang::check_installed(switch(solver, glpk = "Rglpk", symphony = "Rsymphony", gurobi = "gurobi", highs = "highs")) @@ -688,7 +702,8 @@ cardinality_error_report <- function(out, solver) { } } -dispatch_optimizer <- function(solver = "glpk", obj, mat, dir, rhs, types, max = TRUE, lb = NULL, ub = NULL, time = NULL, verbose = FALSE) { +dispatch_optimizer <- function(solver = "glpk", obj, mat, dir, rhs, types, max = TRUE, + lb = NULL, ub = NULL, time = NULL, verbose = FALSE) { if (solver == "glpk") { dir[dir == "="] <- "==" opt.out <- Rglpk::Rglpk_solve_LP(obj = obj, mat = mat, dir = dir, rhs = rhs, max = max, diff --git a/R/matchit2exact.R b/R/matchit2exact.R index ecf3bbae..04006669 100644 --- a/R/matchit2exact.R +++ b/R/matchit2exact.R @@ -92,7 +92,9 @@ matchit2exact <- function(treat, covs, data, estimand = "ATT", verbose = FALSE, if(verbose) cat("Exact matching... \n") - if (length(covs) == 0) .err("covariates must be specified in the input formula to use exact matching") + if (is_null(covs)) { + .err("covariates must be specified in the input formula to use exact matching") + } estimand <- toupper(estimand) estimand <- match_arg(estimand, c("ATT", "ATC", "ATE")) @@ -100,8 +102,8 @@ matchit2exact <- function(treat, covs, data, estimand = "ATT", verbose = FALSE, xx <- exactify(covs, names(treat)) cc <- do.call("intersect", lapply(unique(treat), function(t) xx[treat == t])) - if (length(cc) == 0) { - .err("No exact matches were found") + if (is_null(cc)) { + .err("no exact matches were found") } psclass <- setNames(factor(match(xx, cc), nmax = length(cc)), names(treat)) diff --git a/R/matchit2full.R b/R/matchit2full.R index 7fd625bd..84a297bc 100644 --- a/R/matchit2full.R +++ b/R/matchit2full.R @@ -72,7 +72,7 @@ #' place when `distance` corresponds to a propensity score (e.g., for #' caliper matching or to discard units for common support). If specified, the #' distance measure will not be used in matching. -#' @param antiexact for which variables ant-exact matching should take place. +#' @param antiexact for which variables anti-exact matching should take place. #' Anti-exact matching is processed using \pkgfun{optmatch}{antiExactMatch}. #' @param discard a string containing a method for discarding units outside a #' region of common support. Only allowed when `distance` corresponds to a @@ -255,10 +255,10 @@ matchit2full <- function(treat, formula, data, distance, discarded, treat_ <- setNames(as.integer(treat[!discarded] == focal), names(treat)[!discarded]) - # if (!is.null(data)) data <- data[!discarded,] + # if (is_not_null(data)) data <- data[!discarded,] if (is.full.mahalanobis) { - if (length(attr(terms(formula, data = data), "term.labels")) == 0) { + if (is_null(attr(terms(formula, data = data), "term.labels"))) { .err(sprintf("covariates must be specified in the input formula when `distance = \"%s\"`", attr(is.full.mahalanobis, "transform"))) } @@ -269,13 +269,13 @@ matchit2full <- function(treat, formula, data, distance, discarded, max.controls <- attr(ratio, "max.controls") #Exact matching strata - if (!is.null(exact)) { + if (is_not_null(exact)) { ex <- factor(exactify(model.frame(exact, data = data), sep = ", ", include_vars = TRUE)[!discarded]) cc <- intersect(as.integer(ex)[treat_==1], as.integer(ex)[treat_==0]) - if (length(cc) == 0L) { + if (is_null(cc)) { .err("No matches were found") } } @@ -286,7 +286,7 @@ matchit2full <- function(treat, formula, data, distance, discarded, #Create distance matrix; note that Mahalanobis distance computed using entire #sample (minus discarded), like method2nearest, as opposed to within exact strata, like optmatch. - if (!is.null(mahvars)) { + if (is_not_null(mahvars)) { transform <- if (is.full.mahalanobis) attr(is.full.mahalanobis, "transform") else "mahalanobis" mahcovs <- transform_covariates(mahvars, data = data, method = transform, s.weights = s.weights, treat = treat, @@ -311,7 +311,7 @@ matchit2full <- function(treat, formula, data, distance, discarded, mo <- optmatch::as.InfinitySparseMatrix(mo) #Process antiexact - if (!is.null(antiexact)) { + if (is_not_null(antiexact)) { antiexactcovs <- model.frame(antiexact, data) for (i in seq_len(ncol(antiexactcovs))) { mo <- mo + optmatch::antiExactMatch(antiexactcovs[[i]][!discarded], z = treat_) @@ -319,7 +319,7 @@ matchit2full <- function(treat, formula, data, distance, discarded, } #Process caliper - if (!is.null(caliper)) { + if (is_not_null(caliper)) { if (min.controls != 0) { .err("calipers cannot be used with `method = \"full\"` when `min.controls` is specified") } @@ -333,7 +333,7 @@ matchit2full <- function(treat, formula, data, distance, discarded, if (names(caliper)[i] != "") { mo_cal <- optmatch::match_on(setNames(calcovs[!discarded, names(caliper)[i]], names(treat_)), z = treat_) } - else if (is.null(mahvars) || is.matrix(distance)) { + else if (is_null(mahvars) || is.matrix(distance)) { mo_cal <- mo } else { diff --git a/R/matchit2optimal.R b/R/matchit2optimal.R index 7ac0ed18..30e2dad4 100644 --- a/R/matchit2optimal.R +++ b/R/matchit2optimal.R @@ -65,7 +65,7 @@ #' place when `distance` corresponds to a propensity score (e.g., for #' caliper matching or to discard units for common support). If specified, the #' distance measure will not be used in matching. -#' @param antiexact for which variables ant-exact matching should take place. +#' @param antiexact for which variables anti-exact matching should take place. #' Anti-exact matching is processed using \pkgfun{optmatch}{antiExactMatch}. #' @param discard a string containing a method for discarding units outside a #' region of common support. Only allowed when `distance` is not @@ -272,17 +272,17 @@ matchit2optimal <- function(treat, formula, data, distance, discarded, treat_ <- setNames(as.integer(treat[!discarded] == focal), names(treat)[!discarded]) - # if (!is.null(data)) data <- data[!discarded,] + # if (is_not_null(data)) data <- data[!discarded,] if (is.full.mahalanobis) { - if (length(attr(terms(formula, data = data), "term.labels")) == 0) { + if (is_null(attr(terms(formula, data = data), "term.labels"))) { .err(sprintf("covariates must be specified in the input formula when `distance = \"%s\"`", attr(is.full.mahalanobis, "transform"))) } mahvars <- formula } - if (!is.null(caliper)) { + if (is_not_null(caliper)) { .wrn("calipers are currently not compatible with `method = \"optimal\"` and will be ignored") caliper <- NULL } @@ -290,19 +290,19 @@ matchit2optimal <- function(treat, formula, data, distance, discarded, min.controls <- attr(ratio, "min.controls") max.controls <- attr(ratio, "max.controls") - if (is.null(max.controls)) { + if (is_null(max.controls)) { min.controls <- max.controls <- ratio } #Exact matching strata - if (!is.null(exact)) { + if (is_not_null(exact)) { ex <- factor(exactify(model.frame(exact, data = data), sep = ", ", include_vars = TRUE)[!discarded]) cc <- intersect(as.integer(ex)[treat_==1], as.integer(ex)[treat_==0]) - if (length(cc) == 0L) { - .err("No matches were found") + if (is_null(cc) ) { + .err("no matches were found") } e_ratios <- vapply(levels(ex), function(e) sum(treat_[ex == e] == 0)/sum(treat_[ex == e] == 1), numeric(1L)) @@ -342,7 +342,7 @@ matchit2optimal <- function(treat, formula, data, distance, discarded, #Create distance matrix; note that Mahalanobis distance computed using entire #sample (minus discarded), like method2nearest, as opposed to within exact strata, like optmatch. - if (!is.null(mahvars)) { + if (is_not_null(mahvars)) { transform <- if (is.full.mahalanobis) attr(is.full.mahalanobis, "transform") else "mahalanobis" mahcovs <- transform_covariates(mahvars, data = data, method = transform, s.weights = s.weights, treat = treat, @@ -367,7 +367,7 @@ matchit2optimal <- function(treat, formula, data, distance, discarded, mo <- optmatch::as.InfinitySparseMatrix(mo) #Process antiexact - if (!is.null(antiexact)) { + if (is_not_null(antiexact)) { antiexactcovs <- model.frame(antiexact, data) for (i in seq_len(ncol(antiexactcovs))) { mo <- mo + optmatch::antiExactMatch(antiexactcovs[[i]][!discarded], z = treat_) diff --git a/R/matchit2quick.R b/R/matchit2quick.R index 4cfef36d..f3ef87aa 100644 --- a/R/matchit2quick.R +++ b/R/matchit2quick.R @@ -165,7 +165,7 @@ matchit2quick <- function(treat, formula, data, distance, discarded, # treat_ <- setNames(as.integer(treat[!discarded] == focal), names(treat)[!discarded]) if (is.full.mahalanobis) { - if (length(attr(terms(formula, data = data), "term.labels")) == 0) { + if (is_null(attr(terms(formula, data = data), "term.labels"))) { .err(sprintf("covariates must be specified in the input formula when `distance = \"%s\"`", attr(is.full.mahalanobis, "transform"))) } @@ -173,13 +173,13 @@ matchit2quick <- function(treat, formula, data, distance, discarded, } #Exact matching strata - if (!is.null(exact)) { + if (is_not_null(exact)) { ex <- factor(exactify(model.frame(exact, data = data), sep = ", ", include_vars = TRUE)[!discarded]) cc <- intersect(as.integer(ex)[treat_==1], as.integer(ex)[treat_==0]) - if (length(cc) == 0L) { + if (is_null(cc)) { .err("no matches were found") } } @@ -190,7 +190,7 @@ matchit2quick <- function(treat, formula, data, distance, discarded, #Create distance matrix; note that Mahalanobis distance computed using entire #sample (minus discarded), like method2nearest, as opposed to within exact strata, like optmatch. - if (!is.null(mahvars)) { + if (is_not_null(mahvars)) { transform <- if (is.full.mahalanobis) attr(is.full.mahalanobis, "transform") else "mahalanobis" distcovs <- transform_covariates(mahvars, data = data, method = transform, s.weights = s.weights, treat = treat, @@ -205,8 +205,8 @@ matchit2quick <- function(treat, formula, data, distance, discarded, rownames(distcovs) <- names(treat_) #Process caliper - if (!is.null(caliper)) { - if (!is.null(mahvars)) { + if (is_not_null(caliper)) { + if (is_not_null(mahvars)) { .err("with `method = \"quick\"`, a caliper can only be used when `distance` is a propensity score or vector and `mahvars` is not specified") } if (length(caliper) > 1 || !identical(names(caliper), "")) { diff --git a/R/matchit2subclass.R b/R/matchit2subclass.R index b8a9c57a..45b6d355 100644 --- a/R/matchit2subclass.R +++ b/R/matchit2subclass.R @@ -169,49 +169,49 @@ matchit2subclass <- function(treat, distance, discarded, min.n <- A[["min.n"]] #Checks - if (is.null(subclass)) subclass <- 6 + if (is_null(subclass)) subclass <- 6 chk::chk_numeric(subclass) - if (length(subclass) == 1) { + if (length(subclass) == 1L) { chk::chk_gt(subclass, 1) } else if (!all(subclass <= 1 & subclass >= 0)) { .err("When specifying `subclass` as a vector of quantiles, all values must be between 0 and 1") } - if (!is.null(sub.by)) { + if (is_not_null(sub.by)) { .err("`sub.by` is defunct and has been replaced with `estimand`") } estimand <- toupper(estimand) estimand <- match_arg(estimand, c("ATT", "ATC", "ATE")) - if (is.null(min.n)) min.n <- 1 + if (is_null(min.n)) min.n <- 1 chk::chk_count(min.n) n.obs <- length(treat) ## Setting Cut Points - if (length(subclass) == 1) { + if (length(subclass) == 1L) { sprobs <- seq(0, 1, length.out = round(subclass) + 1) } else { sprobs <- sort(subclass) if (sprobs[1] != 0) sprobs <- c(0, sprobs) if (sprobs[length(sprobs)] != 1) sprobs <- c(sprobs, 1) - subclass <- length(sprobs) - 1 + subclass <- length(sprobs) - 1L } q <- switch(estimand, - "ATT" = quantile(distance[treat==1], probs = sprobs, na.rm = TRUE), - "ATC" = quantile(distance[treat==0], probs = sprobs, na.rm = TRUE), + "ATT" = quantile(distance[treat == 1], probs = sprobs, na.rm = TRUE), + "ATC" = quantile(distance[treat == 0], probs = sprobs, na.rm = TRUE), quantile(distance, probs = sprobs, na.rm = TRUE)) ## Calculating Subclasses psclass <- setNames(rep(NA_integer_, n.obs), names(treat)) psclass[!discarded] <- as.integer(findInterval(distance[!discarded], q, all.inside = TRUE)) - if (length(unique(na.omit(psclass))) != subclass){ + if (!has_n_unique(na.omit(psclass), subclass)) { .wrn("due to discreteness in the distance measure, fewer subclasses were generated than were requested") } diff --git a/R/plot.matchit.R b/R/plot.matchit.R index 2b839f8c..4ae6e88c 100644 --- a/R/plot.matchit.R +++ b/R/plot.matchit.R @@ -149,13 +149,13 @@ plot.matchit <- function(x, type = "qq", interactive = TRUE, which.xs = NULL, da which.xs = which.xs, data = data, ...) } else if (type == "jitter") { - if (is.null(x$distance)) { + if (is_null(x$distance)) { .err("`type = \"jitter\"` cannot be used if a distance measure is not estimated or supplied. No plots generated") } jitter_pscore(x, interactive = interactive,...) } else if (type == "histogram") { - if (is.null(x$distance)) { + if (is_null(x$distance)) { .err("`type = \"hist\"` cannot be used if a distance measure is not estimated or supplied. No plots generated") } hist_pscore(x,...) @@ -192,15 +192,22 @@ plot.matchit.subclass <- function(x, type = "qq", interactive = TRUE, which.xs = #If subclass = NULL, if interactive, use to choose subclass, else display aggregate across subclass subclasses <- levels(x$subclass) - miss.sub <- missing(subclass) || is.null(subclass) - if (miss.sub || isFALSE(subclass)) which.subclass <- NULL - else if (isTRUE(subclass)) which.subclass <- subclasses - else if (!is.atomic(subclass) || !all(subclass %in% seq_along(subclasses))) { + miss.sub <- missing(subclass) || is_null(subclass) + + if (miss.sub || isFALSE(subclass)) { + which.subclass <- NULL + } + else if (isTRUE(subclass)) { + which.subclass <- subclasses + } + else if (is.atomic(subclass) && all(subclass %in% seq_along(subclasses))) { + which.subclass <- subclasses[subclass] + } + else { .err("`subclass` should be `TRUE`, `FALSE`, or a vector of subclass indices for which subclass balance is to be displayed") } - else which.subclass <- subclasses[subclass] - if (!is.null(which.subclass)) { + if (is_not_null(which.subclass)) { matchit.covplot.subclass(x, type = type, which.subclass = which.subclass, interactive = interactive, which.xs = which.xs, ...) } @@ -226,13 +233,13 @@ plot.matchit.subclass <- function(x, type = "qq", interactive = TRUE, which.xs = } } else if (type=="jitter") { - if (is.null(x$distance)) { + if (is_null(x$distance)) { .err("`type = \"jitter\"` cannot be used when no distance variable was estimated or supplied") } jitter_pscore(x, interactive = interactive, ...) } else if (type == "histogram") { - if (is.null(x$distance)) { + if (is_null(x$distance)) { .err("`type = \"histogram\"` cannot be used when no distance variable was estimated or supplied") } hist_pscore(x,...) @@ -243,25 +250,26 @@ plot.matchit.subclass <- function(x, type = "qq", interactive = TRUE, which.xs = ## plot helper functions matchit.covplot <- function(object, type = "qq", interactive = TRUE, which.xs = NULL, data = NULL, ...) { - if (is.null(which.xs)) { - if (length(object$X) == 0L) { + if (is_null(which.xs)) { + if (is_null(object$X)) { .wrn("No covariates to plot") return(invisible(NULL)) } + X <- object$X - if (!is.null(object$exact)) { + if (is_not_null(object$exact)) { Xexact <- model.frame(object$exact, data = object$X) X <- cbind(X, Xexact[setdiff(names(Xexact), names(X))]) } - if (!is.null(object$mahvars)) { + if (is_not_null(object$mahvars)) { Xmahvars <- model.frame(object$mahvars, data = object$X) X <- cbind(X, Xmahvars[setdiff(names(Xmahvars), names(X))]) } } else { - if (!is.null(data)) { + if (is_not_null(data)) { if (!is.data.frame(data) || nrow(data) != length(object$treat)) { .err("`data` must be a data frame with as many rows as there are units in the supplied `matchit` object") } @@ -271,12 +279,12 @@ matchit.covplot <- function(object, type = "qq", interactive = TRUE, which.xs = data <- object$X } - if (!is.null(object$exact)) { + if (is_not_null(object$exact)) { Xexact <- model.frame(object$exact, data = object$X) data <- cbind(data, Xexact[setdiff(names(Xexact), names(data))]) } - if (!is.null(object$mahvars)) { + if (is_not_null(object$mahvars)) { Xmahvars <- model.frame(object$mahvars, data = object$X) data <- cbind(data, Xmahvars[setdiff(names(Xmahvars), names(data))]) } @@ -304,13 +312,16 @@ matchit.covplot <- function(object, type = "qq", interactive = TRUE, which.xs = for (i in seq_len(k)) { if (anyNA(X[[i]]) || (is.numeric(X[[i]]) && any(!is.finite(X[[i]])))) { covariates.with.missingness <- names(X)[i:k][vapply(i:k, function(j) anyNA(X[[j]]) || - (is.numeric(X[[j]]) && - any(!is.finite(X[[j]]))), - logical(1L))] + (is.numeric(X[[j]]) && + any(!is.finite(X[[j]]))), + logical(1L))] .err(paste0("Missing and non-finite values are not allowed in the covariates named in `which.xs`. Variables with missingness or non-finite values:\n\t", paste(covariates.with.missingness, collapse = ", ")), tidy = FALSE) } - if (is.character(X[[i]])) X[[i]] <- factor(X[[i]]) + + if (is.character(X[[i]])) { + X[[i]] <- factor(X[[i]]) + } } } @@ -318,9 +329,13 @@ matchit.covplot <- function(object, type = "qq", interactive = TRUE, which.xs = t <- object$treat - sw <- if (is.null(object$s.weights)) rep(1, length(t)) else object$s.weights + sw <- { + if (is_null(object$s.weights)) rep(1, length(t)) + else object$s.weights + } + w <- object$weights * sw - if (is.null(w)) w <- rep(1, length(t)) + if (is_null(w)) w <- rep(1, length(t)) w <- .make_sum_to_1(w, by = t) sw <- .make_sum_to_1(sw, by = t) @@ -399,25 +414,25 @@ matchit.covplot <- function(object, type = "qq", interactive = TRUE, which.xs = matchit.covplot.subclass <- function(object, type = "qq", which.subclass = NULL, interactive = TRUE, which.xs = NULL, data = NULL, ...) { - if (is.null(which.xs)) { - if (length(object$X) == 0L) { + if (is_null(which.xs)) { + if (is_null(object$X)) { .wrn("No covariates to plot") return(invisible(NULL)) } X <- object$X - if (!is.null(object$exact)) { + if (is_not_null(object$exact)) { Xexact <- model.frame(object$exact, data = object$X) X <- cbind(X, Xexact[setdiff(names(Xexact), names(X))]) } - if (!is.null(object$mahvars)) { + if (is_not_null(object$mahvars)) { Xmahvars <- model.frame(object$mahvars, data = object$X) X <- cbind(X, Xmahvars[setdiff(names(Xmahvars), names(X))]) } } else { - if (!is.null(data)) { + if (is_not_null(data)) { if (!is.data.frame(data) || nrow(data) != length(object$treat)) { .err("`data` must be a data frame with as many rows as there are units in the supplied `matchit` object") } @@ -427,12 +442,12 @@ matchit.covplot.subclass <- function(object, type = "qq", which.subclass = NULL, data <- object$X } - if (!is.null(object$exact)) { + if (is_not_null(object$exact)) { Xexact <- model.frame(object$exact, data = object$X) data <- cbind(data, Xexact[setdiff(names(Xexact), names(data))]) } - if (!is.null(object$mahvars)) { + if (is_not_null(object$mahvars)) { Xmahvars <- model.frame(object$mahvars, data = object$X) data <- cbind(data, Xmahvars[setdiff(names(Xmahvars), names(data))]) } @@ -492,15 +507,19 @@ matchit.covplot.subclass <- function(object, type = "qq", which.subclass = NULL, opar <- par(mfrow = c(3, 3), mar = c(1.5,.5,1.5,.5), oma = oma) } - sw <- if (is.null(object$s.weights)) rep(1, length(t)) else object$s.weights - w <- sw*(!is.na(object$subclass) & object$subclass == s) + sw <- { + if (is_null(object$s.weights)) rep(1, length(t)) + else object$s.weights + } + + w <- sw * (!is.na(object$subclass) & object$subclass == s) w <- .make_sum_to_1(w, by = t) sw <- .make_sum_to_1(sw, by = t) for (i in seq_along(varnames)){ - x <- if (type == "density") X[[i]] else X[,i] + x <- switch(type, "density" = X[[i]], X[,i]) plot.new() @@ -550,7 +569,7 @@ matchit.covplot.subclass <- function(object, type = "qq", which.subclass = NULL, } devAskNewPage(ask = FALSE) - invisible(NULL) + invisible(NULL) } qqplot_match <- function(x, t, w, sw, discrete.cutoff = 5, ...) { @@ -706,7 +725,7 @@ ecdfplot_match <- function(x, t, w, sw, ...) { densityplot_match <- function(x, t, w, sw, ...) { - if (length(unique(x)) == 2L) x <- factor(x) + if (has_n_unique(x, 2L)) x <- factor(x) if (!is.factor(x)) { #Density plot for continuous variable @@ -719,7 +738,9 @@ densityplot_match <- function(x, t, w, sw, ...) { A <- list(...) bw <- A[["bw"]] - if (is.null(bw)) A[["bw"]] <- bw.nrd0(x_small) + if (is_null(bw)) { + A[["bw"]] <- bw.nrd0(x_small) + } else if (is.character(bw)) { bw <- tolower(bw) bw <- match_arg(bw, c("nrd0", "nrd", "ucv", "bcv", "sj", "sj-ste", "sj-dpi")) @@ -728,7 +749,8 @@ densityplot_match <- function(x, t, w, sw, ...) { `sj-ste` = bw.SJ(x_small, method = "ste"), `sj-dpi` = bw.SJ(x_small, method = "dpi")) } - if (is.null(A[["cut"]])) A[["cut"]] <- 3 + + if (is_null(A[["cut"]])) A[["cut"]] <- 3 d_unmatched <- do.call("rbind", lapply(0:1, function(tr) { cbind(as.data.frame(do.call("density", c(list(x[t==tr], weights = sw[t==tr], @@ -783,8 +805,8 @@ densityplot_match <- function(x, t, w, sw, ...) { #Bar plot for binary variable x_t_un <- lapply(sort(unique(t)), function(t_) { vapply(levels(x), function(i) { - wm(x[t==t_] == i, sw[t==t_]) - }, numeric(1L))}) + wm(x[t==t_] == i, sw[t==t_]) + }, numeric(1L))}) x_t_m <- lapply(sort(unique(t)), function(t_) { vapply(levels(x), function(i) { @@ -820,7 +842,10 @@ hist_pscore <- function(x, xlab = "Propensity Score", freq = FALSE, ...) { treat <- x$treat pscore <- x$distance[!is.na(x$distance)] - s.weights <- if (is.null(x$s.weights)) rep(1, length(treat)) else x$s.weights + s.weights <- { + if (is_null(x$s.weights)) rep(1, length(treat)) + else x$s.weights + } weights <- x$weights * s.weights matched <- weights != 0 q.cut <- x$q.cut @@ -829,7 +854,7 @@ hist_pscore <- function(x, xlab = "Propensity Score", freq = FALSE, ...) { maxp <- max(pscore) ratio <- x$call$ratio - if (is.null(ratio)) ratio <- 1 + if (is_null(ratio)) ratio <- 1 if (freq) { weights <- .make_sum_to_n(weights, by = treat) @@ -848,11 +873,9 @@ hist_pscore <- function(x, xlab = "Propensity Score", freq = FALSE, ...) { xlim <- range(breaks) for (n in c("Raw Treated", "Matched Treated", "Raw Control", "Matched Control")) { - if (startsWith(n, "Raw")) w <- s.weights - else w <- weights + w <- if (startsWith(n, "Raw")) s.weights else weights - if (endsWith(n, "Treated")) t <- 1 - else t <- 0 + t <- if (endsWith(n, "Treated")) 1 else 0 #Create histogram using weights #Manually assign density, which is used as height of the bars. The scaling @@ -863,10 +886,11 @@ hist_pscore <- function(x, xlab = "Propensity Score", freq = FALSE, ...) { pm[["density"]] <- vapply(seq_len(length(pm$breaks) - 1), function(i) { sum(w[treat == t & pscore >= pm$breaks[i] & pscore < pm$breaks[i+1]]) }, numeric(1L)) + plot(pm, xlim = xlim, xlab = xlab, main = n, ylab = ylab, freq = FALSE, col = "lightgray", ...) - if (!startsWith(n, "Raw") && !is.null(q.cut)) abline(v = q.cut, lty=2) + if (!startsWith(n, "Raw") && is_not_null(q.cut)) abline(v = q.cut, lty=2) } } @@ -878,11 +902,11 @@ jitter_pscore <- function(x, interactive, pch = 1, ...) { treat <- x$treat pscore <- x$distance - s.weights <- if (is.null(x$s.weights)) rep(1, length(treat)) else x$s.weights + s.weights <- if (is_null(x$s.weights)) rep(1, length(treat)) else x$s.weights weights <- x$weights * s.weights matched <- weights > 0 q.cut <- x$q.cut - jitp <- jitter(rep(1,length(treat)), factor=6)+(treat==1)*(weights==0)-(treat==0) - (weights==0)*(treat==0) + jitp <- jitter(rep(1, length(treat)), factor = 6) + (treat==1) * (weights == 0) - (treat==0) - (weights==0) * (treat==0) cswt <- sqrt(s.weights) cwt <- sqrt(weights) minp <- min(pscore, na.rm = TRUE) @@ -891,7 +915,7 @@ jitter_pscore <- function(x, interactive, pch = 1, ...) { plot(pscore, xlim = c(minp - 0.05*(maxp-minp), maxp + 0.05*(maxp-minp)), ylim = c(-1.5,2.5), type = "n", ylab = "", xlab = "Propensity Score", axes = FALSE, main = "Distribution of Propensity Scores", ...) - if (!is.null(q.cut)) abline(v = q.cut, col = "grey", lty = 1) + if (is_not_null(q.cut)) abline(v = q.cut, col = "grey", lty = 1) #Matched treated points(pscore[treat==1 & matched], jitp[treat==1 & matched], diff --git a/R/plot.summary.matchit.R b/R/plot.summary.matchit.R index 6d2e4dab..d3df1178 100644 --- a/R/plot.summary.matchit.R +++ b/R/plot.summary.matchit.R @@ -85,8 +85,8 @@ plot.summary.matchit <- function(x, on.exit(par(.pardefault)) sub <- inherits(x, "summary.matchit.subclass") - matched <- sub || !is.null(x[["sum.matched"]]) - un <- !is.null(x[["sum.all"]]) + matched <- sub || is_not_null(x[["sum.matched"]]) + un <- is_not_null(x[["sum.all"]]) standard.sum <- if (un) x[["sum.all"]] else x[[if (sub) "sum.across" else "sum.matched"]] @@ -131,7 +131,7 @@ plot.summary.matchit <- function(x, bg = NA, color = NA, ...) abline(v = 0) - if (sub && length(x$sum.subclass) > 0) { + if (sub && is_not_null(x$sum.subclass)) { for (i in seq_along(x$sum.subclass)) { sd.sub <- x$sum.subclass[[i]][,"Std. Mean Diff."] if (abs) sd.sub <- abs(sd.sub) @@ -149,7 +149,7 @@ plot.summary.matchit <- function(x, pch = 21, bg = "black", col = "black") } - if (!is.null(threshold)) { + if (is_not_null(threshold)) { if (abs) { abline(v = threshold, lty = seq_along(threshold)) } @@ -159,7 +159,7 @@ plot.summary.matchit <- function(x, } } - if (sum(matched, un) > 1 && !is.null(position)) { + if (sum(matched, un) > 1 && is_not_null(position)) { position <- match_arg(position, c("bottomright", "bottom", "bottomleft", "left", "topleft", "top", "topright", "right", "center")) legend(position, legend = c("All", "Matched"), diff --git a/R/rbind.matchdata.R b/R/rbind.matchdata.R index 2fde0137..47491199 100644 --- a/R/rbind.matchdata.R +++ b/R/rbind.matchdata.R @@ -73,7 +73,7 @@ rbind.matchdata <- function(..., deparse.level = 1) { allargs <- list(...) allargs <- allargs[lengths(allargs) > 0L] - if (is.null(names(allargs))) { + if (is_null(names(allargs))) { md_list <- allargs allargs <- list() } @@ -84,8 +84,14 @@ rbind.matchdata <- function(..., deparse.level = 1) { allargs$deparse.level <- deparse.level type <- intersect(c("matchdata", "getmatches"), unlist(lapply(md_list, class))) - if (length(type) == 0L) .err("A `matchdata` or `getmatches` object must be supplied") - if (length(type) == 2L) .err("Supplied objects must be all `matchdata` objects or all `getmatches` objects") + + if (is_null(type)) { + .err("A `matchdata` or `getmatches` object must be supplied") + } + + if (length(type) == 2L) { + .err("Supplied objects must be all `matchdata` objects or all `getmatches` objects") + } attrs <- c("distance", "weights", "subclass", "id") attr_list <- setNames(vector("list", length(attrs)), attrs) @@ -94,10 +100,15 @@ rbind.matchdata <- function(..., deparse.level = 1) { for (i in attrs) { attr_list[[i]] <- unlist(lapply(md_list, function(m) { a <- attr(m, i) - if (length(a) == 0) NA_character_ else a + if (is_null(a)) NA_character_ else a })) - if (all(is.na(attr_list[[i]]))) attr_list[[i]] <- NULL - else key_attrs[i] <- attr_list[[i]][which(!is.na(attr_list[[i]]))[1]] + + if (all(is.na(attr_list[[i]]))) { + attr_list[[i]] <- NULL + } + else { + key_attrs[i] <- attr_list[[i]][which(!is.na(attr_list[[i]]))[1]] + } } attrs <- names(attr_list) key_attrs <- key_attrs[attrs] @@ -106,6 +117,7 @@ rbind.matchdata <- function(..., deparse.level = 1) { other_col_list <- lapply(seq_along(md_list), function(d) { setdiff(names(md_list[[d]]), unlist(lapply(attr_list, `[`, d))) }) + for (d in seq_along(md_list)[-1]) { if (length(other_col_list[[d]]) != length(other_col_list[[1]]) || !all(other_col_list[[d]] %in% other_col_list[[1]])) { .err(sprintf("the %s inputs must come from the same dataset", @@ -116,13 +128,21 @@ rbind.matchdata <- function(..., deparse.level = 1) { for (d in seq_along(md_list)) { for (i in attrs) { #Rename columns of each attribute the same across datasets - if (is.null(attr(md_list[[d]], i))) md_list[[d]] <- setNames(cbind(md_list[[d]], NA), c(names(md_list[[d]]), key_attrs[i])) - else names(md_list[[d]])[names(md_list[[d]]) == attr_list[[i]][d]] <- key_attrs[i] + if (is_null(attr(md_list[[d]], i))) { + md_list[[d]] <- setNames(cbind(md_list[[d]], NA), c(names(md_list[[d]]), key_attrs[i])) + } + else { + names(md_list[[d]])[names(md_list[[d]]) == attr_list[[i]][d]] <- key_attrs[i] + } #Give subclasses unique values across datasets if (i == "subclass") { - if (all(is.na(md_list[[d]][[key_attrs[i]]]))) md_list[[d]][[key_attrs[i]]] <- factor(md_list[[d]][[key_attrs[i]]], levels = NA) - else levels(md_list[[d]][[key_attrs[i]]]) <- paste(d, levels(md_list[[d]][[key_attrs[i]]]), sep = "_") + if (all(is.na(md_list[[d]][[key_attrs[i]]]))) { + md_list[[d]][[key_attrs[i]]] <- factor(md_list[[d]][[key_attrs[i]]], levels = NA) + } + else { + levels(md_list[[d]][[key_attrs[i]]]) <- paste(d, levels(md_list[[d]][[key_attrs[i]]]), sep = "_") + } } } @@ -130,7 +150,8 @@ rbind.matchdata <- function(..., deparse.level = 1) { if (d > 1) { md_list[[d]] <- md_list[[d]][names(md_list[[1]])] } - class(md_list[[d]]) <- class(md_list[[d]])[class(md_list[[d]]) != type] + + class(md_list[[d]]) <- setdiff(class(md_list[[d]]), type) } out <- do.call("rbind", c(md_list, allargs)) diff --git a/R/summary.matchit.R b/R/summary.matchit.R index 81926909..3fc0be3d 100644 --- a/R/summary.matchit.R +++ b/R/summary.matchit.R @@ -208,9 +208,12 @@ summary.matchit <- function(object, treat <- object$treat weights <- object$weights - s.weights <- if (is.null(object$s.weights)) rep(1, length(weights)) else object$s.weights + s.weights <- { + if (is_null(object$s.weights)) rep(1, length(weights)) + else object$s.weights + } - no_x <- length(X) == 0 + no_x <- is_null(X) if (no_x) { X <- matrix(1, nrow = length(treat), ncol = 1, @@ -227,7 +230,7 @@ summary.matchit <- function(object, kk <- ncol(X) - matched <- !is.null(object$info$method) + matched <- is_not_null(object$info$method) un <- un || !matched chk::chk_flag(interactions) @@ -321,7 +324,7 @@ summary.matchit <- function(object, } } - if (!is.null(object$distance)) { + if (is_not_null(object$distance)) { if (un) { ad.all <- bal1var(object$distance, tt = treat, ww = NULL, s.weights = s.weights, standardize = standardize, s.d.denom = s.d.denom) @@ -400,7 +403,10 @@ summary.matchit.subclass <- function(object, which.subclass <- subclass treat <- object$treat weights <- object$weights - s.weights <- if (is.null(object$s.weights)) rep(1, length(weights)) else object$s.weights + s.weights <- { + if (is_null(object$s.weights)) rep(1, length(weights)) + else object$s.weights + } subclass <- object$subclass nam <- colnames(X) @@ -430,7 +436,7 @@ summary.matchit.subclass <- function(object, else which.subclass <- subclasses[which.subclass] matched <- TRUE #always compute aggregate balance so plot.summary can use it - subs <- !is.null(which.subclass) + subs <- is_not_null(which.subclass) ## Aggregate Subclass #Use the estimated weights to compute aggregate balance. @@ -507,7 +513,7 @@ summary.matchit.subclass <- function(object, } } - if (!is.null(object$distance)) { + if (is_not_null(object$distance)) { if (un) { ad.all <- bal1var(object$distance, tt = treat, ww = NULL, s.weights = s.weights, standardize = standardize, s.d.denom = s.d.denom) @@ -580,7 +586,7 @@ summary.matchit.subclass <- function(object, sum.sub <- rbind(sum.sub, sum.sub.int[!to.remove,,drop = FALSE]) } - if (!is.null(object$distance)) { + if (is_not_null(object$distance)) { ad <- bal1var.subclass(object$distance, tt = treat, s.weights = s.weights, subclass = subclass, s.d.denom = s.d.denom, standardize = standardize, which.subclass = s) sum.sub <- rbind(ad, sum.sub) @@ -611,28 +617,28 @@ summary.matchit.subclass <- function(object, print.summary.matchit <- function(x, digits = max(3, getOption("digits") - 3), ...) { - if (!is.null(x$call)) { + if (is_not_null(x$call)) { cat("\nCall:", deparse(x$call), sep = "\n") } - if (!is.null(x$sum.all)) { + if (is_not_null(x$sum.all)) { cat("\nSummary of Balance for All Data:\n") print(round_df_char(x$sum.all[,-7, drop = FALSE], digits, pad = "0", na_vals = "."), right = TRUE, quote = FALSE) } - if (!is.null(x$sum.matched)) { + if (is_not_null(x$sum.matched)) { cat("\nSummary of Balance for Matched Data:\n") if (all(is.na(x$sum.matched[,7]))) x$sum.matched <- x$sum.matched[,-7,drop = FALSE] #Remove pair dist if empty print(round_df_char(x$sum.matched, digits, pad = "0", na_vals = "."), right = TRUE, quote = FALSE) } - if (!is.null(x$reduction)) { + if (is_not_null(x$reduction)) { cat("\nPercent Balance Improvement:\n") print(round_df_char(x$reduction[,-5, drop = FALSE], 1, pad = "0", na_vals = "."), right = TRUE, quote = FALSE) } - if (!is.null(x$nn)) { + if (is_not_null(x$nn)) { cat("\nSample Sizes:\n") nn <- x$nn if (isTRUE(all.equal(nn["All (ESS)",], nn["All",]))) { @@ -653,43 +659,43 @@ print.summary.matchit <- function(x, digits = max(3, getOption("digits") - 3), #' @exportS3Method print summary.matchit.subclass print.summary.matchit.subclass <- function(x, digits = max(3, getOption("digits") - 3), ...){ - if (!is.null(x$call)) { + if (is_not_null(x$call)) { cat("\nCall:", deparse(x$call), sep = "\n") } - if (!is.null(x$sum.all)) { + if (is_not_null(x$sum.all)) { cat("\nSummary of Balance for All Data:\n") print(round_df_char(x$sum.all[,-7, drop = FALSE], digits, pad = "0", na_vals = "."), right = TRUE, quote = FALSE) } - if (length(x$sum.subclass) > 0) { + if (is_not_null(x$sum.subclass)) { cat("\nSummary of Balance by Subclass:\n") for (s in seq_along(x$sum.subclass)) { cat(paste0("\n- ", names(x$sum.subclass)[s], "\n")) print(round_df_char(x$sum.subclass[[s]][,-7, drop = FALSE], digits, pad = "0", na_vals = "."), right = TRUE, quote = FALSE) } - if (!is.null(x$qn)) { + if (is_not_null(x$qn)) { cat("\nSample Sizes by Subclass:\n") print(round_df_char(x$qn, 2, pad = " ", na_vals = "."), right = TRUE, quote = FALSE) } } else { - if (!is.null(x$sum.across)) { + if (is_not_null(x$sum.across)) { cat("\nSummary of Balance Across Subclasses\n") if (all(is.na(x$sum.across[,7]))) x$sum.across <- x$sum.across[,-7,drop = FALSE] print(round_df_char(x$sum.across, digits, pad = "0", na_vals = "."), right = TRUE, quote = FALSE) } - if (!is.null(x$reduction)) { + if (is_not_null(x$reduction)) { cat("\nPercent Balance Improvement:\n") print(round_df_char(x$reduction[,-5, drop = FALSE], 1, pad = "0", na_vals = "."), right = TRUE, quote = FALSE) } - if (!is.null(x$nn)) { + if (is_not_null(x$nn)) { cat("\nSample Sizes:\n") nn <- x$nn if (isTRUE(all.equal(nn["All (ESS)",], nn["All",]))) { @@ -710,25 +716,25 @@ print.summary.matchit.subclass <- function(x, digits = max(3, getOption("digits" .process_X <- function(object, addlvariables = NULL, data = NULL) { X <- { - if (length(object$X) == 0) matrix(nrow = length(object$treat), ncol = 0) + if (is_null(object$X)) matrix(nrow = length(object$treat), ncol = 0) else get.covs.matrix(data = object$X) } - if (is.null(addlvariables)) { + if (is_null(addlvariables)) { return(X) } #Attempt to extrct data from matchit object; same as match.data() data.fram.matchit <- FALSE - if (is.null(data)) { + if (is_null(data)) { env <- environment(object$formula) data <- try(eval(object$call$data, envir = env), silent = TRUE) - if (length(data) == 0 || inherits(data, "try-error") || length(dim(data)) != 2 || nrow(data) != length(object[["treat"]])) { + if (null_or_error(data) || length(dim(data)) != 2 || nrow(data) != length(object[["treat"]])) { env <- parent.frame() data <- try(eval(object$call$data, envir = env), silent = TRUE) - if (length(data) == 0 || inherits(data, "try-error") || length(dim(data)) != 2 || nrow(data) != length(object[["treat"]])) { + if (null_or_error(data) || length(dim(data)) != 2 || nrow(data) != length(object[["treat"]])) { data <- object[["model"]][["data"]] - if (length(data) == 0 || nrow(data) != length(object[["treat"]])) { + if (is_null(data) || nrow(data) != length(object[["treat"]])) { data <- NULL } else data.fram.matchit <- TRUE @@ -739,19 +745,19 @@ print.summary.matchit.subclass <- function(x, digits = max(3, getOption("digits" } if (is.character(addlvariables)) { - if (is.null(data) || !is.data.frame(data)) { - .err("If `addlvariables` is specified as a string, a data frame argument must be supplied to `data`") + if (is_null(data) || !is.data.frame(data)) { + .err("if `addlvariables` is specified as a string, a data frame argument must be supplied to `data`") } if (!all(addlvariables %in% names(data))) { - .err("All variables in `addlvariables` must be in `data`") + .err("all variables in `addlvariables` must be in `data`") } addlvariables <- data[addlvariables] } else if (inherits(addlvariables, "formula")) { vars.in.formula <- all.vars(addlvariables) - if (!is.null(data) && is.data.frame(data)) { + if (is_not_null(data) && is.data.frame(data)) { data <- data.frame(data[names(data) %in% vars.in.formula], object$X[names(object$X) %in% setdiff(vars.in.formula, names(data))]) } @@ -760,7 +766,7 @@ print.summary.matchit.subclass <- function(x, digits = max(3, getOption("digits" # addlvariables <- get.covs.matrix(addlvariables, data = data) } else if (!is.matrix(addlvariables) && !is.data.frame(addlvariables)) { - .err("The argument to `addlvariables` must be in one of the accepted forms. See `?summary.matchit` for details") + .err("the argument to `addlvariables` must be in one of the accepted forms. See `?summary.matchit` for details") } @@ -770,7 +776,7 @@ print.summary.matchit.subclass <- function(x, digits = max(3, getOption("digits" } if (nrow(addlvariables) != length(object$treat)) { - if (is.null(data) || data.fram.matchit) { + if (is_null(data) || data.fram.matchit) { .err("variables specified in `addlvariables` must have the same number of units as are present in the original call to `matchit()`") } else { From c136ee71e3bbc341f8f21af105e1afdc885be7ce Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Thu, 24 Oct 2024 18:44:43 -0400 Subject: [PATCH 11/48] Moved some aux_functions to utils; utils incorporates some from WeightIt --- R/aux_functions.R | 481 +++++----------------------------------------- R/utils.R | 403 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 451 insertions(+), 433 deletions(-) create mode 100644 R/utils.R diff --git a/R/aux_functions.R b/R/aux_functions.R index 92d9e84b..937297f3 100644 --- a/R/aux_functions.R +++ b/R/aux_functions.R @@ -1,7 +1,5 @@ -#Auxiliary functions; some from WeightIt - #Function to ensure no subclass is devoid of both treated and control units by "scooting" units -#from other subclasses. From WeightIt. +#from other subclasses. subclass_scoot <- function(sub, treat, x, min.n = 1) { #Reassigns subclasses so there are no empty subclasses #for each treatment group. @@ -27,15 +25,15 @@ create_info <- function(method, fn1, link, discard, replace, ratio, mahalanobis, transform, subclass, antiexact, distance_is_matrix) { info <- list(method = method, - distance = if (is.null(fn1)) NULL else sub("distance2", "", fn1, fixed = TRUE), - link = if (is.null(link)) NULL else link, + distance = if (is_null(fn1)) NULL else sub("distance2", "", fn1, fixed = TRUE), + link = if (is_null(link)) NULL else link, discard = discard, - replace = if (!is.null(method) && method %in% c("nearest", "genetic")) replace else NULL, - ratio = if (!is.null(method) && method %in% c("nearest", "optimal", "genetic")) ratio else NULL, - max.controls = if (!is.null(method) && method %in% c("nearest", "optimal")) attr(ratio, "max.controls") else NULL, + replace = if (is_not_null(method) && method %in% c("nearest", "genetic")) replace else NULL, + ratio = if (is_not_null(method) && method %in% c("nearest", "optimal", "genetic")) ratio else NULL, + max.controls = if (is_not_null(method) && method %in% c("nearest", "optimal")) attr(ratio, "max.controls") else NULL, mahalanobis = mahalanobis, transform = transform, - subclass = if (!is.null(method) && method == "subclass") length(unique(subclass[!is.na(subclass)])) else NULL, + subclass = if (is_not_null(method) && method == "subclass") length(unique(subclass[!is.na(subclass)])) else NULL, antiexact = antiexact, distance_is_matrix = distance_is_matrix) info @@ -47,12 +45,12 @@ info.to.method <- function(info) { out.list <- setNames(vector("list", 3), c("kto1", "type", "replace")) out.list[["kto1"]] <- { - if (!is.null(info$ratio)) paste0(if (!is.null(info$max.controls)) "variable ratio ", round(info$ratio, 2), ":1") + if (is_not_null(info$ratio)) paste0(if (is_not_null(info$max.controls)) "variable ratio ", round(info$ratio, 2), ":1") else NULL } out.list[["type"]] <- { - if (is.null(info$method)) "none (no matching)" + if (is_null(info$method)) "none (no matching)" else switch(info$method, "exact" = "exact matching", "cem" = "coarsened exact matching", @@ -63,12 +61,12 @@ info.to.method <- function(info) { "genetic" = "genetic matching", "subclass" = sprintf("subclassification (%s subclasses)", info$subclass), "cardinality" = "cardinality matching", - if (is.null(attr(info$method, "method"))) "an unspecified matching method" + if (is_null(attr(info$method, "method"))) "an unspecified matching method" else attr(info$method, "method")) } out.list[["replace"]] <- { - if (is.null(info$replace) || !info$method %in% c("nearest", "genetic")) NULL + if (is_null(info$replace) || !info$method %in% c("nearest", "genetic")) NULL else if (info$replace) "with replacement" else "without replacement" } @@ -79,11 +77,13 @@ info.to.method <- function(info) { info.to.distance <- function(info) { distance <- info$distance link <- info$link - if (!is.null(link) && startsWith(as.character(link), "linear")) { + if (is_not_null(link) && startsWith(as.character(link), "linear")) { linear <- TRUE link <- sub("linear.", "", as.character(link)) } - else linear <- FALSE + else { + linear <- FALSE + } if (distance == "glm") { if (link == "logit") dist <- "logistic regression" @@ -127,152 +127,20 @@ info.to.distance <- function(info) { dist } -#Function to turn a vector into a string with "," and "and" or "or" for clean messages. 'and.or' -#controls whether words are separated by "and" or "or"; 'is.are' controls whether the list is -#followed by "is" or "are" (to avoid manually figuring out if plural); quotes controls whether -#quotes should be placed around words in string. From WeightIt. -word_list <- function(word.list = NULL, and.or = "and", is.are = FALSE, quotes = FALSE) { - #When given a vector of strings, creates a string of the form "a and b" - #or "a, b, and c" - #If is.are, adds "is" or "are" appropriately - L <- length(word.list) - word.list <- add_quotes(word.list, quotes) - - if (L == 0) { - out <- "" - attr(out, "plural") <- FALSE - } - else { - word.list <- word.list[!word.list %in% c(NA_character_, "")] - L <- length(word.list) - if (L == 0) { - out <- "" - attr(out, "plural") <- FALSE - } - else if (L == 1) { - out <- word.list - if (is.are) out <- paste(out, "is") - attr(out, "plural") <- FALSE - } - else { - and.or <- match_arg(and.or, c("and", "or")) - if (L == 2) { - out <- paste(word.list, collapse = paste0(" ", and.or, " ")) - } - else { - out <- paste(paste(word.list[seq_len(L - 1)], collapse = ", "), - word.list[L], sep = paste0(", ", and.or, " ")) - - } - if (is.are) out <- paste(out, "are") - attr(out, "plural") <- TRUE - } - - } - - out -} - -#Add quotes to a string -add_quotes <- function(x, quotes = 2L) { - if (isFALSE(quotes)) return(x) - - if (isTRUE(quotes)) quotes <- 2 - - if (chk::vld_string(quotes)) x <- paste0(quotes, x, quotes) - else if (chk::vld_whole_number(quotes)) { - if (as.integer(quotes) == 0) return(x) - else if (as.integer(quotes) == 1) x <- paste0("\'", x, "\'") - else if (as.integer(quotes) == 2) x <- paste0("\"", x, "\"") - else stop("`quotes` must be boolean, 1, 2, or a string.") - } - else { - stop("`quotes` must be boolean, 1, 2, or a string.") +#Make interaction vector out of matrix of covs; similar to interaction() +exactify <- function(X, nam = NULL, sep = "|", include_vars = FALSE, justify = "right") { + if (is_null(nam)) { + nam <- rownames(X) } - x -} - -#More informative and cleaner version of base::match.arg(). Uses chk. -match_arg <- function(arg, choices, several.ok = FALSE) { - #Replaces match.arg() but gives cleaner error message and processing - #of arg. - if (missing(arg)) - stop("No argument was supplied to match_arg.") - arg.name <- deparse1(substitute(arg), width.cutoff = 500L) - - if (missing(choices)) { - formal.args <- formals(sys.function(sysP <- sys.parent())) - choices <- eval(formal.args[[as.character(substitute(arg))]], - envir = sys.frame(sysP)) + if (is.matrix(X)) { + X <- setNames(lapply(seq_len(ncol(X)), function(i) X[,i]), colnames(X)) } - if (length(arg) == 0) return(choices[1L]) - - if (several.ok) { - chk::chk_character(arg, add_quotes(arg.name, "`")) - } - else { - chk::chk_string(arg, add_quotes(arg.name, "`")) - if (identical(arg, choices)) return(arg[1L]) + if (!is.list(X)) { + stop("X must be a matrix, data frame, or list.") } - i <- pmatch(arg, choices, nomatch = 0L, duplicates.ok = TRUE) - if (all(i == 0L)) - .err(sprintf("the argument to `%s` should be %s%s.", - arg.name, ngettext(length(choices), "", if (several.ok) "at least one of " else "one of "), - word_list(choices, and.or = "or", quotes = 2))) - i <- i[i > 0L] - - choices[i] -} - -#Turn a vector into a 0/1 vector. 'zero' and 'one' can be supplied to make it clear which is -#which; otherwise, a guess is used. From WeightIt. -binarize <- function(variable, zero = NULL, one = NULL) { - var.name <- deparse1(substitute(variable)) - if (length(unique(variable)) > 2) { - stop(sprintf("Cannot binarize %s: more than two levels.", var.name)) - } - if (is.character(variable) || is.factor(variable)) { - variable <- factor(variable, nmax = 2) - unique.vals <- levels(variable) - } - else { - unique.vals <- unique(variable, nmax = 2) - } - - if (is.null(zero)) { - if (is.null(one)) { - if (can_str2num(unique.vals)) { - variable.numeric <- str2num(variable) - } - else { - variable.numeric <- as.numeric(variable) - } - - if (0 %in% variable.numeric) zero <- 0 - else zero <- min(variable.numeric, na.rm = TRUE) - - return(setNames(as.integer(variable.numeric != zero), names(variable))) - } - else { - if (one %in% unique.vals) return(setNames(as.integer(variable == one), names(variable))) - else stop("The argument to 'one' is not the name of a level of variable.") - } - } - else { - if (zero %in% unique.vals) return(setNames(as.integer(variable != zero), names(variable))) - else stop("The argument to 'zero' is not the name of a level of variable.") - } -} - -#Make interaction vector out of matrix of covs; similar to interaction() -exactify <- function(X, nam = NULL, sep = "|", include_vars = FALSE, justify = "right") { - if (is.null(nam)) nam <- rownames(X) - if (is.matrix(X)) X <- setNames(lapply(seq_len(ncol(X)), function(i) X[,i]), colnames(X)) - if (!is.list(X)) stop("X must be a matrix, data frame, or list.") - for (i in seq_along(X)) { unique_x <- { if (is.factor(X[[i]])) levels(X[[i]]) @@ -285,7 +153,7 @@ exactify <- function(X, nam = NULL, sep = "|", include_vars = FALSE, justify = " names(X)[i], add_quotes(unique_x, is.character(X[[i]]) || is.factor(X[[i]]))) } - else if (is.null(justify)) unique_x + else if (is_null(justify)) unique_x else format(unique_x, justify = justify) } @@ -300,114 +168,22 @@ exactify <- function(X, nam = NULL, sep = "|", include_vars = FALSE, justify = " out <- factor(out, levels = all_levels[all_levels %in% out]) - if (!is.null(nam)) names(out) <- nam + if (is_not_null(nam)) names(out) <- nam out } -#Determine whether a character vector can be coerced to numeric -can_str2num <- function(x) { - nas <- is.na(x) - suppressWarnings(x_num <- as.numeric(as.character(x[!nas]))) - !anyNA(x_num) -} - -#Cleanly coerces a character vector to numeric; best to use after can_str2num() -str2num <- function(x) { - nas <- is.na(x) - suppressWarnings(x_num <- as.numeric(as.character(x))) - is.na(x_num)[nas] <- TRUE - x_num -} - -#Capitalize first letter of string -firstup <- function(x) { - substr(x, 1, 1) <- toupper(substr(x, 1, 1)) - x -} - -#Capitalize first letter of each word -capwords <- function(s, strict = FALSE) { - cap <- function(s) paste0(toupper(substring(s, 1, 1)), - {s <- substring(s, 2); if(strict) tolower(s) else s}, - collapse = " ") - sapply(strsplit(s, split = " "), cap, USE.NAMES = !is.null(names(s))) -} - -#Clean printing of data frames with numeric and NA elements. -round_df_char <- function(df, digits, pad = "0", na_vals = "") { - #Digits is passed to round(). pad is used to replace trailing zeros so decimal - #lines up. Should be "0" or " "; "" (the empty string) un-aligns decimals. - #na_vals is what NA should print as. - - if (NROW(df) == 0 || NCOL(df) == 0) return(as.matrix(df)) - if (!is.data.frame(df)) df <- as.data.frame.matrix(df, stringsAsFactors = FALSE) - - rn <- rownames(df) - cn <- colnames(df) - - infs <- o.negs <- array(FALSE, dim = dim(df)) - nas <- is.na(df) - nums <- vapply(df, is.numeric, logical(1)) - infs[,nums] <- vapply(which(nums), function(i) !nas[,i] & !is.finite(df[[i]]), logical(NROW(df))) - - for (i in which(!nums)) { - if (can_str2num(df[[i]])) { - df[[i]] <- str2num(df[[i]]) - nums[i] <- TRUE - } - } - - o.negs[,nums] <- !nas[,nums] & df[nums] < 0 & round(df[nums], digits) == 0 - df[nums] <- round(df[nums], digits = digits) - - for (i in which(nums)) { - df[[i]] <- format(df[[i]], scientific = FALSE, justify = "none", trim = TRUE, - drop0trailing = !identical(as.character(pad), "0")) - - if (!identical(as.character(pad), "0") && any(grepl(".", df[[i]], fixed = TRUE))) { - s <- strsplit(df[[i]], ".", fixed = TRUE) - lengths <- lengths(s) - digits.r.of.. <- rep(0, NROW(df)) - digits.r.of..[lengths > 1] <- nchar(vapply(s[lengths > 1], `[[`, character(1L), 2)) - max.dig <- max(digits.r.of..) - - dots <- ifelse(lengths > 1, "", if (as.character(pad) != "") "." else pad) - pads <- vapply(max.dig - digits.r.of.., function(n) paste(rep(pad, n), collapse = ""), character(1L)) - - df[[i]] <- paste0(df[[i]], dots, pads) - } - } - - df[o.negs] <- paste0("-", df[o.negs]) - - # Insert NA placeholders - df[nas] <- na_vals - df[infs] <- "N/A" - - if (length(rn) > 0) rownames(df) <- rn - if (length(cn) > 0) names(df) <- cn - - as.matrix(df) -} - -#Generalized inverse; port of MASS::ginv() -generalized_inverse <- function(sigma) { - sigmasvd <- svd(sigma) - pos <- sigmasvd$d > max(1e-8 * sigmasvd$d[1L], 0) - sigma_inv <- sigmasvd$v[, pos, drop = FALSE] %*% (sigmasvd$d[pos]^-1 * t(sigmasvd$u[, pos, drop = FALSE])) - sigma_inv -} - #Get covariates (RHS) vars from formula get.covs.matrix <- function(formula = NULL, data = NULL) { - if (is.null(formula)) { + if (is_null(formula)) { fnames <- colnames(data) fnames[!startsWith(fnames, "`")] <- paste0("`", fnames[!startsWith(fnames, "`")], "`") formula <- reformulate(fnames) } - else formula <- update(terms(formula, data = data), NULL ~ . + 1) + else { + formula <- update(terms(formula, data = data), NULL ~ . + 1) + } mf <- model.frame(terms(formula, data = data), data, na.action = na.pass) @@ -429,7 +205,9 @@ get.covs.matrix <- function(formula = NULL, data = NULL) { #Extracts and names the "assign" attribute from get.covs.matrix() get_assign <- function(mat) { - if (is.null(attr(mat, "assign"))) return(NULL) + if (is_null(attr(mat, "assign"))) { + return(NULL) + } setNames(attr(mat, "assign"), colnames(mat)) } @@ -452,67 +230,21 @@ charmm2nummm <- function(charmm, treat) { } #Get subclass from match.matrix. Only to be used if replace = FALSE. See subclass2mmC.cpp for reverse. -mm2subclass <- function(mm, treat) { - lab <- names(treat) - mmlab <- rownames(mm) - no.match <- is.na(mm) - - subclass <- setNames(rep(NA_character_, length(treat)), lab) - - subclass[mmlab[!no.match[,1]]] <- mmlab[!no.match[,1]] - - subclass[mm[!no.match]] <- mmlab[row(mm)[!no.match]] - - subclass <- setNames(factor(subclass, nmax = length(mmlab)), lab) - levels(subclass) <- seq_len(nlevels(subclass)) - - subclass -} - -#(Weighted) variance that uses special formula for binary variables -wvar <- function(x, bin.var = NULL, w = NULL) { - if (is.null(w)) w <- rep(1, length(x)) - if (is.null(bin.var)) bin.var <- all(x == 0 | x == 1) - - w <- w / sum(w) #weights normalized to sum to 1 - mx <- sum(w * x) #weighted mean - - if (bin.var) { - return(mx * (1 - mx)) +mm2subclass <- function(mm, treat, focal = NULL) { + if (!is.integer(mm)) { + mm <- charmm2nummm(mm, treat) } - #Reliability weights variance; same as cov.wt() - sum(w * (x - mx)^2)/(1 - sum(w^2)) -} - -#Weighted mean faster than weighted.mean() -wm <- function(x, w = NULL, na.rm = TRUE) { - if (is.null(w)) { - if (anyNA(x)) { - if (!na.rm) return(NA_real_) - nas <- which(is.na(x)) - x <- x[-nas] - } - return(sum(x)/length(x)) - } - - if (anyNA(x) || anyNA(w)) { - if (!na.rm) return(NA_real_) - nas <- which(is.na(x) | is.na(w)) - x <- x[-nas] - w <- w[-nas] - } - - sum(x*w)/sum(w) + mm2subclassC(mm, treat, focal) } #Pooled within-group (weighted) covariance by group-mean centering covariates. Used #in Mahalanobis distance pooled_cov <- function(X, t, w = NULL) { unique_t <- unique(t) - if (is.null(dim(X))) X <- matrix(X, nrow = length(X)) + if (is_null(dim(X))) X <- matrix(X, nrow = length(X)) - if (is.null(w)) { + if (is_null(w)) { n <- nrow(X) for (i in unique_t) { in_t <- which(t == i) @@ -535,9 +267,9 @@ pooled_cov <- function(X, t, w = NULL) { pooled_sd <- function(X, t, w = NULL, bin.var = NULL, contribution = "proportional") { contribution <- match_arg(contribution, c("proportional", "equal")) unique_t <- unique(t) - if (is.null(dim(X))) X <- matrix(X, nrow = length(X)) + if (is_null(dim(X))) X <- matrix(X, nrow = length(X)) n <- nrow(X) - if (is.null(bin.var)) bin.var <- apply(X, 2, function(x) all(x == 0 | x == 1)) + if (is_null(bin.var)) bin.var <- apply(X, 2, function(x) all(x == 0 | x == 1)) if (contribution == "equal") { vars <- matrix(0, nrow = length(unique_t), ncol = ncol(X)) @@ -557,7 +289,7 @@ pooled_sd <- function(X, t, w = NULL, bin.var = NULL, contribution = "proportion b <- bin.var[j] if (b) { - if (is.null(w)) { + if (is_null(w)) { v <- vapply(unique_t, function(i) { sxi <- sum(x[t == i]) ni <- sum(t == i) @@ -575,7 +307,7 @@ pooled_sd <- function(X, t, w = NULL, bin.var = NULL, contribution = "proportion } } else { - if (is.null(w)) { + if (is_null(w)) { for (i in unique_t) { x[t==i] <- x[t==i] - wm(x[t==i]) } @@ -603,8 +335,8 @@ ESS <- function(w) { #Compute sample sizes nn <- function(treat, weights, discarded = NULL, s.weights = NULL) { - if (is.null(discarded)) discarded <- rep(FALSE, length(treat)) - if (is.null(s.weights)) s.weights <- rep(1, length(treat)) + if (is_null(discarded)) discarded <- rep(FALSE, length(treat)) + if (is_null(s.weights)) s.weights <- rep(1, length(treat)) weights <- weights * s.weights n <- matrix(0, ncol=2, nrow=6, dimnames = list(c("All (ESS)", "All", "Matched (ESS)", "Matched", "Unmatched","Discarded"), @@ -625,7 +357,7 @@ nn <- function(treat, weights, discarded = NULL, s.weights = NULL) { qn <- function(treat, subclass, discarded = NULL) { treat <- factor(treat, levels = 0:1, labels = c("Control", "Treated")) - if (is.null(discarded)) discarded <- rep(FALSE, length(treat)) + if (is_null(discarded)) discarded <- rep(FALSE, length(treat)) qn <- table(treat[!discarded], subclass[!discarded]) if (any(is.na(subclass) & !discarded)) { @@ -647,123 +379,6 @@ qn <- function(treat, subclass, discarded = NULL) { qn } -#Faster diff() -diff1 <- function(x) { - x[-1] - x[-length(x)] -} - -#cumsum() for probabilities to ensure they are between 0 and 1 -.cumsum_prob <- function(x) { - s <- cumsum(x) - s / s[length(s)] -} - -#Make vector sum to 1, optionally by group -.make_sum_to_1 <- function(x, by = NULL) { - if (is.null(by)) { - return(x / sum(x)) - } - - for (i in unique(by)) { - in_i <- which(by == i) - x[in_i] <- x[in_i] / sum(x[in_i]) - } - - x -} - -#Make vector sum to n (average of 1), optionally by group -.make_sum_to_n <- function(x, by = NULL) { - if (is.null(by)) { - return(length(x) * x / sum(x)) - } - - for (i in unique(by)) { - in_i <- which(by == i) - x[in_i] <- length(in_i) * x[in_i] / sum(x[in_i]) - } - - x -} - -#Functions for error handling; based on chk and rlang -pkg_caller_call <- function(start = 1) { - package.funs <- c(getNamespaceExports(utils::packageName()), - .getNamespaceInfo(asNamespace(utils::packageName()), "S3methods")[, 3]) - k <- start #skip checking pkg_caller_call() - e_max <- start - while (!is.null(e <- rlang::caller_call(k))) { - if (!is.null(n <- rlang::call_name(e)) && - n %in% package.funs) e_max <- k - k <- k + 1 - } - rlang::caller_call(e_max) -} - -.err <- function(...) { - chk::err(..., call = pkg_caller_call(start = 2)) -} -.wrn <- function(...) { - chk::wrn(...) -} -.msg <- function(...) { - chk::msg(...) -} - -#De-bugged version of chk::chk_null_or() -.chk_null_or <- function(x, chk, ..., x_name = NULL) { - if (is.null(x_name)) { - x_name <- deparse1(substitute(x)) - } - - x_name <- add_quotes(x_name, "`") - - if (is.null(x)) { - return(invisible(x)) - } - - tryCatch(chk(x, ..., x_name = x_name), - error = function(e) { - msg <- sub("[.]$", " or `NULL`.", - conditionMessage(e)) - chk::err(msg, .subclass = "chk_error") - }) -} - -.chk_formula <- function(x, sides = NULL, x_name = NULL) { - if (is.null(sides)) { - if (rlang::is_formula(x)) { - return(invisible(x)) - } - if (is.null(x_name)) { - x_name <- chk::deparse_backtick_chk(substitute(x)) - } - chk::abort_chk(x_name, " must be a formula", - x = x) - } - else if (sides == 1) { - if (rlang::is_formula(x, lhs = FALSE)) { - return(invisible(x)) - } - if (is.null(x_name)) { - x_name <- chk::deparse_backtick_chk(substitute(x)) - } - chk::abort_chk(x_name, " must be a formula with no left-hand side", - x = x) - } - else if (sides == 2) { - if (rlang::is_formula(x, lhs = TRUE)) { - return(invisible(x)) - } - if (is.null(x_name)) { - x_name <- chk::deparse_backtick_chk(substitute(x)) - } - chk::abort_chk(x_name, " must be a formula with a left-hand side", - x = x) - } - else stop("`sides` must be NULL, 1, or 2") -} - #Function to capture and print errors and warnings better matchit_try <- function(expr, from = NULL, dont_warn_if = NULL) { tryCatch({ @@ -771,14 +386,14 @@ matchit_try <- function(expr, from = NULL, dont_warn_if = NULL) { expr }, warning = function(w) { - if (is.null(dont_warn_if) || !grepl(dont_warn_if, conditionMessage(w), fixed = TRUE)) { - if (is.null(from)) .wrn(conditionMessage(w), tidy = FALSE) + if (is_null(dont_warn_if) || !grepl(dont_warn_if, conditionMessage(w), fixed = TRUE)) { + if (is_null(from)) .wrn(conditionMessage(w), tidy = FALSE) else .wrn(sprintf("(from %s) %s", from, conditionMessage(w)), tidy = FALSE) } invokeRestart("muffleWarning") })}, error = function(e) { - if (is.null(from)) .err(conditionMessage(e), tidy = FALSE) + if (is_null(from)) .err(conditionMessage(e), tidy = FALSE) else .err(sprintf("(from %s) %s", from, conditionMessage(e)), tidy = FALSE) }) } \ No newline at end of file diff --git a/R/utils.R b/R/utils.R new file mode 100644 index 00000000..588c4ea6 --- /dev/null +++ b/R/utils.R @@ -0,0 +1,403 @@ +#Function to turn a vector into a string with "," and "and" or "or" for clean messages. 'and.or' +#controls whether words are separated by "and" or "or"; 'is.are' controls whether the list is +#followed by "is" or "are" (to avoid manually figuring out if plural); quotes controls whether +#quotes should be placed around words in string. From WeightIt. +word_list <- function(word.list = NULL, and.or = "and", is.are = FALSE, quotes = FALSE) { + #When given a vector of strings, creates a string of the form "a and b" + #or "a, b, and c" + #If is.are, adds "is" or "are" appropriately + + word.list <- setdiff(word.list, c(NA_character_, "")) + + if (is_null(word.list)) { + out <- "" + attr(out, "plural") <- FALSE + return(out) + } + + word.list <- add_quotes(word.list, quotes) + + L <- length(word.list) + + if (L == 1L) { + out <- word.list + if (is.are) out <- paste(out, "is") + attr(out, "plural") <- FALSE + return(out) + } + + if (is_null(and.or) || isFALSE(and.or)) { + out <- paste(word.list, collapse = ", ") + } + else { + and.or <- match_arg(and.or, c("and", "or")) + + if (L == 2L) { + out <- sprintf("%s %s %s", + word.list[1L], + and.or, + word.list[2L]) + } + else { + out <- sprintf("%s, %s %s", + paste(word.list[-L], collapse = ", "), + and.or, + word.list[L]) + } + } + + if (is.are) out <- sprintf("%s are", out) + + attr(out, "plural") <- TRUE + + out +} + +#Add quotes to a string +add_quotes <- function(x, quotes = 2L) { + if (isFALSE(quotes)) { + return(x) + } + + if (isTRUE(quotes)) + quotes <- '"' + + if (chk::vld_string(quotes)) { + return(paste0(quotes, x, quotes)) + } + + if (!chk::vld_count(quotes) || quotes > 2) { + stop("`quotes` must be boolean, 1, 2, or a string.") + } + + if (quotes == 0L) { + return(x) + } + + x <- { + if (quotes == 1) sprintf("'%s'", x) + else sprintf('"%s"', x) + } + + x +} + +#More informative and cleaner version of base::match.arg(). Uses chk. +match_arg <- function(arg, choices, several.ok = FALSE) { + #Replaces match.arg() but gives cleaner error message and processing + #of arg. + if (missing(arg)) { + stop("No argument was supplied to match_arg.") + } + + arg.name <- deparse1(substitute(arg), width.cutoff = 500L) + + if (missing(choices)) { + formal.args <- formals(sys.function(sysP <- sys.parent())) + choices <- eval(formal.args[[as.character(substitute(arg))]], + envir = sys.frame(sysP)) + } + + if (is_null(arg)) { + return(choices[1L]) + } + + if (several.ok) { + chk::chk_character(arg, x_name = add_quotes(arg.name, "`")) + } + else { + chk::chk_string(arg, x_name = add_quotes(arg.name, "`")) + if (identical(arg, choices)) return(arg[1L]) + } + + i <- pmatch(arg, choices, nomatch = 0L, duplicates.ok = TRUE) + if (all(i == 0L)) + .err(sprintf("the argument to `%s` should be %s%s", + arg.name, + ngettext(length(choices), "", if (several.ok) "at least one of " else "one of "), + word_list(choices, and.or = "or", quotes = 2))) + i <- i[i > 0L] + + choices[i] +} + +#Turn a vector into a 0/1 vector. 'zero' and 'one' can be supplied to make it clear which is +#which; otherwise, a guess is used. From WeightIt. +binarize <- function(variable, zero = NULL, one = NULL) { + var.name <- deparse1(substitute(variable)) + + if (has_n_unique(variable, 1L)) { + return(setNames(rep.int(1L, length(variable)), names(variable))) + } + + if (!has_n_unique(variable, 2L)) { + .err(sprintf("cannot binarize %s: more than two levels", var.name)) + } + + if (is.character(variable) || is.factor(variable)) { + variable <- factor(variable, nmax = if (is.factor(variable)) nlevels(variable) else NA) + unique.vals <- levels(variable) + } + else { + unique.vals <- unique(variable, nmax = 2L) + } + + if (is_not_null(zero)) { + if (!zero %in% unique.vals) { + .err(sprintf("the argument to `zero` is not the name of a level of %s", var.name)) + } + + return(setNames(as.integer(variable != zero), names(variable))) + } + + if (is_not_null(one)) { + if (!one %in% unique.vals) { + .err(sprintf("the argument to `one` is not the name of a level of %s", var.name)) + } + + return(setNames(as.integer(variable == one), names(variable))) + } + + if (is.logical(variable)) { + return(setNames(as.integer(variable), names(variable))) + } + + if (is.numeric(variable)) { + zero <- { + if (any(unique.vals == 0)) 0 + else min(unique.vals, na.rm = TRUE) + } + + return(setNames(as.integer(variable != zero), names(variable))) + } + + variable.numeric <- { + if (can_str2num(unique.vals)) setNames(str2num(unique.vals), unique.vals)[variable] + else as.numeric(factor(variable, levels = unique.vals)) + } + + zero <- { + if (0 %in% variable.numeric) 0 + else min(variable.numeric, na.rm = TRUE) + } + + setNames(as.integer(variable.numeric != zero), names(variable)) +} + +is_null <- function(x) length(x) == 0L +is_not_null <- function(x) !is_null(x) + +null_or_error <- function(x) {is_null(x) || inherits(x, "try-error")} + +#Determine whether a character vector can be coerced to numeric +can_str2num <- function(x) { + if (is.numeric(x) || is.logical(x)) { + return(TRUE) + } + + nas <- is.na(x) + suppressWarnings(x_num <- as.numeric(as.character(x[!nas]))) + + !anyNA(x_num) +} + +#Cleanly coerces a character vector to numeric; best to use after can_str2num() +str2num <- function(x) { + nas <- is.na(x) + if (!is.numeric(x) && !is.logical(x)) x <- as.character(x) + suppressWarnings(x_num <- as.numeric(x)) + is.na(x_num)[nas] <- TRUE + x_num +} + +#Capitalize first letter of string +firstup <- function(x) { + substr(x, 1, 1) <- toupper(substr(x, 1, 1)) + x +} + +#Capitalize first letter of each word +capwords <- function(s, strict = FALSE) { + cap <- function(s) paste0(toupper(substring(s, 1, 1)), + {s <- substring(s, 2); if(strict) tolower(s) else s}, + collapse = " ") + sapply(strsplit(s, split = " "), cap, USE.NAMES = is_not_null(names(s))) +} + +#Clean printing of data frames with numeric and NA elements. +round_df_char <- function(df, digits, pad = "0", na_vals = "") { + if (NROW(df) == 0L || NCOL(df) == 0L) { + return(df) + } + + if (!is.data.frame(df)) { + df <- as.data.frame.matrix(df, stringsAsFactors = FALSE) + } + + rn <- rownames(df) + cn <- colnames(df) + + infs <- o.negs <- array(FALSE, dim = dim(df)) + nas <- is.na(df) + nums <- vapply(df, is.numeric, logical(1)) + for (i in which(nums)) { + infs[,i] <- !nas[,i] & !is.finite(df[[i]]) + } + + for (i in which(!nums)) { + if (can_str2num(df[[i]])) { + df[[i]] <- str2num(df[[i]]) + nums[i] <- TRUE + } + } + + o.negs[,nums] <- !nas[,nums] & df[nums] < 0 & round(df[nums], digits) == 0 + df[nums] <- round(df[nums], digits = digits) + + for (i in which(nums)) { + df[[i]] <- format(df[[i]], scientific = FALSE, justify = "none", trim = TRUE, + drop0trailing = !identical(as.character(pad), "0")) + + if (!identical(as.character(pad), "0") && any(grepl(".", df[[i]], fixed = TRUE))) { + s <- strsplit(df[[i]], ".", fixed = TRUE) + lengths <- lengths(s) + digits.r.of.. <- rep.int(0, NROW(df)) + digits.r.of..[lengths > 1] <- nchar(vapply(s[lengths > 1], `[[`, character(1L), 2)) + max.dig <- max(digits.r.of..) + + dots <- ifelse(lengths > 1, "", if (as.character(pad) != "") "." else pad) + pads <- vapply(max.dig - digits.r.of.., function(n) paste(rep.int(pad, n), collapse = ""), character(1L)) + + df[[i]] <- paste0(df[[i]], dots, pads) + } + } + + df[o.negs] <- paste0("-", df[o.negs]) + + # Insert NA placeholders + df[nas] <- na_vals + df[infs] <- "N/A" + + if (length(rn) > 0) rownames(df) <- rn + if (length(cn) > 0) names(df) <- cn + + df +} + +#Generalized inverse; port of MASS::ginv() +generalized_inverse <- function(sigma) { + sigmasvd <- svd(sigma) + + pos <- sigmasvd$d > max(1e-8 * sigmasvd$d[1L], 0) + + sigmasvd$v[, pos, drop = FALSE] %*% (sigmasvd$d[pos]^-1 * t(sigmasvd$u[, pos, drop = FALSE])) +} + +#(Weighted) variance that uses special formula for binary variables +wvar <- function(x, bin.var = NULL, w = NULL) { + if (is_null(w)) w <- rep(1, length(x)) + if (is_null(bin.var)) bin.var <- all(x == 0 | x == 1) + + w <- w / sum(w) #weights normalized to sum to 1 + mx <- sum(w * x) #weighted mean + + if (bin.var) { + return(mx * (1 - mx)) + } + + #Reliability weights variance; same as cov.wt() + sum(w * (x - mx)^2)/(1 - sum(w^2)) +} + +#Weighted mean faster than weighted.mean() +wm <- function(x, w = NULL, na.rm = TRUE) { + if (is_null(w)) { + if (anyNA(x)) { + if (!na.rm) return(NA_real_) + nas <- which(is.na(x)) + x <- x[-nas] + } + return(sum(x)/length(x)) + } + + if (anyNA(x) || anyNA(w)) { + if (!na.rm) return(NA_real_) + nas <- which(is.na(x) | is.na(w)) + x <- x[-nas] + w <- w[-nas] + } + + sum(x*w)/sum(w) +} + +#Faster diff() +diff1 <- function(x) { + x[-1] - x[-length(x)] +} + +#cumsum() for probabilities to ensure they are between 0 and 1 +.cumsum_prob <- function(x) { + s <- cumsum(x) + s / s[length(s)] +} + +#Make vector sum to 1, optionally by group +.make_sum_to_1 <- function(x, by = NULL) { + if (is_null(by)) { + return(x / sum(x)) + } + + for (i in unique(by)) { + in_i <- which(by == i) + x[in_i] <- x[in_i] / sum(x[in_i]) + } + + x +} + +#Make vector sum to n (average of 1), optionally by group +.make_sum_to_n <- function(x, by = NULL) { + if (is_null(by)) { + return(length(x) * x / sum(x)) + } + + for (i in unique(by)) { + in_i <- which(by == i) + x[in_i] <- length(in_i) * x[in_i] / sum(x[in_i]) + } + + x +} + +#Functions for error handling; based on chk and rlang +pkg_caller_call <- function(start = 1) { + pn <- utils::packageName() + package.funs <- c(getNamespaceExports(pn), + .getNamespaceInfo(asNamespace(pn), "S3methods")[, 3]) + k <- start #skip checking pkg_caller_call() + e_max <- start + while (is_not_null(e <- rlang::caller_call(k))) { + if (is_not_null(n <- rlang::call_name(e)) && + n %in% package.funs) e_max <- k + k <- k + 1 + } + rlang::caller_call(e_max) +} + +.err <- function(..., n = NULL, tidy = TRUE) { + m <- chk::message_chk(..., n = n, tidy = tidy) + rlang::abort(paste(strwrap(m), collapse = "\n"), + call = pkg_caller_call(start = 2)) +} +.wrn <- function(..., n = NULL, tidy = TRUE, immediate = TRUE) { + if (immediate && isTRUE(all.equal(0, getOption("warn")))) { + op <- options(warn = 1) + on.exit(options(op)) + } + m <- chk::message_chk(..., n = n, tidy = tidy) + rlang::warn(paste(strwrap(m), collapse = "\n")) +} +.msg <- function(..., n = NULL, tidy = TRUE) { + m <- chk::message_chk(..., n = n, tidy = tidy) + rlang::inform(paste(strwrap(m), collapse = "\n"), tidy = FALSE) +} \ No newline at end of file From bfe730045ac1afabfeb564bad7a04fadddd7d706 Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Thu, 24 Oct 2024 18:45:45 -0400 Subject: [PATCH 12/48] Improvements --- R/get_weights_from_mm.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/get_weights_from_mm.R b/R/get_weights_from_mm.R index 2cf75edc..1d8ec6f6 100644 --- a/R/get_weights_from_mm.R +++ b/R/get_weights_from_mm.R @@ -1,10 +1,10 @@ -get_weights_from_mm <- function(match.matrix, treat) { +get_weights_from_mm <- function(match.matrix, treat, focal = NULL) { if (!is.integer(match.matrix)) { match.matrix <- charmm2nummm(match.matrix, treat) } - weights <- weights_matrixC(match.matrix, treat) + weights <- weights_matrixC(match.matrix, treat, focal) if (sum(weights) == 0) .err("No units were matched") From 1c456084c7c8b77aaffca2a34e6988ab8e97223f Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Thu, 24 Oct 2024 18:46:12 -0400 Subject: [PATCH 13/48] Added progressbar with ETA and EMA estimation --- src/eta_progress_bar.h | 248 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 248 insertions(+) create mode 100644 src/eta_progress_bar.h diff --git a/src/eta_progress_bar.h b/src/eta_progress_bar.h new file mode 100644 index 00000000..a7e7e7ae --- /dev/null +++ b/src/eta_progress_bar.h @@ -0,0 +1,248 @@ +/* + * eta_progress_bar.h + * + * A custom ProgressBar class to display a progress bar with time estimation + * + * Author: clemens@nevrome.de + * + * Copied from https://github.com/kforner/rcpp_progress/blob/master/inst/examples/RcppProgressETA/src/eta_progress_bar.hpp with modifications by NHG + * + */ +#ifndef _RcppProgress_ETA_PROGRESS_BAR_HPP +#define _RcppProgress_ETA_PROGRESS_BAR_HPP + +#include +#include +#include +#include +#include + +#include "progress_bar.hpp" + +// for unices only +#if !defined(WIN32) && !defined(__WIN32) && !defined(__WIN32__) +#include +#endif + +class ETAProgressBar: public ProgressBar{ + public: // ====== LIFECYCLE ===== + + /** + * Main constructor + */ + ETAProgressBar() { + _max_ticks = 50; + _finalized = false; + _timer_flag = true; + } + + ~ETAProgressBar() { + } + + public: // ===== main methods ===== + + void display() { + REprintf("0%% 10 20 30 40 50 60 70 80 90 100%%\n"); + REprintf("[----|----|----|----|----|----|----|----|----|----|\n"); + flush_console(); + } + + // update display + void update(float progress) { + + // stop if already finalized + if (_finalized) return; + + // start time measurement when update() is called the first time + if (_timer_flag) { + _timer_flag = false; + // measure start time + time(&start); + last_refresh = start; + current_old = start; + progress_old = progress; + + ema_rate = 0; + + time_string = "calculating..."; + + // create progress bar string + std::string progress_bar_string = _current_ticks_display(progress); + + // merge progress bar and time string + std::stringstream strs; + strs << "|" << progress_bar_string << "| ETA: " << time_string; + std::string temp_str = strs.str(); + char const* char_type = temp_str.c_str(); + + // print: remove old and replace with new + REprintf("\r"); + REprintf("%s", char_type); + } else { + + // measure current time + time(¤t_new); + + if (progress != 1) { + // create progress bar string + std::string progress_bar_string = _current_ticks_display(progress); + + // ensure overwriting of old time info + int empty_length = time_string.length(); + + double time_since_start = std::difftime(current_new, start); + + if (time_since_start <= 1) { + ema_rate = progress / time_since_start; + } + else { + double time_since_last_refresh = std::difftime(current_new, last_refresh); + + if (time_since_last_refresh >= .5) { + double current_rate = (progress - progress_old) / time_since_last_refresh; + + double alpha = .8; + + ema_rate = alpha * ema_rate + (1 - alpha) * current_rate; + + // convert seconds to time string + time_string = "~"; + time_string += _time_to_string((1 - progress) / ema_rate); + + last_refresh = current_new; + progress_old = progress; + } + } + + std::string empty_space = std::string(std::fdim(empty_length, time_string.length()), ' '); + + // merge progress bar and time string + std::stringstream strs; + strs << "|" << progress_bar_string << "| ETA: " << time_string << empty_space; + std::string temp_str = strs.str(); + char const* char_type = temp_str.c_str(); + + // print: remove old and replace with new + REprintf("\r"); + REprintf("%s", char_type); + + current_old = current_new; + + } else { + // ensure overwriting of old time info + int empty_length = time_string.length(); + + // finalize display when ready + double time_since_start = std::difftime(current_new, start); + // convert seconds to time string + std::string time_string = _time_to_string(time_since_start); + + std::string empty_space = std::string(std::fdim(empty_length, time_string.length()), ' '); + + // create progress bar string + std::string progress_bar_string = _current_ticks_display(progress); + + // merge progress bar and time string + std::stringstream strs; + strs << "|" << progress_bar_string << "| " << "Elapsed: " << time_string << empty_space; + + std::string temp_str = strs.str(); + char const* char_type = temp_str.c_str(); + + // print: remove old and replace with new + REprintf("\r"); + REprintf("%s", char_type); + + _finalize_display(); + } + } + } + + void end_display() { + update(1); + } + + protected: // ==== other instance methods ===== + + // convert double with seconds to time string + std::string _time_to_string(double seconds) { + + int time = (int) seconds; + + int hour = 0; + int min = 0; + int sec = 0; + + hour = time / 3600; + time = time % 3600; + min = time / 60; + time = time % 60; + sec = time; + + std::stringstream time_strs; + if (hour != 0) time_strs << hour << "h "; + if (min != 0) time_strs << min << "min "; + if (sec != 0 || (hour == 0 && min == 0)) time_strs << sec << "s "; + std::string time_str = time_strs.str(); + + return time_str; + } + + // update the ticks display corresponding to progress + std::string _current_ticks_display(float progress) { + + int nb_ticks = _compute_nb_ticks(progress); + + std::string cur_display = _construct_ticks_display_string(nb_ticks); + + return cur_display; + } + + // construct progress bar display + std::string _construct_ticks_display_string(int nb) { + + std::stringstream ticks_strs; + for (int i = 0; i < (_max_ticks - 1); ++i) { + if (i < nb) { + ticks_strs << "="; + } else { + ticks_strs << " "; + } + } + std::string tick_space_string = ticks_strs.str(); + + return tick_space_string; + } + + // finalize + void _finalize_display() { + if (_finalized) return; + + REprintf("\n"); + flush_console(); + _finalized = true; + } + + // compute number of ticks according to progress + int _compute_nb_ticks(float progress) { + return int(progress * _max_ticks); + } + + // N.B: does nothing on windows + void flush_console() { +#if !defined(WIN32) && !defined(__WIN32) && !defined(__WIN32__) + R_FlushConsole(); +#endif + } + + private: // ===== INSTANCE VARIABLES ==== + int _max_ticks; // the total number of ticks to print + bool _finalized; + bool _timer_flag; + time_t start, current_new, last_refresh, current_old; + double ema_rate, progress_old; + std::string time_string; + +}; + +#endif \ No newline at end of file From 8f1f1312474b8a78acda31e11e3d763f8571d51a Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Thu, 24 Oct 2024 18:46:25 -0400 Subject: [PATCH 14/48] Utility to speed up processing on large vectors --- src/has_n_unique.cpp | 65 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 src/has_n_unique.cpp diff --git a/src/has_n_unique.cpp b/src/has_n_unique.cpp new file mode 100644 index 00000000..b67d0acc --- /dev/null +++ b/src/has_n_unique.cpp @@ -0,0 +1,65 @@ +#include +#include +using namespace Rcpp; + +// [[Rcpp::plugins(cpp11)]] + +// Templated function to check if a vector has exactly n unique values +template +bool has_n_unique_(Vector x, + const int& n) { + + Vector seen(n); + seen[0] = x[0]; + int n_seen = 1; + + int j; + bool was_seen; + + // Iterate over the vector and add elements to the unordered set + for (auto it = x.begin() + 1; it != x.end(); ++it) { + if (*it == *(it - 1)) { + continue; + } + + was_seen = false; + + for (j = 0; j < n_seen; j++) { + if (*it == seen[j]) { + was_seen = true; + break; + } + } + + if (!was_seen) { + n_seen++; + + if (n_seen > n) { + return false; + } + + seen[n_seen - 1] = *it; + } + } + + // Check if the number of unique elements is exactly n + return n_seen == n; +} + +// Wrapper function to handle different types of R vectors +// [[Rcpp::export]] +bool has_n_unique(const SEXP& x, + const int& n) { + switch (TYPEOF(x)) { + case INTSXP: + return has_n_unique_(x, n); + case REALSXP: + return has_n_unique_(x, n); + case STRSXP: + return has_n_unique_(x, n); + case LGLSXP: + return has_n_unique_(x, n); + default: + stop("Unsupported vector type"); + } +} \ No newline at end of file From abe0fb6507ef3dbaf5f25daf567eca8021b87d4e Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Thu, 24 Oct 2024 18:46:55 -0400 Subject: [PATCH 15/48] Reordering and minor cleaning --- R/matchit2cem.R | 162 ++++++++++++++++++++++++++---------------------- 1 file changed, 87 insertions(+), 75 deletions(-) diff --git a/R/matchit2cem.R b/R/matchit2cem.R index 1f8c9faa..a6a4d504 100644 --- a/R/matchit2cem.R +++ b/R/matchit2cem.R @@ -262,8 +262,8 @@ NULL -matchit2cem <- function(treat, covs, estimand = "ATT", s.weights = NULL, verbose = FALSE, ...) { - if (length(covs) == 0) { +matchit2cem <- function(treat, covs, estimand = "ATT", s.weights = NULL, verbose = FALSE, TEST = 1, ...) { + if (is_null(covs)) { .err("Covariates must be specified in the input formula to use coarsened exact matching") } @@ -271,29 +271,51 @@ matchit2cem <- function(treat, covs, estimand = "ATT", s.weights = NULL, verbose A <- list(...) + # if (isTRUE(A[["k2k"]])) { + # if (!has_n_unique(treat, 2L)) { + # .err("`k2k` cannot be `TRUE` with a multi-category treatment") + # } + # } + estimand <- toupper(estimand) estimand <- match_arg(estimand, c("ATT", "ATC", "ATE")) - #Uses in-house cem, no need for cem package. See cem_matchit.R for code. - strat <- do.call("cem_matchit", c(list(treat = treat, X = covs, estimand = estimand, - s.weights = s.weights), - A[names(A) %in% names(formals(cem_matchit))]), - quote = TRUE) + #Uses in-house cem, no need for cem package. + strat <- do.call("cem_matchit", c(list(treat = treat, X = covs), + A[names(A) %in% names(formals(cem_matchit))])) levels(strat) <- seq_len(nlevels(strat)) names(strat) <- names(treat) mm <- NULL if (isTRUE(A[["k2k"]])) { - mm <- nummm2charmm(subclass2mmC(strat, treat, focal = switch(estimand, "ATC" = 0, 1)), + focal <- switch(estimand, "ATC" = 0, 1) + + strat <- do.call("do_k2k", c(list(treat = treat, X = covs, subclass = strat, + estimand = estimand, + s.weights = s.weights), + A[names(A) %in% names(formals(do_k2k))])) + + strat <- setNames(factor(strat), names(treat)) + levels(strat) <- seq_len(nlevels(strat)) + + mm <- nummm2charmm(subclass2mmC(strat, treat, focal = focal), treat) + + weights <- get_weights_from_mm(mm, treat, focal) + } + else { + strat <- setNames(factor(strat), names(treat)) + levels(strat) <- seq_len(nlevels(strat)) + + weights <- get_weights_from_subclass(strat, treat, estimand) } if (verbose) cat("Calculating matching weights... ") res <- list(match.matrix = mm, subclass = strat, - weights = get_weights_from_subclass(strat, treat, estimand)) + weights = weights) if (verbose) cat("Done.\n") @@ -302,30 +324,11 @@ matchit2cem <- function(treat, covs, estimand = "ATT", s.weights = NULL, verbose res } -cem_matchit <- function(treat, X, cutpoints = "sturges", grouping = list(), k2k = FALSE, - k2k.method = "mahalanobis", mpower = 2, s.weights = NULL, - estimand = "ATT") { +cem_matchit <- function(treat, X, cutpoints = "sturges", grouping = list()) { #In-house implementation of cem. Basically the same except: #treat is a vector if treatment status, not the name of a variable #X is a data.frame of covariates #when cutpoints are given as integer or string, they define the number of bins, not the number of breakpoints. "ss" is no longer allowed. - #When k2k = TRUE, subclasses are created for each pair, mimicking true matching, not each covariate combo. - #k2k.method is used instead of method. When k2k.method = NULL, units are matched based on order rather than random. Default is "mahalanobis" (not available in cem). - #k2k now works with single covariates (previously it was ignored). k2k uses original variables, not coarsened versions - - if (k2k) { - if (length(unique(treat)) > 2L) { - .err("`k2k` cannot be `TRUE` with a multi-category treatment") - } - - if (!is.null(k2k.method)) { - k2k.method <- tolower(k2k.method) - k2k.method <- match_arg(k2k.method, c(matchit_distances(), "maximum", "manhattan", "canberra", "binary", "minkowski")) - - X.match <- transform_covariates(data = X, s.weights = s.weights, treat = treat, - method = if (k2k.method %in% matchit_distances()) k2k.method else "euclidean") - } - } for (i in names(X)) { if (is.ordered(X[[i]])) X[[i]] <- as.numeric(X[[i]]) @@ -334,8 +337,8 @@ cem_matchit <- function(treat, X, cutpoints = "sturges", grouping = list(), k2k is.numeric.cov <- setNames(vapply(X, is.numeric, logical(1L)), names(X)) #Process grouping - if (length(grouping) > 0L) { - if (!is.list(grouping) || is.null(names(grouping))) { + if (is_not_null(grouping)) { + if (!is.list(grouping) || is_null(names(grouping))) { .err("`grouping` must be a named list of grouping values with an element for each variable whose values are to be binned") } @@ -385,7 +388,7 @@ cem_matchit <- function(treat, X, cutpoints = "sturges", grouping = list(), k2k cutpoints <- setNames(lapply(names(X)[is.numeric.cov], function(i) cutpoints), names(X)[is.numeric.cov]) } - if (is.null(names(cutpoints))) { + if (is_null(names(cutpoints))) { .err("`cutpoints` must be a named list of binning values with an element for each numeric variable") } @@ -399,7 +402,7 @@ cem_matchit <- function(treat, X, cutpoints = "sturges", grouping = list(), k2k cutpoints[bad.names] <- NULL } - if (length(grouping) > 0L) { + if (is_not_null(grouping)) { grouping.cutpoint.names <- intersect(names(grouping), names(cutpoints)) ngc <- length(grouping.cutpoint.names) @@ -425,7 +428,7 @@ cem_matchit <- function(treat, X, cutpoints = "sturges", grouping = list(), k2k bad.cuts <- setNames(rep(FALSE, length(cutpoints)), names(cutpoints)) for (i in names(cutpoints)) { - if (length(cutpoints[[i]]) == 0L) { + if (is_null(cutpoints[[i]])) { cutpoints[[i]] <- "sturges" } else if (length(cutpoints[[i]]) > 1L) { @@ -466,7 +469,7 @@ cem_matchit <- function(treat, X, cutpoints = "sturges", grouping = list(), k2k #Create bins for numeric variables for (i in names(X)[is.numeric.cov]) { bins <- { - if (!is.null(cutpoints) && i %in% names(cutpoints)) cutpoints[[i]] + if (is_not_null(cutpoints) && i %in% names(cutpoints)) cutpoints[[i]] else "sturges" } @@ -498,7 +501,7 @@ cem_matchit <- function(treat, X, cutpoints = "sturges", grouping = list(), k2k X[[i]] <- findInterval(X[[i]], breaks) } - if (length(X) == 0L) { + if (is_null(X)) { subclass <- setNames(rep(1L, length(treat)), names(treat)) } else { @@ -507,63 +510,72 @@ cem_matchit <- function(treat, X, cutpoints = "sturges", grouping = list(), k2k cc <- do.call("intersect", unname(split(xx, treat))) - if (length(cc) == 0L) { + if (is_null(cc)) { .err("no units were matched. Try coarsening the variables further or decrease the number of variables to match on") } subclass <- setNames(match(xx, cc), names(treat)) } - extra.sub <- max(subclass, na.rm = TRUE) + subclass +} + +do_k2k <- function(treat, X, subclass, k2k.method = "mahalanobis", mpower = 2, s.weights = NULL, + estimand = "ATT") { - if (k2k) { + if (is_not_null(k2k.method)) { + k2k.method <- tolower(k2k.method) + k2k.method <- match_arg(k2k.method, c(matchit_distances(), "maximum", "manhattan", "canberra", "binary", "minkowski")) - na.sub <- is.na(subclass) + X.match <- transform_covariates(data = X, s.weights = s.weights, treat = treat, + method = if (k2k.method %in% matchit_distances()) k2k.method else "euclidean") + } - s <- switch(estimand, "ATC" = 0, 1) + na.sub <- is.na(subclass) - for (i in which(tabulateC(subclass[!na.sub]) > 2)) { + s <- switch(estimand, "ATC" = 0, 1) - in.sub <- which(!na.sub & subclass == i) + extra.sub <- max(subclass, na.rm = TRUE) - #Compute distance matrix; all 0s if k2k.method = NULL for matched based on data order - if (is.null(k2k.method)) { - dist.mat <- matrix(0, nrow = sum(treat[in.sub] == s), ncol = sum(treat[in.sub] != s), - dimnames = list(names(treat)[in.sub][treat[in.sub] == s], - names(treat)[in.sub][treat[in.sub] != s])) + for (i in which(tabulateC(subclass[!na.sub]) > 2)) { - } - else if (k2k.method %in% matchit_distances()) { - #X.match has been transformed - dist.mat <- eucdist_internal(X.match[in.sub,,drop = FALSE], treat[in.sub] == s) - } - else { - dist.mat <- as.matrix(dist(X.match[in.sub,,drop = FALSE], method = k2k.method, p = mpower)) + in.sub <- which(!na.sub & subclass == i) - #Put smaller group on rows - d.rows <- which(rownames(dist.mat) %in% names(treat[in.sub])[treat[in.sub] == s]) - dist.mat <- dist.mat[d.rows, -d.rows, drop = FALSE] - } + #Compute distance matrix; all 0s if k2k.method = NULL for matched based on data order + if (is_null(k2k.method)) { + dist.mat <- matrix(0, nrow = sum(treat[in.sub] == s), ncol = sum(treat[in.sub] != s), + dimnames = list(names(treat)[in.sub][treat[in.sub] == s], + names(treat)[in.sub][treat[in.sub] != s])) - #For each member of group on row, find closest remaining pair from cols - while (all(dim(dist.mat) > 0)) { - extra.sub <- extra.sub + 1 + } + else if (k2k.method %in% matchit_distances()) { + #X.match has been transformed + dist.mat <- eucdist_internal(X.match[in.sub,,drop = FALSE], treat[in.sub] == s) + } + else { + dist.mat <- as.matrix(dist(X.match[in.sub,,drop = FALSE], method = k2k.method, p = mpower)) - closest <- which.min(dist.mat[1,]) - subclass[c(rownames(dist.mat)[1], colnames(dist.mat)[closest])] <- extra.sub + #Put smaller group on rows + d.rows <- which(rownames(dist.mat) %in% names(treat[in.sub])[treat[in.sub] == s]) + dist.mat <- dist.mat[d.rows, -d.rows, drop = FALSE] + } - #Drop already paired units from dist.mat - dist.mat <- dist.mat[-1,-closest, drop = FALSE] - } + #For each member of group on row, find closest remaining pair from cols + while (all(dim(dist.mat) > 0)) { + extra.sub <- extra.sub + 1 - #If any unmatched units remain, give them NA subclass - if (any(dim(dist.mat) > 0)) { - is.na(subclass)[unlist(dimnames(dist.mat))] <- TRUE - } + closest <- which.min(dist.mat[1,]) + subclass[c(rownames(dist.mat)[1], colnames(dist.mat)[closest])] <- extra.sub + + #Drop already paired units from dist.mat + dist.mat <- dist.mat[-1,-closest, drop = FALSE] } - } - subclass <- factor(subclass, nmax = extra.sub) + #If any unmatched units remain, give them NA subclass + if (any(dim(dist.mat) > 0)) { + is.na(subclass)[unlist(dimnames(dist.mat))] <- TRUE + } + } - setNames(subclass, names(treat)) -} + subclass +} \ No newline at end of file From 8195b1d9d3586f102fb568ce1453503085d8aad8 Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Thu, 24 Oct 2024 18:47:14 -0400 Subject: [PATCH 16/48] Improvements and updates --- R/matchit2genetic.R | 52 ++++++++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/R/matchit2genetic.R b/R/matchit2genetic.R index cd0fcec9..721e3b11 100644 --- a/R/matchit2genetic.R +++ b/R/matchit2genetic.R @@ -80,7 +80,7 @@ #' to the distance matrix. Use `mahvars` to only supply a subset. Even if #' `mahvars` is specified, balance will be optimized on all covariates in #' `formula`. See Details. -#' @param antiexact for which variables ant-exact matching should take place. +#' @param antiexact for which variables anti-exact matching should take place. #' Anti-exact matching is processed using the `restrict` argument to #' `Matching::GenMatch()` and `Matching::Match()`. #' @param discard a string containing a method for discarding units outside a @@ -190,6 +190,10 @@ #' support the ATE as an estimand, `matchit()` only supports the ATT and #' ATC for genetic matching. #' +#' ## Reproducibility +#' +#' Genetic matching involves a random component, so a seed must be set using [set.seed()] to ensure reproducibility. When `cluster` is used for parallel processing, the seed must be compatible with parallel processing (e.g., by setting `type = "L'Ecuyer-CMRG"`). +#' #' @seealso [matchit()] for a detailed explanation of the inputs and outputs of #' a call to `matchit()`. #' @@ -286,11 +290,11 @@ matchit2genetic <- function(treat, data, distance, discarded, n.obs <- length(treat) n1 <- sum(treat == 1) - if (is.null(names(treat))) names(treat) <- seq_len(n.obs) + if (is_null(names(treat))) names(treat) <- seq_len(n.obs) m.order <- { - if (is.null(distance)) match_arg(m.order, c("data", "random")) - else if (!is.null(m.order)) match_arg(m.order, c("largest", "smallest", "random", "data")) + if (is_null(distance)) match_arg(m.order, c("data", "random")) + else if (is_not_null(m.order)) match_arg(m.order, c("largest", "smallest", "random", "data")) else if (estimand == "ATC") "smallest" else "largest" } @@ -304,7 +308,7 @@ matchit2genetic <- function(treat, data, distance, discarded, #Create X (matching variables) and covs_to_balance covs_to_balance <- get.covs.matrix(formula, data = data) - if (!is.null(mahvars)) { + if (is_not_null(mahvars)) { X <- get.covs.matrix.for.dist(mahvars, data = data) } else if (is.full.mahalanobis) { @@ -319,12 +323,14 @@ matchit2genetic <- function(treat, data, distance, discarded, } #Process exact; exact.log will be supplied to GenMatch() and Match() - if (!is.null(exact)) { + if (is_not_null(exact)) { #Add covariates in exact not in X to X ex <- as.integer(factor(exactify(model.frame(exact, data = data), names(treat), sep = ", ", include_vars = TRUE))) cc <- intersect(ex[treat==1], ex[treat==0]) - if (length(cc) == 0) .err("No matches were found") + if (is_null(cc)) { + .err("No matches were found") + } X <- cbind(X, ex) @@ -335,15 +341,15 @@ matchit2genetic <- function(treat, data, distance, discarded, } #Process caliper; cal will be supplied to GenMatch() and Match() - if (!is.null(caliper)) { + if (is_not_null(caliper)) { #Add covariates in caliper other than distance (cov.cals) not in X to X cov.cals <- setdiff(names(caliper), "") - if (length(cov.cals) > 0 && any(!cov.cals %in% colnames(X))) { + if (is_not_null(cov.cals) && any(!cov.cals %in% colnames(X))) { calcovs <- get.covs.matrix(reformulate(cov.cals[!cov.cals %in% colnames(X)]), data = data) X <- cbind(X, calcovs) #Expand exact.log for newly added covariates - if (!is.null(exact.log)) exact.log <- c(exact.log, rep(FALSE, ncol(calcovs))) + if (is_not_null(exact.log)) exact.log <- c(exact.log, rep(FALSE, ncol(calcovs))) } else { cov.cals <- NULL @@ -360,20 +366,20 @@ matchit2genetic <- function(treat, data, distance, discarded, cal <- setNames(rep(Inf, ncol(X)), colnames(X)) #First put covariate calipers into cal - if (length(cov.cals) > 0) { + if (is_not_null(cov.cals)) { cal[intersect(cov.cals, names(cal))] <- caliper[intersect(cov.cals, names(cal))] } #Then put distance caliper into cal if ("" %in% names(caliper)) { dist.cal <- caliper[names(caliper) == ""] - if (!is.null(mahvars)) { + if (is_not_null(mahvars)) { #If mahvars specified, distance is not yet in X, so add it to X X <- cbind(X, distance) cal <- c(cal, dist.cal) #Expand exact.log for newly added distance - if (!is.null(exact.log)) exact.log <- c(exact.log, FALSE) + if (is_not_null(exact.log)) exact.log <- c(exact.log, FALSE) } else { #Otherwise, distance is in X at the specified index @@ -394,9 +400,9 @@ matchit2genetic <- function(treat, data, distance, discarded, treat_ <- treat[ord] covs_to_balance <- covs_to_balance[ord,,drop = FALSE] X <- X[ord,,drop = FALSE] - if (!is.null(s.weights)) s.weights <- s.weights[ord] + if (is_not_null(s.weights)) s.weights <- s.weights[ord] - if (!is.null(antiexact)) { + if (is_not_null(antiexact)) { antiexactcovs <- model.frame(antiexact, data)[ord,,drop = FALSE] antiexact_restrict <- cbind(do.call("rbind", lapply(seq_len(ncol(antiexactcovs)), function(i) { unique.vals <- unique(antiexactcovs[,i]) @@ -405,14 +411,14 @@ matchit2genetic <- function(treat, data, distance, discarded, })) })), -1) - if (!is.null(A[["restrict"]])) A[["restrict"]] <- rbind(A[["restrict"]], antiexact_restrict) + if (is_not_null(A[["restrict"]])) A[["restrict"]] <- rbind(A[["restrict"]], antiexact_restrict) else A[["restrict"]] <- antiexact_restrict } else { antiexactcovs <- NULL } - if (is.null(A[["distance.tolerance"]])) { + if (is_null(A[["distance.tolerance"]])) { A[["distance.tolerance"]] <- 0 } @@ -449,7 +455,7 @@ matchit2genetic <- function(treat, data, distance, discarded, distance.tolerance = A[["distance.tolerance"]], Weight = 3, Weight.matrix = { if (use.genetic) g.out - else if (is.null(s.weights)) generalized_inverse(cor(X)) + else if (is_null(s.weights)) generalized_inverse(cor(X)) else generalized_inverse(cov.wt(X, s.weights, cor = TRUE)$cor) }, restrict = A[["restrict"]], version = "fast") @@ -472,10 +478,10 @@ matchit2genetic <- function(treat, data, distance, discarded, # else { # #Use nn_match() instead of Match() # ord1 <- ord[ord %in% which(treat == 1)] - # if (!is.null(cov.cals)) calcovs <- get.covs.matrix(reformulate(cov.cals), data = data) + # if (is_not_null(cov.cals)) calcovs <- get.covs.matrix(reformulate(cov.cals), data = data) # else calcovs <- NULL # - # if (is.null(g.out)) MWM <- generalized_inverse(cov(X)) + # if (is_null(g.out)) MWM <- generalized_inverse(cov(X)) # else MWM <- g.out$Weight.matrix %*% diag(1/apply(X, 2, var)) # # if (isFALSE(A$fast)) { @@ -495,14 +501,16 @@ matchit2genetic <- function(treat, data, distance, discarded, if (replace) { psclass <- NULL + weights <- get_weights_from_mm(mm, treat, 1L) } else { - psclass <- mm2subclass(mm, treat) + psclass <- mm2subclass(mm, treat, 1L) + weights <- get_weights_from_subclass(psclass, treat) } res <- list(match.matrix = nummm2charmm(mm, treat), subclass = psclass, - weights = get_weights_from_mm(mm, treat), + weights = weights, obj = g.out) if (verbose) cat("Done.\n") From 85fa5f9c4dacbc64a2bd52c61bdbe34f060a38dd Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Thu, 24 Oct 2024 18:48:11 -0400 Subject: [PATCH 17/48] Support for long vectors, ETA progress bar --- src/nn_matchC.cpp | 42 +++--- src/nn_matchC_closest.cpp | 132 +++++++++++------- src/nn_matchC_vec.cpp | 64 +++++---- src/nn_matchC_vec_closest.cpp | 253 +++++++++++++++++++--------------- 4 files changed, 280 insertions(+), 211 deletions(-) diff --git a/src/nn_matchC.cpp b/src/nn_matchC.cpp index 5141ebf0..da5c2404 100644 --- a/src/nn_matchC.cpp +++ b/src/nn_matchC.cpp @@ -1,5 +1,6 @@ // [[Rcpp::depends(RcppProgress)]] #include +#include "eta_progress_bar.h" #include #include "internal.h" #include @@ -36,17 +37,20 @@ IntegerMatrix nn_matchC(const IntegerVector& treat_, } } - int n = treat.size(); + R_xlen_t n = treat.size(); IntegerVector ind = Range(0, n - 1); - int i, gi; + R_xlen_t i; + int gi; IntegerVector indt(n); - IntegerVector indt_begin(g), indt_end(g); + IntegerVector indt_sep(g + 1); IntegerVector indt_tmp; IntegerVector nt(g); - IntegerVector ind_match = rep(NA_INTEGER, n); + IntegerVector ind_match(n); + ind_match.fill(NA_INTEGER); - IntegerVector times_matched = rep(0, n); + IntegerVector times_matched(n); + times_matched.fill(0); LogicalVector eligible = !discarded; IntegerVector g_c = Range(0, g - 1); @@ -60,26 +64,23 @@ IntegerMatrix nn_matchC(const IntegerVector& treat_, int max_nc = max(as(nt[g_c])); - indt_begin[0] = 0; - indt_end[0] = nt[0]; + indt_sep[0] = 0; for (gi = 0; gi < g; gi++) { - if (gi > 0) { - indt_begin[gi] = indt_end[gi - 1]; - indt_end[gi] = indt_begin[gi] + nt[gi]; - } + indt_sep[gi + 1] = indt_sep[gi] + nt[gi]; indt_tmp = ind[treat == gi]; for (i = 0; i < nt[gi]; i++) { - indt[indt_begin[gi] + i] = indt_tmp[i]; + indt[indt_sep[gi] + i] = indt_tmp[i]; ind_match[indt_tmp[i]] = i; } } - IntegerVector ind_focal = indt[Range(indt_begin[focal], indt_end[focal] - 1)]; + IntegerVector ind_focal = indt[Range(indt_sep[focal], indt_sep[focal + 1] - 1)]; - IntegerVector times_matched_allowed = rep(reuse_max, n); + IntegerVector times_matched_allowed(n); + times_matched_allowed.fill(reuse_max); times_matched_allowed[ind_focal] = ratio; IntegerVector n_eligible(unique_treat.size()); @@ -174,7 +175,8 @@ IntegerMatrix nn_matchC(const IntegerVector& treat_, int prog_length; if (use_reuse_max) prog_length = sum(ratio) + 1; else prog_length = nf + 1; - Progress p(prog_length, disl_prog); + ETAProgressBar pb; + Progress p(prog_length, disl_prog, pb); //Counters int r, t_id_t_i, t_id_i, c_id_i, c, k; @@ -190,7 +192,7 @@ IntegerMatrix nn_matchC(const IntegerVector& treat_, for (i = 0; i < nf && max(as(n_eligible[g_c])) > 0; i++) { counter++; - if (counter % 500 == 0) Rcpp::checkUserInterrupt(); + if (counter % 200 == 0) Rcpp::checkUserInterrupt(); t_id_t_i = ord[i] - 1; // index among treated t_id_i = ind_focal[t_id_t_i]; // index among sample @@ -211,7 +213,7 @@ IntegerMatrix nn_matchC(const IntegerVector& treat_, k = 0; if (n_eligible[gi] > 0) { - for (c = indt_begin[gi]; c < indt_end[gi]; c++) { + for (c = indt_sep[gi]; c < indt_sep[gi + 1]; c++) { c_id_i = indt[c]; if (!eligible[c_id_i]) { @@ -254,7 +256,7 @@ IntegerMatrix nn_matchC(const IntegerVector& treat_, //Compute distances among eligible if (use_mah_covs) { - dist = sqrt(sum(pow(mah_covs.row(t_id_i) - mah_covs.row(c_id_i), 2.0))); + dist = sum(pow(mah_covs.row(t_id_i) - mah_covs.row(c_id_i), 2.0)); } else if (ps_diff_calculated) { dist = ps_diff; @@ -314,7 +316,7 @@ IntegerMatrix nn_matchC(const IntegerVector& treat_, ck_ = matches_i[Range(0, k_total)]; if (use_unit_id) { - ck_ = ind[match(unit_id, as(unit_id[ck_])) > 0]; + ck_ = which(!is_na(match(unit_id, as(unit_id[ck_])))); } for (int ck : ck_) { @@ -354,7 +356,7 @@ IntegerMatrix nn_matchC(const IntegerVector& treat_, k = 0; if (n_eligible[gi] > 0) { - for (c = indt_begin[gi]; c < indt_end[gi]; c++) { + for (c = indt_sep[gi]; c < indt_sep[gi + 1]; c++) { c_id_i = indt[c]; if (!eligible[c_id_i]) { diff --git a/src/nn_matchC_closest.cpp b/src/nn_matchC_closest.cpp index d437bb04..f33bb45b 100644 --- a/src/nn_matchC_closest.cpp +++ b/src/nn_matchC_closest.cpp @@ -1,5 +1,6 @@ // [[Rcpp::depends(RcppProgress)]] #include +#include "eta_progress_bar.h" #include #include "internal.h" using namespace Rcpp; @@ -18,22 +19,74 @@ IntegerMatrix nn_matchC_closest(const NumericMatrix& distance_mat, const Nullable& caliper_covs_mat_ = R_NilValue, const Nullable& antiexact_covs_ = R_NilValue, const Nullable& unit_id_ = R_NilValue, - const bool& disl_prog = false) -{ + const bool& disl_prog = false) { - int r = distance_mat.nrow(); + IntegerVector unique_treat = {0, 1}; + int g = unique_treat.size(); + int focal = 1; - int n = treat.size(); + R_xlen_t n = treat.size(); + IntegerVector ind = Range(0, n - 1); - IntegerMatrix mm(r, max(ratio)); - mm.fill(NA_INTEGER); + R_xlen_t i; + int gi; + IntegerVector indt(n); + IntegerVector indt_sep(g + 1); + IntegerVector indt_tmp; + IntegerVector nt(g); + IntegerVector ind_match(n); + ind_match.fill(NA_INTEGER); + + IntegerVector times_matched(n); + times_matched.fill(0); + LogicalVector eligible = !discarded; + + for (gi = 0; gi < g; gi++) { + nt[gi] = sum(treat == gi); + } - CharacterVector lab = treat.names(); + int nf = nt[focal]; + indt_sep[0] = 0; - IntegerVector ind = Range(0, n - 1); - IntegerVector ind0 = ind[treat == 0]; - IntegerVector ind1 = ind[treat == 1]; + for (gi = 0; gi < g; gi++) { + indt_sep[gi + 1] = indt_sep[gi] + nt[gi]; + + indt_tmp = ind[treat == gi]; + + for (i = 0; i < nt[gi]; i++) { + indt[indt_sep[gi] + i] = indt_tmp[i]; + ind_match[indt_tmp[i]] = i; + } + } + + IntegerVector ind_focal = indt[Range(indt_sep[focal], indt_sep[focal + 1] - 1)]; + + IntegerVector times_matched_allowed(n); + times_matched_allowed.fill(reuse_max); + times_matched_allowed[ind_focal] = ratio; + + IntegerVector n_eligible(g); + for (i = 0; i < n; i++) { + if (eligible[i]) { + n_eligible[treat[i]]++; + } + } + + int max_ratio = max(ratio); + + // Output matrix with sample indices of control units + IntegerMatrix mm(nf, max_ratio); + mm.fill(NA_INTEGER); + CharacterVector lab = treat.names(); + + //exact + bool use_exact = false; + IntegerVector exact; + if (exact_.isNotNull()) { + exact = as(exact_); + use_exact = true; + } //caliper_dist double caliper_dist; @@ -41,7 +94,7 @@ IntegerMatrix nn_matchC_closest(const NumericMatrix& distance_mat, caliper_dist = as(caliper_dist_); } else { - caliper_dist = max_finite(distance_mat) + 1; + caliper_dist = max_finite(distance_mat) + .1; } //caliper_covs @@ -57,14 +110,6 @@ IntegerMatrix nn_matchC_closest(const NumericMatrix& distance_mat, ncc = 0; } - //exact - bool use_exact = false; - IntegerVector exact; - if (exact_.isNotNull()) { - exact = as(exact_); - use_exact = true; - } - //antiexact IntegerMatrix antiexact_covs; int aenc; @@ -84,35 +129,29 @@ IntegerMatrix nn_matchC_closest(const NumericMatrix& distance_mat, use_unit_id = true; } - IntegerVector times_matched = rep(0, n); - LogicalVector eligible = rep(true, n); - eligible[discarded] = false; - IntegerVector times_matched_allowed = rep(reuse_max, n); - times_matched_allowed[ind1] = ratio; - - int n_eligible0 = sum(as(eligible[treat == 0])); - int n_eligible1 = sum(as(eligible[treat == 1])); - //progress bar int prog_length; prog_length = sum(ratio) + 1; - Progress p(prog_length, disl_prog); - p.increment(); + ETAProgressBar pb; + Progress p(prog_length, disl_prog, pb); Function o("order"); IntegerVector d_ord = o(distance_mat); d_ord = d_ord - 1; //Because R uses 1-indexing - int rj, cj, c_id_i, t_id_i; + gi = 0; - for (int dj : d_ord) { + R_xlen_t r = distance_mat.nrow(); - if (n_eligible1 <= 0) { - break; - } + int rj, cj, c_id_i, t_id_i; + int counter = -1; + + for (R_xlen_t dj : d_ord) { + counter++; + if (counter % 200 == 0) Rcpp::checkUserInterrupt(); - if (n_eligible0 <= 0) { + if (min(n_eligible) <= 0) { break; } @@ -127,14 +166,15 @@ IntegerMatrix nn_matchC_closest(const NumericMatrix& distance_mat, cj = dj / r; // Get sample indices of members of potential pair - t_id_i = ind1[rj]; - c_id_i = ind0[cj]; + t_id_i = ind_focal[rj]; // If either member is discarded, move on if (!eligible[t_id_i]) { continue; } + c_id_i = indt[indt_sep[gi] + cj]; + if (!eligible[c_id_i]) { continue; } @@ -158,11 +198,10 @@ IntegerMatrix nn_matchC_closest(const NumericMatrix& distance_mat, mm(rj, sum(!is_na(mm(rj, _)))) = c_id_i; // If unit_id used, increase match count of all units with that ID + ck_ = {t_id_i, c_id_i}; + if (use_unit_id) { - ck_ = ind[unit_id == unit_id[t_id_i] | unit_id == unit_id[c_id_i]]; - } - else { - ck_ = {t_id_i, c_id_i}; + ck_ = which(!is_na(match(unit_id, as(unit_id[ck_])))); } for (int ck : ck_) { @@ -174,12 +213,7 @@ IntegerMatrix nn_matchC_closest(const NumericMatrix& distance_mat, times_matched[ck]++; if (times_matched[ck] >= times_matched_allowed[ck]) { eligible[ck] = false; - if (treat[ck] == 1) { - n_eligible1--; - } - else { - n_eligible0--; - } + n_eligible[treat[ck]]--; } } @@ -189,7 +223,7 @@ IntegerMatrix nn_matchC_closest(const NumericMatrix& distance_mat, p.update(prog_length); mm = mm + 1; - rownames(mm) = lab[treat == 1]; + rownames(mm) = lab[ind_focal]; return mm; } diff --git a/src/nn_matchC_vec.cpp b/src/nn_matchC_vec.cpp index cc8ef424..0b364673 100644 --- a/src/nn_matchC_vec.cpp +++ b/src/nn_matchC_vec.cpp @@ -1,5 +1,6 @@ // [[Rcpp::depends(RcppProgress)]] #include +#include "eta_progress_bar.h" #include #include "internal.h" using namespace Rcpp; @@ -33,17 +34,20 @@ IntegerMatrix nn_matchC_vec(const IntegerVector& treat_, } } - int n = treat.size(); + R_xlen_t n = treat.size(); IntegerVector ind = Range(0, n - 1); - int i, gi; + R_xlen_t i; + int gi; IntegerVector indt(n); - IntegerVector indt_begin(g), indt_end(g); + IntegerVector indt_sep(g + 1); IntegerVector indt_tmp; IntegerVector nt(g); - IntegerVector ind_match = rep(NA_INTEGER, n); + IntegerVector ind_match(n); + ind_match.fill(NA_INTEGER); - IntegerVector times_matched = rep(0, n); + IntegerVector times_matched(n); + times_matched.fill(0); LogicalVector eligible = !discarded; IntegerVector g_c = Range(0, g - 1); @@ -55,29 +59,26 @@ IntegerMatrix nn_matchC_vec(const IntegerVector& treat_, int nf = nt[focal]; - indt_begin[0] = 0; - indt_end[0] = nt[0]; + indt_sep[0] = 0; for (gi = 0; gi < g; gi++) { - if (gi > 0) { - indt_begin[gi] = indt_end[gi - 1]; - indt_end[gi] = indt_begin[gi] + nt[gi]; - } + indt_sep[gi + 1] = indt_sep[gi] + nt[gi]; indt_tmp = ind[treat == gi]; for (i = 0; i < nt[gi]; i++) { - indt[indt_begin[gi] + i] = indt_tmp[i]; + indt[indt_sep[gi] + i] = indt_tmp[i]; ind_match[indt_tmp[i]] = i; } } - IntegerVector ind_focal = indt[Range(indt_begin[focal], indt_end[focal] - 1)]; + IntegerVector ind_focal = indt[Range(indt_sep[focal], indt_sep[focal + 1] - 1)]; - IntegerVector times_matched_allowed = rep(reuse_max, n); + IntegerVector times_matched_allowed(n); + times_matched_allowed.fill(reuse_max); times_matched_allowed[ind_focal] = ratio; - IntegerVector n_eligible(unique_treat.size()); + IntegerVector n_eligible(g); for (i = 0; i < n; i++) { if (eligible[i]) { n_eligible[treat[i]]++; @@ -94,13 +95,15 @@ IntegerMatrix nn_matchC_vec(const IntegerVector& treat_, //Use base::order() because faster than Rcpp implementation of order() Function o("order"); - IntegerVector ind_d_ord = o(distance, Named("decreasing") = false); + IntegerVector ind_d_ord = o(distance); ind_d_ord = ind_d_ord - 1; //location of each unit after sorting IntegerVector match_d_ord = match(ind, ind_d_ord) - 1; - IntegerVector first_control = rep(0, g); - IntegerVector last_control = rep(n - 1, g); + IntegerVector last_control(g); + last_control.fill(n - 1); + IntegerVector first_control(g); + first_control.fill(0); //exact bool use_exact = false; @@ -150,7 +153,8 @@ IntegerMatrix nn_matchC_vec(const IntegerVector& treat_, //progress bar int prog_length = sum(ratio) + 1; - Progress p(prog_length, disl_prog); + ETAProgressBar pb; + Progress p(prog_length, disl_prog, pb); int r, t_id_t_i, t_id_i, c_id_i, c; IntegerVector ck_; @@ -164,7 +168,7 @@ IntegerMatrix nn_matchC_vec(const IntegerVector& treat_, // t_id_t_i; index of treated unit to match among treated units // t_id_i: index of treated unit to match among all units counter++; - if (counter % 500 == 0) Rcpp::checkUserInterrupt(); + if (counter % 200 == 0) Rcpp::checkUserInterrupt(); t_id_t_i = ord[i] - 1; t_id_i = ind_focal[t_id_t_i]; @@ -182,12 +186,12 @@ IntegerMatrix nn_matchC_vec(const IntegerVector& treat_, k_total = 0; for (int gi : g_c) { - while (!eligible[ind_d_ord[first_control[gi]]] || treat[ind_d_ord[first_control[gi]]] != gi) { - first_control[gi]++; - } - while (!eligible[ind_d_ord[last_control[gi]]] || treat[ind_d_ord[last_control[gi]]] != gi) { - last_control[gi]--; - } + update_first_and_last_control(first_control, + last_control, + ind_d_ord, + eligible, + treat, + gi); c_id_i = find_both(t_id_i, ind_d_ord, @@ -206,8 +210,8 @@ IntegerMatrix nn_matchC_vec(const IntegerVector& treat_, exact, aenc, antiexact_covs, - first_control[gi], - last_control[gi]); + first_control, + last_control); if (c_id_i < 0) { if (r == 1) { @@ -236,7 +240,7 @@ IntegerMatrix nn_matchC_vec(const IntegerVector& treat_, ck_ = matches_i[Range(0, k_total)]; if (use_unit_id) { - ck_ = ind[match(unit_id, as(unit_id[ck_])) > 0]; + ck_ = which(!is_na(match(unit_id, as(unit_id[ck_])))); } for (int ck : ck_) { @@ -259,4 +263,4 @@ IntegerMatrix nn_matchC_vec(const IntegerVector& treat_, rownames(mm) = lab[ind_focal]; return mm; -} \ No newline at end of file +} diff --git a/src/nn_matchC_vec_closest.cpp b/src/nn_matchC_vec_closest.cpp index 32d930ba..0b2abd18 100644 --- a/src/nn_matchC_vec_closest.cpp +++ b/src/nn_matchC_vec_closest.cpp @@ -1,5 +1,6 @@ // [[Rcpp::depends(RcppProgress)]] #include +#include "eta_progress_bar.h" #include #include "internal.h" #include @@ -21,32 +22,81 @@ IntegerMatrix nn_matchC_vec_closest(const IntegerVector& treat, const Nullable& unit_id_ = R_NilValue, const bool& disl_prog = false) { - int n = treat.size(); + IntegerVector unique_treat = {0, 1}; + int g = unique_treat.size(); + int focal = 1; - //Use base::order() because faster than Rcpp implementation of order() - Function o("order"); + R_xlen_t n = treat.size(); + IntegerVector ind = Range(0, n - 1); - IntegerVector ind_d_ord = o(distance, Named("decreasing") = false); - ind_d_ord = ind_d_ord - 1; //location of each unit after sorting + R_xlen_t i; + int gi; + IntegerVector indt(n); + IntegerVector indt_sep(g + 1); + IntegerVector indt_tmp; + IntegerVector nt(g); + IntegerVector ind_match(n); + ind_match.fill(NA_INTEGER); - IntegerVector ind = Range(0, n - 1); - IntegerVector ind1 = ind[treat == 1]; + IntegerVector times_matched(n); + times_matched.fill(0); + LogicalVector eligible = !discarded; - int n1 = ind1.size(); + // IntegerVector g_c = Range(0, g - 1); + // g_c = g_c[g_c != focal]; - int i, j; + for (gi = 0; gi < g; gi++) { + nt[gi] = sum(treat == gi); + } - IntegerVector match_d_ord = match(ind, ind_d_ord) - 1; + int nf = nt[focal]; - int max_ratio = max(ratio); + indt_sep[0] = 0; + + for (gi = 0; gi < g; gi++) { + indt_sep[gi + 1] = indt_sep[gi] + nt[gi]; + + indt_tmp = ind[treat == gi]; - //ind: 1 2 3 4 5 6 7 8 - //ind1: 2 3 5 7 + for (i = 0; i < nt[gi]; i++) { + indt[indt_sep[gi] + i] = indt_tmp[i]; + ind_match[indt_tmp[i]] = i; + } + } + + IntegerVector ind_focal = indt[Range(indt_sep[focal], indt_sep[focal + 1] - 1)]; + + IntegerVector times_matched_allowed(n); + times_matched_allowed.fill(reuse_max); + times_matched_allowed[ind_focal] = ratio; + + IntegerVector n_eligible(g); + for (i = 0; i < n; i++) { + if (eligible[i]) { + n_eligible[treat[i]]++; + } + } - IntegerMatrix mm(n1, max_ratio); + int max_ratio = max(ratio); + + // Output matrix with sample indices of control units + IntegerMatrix mm(nf, max_ratio); mm.fill(NA_INTEGER); CharacterVector lab = treat.names(); + //Use base::order() because faster than Rcpp implementation of order() + Function o("order"); + + IntegerVector ind_d_ord = o(distance); + ind_d_ord = ind_d_ord - 1; //location of each unit after sorting + + IntegerVector match_d_ord = match(ind, ind_d_ord) - 1; + + IntegerVector last_control(g); + last_control.fill(n - 1); + IntegerVector first_control(g); + first_control.fill(0); + //exact bool use_exact = false; IntegerVector exact; @@ -90,52 +140,35 @@ IntegerMatrix nn_matchC_vec_closest(const IntegerVector& treat, use_unit_id = true; } - IntegerVector ind1_match = rep(NA_INTEGER, n); - for (i = 0; i < n1; i++) { - ind1_match[ind1[i]] = i; - } - //storing closeness - IntegerVector t_id(2 * n1); - IntegerVector c_id(2 * n1); - NumericVector dist = rep(R_PosInf, 2 * n1); - for (i = 0; i < n1; i++) { - t_id[i] = ind1[i]; - t_id[i + n1] = ind1[i]; - c_id[i] = -1; - c_id[i + n1] = -2; + IntegerVector t_id = rep_each(ind_focal, 2); + IntegerVector c_id = rep(-1, 2 * nf); + NumericVector dist = rep(R_PosInf, 2 * nf); + for (i = 0; i < nf; i++) { + c_id[2 * i] = -2; } - IntegerVector times_matched = rep(0, n); - LogicalVector eligible = rep(true, n); - eligible[discarded] = false; - IntegerVector times_matched_allowed = rep(reuse_max, n); - times_matched_allowed[ind1] = ratio; - - IntegerVector times_skipped = rep(0, n); + LogicalVector skipped_once = rep(false, nf); //progress bar int prog_length = sum(ratio) + 1; - Progress p(prog_length, disl_prog); - p.increment(); + ETAProgressBar pb; + Progress p(prog_length, disl_prog, pb); IntegerVector ck_; - int t_id_i, c_id_i, t_id_t_i, k; + int t_id_i, c_id_i, t_id_t_i, c, k; - int first_control = 0; - int last_control = n - 1; - - while (!eligible[ind_d_ord[first_control]] || treat[ind_d_ord[first_control]] == 1) { - first_control++; - } - while (!eligible[ind_d_ord[last_control]] || treat[ind_d_ord[last_control]] == 1) { - last_control--; - } + gi = 0; + update_first_and_last_control(first_control, + last_control, + ind_d_ord, + eligible, + treat, + gi); //Find left and right matches for each treated unit - - for (i = 0; i < (2 * n1); i++) { + for (i = 0; i < (2 * nf); i++) { t_id_i = t_id[i]; if (!eligible[t_id_i]) { @@ -151,7 +184,7 @@ IntegerMatrix nn_matchC_vec_closest(const IntegerVector& treat, treat, distance, eligible, - 0, + gi, ncc, caliper_covs_mat, caliper_covs, @@ -164,10 +197,6 @@ IntegerMatrix nn_matchC_vec_closest(const IntegerVector& treat, last_control); if (k < 0) { - times_skipped[t_id_i]++; - if (times_skipped[t_id_i] == 2) { - eligible[t_id_i] = false; - } continue; } @@ -175,19 +204,6 @@ IntegerMatrix nn_matchC_vec_closest(const IntegerVector& treat, dist[i] = std::abs(distance[t_id_i] - distance[k]); } - int n_eligible0 = 0; - int n_eligible1 = 0; - for (i = 0; i < n; i++) { - if (eligible[i]) { - if (treat[i] == 0) { - n_eligible0++; - } - else { - n_eligible1++; - } - } - } - //Order the list IntegerVector heap_ord = o(dist); heap_ord = heap_ord - 1; @@ -195,10 +211,13 @@ IntegerMatrix nn_matchC_vec_closest(const IntegerVector& treat, //Go down the list; update as needed int hi; int counter = -1; + bool find_new = false; + + i = 0; + while (min(n_eligible) > 0 && i < (2 * nf)) { - for (i = 0; (i < 2 * n1) && n_eligible1 > 0 && n_eligible0 > 0; i++) { counter++; - if (counter % 500 == 0) Rcpp::checkUserInterrupt(); + if (counter % 200 == 0) Rcpp::checkUserInterrupt(); hi = heap_ord[i]; @@ -209,20 +228,43 @@ IntegerMatrix nn_matchC_vec_closest(const IntegerVector& treat, t_id_i = t_id[hi]; if (!eligible[t_id_i]) { + i++; continue; } c_id_i = c_id[hi]; - if (!eligible[c_id_i]) { - // If control isn't eligible, find new control and try again + t_id_t_i = ind_match[t_id_i]; - while (!eligible[ind_d_ord[first_control]] || treat[ind_d_ord[first_control]] == 1) { - first_control++; + if (c_id_i < 0) { + if (skipped_once[t_id_t_i]) { + eligible[t_id_i] = false; + n_eligible[focal]--; } - while (!eligible[ind_d_ord[last_control]] || treat[ind_d_ord[last_control]] == 1) { - last_control--; + else { + skipped_once[t_id_t_i] = true; } + i++; + continue; + } + + find_new = false; + if (!eligible[c_id_i]) { + find_new = true; + } + else if (!mm_okay(times_matched[t_id_i] + 1, c_id_i, mm.row(t_id_t_i))) { + find_new = true; + } + + if (find_new) { + // If control isn't eligible, find new control and try again + + update_first_and_last_control(first_control, + last_control, + ind_d_ord, + eligible, + treat, + gi); k = find_lr(c_id_i, t_id_i, @@ -231,7 +273,7 @@ IntegerMatrix nn_matchC_vec_closest(const IntegerVector& treat, treat, distance, eligible, - 0, + gi, ncc, caliper_covs_mat, caliper_covs, @@ -243,68 +285,55 @@ IntegerMatrix nn_matchC_vec_closest(const IntegerVector& treat, first_control, last_control); - //If no new control found, mark treated unit as ineligible and continue + c_id[hi] = k; + + //If no new control found, continue if (k < 0) { - times_skipped[t_id_i]++; - if (times_skipped[t_id_i] == 2) { - eligible[t_id_i] = false; - n_eligible1--; - } continue; } - c_id[hi] = k; dist[hi] = std::abs(distance[t_id_i] - distance[k]); //Find new position of pair in heap - for (j = i; j < (2 * n1) - 1; j++) { - if (dist[heap_ord[j]] < dist[heap_ord[j + 1]]) { + for (c = i; c < (2 * nf) - 1; c++) { + if (dist[heap_ord[c]] < dist[heap_ord[c + 1]]) { break; } - swap_pos(heap_ord, j, j + 1); + swap_pos(heap_ord, c, c + 1); } - i--; + continue; } - else { - t_id_t_i = ind1_match[t_id_i]; - mm(t_id_t_i, sum(!is_na(mm(t_id_t_i, _)))) = c_id_i; + mm(t_id_t_i, sum(!is_na(mm(t_id_t_i, _)))) = c_id_i; - if (use_unit_id) { - ck_ = ind[unit_id == unit_id[t_id_i] | unit_id == unit_id[c_id_i]]; - } - else { - ck_ = {c_id_i, t_id_i}; - } + ck_ = {c_id_i, t_id_i}; - for (int ck : ck_) { + if (use_unit_id) { + ck_ = which(!is_na(match(unit_id, as(unit_id[ck_])))); + } - if (!eligible[ck]) { - continue; - } + for (int ck : ck_) { - times_matched[ck]++; - if (times_matched[ck] >= times_matched_allowed[ck]) { - eligible[ck] = false; - if (treat[ck] == 1) { - n_eligible1--; - } - else { - n_eligible0--; - } - } + if (!eligible[ck]) { + continue; } - p.increment(); + times_matched[ck]++; + if (times_matched[ck] >= times_matched_allowed[ck]) { + eligible[ck] = false; + n_eligible[treat[ck]]--; + } } + + p.increment(); } p.update(prog_length); mm = mm + 1; - rownames(mm) = lab[ind1]; + rownames(mm) = lab[ind_focal]; return mm; -} \ No newline at end of file +} From 3b24a08b72b6704e64a8514950f5b98e9c736b91 Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Thu, 24 Oct 2024 18:49:20 -0400 Subject: [PATCH 18/48] New matching for m.order = "closest" with mahcovs; computing full distance matrix no logner required --- src/nn_matchC_mahcovs_closest.cpp | 371 ++++++++++++++++++++++++++++++ 1 file changed, 371 insertions(+) create mode 100644 src/nn_matchC_mahcovs_closest.cpp diff --git a/src/nn_matchC_mahcovs_closest.cpp b/src/nn_matchC_mahcovs_closest.cpp new file mode 100644 index 00000000..667f711e --- /dev/null +++ b/src/nn_matchC_mahcovs_closest.cpp @@ -0,0 +1,371 @@ +// [[Rcpp::depends(RcppProgress)]] +#include +#include "eta_progress_bar.h" +#include +#include "internal.h" +#include +using namespace Rcpp; + +// [[Rcpp::plugins(cpp11)]] + +// [[Rcpp::export]] +IntegerMatrix nn_matchC_mahcovs_closest(const IntegerVector& treat, + const IntegerVector& ratio, + const LogicalVector& discarded, + const int& reuse_max, + const NumericMatrix& mah_covs, + const Nullable& distance_ = R_NilValue, + const Nullable& exact_ = R_NilValue, + const Nullable& caliper_dist_ = R_NilValue, + const Nullable& caliper_covs_ = R_NilValue, + const Nullable& caliper_covs_mat_ = R_NilValue, + const Nullable& antiexact_covs_ = R_NilValue, + const Nullable& unit_id_ = R_NilValue, + const bool& disl_prog = false) { + + IntegerVector unique_treat = {0, 1}; + int g = unique_treat.size(); + int focal = 1; + + R_xlen_t n = treat.size(); + IntegerVector ind = Range(0, n - 1); + + R_xlen_t i; + int gi; + IntegerVector indt(n); + IntegerVector indt_sep(g + 1); + IntegerVector indt_tmp; + IntegerVector nt(g); + IntegerVector ind_match(n); + ind_match.fill(NA_INTEGER); + + IntegerVector times_matched(n); + times_matched.fill(0); + LogicalVector eligible = !discarded; + + for (gi = 0; gi < g; gi++) { + nt[gi] = sum(treat == gi); + } + + int nf = nt[focal]; + + indt_sep[0] = 0; + + for (gi = 0; gi < g; gi++) { + indt_sep[gi + 1] = indt_sep[gi] + nt[gi]; + + indt_tmp = ind[treat == gi]; + + for (i = 0; i < nt[gi]; i++) { + indt[indt_sep[gi] + i] = indt_tmp[i]; + ind_match[indt_tmp[i]] = i; + } + } + + IntegerVector ind_focal = indt[Range(indt_sep[focal], indt_sep[focal + 1] - 1)]; + + IntegerVector times_matched_allowed(n); + times_matched_allowed.fill(reuse_max); + times_matched_allowed[ind_focal] = ratio; + + IntegerVector n_eligible(g); + for (i = 0; i < n; i++) { + if (eligible[i]) { + n_eligible[treat[i]]++; + } + } + + int max_ratio = max(ratio); + + // Output matrix with sample indices of control units + IntegerMatrix mm(nf, max_ratio); + mm.fill(NA_INTEGER); + CharacterVector lab = treat.names(); + + //exact + bool use_exact = false; + IntegerVector exact; + if (exact_.isNotNull()) { + exact = as(exact_); + use_exact = true; + } + + //distance + bool use_caliper_dist = false; + double caliper_dist, ps_diff; + NumericVector distance; + if (distance_.isNotNull()) { + distance = distance_; + + //caliper_dist + if (caliper_dist_.isNotNull()) { + caliper_dist = as(caliper_dist_); + use_caliper_dist = true; + } + } + + //caliper_covs + NumericVector caliper_covs; + NumericMatrix caliper_covs_mat; + int ncc = 0; + if (caliper_covs_.isNotNull()) { + caliper_covs = as(caliper_covs_); + caliper_covs_mat = as(caliper_covs_mat_); + ncc = caliper_covs_mat.ncol(); + } + + //antiexact + IntegerMatrix antiexact_covs; + int aenc = 0; + if (antiexact_covs_.isNotNull()) { + antiexact_covs = as(antiexact_covs_); + aenc = antiexact_covs.ncol(); + } + + //unit_id + IntegerVector unit_id; + bool use_unit_id = false; + if (unit_id_.isNotNull()) { + unit_id = as(unit_id_); + use_unit_id = true; + } + + //storing closeness + IntegerVector t_id = ind_focal; + IntegerVector c_id = rep(-1, nf); + NumericVector dist = rep(R_PosInf, nf); + + //progress bar + int prog_length = nf + sum(ratio) + 1; + ETAProgressBar pb; + Progress p(prog_length, disl_prog, pb); + + gi = 0; + + IntegerVector ck_, t_inds; + IntegerVector c_eligible(nt[gi]); + NumericVector match_distance(nt[gi]); + + R_xlen_t c; + int c_id_i, t_id_t_i; + int t_id_i = -1; + double dist_c; + bool any_match_found; + + int counter = -1; + int r = 1; + + //Find closest control unit to each treated unit + for (i = 0; i < nf; i++) { + counter++; + if (counter % 200 == 0) Rcpp::checkUserInterrupt(); + + p.increment(); + + t_id_i = ind_focal[i]; + + if (!eligible[t_id_i]) { + continue; + } + + t_inds = which(t_id == t_id_i); + + any_match_found = false; + + for (c = indt_sep[gi]; c < indt_sep[gi + 1]; c++) { + c_id_i = indt[c]; + + if (!eligible[c_id_i]) { + continue; + } + + if (!exact_okay(use_exact, t_id_i, c_id_i, exact)) { + continue; + } + + if (!antiexact_okay(aenc, t_id_i, c_id_i, antiexact_covs)) { + continue; + } + + if (use_caliper_dist) { + ps_diff = std::abs(distance[c_id_i] - distance[t_id_i]); + + if (ps_diff > caliper_dist) { + continue; + } + } + + if (!caliper_covs_okay(ncc, t_id_i, c_id_i, caliper_covs_mat, caliper_covs)) { + continue; + } + + //Compute distances among eligible + dist_c = sqrt(sum(pow(mah_covs.row(t_id_i) - mah_covs.row(c_id_i), 2.0))); + + if (!std::isfinite(dist_c)) { + continue; + } + + if (any_match_found) { + if (dist_c < dist[i]) { + c_id[i] = c_id_i; + dist[i] = dist_c; + } + } + else { + c_id[i] = c_id_i; + dist[i] = dist_c; + any_match_found = true; + } + } + + if (!any_match_found) { + eligible[t_id_i] = false; + n_eligible[focal]--; + } + } + + //Order the list + // Use base::order() because faster than Rcpp implementation of order() + Function o("order"); + IntegerVector heap_ord = o(dist); + heap_ord = heap_ord - 1; + + //Go down the list; update as needed + R_xlen_t hi; + bool find_new; + + i = 0; + while (min(n_eligible) > 0 && i < nf) { + counter++; + if (counter % 200 == 0) Rcpp::checkUserInterrupt(); + + hi = heap_ord[i]; + + t_id_i = t_id[hi]; + + if (!eligible[t_id_i]) { + i++; + continue; + } + + r = times_matched[t_id_i] + 1; + + t_id_t_i = ind_match[t_id_i]; + + c_id_i = c_id[hi]; + + find_new = false; + if (!eligible[c_id_i]) { + find_new = true; + } + else if (!mm_okay(r, c_id_i, mm.row(t_id_t_i))) { + find_new = true; + } + + if (find_new) { + // If control isn't eligible, find new control and try again + any_match_found = false; + + for (c = indt_sep[gi]; c < indt_sep[gi + 1]; c++) { + c_id_i = indt[c]; + + if (!eligible[c_id_i]) { + continue; + } + + //Prevent control units being matched to same treated unit again + if (!mm_okay(r, c_id_i, mm.row(t_id_t_i))) { + continue; + } + + if (!exact_okay(use_exact, t_id_i, c_id_i, exact)) { + continue; + } + + if (!antiexact_okay(aenc, t_id_i, c_id_i, antiexact_covs)) { + continue; + } + + if (use_caliper_dist) { + ps_diff = std::abs(distance[c_id_i] - distance[t_id_i]); + + if (ps_diff > caliper_dist) { + continue; + } + } + + if (!caliper_covs_okay(ncc, t_id_i, c_id_i, caliper_covs_mat, caliper_covs)) { + continue; + } + + //Compute distances among eligible + dist_c = sum(pow(mah_covs.row(t_id_i) - mah_covs.row(c_id_i), 2.0)); + + if (!std::isfinite(dist_c)) { + continue; + } + + if (any_match_found) { + if (dist_c < dist[hi]) { + c_id[hi] = c_id_i; + dist[hi] = dist_c; + } + } + else { + c_id[hi] = c_id_i; + dist[hi] = dist_c; + any_match_found = true; + } + } + + //If no matches... + if (!any_match_found) { + eligible[t_id_i] = false; + n_eligible[focal]--; + continue; + } + + //Find new position of pair in heap + for (c = i; c < nf - 1; c++) { + if (dist[heap_ord[c]] < dist[heap_ord[c + 1]]) { + break; + } + + swap_pos(heap_ord, c, c + 1); + } + + continue; + } + + mm(t_id_t_i, sum(!is_na(mm(t_id_t_i, _)))) = c_id_i; + + ck_ = {c_id_i, t_id_i}; + + if (use_unit_id) { + ck_ = which(!is_na(match(unit_id, as(unit_id[ck_])))); + } + + for (int ck : ck_) { + + if (!eligible[ck]) { + continue; + } + + times_matched[ck]++; + if (times_matched[ck] >= times_matched_allowed[ck]) { + eligible[ck] = false; + n_eligible[treat[ck]]--; + } + } + + p.increment(); + } + + p.update(prog_length); + + mm = mm + 1; + rownames(mm) = lab[ind_focal]; + + return mm; +} From 18841e7ec41667941d6a9e9a6982e14f3ff6beca Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Thu, 24 Oct 2024 18:49:44 -0400 Subject: [PATCH 19/48] Updates to support new matching infrastructure in Rcpp --- R/matchit2nearest.R | 278 +++++++++++++++++++++----------------------- 1 file changed, 135 insertions(+), 143 deletions(-) diff --git a/R/matchit2nearest.R b/R/matchit2nearest.R index 77a158fa..682d3b7e 100644 --- a/R/matchit2nearest.R +++ b/R/matchit2nearest.R @@ -56,12 +56,12 @@ #' by the argument to `distance`. #' @param estimand a string containing the desired estimand. Allowable options #' include `"ATT"` and `"ATC"`. See Details. -#' @param exact for which variables exact matching should take place. +#' @param exact for which variables exact matching should take place; two units with different values of an exact matching variable will not be paired. #' @param mahvars for which variables Mahalanobis distance matching should take #' place when `distance` corresponds to a propensity score (e.g., for #' caliper matching or to discard units for common support). If specified, the #' distance measure will not be used in matching. -#' @param antiexact for which variables ant-exact matching should take place. +#' @param antiexact for which variables anti-exact matching should take place; two units with the same value of an anti-exact matching variable will not be paired. #' @param discard a string containing a method for discarding units outside a #' region of common support. Only allowed when `distance` corresponds to a #' propensity score. @@ -69,20 +69,19 @@ #' re-estimate the propensity score in the remaining sample prior to matching. #' @param s.weights the variable containing sampling weights to be incorporated #' into propensity score models and balance statistics. -#' @param replace whether matching should be done with replacement. +#' @param replace whether matching should be done with replacement (i.e., whether control units can be used as matches multiple times). See also the `reuse.max` argument below. Default is `FALSE` for matching without replacement. #' @param m.order the order that the matching takes place. Allowable options #' include `"largest"`, where matching takes place in descending order of #' distance measures; `"smallest"`, where matching takes place in ascending #' order of distance measures; `"closest"`, where matching takes place in -#' order of the distance between units; `"random"`, where matching takes place +#' ascending order of the distance between units; `"random"`, where matching takes place #' in a random order; and `"data"` where matching takes place based on the #' order of units in the data. When `m.order = "random"`, results may differ #' across different runs of the same code unless a seed is set and specified #' with [set.seed()]. The default of `NULL` corresponds to `"largest"` when a #' propensity score is estimated or supplied as a vector and `"data"` #' otherwise. -#' @param caliper the width(s) of the caliper(s) used for caliper matching. See -#' Details and Examples. +#' @param caliper the width(s) of the caliper(s) used for caliper matching. Two units with a difference on a caliper variable larger than the caliper will not be paired. See Details and Examples. #' @param std.caliper `logical`; when calipers are specified, whether they #' are in standard deviation units (`TRUE`) or raw units (`FALSE`). #' @param ratio how many control units should be matched to each treated unit @@ -93,7 +92,7 @@ #' section "Variable Ratio Matching" in Details below. #' @param verbose `logical`; whether information about the matching #' process should be printed to the console. When `TRUE`, a progress bar -#' implemented using *RcppProgress* will be displayed. +#' implemented using *RcppProgress* will be displayed along with an estimate of the time remaining. #' @param \dots additional arguments that control the matching specification: #' \describe{ #' \item{`reuse.max`}{ `numeric`; the maximum number of @@ -109,8 +108,7 @@ #' unit. Once a control observation has been matched, no other observation with #' the same unit ID can be used as matches. This ensures each control unit is #' used only once even if it has multiple observations associated with it. -#' Omitting this argument is the same as giving each observation a unique ID. -#' Ignored when `replace = TRUE`. } +#' Omitting this argument is the same as giving each observation a unique ID.} #' } #' #' @note Sometimes an error will be produced by *Rcpp* along the lines of @@ -189,8 +187,7 @@ #' #' ## Variable Ratio Matching #' -#' `matchit()` can perform variable -#' ratio "extremal" matching as described by Ming and Rosenbaum (2000). This +#' `matchit()` can perform variable ratio "extremal" matching as described by Ming and Rosenbaum (2000). This #' method tends to result in better balance than fixed ratio matching at the #' expense of some precision. When `ratio > 1`, rather than requiring all #' treated units to receive `ratio` matches, each treated unit is assigned @@ -221,7 +218,11 @@ #' #' ## Using `m.order = "closest"` #' -#' As of version 4.6.0, `m.order` can be set to `"closest"`, which works regardless of how the distance measure is specified. This matches in order of the distance between units. The closest pair of units across all potential pairs of units will be matched first; the second closest pair of all potential pairs will be matched second, etc. This ensures that the best possible matches are given priority, and in that sense performs similarly to `m.order = "smallest"`. +#' As of version 4.6.0, `m.order` can be set to `"closest"`, which works regardless of how the distance measure is specified. This matches in order of the distance between units. The closest pair of units across all potential pairs of units will be matched first; the second closest pair of all potential pairs will be matched second, etc. This ensures that the best possible matches are given priority, and in that sense performs similarly to `m.order = "smallest"`. When `distance` is a vector or a method of estimating a propensity score, an algorithm described by [Rassen et al. (2012)](https://doi.org/10.1002/pds.3263) is used. +#' +#' ## Reproducibility +#' +#' Nearest neighbor matching involves a random component only when `m.order = "random"`, so a seed must be set in that case using [set.seed()] to ensure reproducibility. Otherwise, it is purely deterministic, and any ties are broken based on the order in which the data appear. #' #' @seealso [matchit()] for a detailed explanation of the inputs and outputs of #' a call to `matchit()`. @@ -283,7 +284,7 @@ matchit2nearest <- function(treat, data, distance, discarded, caliper = NULL, mahvars = NULL, exact = NULL, formula = NULL, estimand = "ATT", verbose = FALSE, is.full.mahalanobis, - antiexact = NULL, unit.id = NULL, ...){ + antiexact = NULL, unit.id = NULL, ...) { if (verbose) { rlang::check_installed("RcppProgress") @@ -292,6 +293,7 @@ matchit2nearest <- function(treat, data, distance, discarded, estimand <- toupper(estimand) estimand <- match_arg(estimand, c("ATT", "ATC")) + if (estimand == "ATC") { tc <- c("control", "treated") focal <- 0 @@ -304,7 +306,7 @@ matchit2nearest <- function(treat, data, distance, discarded, treat <- setNames(as.integer(treat == focal), names(treat)) if (is.full.mahalanobis) { - if (length(attr(terms(formula, data = data), "term.labels")) == 0) { + if (is_null(attr(terms(formula, data = data), "term.labels"))) { .err(sprintf("covariates must be specified in the input formula when `distance = \"%s\"`", attr(is.full.mahalanobis, "transform"))) } @@ -318,7 +320,7 @@ matchit2nearest <- function(treat, data, distance, discarded, lab <- names(treat) lab1 <- lab[treat == 1] - if (!is.null(distance)) { + if (is_not_null(distance)) { names(distance) <- names(treat) } @@ -326,8 +328,12 @@ matchit2nearest <- function(treat, data, distance, discarded, max.controls <- attr(ratio, "max.controls") mahcovs <- distance_mat <- NULL - if (!is.null(mahvars)) { - transform <- if (is.full.mahalanobis) attr(is.full.mahalanobis, "transform") else "mahalanobis" + if (is_not_null(mahvars)) { + transform <- { + if (is.full.mahalanobis) attr(is.full.mahalanobis, "transform") + else "mahalanobis" + } + mahcovs <- transform_covariates(mahvars, data = data, method = transform, s.weights = s.weights, treat = treat, discarded = discarded) @@ -343,21 +349,21 @@ matchit2nearest <- function(treat, data, distance, discarded, #Process caliper ex.caliper.list <- list() - if (!is.null(caliper)) { + if (is_not_null(caliper)) { if (any(names(caliper) != "")) { caliper.covs <- caliper[names(caliper) != ""] caliper.covs.mat <- get.covs.matrix(reformulate(names(caliper.covs)), data = data) - ex.caliper.list <- setNames(lapply(names(caliper.covs), function(i) { - splits <- get_splitsC(as.numeric(caliper.covs.mat[,i]), - as.numeric(caliper.covs[i])) - - if (length(splits) == 0) return(NULL) - - cut(caliper.covs.mat[,i], - breaks = splits, - include.lowest = TRUE) - }), names(caliper.covs)) + # ex.caliper.list <- setNames(lapply(names(caliper.covs), function(i) { + # splits <- get_splitsC(as.numeric(caliper.covs.mat[,i]), + # as.numeric(caliper.covs[i])) + # + # if (is_null(splits)) return(NULL) + # + # cut(caliper.covs.mat[,i], + # breaks = splits, + # include.lowest = TRUE) + # }), names(caliper.covs)) } else { caliper.covs.mat <- caliper.covs <- NULL @@ -366,21 +372,21 @@ matchit2nearest <- function(treat, data, distance, discarded, if (any(names(caliper) == "")) { caliper.dist <- caliper[names(caliper) == ""] - if (!is.null(distance)) { - splits <- get_splitsC(as.numeric(distance), - as.numeric(caliper.dist)) - - ex.caliper.list <- c(ex.caliper.list, - list(distance = cut(distance, - breaks = splits, - include.lowest = TRUE))) - } + # if (is_not_null(distance)) { + # splits <- get_splitsC(as.numeric(distance), + # as.numeric(caliper.dist)) + # + # ex.caliper.list <- c(ex.caliper.list, + # list(distance = cut(distance, + # breaks = splits, + # include.lowest = TRUE))) + # } } else { caliper.dist <- NULL } - if (length(ex.caliper.list) > 0L) { + if (is_not_null(ex.caliper.list)) { ex.caliper.list <- ex.caliper.list[lengths(ex.caliper.list) > 0L] for (i in seq_along(ex.caliper.list)) { @@ -393,18 +399,14 @@ matchit2nearest <- function(treat, data, distance, discarded, caliper.covs.mat <- NULL } - #### - ex.caliper.list <- list() - #### - ex.caliper <- { - if (length(ex.caliper.list) == 0) NULL + if (is_null(ex.caliper.list)) NULL else as.factor(exactify(ex.caliper.list, nam = lab, sep = ", ", justify = NULL)) } #Process antiexact - if (!is.null(antiexact)) { + if (is_not_null(antiexact)) { antiexactcovs <- model.frame(antiexact, data) antiexactcovs <- do.call("cbind", lapply(seq_len(ncol(antiexactcovs)), function(i) { as.integer(as.factor(antiexactcovs[[i]])) @@ -416,11 +418,11 @@ matchit2nearest <- function(treat, data, distance, discarded, reuse.max <- attr(replace, "reuse.max") - if (reuse.max >= n1) { - m.order <- "data" - } + # if (reuse.max >= n1) { + # m.order <- "data" + # } - if (!is.null(unit.id) && reuse.max < n1) { + if (is_not_null(unit.id) && reuse.max < n1) { unit.id <- process.variable.input(unit.id, data) unit.id <- factor(exactify(model.frame(unit.id, data = data), nam = lab, sep = ", ", include_vars = TRUE)) @@ -429,7 +431,7 @@ matchit2nearest <- function(treat, data, distance, discarded, #If each unit is a unit.id, unit.ids are meaningless if (num_ctrl_unit.ids == n0 && num_trt_unit.ids == n1) { - # unit.id <- NULL + unit.id <- NULL } } else { @@ -437,27 +439,26 @@ matchit2nearest <- function(treat, data, distance, discarded, } #Process exact - if (!is.null(exact)) { - ex <- as.factor(exactify(model.frame(exact, data = data), nam = lab, sep = ", ", include_vars = TRUE)) - } - else { - ex <- NULL + ex <- { + if (is_not_null(exact)) as.factor(exactify(model.frame(exact, data = data), + nam = lab, sep = ", ", include_vars = TRUE)) + else NULL } - if (!is.null(ex) || !is.null(ex.caliper)) { + if (is_not_null(ex) || is_not_null(ex.caliper)) { ex0 <- { - if (!is.null(ex) && !is.null(ex.caliper)) { + if (is_not_null(ex) && is_not_null(ex.caliper)) { as.factor(exactify(list(ex, ex.caliper), nam = lab, sep = ", ", justify = NULL)) } - else if (!is.null(ex)) ex + else if (is_not_null(ex)) ex else ex.caliper } - cc <- intersect(as.integer(ex0)[treat==1], as.integer(ex0)[treat==0]) + cc <- Reduce("intersect", lapply(unique(treat), function(t) as.integer(ex0)[treat==t])) - if (length(cc) == 0L) { - .err("No matches were found") + if (is_null(cc)) { + .err("no matches were found") } cc <- sort(cc) @@ -467,18 +468,19 @@ matchit2nearest <- function(treat, data, distance, discarded, } if (reuse.max < n1) { - if (!is.null(ex)) { + if (is_not_null(ex)) { e_ratios <- vapply(levels(ex), function(e) { - if (is.null(unit.id)) sum(treat[ex == e] == 0)*(reuse.max/sum(treat[ex == e] == 1)) - else length(unique(unit.id[treat == 0 & ex == e]))*(reuse.max/sum(treat[ex == e] == 1)) + if (is_null(unit.id)) reuse.max * sum(treat[ex == e] == 0) / sum(treat[ex == e] == 1) + else reuse.max * length(unique(unit.id[treat == 0 & ex == e])) / sum(treat[ex == e] == 1) }, numeric(1L)) if (any(e_ratios < 1)) { .wrn(sprintf("fewer %s units than %s units in some `exact` strata; not all %s units will get a match", tc[2], tc[1], tc[1])) } + if (ratio > 1 && any(e_ratios < ratio)) { - if (is.null(max.controls) || ratio == max.controls) + if (is_null(max.controls) || ratio == max.controls) .wrn(sprintf("not all %s units will get %s matches", tc[1], ratio)) else @@ -488,42 +490,44 @@ matchit2nearest <- function(treat, data, distance, discarded, } else { e_ratios <- { - if (is.null(unit.id)) as.numeric(reuse.max)*n0/n1 + if (is_null(unit.id)) as.numeric(reuse.max)*n0/n1 else as.numeric(reuse.max)*num_ctrl_unit.ids/n1 } if (e_ratios < 1) { .wrn(sprintf("fewer %s %s than %s units; not all %s units will get a match", - tc[2], if (is.null(unit.id)) "units" else "unit IDs", tc[1], tc[1])) + tc[2], if (is_null(unit.id)) "units" else "unit IDs", tc[1], tc[1])) } else if (e_ratios < ratio) { - if (is.null(max.controls) || ratio == max.controls) + if (is_null(max.controls) || ratio == max.controls) .wrn(sprintf("not all %s units will get %s matches", tc[1], ratio)) else .wrn(sprintf("not enough %s %s for an average of %s matches per %s unit", - tc[2], if (is.null(unit.id)) "units" else "unit IDs", ratio, tc[1])) + tc[2], if (is_null(unit.id)) "units" else "unit IDs", ratio, tc[1])) } } } #Variable ratio (extremal matching), Ming & Rosenbaum (2000) #Each treated unit get its own value of ratio - if (!is.null(max.controls)) { - if (is.null(distance)) { + if (is_not_null(max.controls)) { + if (is_null(distance)) { if (is.full.mahalanobis) { - .err(sprintf("`distance` cannot be \"%s\" for variable ratio matching", + .err(sprintf("`distance` cannot be \"%s\" for variable ratio nearest neighbor matching", transform)) } - .err("`distance` cannot be supplied as a matrix for variable ratio matching") + else { + .err("`distance` cannot be supplied as a matrix for variable ratio nearest neighbor matching") + } } m <- round(ratio * n1) # if (m > sum(treat == 0)) stop("'ratio' must be less than or equal to n0/n1") - kmax <- floor((m - min.controls*(n1-1)) / (max.controls - min.controls)) + kmax <- floor((m - min.controls * (n1 - 1)) / (max.controls - min.controls)) kmin <- n1 - kmax - 1 - kmed <- m - (min.controls*kmin + max.controls*kmax) + kmed <- m - (min.controls * kmin + max.controls * kmax) ratio0 <- c(rep(min.controls, kmin), kmed, rep(max.controls, kmax)) @@ -546,37 +550,24 @@ matchit2nearest <- function(treat, data, distance, discarded, } m.order <- { - if (is.null(distance)) match_arg(m.order, c("data", "random", "closest")) - else if (!is.null(m.order)) match_arg(m.order, c("largest", "smallest", "data", "random", "closest")) + if (is_null(distance)) match_arg(m.order, c("data", "random", "closest")) + else if (is_not_null(m.order)) match_arg(m.order, c("largest", "smallest", "data", "random", "closest")) else if (estimand == "ATC") "smallest" else "largest" } - if (is.null(ex0) || !is.null(unit.id) || (is.null(mahcovs) && is.null(distance_mat))) { - if (m.order == "closest") { - if (!is.null(mahcovs)) { - distance_mat <- eucdist_internal(mahcovs, treat) - - # If caliper on PS (distance), treat it as a covariate - if (!is.null(distance) && !is.null(caliper.dist)) { - caliper.covs.mat <- { - if (is.null(caliper.covs)) as.matrix(distance) - else cbind(caliper.covs.mat, distance) - } - caliper.covs <- c(caliper.covs, caliper.dist) - caliper.dist <- NULL - } - } + if (is_not_null(mahcovs) && ncol(mahcovs) == 1L && is.full.mahalanobis && is_null(distance)) { + distance <- mahcovs[,1] + mahcovs <- NULL + } - ord <- NULL - } - else { - ord <- switch(m.order, - "largest" = order(distance[treat == 1], decreasing = TRUE), - "smallest" = order(distance[treat == 1], decreasing = FALSE), - "random" = sample.int(n1), - "data" = seq_len(n1)) - } + if (is_null(ex0) || is_not_null(unit.id) || (is_null(mahcovs) && is_null(distance_mat))) { + ord <- switch(m.order, + "largest" = order(distance[treat == 1], decreasing = TRUE), + "smallest" = order(distance[treat == 1], decreasing = FALSE), + "random" = sample.int(n1), + "data" = seq_len(n1), + "closest" = NULL) mm <- nn_matchC_dispatch(treat, ord, ratio, discarded, reuse.max, distance, distance_mat, ex0, caliper.dist, caliper.covs, caliper.covs.mat, mahcovs, @@ -596,43 +587,37 @@ matchit2nearest <- function(treat, data, distance, discarded, treat_ <- treat[.e] discarded_ <- discarded[.e] - if (!is.null(distance)) distance_ <- distance[.e] - if (!is.null(caliper.covs.mat)) caliper.covs.mat_ <- caliper.covs.mat[.e,,drop = FALSE] - if (!is.null(mahcovs)) mahcovs_ <- mahcovs[.e,,drop = FALSE] - if (!is.null(antiexactcovs)) antiexactcovs_ <- antiexactcovs[.e,,drop = FALSE] - if (!is.null(distance_mat)) { - .e0 <- which(ex0[treat==0] == e) - distance_mat_ <- distance_mat[.e1, .e0, drop = FALSE] - } - ratio_ <- ratio[.e1] - n1_ <- sum(treat_ == 1) + if (is_not_null(distance)) { + distance_ <- distance[.e] + } - if (m.order == "closest") { - if (!is.null(mahcovs)) { - distance_mat_ <- eucdist_internal(mahcovs_, treat_) + if (is_not_null(caliper.covs.mat)) { + caliper.covs.mat_ <- caliper.covs.mat[.e,, drop = FALSE] + } - # If caliper on PS (distance), treat it as a covariate - if (!is.null(distance) && !is.null(caliper.dist)) { - caliper.covs.mat <- { - if (is.null(caliper.covs)) as.matrix(distance_) - else cbind(caliper.covs.mat_, distance_) - } - caliper.covs <- c(caliper.covs, caliper.dist) - caliper.dist <- NULL - } - } + if (is_not_null(mahcovs)) { + mahcovs_ <- mahcovs[.e,,drop = FALSE] + } - ord <- NULL + if (is_not_null(antiexactcovs)) { + antiexactcovs_ <- antiexactcovs[.e,, drop = FALSE] } - else { - ord_ <- switch(m.order, - "largest" = order(distance_[treat_ == 1], decreasing = TRUE), - "smallest" = order(distance_[treat_ == 1], decreasing = FALSE), - "random" = sample.int(n1_), - "data" = seq_len(n1_)) + + if (is_not_null(distance_mat)) { + .e0 <- which(ex0[treat==0] == e) + distance_mat_ <- distance_mat[.e1, .e0, drop = FALSE] } + ratio_ <- ratio[.e1] + + ord_ <- switch(m.order, + "largest" = order(distance_[treat_ == 1], decreasing = TRUE), + "smallest" = order(distance_[treat_ == 1], decreasing = FALSE), + "random" = sample.int(sum(treat_ == 1)), + "data" = seq_len(sum(treat_ == 1)), + "closest" = NULL) + mm_ <- nn_matchC_dispatch(treat_, ord_, ratio_, discarded_, reuse.max, distance_, distance_mat_, NULL, caliper.dist, caliper.covs, caliper.covs.mat_, mahcovs_, antiexactcovs_, NULL, m.order, verbose) @@ -656,14 +641,16 @@ matchit2nearest <- function(treat, data, distance, discarded, if (reuse.max > 1) { psclass <- NULL + weights <- get_weights_from_mm(mm, treat, 1L) } else { - psclass <- mm2subclass(mm, treat) + psclass <- mm2subclass(mm, treat, 1L) + weights <- get_weights_from_subclass(psclass, treat) } res <- list(match.matrix = nummm2charmm(mm, treat), subclass = psclass, - weights = get_weights_from_mm(mm, treat)) + weights = weights) if (verbose) cat("Done.\n") @@ -678,27 +665,32 @@ matchit2nearest <- function(treat, data, distance, discarded, nn_matchC_dispatch <- function(treat, ord, ratio, discarded, reuse.max, distance, distance_mat, ex, caliper.dist, caliper.covs, caliper.covs.mat, mahcovs, antiexactcovs, unit.id, m.order, verbose) { if (m.order == "closest") { - if (is.null(distance_mat) && is.null(mahcovs)) { - nn_matchC_vec_closest(treat, ratio, discarded, reuse.max, distance, - ex, caliper.dist, caliper.covs, caliper.covs.mat, - antiexactcovs, unit.id, verbose) + if (is_not_null(mahcovs)) { + nn_matchC_mahcovs_closest(treat, ratio, discarded, reuse.max, mahcovs, distance, + ex, caliper.dist, caliper.covs, caliper.covs.mat, + antiexactcovs, unit.id, verbose) } - else { + else if (is_not_null(distance_mat)) { nn_matchC_closest(distance_mat, treat, ratio, discarded, reuse.max, ex, caliper.dist, caliper.covs, caliper.covs.mat, antiexactcovs, unit.id, verbose) } + else { + nn_matchC_vec_closest(treat, ratio, discarded, reuse.max, distance, + ex, caliper.dist, caliper.covs, caliper.covs.mat, + antiexactcovs, unit.id, verbose) + } } else { - if (is.null(distance_mat) && is.null(mahcovs)) { - nn_matchC_vec(treat, ord, ratio, discarded, reuse.max, 1L, distance, - ex, caliper.dist, caliper.covs, caliper.covs.mat, - antiexactcovs, unit.id, verbose) - } - else { + if (is_not_null(distance_mat) || is_not_null(mahcovs)) { nn_matchC(treat, ord, ratio, discarded, reuse.max, 1L, distance, distance_mat, ex, caliper.dist, caliper.covs, caliper.covs.mat, mahcovs, antiexactcovs, unit.id, verbose) } + else { + nn_matchC_vec(treat, ord, ratio, discarded, reuse.max, 1L, distance, + ex, caliper.dist, caliper.covs, caliper.covs.mat, + antiexactcovs, unit.id, verbose) + } } } \ No newline at end of file From cc0eb3811a91ec10a16d748744ff548266bf8c7e Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Thu, 24 Oct 2024 18:50:17 -0400 Subject: [PATCH 20/48] Doc and metadata updates --- DESCRIPTION | 8 +-- NEWS.md | 14 +++++ man/distance.Rd | 87 +++++++++++++++++--------------- man/mahalanobis_dist.Rd | 5 +- man/method_cem.Rd | 1 - man/method_full.Rd | 2 +- man/method_genetic.Rd | 7 ++- man/method_nearest.Rd | 26 +++++----- man/method_optimal.Rd | 2 +- vignettes/estimating-effects.Rmd | 38 +++++++------- 10 files changed, 104 insertions(+), 86 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 17606cf5..305c94d8 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,5 +1,5 @@ Package: MatchIt -Version: 4.5.5.9000 +Version: 4.5.5.9001 Title: Nonparametric Preprocessing for Parametric Causal Inference Description: Selects matched samples of the original treated and control groups with similar covariate distributions -- can be @@ -43,13 +43,13 @@ Suggests: rpart, mgcv, CBPS (>= 0.17), - dbarts, + dbarts (>= 0.9-28), randomForest (>= 4.7-1), glmnet (>= 4.0), gbm (>= 2.1.7), cobalt (>= 4.2.3), boot, - marginaleffects (>= 0.11.0), + marginaleffects (>= 0.19.0), sandwich (>= 2.5-1), survival, RcppProgress (>= 0.4.2), @@ -71,5 +71,5 @@ URL: https://kosukeimai.github.io/MatchIt/, BugReports: https://github.com/kosukeimai/MatchIt/issues VignetteBuilder: knitr Roxygen: list(markdown = TRUE) -RoxygenNote: 7.2.3 +RoxygenNote: 7.3.2 Config/testthat/edition: 3 diff --git a/NEWS.md b/NEWS.md index f6817261..92b54388 100644 --- a/NEWS.md +++ b/NEWS.md @@ -8,8 +8,22 @@ output: # MatchIt (development version) +Most improvements are related to performance. Some of these dramatically improve speeds for large datasets. Most come from improvements to `Rcpp` code. + +* Speed improvements to `method = "nearest"`, especially when matching on a propensity score. + +* Speed improvements to `summary()` when `pair.dist = TRUE` and a `match.matrix` component is not included in the output (e.g., for `method = "full"` or `method = "quick"`). + +* Speed improvements to `method = "subclass"` with `min.n` greater than 0. + +* When using `method = "nearest"` with `m.order = "closest"`, the full distance matrix is no longer computed, which increases support for larger samples. This uses an adaptation of an algorithm described by [Rassen et al. (2012)](https://doi.org/10.1002/pds.3263). + +* When using `method = "nearest"` with `verbose = TRUE`, the progress bar now displays an estimate of how much time remains. + * Fixed a bug when using `method = "optimal"` or `method = "full"` with `discard` specified and `data` given as a tibble (`tbl_df` object). (#185) +* Fixed a bug when using `method = "cardinality"` with a single covariate. (#194) + # MatchIt 4.5.5 * When using `method = "cardinality"`, a new solver, HiGHS, can be requested by setting `solver = "highs"`, which relies on the `highs` package. This is much faster and more reliable than GLPK and is free and easy to install as a regular R package with no additional requirements. diff --git a/man/distance.Rd b/man/distance.Rd index c234b760..f4c45075 100644 --- a/man/distance.Rd +++ b/man/distance.Rd @@ -12,12 +12,13 @@ defining calipers. This page documents the options that can be supplied to the \code{distance} argument to \code{\link[=matchit]{matchit()}}. } \note{ -In versions of \emph{MatchIt} prior to 4.0.0, \code{distance} was -specified in a slightly different way. When specifying arguments using the -old syntax, they will automatically be converted to the corresponding method -in the new syntax but a warning will be thrown. \code{distance = "logit"}, -the old default, will still work in the new syntax, though \verb{distance = "glm", link = "logit"} is preferred (note that these are the default -settings and don't need to be made explicit). +In versions of \emph{MatchIt} prior to 4.0.0, \code{distance} was specified in a +slightly different way. When specifying arguments using the old syntax, they +will automatically be converted to the corresponding method in the new syntax +but a warning will be thrown. \code{distance = "logit"}, the old default, will +still work in the new syntax, though \verb{distance = "glm", link = "logit"} is +preferred (note that these are the default settings and don't need to be made +explicit). } \section{Allowable options}{ @@ -104,8 +105,11 @@ defaults as used in \emph{WeightIt} and \emph{twang}, except for generalized boosted modeling as in \emph{twang}; here, the number of trees is chosen based on cross-validation or out-of-bag error, rather than based on optimizing balance. \pkg{twang} should not be cited when using this method -to estimate propensity scores. } -\item{\code{"lasso"}, \code{"ridge"}, \code{"elasticnet"}}{ The propensity +to estimate propensity scores. Note that because there is a random component to choosing the tuning +parameter, results will vary across runs unless a \link[=set.seed]{seed} is +set.} +\item{\code{"lasso"}, \code{"ridge"}, \code{"elasticnet"}}{ +The propensity scores are estimated using a lasso, ridge, or elastic net model, respectively. The \code{formula} supplied to \code{matchit()} is processed with \code{\link[=model.matrix]{model.matrix()}} and passed to \pkgfun{glmnet}{cv.glmnet}, and @@ -136,7 +140,8 @@ random forest. The \code{formula} supplied to \code{matchit()} is passed directly to \pkgfun{randomForest}{randomForest}, and \pkgfun{randomForest}{predict.randomForest} is used to compute the propensity scores. The \code{link} argument is ignored, and predicted probabilities are -always returned as the distance measure.} +always returned as the distance measure. Note that because there is a random component, results will vary across runs unless a \link[=set.seed]{seed} is +set. } \item{\code{"nnet"}}{ The propensity scores are estimated using a single-hidden-layer neural network. The \code{formula} supplied to \code{matchit()} is passed directly to @@ -164,36 +169,35 @@ scores. The \code{link} argument can be specified as \code{"linear"} to use the linear predictor instead of the predicted probabilities. When \code{s.weights} is supplied to \code{matchit()}, it will not be passed to \code{bart2} because the \code{weights} argument in \code{bart2} does not -correspond to sampling weights. } +correspond to sampling weights. Note that because there is a random component to choosing the tuning +parameter, results will vary across runs unless the \code{seed} argument is supplied to \code{distance.options}. Note that setting a seed using \code{\link[=set.seed]{set.seed()}} is not sufficient to guarantee reproducibility unless single-threading is used. See \pkgfun{dbarts}{bart2} for details.} } } \subsection{Methods for computing distances from covariates}{ -The following methods involve computing a distance matrix from the covariates themselves -without estimating a propensity score. Calipers on the distance measure and -common support restrictions cannot be used, and the \code{distance} -component of the output object will be empty because no propensity scores -are estimated. The \code{link} and \code{distance.options} arguments are -ignored with these methods. See the individual matching methods pages for -whether these distances are allowed and how they are used. Each of these -distance measures can also be calculated outside \code{matchit()} using its -\link[=euclidean_dist]{corresponding function}. +The following methods involve computing a distance matrix from the covariates +themselves without estimating a propensity score. Calipers on the distance +measure and common support restrictions cannot be used, and the \code{distance} +component of the output object will be empty because no propensity scores are +estimated. The \code{link} and \code{distance.options} arguments are ignored with these +methods. See the individual matching methods pages for whether these +distances are allowed and how they are used. Each of these distance measures +can also be calculated outside \code{matchit()} using its \link[=euclidean_dist]{corresponding function}. \describe{ \item{\code{"euclidean"}}{ The Euclidean distance is the raw -distance between units, computed as \deqn{d_{ij} = \sqrt{(x_i - x_j)(x_i - -x_j)'}} It is sensitive to the scale of the covariates, so covariates with +distance between units, computed as \deqn{d_{ij} = \sqrt{(x_i - x_j)(x_i - x_j)'}} It is sensitive to the scale of the covariates, so covariates with larger scales will take higher priority. } -\item{\code{"scaled_euclidean"}}{ The scaled Euclidean distance is the +\item{\code{"scaled_euclidean"}}{ +The scaled Euclidean distance is the Euclidean distance computed on the scaled (i.e., standardized) covariates. This ensures the covariates are on the same scale. The covariates are standardized using the pooled within-group standard deviations, computed by treatment group-mean centering each covariate before computing the standard -deviation in the full sample. } -\item{\code{"mahalanobis"}}{ The -Mahalanobis distance is computed as \deqn{d_{ij} = \sqrt{(x_i - -x_j)\Sigma^{-1}(x_i - x_j)'}} where \eqn{\Sigma} is the pooled within-group +deviation in the full sample. +} +\item{\code{"mahalanobis"}}{ The Mahalanobis distance is computed as \deqn{d_{ij} = \sqrt{(x_i - x_j)\Sigma^{-1}(x_i - x_j)'}} where \eqn{\Sigma} is the pooled within-group covariance matrix of the covariates, computed by treatment group-mean centering each covariate before computing the covariance in the full sample. This ensures the variables are on the same scale and accounts for the @@ -206,30 +210,29 @@ that handles outliers and rare categories better than the standard Mahalanobis distance but is not affinely invariant. } } -To perform Mahalanobis distance matching \emph{and} estimate propensity -scores to be used for a purpose other than matching, the \code{mahvars} -argument should be used along with a different specification to -\code{distance}. See the individual matching method pages for details on how -to use \code{mahvars}. +To perform Mahalanobis distance matching \emph{and} estimate propensity scores to +be used for a purpose other than matching, the \code{mahvars} argument should be +used along with a different specification to \code{distance}. See the individual +matching method pages for details on how to use \code{mahvars}. } \subsection{Distances supplied as a numeric vector or matrix}{ -\code{distance} can also be supplied as a numeric vector whose values will be taken to -function like propensity scores; their pairwise difference will define the -distance between units. This might be useful for supplying propensity scores -computed outside \code{matchit()} or resupplying \code{matchit()} with -propensity scores estimated previously without having to recompute them. +\code{distance} can also be supplied as a numeric vector whose values will be +taken to function like propensity scores; their pairwise difference will +define the distance between units. This might be useful for supplying +propensity scores computed outside \code{matchit()} or resupplying \code{matchit()} +with propensity scores estimated previously without having to recompute them. \code{distance} can also be supplied as a matrix whose values represent the pairwise distances between units. The matrix should either be a square, with a row and column for each unit (e.g., as the output of a call to -\verb{as.matrix(}\code{\link{dist}}\verb{(.))}), or have as many rows as there are treated -units and as many columns as there are control units (e.g., as the output of -a call to \code{\link[=mahalanobis_dist]{mahalanobis_dist()}} or \pkgfun{optmatch}{match_on}). Distance values -of \code{Inf} will disallow the corresponding units to be matched. When -\code{distance} is a supplied as a numeric vector or matrix, \code{link} and -\code{distance.options} are ignored. +\verb{as.matrix(}\code{\link{dist}}\verb{(.))}), or have as many rows as there are treated units +and as many columns as there are control units (e.g., as the output of a call +to \code{\link[=mahalanobis_dist]{mahalanobis_dist()}} or \pkgfun{optmatch}{match_on}). Distance values of +\code{Inf} will disallow the corresponding units to be matched. When \code{distance} is +a supplied as a numeric vector or matrix, \code{link} and \code{distance.options} are +ignored. } } diff --git a/man/mahalanobis_dist.Rd b/man/mahalanobis_dist.Rd index 3871069a..0e10934f 100644 --- a/man/mahalanobis_dist.Rd +++ b/man/mahalanobis_dist.Rd @@ -177,8 +177,7 @@ table(lalonde$treat) } \references{ -Rosenbaum, P. R. (2010). \emph{Design of observational studies}. -Springer. +Rosenbaum, P. R. (2010). \emph{Design of observational studies}. Springer. Rosenbaum, P. R., & Rubin, D. B. (1985). Constructing a Control Group Using Multivariate Matched Sampling Methods That Incorporate the Propensity Score. @@ -189,7 +188,7 @@ Rubin, D. B. (1980). Bias Reduction Using Mahalanobis-Metric Matching. } \seealso{ \code{\link{distance}}, \code{\link[=matchit]{matchit()}}, \code{\link[=dist]{dist()}} (which is used -internally to compute Euclidean distances) +internally to compute some Euclidean distances) \pkgfun{optmatch}{match_on}, which provides similar functionality but with fewer options and a focus on efficient storage of the output. diff --git a/man/method_cem.Rd b/man/method_cem.Rd index aa337d15..5cac4310 100644 --- a/man/method_cem.Rd +++ b/man/method_cem.Rd @@ -246,7 +246,6 @@ m.out2 <- matchit(treat ~ age + race + married + educ, data = lalonde, k2k = TRUE, k2k.method = "mahalanobis") m.out2 summary(m.out2, un = FALSE) - } \references{ In a manuscript, you don't need to cite another package when diff --git a/man/method_full.Rd b/man/method_full.Rd index 3dfb0e35..28aab09f 100644 --- a/man/method_full.Rd +++ b/man/method_full.Rd @@ -39,7 +39,7 @@ place when \code{distance} corresponds to a propensity score (e.g., for caliper matching or to discard units for common support). If specified, the distance measure will not be used in matching.} -\item{antiexact}{for which variables ant-exact matching should take place. +\item{antiexact}{for which variables anti-exact matching should take place. Anti-exact matching is processed using \pkgfun{optmatch}{antiExactMatch}.} \item{discard}{a string containing a method for discarding units outside a diff --git a/man/method_genetic.Rd b/man/method_genetic.Rd index 23aed566..13060c7d 100644 --- a/man/method_genetic.Rd +++ b/man/method_genetic.Rd @@ -50,7 +50,7 @@ to the distance matrix. Use \code{mahvars} to only supply a subset. Even if \code{mahvars} is specified, balance will be optimized on all covariates in \code{formula}. See Details.} -\item{antiexact}{for which variables ant-exact matching should take place. +\item{antiexact}{for which variables anti-exact matching should take place. Anti-exact matching is processed using the \code{restrict} argument to \code{Matching::GenMatch()} and \code{Matching::Match()}.} @@ -200,6 +200,11 @@ considered "focal". Note that while \code{GenMatch()} and \code{Match()} support the ATE as an estimand, \code{matchit()} only supports the ATT and ATC for genetic matching. } + +\subsection{Reproducibility}{ + +Genetic matching involves a random component, so a seed must be set using \code{\link[=set.seed]{set.seed()}} to ensure reproducibility. When \code{cluster} is used for parallel processing, the seed must be compatible with parallel processing (e.g., by setting \code{type = "L'Ecuyer-CMRG"}). +} } \section{Outputs}{ diff --git a/man/method_nearest.Rd b/man/method_nearest.Rd index b211d9b9..d0652bf1 100644 --- a/man/method_nearest.Rd +++ b/man/method_nearest.Rd @@ -28,14 +28,14 @@ by the argument to \code{distance}.} \item{estimand}{a string containing the desired estimand. Allowable options include \code{"ATT"} and \code{"ATC"}. See Details.} -\item{exact}{for which variables exact matching should take place.} +\item{exact}{for which variables exact matching should take place; two units with different values of an exact matching variable will not be paired.} \item{mahvars}{for which variables Mahalanobis distance matching should take place when \code{distance} corresponds to a propensity score (e.g., for caliper matching or to discard units for common support). If specified, the distance measure will not be used in matching.} -\item{antiexact}{for which variables ant-exact matching should take place.} +\item{antiexact}{for which variables anti-exact matching should take place; two units with the same value of an anti-exact matching variable will not be paired.} \item{discard}{a string containing a method for discarding units outside a region of common support. Only allowed when \code{distance} corresponds to a @@ -47,13 +47,13 @@ re-estimate the propensity score in the remaining sample prior to matching.} \item{s.weights}{the variable containing sampling weights to be incorporated into propensity score models and balance statistics.} -\item{replace}{whether matching should be done with replacement.} +\item{replace}{whether matching should be done with replacement (i.e., whether control units can be used as matches multiple times). See also the \code{reuse.max} argument below. Default is \code{FALSE} for matching without replacement.} \item{m.order}{the order that the matching takes place. Allowable options include \code{"largest"}, where matching takes place in descending order of distance measures; \code{"smallest"}, where matching takes place in ascending order of distance measures; \code{"closest"}, where matching takes place in -order of the distance between units; \code{"random"}, where matching takes place +ascending order of the distance between units; \code{"random"}, where matching takes place in a random order; and \code{"data"} where matching takes place based on the order of units in the data. When \code{m.order = "random"}, results may differ across different runs of the same code unless a seed is set and specified @@ -61,8 +61,7 @@ with \code{\link[=set.seed]{set.seed()}}. The default of \code{NULL} corresponds propensity score is estimated or supplied as a vector and \code{"data"} otherwise.} -\item{caliper}{the width(s) of the caliper(s) used for caliper matching. See -Details and Examples.} +\item{caliper}{the width(s) of the caliper(s) used for caliper matching. Two units with a difference on a caliper variable larger than the caliper will not be paired. See Details and Examples.} \item{std.caliper}{\code{logical}; when calipers are specified, whether they are in standard deviation units (\code{TRUE}) or raw units (\code{FALSE}).} @@ -77,7 +76,7 @@ section "Variable Ratio Matching" in Details below.} \item{verbose}{\code{logical}; whether information about the matching process should be printed to the console. When \code{TRUE}, a progress bar -implemented using \emph{RcppProgress} will be displayed.} +implemented using \emph{RcppProgress} will be displayed along with an estimate of the time remaining.} \item{\dots}{additional arguments that control the matching specification: \describe{ @@ -94,8 +93,7 @@ observation, i.e., in case multiple observations correspond to the same unit. Once a control observation has been matched, no other observation with the same unit ID can be used as matches. This ensures each control unit is used only once even if it has multiple observations associated with it. -Omitting this argument is the same as giving each observation a unique ID. -Ignored when \code{replace = TRUE}. } +Omitting this argument is the same as giving each observation a unique ID.} }} } \description{ @@ -198,8 +196,7 @@ switch to trigger which treatment group is considered "focal". \subsection{Variable Ratio Matching}{ -\code{matchit()} can perform variable -ratio "extremal" matching as described by Ming and Rosenbaum (2000). This +\code{matchit()} can perform variable ratio "extremal" matching as described by Ming and Rosenbaum (2000). This method tends to result in better balance than fixed ratio matching at the expense of some precision. When \code{ratio > 1}, rather than requiring all treated units to receive \code{ratio} matches, each treated unit is assigned @@ -231,7 +228,12 @@ Examples below for an example of their use. \subsection{Using \code{m.order = "closest"}}{ -As of version 4.6.0, \code{m.order} can be set to \code{"closest"}, which works regardless of how the distance measure is specified. This matches in order of the distance between units. The closest pair of units across all potential pairs of units will be matched first; the second closest pair of all potential pairs will be matched second, etc. This ensures that the best possible matches are given priority, and in that sense performs similarly to \code{m.order = "smallest"}. +As of version 4.6.0, \code{m.order} can be set to \code{"closest"}, which works regardless of how the distance measure is specified. This matches in order of the distance between units. The closest pair of units across all potential pairs of units will be matched first; the second closest pair of all potential pairs will be matched second, etc. This ensures that the best possible matches are given priority, and in that sense performs similarly to \code{m.order = "smallest"}. When \code{distance} is a vector or a method of estimating a propensity score, an algorithm described by \href{https://doi.org/10.1002/pds.3263}{Rassen et al. (2012)} is used. +} + +\subsection{Reproducibility}{ + +Nearest neighbor matching involves a random component only when \code{m.order = "random"}, so a seed must be set in that case using \code{\link[=set.seed]{set.seed()}} to ensure reproducibility. Otherwise, it is purely deterministic, and any ties are broken based on the order in which the data appear. } } \note{ diff --git a/man/method_optimal.Rd b/man/method_optimal.Rd index 3c7c825e..9fc08157 100644 --- a/man/method_optimal.Rd +++ b/man/method_optimal.Rd @@ -37,7 +37,7 @@ place when \code{distance} corresponds to a propensity score (e.g., for caliper matching or to discard units for common support). If specified, the distance measure will not be used in matching.} -\item{antiexact}{for which variables ant-exact matching should take place. +\item{antiexact}{for which variables anti-exact matching should take place. Anti-exact matching is processed using \pkgfun{optmatch}{antiExactMatch}.} \item{discard}{a string containing a method for discarding units outside a diff --git a/vignettes/estimating-effects.Rmd b/vignettes/estimating-effects.Rmd index ed0fded9..7f336efc 100644 --- a/vignettes/estimating-effects.Rmd +++ b/vignettes/estimating-effects.Rmd @@ -33,6 +33,7 @@ me_ok <- requireNamespace("marginaleffects", quietly = TRUE) && su_ok <- requireNamespace("survival", quietly = TRUE) boot_ok <- requireNamespace("boot", quietly = TRUE) ``` + ```{r, include = FALSE} #Generating data similar to Austin (2009) for demonstrating treatment effect estimation gen_X <- function(n) { @@ -190,7 +191,7 @@ library("marginaleffects") All effect estimates will be computed using `marginaleffects::avg_comparions()`, even when its use may be superfluous (e.g., for performing a t-test in the matched set). As previously mentioned, this is because it is useful to have a single workflow that works no matter the situation, perhaps with very slight modifications to accommodate different contexts. Using `avg_comparions()` has several advantages, even when the alternatives are simple: it only provides the effect estimate, and not other coefficients; it automatically incorporates robust and cluster-robust SEs if requested; and it always produces average marginal effects for the correct population if requested. -Other packages may be of use but are not used here. There are alternatives to the `marginaleffects` package for computing average marginal effects, including `margins` and `stdReg`. The `survey` package can be used to estimate robust SEs incorporating weights and provides functions for survey-weighted generalized linear models and Cox-proportional hazards models. It is often used with propensity score weighting. +Other packages may be of use but are not used here. There are alternatives to the `marginaleffects` package for computing average marginal effects, including `margins` and `stdReg`. The `survey` package can be used to estimate robust SEs incorporating weights and provides functions for survey-weighted generalized linear models and Cox-proportional hazards models. ### The Standard Case @@ -235,24 +236,22 @@ Next, we use `marginaleffects::avg_comparisons()` to estimate the ATT. ```{r, eval=me_ok} avg_comparisons(fit1, variables = "A", vcov = ~subclass, - newdata = subset(md, A == 1), - wts = "weights") + newdata = subset(A == 1)) ``` -Let's break down the call to `avg_comparisons()`: to the first argument, we supply the model fit, `fit1`; to the `variables` argument, the name of the treatment (`"A"`); to the `vcov` argument, a formula with subclass membership (`~subclass`) to request cluster-robust SEs; to the `newdata` argument, a version of the matched dataset containing only the treated units (`subset(md, A == 1)`) to request the ATT; and to the `wts` argument, the names of the variable in `md` containing the matching weights (`"weights"`) to ensure they are included in the analysis. Some of these arguments differ depending on the specifics of the matching method and outcome type; see the sections below for information. +Let's break down the call to `avg_comparisons()`: to the first argument, we supply the model fit, `fit1`; to the `variables` argument, the name of the treatment (`"A"`); to the `vcov` argument, a formula with subclass membership (`~subclass`) to request cluster-robust SEs; and to the `newdata` argument, a version of the matched dataset containing only the treated units (`subset(A == 1)`) to request the ATT. Some of these arguments differ depending on the specifics of the matching method and outcome type; see the sections below for information. If, in addition to the effect estimate, we want the average estimated potential outcomes, we can use `marginaleffects::avg_predictions()`, which we demonstrate below. Note the interpretation of the resulting estimates as the expected potential outcomes is only valid if all covariates present in the outcome model (if any) are interacted with the treatment. ```{r, eval=me_ok && packageVersion("marginaleffects") >= "0.11.0"} avg_predictions(fit1, variables = "A", vcov = ~subclass, - newdata = subset(md, A == 1), - wts = "weights") + newdata = subset(A == 1)) ``` We can see that the difference in potential outcome means is equal to the average treatment effect computed previously[^4]. All of the arguments to `avg_predictions()` are the same as those to `avg_comparisons()`. -[^4]: To verify that they are equal, supply the output of `avg_predictions()` to `hypotheses(), e.g., `avg_predictions(...) |> hypotheses("revpairwise")`; this explicitly compares the average potential outcomes and should yield identical estimates to the `avg_comparisons()` call. +[^4]: To verify that they are equal, supply the output of `avg_predictions()` to `hypotheses(), e.g.,`avg_predictions(...) \|\> hypotheses("revpairwise")`; this explicitly compares the average potential outcomes and should yield identical estimates to the`avg_comparisons()\` call. ### Adjustments to the Standard Case @@ -260,7 +259,7 @@ This section explains how the procedure might differ if any of the following spe #### Matching for the ATE -When matching for the ATE (including [coarsened] exact matching, full matching, subclassification, and cardinality matching), everything is identical to the Standard Case except that in the calls to `avg_comparisons()` and `avg_predictions()`, the `newdata` argument is omitted (or can be replaced with `newdata = md`). This is because the estimated potential outcomes are computed for the full matched sample rather than just the treated units. +When matching for the ATE (including [coarsened] exact matching, full matching, subclassification, and cardinality matching), everything is identical to the Standard Case except that in the calls to `avg_comparisons()` and `avg_predictions()`, the `newdata` argument is omitted. This is because the estimated potential outcomes are computed for the full sample rather than just the treated units. #### Matching with replacement @@ -280,7 +279,7 @@ There are two natural ways to estimate marginal effects after subclassification: All of the methods described above for the Standard Case also work with MMWS because the formation of the weights is the same; the only difference is that it is not appropriate to use cluster-robust SEs with MMWS because of how few clusters are present, so one should change `vcov = ~subclass` to `vcov = "HC3"` in the calls to `avg_comparisons()` and `avg_predictions()` to use robust SEs instead of cluster-robust SEs. The subclasses can optionally be included in the outcome model (optionally interacting with treatment) as an alternative to including the propensity score. -The subclass-specific approach omits the weights and uses the subclasses directly. It is only appropriate when there are a small number of subclasses relative to the sample size. In the outcome model, `subclass` should interact with all other predictors in the model (including the treatment, covariates, and interactions, if any), and the `weights` argument should be omitted. In the calls to `avg_comparisons()` and `avg_predictions()`, the `wts` argument should be omitted. As with MMWS, one should change `vcov = ~subclass` to `vcov = "HC3"` in the calls to `avg_comparisons()` and `avg_predictions()`. See an example below: +The subclass-specific approach omits the weights and uses the subclasses directly. It is only appropriate when there are a small number of subclasses relative to the sample size. In the outcome model, `subclass` should interact with all other predictors in the model (including the treatment, covariates, and interactions, if any), and the `weights` argument should be omitted. As with MMWS, one should change `vcov = ~subclass` to `vcov = "HC3"` in the calls to `avg_comparisons()` and `avg_predictions()`. See an example below: ```{r, eval=me_ok} #Subclassification on the PS for the ATT @@ -298,7 +297,7 @@ fitS <- lm(Y_C ~ subclass * (A * (X1 + X2 + X3 + X4 + X5 + avg_comparisons(fitS, variables = "A", vcov = "HC3", - newdata = subset(md, A == 1)) + newdata = subset(A == 1)) ``` A model with fewer terms may be required when subclasses are small; removing covariates or their interactions with treatment may be required and can increase precision in smaller datasets. Remember that if subclassification is done for the ATE (even if units are dropped), the `newdata` argument should be dropped. @@ -311,7 +310,7 @@ To fit a logistic regression model, change `lm()` to `glm()` and set `family = q [^6]: We use `quasibinomial()` instead of `binomial()` simply to avoid a spurious warning that can occur with certain kinds of matching; the results will be identical regardless. -[^7]: Note that for low or high average expected risks computed with `predictions()`, the confidence intervals may go below 0 or above 1; this is because an approximation is used. To avoid this problem, bootstrapping or simulation-based inference can be used instead. +[^7]: Note that for low or high average expected risks computed with `avg_predictions()`, the confidence intervals may go below 0 or above 1; this is because an approximation is used. To avoid this problem, bootstrapping or simulation-based inference can be used instead. To compute the marginal RR, we need to add `comparison = "lnratioavg"` to `avg_comparisons()`; this computes the marginal log RR. To get the marginal RR, we need to add `transform = "exp"` to `avg_comparisons()`, which exponentiates the marginal log RR and its confidence interval. The code below computes the effects and displays the statistics of interest: @@ -326,8 +325,7 @@ fit2 <- glm(Y_B ~ A * (X1 + X2 + X3 + X4 + X5 + avg_comparisons(fit2, variables = "A", vcov = ~subclass, - newdata = subset(md, A == 1), - wts = "weights", + newdata = subset(A == 1), comparison = "lnratioavg", transform = "exp") ``` @@ -424,12 +422,12 @@ boot_fun <- function(data, i) { #Estimated potential outcomes under treatment p1 <- predict(fit, type = "response", newdata = transform(md1, A = 1)) - Ep1 <- weighted.mean(p1, md1$weights) + Ep1 <- mean(p1) #Estimated potential outcomes under control p0 <- predict(fit, type = "response", newdata = transform(md1, A = 0)) - Ep0 <- weighted.mean(p0, md1$weights) + Ep0 <- mean(p0) #Risk ratio return(Ep1 / Ep0) @@ -500,12 +498,12 @@ cluster_boot_fun <- function(pairs, i) { #Estimated potential outcomes under treatment p1 <- predict(fit, type = "response", newdata = transform(md1, A = 1)) - Ep1 <- weighted.mean(p1, md1$weights) + Ep1 <- mean(p1) #Estimated potential outcomes under control p0 <- predict(fit, type = "response", newdata = transform(md1, A = 0)) - Ep0 <- weighted.mean(p0, md1$weights) + Ep0 <- mean(p0) #Risk ratio return(Ep1 / Ep0) @@ -565,8 +563,7 @@ To estimate the subgroup ATTs, we can use `avg_comparisons()`, this time specify ```{r, eval=me_ok} avg_comparisons(fitP, variables = "A", vcov = ~subclass, - newdata = subset(md, A == 1), - wts = "weights", + newdata = subset(A == 1), by = "X5") ``` @@ -575,8 +572,7 @@ We can see that the subgroup mean differences are quite similar to each other. F ```{r, eval=me_ok} avg_comparisons(fitP, variables = "A", vcov = ~subclass, - newdata = subset(md, A == 1), - wts = "weights", + newdata = subset(A == 1), by = "X5", hypothesis = "pairwise") ``` From ac11929cb85841cea11f2019d1d08ba074884ce5 Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Thu, 24 Oct 2024 18:50:30 -0400 Subject: [PATCH 21/48] Rcpp updates --- R/RcppExports.R | 52 +++++++++--- src/RcppExports.cpp | 191 +++++++++++++++++++++++++++++++++++--------- 2 files changed, 195 insertions(+), 48 deletions(-) diff --git a/R/RcppExports.R b/R/RcppExports.R index 96fa2764..2ba81509 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -1,36 +1,64 @@ # Generated by using Rcpp::compileAttributes() -> do not edit by hand # Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393 -dist_to_matrixC <- function(d) { - .Call(`_MatchIt_dist_to_matrixC`, d) +eucdistC_N1xN0 <- function(x, t) { + .Call(`_MatchIt_eucdistC_N1xN0`, x, t) } -nn_matchC <- function(treat_, ord_, ratio, discarded, reuse_max, distance_ = NULL, distance_mat_ = NULL, exact_ = NULL, caliper_dist_ = NULL, caliper_covs_ = NULL, caliper_covs_mat_ = NULL, mah_covs_ = NULL, antiexact_covs_ = NULL, unit_id_ = NULL, disl_prog = FALSE) { - .Call(`_MatchIt_nn_matchC`, treat_, ord_, ratio, discarded, reuse_max, distance_, distance_mat_, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, mah_covs_, antiexact_covs_, unit_id_, disl_prog) +get_splitsC <- function(x, caliper) { + .Call(`_MatchIt_get_splitsC`, x, caliper) +} + +has_n_unique <- function(x, n) { + .Call(`_MatchIt_has_n_unique`, x, n) +} + +nn_matchC <- function(treat_, ord, ratio, discarded, reuse_max, focal_, distance_ = NULL, distance_mat_ = NULL, exact_ = NULL, caliper_dist_ = NULL, caliper_covs_ = NULL, caliper_covs_mat_ = NULL, mah_covs_ = NULL, antiexact_covs_ = NULL, unit_id_ = NULL, disl_prog = FALSE) { + .Call(`_MatchIt_nn_matchC`, treat_, ord, ratio, discarded, reuse_max, focal_, distance_, distance_mat_, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, mah_covs_, antiexact_covs_, unit_id_, disl_prog) } nn_matchC_closest <- function(distance_mat, treat, ratio, discarded, reuse_max, exact_ = NULL, caliper_dist_ = NULL, caliper_covs_ = NULL, caliper_covs_mat_ = NULL, antiexact_covs_ = NULL, unit_id_ = NULL, disl_prog = FALSE) { .Call(`_MatchIt_nn_matchC_closest`, distance_mat, treat, ratio, discarded, reuse_max, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, disl_prog) } -nn_matchC_vec <- function(treat_, ord_, ratio_, discarded_, reuse_max, distance_, exact_ = NULL, caliper_dist_ = NULL, caliper_covs_ = NULL, caliper_covs_mat_ = NULL, antiexact_covs_ = NULL, unit_id_ = NULL, disl_prog = FALSE) { - .Call(`_MatchIt_nn_matchC_vec`, treat_, ord_, ratio_, discarded_, reuse_max, distance_, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, disl_prog) +nn_matchC_mahcovs_closest <- function(treat, ratio, discarded, reuse_max, mah_covs, distance_ = NULL, exact_ = NULL, caliper_dist_ = NULL, caliper_covs_ = NULL, caliper_covs_mat_ = NULL, antiexact_covs_ = NULL, unit_id_ = NULL, disl_prog = FALSE) { + .Call(`_MatchIt_nn_matchC_mahcovs_closest`, treat, ratio, discarded, reuse_max, mah_covs, distance_, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, disl_prog) +} + +nn_matchC_vec <- function(treat_, ord, ratio, discarded, reuse_max, focal_, distance, exact_ = NULL, caliper_dist_ = NULL, caliper_covs_ = NULL, caliper_covs_mat_ = NULL, antiexact_covs_ = NULL, unit_id_ = NULL, disl_prog = FALSE) { + .Call(`_MatchIt_nn_matchC_vec`, treat_, ord, ratio, discarded, reuse_max, focal_, distance, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, disl_prog) +} + +nn_matchC_vec_closest <- function(treat, ratio, discarded, reuse_max, distance, exact_ = NULL, caliper_dist_ = NULL, caliper_covs_ = NULL, caliper_covs_mat_ = NULL, antiexact_covs_ = NULL, unit_id_ = NULL, disl_prog = FALSE) { + .Call(`_MatchIt_nn_matchC_vec_closest`, treat, ratio, discarded, reuse_max, distance, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, disl_prog) +} + +pairdistsubC <- function(x, t, s) { + .Call(`_MatchIt_pairdistsubC`, x, t, s) } -pairdistsubC <- function(x_, t_, s_, num_sub) { - .Call(`_MatchIt_pairdistsubC`, x_, t_, s_, num_sub) +subclass2mmC <- function(subclass_, treat, focal) { + .Call(`_MatchIt_subclass2mmC`, subclass_, treat, focal) } -subclass2mmC <- function(subclass, treat, focal) { - .Call(`_MatchIt_subclass2mmC`, subclass, treat, focal) +mm2subclassC <- function(mm, treat, focal = NULL) { + .Call(`_MatchIt_mm2subclassC`, mm, treat, focal) +} + +subclass_scootC <- function(subclass_, treat_, x_, min_n) { + .Call(`_MatchIt_subclass_scootC`, subclass_, treat_, x_, min_n) } tabulateC <- function(bins, nbins = NULL) { .Call(`_MatchIt_tabulateC`, bins, nbins) } -weights_matrixC <- function(mm, treat) { - .Call(`_MatchIt_weights_matrixC`, mm, treat) +weights_matrixC <- function(mm, treat_, focal = NULL) { + .Call(`_MatchIt_weights_matrixC`, mm, treat_, focal) +} + +weights_subclassC <- function(subclass_, treat_, focal_ = NULL) { + .Call(`_MatchIt_weights_subclassC`, subclass_, treat_, focal_) } # Register entry points for exported C++ functions diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 4adbd0b2..c3245b02 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -13,28 +13,54 @@ Rcpp::Rostream& Rcpp::Rcout = Rcpp::Rcpp_cout_get(); Rcpp::Rostream& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get(); #endif -// dist_to_matrixC -NumericMatrix dist_to_matrixC(const NumericVector& d); -RcppExport SEXP _MatchIt_dist_to_matrixC(SEXP dSEXP) { +// eucdistC_N1xN0 +NumericVector eucdistC_N1xN0(const NumericMatrix& x, const IntegerVector& t); +RcppExport SEXP _MatchIt_eucdistC_N1xN0(SEXP xSEXP, SEXP tSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; - Rcpp::traits::input_parameter< const NumericVector& >::type d(dSEXP); - rcpp_result_gen = Rcpp::wrap(dist_to_matrixC(d)); + Rcpp::traits::input_parameter< const NumericMatrix& >::type x(xSEXP); + Rcpp::traits::input_parameter< const IntegerVector& >::type t(tSEXP); + rcpp_result_gen = Rcpp::wrap(eucdistC_N1xN0(x, t)); + return rcpp_result_gen; +END_RCPP +} +// get_splitsC +NumericVector get_splitsC(const NumericVector& x, const double& caliper); +RcppExport SEXP _MatchIt_get_splitsC(SEXP xSEXP, SEXP caliperSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const NumericVector& >::type x(xSEXP); + Rcpp::traits::input_parameter< const double& >::type caliper(caliperSEXP); + rcpp_result_gen = Rcpp::wrap(get_splitsC(x, caliper)); + return rcpp_result_gen; +END_RCPP +} +// has_n_unique +bool has_n_unique(const SEXP& x, const int& n); +RcppExport SEXP _MatchIt_has_n_unique(SEXP xSEXP, SEXP nSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const SEXP& >::type x(xSEXP); + Rcpp::traits::input_parameter< const int& >::type n(nSEXP); + rcpp_result_gen = Rcpp::wrap(has_n_unique(x, n)); return rcpp_result_gen; END_RCPP } // nn_matchC -IntegerMatrix nn_matchC(const IntegerVector& treat_, const IntegerVector& ord_, const IntegerVector& ratio, const LogicalVector& discarded, const int& reuse_max, const Nullable& distance_, const Nullable& distance_mat_, const Nullable& exact_, const Nullable& caliper_dist_, const Nullable& caliper_covs_, const Nullable& caliper_covs_mat_, const Nullable& mah_covs_, const Nullable& antiexact_covs_, const Nullable& unit_id_, const bool& disl_prog); -RcppExport SEXP _MatchIt_nn_matchC(SEXP treat_SEXP, SEXP ord_SEXP, SEXP ratioSEXP, SEXP discardedSEXP, SEXP reuse_maxSEXP, SEXP distance_SEXP, SEXP distance_mat_SEXP, SEXP exact_SEXP, SEXP caliper_dist_SEXP, SEXP caliper_covs_SEXP, SEXP caliper_covs_mat_SEXP, SEXP mah_covs_SEXP, SEXP antiexact_covs_SEXP, SEXP unit_id_SEXP, SEXP disl_progSEXP) { +IntegerMatrix nn_matchC(const IntegerVector& treat_, const IntegerVector& ord, const IntegerVector& ratio, const LogicalVector& discarded, const int& reuse_max, const int& focal_, const Nullable& distance_, const Nullable& distance_mat_, const Nullable& exact_, const Nullable& caliper_dist_, const Nullable& caliper_covs_, const Nullable& caliper_covs_mat_, const Nullable& mah_covs_, const Nullable& antiexact_covs_, const Nullable& unit_id_, const bool& disl_prog); +RcppExport SEXP _MatchIt_nn_matchC(SEXP treat_SEXP, SEXP ordSEXP, SEXP ratioSEXP, SEXP discardedSEXP, SEXP reuse_maxSEXP, SEXP focal_SEXP, SEXP distance_SEXP, SEXP distance_mat_SEXP, SEXP exact_SEXP, SEXP caliper_dist_SEXP, SEXP caliper_covs_SEXP, SEXP caliper_covs_mat_SEXP, SEXP mah_covs_SEXP, SEXP antiexact_covs_SEXP, SEXP unit_id_SEXP, SEXP disl_progSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; Rcpp::traits::input_parameter< const IntegerVector& >::type treat_(treat_SEXP); - Rcpp::traits::input_parameter< const IntegerVector& >::type ord_(ord_SEXP); + Rcpp::traits::input_parameter< const IntegerVector& >::type ord(ordSEXP); Rcpp::traits::input_parameter< const IntegerVector& >::type ratio(ratioSEXP); Rcpp::traits::input_parameter< const LogicalVector& >::type discarded(discardedSEXP); Rcpp::traits::input_parameter< const int& >::type reuse_max(reuse_maxSEXP); + Rcpp::traits::input_parameter< const int& >::type focal_(focal_SEXP); Rcpp::traits::input_parameter< const Nullable& >::type distance_(distance_SEXP); Rcpp::traits::input_parameter< const Nullable& >::type distance_mat_(distance_mat_SEXP); Rcpp::traits::input_parameter< const Nullable& >::type exact_(exact_SEXP); @@ -45,7 +71,7 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const Nullable& >::type antiexact_covs_(antiexact_covs_SEXP); Rcpp::traits::input_parameter< const Nullable& >::type unit_id_(unit_id_SEXP); Rcpp::traits::input_parameter< const bool& >::type disl_prog(disl_progSEXP); - rcpp_result_gen = Rcpp::wrap(nn_matchC(treat_, ord_, ratio, discarded, reuse_max, distance_, distance_mat_, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, mah_covs_, antiexact_covs_, unit_id_, disl_prog)); + rcpp_result_gen = Rcpp::wrap(nn_matchC(treat_, ord, ratio, discarded, reuse_max, focal_, distance_, distance_mat_, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, mah_covs_, antiexact_covs_, unit_id_, disl_prog)); return rcpp_result_gen; END_RCPP } @@ -71,18 +97,42 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } +// nn_matchC_mahcovs_closest +IntegerMatrix nn_matchC_mahcovs_closest(const IntegerVector& treat, const IntegerVector& ratio, const LogicalVector& discarded, const int& reuse_max, const NumericMatrix& mah_covs, const Nullable& distance_, const Nullable& exact_, const Nullable& caliper_dist_, const Nullable& caliper_covs_, const Nullable& caliper_covs_mat_, const Nullable& antiexact_covs_, const Nullable& unit_id_, const bool& disl_prog); +RcppExport SEXP _MatchIt_nn_matchC_mahcovs_closest(SEXP treatSEXP, SEXP ratioSEXP, SEXP discardedSEXP, SEXP reuse_maxSEXP, SEXP mah_covsSEXP, SEXP distance_SEXP, SEXP exact_SEXP, SEXP caliper_dist_SEXP, SEXP caliper_covs_SEXP, SEXP caliper_covs_mat_SEXP, SEXP antiexact_covs_SEXP, SEXP unit_id_SEXP, SEXP disl_progSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const IntegerVector& >::type treat(treatSEXP); + Rcpp::traits::input_parameter< const IntegerVector& >::type ratio(ratioSEXP); + Rcpp::traits::input_parameter< const LogicalVector& >::type discarded(discardedSEXP); + Rcpp::traits::input_parameter< const int& >::type reuse_max(reuse_maxSEXP); + Rcpp::traits::input_parameter< const NumericMatrix& >::type mah_covs(mah_covsSEXP); + Rcpp::traits::input_parameter< const Nullable& >::type distance_(distance_SEXP); + Rcpp::traits::input_parameter< const Nullable& >::type exact_(exact_SEXP); + Rcpp::traits::input_parameter< const Nullable& >::type caliper_dist_(caliper_dist_SEXP); + Rcpp::traits::input_parameter< const Nullable& >::type caliper_covs_(caliper_covs_SEXP); + Rcpp::traits::input_parameter< const Nullable& >::type caliper_covs_mat_(caliper_covs_mat_SEXP); + Rcpp::traits::input_parameter< const Nullable& >::type antiexact_covs_(antiexact_covs_SEXP); + Rcpp::traits::input_parameter< const Nullable& >::type unit_id_(unit_id_SEXP); + Rcpp::traits::input_parameter< const bool& >::type disl_prog(disl_progSEXP); + rcpp_result_gen = Rcpp::wrap(nn_matchC_mahcovs_closest(treat, ratio, discarded, reuse_max, mah_covs, distance_, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, disl_prog)); + return rcpp_result_gen; +END_RCPP +} // nn_matchC_vec -IntegerMatrix nn_matchC_vec(const IntegerVector& treat_, const IntegerVector& ord_, const IntegerVector& ratio_, const LogicalVector& discarded_, const int& reuse_max, const NumericVector& distance_, const Nullable& exact_, const Nullable& caliper_dist_, const Nullable& caliper_covs_, const Nullable& caliper_covs_mat_, const Nullable& antiexact_covs_, const Nullable& unit_id_, const bool& disl_prog); -RcppExport SEXP _MatchIt_nn_matchC_vec(SEXP treat_SEXP, SEXP ord_SEXP, SEXP ratio_SEXP, SEXP discarded_SEXP, SEXP reuse_maxSEXP, SEXP distance_SEXP, SEXP exact_SEXP, SEXP caliper_dist_SEXP, SEXP caliper_covs_SEXP, SEXP caliper_covs_mat_SEXP, SEXP antiexact_covs_SEXP, SEXP unit_id_SEXP, SEXP disl_progSEXP) { +IntegerMatrix nn_matchC_vec(const IntegerVector& treat_, const IntegerVector& ord, const IntegerVector& ratio, const LogicalVector& discarded, const int& reuse_max, const int& focal_, const NumericVector& distance, const Nullable& exact_, const Nullable& caliper_dist_, const Nullable& caliper_covs_, const Nullable& caliper_covs_mat_, const Nullable& antiexact_covs_, const Nullable& unit_id_, const bool& disl_prog); +RcppExport SEXP _MatchIt_nn_matchC_vec(SEXP treat_SEXP, SEXP ordSEXP, SEXP ratioSEXP, SEXP discardedSEXP, SEXP reuse_maxSEXP, SEXP focal_SEXP, SEXP distanceSEXP, SEXP exact_SEXP, SEXP caliper_dist_SEXP, SEXP caliper_covs_SEXP, SEXP caliper_covs_mat_SEXP, SEXP antiexact_covs_SEXP, SEXP unit_id_SEXP, SEXP disl_progSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; Rcpp::traits::input_parameter< const IntegerVector& >::type treat_(treat_SEXP); - Rcpp::traits::input_parameter< const IntegerVector& >::type ord_(ord_SEXP); - Rcpp::traits::input_parameter< const IntegerVector& >::type ratio_(ratio_SEXP); - Rcpp::traits::input_parameter< const LogicalVector& >::type discarded_(discarded_SEXP); + Rcpp::traits::input_parameter< const IntegerVector& >::type ord(ordSEXP); + Rcpp::traits::input_parameter< const IntegerVector& >::type ratio(ratioSEXP); + Rcpp::traits::input_parameter< const LogicalVector& >::type discarded(discardedSEXP); Rcpp::traits::input_parameter< const int& >::type reuse_max(reuse_maxSEXP); - Rcpp::traits::input_parameter< const NumericVector& >::type distance_(distance_SEXP); + Rcpp::traits::input_parameter< const int& >::type focal_(focal_SEXP); + Rcpp::traits::input_parameter< const NumericVector& >::type distance(distanceSEXP); Rcpp::traits::input_parameter< const Nullable& >::type exact_(exact_SEXP); Rcpp::traits::input_parameter< const Nullable& >::type caliper_dist_(caliper_dist_SEXP); Rcpp::traits::input_parameter< const Nullable& >::type caliper_covs_(caliper_covs_SEXP); @@ -90,34 +140,82 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const Nullable& >::type antiexact_covs_(antiexact_covs_SEXP); Rcpp::traits::input_parameter< const Nullable& >::type unit_id_(unit_id_SEXP); Rcpp::traits::input_parameter< const bool& >::type disl_prog(disl_progSEXP); - rcpp_result_gen = Rcpp::wrap(nn_matchC_vec(treat_, ord_, ratio_, discarded_, reuse_max, distance_, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, disl_prog)); + rcpp_result_gen = Rcpp::wrap(nn_matchC_vec(treat_, ord, ratio, discarded, reuse_max, focal_, distance, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, disl_prog)); + return rcpp_result_gen; +END_RCPP +} +// nn_matchC_vec_closest +IntegerMatrix nn_matchC_vec_closest(const IntegerVector& treat, const IntegerVector& ratio, const LogicalVector& discarded, const int& reuse_max, const NumericVector& distance, const Nullable& exact_, const Nullable& caliper_dist_, const Nullable& caliper_covs_, const Nullable& caliper_covs_mat_, const Nullable& antiexact_covs_, const Nullable& unit_id_, const bool& disl_prog); +RcppExport SEXP _MatchIt_nn_matchC_vec_closest(SEXP treatSEXP, SEXP ratioSEXP, SEXP discardedSEXP, SEXP reuse_maxSEXP, SEXP distanceSEXP, SEXP exact_SEXP, SEXP caliper_dist_SEXP, SEXP caliper_covs_SEXP, SEXP caliper_covs_mat_SEXP, SEXP antiexact_covs_SEXP, SEXP unit_id_SEXP, SEXP disl_progSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const IntegerVector& >::type treat(treatSEXP); + Rcpp::traits::input_parameter< const IntegerVector& >::type ratio(ratioSEXP); + Rcpp::traits::input_parameter< const LogicalVector& >::type discarded(discardedSEXP); + Rcpp::traits::input_parameter< const int& >::type reuse_max(reuse_maxSEXP); + Rcpp::traits::input_parameter< const NumericVector& >::type distance(distanceSEXP); + Rcpp::traits::input_parameter< const Nullable& >::type exact_(exact_SEXP); + Rcpp::traits::input_parameter< const Nullable& >::type caliper_dist_(caliper_dist_SEXP); + Rcpp::traits::input_parameter< const Nullable& >::type caliper_covs_(caliper_covs_SEXP); + Rcpp::traits::input_parameter< const Nullable& >::type caliper_covs_mat_(caliper_covs_mat_SEXP); + Rcpp::traits::input_parameter< const Nullable& >::type antiexact_covs_(antiexact_covs_SEXP); + Rcpp::traits::input_parameter< const Nullable& >::type unit_id_(unit_id_SEXP); + Rcpp::traits::input_parameter< const bool& >::type disl_prog(disl_progSEXP); + rcpp_result_gen = Rcpp::wrap(nn_matchC_vec_closest(treat, ratio, discarded, reuse_max, distance, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, disl_prog)); return rcpp_result_gen; END_RCPP } // pairdistsubC -double pairdistsubC(const NumericVector& x_, const IntegerVector& t_, const IntegerVector& s_, const int& num_sub); -RcppExport SEXP _MatchIt_pairdistsubC(SEXP x_SEXP, SEXP t_SEXP, SEXP s_SEXP, SEXP num_subSEXP) { +double pairdistsubC(const NumericVector& x, const IntegerVector& t, const IntegerVector& s); +RcppExport SEXP _MatchIt_pairdistsubC(SEXP xSEXP, SEXP tSEXP, SEXP sSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; - Rcpp::traits::input_parameter< const NumericVector& >::type x_(x_SEXP); - Rcpp::traits::input_parameter< const IntegerVector& >::type t_(t_SEXP); - Rcpp::traits::input_parameter< const IntegerVector& >::type s_(s_SEXP); - Rcpp::traits::input_parameter< const int& >::type num_sub(num_subSEXP); - rcpp_result_gen = Rcpp::wrap(pairdistsubC(x_, t_, s_, num_sub)); + Rcpp::traits::input_parameter< const NumericVector& >::type x(xSEXP); + Rcpp::traits::input_parameter< const IntegerVector& >::type t(tSEXP); + Rcpp::traits::input_parameter< const IntegerVector& >::type s(sSEXP); + rcpp_result_gen = Rcpp::wrap(pairdistsubC(x, t, s)); return rcpp_result_gen; END_RCPP } // subclass2mmC -IntegerMatrix subclass2mmC(const IntegerVector& subclass, const IntegerVector& treat, const int& focal); -RcppExport SEXP _MatchIt_subclass2mmC(SEXP subclassSEXP, SEXP treatSEXP, SEXP focalSEXP) { +IntegerMatrix subclass2mmC(const IntegerVector& subclass_, const IntegerVector& treat, const int& focal); +RcppExport SEXP _MatchIt_subclass2mmC(SEXP subclass_SEXP, SEXP treatSEXP, SEXP focalSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; - Rcpp::traits::input_parameter< const IntegerVector& >::type subclass(subclassSEXP); + Rcpp::traits::input_parameter< const IntegerVector& >::type subclass_(subclass_SEXP); Rcpp::traits::input_parameter< const IntegerVector& >::type treat(treatSEXP); Rcpp::traits::input_parameter< const int& >::type focal(focalSEXP); - rcpp_result_gen = Rcpp::wrap(subclass2mmC(subclass, treat, focal)); + rcpp_result_gen = Rcpp::wrap(subclass2mmC(subclass_, treat, focal)); + return rcpp_result_gen; +END_RCPP +} +// mm2subclassC +IntegerVector mm2subclassC(const IntegerMatrix& mm, const IntegerVector& treat, const Nullable& focal); +RcppExport SEXP _MatchIt_mm2subclassC(SEXP mmSEXP, SEXP treatSEXP, SEXP focalSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const IntegerMatrix& >::type mm(mmSEXP); + Rcpp::traits::input_parameter< const IntegerVector& >::type treat(treatSEXP); + Rcpp::traits::input_parameter< const Nullable& >::type focal(focalSEXP); + rcpp_result_gen = Rcpp::wrap(mm2subclassC(mm, treat, focal)); + return rcpp_result_gen; +END_RCPP +} +// subclass_scootC +IntegerVector subclass_scootC(const IntegerVector& subclass_, const IntegerVector& treat_, const NumericVector& x_, const int& min_n); +RcppExport SEXP _MatchIt_subclass_scootC(SEXP subclass_SEXP, SEXP treat_SEXP, SEXP x_SEXP, SEXP min_nSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const IntegerVector& >::type subclass_(subclass_SEXP); + Rcpp::traits::input_parameter< const IntegerVector& >::type treat_(treat_SEXP); + Rcpp::traits::input_parameter< const NumericVector& >::type x_(x_SEXP); + Rcpp::traits::input_parameter< const int& >::type min_n(min_nSEXP); + rcpp_result_gen = Rcpp::wrap(subclass_scootC(subclass_, treat_, x_, min_n)); return rcpp_result_gen; END_RCPP } @@ -134,14 +232,28 @@ BEGIN_RCPP END_RCPP } // weights_matrixC -NumericVector weights_matrixC(const IntegerMatrix& mm, const IntegerVector& treat); -RcppExport SEXP _MatchIt_weights_matrixC(SEXP mmSEXP, SEXP treatSEXP) { +NumericVector weights_matrixC(const IntegerMatrix& mm, const IntegerVector& treat_, const Nullable& focal); +RcppExport SEXP _MatchIt_weights_matrixC(SEXP mmSEXP, SEXP treat_SEXP, SEXP focalSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; Rcpp::traits::input_parameter< const IntegerMatrix& >::type mm(mmSEXP); - Rcpp::traits::input_parameter< const IntegerVector& >::type treat(treatSEXP); - rcpp_result_gen = Rcpp::wrap(weights_matrixC(mm, treat)); + Rcpp::traits::input_parameter< const IntegerVector& >::type treat_(treat_SEXP); + Rcpp::traits::input_parameter< const Nullable& >::type focal(focalSEXP); + rcpp_result_gen = Rcpp::wrap(weights_matrixC(mm, treat_, focal)); + return rcpp_result_gen; +END_RCPP +} +// weights_subclassC +NumericVector weights_subclassC(const IntegerVector& subclass_, const IntegerVector& treat_, const Nullable& focal_); +RcppExport SEXP _MatchIt_weights_subclassC(SEXP subclass_SEXP, SEXP treat_SEXP, SEXP focal_SEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const IntegerVector& >::type subclass_(subclass_SEXP); + Rcpp::traits::input_parameter< const IntegerVector& >::type treat_(treat_SEXP); + Rcpp::traits::input_parameter< const Nullable& >::type focal_(focal_SEXP); + rcpp_result_gen = Rcpp::wrap(weights_subclassC(subclass_, treat_, focal_)); return rcpp_result_gen; END_RCPP } @@ -161,14 +273,21 @@ RcppExport SEXP _MatchIt_RcppExport_registerCCallable() { } static const R_CallMethodDef CallEntries[] = { - {"_MatchIt_dist_to_matrixC", (DL_FUNC) &_MatchIt_dist_to_matrixC, 1}, - {"_MatchIt_nn_matchC", (DL_FUNC) &_MatchIt_nn_matchC, 15}, + {"_MatchIt_eucdistC_N1xN0", (DL_FUNC) &_MatchIt_eucdistC_N1xN0, 2}, + {"_MatchIt_get_splitsC", (DL_FUNC) &_MatchIt_get_splitsC, 2}, + {"_MatchIt_has_n_unique", (DL_FUNC) &_MatchIt_has_n_unique, 2}, + {"_MatchIt_nn_matchC", (DL_FUNC) &_MatchIt_nn_matchC, 16}, {"_MatchIt_nn_matchC_closest", (DL_FUNC) &_MatchIt_nn_matchC_closest, 12}, - {"_MatchIt_nn_matchC_vec", (DL_FUNC) &_MatchIt_nn_matchC_vec, 13}, - {"_MatchIt_pairdistsubC", (DL_FUNC) &_MatchIt_pairdistsubC, 4}, + {"_MatchIt_nn_matchC_mahcovs_closest", (DL_FUNC) &_MatchIt_nn_matchC_mahcovs_closest, 13}, + {"_MatchIt_nn_matchC_vec", (DL_FUNC) &_MatchIt_nn_matchC_vec, 14}, + {"_MatchIt_nn_matchC_vec_closest", (DL_FUNC) &_MatchIt_nn_matchC_vec_closest, 12}, + {"_MatchIt_pairdistsubC", (DL_FUNC) &_MatchIt_pairdistsubC, 3}, {"_MatchIt_subclass2mmC", (DL_FUNC) &_MatchIt_subclass2mmC, 3}, + {"_MatchIt_mm2subclassC", (DL_FUNC) &_MatchIt_mm2subclassC, 3}, + {"_MatchIt_subclass_scootC", (DL_FUNC) &_MatchIt_subclass_scootC, 4}, {"_MatchIt_tabulateC", (DL_FUNC) &_MatchIt_tabulateC, 2}, - {"_MatchIt_weights_matrixC", (DL_FUNC) &_MatchIt_weights_matrixC, 2}, + {"_MatchIt_weights_matrixC", (DL_FUNC) &_MatchIt_weights_matrixC, 3}, + {"_MatchIt_weights_subclassC", (DL_FUNC) &_MatchIt_weights_subclassC, 3}, {"_MatchIt_RcppExport_registerCCallable", (DL_FUNC) &_MatchIt_RcppExport_registerCCallable, 0}, {NULL, NULL, 0} }; From d259d70263e441e84088610ca398857118bdeb7c Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Thu, 24 Oct 2024 18:50:39 -0400 Subject: [PATCH 22/48] Improved tests --- tests/testthat/helpers.R | 25 +++--- tests/testthat/test-method_nearest.R | 109 +++++++++++++++++++++++++++ 2 files changed, 125 insertions(+), 9 deletions(-) create mode 100644 tests/testthat/test-method_nearest.R diff --git a/tests/testthat/helpers.R b/tests/testthat/helpers.R index 72656448..236ceeb3 100644 --- a/tests/testthat/helpers.R +++ b/tests/testthat/helpers.R @@ -12,11 +12,14 @@ expect_good_matchit <- function(m, expect_subclass = NULL, expect_distance = NUL expect_s3_class(m, "matchit") + n <- length(m$treat) + n1 <- sum(m$treat == 1) + #Related to subclass if (!is.null(expect_subclass)) { if (expect_subclass) { expect_false(is.null(m$subclass)) - expect_length(m$sunclass, n) + expect_length(m$subclass, n) expect_true(is.factor(m$subclass)) expect_false(is.null(names(m$subclass))) } @@ -29,23 +32,27 @@ expect_good_matchit <- function(m, expect_subclass = NULL, expect_distance = NUL if (!is.null(expect_match.matrix)) { if (expect_match.matrix) { expect_false(is.null(m$match.matrix)) - expect_type(m$match.matrix, "matrix") + expect_true(is.matrix(m$match.matrix)) expect_true(is.character(m$match.matrix)) - expect_equal(nrow(m$match.matrix), n) - expect_equal(ncol(m$match.matrix), ratio) + expect_equal(nrow(m$match.matrix), n1) + expect_equal(ncol(m$match.matrix), max(ratio[1], attr(ratio, "max.controls"))) expect_false(is.null(rownames(m$match.matrix))) expect_false(any(rownames(m$match.matrix) %in% m$match.matrix)) #Check no duplicates within each row - expect_true(all(apply(m$match.matrix, 1, function(i) anyDuplicated(i[!is.na(i)]) == 0))) + expect_true(all(apply(m$match.matrix, 1, function(i) anyDuplicated(na.omit(i)) == 0))) if (!is.null(replace)) { if (replace) { - #May not be duplicates incidentially; make sure examples induce duplicates - expect_true(!isTRUE(all.equal(anyDuplicated(m$match.matrix), 0))) + #May not be duplicates incidentally; make sure examples induce duplicates + expect_true(!isTRUE(all.equal(anyDuplicated(na.omit(m$match.matrix)), 0))) + + if (!is.null(attr(replace, "reuse.max"))) { + expect_true(max(table(m$match.matrix)) <= attr(replace, "reuse.max")) + } } else { - expect_equal(anyDuplicated(m$match.matrix), 0) + expect_equal(anyDuplicated(na.omit(m$match.matrix)), 0) } } } @@ -55,7 +62,7 @@ expect_good_matchit <- function(m, expect_subclass = NULL, expect_distance = NUL } #Related to distance - if (!is.null(distance)) { + if (!is.null(expect_distance)) { if (expect_distance) { expect_false(is.null(m$distance)) expect_length(m$distance, n) diff --git a/tests/testthat/test-method_nearest.R b/tests/testthat/test-method_nearest.R new file mode 100644 index 00000000..df217448 --- /dev/null +++ b/tests/testthat/test-method_nearest.R @@ -0,0 +1,109 @@ +test_that("distance vector, mah vars, and distance matrix yield identical results", { + set.seed(1234) + n <- 1e3 + p <- runif(n, 0, .4) + x <- runif(n) + g <- sample(1:5, n, TRUE) + a <- rbinom(n, 1, p) + u <- 1:n; u[a == 0] <- sample(u[a == 0][1:round(sum(a == 0)/5)], sum(a == 0), replace = TRUE) + dis <- as.logical(rbinom(n, 1, .1)) + d <- data.frame(p, x, a, g, u, dis) + d$p_ <- d$p + + dd <- euclidean_dist(a ~ p, data = d) + + test_all <- function(..., which = 1:4) { + + M <- list() + if (any(which == 1)) { + m <- matchit(a ~ p + p_, data = d, + distance = d$p, + ...) + expect_good_matchit(m, expect_distance = TRUE, expect_match.matrix = TRUE, + expect_subclass = !m$info$replace, replace = m$info$replace, + ratio = m$info$ratio) + M <- c(M, list(m)) + } + if (any(which == 2)) { + m <- matchit(a ~ p + p_, data = d, + distance = "euclidean", + ...) + expect_good_matchit(m, expect_distance = FALSE, expect_match.matrix = TRUE, + expect_subclass = !m$info$replace, replace = m$info$replace, + ratio = m$info$ratio) + M <- c(M, list(m)) + } + if (any(which == 3)) { + m <- matchit(a ~ p + p_, data = d, + distance = d$p, + mahvars = ~p + p_, + ...) + expect_good_matchit(m, expect_distance = TRUE, expect_match.matrix = TRUE, + expect_subclass = !m$info$replace, replace = m$info$replace, + ratio = m$info$ratio) + M <- c(M, list(m)) + } + if (any(which == 4)) { + m <- matchit(a ~ p + p_, data = d, + distance = dd, + ...) + expect_good_matchit(m, expect_distance = FALSE, expect_match.matrix = TRUE, + expect_subclass = !m$info$replace, replace = m$info$replace, + ratio = m$info$ratio) + M <- c(M, list(m)) + } + + all(unlist(lapply(combn(seq_along(M), 2, simplify = FALSE), + function(i) isTRUE(all.equal(M[[i[1]]]$match.matrix, + M[[i[2]]]$match.matrix))))) + } + + expect_true(test_all(m.order = "data")) + expect_true(test_all(m.order = "closest")) + expect_true(test_all(m.order = "largest", which = c(1, 3))) + + expect_true(test_all(m.order = "data", ratio = 2)) + expect_true(test_all(m.order = "closest", ratio = 2)) + + expect_true(test_all(m.order = "data", ratio = 2, max.controls = 3, which = c(1, 3))) + expect_true(test_all(m.order = "closest", ratio = 2, max.controls = 3, which = c(1, 3))) + + expect_true(test_all(m.order = "data", ratio = 2, replace = TRUE)) + expect_true(test_all(m.order = "closest", ratio = 2, replace = TRUE)) + + expect_true(test_all(m.order = "data", ratio = 2, reuse.max = 3)) + expect_true(test_all(m.order = "closest", ratio = 2, reuse.max = 3)) + + expect_true(test_all(m.order = "data", ratio = 2, caliper = .001, std.caliper = FALSE, which = c(1, 3))) + expect_true(test_all(m.order = "closest", ratio = 2, caliper = .001, std.caliper = FALSE, which = c(1, 3))) + + expect_true(test_all(m.order = "data", ratio = 2, caliper = c(p = .001), std.caliper = FALSE)) + expect_true(test_all(m.order = "closest", ratio = 2, caliper = c(p = .001), std.caliper = FALSE)) + + expect_true(test_all(m.order = "data", ratio = 2, caliper = c(p = .001), std.caliper = FALSE, reuse.max = 3)) + expect_true(test_all(m.order = "closest", ratio = 2, caliper = c(p = .001), std.caliper = FALSE, reuse.max = 3)) + + expect_true(test_all(m.order = "data", ratio = 2, exact = ~g)) + expect_true(test_all(m.order = "closest", ratio = 2, exact = ~g)) + + expect_true(test_all(m.order = "data", ratio = 2, exact = ~g, replace = TRUE)) + expect_true(test_all(m.order = "closest", ratio = 2, exact = ~g, replace = TRUE)) + + expect_true(test_all(m.order = "data", ratio = 2, antiexact = ~g)) + expect_true(test_all(m.order = "closest", ratio = 2, antiexact = ~g)) + + expect_true(test_all(m.order = "data", ratio = 2, antiexact = ~g, replace = TRUE)) + expect_true(test_all(m.order = "closest", ratio = 2, antiexact = ~g, replace = TRUE)) + + expect_true(test_all(m.order = "data", ratio = 2, discard = dis)) + expect_true(test_all(m.order = "closest", ratio = 2, discard = dis)) + + expect_true(test_all(m.order = "data", ratio = 2, unit.id = ~u)) + expect_true(test_all(m.order = "closest", ratio = 2, unit.id = ~u)) + + expect_true(test_all(m.order = "data", ratio = 2, unit.id = ~u, reuse.max = 3)) + expect_true(test_all(m.order = "closest", ratio = 2, unit.id = ~u, reuse.max = 3)) + + expect_true(test_all(m.order = "data", ratio = 2, unit.id = ~u, replace = TRUE)) + expect_true(test_all(m.order = "closest", ratio = 2, unit.id = ~u, replace = TRUE)) +}) From 87aad0e4e10b595577819f2a31084ee656a9813a Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Thu, 24 Oct 2024 19:07:23 -0400 Subject: [PATCH 23/48] Rcpp updates --- R/RcppExports.R | 4 ---- src/RcppExports.cpp | 14 -------------- 2 files changed, 18 deletions(-) diff --git a/R/RcppExports.R b/R/RcppExports.R index 2ba81509..7d57c39d 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -57,10 +57,6 @@ weights_matrixC <- function(mm, treat_, focal = NULL) { .Call(`_MatchIt_weights_matrixC`, mm, treat_, focal) } -weights_subclassC <- function(subclass_, treat_, focal_ = NULL) { - .Call(`_MatchIt_weights_subclassC`, subclass_, treat_, focal_) -} - # Register entry points for exported C++ functions methods::setLoadAction(function(ns) { .Call(`_MatchIt_RcppExport_registerCCallable`) diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index c3245b02..92277c84 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -244,19 +244,6 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } -// weights_subclassC -NumericVector weights_subclassC(const IntegerVector& subclass_, const IntegerVector& treat_, const Nullable& focal_); -RcppExport SEXP _MatchIt_weights_subclassC(SEXP subclass_SEXP, SEXP treat_SEXP, SEXP focal_SEXP) { -BEGIN_RCPP - Rcpp::RObject rcpp_result_gen; - Rcpp::RNGScope rcpp_rngScope_gen; - Rcpp::traits::input_parameter< const IntegerVector& >::type subclass_(subclass_SEXP); - Rcpp::traits::input_parameter< const IntegerVector& >::type treat_(treat_SEXP); - Rcpp::traits::input_parameter< const Nullable& >::type focal_(focal_SEXP); - rcpp_result_gen = Rcpp::wrap(weights_subclassC(subclass_, treat_, focal_)); - return rcpp_result_gen; -END_RCPP -} // validate (ensure exported C++ functions exist before calling them) static int _MatchIt_RcppExport_validate(const char* sig) { @@ -287,7 +274,6 @@ static const R_CallMethodDef CallEntries[] = { {"_MatchIt_subclass_scootC", (DL_FUNC) &_MatchIt_subclass_scootC, 4}, {"_MatchIt_tabulateC", (DL_FUNC) &_MatchIt_tabulateC, 2}, {"_MatchIt_weights_matrixC", (DL_FUNC) &_MatchIt_weights_matrixC, 3}, - {"_MatchIt_weights_subclassC", (DL_FUNC) &_MatchIt_weights_subclassC, 3}, {"_MatchIt_RcppExport_registerCCallable", (DL_FUNC) &_MatchIt_RcppExport_registerCCallable, 0}, {NULL, NULL, 0} }; From 3ca53e421558c5b01567a5527aa857570253e91d Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Tue, 12 Nov 2024 02:21:53 -0500 Subject: [PATCH 24/48] Metadata updates --- DESCRIPTION | 2 +- NAMESPACE | 4 +--- NEWS.md | 12 ++++++++++++ R/MatchIt-package.R | 5 +---- 4 files changed, 15 insertions(+), 8 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 305c94d8..e650109a 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,5 +1,5 @@ Package: MatchIt -Version: 4.5.5.9001 +Version: 4.5.5.9003 Title: Nonparametric Preprocessing for Parametric Causal Inference Description: Selects matched samples of the original treated and control groups with similar covariate distributions -- can be diff --git a/NAMESPACE b/NAMESPACE index b986c168..54f41307 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -21,13 +21,11 @@ export(scaled_euclidean_dist) import(graphics) import(stats) importFrom(Rcpp,evalCpp) -importFrom(Rcpp,sourceCpp) importFrom(grDevices,devAskNewPage) importFrom(grDevices,nclass.FD) importFrom(grDevices,nclass.Sturges) importFrom(grDevices,nclass.scott) importFrom(utils,capture.output) importFrom(utils,combn) -importFrom(utils,setTxtProgressBar) -importFrom(utils,txtProgressBar) +importFrom(utils,hasName) useDynLib(MatchIt, .registration = TRUE) diff --git a/NEWS.md b/NEWS.md index 92b54388..6f31204c 100644 --- a/NEWS.md +++ b/NEWS.md @@ -10,16 +10,28 @@ output: Most improvements are related to performance. Some of these dramatically improve speeds for large datasets. Most come from improvements to `Rcpp` code. +* When using `method = "nearest"`, `m.order` can now be set to `"farthest"` to prioritize hard-to-match treated units. Note this **does not** implement "far matching" but simply changes the order in which the closest matches are selected. + * Speed improvements to `method = "nearest"`, especially when matching on a propensity score. * Speed improvements to `summary()` when `pair.dist = TRUE` and a `match.matrix` component is not included in the output (e.g., for `method = "full"` or `method = "quick"`). * Speed improvements to `method = "subclass"` with `min.n` greater than 0. +* A new `normalize` argument has been added to `matchit()`. When set to `TRUE` (the default, which used to be the only option), the nonzero weights in each treatment group are rescaled to have an average of 1. When `FALSE`, the weights generated directly by the matching are returned instead. + * When using `method = "nearest"` with `m.order = "closest"`, the full distance matrix is no longer computed, which increases support for larger samples. This uses an adaptation of an algorithm described by [Rassen et al. (2012)](https://doi.org/10.1002/pds.3263). * When using `method = "nearest"` with `verbose = TRUE`, the progress bar now displays an estimate of how much time remains. +* When using `method = "nearest"` with `m.order = "closest"` and `ratio` greater than 1, all eligible units will receive their first match before any receive their second, etc. Previously, the closest pairs would be matched regardless of whether other units had been matched. This ensures consistency with other `m.order` arguments. + +* Speed and memory improvements to `method = "cem"` with many covariates and a large sample size. Previous versions used a Cartesian expansion of all levels of factor variables, which could easily explode. + +* When using `method = "cem"` with `k2k = TRUE`, `m.order` can be set to select the matching order. Allowable options include `"data"` (the default), `"closest"`, `"farthest"`, and `"random"`. `"closest"` is recommended, but `"data"` is the default for now to remain consistent with previous versions. + +* Documentation updates. + * Fixed a bug when using `method = "optimal"` or `method = "full"` with `discard` specified and `data` given as a tibble (`tbl_df` object). (#185) * Fixed a bug when using `method = "cardinality"` with a single covariate. (#194) diff --git a/R/MatchIt-package.R b/R/MatchIt-package.R index d38adc78..5efe8e1a 100644 --- a/R/MatchIt-package.R +++ b/R/MatchIt-package.R @@ -2,7 +2,6 @@ "_PACKAGE" ## usethis namespace: start -#' #' @import graphics #' @import stats #' @importFrom grDevices devAskNewPage @@ -10,11 +9,9 @@ #' @importFrom grDevices nclass.scott #' @importFrom grDevices nclass.Sturges #' @importFrom Rcpp evalCpp -#' @importFrom Rcpp sourceCpp #' @importFrom utils capture.output #' @importFrom utils combn -#' @importFrom utils setTxtProgressBar -#' @importFrom utils txtProgressBar +#' @importFrom utils hasName #' @useDynLib MatchIt, .registration = TRUE ## usethis namespace: end NULL From 16e9e60043a496d8cc230c40f43e03e5ce2f93d2 Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Tue, 12 Nov 2024 02:24:26 -0500 Subject: [PATCH 25/48] Cleaning and improvements --- R/add_s.weights.R | 8 +- R/aux_functions.R | 64 +++++--- R/dist_functions.R | 38 +++-- R/distance2_methods.R | 83 ++++++---- R/get_weights_from_mm.R | 11 +- R/get_weights_from_subclass.R | 68 ++++---- R/input_processing.R | 187 +++++++++++----------- R/match.data.R | 43 +++--- R/match.qoi.R | 29 ++-- R/matchit2cardinality.R | 36 ++--- R/matchit2exact.R | 9 +- R/matchit2full.R | 30 ++-- R/matchit2genetic.R | 18 ++- R/matchit2nearest.R | 282 ++++++++++++++++------------------ R/matchit2optimal.R | 46 +++--- R/matchit2quick.R | 22 ++- R/matchit2subclass.R | 24 ++- R/plot.matchit.R | 152 +++++++++--------- R/rbind.matchdata.R | 5 +- R/summary.matchit.R | 55 ++++--- 20 files changed, 644 insertions(+), 566 deletions(-) diff --git a/R/add_s.weights.R b/R/add_s.weights.R index 366356fc..9f4c08f8 100644 --- a/R/add_s.weights.R +++ b/R/add_s.weights.R @@ -69,10 +69,13 @@ add_s.weights <- function(m, if (is_null(data)) { if (is_not_null(m$model)) { env <- attributes(terms(m$model))$.Environment - } else { + } + else { env <- parent.frame() } + data <- eval(m$call$data, envir = env) + if (is_null(data)) { .err("a dataset could not be found. Please supply an argument to `data` containing the original dataset used in the matching") } @@ -95,7 +98,7 @@ add_s.weights <- function(m, .err("if `s.weights` is specified a string, a data frame containing the named variable must be supplied to `data`") } - if (!all(s.weights %in% names(data))) { + if (!all(hasName(data, s.weights))) { .err("the name supplied to `s.weights` must be a variable in `data`") } @@ -124,6 +127,7 @@ add_s.weights <- function(m, } chk::chk_not_any_na(s.weights) + if (length(s.weights) != length(m$treat)) { .err("`s.weights` must be the same length as the treatment vector") } diff --git a/R/aux_functions.R b/R/aux_functions.R index 937297f3..4dc54dbc 100644 --- a/R/aux_functions.R +++ b/R/aux_functions.R @@ -134,13 +134,18 @@ exactify <- function(X, nam = NULL, sep = "|", include_vars = FALSE, justify = " } if (is.matrix(X)) { - X <- setNames(lapply(seq_len(ncol(X)), function(i) X[,i]), colnames(X)) + X <- as.data.frame.matrix(X) } - - if (!is.list(X)) { + else if (!is.list(X)) { stop("X must be a matrix, data frame, or list.") } + X <- X[lengths(X) > 0] + + if (is_null(X)) { + return(NULL) + } + for (i in seq_along(X)) { unique_x <- { if (is.factor(X[[i]])) levels(X[[i]]) @@ -160,17 +165,13 @@ exactify <- function(X, nam = NULL, sep = "|", include_vars = FALSE, justify = " X[[i]] <- factor(X[[i]], levels = unique_x, labels = lev) } - all_levels <- do.call("paste", c(rev(expand.grid(rev(lapply(X, levels)), - KEEP.OUT.ATTRS = FALSE, stringsAsFactors = FALSE)), - sep = sep)) - - out <- do.call("paste", c(X, sep = sep)) - - out <- factor(out, levels = all_levels[all_levels %in% out]) + out <- interaction2(X, sep = sep, lex.order = if (include_vars) TRUE else NULL) - if (is_not_null(nam)) names(out) <- nam + if (is_null(nam)) { + return(out) + } - out + setNames(out, nam) } #Get covariates (RHS) vars from formula @@ -178,7 +179,7 @@ get.covs.matrix <- function(formula = NULL, data = NULL) { if (is_null(formula)) { fnames <- colnames(data) - fnames[!startsWith(fnames, "`")] <- paste0("`", fnames[!startsWith(fnames, "`")], "`") + fnames[!startsWith(fnames, "`")] <- add_quotes(fnames[!startsWith(fnames, "`")], "`") formula <- reformulate(fnames) } else { @@ -189,15 +190,17 @@ get.covs.matrix <- function(formula = NULL, data = NULL) { na.action = na.pass) chars.in.mf <- vapply(mf, is.character, logical(1L)) - mf[chars.in.mf] <- lapply(mf[chars.in.mf], factor) + mf[chars.in.mf] <- lapply(mf[chars.in.mf], as.factor) mf <- droplevels(mf) X <- model.matrix(formula, data = mf, contrasts.arg = lapply(Filter(is.factor, mf), contrasts, contrasts = FALSE)) + assign <- attr(X, "assign")[-1] X <- X[,-1, drop = FALSE] + attr(X, "assign") <- assign X @@ -252,6 +255,7 @@ pooled_cov <- function(X, t, w = NULL) { X[in_t, j] <- X[in_t, j] - mean(X[in_t, j]) } } + return(cov(X)*(n-1)/(n-length(unique_t))) } @@ -261,6 +265,7 @@ pooled_cov <- function(X, t, w = NULL) { X[in_t, j] <- X[in_t, j] - wm(X[in_t, j], w[in_t]) } } + cov.wt(X, w)$cov } @@ -269,10 +274,14 @@ pooled_sd <- function(X, t, w = NULL, bin.var = NULL, contribution = "proportion unique_t <- unique(t) if (is_null(dim(X))) X <- matrix(X, nrow = length(X)) n <- nrow(X) - if (is_null(bin.var)) bin.var <- apply(X, 2, function(x) all(x == 0 | x == 1)) + + if (is_null(bin.var)) { + bin.var <- apply(X, 2, function(x) all(x == 0 | x == 1)) + } if (contribution == "equal") { vars <- matrix(0, nrow = length(unique_t), ncol = ncol(X)) + for (i in seq_along(unique_t)) { in_t <- which(t == unique_t[i]) vars[i,] <- vapply(seq_len(ncol(X)), function(j) { @@ -281,6 +290,7 @@ pooled_sd <- function(X, t, w = NULL, bin.var = NULL, contribution = "proportion wvar(x[in_t], w = w[in_t], bin.var = b) }, numeric(1L)) } + pooled_var <- colMeans(vars) } else { @@ -335,12 +345,20 @@ ESS <- function(w) { #Compute sample sizes nn <- function(treat, weights, discarded = NULL, s.weights = NULL) { - if (is_null(discarded)) discarded <- rep(FALSE, length(treat)) - if (is_null(s.weights)) s.weights <- rep(1, length(treat)) + if (is_null(discarded)) { + discarded <- rep.int(FALSE, length(treat)) + } + + if (is_null(s.weights)) { + s.weights <- rep.int(1, length(treat)) + } + weights <- weights * s.weights - n <- matrix(0, ncol=2, nrow=6, dimnames = list(c("All (ESS)", "All", "Matched (ESS)", - "Matched", "Unmatched","Discarded"), - c("Control", "Treated"))) + + n <- matrix(0, ncol = 2L, nrow = 6L, + dimnames = list(c("All (ESS)", "All", "Matched (ESS)", + "Matched", "Unmatched","Discarded"), + c("Control", "Treated"))) # Control Treated n["All (ESS)",] <- c(ESS(s.weights[treat==0]), ESS(s.weights[treat==1])) @@ -357,7 +375,11 @@ nn <- function(treat, weights, discarded = NULL, s.weights = NULL) { qn <- function(treat, subclass, discarded = NULL) { treat <- factor(treat, levels = 0:1, labels = c("Control", "Treated")) - if (is_null(discarded)) discarded <- rep(FALSE, length(treat)) + + if (is_null(discarded)) { + discarded <- rep.int(FALSE, length(treat)) + } + qn <- table(treat[!discarded], subclass[!discarded]) if (any(is.na(subclass) & !discarded)) { diff --git a/R/dist_functions.R b/R/dist_functions.R index c879e22b..7ffa48e4 100644 --- a/R/dist_functions.R +++ b/R/dist_functions.R @@ -225,13 +225,15 @@ transform_covariates <- function(formula = NULL, data = NULL, method = "mahalano method <- "euclidean" X <- X[, 1, drop = FALSE] } - else if (length(no_variance) > 0) { + else if (is_not_null(no_variance)) { X <- X[, -no_variance, drop = FALSE] } method <- match_arg(method, matchit_distances()) - if (is_null(discarded)) discarded <- rep(FALSE, nrow(X)) + if (is_null(discarded)) { + discarded <- rep.int(FALSE, nrow(X)) + } if (method == "mahalanobis") { # X <- sweep(X, 2, colMeans(X)) @@ -270,7 +272,9 @@ transform_covariates <- function(formula = NULL, data = NULL, method = "mahalano X_r <- matrix(0, nrow = sum(!discarded), ncol = ncol(X), dimnames = list(rownames(X)[!discarded], colnames(X))) - for (i in seq_len(ncol(X_r))) X_r[,i] <- rank(X[!discarded, i]) + for (i in seq_len(ncol(X_r))) { + X_r[,i] <- rank(X[!discarded, i]) + } var_r <- { if (is_null(s.weights)) cov(X_r) @@ -293,7 +297,9 @@ transform_covariates <- function(formula = NULL, data = NULL, method = "mahalano if (any(discarded)) { X_r <- array(0, dim = dim(X), dimnames = dimnames(X)) - for (i in seq_len(ncol(X_r))) X_r[!discarded,i] <- rank(X[!discarded,i]) + for (i in seq_len(ncol(X_r))) { + X_r[!discarded,i] <- rank(X[!discarded,i]) + } } X <- mahalanobize(X_r, inv_var) @@ -326,6 +332,7 @@ transform_covariates <- function(formula = NULL, data = NULL, method = "mahalano } attr(X, "treat") <- treat + X } @@ -359,7 +366,10 @@ eucdist_internal <- function(X, treat = NULL) { get.covs.matrix.for.dist <- function(formula = NULL, data = NULL) { if (is_null(formula)) { - if (is_null(colnames(data))) colnames(data) <- paste0("X", seq_len(ncol(data))) + if (is_null(colnames(data))) { + colnames(data) <- paste0("X", seq_len(ncol(data))) + } + fnames <- colnames(data) fnames[!startsWith(fnames, "`")] <- add_quotes(fnames[!startsWith(fnames, "`")], "`") data <- as.data.frame(data) @@ -387,9 +397,9 @@ get.covs.matrix.for.dist <- function(formula = NULL, data = NULL) { contrasts.arg = lapply(Filter(is.factor, mf), function(x) contrasts(x, contrasts = FALSE)/sqrt(2))) - if (ncol(X) > 1) { - assign <- attr(X, "assign")[-1] - X <- X[, -1, drop = FALSE] + if (ncol(X) > 1L) { + assign <- attr(X, "assign")[-1L] + X <- X[, -1L, drop = FALSE] } attr(X, "assign") <- assign @@ -399,7 +409,9 @@ get.covs.matrix.for.dist <- function(formula = NULL, data = NULL) { } .check_X <- function(X) { - if (isTRUE(attr(X, "checked"))) return(X) + if (isTRUE(attr(X, "checked"))) { + return(X) + } treat <- attr(X, "treat") @@ -411,12 +423,10 @@ get.covs.matrix.for.dist <- function(formula = NULL, data = NULL) { dimnames = list(names(X), NULL)) } - if (anyNA(X)) { - .err("missing values are not allowed in the covariates") - } + chk::chk_not_any_na(X, "the covariates") if (any(!is.finite(X))) { - .err("Non-finite values are not allowed in the covariates") + .err("non-finite values are not allowed in the covariates") } if (!is.numeric(X) || length(dim(X)) != 2) { @@ -430,7 +440,7 @@ get.covs.matrix.for.dist <- function(formula = NULL, data = NULL) { is.cov_like <- function(var, X) { is.numeric(var) && - length(dim(var)) == 2 && + length(dim(var)) == 2L && (missing(X) || all(dim(var) == ncol(X))) && isSymmetric(var) && all(diag(var) >= 0) diff --git a/R/distance2_methods.R b/R/distance2_methods.R index 232d2572..f16b57eb 100644 --- a/R/distance2_methods.R +++ b/R/distance2_methods.R @@ -299,13 +299,13 @@ distance2glm <- function(formula, data = NULL, link = "logit", ...) { linear <- is_not_null(link) && startsWith(as.character(link), "linear") if (linear) link <- sub("linear.", "", as.character(link), fixed = TRUE) - A <- list(...) - A[!names(A) %in% c(names(formals(glm)), names(formals(glm.control)))] <- NULL + args <- c(names(formals(glm)), names(formals(glm.control))) + A <- setNames(lapply(args, ...get, ...), args) + A[lengths(A) == 0L] <- NULL res <- do.call("glm", c(list(formula = formula, data = data, family = quasibinomial(link = link)), A)) - if (linear) pred <- predict(res, type = "link") - else pred <- predict(res, type = "response") + pred <- predict(res, type = if (linear) "link" else "response") list(model = res, distance = pred) } @@ -325,8 +325,7 @@ distance2gam <- function(formula, data = NULL, link = "logit", ...) { weights = weights), A), quote = TRUE) - if (linear) pred <- predict(res, type = "link") - else pred <- predict(res, type = "response") + pred <- predict(res, type = if (linear) "link" else "response") list(model = res, distance = as.numeric(pred)) } @@ -334,8 +333,11 @@ distance2gam <- function(formula, data = NULL, link = "logit", ...) { #distance2rpart----------------- distance2rpart <- function(formula, data = NULL, link = NULL, ...) { rlang::check_installed("rpart") - A <- list(...) - A[!names(A) %in% c(names(formals(rpart::rpart)), names(formals(rpart::rpart.control)))] <- NULL + + args <- c(names(formals(rpart::rpart)), names(formals(rpart::rpart.control))) + A <- setNames(lapply(args, ...get, ...), args) + A[lengths(A) == 0L] <- NULL + A$formula <- formula A$data <- data A$method <- "class" @@ -404,15 +406,16 @@ distance2bart <- function(formula, data = NULL, link = NULL, ...) { linear <- is_not_null(link) && startsWith(as.character(link), "linear") - A <- list(...) - A[!names(A) %in% c(names(formals(dbarts::bart2)), names(formals(dbarts::dbartsControl)))] <- NULL + args <- c(names(formals(dbarts::bart2)), names(formals(dbarts::dbartsControl))) + A <- setNames(lapply(args, ...get, ...), args) + A[lengths(A) == 0L] <- NULL + A$formula <- formula A$data <- data res <- do.call(dbarts::bart2, A) - if (linear) pred <- fitted(res, type = "link") - else pred <- fitted(res, type = "response") + pred <- fitted(res, type = if (linear) "link" else "response") list(model = res, distance = pred) } @@ -477,9 +480,14 @@ distance2elasticnet <- function(formula, data = NULL, link = NULL, ...) { linear <- is_not_null(link) && startsWith(as.character(link), "linear") if (linear) link <- sub("linear.", "", as.character(link), fixed = TRUE) - A <- list(...) - s <- A[["s"]] - A[!names(A) %in% c(names(formals(glmnet::glmnet)), names(formals(glmnet::cv.glmnet)))] <- NULL + s <- ...get("s", ...) + if (is_null(s)) { + s <- "lambda.1se" + } + + args <- c(names(formals(glmnet::glmnet)), names(formals(glmnet::cv.glmnet))) + A <- setNames(lapply(args, ...get, ...), args) + A[lengths(A) == 0L] <- NULL if (is_null(link)) link <- "logit" @@ -488,7 +496,9 @@ distance2elasticnet <- function(formula, data = NULL, link = NULL, ...) { "log" = "poisson", binomial(link = link)) - if (is_null(A[["alpha"]])) A[["alpha"]] <- .5 + if (is_null(A[["alpha"]])) { + A[["alpha"]] <- .5 + } mf <- model.frame(formula, data = data) @@ -497,24 +507,38 @@ distance2elasticnet <- function(formula, data = NULL, link = NULL, ...) { res <- do.call(glmnet::cv.glmnet, A) - if (is_null(s)) s <- "lambda.1se" - pred <- drop(predict(res, newx = A$x, s = s, type = if (linear) "link" else "response")) list(model = res, distance = pred) } distance2lasso <- function(formula, data = NULL, link = NULL, ...) { - A <- list(...) - A$alpha <- 1 - do.call("distance2elasticnet", c(list(formula, data = data, link = link), A), - quote = TRUE) + if ("alpha" %in% ...names()) { + args <- c("s", names(formals(glmnet::glmnet)), names(formals(glmnet::cv.glmnet))) + A <- setNames(lapply(args, ...get, ...), args) + A[lengths(A) == 0L] <- NULL + + A$alpha <- 1 + do.call("distance2elasticnet", c(list(formula, data = data, link = link), A), + quote = TRUE) + } + else { + distance2elasticnet(formula = formula, data = data, link = link, alpha = 1, ...) + } } distance2ridge <- function(formula, data = NULL, link = NULL, ...) { - A <- list(...) - A$alpha <- 0 - do.call("distance2elasticnet", c(list(formula, data = data, link = link), A), - quote = TRUE) + if ("alpha" %in% ...names()) { + args <- c("s", names(formals(glmnet::glmnet)), names(formals(glmnet::cv.glmnet))) + A <- setNames(lapply(args, ...get, ...), args) + A[lengths(A) == 0L] <- NULL + + A$alpha <- 0 + do.call("distance2elasticnet", c(list(formula, data = data, link = link), A), + quote = TRUE) + } + else { + distance2elasticnet(formula = formula, data = data, link = link, alpha = 0, ...) + } } #distance2gbm-------------- @@ -525,8 +549,11 @@ distance2gbm <- function(formula, data = NULL, link = NULL, ...) { A <- list(...) - method <- A[["method"]] - A[!names(A) %in% names(formals(gbm::gbm))] <- NULL + method <- ...get("method", ...) + + args <- names(formals(gbm::gbm)) + A <- setNames(lapply(args, ...get, ...), args) + A[lengths(A) == 0L] <- NULL A$formula <- formula A$data <- data diff --git a/R/get_weights_from_mm.R b/R/get_weights_from_mm.R index 1d8ec6f6..64390a94 100644 --- a/R/get_weights_from_mm.R +++ b/R/get_weights_from_mm.R @@ -6,12 +6,17 @@ get_weights_from_mm <- function(match.matrix, treat, focal = NULL) { weights <- weights_matrixC(match.matrix, treat, focal) - if (sum(weights) == 0) + if (all_equal_to(weights, 0)) { .err("No units were matched") - if (sum(weights[treat == 1]) == 0) + } + + if (all_equal_to(weights[treat == 1], 0)) { .err("No treated units were matched") - if (sum(weights[treat == 0]) == 0) + } + + if (all_equal_to(weights[treat == 0], 0)) { .err("No control units were matched") + } setNames(weights, names(treat)) } \ No newline at end of file diff --git a/R/get_weights_from_subclass.R b/R/get_weights_from_subclass.R index b6e2baa1..edbf2a9a 100644 --- a/R/get_weights_from_subclass.R +++ b/R/get_weights_from_subclass.R @@ -2,52 +2,58 @@ get_weights_from_subclass <- function(psclass, treat, estimand = "ATT") { NAsub <- is.na(psclass) - i1 <- treat == 1 & !NAsub - i0 <- treat == 0 & !NAsub + i1 <- which(treat == 1 & !NAsub) + i0 <- which(treat == 0 & !NAsub) - weights <- setNames(rep(0, length(treat)), names(treat)) + if (is_null(i1)) { + if (is_null(i0)) { + .err("No units were matched") + } - if (!is.factor(psclass)) { - psclass <- factor(psclass, nmax = min(sum(i1), sum(i0))) - levels(psclass) <- seq_len(nlevels(psclass)) + .err("No treated units were matched") + } + else if (is_null(i0)) { + .err("No control units were matched") } - treated_by_sub <- setNames(tabulateC(psclass[i1], nlevels(psclass)), levels(psclass)) - control_by_sub <- setNames(tabulateC(psclass[i0], nlevels(psclass)), levels(psclass)) + weights <- setNames(rep(0.0, length(treat)), names(treat)) + + if (!is.factor(psclass)) { + psclass <- factor(psclass, nmax = min(length(i1), length(i0))) + } - total_by_sub <- treated_by_sub + control_by_sub + treated_by_sub <- tabulate(psclass[i1], nlevels(psclass)) + control_by_sub <- tabulate(psclass[i0], nlevels(psclass)) - psclass <- as.character(psclass) + psclass <- unclass(psclass) if (estimand == "ATT") { - weights[i1] <- 1 + weights[i1] <- 1.0 weights[i0] <- (treated_by_sub/control_by_sub)[psclass[i0]] - - #Weights average 1 - weights[i0] <- .make_sum_to_n(weights[i0]) } else if (estimand == "ATC") { weights[i1] <- (control_by_sub/treated_by_sub)[psclass[i1]] - weights[i0] <- 1 - - #Weights average 1 - weights[i1] <- .make_sum_to_n(weights[i1]) + weights[i0] <- 1.0 } else if (estimand == "ATE") { - weights[i1] <- (total_by_sub/treated_by_sub)[psclass[i1]] - weights[i0] <- (total_by_sub/control_by_sub)[psclass[i0]] - - #Weights average 1 - weights[i1] <- .make_sum_to_n(weights[i1]) - weights[i0] <- .make_sum_to_n(weights[i0]) + weights[i1] <- 1.0 + (control_by_sub/treated_by_sub)[psclass[i1]] + weights[i0] <- 1.0 + (treated_by_sub/control_by_sub)[psclass[i0]] } - if (sum(weights) == 0) - .err("No units were matched") - if (sum(weights[treat == 1]) == 0) - .err("No treated units were matched") - if (sum(weights[treat == 0]) == 0) - .err("No control units were matched") - weights } + +# get_weights_from_subclass2 <- function(psclass, treat, estimand = "ATT") { +# +# weights <- weights_subclassC(psclass, treat, +# switch(estimand, "ATT" = 1, "ATC" = 0, NULL)) +# +# if (sum(weights) == 0) +# .err("No units were matched") +# if (sum(weights[treat == 1]) == 0) +# .err("No treated units were matched") +# if (sum(weights[treat == 0]) == 0) +# .err("No control units were matched") +# +# weights +# } \ No newline at end of file diff --git a/R/input_processing.R b/R/input_processing.R index 61460e78..ed850f26 100644 --- a/R/input_processing.R +++ b/R/input_processing.R @@ -5,6 +5,7 @@ check.inputs <- function(mcall, method, distance, exact, mahvars, antiexact, min.controls = NULL, max.controls = NULL) { null.method <- is_null(method) + if (null.method) { method <- "NULL" } @@ -16,23 +17,24 @@ check.inputs <- function(mcall, method, distance, exact, mahvars, antiexact, ignored.inputs <- character(0) error.inputs <- character(0) + if (null.method) { for (i in c("exact", "mahvars", "antiexact", "caliper", "std.caliper", "replace", "ratio", "min.controls", "max.controls", "m.order")) { - if (i %in% names(mcall) && is_not_null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { + if (hasName(mcall, i) && is_not_null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { ignored.inputs <- c(ignored.inputs, i) } } } else if (method == "exact") { for (i in c("distance", "exact", "mahvars", "antiexact", "caliper", "std.caliper", "discard", "reestimate", "replace", "ratio", "min.controls", "max.controls", "m.order")) { - if (i %in% names(mcall) && is_not_null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { + if (hasName(mcall, i) && is_not_null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { ignored.inputs <- c(ignored.inputs, i) } } } else if (method == "cem") { - for (i in c("distance", "exact", "mahvars", "antiexact", "caliper", "std.caliper", "discard", "reestimate", "replace", "ratio", "min.controls", "max.controls", "m.order")) { - if (i %in% names(mcall) && is_not_null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { + for (i in c("distance", "exact", "mahvars", "antiexact", "caliper", "std.caliper", "discard", "reestimate", "replace", "ratio", "min.controls", "max.controls")) { + if (hasName(mcall, i) && is_not_null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { ignored.inputs <- c(ignored.inputs, i) } } @@ -40,7 +42,7 @@ check.inputs <- function(mcall, method, distance, exact, mahvars, antiexact, else if (method == "nearest") { if (is.character(distance) && distance %in% matchit_distances()) { for (e in c("mahvars", "reestimate")) { - if (e %in% names(mcall) && is_not_null(e_ <- get0(e, inherits = FALSE)) && !identical(e_, formals(matchit)[[e]])) { + if (hasName(mcall, e) && is_not_null(e_ <- get0(e, inherits = FALSE)) && !identical(e_, formals(matchit)[[e]])) { error.inputs <- c(error.inputs, e) } } @@ -49,14 +51,14 @@ check.inputs <- function(mcall, method, distance, exact, mahvars, antiexact, else if (method == "optimal") { if (is.character(distance) && distance %in% matchit_distances()) { for (e in c("mahvars", "reestimate")) { - if (e %in% names(mcall) && is_not_null(e_ <- get0(e, inherits = FALSE)) && !identical(e_, formals(matchit)[[e]])) { + if (hasName(mcall, e) && is_not_null(e_ <- get0(e, inherits = FALSE)) && !identical(e_, formals(matchit)[[e]])) { error.inputs <- c(error.inputs, e) } } } for (i in c("replace", "caliper", "std.caliper", "m.order")) { - if (i %in% names(mcall) && is_not_null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { + if (hasName(mcall, i) && is_not_null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { ignored.inputs <- c(ignored.inputs, i) } } @@ -65,14 +67,14 @@ check.inputs <- function(mcall, method, distance, exact, mahvars, antiexact, else if (method == "full") { if (is.character(distance) && distance %in% matchit_distances()) { for (e in c("mahvars", "reestimate")) { - if (e %in% names(mcall) && is_not_null(e_ <- get0(e, inherits = FALSE)) && !identical(e_, formals(matchit)[[e]])) { + if (hasName(mcall, e) && is_not_null(e_ <- get0(e, inherits = FALSE)) && !identical(e_, formals(matchit)[[e]])) { error.inputs <- c(error.inputs, e) } } } for (i in c("replace", "ratio", "m.order")) { - if (i %in% names(mcall) && is_not_null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { + if (hasName(mcall, i) && is_not_null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { ignored.inputs <- c(ignored.inputs, i) } } @@ -80,27 +82,27 @@ check.inputs <- function(mcall, method, distance, exact, mahvars, antiexact, else if (method == "genetic") { if (is.character(distance) && distance %in% matchit_distances()) { for (e in c("mahvars", "reestimate")) { - if (e %in% names(mcall) && is_not_null(e_ <- get0(e, inherits = FALSE)) && !identical(e_, formals(matchit)[[e]])) { + if (hasName(mcall, e) && is_not_null(e_ <- get0(e, inherits = FALSE)) && !identical(e_, formals(matchit)[[e]])) { error.inputs <- c(error.inputs, e) } } } for (i in c("min.controls", "max.controls")) { - if (i %in% names(mcall) && is_not_null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { + if (hasName(mcall, i) && is_not_null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { ignored.inputs <- c(ignored.inputs, i) } } } else if (method == "cardinality") { for (i in c("distance", "antiexact", "caliper", "std.caliper", "reestimate", "replace", "min.controls", "m.order")) { - if (i %in% names(mcall) && is_not_null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { + if (hasName(mcall, i) && is_not_null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { ignored.inputs <- c(ignored.inputs, i) } } } else if (method == "subclass") { for (i in c("exact", "mahvars", "antiexact", "caliper", "std.caliper", "replace", "ratio", "min.controls", "max.controls", "m.order")) { - if (i %in% names(mcall) && is_not_null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { + if (hasName(mcall, i) && is_not_null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { ignored.inputs <- c(ignored.inputs, i) } } @@ -108,32 +110,32 @@ check.inputs <- function(mcall, method, distance, exact, mahvars, antiexact, else if (method == "quick") { if (is.character(distance) && distance %in% matchit_distances()) { for (e in c("mahvars", "reestimate")) { - if (e %in% names(mcall) && is_not_null(e_ <- get0(e, inherits = FALSE)) && !identical(e_, formals(matchit)[[e]])) { + if (hasName(mcall, e) && is_not_null(e_ <- get0(e, inherits = FALSE)) && !identical(e_, formals(matchit)[[e]])) { error.inputs <- c(error.inputs, e) } } } for (i in c("replace", "ratio", "min.controls", "max.controls", "m.order", "antiexact")) { - if (i %in% names(mcall) && is_not_null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { + if (hasName(mcall, i) && is_not_null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { ignored.inputs <- c(ignored.inputs, i) } } } if (is_not_null(ignored.inputs)) { - .wrn(sprintf("the %s %s not used with `method = %s` and will be ignored", - ngettext(length(ignored.inputs), "argument", "arguments"), - word_list(ignored.inputs, quotes = 1, is.are = TRUE), - add_quotes(method, quotes = !null.method))) + .wrn(sprintf("the argument%%s %s %%r not used with `method = %s` and will be ignored", + word_list(ignored.inputs, quotes = "`"), + add_quotes(method, quotes = !null.method)), + n = length(ignored.inputs)) } if (is_not_null(error.inputs)) { - .err(sprintf("the %s %s not used with `method = %s` and `distance = \"%s\"`", - ngettext(length(error.inputs), "argument", "arguments"), - word_list(error.inputs, quotes = 1, is.are = TRUE), + .err(sprintf("the argument%%s %s %%r not used with `method = %s` and `distance = %s`", + word_list(error.inputs, quotes = "`"), add_quotes(method, quotes = !null.method), - distance)) + add_quotes(distance)), + n = length(error.inputs)) } ignored.inputs @@ -186,7 +188,7 @@ process.distance <- function(distance, method = NULL, treat) { return(distance) } - if (is.character(distance) && length(distance) == 1L) { + if (chk::vld_string(distance)) { allowable.distances <- c( #Propensity score methods "glm", "cbps", "gam", "nnet", "rpart", "bart", @@ -198,15 +200,19 @@ process.distance <- function(distance, method = NULL, treat) { if (tolower(distance) %in% c("cauchit", "cloglog", "linear.cloglog", "linear.log", "linear.logit", "linear.probit", "linear.cauchit", "log", "probit")) { link <- tolower(distance) + .wrn(sprintf("`distance = \"%s\"` will be deprecated; please use `distance = \"glm\", link = \"%s\"` in the future", distance, link)) + distance <- "glm" attr(distance, "link") <- link } else if (tolower(distance) %in% tolower(c("GAMcloglog", "GAMlog", "GAMlogit", "GAMprobit"))) { link <- tolower(substr(distance, 4, nchar(distance))) + .wrn(sprintf("`distance = \"%s\"` will be deprecated; please use `distance = \"gam\", link = \"%s\"` in the future", distance, link)) + distance <- "gam" attr(distance, "link") <- link } @@ -227,51 +233,53 @@ process.distance <- function(distance, method = NULL, treat) { else { distance <- tolower(distance) } + + return(distance) } - else if (!is.numeric(distance) || (is_not_null(dim(distance)) && length(dim(distance)) != 2)) { + + if (!is.numeric(distance) || (is_not_null(dim(distance)) && length(dim(distance)) != 2)) { .err("`distance` must be a string with the name of the distance measure to be used or a numeric vector or matrix containing distance measures") } - else if (is.matrix(distance) && (is_null(method) || !method %in% c("nearest", "optimal", "full"))) { + + if (is.matrix(distance) && (is_null(method) || !method %in% c("nearest", "optimal", "full"))) { .err(sprintf("`distance` cannot be supplied as a matrix with `method = %s`", add_quotes(method, quotes = is_not_null(method)))) } - if (is.numeric(distance)) { - if (is.matrix(distance)) { - dim.distance <- dim(distance) - if (all(dim.distance == length(treat))) { - if (is_not_null(rownames(distance))) { - distance <- distance[names(treat),, drop = FALSE] - } + if (is.matrix(distance)) { + dim.distance <- dim(distance) - if (is_not_null(colnames(distance))) { - distance <- distance[,names(treat), drop = FALSE] - } + if (all_equal_to(dim.distance, length(treat))) { + if (is_not_null(rownames(distance))) { + distance <- distance[names(treat),, drop = FALSE] + } - distance <- distance[treat == 1, treat == 0, drop = FALSE] + if (is_not_null(colnames(distance))) { + distance <- distance[,names(treat), drop = FALSE] } - else if (all(dim.distance == c(sum(treat==1), sum(treat==0)))) { - if (is_not_null(rownames(distance))) { - distance <- distance[names(treat)[treat == 1],, drop = FALSE] - } - if (is_not_null(colnames(distance))) { - distance <- distance[,names(treat)[treat == 0], drop = FALSE] - } + distance <- distance[treat == 1, treat == 0, drop = FALSE] + } + else if (dim.distance[1L] == sum(treat == 1) && + dim.distance[2L] == sum(treat == 0)) { + if (is_not_null(rownames(distance))) { + distance <- distance[names(treat)[treat == 1],, drop = FALSE] } - else { - .err("when supplied as a matrix, `distance` must have dimensions NxN or N1xN0. See `help(\"distance\")` for details") + + if (is_not_null(colnames(distance))) { + distance <- distance[,names(treat)[treat == 0], drop = FALSE] } } else { - if (length(distance) != length(treat)) { - .err("`distance` must be the same length as the dataset if specified as a numeric vector") - } + .err("when supplied as a matrix, `distance` must have dimensions NxN or N1xN0. See `help(\"distance\")` for details") } - - chk::chk_not_any_na(distance) } + else if (length(distance) != length(treat)) { + .err("`distance` must be the same length as the dataset if specified as a numeric vector") + } + + chk::chk_not_any_na(distance) distance } @@ -287,28 +295,28 @@ process.ratio <- function(ratio, method = NULL, ..., min.controls = NULL, max.co } if (method %in% c("nearest", "optimal")) { - if (ratio.null) ratio <- 1 - else if (ratio.na) .err("`ratio` cannot be `NA`") - else if (!is.atomic(ratio) || !is.numeric(ratio) || length(ratio) > 1 || ratio < 1) { - .err("`ratio` must be a single number greater than or equal to 1") + if (ratio.null) { + ratio <- 1 + } + else { + chk::chk_number(ratio) + chk::chk_gte(ratio, 1) } if (is_null(max.controls)) { if (!chk::vld_whole_number(ratio)) { .err("`ratio` must be a whole number when `max.controls` is not specified") } + ratio <- round(ratio) } - else if (anyNA(max.controls) || !is.atomic(max.controls) || !is.numeric(max.controls) || length(max.controls) > 1) { - .err("`max.controls` must be a single positive number") - } else { - if (ratio <= 1) { + chk::chk_count(max.controls) + + if (ratio == 1) { .err("`ratio` must be greater than 1 for variable ratio matching") } - max.controls <- ceiling(max.controls) - if (max.controls <= ratio) { .err("`max.controls` must be greater than `ratio` for variable ratio matching") } @@ -316,11 +324,8 @@ process.ratio <- function(ratio, method = NULL, ..., min.controls = NULL, max.co if (is_null(min.controls)) { min.controls <- 1 } - else if (!anyNA(max.controls) && is.atomic(max.controls) && is.numeric(max.controls) && length(max.controls) == 1L) { - min.controls <- floor(min.controls) - } else { - .err("`max.controls` must be a single positive number") + chk::chk_count(min.controls) } if (min.controls < 1) { @@ -336,15 +341,17 @@ process.ratio <- function(ratio, method = NULL, ..., min.controls = NULL, max.co if (is_null(max.controls)) { max.controls <- Inf } - else if ((anyNA(max.controls) || !is.atomic(max.controls) || !is.numeric(max.controls) || length(max.controls) > 1)) { - .err("`max.controls` must be a single positive number") + else { + chk::chk_number(max.controls) + chk::chk_gt(max.controls, 0) } if (is_null(min.controls)) { min.controls <- 0 } - else if ((anyNA(min.controls) || !is.atomic(min.controls) || !is.numeric(min.controls) || length(min.controls) > 1)) { - .err("`min.controls` must be a single positive number") + else { + chk::chk_number(min.controls) + chk::chk_gt(min.controls, 0) } ratio <- 1 #Just to get min.controls and max.controls out @@ -353,23 +360,17 @@ process.ratio <- function(ratio, method = NULL, ..., min.controls = NULL, max.co if (ratio.null) { ratio <- 1 } - else if (ratio.na) { - .err("`ratio` cannot be `NA`") - } - else if (!is.atomic(ratio) || !is.numeric(ratio) || length(ratio) > 1 || ratio < 1 || - !chk::vld_whole_number(ratio)) { - .err("`ratio` must be a single whole number greater than or equal to 1") + else { + chk::chk_count(ratio) } - ratio <- round(ratio) - min.controls <- max.controls <- NULL } else if (method == "cardinality") { if (ratio.null) { ratio <- 1 } - else if (!ratio.na && (!is.atomic(ratio) || !is.numeric(ratio) || length(ratio) > 1 || ratio < 0)) { + else if (!ratio.na && (!chk::vld_number(ratio) || !chk::vld_gte(ratio, 0))) { .err("`ratio` must be a single positive number or `NA`") } @@ -423,7 +424,7 @@ process.caliper <- function(caliper = NULL, method = NULL, data = NULL, covs = N .err("no more than one entry in `caliper` can have no name") } - if (any(names(caliper) == "") && is_null(distance)) { + if (hasName(caliper, "") && is_null(distance)) { .err("all entries in `caliper` must be named when `distance` does not correspond to a propensity score") } @@ -439,7 +440,7 @@ process.caliper <- function(caliper = NULL, method = NULL, data = NULL, covs = N #Check std.caliper chk::chk_logical(std.caliper) - if (length(std.caliper) == 1) { + if (length(std.caliper) == 1L) { std.caliper <- setNames(rep.int(std.caliper, length(caliper)), names(caliper)) } else if (length(std.caliper) == length(caliper)) { @@ -477,7 +478,7 @@ process.caliper <- function(caliper = NULL, method = NULL, data = NULL, covs = N chk::chk_not_any_na(std.caliper) if (any(std.caliper)) { - if (any(names(std.caliper) == "") && isTRUE(std.caliper[names(std.caliper) == ""]) && is.matrix(distance)) { + if (hasName(std.caliper, "") && isTRUE(std.caliper[names(std.caliper) == ""]) && is.matrix(distance)) { .err("when `distance` is supplied as a matrix and a caliper for it is specified, `std.caliper` must be `FALSE` for the distance measure") } @@ -512,20 +513,18 @@ process.replace <- function(replace, method = NULL, ..., reuse.max = NULL) { if (method %in% c("nearest")) { if (is_null(reuse.max)) { - if (replace) reuse.max <- .Machine$integer.max - else reuse.max <- 1L + reuse.max <- if (replace) .Machine$integer.max else 1L } - else if (length(reuse.max) == 1 && is.numeric(reuse.max) && - (!is.finite(reuse.max) || reuse.max > .Machine$integer.max) && - !anyNA(reuse.max)) { - reuse.max <- .Machine$integer.max - } - else if (abs(reuse.max - round(reuse.max)) > 1e-8 || length(reuse.max) != 1 || - anyNA(reuse.max) || reuse.max < 1) { - .err("`reuse.max` must be a positive integer of length 1") + else { + chk::chk_count(reuse.max) + chk::chk_gte(reuse.max, 1) + + if (reuse.max > .Machine$integer.max) { + reuse.max <- .Machine$integer.max + } } - replace <- reuse.max != 1L + replace <- reuse.max > 1L attr(replace, "reuse.max") <- as.integer(reuse.max) } @@ -547,7 +546,7 @@ process.variable.input <- function(x, data = NULL) { n)) } - if (!all(x %in% names(data))) { + if (!all(hasName(data, x))) { .err(sprintf("All names supplied to `%s` must be variables in `data`. Variables not in `data`:\n\t%s", n, paste(add_quotes(setdiff(x, names(data))), collapse = ", ")), tidy = FALSE) } @@ -568,4 +567,4 @@ process.variable.input <- function(x, data = NULL) { } x_covs -} \ No newline at end of file +} diff --git a/R/match.data.R b/R/match.data.R index 9f90198e..a9a52033 100644 --- a/R/match.data.R +++ b/R/match.data.R @@ -178,21 +178,26 @@ match.data <- function(object, group = "all", distance = "distance", weights = " chk::chk_is(object, "matchit") - if (is_null(data)) { - env <- environment(object$formula) - data <- try(eval(object$call$data, envir = env), silent = TRUE) - if (null_or_error(data) || - length(dim(data)) != 2 || nrow(data) != length(object[["treat"]])) { - env <- parent.frame() - data <- try(eval(object$call$data, envir = env), silent = TRUE) - if (null_or_error(data) || - length(dim(data)) != 2 || nrow(data) != length(object[["treat"]])) { - data <- object[["model"]][["data"]] - if (is_null(data) || nrow(data) != length(object[["treat"]])) { - .err("a valid dataset could not be found. Please supply an argument to `data` containing the original dataset used in the matching") - } - } + data.found <- FALSE + for (i in 1:4) { + if (i == 2) { + data <- try(eval(object$call$data, envir = environment(object$formula)), silent = TRUE) } + else if (i == 3) { + data <- try(eval(object$call$data, envir = parent.frame()), silent = TRUE) + } + else if (i == 4) { + data <- object[["model"]][["data"]] + } + + if (!null_or_error(data) && length(dim(data)) == 2L && nrow(data) == length(object[["treat"]])) { + data.found <- TRUE + break + } + } + + if (!data.found) { + .err("a valid dataset could not be found. Please supply an argument to `data` containing the original dataset used in the matching") } if (!is.data.frame(data)) { @@ -209,7 +214,7 @@ match.data <- function(object, group = "all", distance = "distance", weights = " if (is_not_null(object$distance)) { chk::chk_not_null(distance) chk::chk_string(distance) - if (distance %in% names(data)) { + if (hasName(data, distance)) { .err(sprintf("%s is already the name of a variable in the data. Please choose another name for distance using the `distance` argument", add_quotes(distance))) } @@ -219,7 +224,7 @@ match.data <- function(object, group = "all", distance = "distance", weights = " if (is_not_null(object$weights)) { chk::chk_not_null(weights) chk::chk_string(weights) - if (weights %in% names(data)) { + if (hasName(data, weights)) { .err(sprintf("%s is already the name of a variable in the data. Please choose another name for weights using the `weights` argument", add_quotes(weights))) } @@ -233,7 +238,7 @@ match.data <- function(object, group = "all", distance = "distance", weights = " if (is_not_null(object$subclass)) { chk::chk_not_null(subclass) chk::chk_string(subclass) - if (subclass %in% names(data)) { + if (hasName(data, subclass)) { .err(sprintf("%s is already the name of a variable in the data. Please choose another name for subclass using the `subclass` argument", add_quotes(subclass))) } @@ -280,7 +285,7 @@ get_matches <- function(object, distance = "distance", weights = "weights", subc chk::chk_not_null(id) chk::chk_string(id) - if (id %in% names(m.data)) { + if (hasName(m.data, id)) { .err(sprintf("%s is already the name of a variable in the data. Please choose another name for id using the `id` argument", add_quotes(id))) } @@ -288,7 +293,7 @@ get_matches <- function(object, distance = "distance", weights = "weights", subc m.data[[id]] <- names(object$treat)[object$weights > 0] for (i in c(weights, subclass)) { - if (i %in% names(m.data)) m.data[[i]] <- NULL + if (hasName(m.data, i)) m.data[[i]] <- NULL } mm <- object$match.matrix diff --git a/R/match.qoi.R b/R/match.qoi.R index bb683a3d..ad47f891 100644 --- a/R/match.qoi.R +++ b/R/match.qoi.R @@ -22,8 +22,8 @@ bal1var <- function(xx, tt, ww = NULL, s.weights, subclass = NULL, mm = NULL, too.small <- sum(ww[i1] != 0) < 2 && sum(ww[i0] != 0) < 2 - xsum["Means Treated"] <- wm(xx[i1], ww[i1], na.rm=TRUE) - xsum["Means Control"] <- wm(xx[i0], ww[i0], na.rm=TRUE) + xsum["Means Treated"] <- wm(xx[i1], ww[i1], na.rm = TRUE) + xsum["Means Control"] <- wm(xx[i0], ww[i0], na.rm = TRUE) mdiff <- xsum["Means Treated"] - xsum["Means Control"] @@ -47,12 +47,17 @@ bal1var <- function(xx, tt, ww = NULL, s.weights, subclass = NULL, mm = NULL, } xsum[3] <- mdiff/std - if (!un && compute.pair.dist) xsum[7] <- pair.dist(xx, tt, subclass, mm, std) + if (!un && compute.pair.dist) { + xsum[7] <- pair.dist(xx, tt, subclass, mm, std) + } } } else { xsum[3] <- mdiff - if (!un && compute.pair.dist) xsum[7] <- pair.dist(xx, tt, subclass, mm) + + if (!un && compute.pair.dist) { + xsum[7] <- pair.dist(xx, tt, subclass, mm) + } } if (bin.var) { @@ -86,10 +91,10 @@ bal1var.subclass <- function(xx, tt, s.weights, subclass, s.d.denom = "treated", i1 <- which(in.sub & tt == 1) i0 <- which(in.sub & tt == 0) - too.small <- length(i1) < 2 && length(i0) < 2 + too.small <- length(i1) < 2L && length(i0) < 2L - xsum["Subclass","Means Treated"] <- wm(xx[i1], s.weights[i1], na.rm=TRUE) - xsum["Subclass","Means Control"] <- wm(xx[i0], s.weights[i0], na.rm=TRUE) + xsum["Subclass","Means Treated"] <- wm(xx[i1], s.weights[i1], na.rm = TRUE) + xsum["Subclass","Means Control"] <- wm(xx[i0], s.weights[i0], na.rm = TRUE) mdiff <- xsum["Subclass","Means Treated"] - xsum["Subclass","Means Control"] @@ -152,7 +157,7 @@ pair.dist <- function(xx, tt, subclass = NULL, mm = NULL, std = NULL) { } if (is_not_null(std) && abs(mpdiff) > 1e-8) { - mpdiff <- mpdiff/std + return(mpdiff/std) } mpdiff @@ -164,12 +169,15 @@ qqsum <- function(x, t, w = NULL, standardize = FALSE) { n.obs <- length(x) - if (is_null(w)) w <- rep(1, n.obs) + if (is_null(w)) { + w <- rep(1, n.obs) + } - if (all(x == 0 | x == 1)) { + if (has_n_unique(x, 2) && all(x == 0 | x == 1)) { t1 <- t == t[1] #For binary variables, just difference in means ediff <- abs(wm(x[t1], w[t1]) - wm(x[-t1], w[-t1])) + return(c(meandiff = ediff, maxdiff = ediff)) } @@ -228,6 +236,7 @@ qqsum <- function(x, t, w = NULL, standardize = FALSE) { method = "constant", ties = "ordered")$y } } + ediff <- abs(x1 - x0) } diff --git a/R/matchit2cardinality.R b/R/matchit2cardinality.R index 429a0f31..093a6988 100644 --- a/R/matchit2cardinality.R +++ b/R/matchit2cardinality.R @@ -265,16 +265,14 @@ #' # covariates included if possible. NULL -matchit2cardinality <- function(treat, data, discarded, formula, - ratio = 1, focal = NULL, s.weights = NULL, - replace = FALSE, mahvars = NULL, exact = NULL, - estimand = "ATT", verbose = FALSE, - tols = .05, std.tols = TRUE, - solver = "glpk", time = 1*60, ...){ +matchit2cardinality <- function(treat, data, discarded, formula, + ratio = 1, focal = NULL, s.weights = NULL, + replace = FALSE, mahvars = NULL, exact = NULL, + estimand = "ATT", verbose = FALSE, + tols = .05, std.tols = TRUE, + solver = "glpk", time = 1*60, ...) { - if (verbose) { - cat("Cardinality matching... \n") - } + .cat_verbose("Cardinality matching... \n", verbose = verbose) tvals <- unique(treat) nt <- length(tvals) @@ -300,9 +298,9 @@ matchit2cardinality <- function(treat, data, discarded, formula, X <- get.covs.matrix(formula, data = data) if (is_not_null(exact)) { - ex <- factor(exactify(model.frame(exact, data = data), nam = lab, sep = ", ", include_vars = TRUE)) + ex <- exactify(model.frame(exact, data = data), nam = lab, sep = ", ", include_vars = TRUE) - cc <- Reduce("intersect", lapply(tvals, function(t) as.integer(ex)[treat==t])) + cc <- Reduce("intersect", lapply(tvals, function(t) unclass(ex)[treat==t])) if (is_null(cc)) { .err("no matches were found") @@ -388,23 +386,23 @@ matchit2cardinality <- function(treat, data, discarded, formula, match(e, levels(ex)[cc]), length(cc), e)) } - in.exact <- which(!discarded & ex == e) + .e <- which(!discarded & ex == e) - treat_in.exact <- treat[in.exact] + treat_in.exact <- treat[.e] out <- cardinality_matchit(treat = treat_in.exact, - X = X[in.exact,, drop = FALSE], + X = X[.e,, drop = FALSE], estimand = estimand, tols = tols, - s.weights = s.weights[in.exact], + s.weights = s.weights[.e], ratio = ratio, focal = focal, tvals = tvals, solver = solver, time = time, verbose = verbose) - weights[in.exact] <- out[["weights"]] + weights[.e] <- out[["weights"]] opt.out[[e]] <- out[["opt.out"]] if (is_not_null(mahvars)) { - mo <- eucdist_internal(mahcovs[in.exact[out[["weights"]] > 0],, drop = FALSE], + mo <- eucdist_internal(mahcovs[.e[out[["weights"]] > 0],, drop = FALSE], treat_in.exact[out[["weights"]] > 0]) pm <- optmatch::pairmatch(mo, @@ -434,7 +432,7 @@ matchit2cardinality <- function(treat, data, discarded, formula, weights = weights, obj = opt.out) - if (verbose) cat("Done.\n") + .cat_verbose("Done.\n", verbose = verbose) class(res) <- "matchit" @@ -668,7 +666,7 @@ cardinality_matchit <- function(treat, X, estimand = "ATT", tols = .05, s.weight cardinality_error_report <- function(out, solver) { if (solver == "glpk") { if (out$status == 1) { - if (all(out$solution == 0)) { + if (all_equal_to(out$solution, 0)) { .err("the optimization problem may be infeasible. Try increasing the value of `tols`.\nSee `?method_cardinality` for additional details") } .wrn("the optimizer failed to find an optimal solution in the time alotted. The returned solution may not be optimal.\nSee `?method_cardinality` for additional details") diff --git a/R/matchit2exact.R b/R/matchit2exact.R index 04006669..85170634 100644 --- a/R/matchit2exact.R +++ b/R/matchit2exact.R @@ -89,8 +89,7 @@ NULL matchit2exact <- function(treat, covs, data, estimand = "ATT", verbose = FALSE, ...){ - if(verbose) - cat("Exact matching... \n") + .cat_verbose("Exact matching...\n", verbose = verbose) if (is_null(covs)) { .err("covariates must be specified in the input formula to use exact matching") @@ -100,7 +99,7 @@ matchit2exact <- function(treat, covs, data, estimand = "ATT", verbose = FALSE, estimand <- match_arg(estimand, c("ATT", "ATC", "ATE")) xx <- exactify(covs, names(treat)) - cc <- do.call("intersect", lapply(unique(treat), function(t) xx[treat == t])) + cc <- Reduce("intersect", lapply(unique(treat), function(t) xx[treat == t])) if (is_null(cc)) { .err("no exact matches were found") @@ -108,12 +107,12 @@ matchit2exact <- function(treat, covs, data, estimand = "ATT", verbose = FALSE, psclass <- setNames(factor(match(xx, cc), nmax = length(cc)), names(treat)) - if (verbose) cat("Calculating matching weights... ") + .cat_verbose("Calculating matching weights... ", verbose = verbose) res <- list(subclass = psclass, weights = get_weights_from_subclass(psclass, treat, estimand)) - if (verbose) cat("Done.\n") + .cat_verbose("Done.\n", verbose = verbose) class(res) <- "matchit" res diff --git a/R/matchit2full.R b/R/matchit2full.R index 84a297bc..b9f9429a 100644 --- a/R/matchit2full.R +++ b/R/matchit2full.R @@ -230,12 +230,11 @@ matchit2full <- function(treat, formula, data, distance, discarded, rlang::check_installed("optmatch") - if (verbose) cat("Full matching... \n") - - A <- list(...) + .cat_verbose("Full matching... \n", verbose = verbose) fm.args <- c("omit.fraction", "mean.controls", "tol", "solver") - A[!names(A) %in% fm.args] <- NULL + A <- setNames(lapply(fm.args, ...get, ...), fm.args) + A[lengths(A) == 0L] <- NULL #Set max problem size to Inf and return to original value after match omps <- getOption("optmatch_max_problem_size") @@ -273,7 +272,7 @@ matchit2full <- function(treat, formula, data, distance, discarded, ex <- factor(exactify(model.frame(exact, data = data), sep = ", ", include_vars = TRUE)[!discarded]) - cc <- intersect(as.integer(ex)[treat_==1], as.integer(ex)[treat_==0]) + cc <- Reduce("intersect", lapply(unique(treat_), function(t) unclass(ex)[treat_==t])) if (is_null(cc)) { .err("No matches were found") @@ -353,20 +352,22 @@ matchit2full <- function(treat, formula, data, distance, discarded, t_df <- data.frame(treat) for (e in levels(ex)[cc]) { - if (nlevels(ex) > 1) { - if (verbose) { - cat(sprintf("Matching subgroup %s/%s: %s...\n", - match(e, levels(ex)[cc]), length(cc), e)) - } + if (nlevels(ex) > 1L) { + .cat_verbose(sprintf("Matching subgroup %s/%s: %s...\n", + match(e, levels(ex)[cc]), length(cc), e), + verbose = verbose) + mo_ <- mo[ex[treat_==1] == e, ex[treat_==0] == e] } - else mo_ <- mo + else { + mo_ <- mo + } if (any(dim(mo_) == 0) || !any(is.finite(mo_))) { next } - if (all(dim(mo_) == 1) && all(is.finite(mo_))) { + if (all_equal_to(dim(mo_), 1) && all(is.finite(mo_))) { pair[ex == e] <- paste(1, e, sep = "|") next } @@ -398,13 +399,14 @@ matchit2full <- function(treat, formula, data, distance, discarded, #No match.matrix because treated units don't index matched strata (i.e., more than one #treated unit can be in the same stratum). Stratum information is contained in subclass. - if (verbose) cat("Calculating matching weights... ") + .cat_verbose("Calculating matching weights... ", + verbose = verbose) res <- list(subclass = psclass, weights = get_weights_from_subclass(psclass, treat, estimand), obj = p) - if (verbose) cat("Done.\n") + .cat_verbose("Done.\n", verbose = verbose) class(res) <- c("matchit") res diff --git a/R/matchit2genetic.R b/R/matchit2genetic.R index 721e3b11..d6269d05 100644 --- a/R/matchit2genetic.R +++ b/R/matchit2genetic.R @@ -259,9 +259,11 @@ matchit2genetic <- function(treat, data, distance, discarded, rlang::check_installed(c("Matching", "rgenoud")) - if (verbose) cat("Genetic matching... \n") + .cat_verbose("Genetic matching...\n", verbose = verbose) - A <- list(...) + args <- names(formals(Matching::GenMatch)) + A <- setNames(lapply(args, ...get, ...), args) + A[lengths(A) == 0L] <- NULL estimand <- toupper(estimand) estimand <- match_arg(estimand, c("ATT", "ATC")) @@ -325,9 +327,11 @@ matchit2genetic <- function(treat, data, distance, discarded, #Process exact; exact.log will be supplied to GenMatch() and Match() if (is_not_null(exact)) { #Add covariates in exact not in X to X - ex <- as.integer(factor(exactify(model.frame(exact, data = data), names(treat), sep = ", ", include_vars = TRUE))) + ex <- unclass(exactify(model.frame(exact, data = data), names(treat), + sep = ", ", include_vars = TRUE)) cc <- intersect(ex[treat==1], ex[treat==0]) + if (is_null(cc)) { .err("No matches were found") } @@ -371,7 +375,7 @@ matchit2genetic <- function(treat, data, distance, discarded, } #Then put distance caliper into cal - if ("" %in% names(caliper)) { + if (hasName(caliper, "")) { dist.cal <- caliper[names(caliper) == ""] if (is_not_null(mahvars)) { #If mahvars specified, distance is not yet in X, so add it to X @@ -430,7 +434,7 @@ matchit2genetic <- function(treat, data, distance, discarded, replace = replace, estimand = "ATT", ties = FALSE, CommonSupport = FALSE, verbose = verbose, weights = s.weights, print.level = 2*verbose), - A[names(A) %in% names(formals(Matching::GenMatch))])) + A[names(A) %in% args])) }, from = "Matching", dont_warn_if = "replace==FALSE, but there are more (weighted) treated obs than control obs.") } @@ -497,7 +501,7 @@ matchit2genetic <- function(treat, data, distance, discarded, # dimnames(mm) <- list(lab1, seq_len(ratio)) # } - if (verbose) cat("Calculating matching weights... ") + .cat_verbose("Calculating matching weights... ", verbose = verbose) if (replace) { psclass <- NULL @@ -513,7 +517,7 @@ matchit2genetic <- function(treat, data, distance, discarded, weights = weights, obj = g.out) - if (verbose) cat("Done.\n") + .cat_verbose("Done.\n", verbose = verbose) class(res) <- "matchit" res diff --git a/R/matchit2nearest.R b/R/matchit2nearest.R index 682d3b7e..4205e690 100644 --- a/R/matchit2nearest.R +++ b/R/matchit2nearest.R @@ -74,19 +74,20 @@ #' include `"largest"`, where matching takes place in descending order of #' distance measures; `"smallest"`, where matching takes place in ascending #' order of distance measures; `"closest"`, where matching takes place in -#' ascending order of the distance between units; `"random"`, where matching takes place +#' ascending order of the smallest distance between units; `"farthest"`, where matching takes place in +#' descending order of the smallest distance between units; `"random"`, where matching takes place #' in a random order; and `"data"` where matching takes place based on the #' order of units in the data. When `m.order = "random"`, results may differ #' across different runs of the same code unless a seed is set and specified #' with [set.seed()]. The default of `NULL` corresponds to `"largest"` when a #' propensity score is estimated or supplied as a vector and `"data"` -#' otherwise. +#' otherwise. See Details for more information. #' @param caliper the width(s) of the caliper(s) used for caliper matching. Two units with a difference on a caliper variable larger than the caliper will not be paired. See Details and Examples. #' @param std.caliper `logical`; when calipers are specified, whether they #' are in standard deviation units (`TRUE`) or raw units (`FALSE`). #' @param ratio how many control units should be matched to each treated unit #' for k:1 matching. For variable ratio matching, see section "Variable Ratio -#' Matching" in Details below. +#' Matching" in Details below. When `ratio` is greater than 1, all treated units will be attempted to be matched with a control unit before any treated unit is matched with a second control unit, etc. This reduces the possibility that control units will be used up before some treated units receive any matches. #' @param min.controls,max.controls for variable ratio matching, the minimum #' and maximum number of controls units to be matched to each treated unit. See #' section "Variable Ratio Matching" in Details below. @@ -216,20 +217,22 @@ #' `ratio`, and `max.controls` must be greater than `ratio`. See #' Examples below for an example of their use. #' -#' ## Using `m.order = "closest"` +#' ## Using `m.order = "closest"` or `"farthest"` #' -#' As of version 4.6.0, `m.order` can be set to `"closest"`, which works regardless of how the distance measure is specified. This matches in order of the distance between units. The closest pair of units across all potential pairs of units will be matched first; the second closest pair of all potential pairs will be matched second, etc. This ensures that the best possible matches are given priority, and in that sense performs similarly to `m.order = "smallest"`. When `distance` is a vector or a method of estimating a propensity score, an algorithm described by [Rassen et al. (2012)](https://doi.org/10.1002/pds.3263) is used. +#' `m.order` can be set to `"closest"` or `"farthest"`, which work regardless of how the distance measure is specified. This matches in order of the distance between units. First, all the closest match is found for all treated units and the pairwise distances computed; when `m.order = "closest"` the pair with the smallest of the distances is matched first, and when `m.order = "farthest"`, the pair with the largest of the distances is matched first. Then, the pair with the second smallest (or largest) is matched second. If the matched control is ineligible (i.e., because it has already been used in a prior match), a new match is found for the treated unit, the new pair's distance is re-computed, and the pairs are re-ordered by distance. +#' +#' Using `m.order = "closest"` ensures that the best possible matches are given priority, and in that sense should perform similarly to `m.order = "smallest"`. It can be used to ensure the best matches, especially when matching with a caliper. Using `m.order = "farthest"` ensures that the hardest units to match are given their best chance to find a close match, and in that sense should perform similarly to `m.order = "largest"`. It can be used to reduce the possibility of extreme imbalance when there are hard-to-match units competing for controls. Note that `m.order = "farthest"` **does not** implement "far matching" (i.e., finding the farthest control unit from each treated unit); it defines the order in which the closest matches are selected. #' #' ## Reproducibility #' -#' Nearest neighbor matching involves a random component only when `m.order = "random"`, so a seed must be set in that case using [set.seed()] to ensure reproducibility. Otherwise, it is purely deterministic, and any ties are broken based on the order in which the data appear. +#' Nearest neighbor matching involves a random component only when `m.order = "random"` (or when the propensity is estimated using a method with randomness; see [`distance`] for details), so a seed must be set in that case using [set.seed()] to ensure reproducibility. Otherwise, it is purely deterministic, and any ties are broken based on the order in which the data appear. #' #' @seealso [matchit()] for a detailed explanation of the inputs and outputs of #' a call to `matchit()`. #' #' [method_optimal()] for optimal pair matching, which is similar to -#' nearest neighbor matching except that an overall distance criterion is -#' minimized. +#' nearest neighbor matching without replacement except that an overall distance criterion is +#' minimized (i.e., as an alternative to specifying `m.order`). #' #' @references In a manuscript, you don't need to cite another package when #' using `method = "nearest"` because the matching is performed completely @@ -286,10 +289,7 @@ matchit2nearest <- function(treat, data, distance, discarded, is.full.mahalanobis, antiexact = NULL, unit.id = NULL, ...) { - if (verbose) { - rlang::check_installed("RcppProgress") - cat("Nearest neighbor matching... \n") - } + .cat_verbose("Nearest neighbor matching... \n", verbose = verbose) estimand <- toupper(estimand) estimand <- match_arg(estimand, c("ATT", "ATC")) @@ -317,13 +317,6 @@ matchit2nearest <- function(treat, data, distance, discarded, n1 <- sum(treat == 1) n0 <- n.obs - n1 - lab <- names(treat) - lab1 <- lab[treat == 1] - - if (is_not_null(distance)) { - names(distance) <- names(treat) - } - min.controls <- attr(ratio, "min.controls") max.controls <- attr(ratio, "max.controls") @@ -348,84 +341,68 @@ matchit2nearest <- function(treat, data, distance, discarded, } #Process caliper - ex.caliper.list <- list() + caliper.dist <- caliper.covs <- NULL + caliper.covs.mat <- NULL + ex.caliper <- NULL + if (is_not_null(caliper)) { if (any(names(caliper) != "")) { caliper.covs <- caliper[names(caliper) != ""] caliper.covs.mat <- get.covs.matrix(reformulate(names(caliper.covs)), data = data) - # ex.caliper.list <- setNames(lapply(names(caliper.covs), function(i) { - # splits <- get_splitsC(as.numeric(caliper.covs.mat[,i]), - # as.numeric(caliper.covs[i])) - # - # if (is_null(splits)) return(NULL) - # - # cut(caliper.covs.mat[,i], - # breaks = splits, - # include.lowest = TRUE) - # }), names(caliper.covs)) - } - else { - caliper.covs.mat <- caliper.covs <- NULL - } + ex.caliper.list <- setNames(lapply(names(caliper.covs), function(i) { + splits <- get_splitsC(as.numeric(caliper.covs.mat[,i]), + as.numeric(caliper.covs[i])) - if (any(names(caliper) == "")) { - caliper.dist <- caliper[names(caliper) == ""] + if (is_null(splits)) { + return(integer(0)) + } - # if (is_not_null(distance)) { - # splits <- get_splitsC(as.numeric(distance), - # as.numeric(caliper.dist)) - # - # ex.caliper.list <- c(ex.caliper.list, - # list(distance = cut(distance, - # breaks = splits, - # include.lowest = TRUE))) - # } - } - else { - caliper.dist <- NULL - } + cut(caliper.covs.mat[,i], + breaks = splits, + include.lowest = TRUE) + }), names(caliper.covs)) - if (is_not_null(ex.caliper.list)) { ex.caliper.list <- ex.caliper.list[lengths(ex.caliper.list) > 0L] - for (i in seq_along(ex.caliper.list)) { - levels(ex.caliper.list[[i]]) <- paste(names(ex.caliper.list)[i], "\u2208", levels(ex.caliper.list[[i]])) + # ex.caliper.list <- Filter(ex.caliper.list, f = function(x) nlevels(x) > 1) + + if (is_not_null(ex.caliper.list)) { + for (i in seq_along(ex.caliper.list)) { + levels(ex.caliper.list[[i]]) <- paste(names(ex.caliper.list)[i], "\u2208", levels(ex.caliper.list[[i]])) + } + + ex.caliper <- exactify(ex.caliper.list, nam = names(treat), sep = ", ", + justify = NULL) } } - } - else { - caliper.dist <- caliper.covs <- NULL - caliper.covs.mat <- NULL - } - ex.caliper <- { - if (is_null(ex.caliper.list)) NULL - else as.factor(exactify(ex.caliper.list, nam = lab, sep = ", ", - justify = NULL)) + if (hasName(caliper, "")) { + caliper.dist <- caliper[names(caliper) == ""] + } } #Process antiexact + antiexactcovs <- NULL if (is_not_null(antiexact)) { antiexactcovs <- model.frame(antiexact, data) antiexactcovs <- do.call("cbind", lapply(seq_len(ncol(antiexactcovs)), function(i) { - as.integer(as.factor(antiexactcovs[[i]])) + unclass(as.factor(antiexactcovs[[i]])) })) } - else { - antiexactcovs <- NULL - } reuse.max <- attr(replace, "reuse.max") - # if (reuse.max >= n1) { - # m.order <- "data" - # } + if (reuse.max >= n1) { + m.order <- "data" + } + #unit.id if (is_not_null(unit.id) && reuse.max < n1) { unit.id <- process.variable.input(unit.id, data) - unit.id <- factor(exactify(model.frame(unit.id, data = data), - nam = lab, sep = ", ", include_vars = TRUE)) + unit.id <- exactify(model.frame(unit.id, data = data), + nam = names(treat), sep = ", ", include_vars = TRUE) + num_ctrl_unit.ids <- length(unique(unit.id[treat == 0])) num_trt_unit.ids <- length(unique(unit.id[treat == 1])) @@ -439,23 +416,12 @@ matchit2nearest <- function(treat, data, distance, discarded, } #Process exact - ex <- { - if (is_not_null(exact)) as.factor(exactify(model.frame(exact, data = data), - nam = lab, sep = ", ", include_vars = TRUE)) - else NULL - } + ex <- NULL + if (is_not_null(exact)) { + ex <- exactify(model.frame(exact, data = data), + nam = names(treat), sep = ", ", include_vars = TRUE) - if (is_not_null(ex) || is_not_null(ex.caliper)) { - ex0 <- { - if (is_not_null(ex) && is_not_null(ex.caliper)) { - as.factor(exactify(list(ex, ex.caliper), nam = lab, sep = ", ", - justify = NULL)) - } - else if (is_not_null(ex)) ex - else ex.caliper - } - - cc <- Reduce("intersect", lapply(unique(treat), function(t) as.integer(ex0)[treat==t])) + cc <- Reduce("intersect", lapply(unique(treat), function(t) unclass(ex)[treat==t])) if (is_null(cc)) { .err("no matches were found") @@ -463,9 +429,6 @@ matchit2nearest <- function(treat, data, distance, discarded, cc <- sort(cc) } - else { - ex0 <- NULL - } if (reuse.max < n1) { if (is_not_null(ex)) { @@ -490,8 +453,8 @@ matchit2nearest <- function(treat, data, distance, discarded, } else { e_ratios <- { - if (is_null(unit.id)) as.numeric(reuse.max)*n0/n1 - else as.numeric(reuse.max)*num_ctrl_unit.ids/n1 + if (is_null(unit.id)) as.numeric(reuse.max) * n0 / n1 + else as.numeric(reuse.max) * num_ctrl_unit.ids / num_trt_unit.ids } if (e_ratios < 1) { @@ -532,13 +495,18 @@ matchit2nearest <- function(treat, data, distance, discarded, ratio0 <- c(rep(min.controls, kmin), kmed, rep(max.controls, kmax)) #Make sure no units are assigned 0 matches - if (any(ratio0 == 0)) { - ind <- which(ratio0 == 0) + while (any(ratio0 == 0)) { + ind <- which(ratio0 == 0)[1L] ratio0[ind] <- 1 - ratio0[ind + 1] <- ratio0[ind + 1] - 1 + + if (ind == length(ratio0)) { + break + } + + ratio0[ind + 1L] <- ratio0[ind + 1] - 1 } - ratio <- rep(NA_integer_, n1) + ratio <- rep.int(NA_integer_, n1) #Order by distance; treated are supposed to have higher values ratio[order(distance[treat == 1], @@ -546,98 +514,104 @@ matchit2nearest <- function(treat, data, distance, discarded, ratio <- as.integer(ratio) } else { - ratio <- as.integer(rep(ratio, n1)) + ratio <- as.integer(rep.int(ratio, n1)) } m.order <- { - if (is_null(distance)) match_arg(m.order, c("data", "random", "closest")) - else if (is_not_null(m.order)) match_arg(m.order, c("largest", "smallest", "data", "random", "closest")) + if (is_null(distance)) match_arg(m.order, c("data", "random", "closest", "farthest")) + else if (is_not_null(m.order)) match_arg(m.order, c("largest", "smallest", "data", "random", "closest", "farthest")) else if (estimand == "ATC") "smallest" else "largest" } - if (is_not_null(mahcovs) && ncol(mahcovs) == 1L && is.full.mahalanobis && is_null(distance)) { + #If mahcovs is only 1 variable, use vec matching for speed + if (is.full.mahalanobis && is_null(distance) && is_null(caliper.dist) && + is_not_null(mahcovs) && ncol(mahcovs) == 1L) { distance <- mahcovs[,1] mahcovs <- NULL } - if (is_null(ex0) || is_not_null(unit.id) || (is_null(mahcovs) && is_null(distance_mat))) { - ord <- switch(m.order, - "largest" = order(distance[treat == 1], decreasing = TRUE), - "smallest" = order(distance[treat == 1], decreasing = FALSE), - "random" = sample.int(n1), - "data" = seq_len(n1), - "closest" = NULL) - - mm <- nn_matchC_dispatch(treat, ord, ratio, discarded, reuse.max, distance, distance_mat, - ex0, caliper.dist, caliper.covs, caliper.covs.mat, mahcovs, + if (is_not_null(ex)) { + discarded[!ex %in% levels(ex)[cc]] <- TRUE + } + + if (is_null(ex) || is_not_null(unit.id) || (is_null(mahcovs) && is_null(distance_mat) && + !(m.order %in% c("closest", "farthest")))) { + if (is_not_null(ex.caliper)) { + ex <- exactify(list(ex, ex.caliper), + nam = names(treat), sep = ", ", include_vars = FALSE) + } + + mm <- nn_matchC_dispatch(treat, 1L, ratio, discarded, reuse.max, distance, distance_mat, + ex, caliper.dist, caliper.covs, caliper.covs.mat, mahcovs, antiexactcovs, unit.id, m.order, verbose) } else { - distance_ <- caliper.covs.mat_ <- mahcovs_ <- antiexactcovs_ <- distance_mat_ <- NULL - - mm_list <- lapply(levels(ex0)[cc], function(e) { - if (verbose) { - cat(sprintf("Matching subgroup %s/%s: %s...\n", - match(e, levels(ex0)[cc]), length(cc), e)) - } + mm_list <- lapply(levels(ex)[cc], function(e) { + .cat_verbose(sprintf("Matching subgroup %s/%s: %s...\n", + match(e, levels(ex)[cc]), length(cc), e), + verbose = verbose) - .e <- which(ex0 == e) - .e1 <- which(ex0[treat==1] == e) + .e <- which(ex == e) + .e1 <- which(ex[treat==1] == e) treat_ <- treat[.e] discarded_ <- discarded[.e] + distance_ <- NULL if (is_not_null(distance)) { distance_ <- distance[.e] } + ex.caliper_ <- NULL + if (is_not_null(ex.caliper)) { + ex.caliper_ <- ex.caliper[.e] + } + + caliper.covs.mat_ <- NULL if (is_not_null(caliper.covs.mat)) { caliper.covs.mat_ <- caliper.covs.mat[.e,, drop = FALSE] } + mahcovs_ <- NULL if (is_not_null(mahcovs)) { mahcovs_ <- mahcovs[.e,,drop = FALSE] } + antiexactcovs_ <- NULL if (is_not_null(antiexactcovs)) { antiexactcovs_ <- antiexactcovs[.e,, drop = FALSE] } + distance_mat_ <- NULL if (is_not_null(distance_mat)) { - .e0 <- which(ex0[treat==0] == e) + .e0 <- which(ex[treat==0] == e) distance_mat_ <- distance_mat[.e1, .e0, drop = FALSE] } ratio_ <- ratio[.e1] - ord_ <- switch(m.order, - "largest" = order(distance_[treat_ == 1], decreasing = TRUE), - "smallest" = order(distance_[treat_ == 1], decreasing = FALSE), - "random" = sample.int(sum(treat_ == 1)), - "data" = seq_len(sum(treat_ == 1)), - "closest" = NULL) - - mm_ <- nn_matchC_dispatch(treat_, ord_, ratio_, discarded_, reuse.max, distance_, distance_mat_, - NULL, caliper.dist, caliper.covs, caliper.covs.mat_, mahcovs_, + mm_ <- nn_matchC_dispatch(treat_, 1L, ratio_, discarded_, reuse.max, distance_, distance_mat_, + ex.caliper_, caliper.dist, caliper.covs, caliper.covs.mat_, mahcovs_, antiexactcovs_, NULL, m.order, verbose) #Ensure matched indices correspond to indices in full sample, not subgroup - mm_[] <- seq_along(treat)[.e][mm_] + mm_[] <- .e[mm_] mm_ }) #Construct match.matrix - mm <- matrix(NA_integer_, nrow = length(lab1), + mm <- matrix(NA_integer_, nrow = n1, ncol = max(vapply(mm_list, ncol, numeric(1L))), - dimnames = list(lab1, NULL)) + dimnames = list(names(treat)[treat == 1], NULL)) for (m in mm_list) { mm[rownames(m), seq_len(ncol(m))] <- m } } - if (verbose) cat("Calculating matching weights... ") + .cat_verbose("Calculating matching weights... ", + verbose = verbose) if (reuse.max > 1) { psclass <- NULL @@ -652,45 +626,55 @@ matchit2nearest <- function(treat, data, distance, discarded, subclass = psclass, weights = weights) - if (verbose) cat("Done.\n") + .cat_verbose("Done.\n", verbose = verbose) class(res) <- "matchit" res } -# Dispatches Rcpp function for NN matching -# nn_matchC_vec() if distance_mat and mahcovs are NULL -# nn_matchC() otherwise -nn_matchC_dispatch <- function(treat, ord, ratio, discarded, reuse.max, distance, distance_mat, ex, caliper.dist, +# Dispatches Rcpp functions for NN matching +nn_matchC_dispatch <- function(treat, focal, ratio, discarded, reuse.max, distance, distance_mat, ex, caliper.dist, caliper.covs, caliper.covs.mat, mahcovs, antiexactcovs, unit.id, m.order, verbose) { - if (m.order == "closest") { + if (m.order %in% c("closest", "farthest")) { if (is_not_null(mahcovs)) { nn_matchC_mahcovs_closest(treat, ratio, discarded, reuse.max, mahcovs, distance, ex, caliper.dist, caliper.covs, caliper.covs.mat, - antiexactcovs, unit.id, verbose) + antiexactcovs, unit.id, m.order == "closest", verbose) } else if (is_not_null(distance_mat)) { - nn_matchC_closest(distance_mat, treat, ratio, discarded, reuse.max, - ex, caliper.dist, caliper.covs, caliper.covs.mat, - antiexactcovs, unit.id, verbose) + nn_matchC_distmat_closest(treat, ratio, discarded, reuse.max, distance_mat, + ex, caliper.dist, caliper.covs, caliper.covs.mat, + antiexactcovs, unit.id, m.order == "closest", verbose) + } else { nn_matchC_vec_closest(treat, ratio, discarded, reuse.max, distance, ex, caliper.dist, caliper.covs, caliper.covs.mat, - antiexactcovs, unit.id, verbose) + antiexactcovs, unit.id, m.order == "closest", verbose) } } else { - if (is_not_null(distance_mat) || is_not_null(mahcovs)) { - nn_matchC(treat, ord, ratio, discarded, reuse.max, 1L, distance, distance_mat, - ex, caliper.dist, caliper.covs, caliper.covs.mat, mahcovs, - antiexactcovs, unit.id, verbose) + ord <- switch(m.order, + "largest" = order(distance[treat == focal], decreasing = TRUE), + "smallest" = order(distance[treat == focal], decreasing = FALSE), + "random" = sample(which(!discarded[treat == focal])), + "data" = which(!discarded[treat == focal])) + + if (is_not_null(mahcovs)) { + nn_matchC_mahcovs(treat, ord, ratio, discarded, reuse.max, focal, mahcovs, distance, + ex, caliper.dist, caliper.covs, caliper.covs.mat, + antiexactcovs, unit.id, verbose) + } + else if (is_not_null(distance_mat)) { + nn_matchC_distmat(treat, ord, ratio, discarded, reuse.max, focal, distance_mat, + ex, caliper.dist, caliper.covs, caliper.covs.mat, + antiexactcovs, unit.id, verbose) } else { - nn_matchC_vec(treat, ord, ratio, discarded, reuse.max, 1L, distance, + nn_matchC_vec(treat, ord, ratio, discarded, reuse.max, focal, distance, ex, caliper.dist, caliper.covs, caliper.covs.mat, antiexactcovs, unit.id, verbose) } } -} \ No newline at end of file +} diff --git a/R/matchit2optimal.R b/R/matchit2optimal.R index 30e2dad4..33090f4f 100644 --- a/R/matchit2optimal.R +++ b/R/matchit2optimal.R @@ -248,11 +248,11 @@ matchit2optimal <- function(treat, formula, data, distance, discarded, rlang::check_installed("optmatch") - if (verbose) cat("Optimal matching... \n") + .cat_verbose("Optimal matching...\n", verbose = verbose) - A <- list(...) - pm.args <- c("tol", "solver") - A[!names(A) %in% pm.args] <- NULL + args <- c("tol", "solver") + A <- setNames(lapply(args, ...get, ...), args) + A[lengths(A) == 0L] <- NULL #Set max problem size to Inf and return to original value after match omps <- getOption("optmatch_max_problem_size") @@ -299,25 +299,27 @@ matchit2optimal <- function(treat, formula, data, distance, discarded, ex <- factor(exactify(model.frame(exact, data = data), sep = ", ", include_vars = TRUE)[!discarded]) - cc <- intersect(as.integer(ex)[treat_==1], as.integer(ex)[treat_==0]) + cc <- Reduce("intersect", lapply(unique(treat_), function(t) unclass(ex)[treat_==t])) if (is_null(cc) ) { .err("no matches were found") } - e_ratios <- vapply(levels(ex), function(e) sum(treat_[ex == e] == 0)/sum(treat_[ex == e] == 1), numeric(1L)) + e_ratios <- vapply(levels(ex), function(e) { + sum(treat_[ex == e] == 0)/sum(treat_[ex == e] == 1) + }, numeric(1L)) if (any(e_ratios < 1)) { .wrn(sprintf("Fewer %s units than %s units in some `exact` strata; not all %s units will get a match", - tc[2], tc[1], tc[1])) + tc[2], tc[1], tc[1])) } if (ratio > 1 && any(e_ratios < ratio)) { if (ratio == max.controls) .wrn(sprintf("Not all %s units will get %s matches", - tc[1], ratio)) + tc[1], ratio)) else .wrn(sprintf("Not enough %s units for an average of %s matches per %s unit in all `exact` strata", - tc[2], ratio, tc[1])) + tc[2], ratio, tc[1])) } } else { @@ -328,15 +330,15 @@ matchit2optimal <- function(treat, formula, data, distance, discarded, if (e_ratios < 1) { .wrn(sprintf("Fewer %s units than %s units; not all %s units will get a match", - tc[2], tc[1], tc[1])) + tc[2], tc[1], tc[1])) } else if (e_ratios < ratio) { if (ratio == max.controls) .wrn(sprintf("Not all %s units will get %s matches", - tc[1], ratio)) + tc[1], ratio)) else .wrn(sprintf("Not enough %s units for an average of %s matches per %s unit", - tc[2], ratio, tc[1])) + tc[2], ratio, tc[1])) } } @@ -381,11 +383,11 @@ matchit2optimal <- function(treat, formula, data, distance, discarded, t_df <- data.frame(treat_) for (e in levels(ex)[cc]) { - if (nlevels(ex) > 1) { - if (verbose) { - cat(sprintf("Matching subgroup %s/%s: %s...\n", - match(e, levels(ex)[cc]), length(cc), e)) - } + if (nlevels(ex) > 1L) { + .cat_verbose(sprintf("Matching subgroup %s/%s: %s...\n", + match(e, levels(ex)[cc]), length(cc), e), + verbose = verbose) + mo_ <- mo[ex[treat_==1] == e, ex[treat_==0] == e] } else mo_ <- mo @@ -394,7 +396,7 @@ matchit2optimal <- function(treat, formula, data, distance, discarded, next } - if (all(dim(mo_) == 1) && all(is.finite(mo_))) { + if (all_equal_to(dim(mo_), 1) && all(is.finite(mo_))) { pair[ex == e] <- paste(1, e, sep = "|") next } @@ -431,10 +433,6 @@ matchit2optimal <- function(treat, formula, data, distance, discarded, pair[names(p[[e]])[!is.na(p[[e]])]] <- paste(as.character(p[[e]][!is.na(p[[e]])]), e, sep = "|") } - if (all(is.na(pair))) { - .err("No matches were found") - } - if (length(p) == 1L) { p <- p[[1]] } @@ -445,7 +443,7 @@ matchit2optimal <- function(treat, formula, data, distance, discarded, mm <- nummm2charmm(subclass2mmC(psclass, treat, focal), treat) - if (verbose) cat("Calculating matching weights... ") + .cat_verbose("Calculating matching weights... ", verbose = verbose) ## calculate weights and return the results res <- list(match.matrix = mm, @@ -453,7 +451,7 @@ matchit2optimal <- function(treat, formula, data, distance, discarded, weights = get_weights_from_subclass(psclass, treat, estimand), obj = p) - if (verbose) cat("Done.\n") + .cat_verbose("Done.\n", verbose = verbose) class(res) <- "matchit" res diff --git a/R/matchit2quick.R b/R/matchit2quick.R index f3ef87aa..f3589ac0 100644 --- a/R/matchit2quick.R +++ b/R/matchit2quick.R @@ -143,7 +143,7 @@ matchit2quick <- function(treat, formula, data, distance, discarded, rlang::check_installed("quickmatch") - if (verbose) cat("Generalized full matching... \n") + .cat_verbose("Generalized full matching...\n", verbose = verbose) A <- list(...) @@ -177,7 +177,7 @@ matchit2quick <- function(treat, formula, data, distance, discarded, ex <- factor(exactify(model.frame(exact, data = data), sep = ", ", include_vars = TRUE)[!discarded]) - cc <- intersect(as.integer(ex)[treat_==1], as.integer(ex)[treat_==0]) + cc <- Reduce("intersect", lapply(unique(treat_), function(t) unclass(ex)[treat_==t])) if (is_null(cc)) { .err("no matches were found") @@ -209,7 +209,8 @@ matchit2quick <- function(treat, formula, data, distance, discarded, if (is_not_null(mahvars)) { .err("with `method = \"quick\"`, a caliper can only be used when `distance` is a propensity score or vector and `mahvars` is not specified") } - if (length(caliper) > 1 || !identical(names(caliper), "")) { + + if (length(caliper) > 1L || !identical(names(caliper), "")) { .err("with `method = \"quick\"`, calipers cannot be placed on covariates") } } @@ -219,9 +220,10 @@ matchit2quick <- function(treat, formula, data, distance, discarded, p <- setNames(vector("list", nlevels(ex)), levels(ex)) for (e in levels(ex)[cc]) { - if (verbose && nlevels(ex) > 1) { - cat(sprintf("Matching subgroup %s/%s: %s...\n", - match(e, levels(ex)[cc]), length(cc), e)) + if (nlevels(ex) > 1L) { + .cat_verbose(sprintf("Matching subgroup %s/%s: %s...\n", + match(e, levels(ex)[cc]), length(cc), e), + verbose = verbose) } distcovs_ <- distcovs[ex == e,, drop = FALSE] @@ -237,10 +239,6 @@ matchit2quick <- function(treat, formula, data, distance, discarded, pair[which(ex == e)[!is.na(p[[e]])]] <- paste(as.character(p[[e]][!is.na(p[[e]])]), e, sep = "|") } - if (all(is.na(pair))) { - .err("no matches were found") - } - if (length(p) == 1L) { p <- p[[1]] } @@ -252,13 +250,13 @@ matchit2quick <- function(treat, formula, data, distance, discarded, #No match.matrix because treated units don't index matched strata (i.e., more than one #treated unit can be in the same stratum). Stratum information is contained in subclass. - if (verbose) cat("Calculating matching weights... ") + .cat_verbose("Calculating matching weights... ", verbose = verbose) res <- list(subclass = psclass, weights = get_weights_from_subclass(psclass, treat, estimand), obj = p) - if (verbose) cat("Done.\n") + .cat_verbose("Done.\n", verbose = verbose) class(res) <- c("matchit") res diff --git a/R/matchit2subclass.R b/R/matchit2subclass.R index 45b6d355..93fbaed8 100644 --- a/R/matchit2subclass.R +++ b/R/matchit2subclass.R @@ -158,35 +158,28 @@ NULL matchit2subclass <- function(treat, distance, discarded, replace = FALSE, exact = NULL, estimand = "ATT", verbose = FALSE, + subclass = 6L, min.n = 1L, ...) { - if(verbose) - cat("Subclassifying... \n") - - A <- list(...) - subclass <- A[["subclass"]] - sub.by <- A[["sub.by"]] - min.n <- A[["min.n"]] + .cat_verbose("Subclassifying...\n", verbose = verbose) #Checks - if (is_null(subclass)) subclass <- 6 chk::chk_numeric(subclass) if (length(subclass) == 1L) { chk::chk_gt(subclass, 1) } - else if (!all(subclass <= 1 & subclass >= 0)) { - .err("When specifying `subclass` as a vector of quantiles, all values must be between 0 and 1") + else if (any(subclass > 1) || any(subclass < 0)) { + .err("when specifying `subclass` as a vector of quantiles, all values must be between 0 and 1") } - if (is_not_null(sub.by)) { + if (is_not_null(...get("sub.by", ...))) { .err("`sub.by` is defunct and has been replaced with `estimand`") } estimand <- toupper(estimand) estimand <- match_arg(estimand, c("ATT", "ATC", "ATE")) - if (is_null(min.n)) min.n <- 1 chk::chk_count(min.n) n.obs <- length(treat) @@ -233,12 +226,13 @@ matchit2subclass <- function(treat, distance, discarded, psclass <- setNames(factor(psclass, nmax = length(q)), names(treat)) levels(psclass) <- as.character(seq_len(nlevels(psclass))) - if (verbose) cat("Calculating matching weights... ") + .cat_verbose("Calculating matching weights... ", verbose = verbose) - res <- list(subclass = psclass, q.cut = q, + res <- list(subclass = psclass, + q.cut = q, weights = get_weights_from_subclass(psclass, treat, estimand)) - if (verbose) cat("Done.\n") + .cat_verbose("Done.\n", verbose = verbose) class(res) <- c("matchit.subclass", "matchit") res diff --git a/R/plot.matchit.R b/R/plot.matchit.R index 4ae6e88c..0c48d3ec 100644 --- a/R/plot.matchit.R +++ b/R/plot.matchit.R @@ -290,7 +290,7 @@ matchit.covplot <- function(object, type = "qq", interactive = TRUE, which.xs = } if (is.character(which.xs)) { - if (!all(which.xs %in% names(data))) { + if (!all(hasName(data, which.xs))) { .err("all variables in `which.xs` must be in the supplied `matchit` object or in `data`") } X <- data[which.xs] @@ -406,6 +406,7 @@ matchit.covplot <- function(object, type = "qq", interactive = TRUE, which.xs = devAskNewPage(ask = interactive) } + devAskNewPage(ask = FALSE) invisible(NULL) @@ -416,9 +417,11 @@ matchit.covplot.subclass <- function(object, type = "qq", which.subclass = NULL, if (is_null(which.xs)) { if (is_null(object$X)) { - .wrn("No covariates to plot") + .wrn("no covariates to plot") + return(invisible(NULL)) } + X <- object$X if (is_not_null(object$exact)) { @@ -453,7 +456,7 @@ matchit.covplot.subclass <- function(object, type = "qq", which.subclass = NULL, } if (is.character(which.xs)) { - if (!all(which.xs %in% names(data))) { + if (!all(hasName(data, which.xs))) { .err("all variables in `which.xs` must be in the supplied `matchit` object or in `data`") } X <- data[which.xs] @@ -462,9 +465,7 @@ matchit.covplot.subclass <- function(object, type = "qq", which.subclass = NULL, which.xs <- update(terms(which.xs, data = data), NULL ~ .) X <- model.frame(which.xs, data, na.action = "na.pass") - if (anyNA(X)) { - .err("missing values are not allowed in the covariates named in `which.xs`") - } + chk::chk_not_any_na(X, "the covariates named in `which.xs`") } else { .err("`which.xs` must be supplied as a character vector of names or a one-sided formula") @@ -479,11 +480,11 @@ matchit.covplot.subclass <- function(object, type = "qq", which.subclass = NULL, t <- object$treat if (!is.atomic(which.subclass)) { - .err("The argument to `subclass` must be NULL or the indices of the subclasses for which to display covariate distributions") + .err("the argument to `subclass` must be `NULL` or the indices of the subclasses for which to display covariate distributions") } if (!all(which.subclass %in% object$subclass[!is.na(object$subclass)])) { - .err("The argument supplied to `subclass` is not the index of any subclass in the matchit object") + .err("the argument supplied to `subclass` is not the index of any subclass in the `matchit` object") } if (type == "density") { @@ -508,7 +509,7 @@ matchit.covplot.subclass <- function(object, type = "qq", which.subclass = NULL, } sw <- { - if (is_null(object$s.weights)) rep(1, length(t)) + if (is_null(object$s.weights)) rep.int(1, length(t)) else object$s.weights } @@ -523,27 +524,27 @@ matchit.covplot.subclass <- function(object, type = "qq", which.subclass = NULL, plot.new() - if (((i-1)%%3)==0) { + if (((i-1) %% 3) == 0) { if (type == "qq") { - htext <- paste0("eQQ Plots (Subclass ", s,")") - mtext(htext, 3, 2, TRUE, 0.5, cex=1.1, font = 2) - mtext("All", 3, .25, TRUE, 0.5, cex=1, font = 1) - mtext("Matched", 3, .25, TRUE, 0.83, cex=1, font = 1) - mtext("Control Units", 1, 0, TRUE, 2/3, cex=1, font = 1) - mtext("Treated Units", 4, 0, TRUE, 0.5, cex=1, font = 1) + htext <- sprintf("eQQ Plots (Subclass %s)", s) + mtext(htext, 3, 2, TRUE, 0.5, cex = 1.1, font = 2) + mtext("All", 3, .25, TRUE, 0.5, cex = 1, font = 1) + mtext("Matched", 3, .25, TRUE, 0.83, cex = 1, font = 1) + mtext("Control Units", 1, 0, TRUE, 2/3, cex = 1, font = 1) + mtext("Treated Units", 4, 0, TRUE, 0.5, cex = 1, font = 1) } else if (type == "ecdf") { - htext <- paste0("eCDF Plots (Subclass ", s,")") - mtext(htext, 3, 2, TRUE, 0.5, cex=1.1, font = 2) - mtext("All", 3, .25, TRUE, 0.5, cex=1, font = 1) - mtext("Matched", 3, .25, TRUE, 0.83, cex=1, font = 1) + htext <- sprintf("eCDF Plots (Subclass %s)", s) + mtext(htext, 3, 2, TRUE, 0.5, cex = 1.1, font = 2) + mtext("All", 3, .25, TRUE, 0.5, cex = 1, font = 1) + mtext("Matched", 3, .25, TRUE, 0.83, cex = 1, font = 1) } else if (type == "density") { - htext <- paste0("Density Plots (Subclass ", s,")") - mtext(htext, 3, 2, TRUE, 0.5, cex=1.1, font = 2) - mtext("All", 3, .25, TRUE, 0.5, cex=1, font = 1) - mtext("Matched", 3, .25, TRUE, 0.83, cex=1, font = 1) + htext <- sprintf("Density Plots (Subclass %s)", s) + mtext(htext, 3, 2, TRUE, 0.5, cex = 1.1, font = 2) + mtext("All", 3, .25, TRUE, 0.5, cex = 1, font = 1) + mtext("Matched", 3, .25, TRUE, 0.83, cex = 1, font = 1) } } @@ -567,6 +568,7 @@ matchit.covplot.subclass <- function(object, type = "qq", which.subclass = NULL, devAskNewPage(ask = interactive) } } + devAskNewPage(ask = FALSE) invisible(NULL) @@ -628,8 +630,8 @@ qqplot_match <- function(x, t, w, sw, discrete.cutoff = 5, ...) { rr <- range(c(x0, x1)) plot(x0, x1, xlab = "", ylab = "", xlim = rr, ylim = rr, axes = FALSE, ...) abline(a = 0, b = 1) - abline(a = (rr[2]-rr[1])*0.1, b = 1, lty = 2) - abline(a = -(rr[2]-rr[1])*0.1, b = 1, lty = 2) + abline(a = (rr[2]-rr[1]) * 0.1, b = 1, lty = 2) + abline(a = -(rr[2]-rr[1]) * 0.1, b = 1, lty = 2) axis(2) box() @@ -678,8 +680,8 @@ qqplot_match <- function(x, t, w, sw, discrete.cutoff = 5, ...) { plot(x0, x1, xlab = "", ylab = "", xlim = rr, ylim = rr, axes = FALSE, ...) abline(a = 0, b = 1) - abline(a = (rr[2]-rr[1])*0.1, b = 1, lty = 2) - abline(a = -(rr[2]-rr[1])*0.1, b = 1, lty = 2) + abline(a = (rr[2]-rr[1]) * 0.1, b = 1, lty = 2) + abline(a = -(rr[2]-rr[1]) * 0.1, b = 1, lty = 2) box() } @@ -690,7 +692,8 @@ ecdfplot_match <- function(x, t, w, sw, ...) { x.range <- x.max - x.min #Unmatched samples - plot(x = x, y = w, type= "n" , xlim = c(x.min - .02 * x.range, x.max + .02 * x.range), + plot(x = x, y = w, type = "n" , + xlim = c(x.min - .02 * x.range, x.max + .02 * x.range), ylim = c(0, 1), axes = TRUE, ...) for (tr in 0:1) { @@ -706,8 +709,10 @@ ecdfplot_match <- function(x, t, w, sw, ...) { box() #Matched sample - plot(x = x, y = w, type= "n" , xlim = c(x.min - .02 * x.range, x.max + .02 * x.range), + plot(x = x, y = w, type = "n" , + xlim = c(x.min - .02 * x.range, x.max + .02 * x.range), ylim = c(0, 1), axes = FALSE, ...) + for (tr in 0:1) { in.tr <- t[ord] == tr ordt <- ord[in.tr] @@ -723,46 +728,49 @@ ecdfplot_match <- function(x, t, w, sw, ...) { box() } -densityplot_match <- function(x, t, w, sw, ...) { +densityplot_match <- function(x, t, w, sw, bw = NULL, cut = 3, ...) { + + if (has_n_unique(x, 2L)) { + x <- factor(x, nmax = 2) + } - if (has_n_unique(x, 2L)) x <- factor(x) + u <- unique(t) if (!is.factor(x)) { #Density plot for continuous variable - small.tr <- (0:1)[which.min(c(sum(t==0), sum(t==1)))] - x_small <- x[t==small.tr] + small.tr <- u[which.min(vapply(u, function(tr) sum(t == tr), numeric(1L)))] + x_small <- x[t == small.tr] x.min <- min(x) x.max <- max(x) - # - A <- list(...) - bw <- A[["bw"]] if (is_null(bw)) { - A[["bw"]] <- bw.nrd0(x_small) + bw <- bw.nrd0(x_small) } else if (is.character(bw)) { bw <- tolower(bw) bw <- match_arg(bw, c("nrd0", "nrd", "ucv", "bcv", "sj", "sj-ste", "sj-dpi")) - A[["bw"]] <- switch(bw, nrd0 = bw.nrd0(x_small), nrd = bw.nrd(x_small), - ucv = bw.ucv(x_small), bcv = bw.bcv(x_small), sj = , - `sj-ste` = bw.SJ(x_small, method = "ste"), - `sj-dpi` = bw.SJ(x_small, method = "dpi")) - } - - if (is_null(A[["cut"]])) A[["cut"]] <- 3 - - d_unmatched <- do.call("rbind", lapply(0:1, function(tr) { - cbind(as.data.frame(do.call("density", c(list(x[t==tr], weights = sw[t==tr], - from = x.min - A[["cut"]]*A[["bw"]], - to = x.max + A[["cut"]]*A[["bw"]]), A))[1:2]), + bw <- switch(bw, nrd0 = bw.nrd0(x_small), nrd = bw.nrd(x_small), + ucv = bw.ucv(x_small), bcv = bw.bcv(x_small), sj = , + `sj-ste` = bw.SJ(x_small, method = "ste"), + `sj-dpi` = bw.SJ(x_small, method = "dpi")) + } + + d_unmatched <- do.call("rbind", lapply(u, function(tr) { + cbind(as.data.frame(density(x[t == tr], + weights = sw[t==tr], + from = x.min - cut * bw, + to = x.max + cut * bw, + bw = bw, cut = cut, ...)[1:2]), t = tr) })) - d_matched <- do.call("rbind", lapply(0:1, function(tr) { - cbind(as.data.frame(do.call("density", c(list(x[t==tr], weights = w[t==tr], - from = x.min - A[["cut"]]*A[["bw"]], - to = x.max + A[["cut"]]*A[["bw"]]), A))[1:2]), + d_matched <- do.call("rbind", lapply(u, function(tr) { + cbind(as.data.frame(density(x[t == tr], + weights = w[t == tr], + from = x.min - cut * bw, + to = x.max + cut * bw, + bw = bw, cut = cut, ...)[1:2]), t = tr) })) @@ -770,8 +778,9 @@ densityplot_match <- function(x, t, w, sw, ...) { #Unmatched samples plot(x = d_unmatched$x, y = d_unmatched$y, type = "n", - xlim = c(x.min - A[["cut"]]*A[["bw"]], x.max + A[["cut"]]*A[["bw"]]), - ylim = c(0, 1.1*y.max), axes = TRUE, ...) + xlim = c(x.min - cut * bw, x.max + cut * bw), + ylim = c(0, 1.1 * y.max), + axes = TRUE, ...) for (tr in 0:1) { in.tr <- d_unmatched$t == tr @@ -786,8 +795,9 @@ densityplot_match <- function(x, t, w, sw, ...) { #Matched sample plot(x = d_matched$x, y = d_matched$y, type = "n", - xlim = c(x.min - A[["cut"]]*A[["bw"]], x.max + A[["cut"]]*A[["bw"]]), - ylim = c(0, 1.1*y.max), axes = FALSE, ...) + xlim = c(x.min - cut * bw, x.max + cut * bw), + ylim = c(0, 1.1 * y.max), + axes = FALSE, ...) for (tr in 0:1) { in.tr <- d_matched$t == tr @@ -805,15 +815,15 @@ densityplot_match <- function(x, t, w, sw, ...) { #Bar plot for binary variable x_t_un <- lapply(sort(unique(t)), function(t_) { vapply(levels(x), function(i) { - wm(x[t==t_] == i, sw[t==t_]) + wm(x[t == t_] == i, sw[t == t_]) }, numeric(1L))}) x_t_m <- lapply(sort(unique(t)), function(t_) { vapply(levels(x), function(i) { - wm(x[t==t_] == i, w[t==t_]) + wm(x[t == t_] == i, w[t == t_]) }, numeric(1L))}) - ylim <- c(0, 1.1*max(unlist(x_t_un), unlist(x_t_m))) + ylim <- c(0, 1.1 * max(unlist(x_t_un), unlist(x_t_m))) borders <- c("grey60", "black") for (i in seq_along(x_t_un)) { @@ -842,19 +852,18 @@ hist_pscore <- function(x, xlab = "Propensity Score", freq = FALSE, ...) { treat <- x$treat pscore <- x$distance[!is.na(x$distance)] + s.weights <- { if (is_null(x$s.weights)) rep(1, length(treat)) else x$s.weights } + weights <- x$weights * s.weights matched <- weights != 0 q.cut <- x$q.cut minp <- min(pscore) maxp <- max(pscore) - ratio <- x$call$ratio - - if (is_null(ratio)) ratio <- 1 if (freq) { weights <- .make_sum_to_n(weights, by = treat) @@ -890,7 +899,9 @@ hist_pscore <- function(x, xlab = "Propensity Score", freq = FALSE, ...) { plot(pm, xlim = xlim, xlab = xlab, main = n, ylab = ylab, freq = FALSE, col = "lightgray", ...) - if (!startsWith(n, "Raw") && is_not_null(q.cut)) abline(v = q.cut, lty=2) + if (!startsWith(n, "Raw") && is_not_null(q.cut)) { + abline(v = q.cut, lty = 2) + } } } @@ -902,20 +913,23 @@ jitter_pscore <- function(x, interactive, pch = 1, ...) { treat <- x$treat pscore <- x$distance - s.weights <- if (is_null(x$s.weights)) rep(1, length(treat)) else x$s.weights + s.weights <- if (is_null(x$s.weights)) rep.int(1, length(treat)) else x$s.weights weights <- x$weights * s.weights matched <- weights > 0 q.cut <- x$q.cut - jitp <- jitter(rep(1, length(treat)), factor = 6) + (treat==1) * (weights == 0) - (treat==0) - (weights==0) * (treat==0) + jitp <- jitter(rep.int(1, length(treat)), factor = 6) + (treat==1) * (weights == 0) - (treat==0) - (weights==0) * (treat==0) cswt <- sqrt(s.weights) cwt <- sqrt(weights) minp <- min(pscore, na.rm = TRUE) maxp <- max(pscore, na.rm = TRUE) - plot(pscore, xlim = c(minp - 0.05*(maxp-minp), maxp + 0.05*(maxp-minp)), ylim = c(-1.5,2.5), + plot(pscore, xlim = c(minp - 0.05*(maxp-minp), maxp + 0.05 * (maxp - minp)), ylim = c(-1.5, 2.5), type = "n", ylab = "", xlab = "Propensity Score", axes = FALSE, main = "Distribution of Propensity Scores", ...) - if (is_not_null(q.cut)) abline(v = q.cut, col = "grey", lty = 1) + + if (is_not_null(q.cut)) { + abline(v = q.cut, col = "grey", lty = 1) + } #Matched treated points(pscore[treat==1 & matched], jitp[treat==1 & matched], diff --git a/R/rbind.matchdata.R b/R/rbind.matchdata.R index 47491199..295743ce 100644 --- a/R/rbind.matchdata.R +++ b/R/rbind.matchdata.R @@ -107,7 +107,7 @@ rbind.matchdata <- function(..., deparse.level = 1) { attr_list[[i]] <- NULL } else { - key_attrs[i] <- attr_list[[i]][which(!is.na(attr_list[[i]]))[1]] + key_attrs[i] <- Find(Negate(is.na), attr_list[[i]]) } } attrs <- names(attr_list) @@ -119,7 +119,8 @@ rbind.matchdata <- function(..., deparse.level = 1) { }) for (d in seq_along(md_list)[-1]) { - if (length(other_col_list[[d]]) != length(other_col_list[[1]]) || !all(other_col_list[[d]] %in% other_col_list[[1]])) { + if (length(other_col_list[[d]]) != length(other_col_list[[1]]) || + !all(other_col_list[[d]] %in% other_col_list[[1]])) { .err(sprintf("the %s inputs must come from the same dataset", switch(type, "matchdata" = "`match.data()`", "`get_matches()`"))) } diff --git a/R/summary.matchit.R b/R/summary.matchit.R index 3fc0be3d..4be27821 100644 --- a/R/summary.matchit.R +++ b/R/summary.matchit.R @@ -724,24 +724,23 @@ print.summary.matchit.subclass <- function(x, digits = max(3, getOption("digits" return(X) } - #Attempt to extrct data from matchit object; same as match.data() - data.fram.matchit <- FALSE - if (is_null(data)) { - env <- environment(object$formula) - data <- try(eval(object$call$data, envir = env), silent = TRUE) - if (null_or_error(data) || length(dim(data)) != 2 || nrow(data) != length(object[["treat"]])) { - env <- parent.frame() - data <- try(eval(object$call$data, envir = env), silent = TRUE) - if (null_or_error(data) || length(dim(data)) != 2 || nrow(data) != length(object[["treat"]])) { - data <- object[["model"]][["data"]] - if (is_null(data) || nrow(data) != length(object[["treat"]])) { - data <- NULL - } - else data.fram.matchit <- TRUE - } - else data.fram.matchit <- TRUE + #Attempt to extract data from matchit object; same as match.data() + data.found <- FALSE + for (i in 1:4) { + if (i == 2) { + data <- try(eval(object$call$data, envir = environment(object$formula)), silent = TRUE) + } + else if (i == 3) { + data <- try(eval(object$call$data, envir = parent.frame()), silent = TRUE) + } + else if (i == 4) { + data <- object[["model"]][["data"]] + } + + if (!null_or_error(data) && length(dim(data)) == 2L && nrow(data) == length(object[["treat"]])) { + data.found <- TRUE + break } - else data.fram.matchit <- TRUE } if (is.character(addlvariables)) { @@ -749,34 +748,34 @@ print.summary.matchit.subclass <- function(x, digits = max(3, getOption("digits" .err("if `addlvariables` is specified as a string, a data frame argument must be supplied to `data`") } - if (!all(addlvariables %in% names(data))) { + if (!all(hasName(data, addlvariables))) { .err("all variables in `addlvariables` must be in `data`") } addlvariables <- data[addlvariables] } - else if (inherits(addlvariables, "formula")) { - vars.in.formula <- all.vars(addlvariables) + else if (rlang::is_formula(addlvariables)) { if (is_not_null(data) && is.data.frame(data)) { - data <- data.frame(data[names(data) %in% vars.in.formula], - object$X[names(object$X) %in% setdiff(vars.in.formula, names(data))]) + vars.in.formula <- all.vars(addlvariables) + data <- cbind(data[names(data) %in% vars.in.formula], + object$X[names(object$X) %in% setdiff(vars.in.formula, names(data))]) + } + else { + data <- object$X } - else data <- object$X - - # addlvariables <- get.covs.matrix(addlvariables, data = data) } else if (!is.matrix(addlvariables) && !is.data.frame(addlvariables)) { .err("the argument to `addlvariables` must be in one of the accepted forms. See `?summary.matchit` for details") } - if (af <- inherits(addlvariables, "formula")) { + if (af <- rlang::is_formula(addlvariables)) { addvariables_f <- addlvariables addlvariables <- model.frame(addvariables_f, data = data, na.action = "na.pass") } if (nrow(addlvariables) != length(object$treat)) { - if (is_null(data) || data.fram.matchit) { + if (is_null(data) || data.found) { .err("variables specified in `addlvariables` must have the same number of units as are present in the original call to `matchit()`") } else { @@ -810,4 +809,4 @@ print.summary.matchit.subclass <- function(x, digits = max(3, getOption("digits" # addl_assign <- get_assign(addlvariables) cbind(X, addlvariables[, setdiff(colnames(addlvariables), colnames(X)), drop = FALSE]) -} \ No newline at end of file +} From 5e532e45b06b6930801eeb017f6b5c3ae1679970 Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Tue, 12 Nov 2024 02:25:06 -0500 Subject: [PATCH 26/48] Improvements, added normalize option --- R/matchit.R | 249 +++++++++++++++++++++++++++++++--------------------- 1 file changed, 150 insertions(+), 99 deletions(-) diff --git a/R/matchit.R b/R/matchit.R index b161937f..3cd6bbfa 100644 --- a/R/matchit.R +++ b/R/matchit.R @@ -1,7 +1,5 @@ #' Matching for Causal Inference #' -#' @aliases matchit print.matchit -#' #' @description #' `matchit()` is the main function of *MatchIt* and performs #' pairing, subset selection, and subclassification with the aim of creating @@ -27,6 +25,7 @@ #' [`"nearest"`][method_nearest] for nearest neighbor matching (on #' the propensity score by default), [`"optimal"`][method_optimal] #' for optimal pair matching, [`"full"`][method_full] for optimal +#' full matching, [`"quick"`][method_quick] for generalized (quick) #' full matching, [`"genetic"`][method_genetic] for genetic #' matching, [`"cem"`][method_cem] for coarsened exact matching, #' [`"exact"`][method_exact] for exact matching, @@ -54,11 +53,10 @@ #' scores. #' @param distance.options a named list containing additional arguments #' supplied to the function that estimates the distance measure as determined -#' by the argument to `distance`. See [distance] for an +#' by the argument to `distance`. See [`distance`] for an #' example of its use. #' @param estimand a string containing the name of the target estimand desired. -#' Can be one of `"ATT"` or `"ATC"`. Some methods accept `"ATE"` -#' as well. Default is `"ATT"`. See Details and the individual methods +#' Can be one of `"ATT"`, `"ATC"`, or `"ATE"`. Default is `"ATT"`. See Details and the individual methods #' pages for information on how this argument is used. #' @param exact for methods that allow it, for which variables exact matching #' should take place. Can be specified as a string containing the names of @@ -99,7 +97,7 @@ #' be specified as a string containing the name of variable in `data` to #' be used or a one-sided formula with the variable on the right-hand side #' (e.g., `~ SW`). Not all propensity score models accept sampling -#' weights; see [distance] for information on which do and do not, +#' weights; see [`distance`] for information on which do and do not, #' and see `vignette("sampling-weights")` for details on how to use #' sampling weights in a matching analysis. #' @param replace for methods that allow it, whether matching should be done @@ -134,10 +132,10 @@ #' the matching process in the output, i.e., by the functions from other #' packages `matchit()` calls. What is included depends on the matching #' method. Default is `FALSE`. +#' @param normalize `logical`; whether to rescale the nonzero weights in each treatment group to have an average of 1. Default is `TRUE`. See "How Matching Weights Are Computed" below for more details. #' @param \dots additional arguments passed to the functions used in the #' matching process. See the individual methods pages for information on what -#' additional arguments are allowed for each method. Ignored for `print()`. -#' @param x a `matchit` object. +#' additional arguments are allowed for each method. #' #' @details #' Details for the various matching methods can be found at the following help @@ -145,10 +143,11 @@ #' * [`method_nearest`] for nearest neighbor matching #' * [`method_optimal`] for optimal pair matching #' * [`method_full`] for optimal full matching +#' * [`method_quick`] for generalized (quick) full matching #' * [`method_genetic`] for genetic matching #' * [`method_cem`] for coarsened exact matching #' * [`method_exact`] for exact matching -#' * [`method_cardinality`] for cardinality and template matching +#' * [`method_cardinality`] for cardinality and profile matching #' * [`method_subclass`] for subclassification #' #' The pages contain information on what the method does, which of the arguments above are @@ -171,7 +170,7 @@ #' specified. All arguments other than `distance`, `discard`, and #' `reestimate` will be ignored. #' -#' See [distance] for details on the several ways to +#' See [`distance`] for details on the several ways to #' specify the `distance`, `link`, and `distance.options` #' arguments to estimate propensity scores and create distance measures. #' @@ -180,7 +179,7 @@ #' Value, below). The following rules are used: 1) if `0` is one of the #' values, it will be considered the control and the other value the treated; #' 2) otherwise, if the variable is a factor, `levels(treat)[1]` will be -#' considered control and the other variable the treated; 3) otherwise, +#' considered control and the other value the treated; 3) otherwise, #' `sort(unique(treat))[1]` will be considered control and the other value #' the treated. It is safest to ensure the treatment variable is a `0/1` #' variable. @@ -217,22 +216,30 @@ #' Matching weights are computed in one of two ways depending on whether matching was done with replacement #' or not. #' -#' For matching *without* replacement (except for cardinality matching), each +#' ### Matching without replacement and subclassification +#' +#' For matching *without* replacement (except for cardinality matching), including subclassification, each #' unit is assigned to a subclass, which represents the pair they are a part of #' (in the case of k:1 matching) or the stratum they belong to (in the case of #' exact matching, coarsened exact matching, full matching, or #' subclassification). The formula for computing the weights depends on the #' argument supplied to `estimand`. A new "stratum propensity score" -#' (`sp`) is computed as the proportion of units in each stratum that are -#' in the treated group, and all units in that stratum are assigned that +#' (\eqn{p^s_i}) is computed for each unit \eqn{i} as \eqn{p^s_i = \frac{1}{n_s}\sum_{j: s_j =s_i}{I(A_j=1)}} where \eqn{n_s} is the size of subclass \eqn{s} and \eqn{I(A_j=1)} is 1 if unit \eqn{j} is treated and 0 otherwise. That is, the stratum propensity score for stratum \eqn{s} is the proportion of units in stratum \eqn{s} that are +#' in the treated group, and all units in stratum \eqn{s} are assigned that #' stratum propensity score. This is distinct from the propensity score used for matching, if any. Weights are then computed using the standard formulas for -#' inverse probability weights with the stratum propensity score inserted: for the ATT, weights are 1 for the treated -#' units and `sp/(1-sp)` for the control units; for the ATC, weights are -#' `(1-sp)/sp` for the treated units and 1 for the control units; for the -#' ATE, weights are `1/sp` for the treated units and `1/(1-sp)` for the -#' control units. For cardinality matching, all matched units receive a weight +#' inverse probability weights with the stratum propensity score inserted: +#' * for the ATT, weights are 1 for the treated +#' units and \eqn{\frac{p^s}{1-p^s}} for the control units +#' * for the ATC, weights are +#' \eqn{\frac{1-p^s}{p^s}} for the treated units and 1 for the control units +#' * for the ATE, weights are \eqn{\frac{1}{p^s}} for the treated units and \eqn{\frac{1}{1-p^s}} for the +#' control units. +#' +#' For cardinality matching, all matched units receive a weight #' of 1. #' +#' ### Matching witht replacement +#' #' For matching *with* replacement, units are not assigned to unique strata. For #' the ATT, each treated unit gets a weight of 1. Each control unit is weighted #' as the sum of the inverse of the number of control units matched to the same @@ -244,9 +251,15 @@ #' switched. The weights are computed using the `match.matrix` component #' of the `matchit()` output object. #' -#' In each treatment group, weights are divided by the mean of the nonzero +#' ### Normalized weights +#' +#' When `normalize = TRUE` (the default), in each treatment group, weights are divided by the mean of the nonzero #' weights in that treatment group to make the weights sum to the number of -#' units in that treatment group. If sampling weights are included through the +#' units in that treatment group (i.e., to have an average of 1). +#' +#' ### Sampling weights +#' +#' If sampling weights are included through the #' `s.weights` argument, they will be included in the `matchit()` #' output object but not incorporated into the matching weights. #' [match.data()], which extracts the matched set from a `matchit` object, @@ -258,22 +271,21 @@ #' \item{match.matrix}{a matrix containing the matches. The rownames correspond #' to the treated units and the values in each row are the names (or indices) #' of the control units matched to each treated unit. When treated units are -#' matched to different numbers of control units (e.g., with exact matching or +#' matched to different numbers of control units (e.g., with variable ratio matching or #' matching with a caliper), empty spaces will be filled with `NA`. Not -#' included when `method` is `"full"`, `"cem"` (unless `k2k -#' = TRUE`), `"exact"`, or `"cardinality"`.} +#' included when `method` is `"full"`, `"cem"` (unless `k2k = TRUE`), `"exact"`, `"quick"`, or `"cardinality"` (unless `mahvars` is supplied and `ratio` is an integer).} #' \item{subclass}{a factor #' containing matching pair/stratum membership for each unit. Unmatched units -#' will have a value of `NA`. Not included when `replace = TRUE`.} +#' will have a value of `NA`. Not included when `replace = TRUE` or when `method = "cardinality"` unless `mahvars` is supplied and `ratio` is an integer.} #' \item{weights}{a numeric vector of estimated matching weights. Unmatched and #' discarded units will have a weight of zero.} #' \item{model}{the fit object of #' the model used to estimate propensity scores when `distance` is -#' specified and not `"mahalanobis"` or a numeric vector. When +#' specified as a method of estimating propensity scores. When #' `reestimate = TRUE`, this is the model estimated after discarding #' units.} #' \item{X}{a data frame of covariates mentioned in `formula`, -#' `exact`, `mahvars`, and `antiexact`.} +#' `exact`, `mahvars`, `caliper`, and `antiexact`.} #' \item{call}{the `matchit()` call.} #' \item{info}{information on the matching method and #' distance measures used.} @@ -304,24 +316,18 @@ #' distance measure cutpoints used to define the subclasses. See #' [`method_subclass`] for details. #' -#' @author Daniel Ho (\email{dho@@law.stanford.edu}); Kosuke Imai -#' (\email{imai@@harvard.edu}); Gary King (\email{king@@harvard.edu}); -#' Elizabeth Stuart (\email{estuart@@jhsph.edu}) -#' -#' Version 4.0.0 update by Noah Greifer (\email{noah.greifer@@gmail.com}) +#' @author Daniel Ho, Kosuke Imai, Gary King, and Elizabeth Stuart wrote the original package. Starting with version 4.0.0, Noah Greifer is the primary maintainer and developer. #' #' @seealso [summary.matchit()] for balance assessment after matching, [plot.matchit()] for plots of covariate balance and propensity score overlap after matching. #' -#' `vignette("MatchIt")` for an introduction to matching with -#' *MatchIt*; `vignette("matching-methods")` for descriptions of the -#' variety of matching methods and options available; -#' `vignette("assessing-balance")` for information on assessing the -#' quality of a matching specification; `vignette("estimating-effects")` -#' for instructions on how to estimate treatment effects after matching; and -#' `vignette("sampling-weights")` for a guide to using *MatchIt* with -#' sampling weights. +#' * `vignette("MatchIt")` for an introduction to matching with *MatchIt* +#' * `vignette("matching-methods")` for descriptions of the variety of matching methods and options available +#' * `vignette("assessing-balance")` for information on assessing the quality of a matching specification +#' * `vignette("estimating-effects")` for instructions on how to estimate treatment effects after matching +#' * `vignette("sampling-weights")` for a guide to using *MatchIt* with sampling weights. #' -#' @references Ho, D. E., Imai, K., King, G., & Stuart, E. A. (2007). Matching +#' @references +#' Ho, D. E., Imai, K., King, G., & Stuart, E. A. (2007). Matching #' as Nonparametric Preprocessing for Reducing Model Dependence in Parametric #' Causal Inference. *Political Analysis*, 15(3), 199–236. \doi{10.1093/pan/mpl013} #' @@ -402,6 +408,7 @@ matchit <- function(formula, ratio = 1, verbose = FALSE, include.obj = FALSE, + normalize = TRUE, ...) { #Checking input format @@ -410,17 +417,15 @@ matchit <- function(formula, ## Process method chk::chk_null_or(method, vld = chk::vld_string) - if (length(method) == 1L && is.character(method)) { - method <- tolower(method) - method <- match_arg(method, c("exact", "cem", "nearest", "optimal", "full", "genetic", "subclass", "cardinality", - "quick")) - fn2 <- paste0("matchit2", method) - } - else if (is_null(method)) { + if (is_null(method)) { fn2 <- "matchit2null" } else { - .err("`method` must be the name of a supported matching method. See `?matchit` for allowable options") + method <- tolower(method) + method <- match_arg(method, c("exact", "cem", "nearest", "optimal", + "full", "genetic", "subclass", "cardinality", + "quick")) + fn2 <- paste0("matchit2", method) } #Process formula and data inputs @@ -465,7 +470,7 @@ matchit <- function(formula, .err("if `s.weights` is specified a string, a data frame containing the named variable must be supplied to `data`") } - if (!all(s.weights %in% names(data))) { + if (!all(hasName(data, s.weights))) { .err("the name supplied to `s.weights` must be a variable in `data`") } @@ -478,7 +483,7 @@ matchit <- function(formula, s.weights <- s.weights[[1L]] } - else if (inherits(s.weights, "formula")) { + else if (rlang::is_formula(s.weights)) { s.weights.form <- update(terms(s.weights, data = data), NULL ~ .) s.weights <- model.frame(s.weights.form, data, na.action = "na.pass") @@ -560,6 +565,9 @@ matchit <- function(formula, antiexactcovs <- process.variable.input(antiexact, data) antiexact <- attr(antiexactcovs, "terms") + chk::chk_flag(verbose) + chk::chk_flag(normalize) + #Estimate distance, discard from common support, optionally re-estimate distance if (is_null(fn1) || is.full.mahalanobis) { #No distance measure @@ -569,19 +577,28 @@ matchit <- function(formula, dist.model <- link <- NULL } else { - if (verbose) { - cat("Estimating propensity scores... \n") - } + .cat_verbose("Estimating propensity scores... \n", verbose = verbose) if (is_not_null(s.weights)) { attr(s.weights, "in_ps") <- !distance %in% c("bart") } #Estimate distance - if (is_null(distance.options$formula)) distance.options$formula <- formula - if (is_null(distance.options$data)) distance.options$data <- data - if (is_null(distance.options$verbose)) distance.options$verbose <- verbose - if (is_null(distance.options$estimand)) distance.options$estimand <- estimand + if (is_null(distance.options)) { + distance.options <- list(formula = formula, + data = data, + verbose = verbose, + estimand = estimand) + } + else { + chk::chk_list(distance.options) + + if (is_null(distance.options$formula)) distance.options$formula <- formula + if (is_null(distance.options$data)) distance.options$data <- data + if (is_null(distance.options$verbose)) distance.options$verbose <- verbose + if (is_null(distance.options$estimand)) distance.options$estimand <- estimand + } + if (is_null(distance.options$weights) && !fn1 %in% c("distance2bart")) { distance.options$weights <- s.weights } @@ -605,7 +622,7 @@ matchit <- function(formula, } #Process discard - if (is_null(fn1) || is.full.mahalanobis || fn1 == "distance2user") { + if (is_null(fn1) || is.full.mahalanobis || identical(fn1, "distance2user")) { discarded <- discard(treat, distance, discard) } else { @@ -634,7 +651,8 @@ matchit <- function(formula, caliper <- process.caliper(caliper, method, data, covs, mahcovs, distance, discarded, std.caliper) if (is_not_null(attr(caliper, "cal.formula"))) { - calcovs <- model.frame(attr(caliper, "cal.formula"), data, na.action = "na.pass") + calcovs <- model.frame(attr(caliper, "cal.formula"), data = data, + na.action = "na.pass") if (anyNA(calcovs)) { .err("missing values are not allowed in the covariates named in `caliper`") @@ -653,6 +671,14 @@ matchit <- function(formula, antiexact = antiexact, ...), quote = TRUE) + weights <- match.out[["weights"]] + + #Normalize weights + if (normalize) { + wi <- which(weights > 0) + weights[wi] <- .make_sum_to_n(weights[wi], treat[wi]) + } + info <- create_info(method, fn1, link, discard, replace, ratio, mahalanobis = is.full.mahalanobis || is_not_null(mahvars), transform = attr(is.full.mahalanobis, "transform"), @@ -660,24 +686,28 @@ matchit <- function(formula, antiexact = colnames(antiexactcovs), distance_is_matrix = is_not_null(distance) && is.matrix(distance)) - #Create X.list for X output, removing duplicate variables - X.list <- list(covs, exactcovs, mahcovs, calcovs, antiexactcovs) - all.covs <- lapply(X.list, names) + #Create X output, removing duplicate variables + X.list.nm <- c("covs", "exactcovs", "mahcovs", "calcovs", "antiexactcovs") + X <- NULL + for (i in X.list.nm) { + X_tmp <- get0(i, inherits = FALSE) - for (i in seq_along(X.list)[-1]) { - if (is_not_null(X.list[[i]])) { - X.list[[i]][names(X.list[[i]]) %in% unlist(all.covs[1:(i-1)])] <- NULL + if (is_not_null(X_tmp)) { + if (is_null(X)) { + X <- X_tmp + } + else if (!all(hasName(X, names(X_tmp)))) { + X <- cbind(X, X_tmp[!names(X_tmp) %in% names(X)]) + } } } - X.list[vapply(X.list, is_null, logical(1L))] <- NULL - ## putting all the results together out <- list( match.matrix = match.out[["match.matrix"]], subclass = match.out[["subclass"]], - weights = match.out[["weights"]], - X = do.call("cbind", X.list), + weights = weights, + X = X, call = mcall, info = info, estimand = estimand, @@ -694,15 +724,14 @@ matchit <- function(formula, obj = if (include.obj) match.out[["obj"]] ) - out[vapply(out, is_null, logical(1L))] <- NULL + out[lengths(out) == 0L] <- NULL class(out) <- class(match.out) + out - #:) } -#' @export -#' @rdname matchit +#' @exportS3Method print matchit print.matchit <- function(x, ...) { info <- x[["info"]] cal <- is_not_null(x[["caliper"]]) @@ -710,11 +739,12 @@ print.matchit <- function(x, ...) { disl <- is_not_null(dis) nm <- is_null(x[["method"]]) - cat("A matchit object") - cat(paste0("\n - method: ", info.to.method(info))) + cat("A `matchit` object\n") + + cat(sprintf(" - method: %s\n", info.to.method(info))) if (is_not_null(info$distance) || info$mahalanobis) { - cat("\n - distance: ") + cat(" - distance: ") if (info$mahalanobis) { if (is_null(info$transform)) #mahvars used cat("Mahalanobis") @@ -722,6 +752,7 @@ print.matchit <- function(x, ...) { cat(capwords(gsub("_", " ", info$transform, fixed = TRUE))) } } + if (is_not_null(info$distance) && !info$distance %in% matchit_distances()) { if (info$mahalanobis) cat(" [matching]\n ") @@ -731,42 +762,62 @@ print.matchit <- function(x, ...) { else cat("User-defined") if (cal || disl) { - cal.ps <- "" %in% names(x[["caliper"]]) - cat(" [") - cat(paste(c("matching", "subclassification", "caliper", "common support")[c(!nm && !info$mahalanobis && info$method != "subclass", !nm && info$method == "subclass", cal.ps, disl)], collapse = ", ")) - cat("]") + cal.ps <- hasName(x[["caliper"]], "") + cat(sprintf(" [%s]\n", + paste(c("matching", "subclassification", "caliper", "common support")[c(!nm && !info$mahalanobis && info$method != "subclass", !nm && info$method == "subclass", cal.ps, disl)], collapse = ", "))) } + if (info$distance != "user") { - cat("\n - estimated with ") - cat(info.to.distance(info)) + cat(sprintf(" - estimated with %s\n", + info.to.distance(info))) if (is_not_null(x[["s.weights"]])) { - if (isTRUE(attr(x[["s.weights"]], "in_ps"))) - cat("\n - sampling weights included in estimation") - else cat("\n - sampling weights not included in estimation") + cat(sprintf(" - sampling weights %sincluded in estimation\n", + if (isTRUE(attr(x[["s.weights"]], "in_ps"))) "" else "not ")) } } } } if (cal) { - cat(paste0("\n - caliper: ", paste(vapply(seq_along(x[["caliper"]]), function(z) paste0(if (names(x[["caliper"]])[z] == "") "" else names(x[["caliper"]])[z], - " (", format(round(x[["caliper"]][z], 3)), ")"), character(1L)), - collapse = ", "))) + cat(sprintf(" - caliper: %s\n", + paste(vapply(seq_along(x[["caliper"]]), + function(z) { + sprintf("%s (%s)", + if (names(x[["caliper"]])[z] == "") "" + else names(x[["caliper"]])[z], + format(round(x[["caliper"]][z], 3))) + }, character(1L)), + collapse = ", "))) } if (disl) { - cat("\n - common support: ") - if (dis == "both") cat("units from both groups") - else if (dis == "treat") cat("treated units") - else if (dis == "control") cat("control units") - cat(" dropped") + cat(sprintf(" - common support: %s dropped\n", + switch(dis, + "both" = "units from both groups", + "treat" = "treated units", + "control" = "control units"))) + } + + cat(sprintf(" - number of obs.: %s (original)%s\n", + length(x[["treat"]]), + if (!all_equal_to(x[["weights"]], 1)) + sprintf(", %s (matched)", sum(x[["weights"]] != 0)) + else "")) + + if (is_not_null(x[["s.weights"]])) { + cat(" - sampling weights: present\n") + } + + if (is_not_null(x[["estimand"]])) { + cat(sprintf(" - target estimand: %s\n", x[["estimand"]])) + } + + if (is_not_null(x[["X"]])) { + cat(sprintf(" - covariates: %s\n", + if (length(names(x[["X"]])) > 40L) "too many to name" + else paste(names(x[["X"]]), collapse = ", "))) } - cat(paste0("\n - number of obs.: ", length(x[["treat"]]), " (original)", if (!all(x[["weights"]] == 1)) paste0(", ", sum(x[["weights"]] != 0), " (matched)"))) - if (is_not_null(x[["s.weights"]])) cat("\n - sampling weights: present") - if (is_not_null(x[["estimand"]])) cat(paste0("\n - target estimand: ", x[["estimand"]])) - if (is_not_null(x[["X"]])) cat(paste0("\n - covariates: ", if (length(names(x[["X"]])) > 40) "too many to name" else paste(names(x[["X"]]), collapse = ", "))) - cat("\n") invisible(x) } From 61df63a6e8e9982e5df27eaa89e99b93ef4b8bda Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Tue, 12 Nov 2024 02:25:48 -0500 Subject: [PATCH 27/48] k2k now uses nnmatch and allows m.order --- R/matchit2cem.R | 195 +++++++++++++++++++++++++----------------------- 1 file changed, 102 insertions(+), 93 deletions(-) diff --git a/R/matchit2cem.R b/R/matchit2cem.R index a6a4d504..c1b76076 100644 --- a/R/matchit2cem.R +++ b/R/matchit2cem.R @@ -75,18 +75,29 @@ #' the stratum, the treated units without a match will be dropped). The #' `k2k.method` argument controls how the distance between units is #' calculated. } -#' \item{`k2k.method`}{ `character`; how the distance +#' \item{`k2k.method`}{`character`; how the distance #' between units should be calculated if `k2k = TRUE`. Allowable arguments #' include `NULL` (for random matching), any argument to #' [distance()] for computing a distance matrix from covariates #' (e.g., `"mahalanobis"`), or any allowable argument to `method` in #' [dist()]. Matching will take place on the original -#' (non-coarsened) variables. The default is `"mahalanobis"`. } -#' \item{`mpower`}{ if `k2k.method = "minkowski"`, the power used in -#' creating the distance. This is passed to the `p` argument of [dist()].} +#' (non-coarsened) variables. The default is `"mahalanobis"`. +#' } +#' \item{`mpower`}{if `k2k.method = "minkowski"`, the power used in +#' creating the distance. This is passed to the `p` argument of [dist()]. +#' } +#' \item{`m.order`}{`character`; the order that the matching takes place when `k2k = TRUE`. Allowable options +#' include `"closest"`, where matching takes place in +#' ascending order of the smallest distance between units; `"farthest"`, where matching takes place in +#' descending order of the smallest distance between units; `"random"`, where matching takes place +#' in a random order; and `"data"` where matching takes place based on the +#' order of units in the data. When `m.order = "random"`, results may differ +#' across different runs of the same code unless a seed is set and specified +#' with [set.seed()]. The default of `NULL` corresponds to `"data"`. See [`method_nearest`] for more information. +#' } #' } #' -#' The arguments `distance` (and related arguments), `exact`, `mahvars`, `discard` (and related arguments), `replace`, `m.order`, `caliper` (and related arguments), and `ratio` are ignored with a warning. +#' The arguments `distance` (and related arguments), `exact`, `mahvars`, `discard` (and related arguments), `replace`, `caliper` (and related arguments), and `ratio` are ignored with a warning. #' #' @section Outputs: #' @@ -109,8 +120,7 @@ #' Setting `k2k = TRUE` is equivalent to first doing coarsened exact #' matching with `k2k = FALSE` and then supplying stratum membership as an #' exact matching variable (i.e., in `exact`) to another call to -#' `matchit()` with `method = "nearest"`, `distance = -#' "mahalanobis"` and an argument to `discard` denoting unmatched units. +#' `matchit()` with `method = "nearest"`. #' It is also equivalent to performing nearest neighbor matching supplying #' coarsened versions of the variables to `exact`, except that #' `method = "cem"` automatically coarsens the continuous variables. The @@ -262,14 +272,12 @@ NULL -matchit2cem <- function(treat, covs, estimand = "ATT", s.weights = NULL, verbose = FALSE, TEST = 1, ...) { +matchit2cem <- function(treat, covs, estimand = "ATT", s.weights = NULL, m.order = NULL, verbose = FALSE, ...) { if (is_null(covs)) { .err("Covariates must be specified in the input formula to use coarsened exact matching") } - if (verbose) cat("Coarsened exact matching...\n") - - A <- list(...) + .cat_verbose("Coarsened exact matching... \n", verbose = verbose) # if (isTRUE(A[["k2k"]])) { # if (!has_n_unique(treat, 2L)) { @@ -281,57 +289,55 @@ matchit2cem <- function(treat, covs, estimand = "ATT", s.weights = NULL, verbose estimand <- match_arg(estimand, c("ATT", "ATC", "ATE")) #Uses in-house cem, no need for cem package. - strat <- do.call("cem_matchit", c(list(treat = treat, X = covs), - A[names(A) %in% names(formals(cem_matchit))])) - - levels(strat) <- seq_len(nlevels(strat)) - names(strat) <- names(treat) + strat <- cem_matchit(treat = treat, X = covs, ...) mm <- NULL - if (isTRUE(A[["k2k"]])) { + if (isTRUE(...get("k2k", ...))) { focal <- switch(estimand, "ATC" = 0, 1) - strat <- do.call("do_k2k", c(list(treat = treat, X = covs, subclass = strat, - estimand = estimand, - s.weights = s.weights), - A[names(A) %in% names(formals(do_k2k))])) + mm <- do_k2k(treat = treat, + X = covs, + subclass = strat, + s.weights = s.weights, + focal = focal, + m.order = m.order, + verbose = verbose, + ...) - strat <- setNames(factor(strat), names(treat)) + strat <- mm2subclass(mm, treat, focal = focal) levels(strat) <- seq_len(nlevels(strat)) - mm <- nummm2charmm(subclass2mmC(strat, treat, focal = focal), - treat) + mm <- nummm2charmm(mm, treat) weights <- get_weights_from_mm(mm, treat, focal) } else { - strat <- setNames(factor(strat), names(treat)) levels(strat) <- seq_len(nlevels(strat)) weights <- get_weights_from_subclass(strat, treat, estimand) } - if (verbose) cat("Calculating matching weights... ") + .cat_verbose("Calculating matching weights... ", verbose = verbose) res <- list(match.matrix = mm, subclass = strat, weights = weights) - if (verbose) cat("Done.\n") + .cat_verbose("Done.\n", verbose = verbose) class(res) <- "matchit" res } -cem_matchit <- function(treat, X, cutpoints = "sturges", grouping = list()) { +cem_matchit <- function(treat, X, cutpoints = "sturges", grouping = list(), ...) { #In-house implementation of cem. Basically the same except: #treat is a vector if treatment status, not the name of a variable #X is a data.frame of covariates #when cutpoints are given as integer or string, they define the number of bins, not the number of breakpoints. "ss" is no longer allowed. - for (i in names(X)) { - if (is.ordered(X[[i]])) X[[i]] <- as.numeric(X[[i]]) + for (i in seq_along(X)) { + if (is.ordered(X[[i]])) X[[i]] <- unclass(X[[i]]) } is.numeric.cov <- setNames(vapply(X, is.numeric, logical(1L)), names(X)) @@ -385,7 +391,7 @@ cem_matchit <- function(treat, X, cutpoints = "sturges", grouping = list()) { #Process cutpoints if (!is.list(cutpoints)) { - cutpoints <- setNames(lapply(names(X)[is.numeric.cov], function(i) cutpoints), names(X)[is.numeric.cov]) + cutpoints <- setNames(rep(list(cutpoints), sum(is.numeric.cov)), names(X)[is.numeric.cov]) } if (is_null(names(cutpoints))) { @@ -441,10 +447,7 @@ cem_matchit <- function(treat, X, cutpoints = "sturges", grouping = list()) { bad.cuts[i] <- !(startsWith(cutpoints[[i]], "q") && can_str2num(substring(cutpoints[[i]], 2))) && is.na(pmatch(cutpoints[[i]], c("sturges", "fd", "scott"))) } - else if (!is.numeric(cutpoints[[i]])) { - bad.cuts[i] <- TRUE - } - else if (!is.finite(cutpoints[[i]]) || cutpoints[[i]] < 0) { + else if (!is.numeric(cutpoints[[i]]) || !is.finite(cutpoints[[i]]) || cutpoints[[i]] < 0) { bad.cuts[i] <- TRUE } else if (cutpoints[[i]] == 0) { @@ -466,10 +469,14 @@ cem_matchit <- function(treat, X, cutpoints = "sturges", grouping = list()) { tidy = FALSE, n = sum(bad.cuts)) } + if (is_null(X)) { + return(setNames(rep(1L, length(treat)), names(treat))) + } + #Create bins for numeric variables for (i in names(X)[is.numeric.cov]) { bins <- { - if (is_not_null(cutpoints) && i %in% names(cutpoints)) cutpoints[[i]] + if (is_not_null(cutpoints) && any(names(cutpoints) == i)) cutpoints[[i]] else "sturges" } @@ -501,81 +508,83 @@ cem_matchit <- function(treat, X, cutpoints = "sturges", grouping = list()) { X[[i]] <- findInterval(X[[i]], breaks) } - if (is_null(X)) { - subclass <- setNames(rep(1L, length(treat)), names(treat)) - } - else { - #Exact match - xx <- exactify(X, names(treat)) + #Exact match + ex <- unclass(exactify(X, names(treat))) - cc <- do.call("intersect", unname(split(xx, treat))) + cc <- Reduce("intersect", lapply(unique(treat), function(t) ex[treat==t])) - if (is_null(cc)) { - .err("no units were matched. Try coarsening the variables further or decrease the number of variables to match on") - } - - subclass <- setNames(match(xx, cc), names(treat)) + if (is_null(cc)) { + .err("no units were matched. Try coarsening the variables further or decrease the number of variables to match on") } - subclass + setNames(factor(match(ex, cc), nmax = length(cc)), names(treat)) } do_k2k <- function(treat, X, subclass, k2k.method = "mahalanobis", mpower = 2, s.weights = NULL, - estimand = "ATT") { - - if (is_not_null(k2k.method)) { - k2k.method <- tolower(k2k.method) - k2k.method <- match_arg(k2k.method, c(matchit_distances(), "maximum", "manhattan", "canberra", "binary", "minkowski")) - - X.match <- transform_covariates(data = X, s.weights = s.weights, treat = treat, - method = if (k2k.method %in% matchit_distances()) k2k.method else "euclidean") - } - - na.sub <- is.na(subclass) - - s <- switch(estimand, "ATC" = 0, 1) - - extra.sub <- max(subclass, na.rm = TRUE) + focal, m.order = "data", verbose = FALSE, k2k = TRUE, ...) { + #Note: need k2k argument to prevent partial matching for k2k.method - for (i in which(tabulateC(subclass[!na.sub]) > 2)) { + m.order <- match_arg(m.order, c("data", "random", "closest", "farthest")) - in.sub <- which(!na.sub & subclass == i) + .cat_verbose("K:K matching...\n", verbose = verbose) - #Compute distance matrix; all 0s if k2k.method = NULL for matched based on data order - if (is_null(k2k.method)) { - dist.mat <- matrix(0, nrow = sum(treat[in.sub] == s), ncol = sum(treat[in.sub] != s), - dimnames = list(names(treat)[in.sub][treat[in.sub] == s], - names(treat)[in.sub][treat[in.sub] != s])) + if (is_not_null(k2k.method)) { + chk::chk_string(k2k.method) + k2k.method <- tolower(k2k.method) + k2k.method <- match_arg(k2k.method, c(matchit_distances(), "maximum", "manhattan", "canberra", "binary", "minkowski")) - } - else if (k2k.method %in% matchit_distances()) { - #X.match has been transformed - dist.mat <- eucdist_internal(X.match[in.sub,,drop = FALSE], treat[in.sub] == s) - } - else { - dist.mat <- as.matrix(dist(X.match[in.sub,,drop = FALSE], method = k2k.method, p = mpower)) + if (k2k.method == "minkowski") { + chk::chk_number(mpower) + chk::chk_gt(mpower, 0) - #Put smaller group on rows - d.rows <- which(rownames(dist.mat) %in% names(treat[in.sub])[treat[in.sub] == s]) - dist.mat <- dist.mat[d.rows, -d.rows, drop = FALSE] + if (mpower == 2) { + k2k.method <- "euclidean" + } } - #For each member of group on row, find closest remaining pair from cols - while (all(dim(dist.mat) > 0)) { - extra.sub <- extra.sub + 1 + X.match <- transform_covariates(data = X, s.weights = s.weights, treat = treat, + method = if (k2k.method %in% matchit_distances()) k2k.method else "euclidean") + distance <- NULL + } + else { + k2k.method <- "euclidean" + X.match <- NULL + distance <- rep(0.0, length(treat)) + } - closest <- which.min(dist.mat[1,]) - subclass[c(rownames(dist.mat)[1], colnames(dist.mat)[closest])] <- extra.sub + reuse.max <- 1L + caliper.dist <- caliper.covs <- caliper.covs.mat <- antiexactcovs <- unit.id <- NULL - #Drop already paired units from dist.mat - dist.mat <- dist.mat[-1,-closest, drop = FALSE] - } + if (k2k.method %in% matchit_distances()) { + discarded <- is.na(subclass) + ratio <- rep(1L, sum(treat == focal)) - #If any unmatched units remain, give them NA subclass - if (any(dim(dist.mat) > 0)) { - is.na(subclass)[unlist(dimnames(dist.mat))] <- TRUE + mm <- nn_matchC_dispatch(treat, focal, ratio, discarded, reuse.max, distance, NULL, + subclass, caliper.dist, caliper.covs, caliper.covs.mat, X.match, + antiexactcovs, unit.id, m.order, verbose) + } + else { + mm <- matrix(NA_integer_, ncol = 1, nrow = sum(treat == 1), + dimnames = list(names(treat)[treat == 1], NULL)) + + for (s in levels(subclass)) { + .e <- which(subclass == s) + treat_ <- treat[.e] + discarded_ <- rep(FALSE, length(.e)) + ex_ <- NULL + ratio_ <- rep(1L, sum(treat_ == focal)) + distance_mat <- as.matrix(dist(X.match[.e,,drop = FALSE], + method = k2k.method, p = mpower))[treat_ == focal, treat_ != focal, drop = FALSE] + + mm_ <- nn_matchC_dispatch(treat_, focal, ratio_, discarded_, reuse.max, distance, distance_mat, + ex_, caliper.dist, caliper.covs, caliper.covs.mat, NULL, + antiexactcovs, unit.id, m.order, FALSE) + + #Ensure matched indices correspond to indices in full sample, not subgroup + mm_[] <- .e[mm_] + mm[rownames(mm_),] <- mm_ } } - subclass + mm } \ No newline at end of file From 3beef175e680033fce3f348938fac926c9027ebe Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Tue, 12 Nov 2024 02:26:07 -0500 Subject: [PATCH 28/48] New faster helper functions and improvements. --- R/utils.R | 107 +++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 93 insertions(+), 14 deletions(-) diff --git a/R/utils.R b/R/utils.R index 588c4ea6..8e1cce4b 100644 --- a/R/utils.R +++ b/R/utils.R @@ -46,7 +46,9 @@ word_list <- function(word.list = NULL, and.or = "and", is.are = FALSE, quotes = } } - if (is.are) out <- sprintf("%s are", out) + if (is.are) { + out <- sprintf("%s are", out) + } attr(out, "plural") <- TRUE @@ -107,20 +109,69 @@ match_arg <- function(arg, choices, several.ok = FALSE) { } else { chk::chk_string(arg, x_name = add_quotes(arg.name, "`")) - if (identical(arg, choices)) return(arg[1L]) + + if (identical(arg, choices)) { + return(arg[1L]) + } } i <- pmatch(arg, choices, nomatch = 0L, duplicates.ok = TRUE) - if (all(i == 0L)) + + if (all_equal_to(i, 0L)) { .err(sprintf("the argument to `%s` should be %s%s", arg.name, ngettext(length(choices), "", if (several.ok) "at least one of " else "one of "), word_list(choices, and.or = "or", quotes = 2))) + } + i <- i[i > 0L] choices[i] } +# Version of interaction(., drop = TRUE) that doesn't succumb to vector limit reached by +# avoiding Cartesian expansion. Falls back to interaction() for small problems. +interaction2 <- function(..., sep = ".", lex.order = TRUE) { + + narg <- ...length() + + if (narg == 0L) { + stop("No factors specified") + } + + if (narg == 1L && is.list(..1)) { + args <- ..1 + narg <- length(args) + } + else { + args <- list(...) + } + + for (i in seq_len(narg)) { + args[[i]] <- as.factor(args[[i]]) + } + + if (do.call("prod", lapply(args, nlevels)) <= 1e6) { + return(interaction(args, drop = TRUE, sep = sep, + lex.order = if (is.null(lex.order)) TRUE else lex.order)) + } + + out <- do.call(function(...) paste(..., sep = sep), args) + + args_char <- lapply(args, function(x) { + x <- unclass(x) + formatC(x, format = "d", flag = "0", width = ceiling(log10(max(x)))) + }) + + lev <- { + if (is.null(lex.order)) unique(out) + else if (lex.order) unique(out[order(do.call("paste", c(args_char, sep = sep)))]) + else unique(out[order(do.call("paste", c(rev(args_char), sep = sep)))]) + } + + factor(out, levels = lev) +} + #Turn a vector into a 0/1 vector. 'zero' and 'one' can be supplied to make it clear which is #which; otherwise, a guess is used. From WeightIt. binarize <- function(variable, zero = NULL, one = NULL) { @@ -135,7 +186,7 @@ binarize <- function(variable, zero = NULL, one = NULL) { } if (is.character(variable) || is.factor(variable)) { - variable <- factor(variable, nmax = if (is.factor(variable)) nlevels(variable) else NA) + variable <- factor(variable, nmax = 2L) unique.vals <- levels(variable) } else { @@ -173,7 +224,7 @@ binarize <- function(variable, zero = NULL, one = NULL) { variable.numeric <- { if (can_str2num(unique.vals)) setNames(str2num(unique.vals), unique.vals)[variable] - else as.numeric(factor(variable, levels = unique.vals)) + else unclass(factor(variable, levels = unique.vals)) } zero <- { @@ -219,7 +270,8 @@ firstup <- function(x) { #Capitalize first letter of each word capwords <- function(s, strict = FALSE) { cap <- function(s) paste0(toupper(substring(s, 1, 1)), - {s <- substring(s, 2); if(strict) tolower(s) else s}, + {s <- substring(s, 2) + if (strict) tolower(s) else s}, collapse = " ") sapply(strsplit(s, split = " "), cap, USE.NAMES = is_not_null(names(s))) } @@ -240,6 +292,7 @@ round_df_char <- function(df, digits, pad = "0", na_vals = "") { infs <- o.negs <- array(FALSE, dim = dim(df)) nas <- is.na(df) nums <- vapply(df, is.numeric, logical(1)) + for (i in which(nums)) { infs[,i] <- !nas[,i] & !is.finite(df[[i]]) } @@ -254,19 +307,24 @@ round_df_char <- function(df, digits, pad = "0", na_vals = "") { o.negs[,nums] <- !nas[,nums] & df[nums] < 0 & round(df[nums], digits) == 0 df[nums] <- round(df[nums], digits = digits) + pad0 <- identical(as.character(pad), "0") + for (i in which(nums)) { df[[i]] <- format(df[[i]], scientific = FALSE, justify = "none", trim = TRUE, - drop0trailing = !identical(as.character(pad), "0")) + drop0trailing = !pad0) - if (!identical(as.character(pad), "0") && any(grepl(".", df[[i]], fixed = TRUE))) { + if (!pad0 && any(grepl(".", df[[i]], fixed = TRUE))) { s <- strsplit(df[[i]], ".", fixed = TRUE) lengths <- lengths(s) digits.r.of.. <- rep.int(0, NROW(df)) digits.r.of..[lengths > 1] <- nchar(vapply(s[lengths > 1], `[[`, character(1L), 2)) - max.dig <- max(digits.r.of..) - dots <- ifelse(lengths > 1, "", if (as.character(pad) != "") "." else pad) - pads <- vapply(max.dig - digits.r.of.., function(n) paste(rep.int(pad, n), collapse = ""), character(1L)) + dots <- rep.int("", length(s)) + dots[lengths <= 1] <- if (as.character(pad) != "") "." else pad + + pads <- vapply(max(digits.r.of..) - digits.r.of.., + function(n) paste(rep.int(pad, n), collapse = ""), + character(1L)) df[[i]] <- paste0(df[[i]], dots, pads) } @@ -285,17 +343,17 @@ round_df_char <- function(df, digits, pad = "0", na_vals = "") { } #Generalized inverse; port of MASS::ginv() -generalized_inverse <- function(sigma) { +generalized_inverse <- function(sigma, tol = 1e-8) { sigmasvd <- svd(sigma) - pos <- sigmasvd$d > max(1e-8 * sigmasvd$d[1L], 0) + pos <- sigmasvd$d > max(tol * sigmasvd$d[1L], 0) sigmasvd$v[, pos, drop = FALSE] %*% (sigmasvd$d[pos]^-1 * t(sigmasvd$u[, pos, drop = FALSE])) } #(Weighted) variance that uses special formula for binary variables wvar <- function(x, bin.var = NULL, w = NULL) { - if (is_null(w)) w <- rep(1, length(x)) + if (is_null(w)) w <- rep.int(1, length(x)) if (is_null(bin.var)) bin.var <- all(x == 0 | x == 1) w <- w / sum(w) #weights normalized to sum to 1 @@ -369,6 +427,27 @@ diff1 <- function(x) { x } +...get <- function(x, ...) { + m <- match(x, ...names(), 0L) + + if (m == 0L) { + return(NULL) + } + + ...elt(m) +} + +#cat() if verbose = TRUE (default sep = "", line wrapping) +.cat_verbose <- function(..., verbose = TRUE, sep = "") { + if (!verbose) { + return(invisible(NULL)) + } + + m <- do.call(function(...) paste(..., sep = sep), list(...)) + + cat(paste(strwrap(m), collapse = "\n")) +} + #Functions for error handling; based on chk and rlang pkg_caller_call <- function(start = 1) { pn <- utils::packageName() From f7645ac2ae9f3d1864035a082bfee6b01dbc65ac Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Tue, 12 Nov 2024 02:26:21 -0500 Subject: [PATCH 29/48] test updates --- tests/testthat/test-method_cem.R | 137 +++++++++++++++++++++++++++ tests/testthat/test-method_nearest.R | 43 +++++---- 2 files changed, 164 insertions(+), 16 deletions(-) create mode 100644 tests/testthat/test-method_cem.R diff --git a/tests/testthat/test-method_cem.R b/tests/testthat/test-method_cem.R new file mode 100644 index 00000000..4ac1a419 --- /dev/null +++ b/tests/testthat/test-method_cem.R @@ -0,0 +1,137 @@ +test_that("Coarsened exact matching works", { + set.seed(123) + k <- 6 + n <- 1e4 + + d <- as.data.frame(matrix(rnorm(k * n), nrow = n)) + + d[[1]] <- factor(cut(d[[1]], 4, labels = FALSE)) + d[[2]] <- factor(cut(d[[2]], 10, labels = FALSE)) + + d$a <- rbinom(n, 1, .3) + + m <- matchit(a ~ ., data = d, method = "cem") + + expect_good_matchit(m, expect_subclass = TRUE, expect_distance = FALSE, + expect_match.matrix = FALSE) + + #Categories are exactly matched by default + expect_true(all(sapply(levels(m$subclass), function(s) length(unique(d[[1]][which(m$subclass == s)])) == 1))) + expect_true(all(sapply(levels(m$subclass), function(s) length(unique(d[[2]][which(m$subclass == s)])) == 1))) + expect_false(all(sapply(levels(m$subclass), function(s) length(unique(d[[3]][which(m$subclass == s)])) == 1))) + + #k2k didn't accidentally activate + expect_true(length(unique(sapply(unique(m$treat), function(t) { + sum(m$weights[m$treat == t] > 0) + }))) > 1L) + + #Groupings: V1 into 2 categories, no grouping of V2 + m <- matchit(a ~ ., data = d, method = "cem", + grouping = list(V1 = list(c("1", "2"), c("3", "4")), + V2 = list(levels(d$V2)))) + + expect_good_matchit(m, expect_subclass = TRUE, expect_distance = FALSE, + expect_match.matrix = FALSE) + + + #Each subclass has V1 in 1,2 or 3,4 + expect_true(all(sapply(levels(m$subclass), function(s) all(d[[1]][which(m$subclass == s)] %in% c("1", "2")) || + all(d[[1]][which(m$subclass == s)] %in% c("3", "4"))))) + + #No restriction on bins for V2 + expect_false(all(sapply(levels(m$subclass), function(s) length(unique(d[[2]][which(m$subclass == s)])) == 1))) + + + m <- matchit(a ~ ., data = d, method = "cem", + grouping = list(V1 = list(c("1", "2"), c("3", "4")), + V2 = list(levels(d$V2))), + cutpoints = list(V3 = c(-1.5, 1.5), + V4 = 1)) + + expect_good_matchit(m, expect_subclass = TRUE, expect_distance = FALSE, + expect_match.matrix = FALSE) + + #Each subclass has V1 in 1,2 or 3,4 + expect_true(all(sapply(levels(m$subclass), function(s) all(d[[1]][which(m$subclass == s)] %in% c("1", "2")) || + all(d[[1]][which(m$subclass == s)] %in% c("3", "4"))))) + + #No restriction on bins for V2 + expect_false(all(sapply(levels(m$subclass), function(s) length(unique(d[[2]][which(m$subclass == s)])) == 1))) + + #V3 correctly split into defined bins + expect_true(all(sapply(levels(m$subclass), function(s) { + all(d[[3]][which(m$subclass == s)] < -1.5) || + all(d[[3]][which(m$subclass == s)] > -1.5 | d[[3]][which(m$subclass == s)] < 1.5) || + all(d[[3]][which(m$subclass == s)] > 1.5) + }))) + + #Setting V1 = 1 in cutpoints same as omitting it + m1 <- matchit(a ~ . - V4, data = d, method = "cem", + grouping = list(V1 = list(c("1", "2"), c("3", "4")), + V2 = list(levels(d$V2))), + cutpoints = list(V3 = c(-1.5, 1.5))) + + expect_equal(m$subclass, m1$subclass) + + m <- matchit(a ~ ., data = d, method = "cem", k2k = TRUE) + + expect_good_matchit(m, expect_subclass = TRUE, expect_distance = FALSE, + expect_match.matrix = TRUE, ratio = 1) + + #1:1 matched + expect_true(length(unique(sapply(unique(m$treat), function(t) { + sum(m$weights[m$treat == t] > 0) + }))) == 1L) + + #Default is using Mahalanobis + m1 <- matchit(a ~ ., data = d, method = "cem", k2k = TRUE, k2k.method = "mahalanobis") + expect_good_matchit(m, expect_subclass = TRUE, expect_distance = FALSE, + expect_match.matrix = TRUE, ratio = 1) + + expect_equal(m$match.matrix, m1$match.matrix) + expect_equal(m$subclass, m1$subclass) + + m1 <- matchit(a ~ ., data = d, method = "cem", k2k = TRUE, k2k.method = "scaled_euclidean") + + expect_good_matchit(m, expect_subclass = TRUE, expect_distance = FALSE, + expect_match.matrix = TRUE, ratio = 1) + + expect_failure(expect_equal(m$match.matrix, m1$match.matrix)) + expect_failure(expect_equal(m$subclass, m1$subclass)) + + m1 <- matchit(a ~ ., data = d, method = "cem", k2k = TRUE, k2k.method = "manhattan") + + expect_good_matchit(m, expect_subclass = TRUE, expect_distance = FALSE, + expect_match.matrix = TRUE, ratio = 1) + + expect_failure(expect_equal(m$match.matrix, m1$match.matrix)) + expect_failure(expect_equal(m$subclass, m1$subclass)) + + m1 <- matchit(a ~ ., data = d, method = "cem", k2k = TRUE, m.order = "data") + + expect_good_matchit(m, expect_subclass = TRUE, expect_distance = FALSE, + expect_match.matrix = TRUE, ratio = 1) + + expect_equal(m$match.matrix, m1$match.matrix) + expect_equal(m$subclass, m1$subclass) + + m1 <- matchit(a ~ ., data = d, method = "cem", k2k = TRUE, m.order = "closest") + + expect_good_matchit(m, expect_subclass = TRUE, expect_distance = FALSE, + expect_match.matrix = TRUE, ratio = 1) + + expect_failure(expect_equal(m$match.matrix, m1$match.matrix)) + expect_failure(expect_equal(m$subclass, m1$subclass)) + + m2 <- matchit(a ~ ., data = d, method = "cem") + m2$subclass[is.na(m2$subclass)] <- m2$subclass[!is.na(m2$subclass)][1] + + # Equivalent to NN matching with exact matching on subclass + suppressWarnings({ + m1 <- matchit(a ~ ., data = d, method = "nearest", distance = "mahalanobis", + discard = m$weights == 0, exact = ~m2$subclass) + }) + + expect_equal(m$match.matrix, m1$match.matrix) + expect_equal(m$subclass, m1$subclass) +}) diff --git a/tests/testthat/test-method_nearest.R b/tests/testthat/test-method_nearest.R index df217448..51c2a962 100644 --- a/tests/testthat/test-method_nearest.R +++ b/tests/testthat/test-method_nearest.R @@ -16,57 +16,67 @@ test_that("distance vector, mah vars, and distance matrix yield identical result M <- list() if (any(which == 1)) { - m <- matchit(a ~ p + p_, data = d, - distance = d$p, - ...) + matchit_try({ + m <- matchit(a ~ p + p_, data = d, + distance = d$p, + ...) + }, dont_warn_if = "Fewer control") expect_good_matchit(m, expect_distance = TRUE, expect_match.matrix = TRUE, expect_subclass = !m$info$replace, replace = m$info$replace, ratio = m$info$ratio) M <- c(M, list(m)) } if (any(which == 2)) { - m <- matchit(a ~ p + p_, data = d, - distance = "euclidean", - ...) + matchit_try({ + m <- matchit(a ~ p + p_, data = d, + distance = "euclidean", + ...) + }, dont_warn_if = "Fewer control") expect_good_matchit(m, expect_distance = FALSE, expect_match.matrix = TRUE, expect_subclass = !m$info$replace, replace = m$info$replace, ratio = m$info$ratio) M <- c(M, list(m)) } if (any(which == 3)) { - m <- matchit(a ~ p + p_, data = d, - distance = d$p, - mahvars = ~p + p_, - ...) + matchit_try({ + m <- matchit(a ~ p + p_, data = d, + distance = d$p, + mahvars = ~p + p_, + ...) + }, dont_warn_if = "Fewer control") expect_good_matchit(m, expect_distance = TRUE, expect_match.matrix = TRUE, expect_subclass = !m$info$replace, replace = m$info$replace, ratio = m$info$ratio) M <- c(M, list(m)) } if (any(which == 4)) { - m <- matchit(a ~ p + p_, data = d, - distance = dd, - ...) + matchit_try({ + m <- matchit(a ~ p + p_, data = d, + distance = dd, + ...) + }, dont_warn_if = "Fewer control") expect_good_matchit(m, expect_distance = FALSE, expect_match.matrix = TRUE, expect_subclass = !m$info$replace, replace = m$info$replace, ratio = m$info$ratio) M <- c(M, list(m)) } - all(unlist(lapply(combn(seq_along(M), 2, simplify = FALSE), - function(i) isTRUE(all.equal(M[[i[1]]]$match.matrix, - M[[i[2]]]$match.matrix))))) + all(unlist(lapply(M[-1], function(m) isTRUE(all.equal(M[[1]]$match.matrix, + m$match.matrix))))) } expect_true(test_all(m.order = "data")) expect_true(test_all(m.order = "closest")) + expect_true(test_all(m.order = "farthest")) expect_true(test_all(m.order = "largest", which = c(1, 3))) + expect_true(test_all(m.order = "smallest", which = c(1, 3))) expect_true(test_all(m.order = "data", ratio = 2)) expect_true(test_all(m.order = "closest", ratio = 2)) expect_true(test_all(m.order = "data", ratio = 2, max.controls = 3, which = c(1, 3))) expect_true(test_all(m.order = "closest", ratio = 2, max.controls = 3, which = c(1, 3))) + expect_true(test_all(m.order = "largest", ratio = 2, max.controls = 3, which = c(1, 3))) expect_true(test_all(m.order = "data", ratio = 2, replace = TRUE)) expect_true(test_all(m.order = "closest", ratio = 2, replace = TRUE)) @@ -76,6 +86,7 @@ test_that("distance vector, mah vars, and distance matrix yield identical result expect_true(test_all(m.order = "data", ratio = 2, caliper = .001, std.caliper = FALSE, which = c(1, 3))) expect_true(test_all(m.order = "closest", ratio = 2, caliper = .001, std.caliper = FALSE, which = c(1, 3))) + expect_true(test_all(m.order = "largest", ratio = 2, caliper = .001, std.caliper = FALSE, which = c(1, 3))) expect_true(test_all(m.order = "data", ratio = 2, caliper = c(p = .001), std.caliper = FALSE)) expect_true(test_all(m.order = "closest", ratio = 2, caliper = c(p = .001), std.caliper = FALSE)) From ad43269229b9eb1050bad2ab6ce7c7638bf79cc4 Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Tue, 12 Nov 2024 02:27:39 -0500 Subject: [PATCH 30/48] Rcpp cleaning and updates --- src/RcppExports.cpp | 86 ++++++++++++++++++++++++++++++----------- src/all_equal_to.cpp | 35 +++++++++++++++++ src/eucdistC.cpp | 4 +- src/get_splitsC.cpp | 6 ++- src/has_n_unique.cpp | 3 +- src/pairdistC.cpp | 4 +- src/subclass2mm.cpp | 23 ++++++----- src/subclass_scootC.cpp | 10 ++--- src/weights_matrixC.cpp | 30 -------------- 9 files changed, 125 insertions(+), 76 deletions(-) create mode 100644 src/all_equal_to.cpp diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 92277c84..7203f657 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -13,6 +13,18 @@ Rcpp::Rostream& Rcpp::Rcout = Rcpp::Rcpp_cout_get(); Rcpp::Rostream& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get(); #endif +// all_equal_to +bool all_equal_to(RObject x, RObject y); +RcppExport SEXP _MatchIt_all_equal_to(SEXP xSEXP, SEXP ySEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< RObject >::type x(xSEXP); + Rcpp::traits::input_parameter< RObject >::type y(ySEXP); + rcpp_result_gen = Rcpp::wrap(all_equal_to(x, y)); + return rcpp_result_gen; +END_RCPP +} // eucdistC_N1xN0 NumericVector eucdistC_N1xN0(const NumericMatrix& x, const IntegerVector& t); RcppExport SEXP _MatchIt_eucdistC_N1xN0(SEXP xSEXP, SEXP tSEXP) { @@ -49,9 +61,9 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } -// nn_matchC -IntegerMatrix nn_matchC(const IntegerVector& treat_, const IntegerVector& ord, const IntegerVector& ratio, const LogicalVector& discarded, const int& reuse_max, const int& focal_, const Nullable& distance_, const Nullable& distance_mat_, const Nullable& exact_, const Nullable& caliper_dist_, const Nullable& caliper_covs_, const Nullable& caliper_covs_mat_, const Nullable& mah_covs_, const Nullable& antiexact_covs_, const Nullable& unit_id_, const bool& disl_prog); -RcppExport SEXP _MatchIt_nn_matchC(SEXP treat_SEXP, SEXP ordSEXP, SEXP ratioSEXP, SEXP discardedSEXP, SEXP reuse_maxSEXP, SEXP focal_SEXP, SEXP distance_SEXP, SEXP distance_mat_SEXP, SEXP exact_SEXP, SEXP caliper_dist_SEXP, SEXP caliper_covs_SEXP, SEXP caliper_covs_mat_SEXP, SEXP mah_covs_SEXP, SEXP antiexact_covs_SEXP, SEXP unit_id_SEXP, SEXP disl_progSEXP) { +// nn_matchC_distmat +IntegerMatrix nn_matchC_distmat(const IntegerVector& treat_, const IntegerVector& ord, const IntegerVector& ratio, const LogicalVector& discarded, const int& reuse_max, const int& focal_, const NumericMatrix& distance_mat, const Nullable& exact_, const Nullable& caliper_dist_, const Nullable& caliper_covs_, const Nullable& caliper_covs_mat_, const Nullable& antiexact_covs_, const Nullable& unit_id_, const bool& disl_prog); +RcppExport SEXP _MatchIt_nn_matchC_distmat(SEXP treat_SEXP, SEXP ordSEXP, SEXP ratioSEXP, SEXP discardedSEXP, SEXP reuse_maxSEXP, SEXP focal_SEXP, SEXP distance_matSEXP, SEXP exact_SEXP, SEXP caliper_dist_SEXP, SEXP caliper_covs_SEXP, SEXP caliper_covs_mat_SEXP, SEXP antiexact_covs_SEXP, SEXP unit_id_SEXP, SEXP disl_progSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -61,31 +73,55 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const LogicalVector& >::type discarded(discardedSEXP); Rcpp::traits::input_parameter< const int& >::type reuse_max(reuse_maxSEXP); Rcpp::traits::input_parameter< const int& >::type focal_(focal_SEXP); - Rcpp::traits::input_parameter< const Nullable& >::type distance_(distance_SEXP); - Rcpp::traits::input_parameter< const Nullable& >::type distance_mat_(distance_mat_SEXP); - Rcpp::traits::input_parameter< const Nullable& >::type exact_(exact_SEXP); + Rcpp::traits::input_parameter< const NumericMatrix& >::type distance_mat(distance_matSEXP); + Rcpp::traits::input_parameter< const Nullable& >::type exact_(exact_SEXP); Rcpp::traits::input_parameter< const Nullable& >::type caliper_dist_(caliper_dist_SEXP); Rcpp::traits::input_parameter< const Nullable& >::type caliper_covs_(caliper_covs_SEXP); Rcpp::traits::input_parameter< const Nullable& >::type caliper_covs_mat_(caliper_covs_mat_SEXP); - Rcpp::traits::input_parameter< const Nullable& >::type mah_covs_(mah_covs_SEXP); Rcpp::traits::input_parameter< const Nullable& >::type antiexact_covs_(antiexact_covs_SEXP); Rcpp::traits::input_parameter< const Nullable& >::type unit_id_(unit_id_SEXP); Rcpp::traits::input_parameter< const bool& >::type disl_prog(disl_progSEXP); - rcpp_result_gen = Rcpp::wrap(nn_matchC(treat_, ord, ratio, discarded, reuse_max, focal_, distance_, distance_mat_, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, mah_covs_, antiexact_covs_, unit_id_, disl_prog)); + rcpp_result_gen = Rcpp::wrap(nn_matchC_distmat(treat_, ord, ratio, discarded, reuse_max, focal_, distance_mat, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, disl_prog)); return rcpp_result_gen; END_RCPP } -// nn_matchC_closest -IntegerMatrix nn_matchC_closest(const NumericMatrix& distance_mat, const IntegerVector& treat, const IntegerVector& ratio, const LogicalVector& discarded, const int& reuse_max, const Nullable& exact_, const Nullable& caliper_dist_, const Nullable& caliper_covs_, const Nullable& caliper_covs_mat_, const Nullable& antiexact_covs_, const Nullable& unit_id_, const bool& disl_prog); -RcppExport SEXP _MatchIt_nn_matchC_closest(SEXP distance_matSEXP, SEXP treatSEXP, SEXP ratioSEXP, SEXP discardedSEXP, SEXP reuse_maxSEXP, SEXP exact_SEXP, SEXP caliper_dist_SEXP, SEXP caliper_covs_SEXP, SEXP caliper_covs_mat_SEXP, SEXP antiexact_covs_SEXP, SEXP unit_id_SEXP, SEXP disl_progSEXP) { +// nn_matchC_distmat_closest +IntegerMatrix nn_matchC_distmat_closest(const IntegerVector& treat, const IntegerVector& ratio, const LogicalVector& discarded, const int& reuse_max, const NumericMatrix& distance_mat, const Nullable& exact_, const Nullable& caliper_dist_, const Nullable& caliper_covs_, const Nullable& caliper_covs_mat_, const Nullable& antiexact_covs_, const Nullable& unit_id_, const bool& close, const bool& disl_prog); +RcppExport SEXP _MatchIt_nn_matchC_distmat_closest(SEXP treatSEXP, SEXP ratioSEXP, SEXP discardedSEXP, SEXP reuse_maxSEXP, SEXP distance_matSEXP, SEXP exact_SEXP, SEXP caliper_dist_SEXP, SEXP caliper_covs_SEXP, SEXP caliper_covs_mat_SEXP, SEXP antiexact_covs_SEXP, SEXP unit_id_SEXP, SEXP closeSEXP, SEXP disl_progSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; - Rcpp::traits::input_parameter< const NumericMatrix& >::type distance_mat(distance_matSEXP); Rcpp::traits::input_parameter< const IntegerVector& >::type treat(treatSEXP); Rcpp::traits::input_parameter< const IntegerVector& >::type ratio(ratioSEXP); Rcpp::traits::input_parameter< const LogicalVector& >::type discarded(discardedSEXP); Rcpp::traits::input_parameter< const int& >::type reuse_max(reuse_maxSEXP); + Rcpp::traits::input_parameter< const NumericMatrix& >::type distance_mat(distance_matSEXP); + Rcpp::traits::input_parameter< const Nullable& >::type exact_(exact_SEXP); + Rcpp::traits::input_parameter< const Nullable& >::type caliper_dist_(caliper_dist_SEXP); + Rcpp::traits::input_parameter< const Nullable& >::type caliper_covs_(caliper_covs_SEXP); + Rcpp::traits::input_parameter< const Nullable& >::type caliper_covs_mat_(caliper_covs_mat_SEXP); + Rcpp::traits::input_parameter< const Nullable& >::type antiexact_covs_(antiexact_covs_SEXP); + Rcpp::traits::input_parameter< const Nullable& >::type unit_id_(unit_id_SEXP); + Rcpp::traits::input_parameter< const bool& >::type close(closeSEXP); + Rcpp::traits::input_parameter< const bool& >::type disl_prog(disl_progSEXP); + rcpp_result_gen = Rcpp::wrap(nn_matchC_distmat_closest(treat, ratio, discarded, reuse_max, distance_mat, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, close, disl_prog)); + return rcpp_result_gen; +END_RCPP +} +// nn_matchC_mahcovs +IntegerMatrix nn_matchC_mahcovs(const IntegerVector& treat_, const IntegerVector& ord, const IntegerVector& ratio, const LogicalVector& discarded, const int& reuse_max, const int& focal_, const NumericMatrix& mah_covs, const Nullable& distance_, const Nullable& exact_, const Nullable& caliper_dist_, const Nullable& caliper_covs_, const Nullable& caliper_covs_mat_, const Nullable& antiexact_covs_, const Nullable& unit_id_, const bool& disl_prog); +RcppExport SEXP _MatchIt_nn_matchC_mahcovs(SEXP treat_SEXP, SEXP ordSEXP, SEXP ratioSEXP, SEXP discardedSEXP, SEXP reuse_maxSEXP, SEXP focal_SEXP, SEXP mah_covsSEXP, SEXP distance_SEXP, SEXP exact_SEXP, SEXP caliper_dist_SEXP, SEXP caliper_covs_SEXP, SEXP caliper_covs_mat_SEXP, SEXP antiexact_covs_SEXP, SEXP unit_id_SEXP, SEXP disl_progSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const IntegerVector& >::type treat_(treat_SEXP); + Rcpp::traits::input_parameter< const IntegerVector& >::type ord(ordSEXP); + Rcpp::traits::input_parameter< const IntegerVector& >::type ratio(ratioSEXP); + Rcpp::traits::input_parameter< const LogicalVector& >::type discarded(discardedSEXP); + Rcpp::traits::input_parameter< const int& >::type reuse_max(reuse_maxSEXP); + Rcpp::traits::input_parameter< const int& >::type focal_(focal_SEXP); + Rcpp::traits::input_parameter< const NumericMatrix& >::type mah_covs(mah_covsSEXP); + Rcpp::traits::input_parameter< const Nullable& >::type distance_(distance_SEXP); Rcpp::traits::input_parameter< const Nullable& >::type exact_(exact_SEXP); Rcpp::traits::input_parameter< const Nullable& >::type caliper_dist_(caliper_dist_SEXP); Rcpp::traits::input_parameter< const Nullable& >::type caliper_covs_(caliper_covs_SEXP); @@ -93,13 +129,13 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const Nullable& >::type antiexact_covs_(antiexact_covs_SEXP); Rcpp::traits::input_parameter< const Nullable& >::type unit_id_(unit_id_SEXP); Rcpp::traits::input_parameter< const bool& >::type disl_prog(disl_progSEXP); - rcpp_result_gen = Rcpp::wrap(nn_matchC_closest(distance_mat, treat, ratio, discarded, reuse_max, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, disl_prog)); + rcpp_result_gen = Rcpp::wrap(nn_matchC_mahcovs(treat_, ord, ratio, discarded, reuse_max, focal_, mah_covs, distance_, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, disl_prog)); return rcpp_result_gen; END_RCPP } // nn_matchC_mahcovs_closest -IntegerMatrix nn_matchC_mahcovs_closest(const IntegerVector& treat, const IntegerVector& ratio, const LogicalVector& discarded, const int& reuse_max, const NumericMatrix& mah_covs, const Nullable& distance_, const Nullable& exact_, const Nullable& caliper_dist_, const Nullable& caliper_covs_, const Nullable& caliper_covs_mat_, const Nullable& antiexact_covs_, const Nullable& unit_id_, const bool& disl_prog); -RcppExport SEXP _MatchIt_nn_matchC_mahcovs_closest(SEXP treatSEXP, SEXP ratioSEXP, SEXP discardedSEXP, SEXP reuse_maxSEXP, SEXP mah_covsSEXP, SEXP distance_SEXP, SEXP exact_SEXP, SEXP caliper_dist_SEXP, SEXP caliper_covs_SEXP, SEXP caliper_covs_mat_SEXP, SEXP antiexact_covs_SEXP, SEXP unit_id_SEXP, SEXP disl_progSEXP) { +IntegerMatrix nn_matchC_mahcovs_closest(const IntegerVector& treat, const IntegerVector& ratio, const LogicalVector& discarded, const int& reuse_max, const NumericMatrix& mah_covs, const Nullable& distance_, const Nullable& exact_, const Nullable& caliper_dist_, const Nullable& caliper_covs_, const Nullable& caliper_covs_mat_, const Nullable& antiexact_covs_, const Nullable& unit_id_, const bool& close, const bool& disl_prog); +RcppExport SEXP _MatchIt_nn_matchC_mahcovs_closest(SEXP treatSEXP, SEXP ratioSEXP, SEXP discardedSEXP, SEXP reuse_maxSEXP, SEXP mah_covsSEXP, SEXP distance_SEXP, SEXP exact_SEXP, SEXP caliper_dist_SEXP, SEXP caliper_covs_SEXP, SEXP caliper_covs_mat_SEXP, SEXP antiexact_covs_SEXP, SEXP unit_id_SEXP, SEXP closeSEXP, SEXP disl_progSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -115,8 +151,9 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const Nullable& >::type caliper_covs_mat_(caliper_covs_mat_SEXP); Rcpp::traits::input_parameter< const Nullable& >::type antiexact_covs_(antiexact_covs_SEXP); Rcpp::traits::input_parameter< const Nullable& >::type unit_id_(unit_id_SEXP); + Rcpp::traits::input_parameter< const bool& >::type close(closeSEXP); Rcpp::traits::input_parameter< const bool& >::type disl_prog(disl_progSEXP); - rcpp_result_gen = Rcpp::wrap(nn_matchC_mahcovs_closest(treat, ratio, discarded, reuse_max, mah_covs, distance_, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, disl_prog)); + rcpp_result_gen = Rcpp::wrap(nn_matchC_mahcovs_closest(treat, ratio, discarded, reuse_max, mah_covs, distance_, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, close, disl_prog)); return rcpp_result_gen; END_RCPP } @@ -145,8 +182,8 @@ BEGIN_RCPP END_RCPP } // nn_matchC_vec_closest -IntegerMatrix nn_matchC_vec_closest(const IntegerVector& treat, const IntegerVector& ratio, const LogicalVector& discarded, const int& reuse_max, const NumericVector& distance, const Nullable& exact_, const Nullable& caliper_dist_, const Nullable& caliper_covs_, const Nullable& caliper_covs_mat_, const Nullable& antiexact_covs_, const Nullable& unit_id_, const bool& disl_prog); -RcppExport SEXP _MatchIt_nn_matchC_vec_closest(SEXP treatSEXP, SEXP ratioSEXP, SEXP discardedSEXP, SEXP reuse_maxSEXP, SEXP distanceSEXP, SEXP exact_SEXP, SEXP caliper_dist_SEXP, SEXP caliper_covs_SEXP, SEXP caliper_covs_mat_SEXP, SEXP antiexact_covs_SEXP, SEXP unit_id_SEXP, SEXP disl_progSEXP) { +IntegerMatrix nn_matchC_vec_closest(const IntegerVector& treat, const IntegerVector& ratio, const LogicalVector& discarded, const int& reuse_max, const NumericVector& distance, const Nullable& exact_, const Nullable& caliper_dist_, const Nullable& caliper_covs_, const Nullable& caliper_covs_mat_, const Nullable& antiexact_covs_, const Nullable& unit_id_, const bool& close, const bool& disl_prog); +RcppExport SEXP _MatchIt_nn_matchC_vec_closest(SEXP treatSEXP, SEXP ratioSEXP, SEXP discardedSEXP, SEXP reuse_maxSEXP, SEXP distanceSEXP, SEXP exact_SEXP, SEXP caliper_dist_SEXP, SEXP caliper_covs_SEXP, SEXP caliper_covs_mat_SEXP, SEXP antiexact_covs_SEXP, SEXP unit_id_SEXP, SEXP closeSEXP, SEXP disl_progSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -161,8 +198,9 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const Nullable& >::type caliper_covs_mat_(caliper_covs_mat_SEXP); Rcpp::traits::input_parameter< const Nullable& >::type antiexact_covs_(antiexact_covs_SEXP); Rcpp::traits::input_parameter< const Nullable& >::type unit_id_(unit_id_SEXP); + Rcpp::traits::input_parameter< const bool& >::type close(closeSEXP); Rcpp::traits::input_parameter< const bool& >::type disl_prog(disl_progSEXP); - rcpp_result_gen = Rcpp::wrap(nn_matchC_vec_closest(treat, ratio, discarded, reuse_max, distance, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, disl_prog)); + rcpp_result_gen = Rcpp::wrap(nn_matchC_vec_closest(treat, ratio, discarded, reuse_max, distance, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, close, disl_prog)); return rcpp_result_gen; END_RCPP } @@ -260,14 +298,16 @@ RcppExport SEXP _MatchIt_RcppExport_registerCCallable() { } static const R_CallMethodDef CallEntries[] = { + {"_MatchIt_all_equal_to", (DL_FUNC) &_MatchIt_all_equal_to, 2}, {"_MatchIt_eucdistC_N1xN0", (DL_FUNC) &_MatchIt_eucdistC_N1xN0, 2}, {"_MatchIt_get_splitsC", (DL_FUNC) &_MatchIt_get_splitsC, 2}, {"_MatchIt_has_n_unique", (DL_FUNC) &_MatchIt_has_n_unique, 2}, - {"_MatchIt_nn_matchC", (DL_FUNC) &_MatchIt_nn_matchC, 16}, - {"_MatchIt_nn_matchC_closest", (DL_FUNC) &_MatchIt_nn_matchC_closest, 12}, - {"_MatchIt_nn_matchC_mahcovs_closest", (DL_FUNC) &_MatchIt_nn_matchC_mahcovs_closest, 13}, + {"_MatchIt_nn_matchC_distmat", (DL_FUNC) &_MatchIt_nn_matchC_distmat, 14}, + {"_MatchIt_nn_matchC_distmat_closest", (DL_FUNC) &_MatchIt_nn_matchC_distmat_closest, 13}, + {"_MatchIt_nn_matchC_mahcovs", (DL_FUNC) &_MatchIt_nn_matchC_mahcovs, 15}, + {"_MatchIt_nn_matchC_mahcovs_closest", (DL_FUNC) &_MatchIt_nn_matchC_mahcovs_closest, 14}, {"_MatchIt_nn_matchC_vec", (DL_FUNC) &_MatchIt_nn_matchC_vec, 14}, - {"_MatchIt_nn_matchC_vec_closest", (DL_FUNC) &_MatchIt_nn_matchC_vec_closest, 12}, + {"_MatchIt_nn_matchC_vec_closest", (DL_FUNC) &_MatchIt_nn_matchC_vec_closest, 13}, {"_MatchIt_pairdistsubC", (DL_FUNC) &_MatchIt_pairdistsubC, 3}, {"_MatchIt_subclass2mmC", (DL_FUNC) &_MatchIt_subclass2mmC, 3}, {"_MatchIt_mm2subclassC", (DL_FUNC) &_MatchIt_mm2subclassC, 3}, diff --git a/src/all_equal_to.cpp b/src/all_equal_to.cpp new file mode 100644 index 00000000..9a72a771 --- /dev/null +++ b/src/all_equal_to.cpp @@ -0,0 +1,35 @@ +#include "internal.h" +using namespace Rcpp; + +// [[Rcpp::plugins(cpp11)]] + +// Templated function to check if a vector has exactly n unique values +template +bool all_equal_to_(Vector x, + typename traits::storage_type::type y) { + + for (auto xi : x) { + if (xi != y) { + return false; + } + } + + return true; +} + +// Wrapper function to handle different types of R vectors +// [[Rcpp::export]] +bool all_equal_to(RObject x, + RObject y) { + + switch (TYPEOF(x)) { + case INTSXP: + return all_equal_to_(as(x), as(y)); + case REALSXP: + return all_equal_to_(as(x), as(y)); + case LGLSXP: + return all_equal_to_(as(x), as(y)); + default: + stop("Unsupported vector type"); + } +} \ No newline at end of file diff --git a/src/eucdistC.cpp b/src/eucdistC.cpp index 509ca20a..0cc68869 100644 --- a/src/eucdistC.cpp +++ b/src/eucdistC.cpp @@ -1,6 +1,6 @@ -#include #include "internal.h" using namespace Rcpp; + // [[Rcpp::plugins(cpp11)]] // [[Rcpp::export]] @@ -31,4 +31,4 @@ NumericVector eucdistC_N1xN0(const NumericMatrix& x, dist.attr("dim") = Dimension(ind1.size(), ind0.size()); return dist; -} \ No newline at end of file +} diff --git a/src/get_splitsC.cpp b/src/get_splitsC.cpp index 5ecef79c..fde059d6 100644 --- a/src/get_splitsC.cpp +++ b/src/get_splitsC.cpp @@ -1,6 +1,8 @@ -#include +#include "internal.h" using namespace Rcpp; +// [[Rcpp::plugins(cpp11)]] + // [[Rcpp::export]] NumericVector get_splitsC(const NumericVector& x, const double& caliper) { @@ -10,7 +12,7 @@ NumericVector get_splitsC(const NumericVector& x, NumericVector x_ = unique(x); NumericVector x_sorted = x_.sort(); - int n = x_sorted.size(); + R_xlen_t n = x_sorted.size(); if (n <= 1) { return splits; diff --git a/src/has_n_unique.cpp b/src/has_n_unique.cpp index b67d0acc..714569a9 100644 --- a/src/has_n_unique.cpp +++ b/src/has_n_unique.cpp @@ -1,5 +1,4 @@ -#include -#include +#include "internal.h" using namespace Rcpp; // [[Rcpp::plugins(cpp11)]] diff --git a/src/pairdistC.cpp b/src/pairdistC.cpp index 8bc02aaf..2625ba12 100644 --- a/src/pairdistC.cpp +++ b/src/pairdistC.cpp @@ -1,7 +1,8 @@ -#include #include "internal.h" using namespace Rcpp; +// [[Rcpp::plugins(cpp11)]] + // [[Rcpp::export]] double pairdistsubC(const NumericVector& x, const IntegerVector& t, @@ -19,7 +20,6 @@ double pairdistsubC(const NumericVector& x, R_xlen_t n = sum(!is_na(s)); - for (i = 0; i < n; i++) { ord_i = ord[i]; s_i = s[ord_i]; diff --git a/src/subclass2mm.cpp b/src/subclass2mm.cpp index ce574cbc..07db755d 100644 --- a/src/subclass2mm.cpp +++ b/src/subclass2mm.cpp @@ -1,7 +1,8 @@ -#include #include "internal.h" using namespace Rcpp; +// [[Rcpp::plugins(cpp11)]] + //Turns subclass vector given as a factor into a numeric match.matrix. //focal is the treatment level (0/1) that corresponds to the rownames. @@ -14,14 +15,15 @@ IntegerMatrix subclass2mmC(const IntegerVector& subclass_, IntegerVector unique_sub = unique(as(subclass_[!na_sub])); IntegerVector subclass = match(subclass_, unique_sub) - 1; - int nsub = unique_sub.size(); + R_xlen_t nsub = unique_sub.size(); R_xlen_t n = treat.size(); IntegerVector ind = Range(0, n - 1); IntegerVector ind_focal = ind[treat == focal]; R_xlen_t n1 = ind_focal.size(); - IntegerVector subtab = rep(-1, nsub); + IntegerVector subtab(nsub); + subtab.fill(-1); R_xlen_t i; for (i = 0; i < n; i++) { @@ -76,8 +78,7 @@ IntegerMatrix subclass2mmC(const IntegerVector& subclass_, } mm = mm + 1; - - rownames(mm) = lab[ind_focal]; + rownames(mm) = as(lab[ind_focal]); return mm; } @@ -89,7 +90,9 @@ IntegerVector mm2subclassC(const IntegerMatrix& mm, CharacterVector lab = treat.names(); - IntegerVector subclass(treat.size()); + R_xlen_t n1 = treat.size(); + + IntegerVector subclass(n1); subclass.fill(NA_INTEGER); subclass.names() = lab; @@ -101,10 +104,10 @@ IntegerVector mm2subclassC(const IntegerMatrix& mm, ind1 = match(as(rownames(mm)), lab) - 1; } - int r = mm.nrow(); - int ki = 0; + R_xlen_t r = mm.nrow(); + R_xlen_t ki = 0; - for (int i : which(!is_na(mm))) { + for (R_xlen_t i : which(!is_na(mm))) { if (i / r == 0) { //If first entry in row, increment ki and assign subclass of treated ki++; @@ -115,7 +118,7 @@ IntegerVector mm2subclassC(const IntegerMatrix& mm, } CharacterVector levs(ki); - for (int j = 0; j < ki; j++){ + for (R_xlen_t j = 0; j < ki; j++){ levs[j] = std::to_string(j + 1); } diff --git a/src/subclass_scootC.cpp b/src/subclass_scootC.cpp index 91cb7430..67c7669d 100644 --- a/src/subclass_scootC.cpp +++ b/src/subclass_scootC.cpp @@ -1,4 +1,3 @@ -#include #include "internal.h" using namespace Rcpp; @@ -32,9 +31,9 @@ IntegerVector subclass_scootC(const IntegerVector& subclass_, subclass = match(subclass, unique_sub) - 1; - int nsub = unique_sub.size(); + R_xlen_t nsub = unique_sub.size(); - IntegerVector subtab(nsub); + NumericVector subtab(nsub); IntegerVector indt; bool left = false; @@ -45,7 +44,8 @@ IntegerVector subclass_scootC(const IntegerVector& subclass_, nt = indt.size(); //Tabulate - subtab = rep(0, nsub); + subtab.fill(0.0); + for (int i : indt) { subtab[subclass[i]]++; } @@ -77,7 +77,7 @@ IntegerVector subclass_scootC(const IntegerVector& subclass_, continue; } - score += static_cast(subtab[s2] - 1) / static_cast(s2 - s); + score += (subtab[s2] - 1) / static_cast(s2 - s); } left = (score <= 0); diff --git a/src/weights_matrixC.cpp b/src/weights_matrixC.cpp index f6b06fc5..10a923e3 100644 --- a/src/weights_matrixC.cpp +++ b/src/weights_matrixC.cpp @@ -1,6 +1,4 @@ -#include #include "internal.h" -#include using namespace Rcpp; // [[Rcpp::plugins(cpp11)]] @@ -59,33 +57,5 @@ NumericVector weights_matrixC(const IntegerMatrix& mm, weights[row_ind[r]] += 1.0; } - //Scale control weights to sum to number of matched controls - NumericVector weights_gi; - IntegerVector indg; - double sum_w; - double sum_matched; - - for (gi = 0; gi < g; gi++) { - indg = which(treat == gi); - - weights_gi = weights[indg]; - - sum_w = sum(weights_gi); - - if (sum_w == 0) { - continue; - } - - sum_matched = sum(weights_gi > 0); - - if (sum_matched == sum_w) { - continue; - } - - for (int i : indg) { - weights[i] *= sum_matched / sum_w; - } - } - return weights; } \ No newline at end of file From 8f900dcb46acd9a6f95e4f2cbc729dc87add9b8c Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Tue, 12 Nov 2024 02:27:55 -0500 Subject: [PATCH 31/48] Improvd eta estimation and speed for progress bar --- src/eta_progress_bar.h | 105 +++++++++++++++++++++++------------------ 1 file changed, 58 insertions(+), 47 deletions(-) diff --git a/src/eta_progress_bar.h b/src/eta_progress_bar.h index a7e7e7ae..590338e7 100644 --- a/src/eta_progress_bar.h +++ b/src/eta_progress_bar.h @@ -17,6 +17,7 @@ #include #include +#include #include "progress_bar.hpp" // for unices only @@ -53,21 +54,25 @@ class ETAProgressBar: public ProgressBar{ // stop if already finalized if (_finalized) return; + // measure current time + time(¤t_time); + // start time measurement when update() is called the first time if (_timer_flag) { _timer_flag = false; // measure start time - time(&start); - last_refresh = start; - current_old = start; - progress_old = progress; + time_at_start = current_time; + + time_at_last_refresh = current_time; - ema_rate = 0; + progress_at_last_refresh = progress; + + _num_ticks = _compute_nb_ticks(progress); time_string = "calculating..."; // create progress bar string - std::string progress_bar_string = _current_ticks_display(progress); + std::string progress_bar_string = _current_ticks_display(_num_ticks); // merge progress bar and time string std::stringstream strs; @@ -80,67 +85,82 @@ class ETAProgressBar: public ProgressBar{ REprintf("%s", char_type); } else { - // measure current time - time(¤t_new); + double time_since_start = std::difftime(current_time, time_at_start); if (progress != 1) { - // create progress bar string - std::string progress_bar_string = _current_ticks_display(progress); - // ensure overwriting of old time info int empty_length = time_string.length(); - double time_since_start = std::difftime(current_new, start); + int _num_ticks_current = _compute_nb_ticks(progress); - if (time_since_start <= 1) { - ema_rate = progress / time_since_start; - } - else { - double time_since_last_refresh = std::difftime(current_new, last_refresh); + bool update_bar = (_num_ticks_current != _num_ticks); + + _num_ticks = _num_ticks_current; + + if (progress > 0 && time_since_start > 1) { + double time_since_last_refresh = std::difftime(current_time, time_at_last_refresh); if (time_since_last_refresh >= .5) { - double current_rate = (progress - progress_old) / time_since_last_refresh; + update_bar = true; + + double progress_since_last_refresh = progress - progress_at_last_refresh; + + double total_rate = progress / time_since_start; + + if (progress_since_last_refresh == 0) { + progress_since_last_refresh = .0000001; + } + double current_rate = progress_since_last_refresh / time_since_last_refresh; + + //alpha weights average rate against recent recent (current) rate; + //alpha = 1 => estimate based on on total_rate (treats as constant) + //alpha = 0 => estimate based on recent rate (high fluctuation) double alpha = .8; - ema_rate = alpha * ema_rate + (1 - alpha) * current_rate; + double eta = (1 - progress) * (alpha / total_rate + (1 - alpha) / current_rate); // convert seconds to time string time_string = "~"; - time_string += _time_to_string((1 - progress) / ema_rate); + time_string += _time_to_string(eta); - last_refresh = current_new; - progress_old = progress; + time_at_last_refresh = current_time; + progress_at_last_refresh = progress; } } - std::string empty_space = std::string(std::fdim(empty_length, time_string.length()), ' '); + if (update_bar) { + // create progress bar string + std::string progress_bar_string = _current_ticks_display(_num_ticks); - // merge progress bar and time string - std::stringstream strs; - strs << "|" << progress_bar_string << "| ETA: " << time_string << empty_space; - std::string temp_str = strs.str(); - char const* char_type = temp_str.c_str(); + std::string empty_space = std::string(std::fdim(empty_length, time_string.length()), ' '); - // print: remove old and replace with new - REprintf("\r"); - REprintf("%s", char_type); + // merge progress bar and time string + std::stringstream strs; + strs << "|" << progress_bar_string << "| ETA: " << time_string << empty_space; + std::string temp_str = strs.str(); + char const* char_type = temp_str.c_str(); - current_old = current_new; + // print: remove old and replace with new + REprintf("\r"); + REprintf("%s", char_type); + } } else { // ensure overwriting of old time info int empty_length = time_string.length(); // finalize display when ready - double time_since_start = std::difftime(current_new, start); + // convert seconds to time string std::string time_string = _time_to_string(time_since_start); std::string empty_space = std::string(std::fdim(empty_length, time_string.length()), ' '); // create progress bar string - std::string progress_bar_string = _current_ticks_display(progress); + _num_ticks = _compute_nb_ticks(progress); + + std::string progress_bar_string = _current_ticks_display(_num_ticks); // merge progress bar and time string std::stringstream strs; @@ -189,17 +209,7 @@ class ETAProgressBar: public ProgressBar{ } // update the ticks display corresponding to progress - std::string _current_ticks_display(float progress) { - - int nb_ticks = _compute_nb_ticks(progress); - - std::string cur_display = _construct_ticks_display_string(nb_ticks); - - return cur_display; - } - - // construct progress bar display - std::string _construct_ticks_display_string(int nb) { + std::string _current_ticks_display(int nb) { std::stringstream ticks_strs; for (int i = 0; i < (_max_ticks - 1); ++i) { @@ -237,10 +247,11 @@ class ETAProgressBar: public ProgressBar{ private: // ===== INSTANCE VARIABLES ==== int _max_ticks; // the total number of ticks to print + int _num_ticks; bool _finalized; bool _timer_flag; - time_t start, current_new, last_refresh, current_old; - double ema_rate, progress_old; + time_t time_at_start, current_time, time_at_last_refresh; + float progress_at_last_refresh; std::string time_string; }; From 89d5f02dcc3424c8ea0e9f5209a383ac484fba1a Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Tue, 12 Nov 2024 02:28:36 -0500 Subject: [PATCH 32/48] Improvements to matching algorithms, mostly using vector instead of Rcpp vector --- src/internal.cpp | 681 +++++++++++++++++++++++++++++++++++------------ 1 file changed, 512 insertions(+), 169 deletions(-) diff --git a/src/internal.cpp b/src/internal.cpp index 1b399464..a401f7e2 100644 --- a/src/internal.cpp +++ b/src/internal.cpp @@ -80,6 +80,19 @@ bool caliper_covs_okay(const int& ncc, return true; } +// [[Rcpp::interfaces(cpp)]] +bool caliper_dist_okay(const bool& use_caliper_dist, + const int& i, + const int& j, + const NumericVector& distance, + const double& caliper_dist) { + if (!use_caliper_dist) { + return true; + } + + return std::abs(distance[i] - distance[j]) <= caliper_dist; +} + // [[Rcpp::interfaces(cpp)]] bool mm_okay(const int& r, const int& i, @@ -110,247 +123,577 @@ bool exact_okay(const bool& use_exact, } // [[Rcpp::interfaces(cpp)]] -int find_both(const int& t_id, - const IntegerVector& ind_d_ord, - const IntegerVector& match_d_ord, - const IntegerVector& treat, - const NumericVector& distance, - const LogicalVector& eligible, - const int& gi, - const int& r, - const IntegerVector& mm_ordi, - const int& ncc, - const NumericMatrix& caliper_covs_mat, - const NumericVector& caliper_covs, - const double& caliper_dist, - const bool& use_exact, - const IntegerVector& exact, - const int& aenc, - const IntegerMatrix& antiexact_covs, - const IntegerVector& first_control, - const IntegerVector& last_control) { +std::vector find_control_vec(const int& t_id, + const IntegerVector& ind_d_ord, + const IntegerVector& match_d_ord, + const IntegerVector& treat, + const NumericVector& distance, + const LogicalVector& eligible, + const int& gi, + const int& r, + const IntegerVector& mm_rowi_, + const int& ncc, + const NumericMatrix& caliper_covs_mat, + const NumericVector& caliper_covs, + const double& caliper_dist, + const bool& use_exact, + const IntegerVector& exact, + const int& aenc, + const IntegerMatrix& antiexact_covs, + const IntegerVector& first_control, + const IntegerVector& last_control, + const int& ratio = 1, + const int& prev_start = -1) { int ii = match_d_ord[t_id]; + IntegerVector mm_rowi; + std::vector possible_starts; + + + if (r > 1) { + mm_rowi = na_omit(mm_rowi_); + mm_rowi = mm_rowi[as(treat[mm_rowi]) == gi]; + possible_starts.reserve(mm_rowi.size() + 2); + + for (int mmi : mm_rowi) { + possible_starts.push_back(match_d_ord[mmi]); + } + } + else { + possible_starts.reserve(2); + } + + possible_starts.push_back(ii); + + if (prev_start >= 0) { + possible_starts.push_back(match_d_ord[prev_start]); + } + + int iil, iir; + double min_dist; + + if (possible_starts.size() == 1) { + iil = ii; + iir = ii; + min_dist = 0; + } + else { + iil = *std::min_element(possible_starts.begin(), possible_starts.end()); + iir = *std::max_element(possible_starts.begin(), possible_starts.end()); + + if (iil == ii) { + min_dist = std::abs(distance[t_id] - distance[ind_d_ord[iir]]); + } + else if (iir == ii) { + min_dist = std::abs(distance[t_id] - distance[ind_d_ord[iil]]); + } + else { + min_dist = std::max(std::abs(distance[t_id] - distance[ind_d_ord[iil]]), + std::abs(distance[t_id] - distance[ind_d_ord[iir]])); + } + } + int min_ii = first_control[gi]; int max_ii = last_control[gi]; - int iil = ii; - int iir = ii; - int il = -1; - int ir = -1; + double di = distance[t_id]; - bool l_found = (iil <= min_ii); - bool r_found = (iir >= max_ii); + bool l_stop = false; + bool r_stop = false; - double di = distance[t_id]; - double distl, distr; + double dist_c; + + std::vector potential_matches_id; + potential_matches_id.reserve(2 * ratio); + std::vector potential_matches_dist; + potential_matches_dist.reserve(2 * ratio); + + int num_matches_l = 0; + int num_matches_r = 0; - while (!l_found || !r_found) { - if (!l_found) { - if (iil == min_ii) { - l_found = true; - il = -1; + int iz; + int z = 1; + int num_closer_than_dist_c; + + while (!l_stop || !r_stop) { + if (l_stop) { + z = 1; + } + else if (r_stop) { + z = -1; + } + else { + z *= -1; + } + + if (z == -1) { + if (iil <= min_ii || num_matches_l == ratio) { + l_stop = true; + continue; } - else { - iil--; - il = ind_d_ord[iil]; - - //Left - if (eligible[il]) { - if (treat[il] == gi) { - if (mm_okay(r, il, mm_ordi)) { - - distl = std::abs(di - distance[il]); - - if (r_found && ir >= 0 && distl > distr) { - return ir; - } - - if (distl > caliper_dist) { - il = -1; - l_found = true; - } - else { - if (exact_okay(use_exact, t_id, il, exact)) { - if (antiexact_okay(aenc, t_id, il, antiexact_covs)) { - if (caliper_covs_okay(ncc, t_id, il, caliper_covs_mat, caliper_covs)) { - l_found = true; - } - } - } - } - } + + iil += z; + iz = ind_d_ord[iil]; + } + else { + if (iir >= max_ii || num_matches_r == ratio) { + r_stop = true; + continue; + } + + iir += z; + iz = ind_d_ord[iir]; + } + + if (!eligible[iz]) { + continue; + } + + if (treat[iz] != gi) { + continue; + } + + if (!mm_okay(r, iz, mm_rowi)) { + continue; + } + + dist_c = std::abs(di - distance[iz]); + + //If current dist is worse than ratio dists, continue + if (potential_matches_id.size() >= ratio) { + num_closer_than_dist_c = 0; + for (double d : potential_matches_dist) { + if (d < dist_c) { + num_closer_than_dist_c++; + if (num_closer_than_dist_c == ratio) { + break; } } } + + if (num_closer_than_dist_c >= ratio) { + if (z == -1) { + l_stop = true; + } + else { + r_stop = true; + } + continue; + } } - if (!r_found) { - if (iir == max_ii) { - r_found = true; - ir = -1; + if (dist_c > caliper_dist) { + if (z == -1) { + l_stop = true; } else { - iir++; - ir = ind_d_ord[iir]; - - //Right - if (eligible[ir]) { - if (treat[ir] == gi) { - if (mm_okay(r, ir, mm_ordi)) { - - distr = std::abs(di - distance[ir]); - - if (l_found && il >= 0 && distl <= distr) { - return il; - } - - if (distr > caliper_dist) { - ir = -1; - r_found = true; - } - else { - if (exact_okay(use_exact, t_id, ir, exact)) { - if (antiexact_okay(aenc, t_id, ir, antiexact_covs)) { - if (caliper_covs_okay(ncc, t_id, ir, caliper_covs_mat, caliper_covs)) { - r_found = true; - } - } - } - } - } - } - } + r_stop = true; + } + continue; + } + + if (dist_c < min_dist) { + continue; + } + + if (!exact_okay(use_exact, t_id, iz, exact)) { + continue; + } + + if (!antiexact_okay(aenc, t_id, iz, antiexact_covs)) { + continue; + } + + if (!caliper_covs_okay(ncc, t_id, iz, caliper_covs_mat, caliper_covs)) { + continue; + } + + potential_matches_id.push_back(iz); + potential_matches_dist.push_back(dist_c); + + if (z == -1) { + num_matches_l++; + if (num_matches_l == ratio) { + l_stop = true; + } + } + else { + num_matches_r++; + if (num_matches_r == ratio) { + r_stop = true; } } } - if (il < 0) { - return ir; + int n_potential_matches = potential_matches_id.size(); + + if (n_potential_matches <= 1) { + return potential_matches_id; } - if (ir < 0) { - return il; + if (n_potential_matches <= ratio && + std::is_sorted(potential_matches_dist.begin(), + potential_matches_dist.end())) { + return potential_matches_id; } - if (distl <= distr) { - return il; + std::vector ind(n_potential_matches); + std::iota(ind.begin(), ind.end(), 0); + + std::vector matches_out; + + if (n_potential_matches > ratio) { + std::partial_sort(ind.begin(), ind.begin() + ratio, ind.end(), + [&potential_matches_dist](int a, int b){ + return potential_matches_dist[a] < potential_matches_dist[b]; + }); + + matches_out.reserve(ratio); + + for (auto it = ind.begin(); it != ind.begin() + ratio; ++it) { + matches_out.push_back(potential_matches_id[*it]); + } } else { - return ir; + std::sort(ind.begin(), ind.end(), + [&potential_matches_dist](int a, int b){ + return potential_matches_dist[a] < potential_matches_dist[b]; + }); + + matches_out.reserve(n_potential_matches); + + for (auto it = ind.begin(); it != ind.end(); ++it) { + matches_out.push_back(potential_matches_id[*it]); + } } + + return matches_out; } // [[Rcpp::interfaces(cpp)]] -int find_lr(const int& prev_match, - const int& t_id, - const IntegerVector& ind_d_ord, - const IntegerVector& match_d_ord, - const IntegerVector& treat, - const NumericVector& distance, - const LogicalVector& eligible, - const int& gi, - const int& ncc, - const NumericMatrix& caliper_covs_mat, - const NumericVector& caliper_covs, - const double& caliper_dist, - const bool& use_exact, - const IntegerVector& exact, - const int& aenc, - const IntegerMatrix& antiexact_covs, - const IntegerVector& first_control, - const IntegerVector& last_control) { - - int ik, iik; - double dist; - - int prev_pos; +std::vector find_control_mahcovs(const int& t_id, + const IntegerVector& ind_d_ord, + const IntegerVector& match_d_ord, + const NumericVector& match_var, + const double& match_var_caliper, + const IntegerVector& treat, + const NumericVector& distance, + const LogicalVector& eligible, + const int& gi, + const int& r, + const IntegerVector& mm_rowi, + const NumericMatrix& mah_covs, + const int& ncc, + const NumericMatrix& caliper_covs_mat, + const NumericVector& caliper_covs, + const bool& use_caliper_dist, + const double& caliper_dist, + const bool& use_exact, + const IntegerVector& exact, + const int& aenc, + const IntegerMatrix& antiexact_covs, + const int& ratio = 1) { + int ii = match_d_ord[t_id]; - int z; - if (prev_match < 0) { - if (prev_match == -1) { + int iil, iir; + + iil = ii; + iir = ii; + + int min_ii = 0; + int max_ii = match_d_ord.size() - 1; + + bool l_stop = false; + bool r_stop = false; + + double dist_c; + + std::vector> potential_matches; + potential_matches.reserve(ratio); + + std::pair new_match; + + int num_matches_l = 0; + int num_matches_r = 0; + + double mv_i = match_var[t_id]; + double mv_dist; + + int iz; + int z = 1; + + auto dist_comp = [](std::pair a, std::pair b) { + return a.second < b.second; + }; + + while (!l_stop || !r_stop) { + if (l_stop) { + z = 1; + } + else if (r_stop) { z = -1; } else { - z = 1; + z *= -1; } - prev_pos = ii; - } - else { - prev_pos = match_d_ord[prev_match]; - if (prev_pos < ii) { - z = -1; + if (z == -1) { + if (iil <= min_ii || num_matches_l == ratio) { + l_stop = true; + continue; + } + + iil += z; + iz = ind_d_ord[iil]; } else { - z = 1; + if (iir >= max_ii || num_matches_r == ratio) { + r_stop = true; + continue; + } + + iir += z; + iz = ind_d_ord[iir]; } - } - int min_ii = 0; - int max_ii = ind_d_ord.size() - 1; + if (!eligible[iz]) { + continue; + } - if (z == -1) { - min_ii = first_control[gi]; - } - else { - max_ii = last_control[gi]; + if (treat[iz] != gi) { + continue; + } + + if (!mm_okay(r, iz, mm_rowi)) { + continue; + } + + mv_dist = pow(mv_i - match_var[iz], 2); + + if (mv_dist > match_var_caliper) { + if (z == -1) { + l_stop = true; + } + else { + r_stop = true; + } + + continue; + } + + //If current dist is worse than ratio dists, continue + if (potential_matches.size() == ratio) { + if (potential_matches.back().second < mv_dist) { + if (z == -1) { + l_stop = true; + } + else { + r_stop = true; + } + + continue; + } + } + + if (!exact_okay(use_exact, t_id, iz, exact)) { + continue; + } + + if (!caliper_dist_okay(use_caliper_dist, t_id, iz, distance, caliper_dist)) { + continue; + } + + if (!antiexact_okay(aenc, t_id, iz, antiexact_covs)) { + continue; + } + + if (!caliper_covs_okay(ncc, t_id, iz, caliper_covs_mat, caliper_covs)) { + continue; + } + + dist_c = sum(pow(mah_covs.row(t_id) - mah_covs.row(iz), 2.0)); + + if (!std::isfinite(dist_c)) { + continue; + } + + new_match = std::pair(iz, dist_c); + + if (potential_matches.empty()) { + potential_matches.push_back(new_match); + } + else if (dist_c > potential_matches.back().second) { + if (potential_matches.size() == ratio) { + continue; + } + + potential_matches.push_back(new_match); + } + else if (ratio == 1) { + potential_matches[0] = new_match; + } + else { + if (potential_matches.size() == ratio) { + potential_matches.pop_back(); + } + + if (dist_c > potential_matches.back().second) { + potential_matches.push_back(new_match); + } + else { + potential_matches.insert(std::lower_bound(potential_matches.begin(), potential_matches.end(), + new_match, dist_comp), + new_match); + } + } } - int start = prev_pos + z; - if (start > last_control[gi]) { - start = last_control[gi]; + std::vector matches_out; + matches_out.reserve(potential_matches.size()); + + for (auto p : potential_matches) { + matches_out.push_back(p.first); } - else if (start < first_control[gi]) { - start = first_control[gi]; + + return matches_out; +} + +// [[Rcpp::interfaces(cpp)]] +std::vector find_control_mat(const int& t_id, + const IntegerVector& treat, + const IntegerVector& ind_non_focal, + const NumericVector& distance_mat_row_i, + const LogicalVector& eligible, + const int& gi, + const int& r, + const IntegerVector& mm_rowi, + const int& ncc, + const NumericMatrix& caliper_covs_mat, + const NumericVector& caliper_covs, + const double& caliper_dist, + const bool& use_exact, + const IntegerVector& exact, + const int& aenc, + const IntegerMatrix& antiexact_covs, + const int& ratio = 1) { + + int c_id_i; + double dist_c; + + std::vector potential_matches_id; + + if (ratio < 1) { + return potential_matches_id; } - for (iik = start; iik >= min_ii && iik <= max_ii; iik += z) { + std::vector potential_matches_dist; + double max_dist; + + R_xlen_t nc = distance_mat_row_i.size(); - ik = ind_d_ord[iik]; + potential_matches_id.reserve(nc); + potential_matches_dist.reserve(nc); - if (!eligible[ik]) { + for (R_xlen_t c = 0; c < nc; c++) { + + dist_c = distance_mat_row_i[c]; + + if (potential_matches_id.size() == ratio) { + if (dist_c > max_dist) { + continue; + } + } + + if (dist_c > caliper_dist) { continue; } - if (treat[ik] != gi) { + if (!std::isfinite(dist_c)) { continue; } - dist = std::abs(distance[t_id] - distance[ik]); + c_id_i = ind_non_focal[c]; - if (dist > caliper_dist) { - return -1; + if (!eligible[c_id_i]) { + continue; } - if (!exact_okay(use_exact, t_id, ik, exact)) { + if (treat[c_id_i] != gi) { continue; } - if (!antiexact_okay(aenc, t_id, ik, antiexact_covs)) { + if (!mm_okay(r, c_id_i, mm_rowi)) { continue; } - if (!caliper_covs_okay(ncc, t_id, ik, caliper_covs_mat, caliper_covs)) { + if (!exact_okay(use_exact, t_id, c_id_i, exact)) { continue; } - return ik; + if (!antiexact_okay(aenc, t_id, c_id_i, antiexact_covs)) { + continue; + } + + if (!caliper_covs_okay(ncc, t_id, c_id_i, caliper_covs_mat, caliper_covs)) { + continue; + } + + potential_matches_id.push_back(c_id_i); + potential_matches_dist.push_back(dist_c); + + if (potential_matches_id.size() == 1) { + max_dist = dist_c; + } + else if (dist_c > max_dist) { + max_dist = dist_c; + } } - return -1; -} + int n_potential_matches = potential_matches_id.size(); -// [[Rcpp::interfaces(cpp)]] -void swap_pos(IntegerVector x, - const int& a, - const int& b) { - int xa = x[a]; + if (n_potential_matches <= 1) { + return potential_matches_id; + } + + if (n_potential_matches <= ratio && + std::is_sorted(potential_matches_dist.begin(), + potential_matches_dist.end())) { + return potential_matches_id; + } + + std::vector ind(n_potential_matches); + std::iota(ind.begin(), ind.end(), 0); + + std::vector matches_out; + + if (n_potential_matches > ratio) { + std::partial_sort(ind.begin(), ind.begin() + ratio, ind.end(), + [&potential_matches_dist](int a, int b){ + return potential_matches_dist[a] < potential_matches_dist[b]; + }); + + matches_out.reserve(ratio); + + for (auto it = ind.begin(); it != ind.begin() + ratio; ++it) { + matches_out.push_back(potential_matches_id[*it]); + } + } + else { + std::sort(ind.begin(), ind.end(), + [&potential_matches_dist](int a, int b){ + return potential_matches_dist[a] < potential_matches_dist[b]; + }); + + matches_out.reserve(n_potential_matches); + + for (auto it = ind.begin(); it != ind.end(); ++it) { + matches_out.push_back(potential_matches_id[*it]); + } + } - x[a] = x[b]; - x[b] = xa; + return matches_out; } // [[Rcpp::interfaces(cpp)]] From f2e9482e718a1908dca4b1121006388491162057 Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Tue, 12 Nov 2024 02:29:16 -0500 Subject: [PATCH 33/48] Rewrote distmat matching to mirror other matching algorithms --- src/nn_matchC.cpp | 460 ------------------------------ src/nn_matchC_closest.cpp | 229 --------------- src/nn_matchC_distmat.cpp | 322 +++++++++++++++++++++ src/nn_matchC_distmat_closest.cpp | 319 +++++++++++++++++++++ 4 files changed, 641 insertions(+), 689 deletions(-) delete mode 100644 src/nn_matchC.cpp delete mode 100644 src/nn_matchC_closest.cpp create mode 100644 src/nn_matchC_distmat.cpp create mode 100644 src/nn_matchC_distmat_closest.cpp diff --git a/src/nn_matchC.cpp b/src/nn_matchC.cpp deleted file mode 100644 index da5c2404..00000000 --- a/src/nn_matchC.cpp +++ /dev/null @@ -1,460 +0,0 @@ -// [[Rcpp::depends(RcppProgress)]] -#include -#include "eta_progress_bar.h" -#include -#include "internal.h" -#include -using namespace Rcpp; - -// [[Rcpp::plugins(cpp11)]] - -// [[Rcpp::export]] -IntegerMatrix nn_matchC(const IntegerVector& treat_, - const IntegerVector& ord, - const IntegerVector& ratio, - const LogicalVector& discarded, - const int& reuse_max, - const int& focal_, - const Nullable& distance_ = R_NilValue, - const Nullable& distance_mat_ = R_NilValue, - const Nullable& exact_ = R_NilValue, - const Nullable& caliper_dist_ = R_NilValue, - const Nullable& caliper_covs_ = R_NilValue, - const Nullable& caliper_covs_mat_ = R_NilValue, - const Nullable& mah_covs_ = R_NilValue, - const Nullable& antiexact_covs_ = R_NilValue, - const Nullable& unit_id_ = R_NilValue, - const bool& disl_prog = false) { - - IntegerVector unique_treat = unique(treat_); - std::sort(unique_treat.begin(), unique_treat.end()); - int g = unique_treat.size(); - IntegerVector treat = match(treat_, unique_treat) - 1; - int focal; - for (focal = 0; focal < g; focal++) { - if (unique_treat[focal] == focal_) { - break; - } - } - - R_xlen_t n = treat.size(); - IntegerVector ind = Range(0, n - 1); - - R_xlen_t i; - int gi; - IntegerVector indt(n); - IntegerVector indt_sep(g + 1); - IntegerVector indt_tmp; - IntegerVector nt(g); - IntegerVector ind_match(n); - ind_match.fill(NA_INTEGER); - - IntegerVector times_matched(n); - times_matched.fill(0); - LogicalVector eligible = !discarded; - - IntegerVector g_c = Range(0, g - 1); - g_c = g_c[g_c != focal]; - - for (gi = 0; gi < g; gi++) { - nt[gi] = sum(treat == gi); - } - - int nf = nt[focal]; - - int max_nc = max(as(nt[g_c])); - - indt_sep[0] = 0; - - for (gi = 0; gi < g; gi++) { - indt_sep[gi + 1] = indt_sep[gi] + nt[gi]; - - indt_tmp = ind[treat == gi]; - - for (i = 0; i < nt[gi]; i++) { - indt[indt_sep[gi] + i] = indt_tmp[i]; - ind_match[indt_tmp[i]] = i; - } - } - - IntegerVector ind_focal = indt[Range(indt_sep[focal], indt_sep[focal + 1] - 1)]; - - IntegerVector times_matched_allowed(n); - times_matched_allowed.fill(reuse_max); - times_matched_allowed[ind_focal] = ratio; - - IntegerVector n_eligible(unique_treat.size()); - for (i = 0; i < n; i++) { - if (eligible[i]) { - n_eligible[treat[i]]++; - } - } - - int max_ratio = max(ratio); - - // Output matrix with sample indices of control units - IntegerMatrix mm(nf, max_ratio); - mm.fill(NA_INTEGER); - CharacterVector lab = treat_.names(); - - int min_ind, t_rat; - - //exact - bool use_exact = false; - IntegerVector exact; - if (exact_.isNotNull()) { - exact = as(exact_); - use_exact = true; - } - - //caliper_dist - bool use_caliper_dist = false; - double caliper_dist; - if (caliper_dist_.isNotNull()) { - caliper_dist = as(caliper_dist_); - use_caliper_dist = true; - } - - //caliper_covs - NumericVector caliper_covs; - NumericMatrix caliper_covs_mat; - int ncc = 0; - if (caliper_covs_.isNotNull()) { - caliper_covs = as(caliper_covs_); - caliper_covs_mat = as(caliper_covs_mat_); - ncc = caliper_covs_mat.ncol(); - } - - //dsit_mat and mah_covs - bool use_dist_mat = false; - bool use_mah_covs = false; - NumericMatrix distance_mat, mah_covs; - if (mah_covs_.isNotNull()) { - mah_covs = as(mah_covs_); - use_mah_covs = true; - } - else if (distance_mat_.isNotNull()) { - distance_mat = as(distance_mat_); - use_dist_mat = true; - } - - //distance - NumericVector distance; - if (distance_.isNotNull()) { - distance = distance_; - } - - //anitexact - IntegerMatrix antiexact_covs; - int aenc = 0; - if (antiexact_covs_.isNotNull()) { - antiexact_covs = as(antiexact_covs_); - aenc = antiexact_covs.ncol(); - } - - //reuse_max - bool use_reuse_max = (reuse_max < nf); - - //unit_id - IntegerVector unit_id; - IntegerVector matched_unit_ids; - bool use_unit_id = false; - if (unit_id_.isNotNull()) { - unit_id = as(unit_id_); - use_unit_id = true; - use_reuse_max = true; - matched_unit_ids = rep(NA_INTEGER, max_nc); - } - - IntegerVector c_eligible(max_nc); - NumericVector match_distance(max_nc); - IntegerVector matches_i(1 + max_ratio * (g - 1)); - int k_total; - - //progress bar - int prog_length; - if (use_reuse_max) prog_length = sum(ratio) + 1; - else prog_length = nf + 1; - ETAProgressBar pb; - Progress p(prog_length, disl_prog, pb); - - //Counters - int r, t_id_t_i, t_id_i, c_id_i, c, k; - double ps_diff, dist; - IntegerVector ck_, top_r_matches; - bool ps_diff_calculated; - - int counter = -1; - - //Matching - if (use_reuse_max) { - for (r = 1; r <= max_ratio; r++) { - for (i = 0; i < nf && max(as(n_eligible[g_c])) > 0; i++) { - - counter++; - if (counter % 200 == 0) Rcpp::checkUserInterrupt(); - - t_id_t_i = ord[i] - 1; // index among treated - t_id_i = ind_focal[t_id_t_i]; // index among sample - - if (r > times_matched_allowed[t_id_i]) { - continue; - } - - p.increment(); - - if (!eligible[t_id_i]) { - continue; - } - - k_total = 0; - - for (int gi : g_c) { - k = 0; - - if (n_eligible[gi] > 0) { - for (c = indt_sep[gi]; c < indt_sep[gi + 1]; c++) { - c_id_i = indt[c]; - - if (!eligible[c_id_i]) { - continue; - } - - //Prevent control units being matched to same treated unit again - if (!mm_okay(r, c_id_i, mm.row(t_id_t_i))) { - continue; - } - - if (!exact_okay(use_exact, t_id_i, c_id_i, exact)) { - continue; - } - - if (!antiexact_okay(aenc, t_id_i, c_id_i, antiexact_covs)) { - continue; - } - - ps_diff_calculated = false; - - if (use_caliper_dist) { - if (use_dist_mat) { - ps_diff = distance_mat(t_id_t_i, ind_match[c_id_i]); - } - else { - ps_diff = std::abs(distance[c_id_i] - distance[t_id_i]); - } - - if (ps_diff > caliper_dist) { - continue; - } - - ps_diff_calculated = true; - } - - if (!caliper_covs_okay(ncc, t_id_i, c_id_i, caliper_covs_mat, caliper_covs)) { - continue; - } - - //Compute distances among eligible - if (use_mah_covs) { - dist = sum(pow(mah_covs.row(t_id_i) - mah_covs.row(c_id_i), 2.0)); - } - else if (ps_diff_calculated) { - dist = ps_diff; - } - else if (use_dist_mat) { - dist = distance_mat(t_id_t_i, ind_match[c_id_i]); - } - else { - dist = std::abs(distance[c_id_i] - distance[t_id_i]); - } - - if (!std::isfinite(dist)) { - continue; - } - - c_eligible[k] = c_id_i; - match_distance[k] = dist; - k++; - } - } - - //If no matches... - if (k == 0) { - //If round 1, focal has no possible matches - if (r == 1) { - k_total = 0; - break; - } - continue; - } - - //Find minimum distance and assign - min_ind = 0; - for (c = 1; c < k; c++) { - if (match_distance[c] < match_distance[min_ind]) { - min_ind = c; - } - } - - matches_i[k_total] = c_eligible[min_ind]; - k_total++; - } - - if (k_total == 0) { - eligible[t_id_i] = false; - n_eligible[focal]--; - continue; - } - - //Assign to match matrix - for (c = 0; c < k_total; c++) { - mm(t_id_t_i, sum(!is_na(mm(t_id_t_i, _)))) = matches_i[c]; - } - - matches_i[k_total] = t_id_i; - - ck_ = matches_i[Range(0, k_total)]; - - if (use_unit_id) { - ck_ = which(!is_na(match(unit_id, as(unit_id[ck_])))); - } - - for (int ck : ck_) { - if (!eligible[ck]) { - continue; - } - - times_matched[ck]++; - if (times_matched[ck] >= times_matched_allowed[ck]) { - eligible[ck] = false; - n_eligible[treat[ck]]--; - } - } - } - } - } - else { - for (i = 0; i < nf; i++) { - - counter++; - if (counter % 500 == 0) Rcpp::checkUserInterrupt(); - - t_id_t_i = ord[i] - 1; // index among treated - t_id_i = ind_focal[t_id_t_i]; // index among sample - - p.increment(); - - if (!eligible[t_id_i]) { - continue; - } - - t_rat = ratio[t_id_t_i]; - - k_total = 0; - - for (int gi : g_c) { - k = 0; - - if (n_eligible[gi] > 0) { - for (c = indt_sep[gi]; c < indt_sep[gi + 1]; c++) { - c_id_i = indt[c]; - - if (!eligible[c_id_i]) { - continue; - } - - if (!exact_okay(use_exact, t_id_i, c_id_i, exact)) { - continue; - } - - if (!antiexact_okay(aenc, t_id_i, c_id_i, antiexact_covs)) { - continue; - } - - ps_diff_calculated = false; - - if (use_caliper_dist) { - if (use_dist_mat) { - ps_diff = distance_mat(t_id_t_i, ind_match[c_id_i]); - } - else { - ps_diff = std::abs(distance[c_id_i] - distance[t_id_i]); - } - - if (ps_diff > caliper_dist) { - continue; - } - - ps_diff_calculated = true; - } - - if (!caliper_covs_okay(ncc, t_id_i, c_id_i, caliper_covs_mat, caliper_covs)) { - continue; - } - - //Compute distances among eligible - if (use_mah_covs) { - dist = sqrt(sum(pow(mah_covs.row(t_id_i) - mah_covs.row(c_id_i), 2.0))); - } - else if (ps_diff_calculated) { - dist = ps_diff; - } - else if (use_dist_mat) { - dist = distance_mat(t_id_t_i, ind_match[c_id_i]); - } - else { - dist = std::abs(distance[c_id_i] - distance[t_id_i]); - } - - if (!std::isfinite(dist)) { - continue; - } - - c_eligible[k] = c_id_i; - match_distance[k] = dist; - k++; - } - } - - //If no matches... - if (k == 0) { - k_total = 0; - break; - } - - //If replace and few eligible controls, assign all and move on - - if (k < t_rat) { - t_rat = k; - } - - //Sort distances and assign - top_r_matches = Range(0, k - 1); - - std::partial_sort(top_r_matches.begin(), top_r_matches.begin() + t_rat, top_r_matches.end(), - [&match_distance](int a, int b) {return match_distance[a] < match_distance[b];}); - - for (c = 0; c < t_rat; c++) { - matches_i[k_total] = c_eligible[top_r_matches[c]]; - k_total++; - } - } - - if (k_total == 0) { - continue; - } - - //Assign to match matrix - for (c = 0; c < k_total; c++) { - mm(t_id_t_i, sum(!is_na(mm(t_id_t_i, _)))) = matches_i[c]; - } - } - } - - p.update(prog_length); - - mm = mm + 1; // + 1 because C indexing starts at 0 but mm is sent to R - rownames(mm) = lab[ind_focal]; - - return mm; -} diff --git a/src/nn_matchC_closest.cpp b/src/nn_matchC_closest.cpp deleted file mode 100644 index f33bb45b..00000000 --- a/src/nn_matchC_closest.cpp +++ /dev/null @@ -1,229 +0,0 @@ -// [[Rcpp::depends(RcppProgress)]] -#include -#include "eta_progress_bar.h" -#include -#include "internal.h" -using namespace Rcpp; - -// [[Rcpp::plugins(cpp11)]] - -// [[Rcpp::export]] -IntegerMatrix nn_matchC_closest(const NumericMatrix& distance_mat, - const IntegerVector& treat, - const IntegerVector& ratio, - const LogicalVector& discarded, - const int& reuse_max, - const Nullable& exact_ = R_NilValue, - const Nullable& caliper_dist_ = R_NilValue, - const Nullable& caliper_covs_ = R_NilValue, - const Nullable& caliper_covs_mat_ = R_NilValue, - const Nullable& antiexact_covs_ = R_NilValue, - const Nullable& unit_id_ = R_NilValue, - const bool& disl_prog = false) { - - IntegerVector unique_treat = {0, 1}; - int g = unique_treat.size(); - int focal = 1; - - R_xlen_t n = treat.size(); - IntegerVector ind = Range(0, n - 1); - - R_xlen_t i; - int gi; - IntegerVector indt(n); - IntegerVector indt_sep(g + 1); - IntegerVector indt_tmp; - IntegerVector nt(g); - IntegerVector ind_match(n); - ind_match.fill(NA_INTEGER); - - IntegerVector times_matched(n); - times_matched.fill(0); - LogicalVector eligible = !discarded; - - for (gi = 0; gi < g; gi++) { - nt[gi] = sum(treat == gi); - } - - int nf = nt[focal]; - - indt_sep[0] = 0; - - for (gi = 0; gi < g; gi++) { - indt_sep[gi + 1] = indt_sep[gi] + nt[gi]; - - indt_tmp = ind[treat == gi]; - - for (i = 0; i < nt[gi]; i++) { - indt[indt_sep[gi] + i] = indt_tmp[i]; - ind_match[indt_tmp[i]] = i; - } - } - - IntegerVector ind_focal = indt[Range(indt_sep[focal], indt_sep[focal + 1] - 1)]; - - IntegerVector times_matched_allowed(n); - times_matched_allowed.fill(reuse_max); - times_matched_allowed[ind_focal] = ratio; - - IntegerVector n_eligible(g); - for (i = 0; i < n; i++) { - if (eligible[i]) { - n_eligible[treat[i]]++; - } - } - - int max_ratio = max(ratio); - - // Output matrix with sample indices of control units - IntegerMatrix mm(nf, max_ratio); - mm.fill(NA_INTEGER); - CharacterVector lab = treat.names(); - - //exact - bool use_exact = false; - IntegerVector exact; - if (exact_.isNotNull()) { - exact = as(exact_); - use_exact = true; - } - - //caliper_dist - double caliper_dist; - if (caliper_dist_.isNotNull()) { - caliper_dist = as(caliper_dist_); - } - else { - caliper_dist = max_finite(distance_mat) + .1; - } - - //caliper_covs - NumericVector caliper_covs; - NumericMatrix caliper_covs_mat; - int ncc; - if (caliper_covs_.isNotNull()) { - caliper_covs = as(caliper_covs_); - caliper_covs_mat = as(caliper_covs_mat_); - ncc = caliper_covs_mat.ncol(); - } - else { - ncc = 0; - } - - //antiexact - IntegerMatrix antiexact_covs; - int aenc; - if (antiexact_covs_.isNotNull()) { - antiexact_covs = as(antiexact_covs_); - aenc = antiexact_covs.ncol(); - } - else { - aenc = 0; - } - - //unit_id - IntegerVector unit_id, ck_; - bool use_unit_id = false; - if (unit_id_.isNotNull()) { - unit_id = as(unit_id_); - use_unit_id = true; - } - - //progress bar - int prog_length; - prog_length = sum(ratio) + 1; - ETAProgressBar pb; - Progress p(prog_length, disl_prog, pb); - - Function o("order"); - - IntegerVector d_ord = o(distance_mat); - d_ord = d_ord - 1; //Because R uses 1-indexing - - gi = 0; - - R_xlen_t r = distance_mat.nrow(); - - int rj, cj, c_id_i, t_id_i; - int counter = -1; - - for (R_xlen_t dj : d_ord) { - counter++; - if (counter % 200 == 0) Rcpp::checkUserInterrupt(); - - if (min(n_eligible) <= 0) { - break; - } - - // If distance is greater than distance caliper, stop the whole thing because - // no remaining distance will be smaller - if (distance_mat[dj] > caliper_dist) { - break; - } - - // Get row and column index of potential pair - rj = dj % r; - cj = dj / r; - - // Get sample indices of members of potential pair - t_id_i = ind_focal[rj]; - - // If either member is discarded, move on - if (!eligible[t_id_i]) { - continue; - } - - c_id_i = indt[indt_sep[gi] + cj]; - - if (!eligible[c_id_i]) { - continue; - } - - // Exact matching criterion - if (!exact_okay(use_exact, t_id_i, c_id_i, exact)) { - continue; - } - - // Antiexact criterion - if (!antiexact_okay(aenc, t_id_i, c_id_i, antiexact_covs)) { - continue; - } - - // Covariate caliper criterion - if (!caliper_covs_okay(ncc, t_id_i, c_id_i, caliper_covs_mat, caliper_covs)) { - continue; - } - - // If all criteria above are satisfied, potential pair becomes a pair! - mm(rj, sum(!is_na(mm(rj, _)))) = c_id_i; - - // If unit_id used, increase match count of all units with that ID - ck_ = {t_id_i, c_id_i}; - - if (use_unit_id) { - ck_ = which(!is_na(match(unit_id, as(unit_id[ck_])))); - } - - for (int ck : ck_) { - - if (!eligible[ck]) { - continue; - } - - times_matched[ck]++; - if (times_matched[ck] >= times_matched_allowed[ck]) { - eligible[ck] = false; - n_eligible[treat[ck]]--; - } - } - - p.increment(); - } - - p.update(prog_length); - - mm = mm + 1; - rownames(mm) = lab[ind_focal]; - - return mm; -} diff --git a/src/nn_matchC_distmat.cpp b/src/nn_matchC_distmat.cpp new file mode 100644 index 00000000..fb2d879a --- /dev/null +++ b/src/nn_matchC_distmat.cpp @@ -0,0 +1,322 @@ +// [[Rcpp::depends(RcppProgress)]] +#include "eta_progress_bar.h" +#include "internal.h" +using namespace Rcpp; + +// [[Rcpp::plugins(cpp11)]] + +// [[Rcpp::export]] +IntegerMatrix nn_matchC_distmat(const IntegerVector& treat_, + const IntegerVector& ord, + const IntegerVector& ratio, + const LogicalVector& discarded, + const int& reuse_max, + const int& focal_, + const NumericMatrix& distance_mat, + const Nullable& exact_ = R_NilValue, + const Nullable& caliper_dist_ = R_NilValue, + const Nullable& caliper_covs_ = R_NilValue, + const Nullable& caliper_covs_mat_ = R_NilValue, + const Nullable& antiexact_covs_ = R_NilValue, + const Nullable& unit_id_ = R_NilValue, + const bool& disl_prog = false) { + + IntegerVector unique_treat = unique(treat_); + std::sort(unique_treat.begin(), unique_treat.end()); + int g = unique_treat.size(); + IntegerVector treat = match(treat_, unique_treat) - 1; + int focal; + for (focal = 0; focal < g; focal++) { + if (unique_treat[focal] == focal_) { + break; + } + } + + R_xlen_t n = treat.size(); + IntegerVector ind = Range(0, n - 1); + + R_xlen_t i; + int gi; + IntegerVector indt(n); + IntegerVector indt_sep(g + 1); + IntegerVector indt_tmp; + IntegerVector nt(g); + IntegerVector ind_match(n); + ind_match.fill(NA_INTEGER); + + LogicalVector eligible = !discarded; + + IntegerVector g_c = Range(0, g - 1); + g_c = g_c[g_c != focal]; + + IntegerVector n_eligible(g); + for (i = 0; i < n; i++) { + nt[treat[i]]++; + + if (eligible[i]) { + n_eligible[treat[i]]++; + } + } + + int nf = nt[focal]; + + indt_sep[0] = 0; + + for (gi = 0; gi < g; gi++) { + indt_sep[gi + 1] = indt_sep[gi] + nt[gi]; + + indt_tmp = ind[treat == gi]; + + for (i = 0; i < nt[gi]; i++) { + indt[indt_sep[gi] + i] = indt_tmp[i]; + } + } + + IntegerVector ind_focal = indt[Range(indt_sep[focal], indt_sep[focal + 1] - 1)]; + + std::vector times_matched(n, 0); + + std::vector times_matched_allowed(n, reuse_max); + for (i = 0; i < nf; i++) { + times_matched_allowed[ind_focal[i]] = ratio[i]; + } + + int max_ratio = max(ratio); + + IntegerVector ind_non_focal = which(treat != focal); + + for (i = 0; i < n - nf; i++) { + ind_match[ind_non_focal[i]] = i; + } + + for (i = 0; i < nf; i++) { + ind_match[ind_focal[i]] = i; + } + + // Output matrix with sample indices of control units + IntegerMatrix mm(nf, max_ratio); + mm.fill(NA_INTEGER); + CharacterVector lab = treat_.names(); + + //exact + bool use_exact = false; + IntegerVector exact; + if (exact_.isNotNull()) { + exact = as(exact_); + use_exact = true; + } + + //caliper_dist + double caliper_dist; + if (caliper_dist_.isNotNull()) { + caliper_dist = as(caliper_dist_); + } + else { + caliper_dist = max_finite(distance_mat) + .1; + } + + //caliper_covs + NumericVector caliper_covs; + NumericMatrix caliper_covs_mat; + int ncc = 0; + if (caliper_covs_.isNotNull()) { + caliper_covs = as(caliper_covs_); + caliper_covs_mat = as(caliper_covs_mat_); + ncc = caliper_covs_mat.ncol(); + } + + //antiexact + IntegerMatrix antiexact_covs; + int aenc = 0; + if (antiexact_covs_.isNotNull()) { + antiexact_covs = as(antiexact_covs_); + aenc = antiexact_covs.ncol(); + } + + //reuse_max + bool use_reuse_max = (reuse_max < nf); + + //unit_id + IntegerVector unit_id; + bool use_unit_id = false; + if (unit_id_.isNotNull()) { + unit_id = as(unit_id_); + use_unit_id = true; + use_reuse_max = true; + } + + IntegerVector matches_i(1 + max_ratio * (g - 1)); + int k_total; + + //progress bar + int prog_length; + if (use_reuse_max) prog_length = sum(ratio) + 1; + else prog_length = nf + 1; + ETAProgressBar pb; + Progress p(prog_length, disl_prog, pb); + + R_xlen_t c; + int r, t_id_t_i, t_id_i; + IntegerVector ck_; + std::vector k(max_ratio); + + int counter = 0; + + if (use_reuse_max) { + for (r = 1; r <= max_ratio; r++) { + for (auto it = ord.begin(); it != ord.end() && max(as(n_eligible[g_c])) > 0; ++it) { + // i: generic looping index + // t_id_t_i; index of treated unit to match among treated units + // t_id_i: index of treated unit to match among all units + counter++; + if (counter == 200) { + counter = 0; + Rcpp::checkUserInterrupt(); + } + + t_id_t_i = *it - 1; + t_id_i = ind_focal[t_id_t_i]; + + if (r > times_matched_allowed[t_id_i]) { + continue; + } + + p.increment(); + + if (!eligible[t_id_i]) { + continue; + } + + k_total = 0; + + for (int gi : g_c) { + k = find_control_mat(t_id_i, + treat, + ind_non_focal, + distance_mat.row(t_id_t_i), + eligible, + gi, + r, + mm.row(t_id_t_i), + ncc, + caliper_covs_mat, + caliper_covs, + caliper_dist, + use_exact, + exact, + aenc, + antiexact_covs); + + if (k.empty()) { + if (r == 1) { + k_total = 0; + break; + } + continue; + } + + matches_i[k_total] = k[0]; + k_total++; + } + + if (k_total == 0) { + eligible[t_id_i] = false; + n_eligible[focal]--; + continue; + } + + for (c = 0; c < k_total; c++) { + mm(t_id_t_i, sum(!is_na(mm(t_id_t_i, _)))) = matches_i[c]; + } + + matches_i[k_total] = t_id_i; + + ck_ = matches_i[Range(0, k_total)]; + + if (use_unit_id) { + ck_ = which(!is_na(match(unit_id, as(unit_id[ck_])))); + } + + for (int ck : ck_) { + if (!eligible[ck]) { + continue; + } + + times_matched[ck]++; + if (times_matched[ck] >= times_matched_allowed[ck]) { + eligible[ck] = false; + n_eligible[treat[ck]]--; + } + } + } + } + } + else { + for (auto it = ord.begin(); it != ord.end(); ++it) { + // i: generic looping index + // t_id_t_i; index of treated unit to match among treated units + // t_id_i: index of treated unit to match among all units + counter++; + if (counter == 200) { + counter = 0; + Rcpp::checkUserInterrupt(); + } + + t_id_t_i = *it - 1; + t_id_i = ind_focal[t_id_t_i]; + + p.increment(); + + if (!eligible[t_id_i]) { + continue; + } + + k_total = 0; + + for (int gi : g_c) { + k = find_control_mat(t_id_i, + treat, + ind_non_focal, + distance_mat.row(t_id_t_i), + eligible, + gi, + 1, + mm.row(t_id_t_i), + ncc, + caliper_covs_mat, + caliper_covs, + caliper_dist, + use_exact, + exact, + aenc, + antiexact_covs, + ratio[t_id_t_i]); + + if (k.empty()) { + k_total = 0; + break; + } + + for (int cc : k) { + matches_i[k_total] = cc; + k_total++; + } + } + + if (k_total == 0) { + continue; + } + + for (c = 0; c < k_total; c++) { + mm(t_id_t_i, sum(!is_na(mm(t_id_t_i, _)))) = matches_i[c]; + } + } + } + + p.update(prog_length); + + mm = mm + 1; + rownames(mm) = as(lab[ind_focal]); + + return mm; +} \ No newline at end of file diff --git a/src/nn_matchC_distmat_closest.cpp b/src/nn_matchC_distmat_closest.cpp new file mode 100644 index 00000000..c4651556 --- /dev/null +++ b/src/nn_matchC_distmat_closest.cpp @@ -0,0 +1,319 @@ +// [[Rcpp::depends(RcppProgress)]] +#include "eta_progress_bar.h" +#include "internal.h" +using namespace Rcpp; + +// [[Rcpp::plugins(cpp11)]] + +// [[Rcpp::export]] +IntegerMatrix nn_matchC_distmat_closest(const IntegerVector& treat, + const IntegerVector& ratio, + const LogicalVector& discarded, + const int& reuse_max, + const NumericMatrix& distance_mat, + const Nullable& exact_ = R_NilValue, + const Nullable& caliper_dist_ = R_NilValue, + const Nullable& caliper_covs_ = R_NilValue, + const Nullable& caliper_covs_mat_ = R_NilValue, + const Nullable& antiexact_covs_ = R_NilValue, + const Nullable& unit_id_ = R_NilValue, + const bool& close = true, + const bool& disl_prog = false) { + + IntegerVector unique_treat = {0, 1}; + int g = unique_treat.size(); + int focal = 1; + + R_xlen_t n = treat.size(); + IntegerVector ind = Range(0, n - 1); + + R_xlen_t i; + int gi; + IntegerVector indt(n); + IntegerVector indt_sep(g + 1); + IntegerVector indt_tmp; + IntegerVector nt(g); + IntegerVector ind_match(n); + ind_match.fill(NA_INTEGER); + + IntegerVector times_matched(n); + times_matched.fill(0); + LogicalVector eligible = !discarded; + + for (gi = 0; gi < g; gi++) { + nt[gi] = sum(treat == gi); + } + + int nf = nt[focal]; + + indt_sep[0] = 0; + + for (gi = 0; gi < g; gi++) { + indt_sep[gi + 1] = indt_sep[gi] + nt[gi]; + + indt_tmp = ind[treat == gi]; + + for (i = 0; i < nt[gi]; i++) { + indt[indt_sep[gi] + i] = indt_tmp[i]; + } + } + + IntegerVector ind_non_focal = which(treat != focal); + IntegerVector ind_focal = which(treat == focal); + + for (i = 0; i < n - nf; i++) { + ind_match[ind_non_focal[i]] = i; + } + + for (i = 0; i < nf; i++) { + ind_match[ind_focal[i]] = i; + } + + IntegerVector times_matched_allowed(n); + times_matched_allowed.fill(reuse_max); + times_matched_allowed[ind_focal] = ratio; + + IntegerVector n_eligible(g); + for (i = 0; i < n; i++) { + if (eligible[i]) { + n_eligible[treat[i]]++; + } + } + + int max_ratio = max(ratio); + + // Output matrix with sample indices of control units + IntegerMatrix mm(nf, max_ratio); + mm.fill(NA_INTEGER); + CharacterVector lab = treat.names(); + + Function o("order"); + + //exact + bool use_exact = false; + IntegerVector exact; + if (exact_.isNotNull()) { + exact = as(exact_); + use_exact = true; + } + + //caliper_dist + double caliper_dist; + if (caliper_dist_.isNotNull()) { + caliper_dist = as(caliper_dist_); + } + else { + caliper_dist = max_finite(distance_mat) + .1; + } + + //caliper_covs + NumericVector caliper_covs; + NumericMatrix caliper_covs_mat; + int ncc = 0; + if (caliper_covs_.isNotNull()) { + caliper_covs = as(caliper_covs_); + caliper_covs_mat = as(caliper_covs_mat_); + ncc = caliper_covs_mat.ncol(); + } + + //antiexact + IntegerMatrix antiexact_covs; + int aenc = 0; + if (antiexact_covs_.isNotNull()) { + antiexact_covs = as(antiexact_covs_); + aenc = antiexact_covs.ncol(); + } + + //unit_id + IntegerVector unit_id; + bool use_unit_id = false; + if (unit_id_.isNotNull()) { + unit_id = as(unit_id_); + use_unit_id = true; + } + + //storing closeness + std::vector t_id, c_id; + std::vector dist; + t_id.reserve(n_eligible[focal]); + c_id.reserve(n_eligible[focal]); + dist.reserve(n_eligible[focal]); + + //progress bar + R_xlen_t prog_length = n_eligible[focal] + sum(ratio) + 1; + ETAProgressBar pb; + Progress p(prog_length, disl_prog, pb); + + gi = 0; + + IntegerVector ck_; + + int c_id_i, t_id_t_i, t_id_i; + + int counter = 0; + int r = 1; + + IntegerVector heap_ord(n_eligible[focal]); + std::vector k; + k.reserve(1); + R_xlen_t hi; + + IntegerVector::iterator ci; + + std::function cmp; + if (close) { + cmp = [&dist](const int& a, const int& b) {return dist[a] < dist[b];}; + } + else { + cmp = [&dist](const int& a, const int& b) {return dist[a] >= dist[b];}; + } + + for (r = 1; r <= max_ratio; r++) { + for (int ti : ind_focal) { + if (!eligible[ti]) { + continue; + } + + counter++; + if (counter == 200) { + counter = 0; + Rcpp::checkUserInterrupt(); + } + + t_id_t_i = ind_match[ti]; + + k = find_control_mat(ti, + treat, + ind_non_focal, + distance_mat.row(t_id_t_i), + eligible, + gi, + r, + mm.row(t_id_t_i), + ncc, + caliper_covs_mat, + caliper_covs, + caliper_dist, + use_exact, + exact, + aenc, + antiexact_covs); + + p.increment(); + + if (k.empty()) { + eligible[ti] = false; + n_eligible[focal]--; + continue; + } + + t_id.push_back(ti); + c_id.push_back(k[0]); + dist.push_back(distance_mat(t_id_t_i, ind_match[k[0]])); + } + + nf = dist.size(); + + //Order the list + heap_ord = o(dist, _["decreasing"] = !close); + heap_ord = heap_ord - 1; + + i = 0; + while (min(n_eligible) > 0 && i < nf) { + counter++; + if (counter == 200) { + counter = 0; + Rcpp::checkUserInterrupt(); + } + + hi = heap_ord[i]; + + t_id_i = t_id[hi]; + + if (!eligible[t_id_i]) { + i++; + continue; + } + + t_id_t_i = ind_match[t_id_i]; + + c_id_i = c_id[hi]; + + if (!eligible[c_id_i]) { + // If control isn't eligible, find new control and try again + + k = find_control_mat(t_id_i, + treat, + ind_non_focal, + distance_mat.row(t_id_t_i), + eligible, + gi, + r, + mm.row(t_id_t_i), + ncc, + caliper_covs_mat, + caliper_covs, + caliper_dist, + use_exact, + exact, + aenc, + antiexact_covs); + + //If no matches... + if (k.empty()) { + eligible[t_id_i] = false; + n_eligible[focal]--; + continue; + } + + c_id[hi] = k[0]; + dist[hi] = distance_mat(t_id_t_i, ind_match[k[0]]); + + // Find new position of pair in heap + ci = std::lower_bound(heap_ord.begin() + i, heap_ord.end(), hi, cmp); + + if (ci != heap_ord.begin() + i) { + std::rotate(heap_ord.begin() + i, heap_ord.begin() + i + 1, ci); + } + + continue; + } + + mm(t_id_t_i, sum(!is_na(mm(t_id_t_i, _)))) = c_id_i; + + ck_ = {c_id_i, t_id_i}; + + if (use_unit_id) { + ck_ = which(!is_na(match(unit_id, as(unit_id[ck_])))); + } + + for (int ck : ck_) { + + if (!eligible[ck]) { + continue; + } + + times_matched[ck]++; + if (times_matched[ck] >= times_matched_allowed[ck]) { + eligible[ck] = false; + n_eligible[treat[ck]]--; + } + } + + p.increment(); + + i++; + } + + t_id.clear(); + c_id.clear(); + dist.clear(); + } + + p.update(prog_length); + + mm = mm + 1; + rownames(mm) = as(lab[ind_focal]); + + return mm; +} \ No newline at end of file From 70aebd213f1dbf2ea59ec607decff237fd3a4553 Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Tue, 12 Nov 2024 02:29:43 -0500 Subject: [PATCH 34/48] Rewrote matching for mahcovs and vec --- src/nn_matchC_mahcovs.cpp | 346 ++++++++++++++++++++++++++++++++++++++ src/nn_matchC_vec.cpp | 246 +++++++++++++++++---------- 2 files changed, 506 insertions(+), 86 deletions(-) create mode 100644 src/nn_matchC_mahcovs.cpp diff --git a/src/nn_matchC_mahcovs.cpp b/src/nn_matchC_mahcovs.cpp new file mode 100644 index 00000000..911f3046 --- /dev/null +++ b/src/nn_matchC_mahcovs.cpp @@ -0,0 +1,346 @@ +// [[Rcpp::depends(RcppProgress)]] +#include "eta_progress_bar.h" +#include "internal.h" +using namespace Rcpp; + +// [[Rcpp::plugins(cpp11)]] + +// [[Rcpp::export]] +IntegerMatrix nn_matchC_mahcovs(const IntegerVector& treat_, + const IntegerVector& ord, + const IntegerVector& ratio, + const LogicalVector& discarded, + const int& reuse_max, + const int& focal_, + const NumericMatrix& mah_covs, + const Nullable& distance_ = R_NilValue, + const Nullable& exact_ = R_NilValue, + const Nullable& caliper_dist_ = R_NilValue, + const Nullable& caliper_covs_ = R_NilValue, + const Nullable& caliper_covs_mat_ = R_NilValue, + const Nullable& antiexact_covs_ = R_NilValue, + const Nullable& unit_id_ = R_NilValue, + const bool& disl_prog = false) { + IntegerVector unique_treat = unique(treat_); + std::sort(unique_treat.begin(), unique_treat.end()); + int g = unique_treat.size(); + IntegerVector treat = match(treat_, unique_treat) - 1; + int focal; + for (focal = 0; focal < g; focal++) { + if (unique_treat[focal] == focal_) { + break; + } + } + + R_xlen_t n = treat.size(); + IntegerVector ind = Range(0, n - 1); + + R_xlen_t i; + int gi; + IntegerVector indt(n); + IntegerVector indt_sep(g + 1); + IntegerVector indt_tmp; + IntegerVector nt(g); + IntegerVector ind_match(n); + ind_match.fill(NA_INTEGER); + + LogicalVector eligible = !discarded; + + IntegerVector g_c = Range(0, g - 1); + g_c = g_c[g_c != focal]; + + IntegerVector n_eligible(g); + for (i = 0; i < n; i++) { + nt[treat[i]]++; + + if (eligible[i]) { + n_eligible[treat[i]]++; + } + } + + int nf = nt[focal]; + + indt_sep[0] = 0; + + for (gi = 0; gi < g; gi++) { + indt_sep[gi + 1] = indt_sep[gi] + nt[gi]; + + indt_tmp = ind[treat == gi]; + + for (i = 0; i < nt[gi]; i++) { + indt[indt_sep[gi] + i] = indt_tmp[i]; + ind_match[indt_tmp[i]] = i; + } + } + + IntegerVector ind_focal = indt[Range(indt_sep[focal], indt_sep[focal + 1] - 1)]; + + std::vector times_matched(n, 0); + + std::vector times_matched_allowed(n, reuse_max); + for (i = 0; i < nf; i++) { + times_matched_allowed[ind_focal[i]] = ratio[i]; + } + + int max_ratio = max(ratio); + + // Output matrix with sample indices of control units + IntegerMatrix mm(nf, max_ratio); + mm.fill(NA_INTEGER); + CharacterVector lab = treat_.names(); + + Function o("order"); + + NumericVector match_var = mah_covs.column(0); + double match_var_caliper = R_PosInf; + + IntegerVector ind_d_ord = o(match_var); + ind_d_ord = ind_d_ord - 1; //location of each unit after sorting + + IntegerVector match_d_ord = o(ind_d_ord); + match_d_ord = match_d_ord - 1; + + //exact + bool use_exact = false; + IntegerVector exact; + if (exact_.isNotNull()) { + exact = as(exact_); + use_exact = true; + } + + //distance & caliper_dist + bool use_caliper_dist = false; + double caliper_dist; + NumericVector distance; + if (caliper_dist_.isNotNull() && distance_.isNotNull()) { + distance = as(distance_); + caliper_dist = as(caliper_dist_); + use_caliper_dist = true; + } + + //caliper_covs + NumericVector caliper_covs; + NumericMatrix caliper_covs_mat; + int ncc = 0; + if (caliper_covs_.isNotNull()) { + caliper_covs = as(caliper_covs_); + caliper_covs_mat = as(caliper_covs_mat_); + ncc = caliper_covs_mat.ncol(); + + //Find if caliper placed on match_var + for (int cci = 0; cci < ncc; cci++) { + if (std::equal(caliper_covs_mat.column(cci).begin(), + caliper_covs_mat.column(cci).end(), + match_var.begin(), + match_var.end())) { + match_var_caliper = caliper_covs[cci]; + break; + } + } + } + + //antiexact + IntegerMatrix antiexact_covs; + int aenc = 0; + if (antiexact_covs_.isNotNull()) { + antiexact_covs = as(antiexact_covs_); + aenc = antiexact_covs.ncol(); + } + + //reuse_max + bool use_reuse_max = (reuse_max < nf); + + //unit_id + IntegerVector unit_id; + bool use_unit_id = false; + if (unit_id_.isNotNull()) { + unit_id = as(unit_id_); + use_unit_id = true; + use_reuse_max = true; + } + + IntegerVector matches_i(1 + max_ratio * (g - 1)); + int k_total; + + //progress bar + int prog_length; + if (use_reuse_max) prog_length = sum(ratio) + 1; + else prog_length = nf + 1; + ETAProgressBar pb; + Progress p(prog_length, disl_prog, pb); + + R_xlen_t c; + int r, t_id_t_i, t_id_i; + IntegerVector ck_; + std::vector k(max_ratio); + + int counter = 0; + + if (use_reuse_max) { + for (r = 1; r <= max_ratio; r++) { + for (auto it = ord.begin(); it != ord.end() && max(as(n_eligible[g_c])) > 0; ++it) { + // i: generic looping index + // t_id_t_i; index of treated unit to match among treated units + // t_id_i: index of treated unit to match among all units + counter++; + if (counter == 200) { + counter = 0; + Rcpp::checkUserInterrupt(); + } + + t_id_t_i = *it - 1; + t_id_i = ind_focal[t_id_t_i]; + + if (r > times_matched_allowed[t_id_i]) { + continue; + } + + p.increment(); + + if (!eligible[t_id_i]) { + continue; + } + + k_total = 0; + + for (int gi : g_c) { + k = find_control_mahcovs(t_id_i, + ind_d_ord, + match_d_ord, + match_var, + match_var_caliper, + treat, + distance, + eligible, + gi, + r, + mm.row(t_id_t_i), + mah_covs, + ncc, + caliper_covs_mat, + caliper_covs, + use_caliper_dist, + caliper_dist, + use_exact, + exact, + aenc, + antiexact_covs); + + if (k.empty()) { + if (r == 1) { + k_total = 0; + break; + } + continue; + } + + matches_i[k_total] = k[0]; + k_total++; + } + + if (k_total == 0) { + eligible[t_id_i] = false; + n_eligible[focal]--; + continue; + } + + for (c = 0; c < k_total; c++) { + mm(t_id_t_i, sum(!is_na(mm(t_id_t_i, _)))) = matches_i[c]; + } + + matches_i[k_total] = t_id_i; + + ck_ = matches_i[Range(0, k_total)]; + + if (use_unit_id) { + ck_ = which(!is_na(match(unit_id, as(unit_id[ck_])))); + } + + for (int ck : ck_) { + if (!eligible[ck]) { + continue; + } + + times_matched[ck]++; + if (times_matched[ck] >= times_matched_allowed[ck]) { + eligible[ck] = false; + n_eligible[treat[ck]]--; + } + } + } + } + } + else { + for (auto it = ord.begin(); it != ord.end(); ++it) { + // i: generic looping index + // t_id_t_i; index of treated unit to match among treated units + // t_id_i: index of treated unit to match among all units + counter++; + if (counter == 200) { + counter = 0; + Rcpp::checkUserInterrupt(); + } + + t_id_t_i = *it - 1; + t_id_i = ind_focal[t_id_t_i]; + + p.increment(); + + if (!eligible[t_id_i]) { + continue; + } + + k_total = 0; + + for (int gi : g_c) { + k = find_control_mahcovs(t_id_i, + ind_d_ord, + match_d_ord, + match_var, + match_var_caliper, + treat, + distance, + eligible, + gi, + 1, + mm.row(t_id_t_i), + mah_covs, + ncc, + caliper_covs_mat, + caliper_covs, + use_caliper_dist, + caliper_dist, + use_exact, + exact, + aenc, + antiexact_covs, + ratio[t_id_t_i]); + + if (k.empty()) { + k_total = 0; + break; + } + + for (int cc : k) { + matches_i[k_total] = cc; + k_total++; + } + } + + if (k_total == 0) { + continue; + } + + for (c = 0; c < k_total; c++) { + mm(t_id_t_i, sum(!is_na(mm(t_id_t_i, _)))) = matches_i[c]; + } + } + } + + p.update(prog_length); + + mm = mm + 1; + rownames(mm) = as(lab[ind_focal]); + + return mm; +} \ No newline at end of file diff --git a/src/nn_matchC_vec.cpp b/src/nn_matchC_vec.cpp index 0b364673..15e7c572 100644 --- a/src/nn_matchC_vec.cpp +++ b/src/nn_matchC_vec.cpp @@ -1,7 +1,5 @@ // [[Rcpp::depends(RcppProgress)]] -#include #include "eta_progress_bar.h" -#include #include "internal.h" using namespace Rcpp; @@ -46,15 +44,18 @@ IntegerMatrix nn_matchC_vec(const IntegerVector& treat_, IntegerVector ind_match(n); ind_match.fill(NA_INTEGER); - IntegerVector times_matched(n); - times_matched.fill(0); LogicalVector eligible = !discarded; IntegerVector g_c = Range(0, g - 1); g_c = g_c[g_c != focal]; - for (gi = 0; gi < g; gi++) { - nt[gi] = sum(treat == gi); + IntegerVector n_eligible(g); + for (i = 0; i < n; i++) { + nt[treat[i]]++; + + if (eligible[i]) { + n_eligible[treat[i]]++; + } } int nf = nt[focal]; @@ -74,15 +75,11 @@ IntegerMatrix nn_matchC_vec(const IntegerVector& treat_, IntegerVector ind_focal = indt[Range(indt_sep[focal], indt_sep[focal + 1] - 1)]; - IntegerVector times_matched_allowed(n); - times_matched_allowed.fill(reuse_max); - times_matched_allowed[ind_focal] = ratio; + std::vector times_matched(n, 0); - IntegerVector n_eligible(g); - for (i = 0; i < n; i++) { - if (eligible[i]) { - n_eligible[treat[i]]++; - } + std::vector times_matched_allowed(n, reuse_max); + for (i = 0; i < nf; i++) { + times_matched_allowed[ind_focal[i]] = ratio[i]; } int max_ratio = max(ratio); @@ -98,7 +95,8 @@ IntegerMatrix nn_matchC_vec(const IntegerVector& treat_, IntegerVector ind_d_ord = o(distance); ind_d_ord = ind_d_ord - 1; //location of each unit after sorting - IntegerVector match_d_ord = match(ind, ind_d_ord) - 1; + IntegerVector match_d_ord = o(ind_d_ord); + match_d_ord = match_d_ord - 1; IntegerVector last_control(g); last_control.fill(n - 1); @@ -140,43 +138,148 @@ IntegerMatrix nn_matchC_vec(const IntegerVector& treat_, aenc = antiexact_covs.ncol(); } + //reuse_max + bool use_reuse_max = (reuse_max < nf); + //unit_id IntegerVector unit_id; bool use_unit_id = false; if (unit_id_.isNotNull()) { unit_id = as(unit_id_); use_unit_id = true; + use_reuse_max = true; } - IntegerVector matches_i(g); + IntegerVector matches_i(1 + max_ratio * (g - 1)); int k_total; //progress bar - int prog_length = sum(ratio) + 1; + int prog_length; + if (use_reuse_max) prog_length = sum(ratio) + 1; + else prog_length = nf + 1; ETAProgressBar pb; Progress p(prog_length, disl_prog, pb); - int r, t_id_t_i, t_id_i, c_id_i, c; + R_xlen_t c; + int r, t_id_t_i, t_id_i; IntegerVector ck_; - // bool check = true; + std::vector k(max_ratio); + + int counter = 0; + + if (use_reuse_max) { + for (r = 1; r <= max_ratio; r++) { + for (auto it = ord.begin(); it != ord.end() && max(as(n_eligible[g_c])) > 0; ++it) { + // i: generic looping index + // t_id_t_i; index of treated unit to match among treated units + // t_id_i: index of treated unit to match among all units + counter++; + if (counter == 200) { + counter = 0; + Rcpp::checkUserInterrupt(); + } + + t_id_t_i = *it - 1; + t_id_i = ind_focal[t_id_t_i]; - int counter = -1; + if (r > times_matched_allowed[t_id_i]) { + continue; + } + + p.increment(); + + if (!eligible[t_id_i]) { + continue; + } + + k_total = 0; + + for (int gi : g_c) { + update_first_and_last_control(first_control, + last_control, + ind_d_ord, + eligible, + treat, + gi); + + k = find_control_vec(t_id_i, + ind_d_ord, + match_d_ord, + treat, + distance, + eligible, + gi, + r, + mm.row(t_id_t_i), + ncc, + caliper_covs_mat, + caliper_covs, + caliper_dist, + use_exact, + exact, + aenc, + antiexact_covs, + first_control, + last_control); + + if (k.empty()) { + if (r == 1) { + k_total = 0; + break; + } + continue; + } + + matches_i[k_total] = k[0]; + k_total++; + } + + if (k_total == 0) { + eligible[t_id_i] = false; + n_eligible[focal]--; + continue; + } + + for (c = 0; c < k_total; c++) { + mm(t_id_t_i, sum(!is_na(mm(t_id_t_i, _)))) = matches_i[c]; + } - for (r = 1; r <= max_ratio; r++) { - for (i = 0; i < nf && max(as(n_eligible[g_c])) > 0; i++) { + matches_i[k_total] = t_id_i; + + ck_ = matches_i[Range(0, k_total)]; + + if (use_unit_id) { + ck_ = which(!is_na(match(unit_id, as(unit_id[ck_])))); + } + + for (int ck : ck_) { + if (!eligible[ck]) { + continue; + } + + times_matched[ck]++; + if (times_matched[ck] >= times_matched_allowed[ck]) { + eligible[ck] = false; + n_eligible[treat[ck]]--; + } + } + } + } + } + else { + for (auto it = ord.begin(); it != ord.end(); ++it) { // i: generic looping index // t_id_t_i; index of treated unit to match among treated units // t_id_i: index of treated unit to match among all units counter++; - if (counter % 200 == 0) Rcpp::checkUserInterrupt(); + if (counter == 200) { + counter = 0; + Rcpp::checkUserInterrupt(); + } - t_id_t_i = ord[i] - 1; + t_id_t_i = *it - 1; t_id_i = ind_focal[t_id_t_i]; - if (r > times_matched_allowed[t_id_i]) { - continue; - } - p.increment(); if (!eligible[t_id_i]) { @@ -186,81 +289,52 @@ IntegerMatrix nn_matchC_vec(const IntegerVector& treat_, k_total = 0; for (int gi : g_c) { - update_first_and_last_control(first_control, - last_control, - ind_d_ord, - eligible, - treat, - gi); - - c_id_i = find_both(t_id_i, - ind_d_ord, - match_d_ord, - treat, - distance, - eligible, - gi, - r, - mm.row(t_id_t_i), - ncc, - caliper_covs_mat, - caliper_covs, - caliper_dist, - use_exact, - exact, - aenc, - antiexact_covs, - first_control, - last_control); - - if (c_id_i < 0) { - if (r == 1) { - k_total = 0; - break; - } - continue; + k = find_control_vec(t_id_i, + ind_d_ord, + match_d_ord, + treat, + distance, + eligible, + gi, + 1, + mm.row(t_id_t_i), + ncc, + caliper_covs_mat, + caliper_covs, + caliper_dist, + use_exact, + exact, + aenc, + antiexact_covs, + first_control, + last_control, + ratio[t_id_t_i]); + + if (k.empty()) { + k_total = 0; + break; } - matches_i[k_total] = c_id_i; - k_total++; + for (int cc : k) { + matches_i[k_total] = cc; + k_total++; + } } if (k_total == 0) { - eligible[t_id_i] = false; - n_eligible[focal]--; continue; } for (c = 0; c < k_total; c++) { mm(t_id_t_i, sum(!is_na(mm(t_id_t_i, _)))) = matches_i[c]; } - - matches_i[k_total] = t_id_i; - - ck_ = matches_i[Range(0, k_total)]; - - if (use_unit_id) { - ck_ = which(!is_na(match(unit_id, as(unit_id[ck_])))); - } - - for (int ck : ck_) { - if (!eligible[ck]) { - continue; - } - - times_matched[ck]++; - if (times_matched[ck] >= times_matched_allowed[ck]) { - eligible[ck] = false; - n_eligible[treat[ck]]--; - } - } } } p.update(prog_length); mm = mm + 1; - rownames(mm) = lab[ind_focal]; + rownames(mm) = as(lab[ind_focal]); return mm; } From 5086c71dbe92604afb45e5164c703f60cc8e8e01 Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Tue, 12 Nov 2024 02:30:05 -0500 Subject: [PATCH 35/48] Updates for performance and to use more STL --- src/nn_matchC_mahcovs_closest.cpp | 382 ++++++++++++++---------------- src/nn_matchC_vec_closest.cpp | 335 +++++++++++++------------- 2 files changed, 351 insertions(+), 366 deletions(-) diff --git a/src/nn_matchC_mahcovs_closest.cpp b/src/nn_matchC_mahcovs_closest.cpp index 667f711e..cf09ff0c 100644 --- a/src/nn_matchC_mahcovs_closest.cpp +++ b/src/nn_matchC_mahcovs_closest.cpp @@ -1,9 +1,6 @@ // [[Rcpp::depends(RcppProgress)]] -#include #include "eta_progress_bar.h" -#include #include "internal.h" -#include using namespace Rcpp; // [[Rcpp::plugins(cpp11)]] @@ -21,6 +18,7 @@ IntegerMatrix nn_matchC_mahcovs_closest(const IntegerVector& treat, const Nullable& caliper_covs_mat_ = R_NilValue, const Nullable& antiexact_covs_ = R_NilValue, const Nullable& unit_id_ = R_NilValue, + const bool& close = true, const bool& disl_prog = false) { IntegerVector unique_treat = {0, 1}; @@ -39,12 +37,18 @@ IntegerMatrix nn_matchC_mahcovs_closest(const IntegerVector& treat, IntegerVector ind_match(n); ind_match.fill(NA_INTEGER); - IntegerVector times_matched(n); - times_matched.fill(0); LogicalVector eligible = !discarded; - for (gi = 0; gi < g; gi++) { - nt[gi] = sum(treat == gi); + IntegerVector g_c = Range(0, g - 1); + g_c = g_c[g_c != focal]; + + IntegerVector n_eligible(g); + for (i = 0; i < n; i++) { + nt[treat[i]]++; + + if (eligible[i]) { + n_eligible[treat[i]]++; + } } int nf = nt[focal]; @@ -64,15 +68,11 @@ IntegerMatrix nn_matchC_mahcovs_closest(const IntegerVector& treat, IntegerVector ind_focal = indt[Range(indt_sep[focal], indt_sep[focal + 1] - 1)]; - IntegerVector times_matched_allowed(n); - times_matched_allowed.fill(reuse_max); - times_matched_allowed[ind_focal] = ratio; + std::vector times_matched(n, 0); - IntegerVector n_eligible(g); - for (i = 0; i < n; i++) { - if (eligible[i]) { - n_eligible[treat[i]]++; - } + std::vector times_matched_allowed(n, reuse_max); + for (i = 0; i < nf; i++) { + times_matched_allowed[ind_focal[i]] = ratio[i]; } int max_ratio = max(ratio); @@ -82,6 +82,17 @@ IntegerMatrix nn_matchC_mahcovs_closest(const IntegerVector& treat, mm.fill(NA_INTEGER); CharacterVector lab = treat.names(); + Function o("order"); + + NumericVector match_var = mah_covs.column(0); + double match_var_caliper = R_PosInf; + + IntegerVector ind_d_ord = o(match_var); + ind_d_ord = ind_d_ord - 1; //location of each unit after sorting + + IntegerVector match_d_ord = o(ind_d_ord); + match_d_ord = match_d_ord - 1; + //exact bool use_exact = false; IntegerVector exact; @@ -90,18 +101,14 @@ IntegerMatrix nn_matchC_mahcovs_closest(const IntegerVector& treat, use_exact = true; } - //distance + //distance & caliper_dist bool use_caliper_dist = false; - double caliper_dist, ps_diff; + double caliper_dist; NumericVector distance; - if (distance_.isNotNull()) { - distance = distance_; - - //caliper_dist - if (caliper_dist_.isNotNull()) { - caliper_dist = as(caliper_dist_); - use_caliper_dist = true; - } + if (caliper_dist_.isNotNull() && distance_.isNotNull()) { + distance = as(distance_); + caliper_dist = as(caliper_dist_); + use_caliper_dist = true; } //caliper_covs @@ -112,6 +119,17 @@ IntegerMatrix nn_matchC_mahcovs_closest(const IntegerVector& treat, caliper_covs = as(caliper_covs_); caliper_covs_mat = as(caliper_covs_mat_); ncc = caliper_covs_mat.ncol(); + + //Find if caliper placed on match_var + for (int cci = 0; cci < ncc; cci++) { + if (std::equal(caliper_covs_mat.column(cci).begin(), + caliper_covs_mat.column(cci).end(), + match_var.begin(), + match_var.end())) { + match_var_caliper = caliper_covs[cci]; + break; + } + } } //antiexact @@ -131,241 +149,199 @@ IntegerMatrix nn_matchC_mahcovs_closest(const IntegerVector& treat, } //storing closeness - IntegerVector t_id = ind_focal; - IntegerVector c_id = rep(-1, nf); - NumericVector dist = rep(R_PosInf, nf); + std::vector t_id, c_id; + std::vector dist; + t_id.reserve(n_eligible[focal]); + c_id.reserve(n_eligible[focal]); + dist.reserve(n_eligible[focal]); //progress bar - int prog_length = nf + sum(ratio) + 1; + R_xlen_t prog_length = n_eligible[focal] + sum(ratio) + 1; ETAProgressBar pb; Progress p(prog_length, disl_prog, pb); gi = 0; - IntegerVector ck_, t_inds; - IntegerVector c_eligible(nt[gi]); - NumericVector match_distance(nt[gi]); + IntegerVector ck_; - R_xlen_t c; - int c_id_i, t_id_t_i; - int t_id_i = -1; - double dist_c; - bool any_match_found; + int c_id_i, t_id_t_i, t_id_i; - int counter = -1; + int counter = 0; int r = 1; - //Find closest control unit to each treated unit - for (i = 0; i < nf; i++) { - counter++; - if (counter % 200 == 0) Rcpp::checkUserInterrupt(); - - p.increment(); - - t_id_i = ind_focal[i]; - - if (!eligible[t_id_i]) { - continue; - } - - t_inds = which(t_id == t_id_i); + IntegerVector heap_ord; + std::vector k; + k.reserve(1); + R_xlen_t hi; - any_match_found = false; + std::function cmp; + if (close) { + cmp = [&dist](const int& a, const int& b) {return dist[a] < dist[b];}; + } + else { + cmp = [&dist](const int& a, const int& b) {return dist[a] >= dist[b];}; + } - for (c = indt_sep[gi]; c < indt_sep[gi + 1]; c++) { - c_id_i = indt[c]; + IntegerVector::iterator ci; - if (!eligible[c_id_i]) { - continue; - } + for (r = 1; r <= max_ratio; r++) { + //Find closest control unit to each treated unit + for (int ti : ind_focal) { - if (!exact_okay(use_exact, t_id_i, c_id_i, exact)) { + if (!eligible[ti]) { continue; } - if (!antiexact_okay(aenc, t_id_i, c_id_i, antiexact_covs)) { - continue; + counter++; + if (counter == 200) { + counter = 0; + Rcpp::checkUserInterrupt(); } - if (use_caliper_dist) { - ps_diff = std::abs(distance[c_id_i] - distance[t_id_i]); - - if (ps_diff > caliper_dist) { - continue; - } - } - - if (!caliper_covs_okay(ncc, t_id_i, c_id_i, caliper_covs_mat, caliper_covs)) { - continue; - } - - //Compute distances among eligible - dist_c = sqrt(sum(pow(mah_covs.row(t_id_i) - mah_covs.row(c_id_i), 2.0))); - - if (!std::isfinite(dist_c)) { + t_id_t_i = ind_match[ti]; + + k = find_control_mahcovs(ti, + ind_d_ord, + match_d_ord, + match_var, + match_var_caliper, + treat, + distance, + eligible, + gi, + r, + mm.row(t_id_t_i), + mah_covs, + ncc, + caliper_covs_mat, + caliper_covs, + use_caliper_dist, + caliper_dist, + use_exact, + exact, + aenc, + antiexact_covs); + + p.increment(); + + if (k.empty()) { + eligible[ti] = false; + n_eligible[focal]--; continue; } - if (any_match_found) { - if (dist_c < dist[i]) { - c_id[i] = c_id_i; - dist[i] = dist_c; - } - } - else { - c_id[i] = c_id_i; - dist[i] = dist_c; - any_match_found = true; - } + t_id.push_back(ti); + c_id.push_back(k[0]); + dist.push_back(sum(pow(mah_covs.row(ti) - mah_covs.row(k[0]), 2.0))); } - if (!any_match_found) { - eligible[t_id_i] = false; - n_eligible[focal]--; - } - } - - //Order the list - // Use base::order() because faster than Rcpp implementation of order() - Function o("order"); - IntegerVector heap_ord = o(dist); - heap_ord = heap_ord - 1; - - //Go down the list; update as needed - R_xlen_t hi; - bool find_new; - - i = 0; - while (min(n_eligible) > 0 && i < nf) { - counter++; - if (counter % 200 == 0) Rcpp::checkUserInterrupt(); - - hi = heap_ord[i]; - - t_id_i = t_id[hi]; - - if (!eligible[t_id_i]) { - i++; - continue; - } - - r = times_matched[t_id_i] + 1; + nf = dist.size(); - t_id_t_i = ind_match[t_id_i]; + //Order the list + heap_ord = o(dist, _["decreasing"] = !close); + heap_ord = heap_ord - 1; - c_id_i = c_id[hi]; + i = 0; + while (min(n_eligible) > 0 && i < nf) { + counter++; + if (counter == 200) { + counter = 0; + Rcpp::checkUserInterrupt(); + } - find_new = false; - if (!eligible[c_id_i]) { - find_new = true; - } - else if (!mm_okay(r, c_id_i, mm.row(t_id_t_i))) { - find_new = true; - } + hi = heap_ord[i]; - if (find_new) { - // If control isn't eligible, find new control and try again - any_match_found = false; + t_id_i = t_id[hi]; - for (c = indt_sep[gi]; c < indt_sep[gi + 1]; c++) { - c_id_i = indt[c]; + if (!eligible[t_id_i]) { + i++; + continue; + } - if (!eligible[c_id_i]) { - continue; - } + t_id_t_i = ind_match[t_id_i]; - //Prevent control units being matched to same treated unit again - if (!mm_okay(r, c_id_i, mm.row(t_id_t_i))) { - continue; - } + c_id_i = c_id[hi]; - if (!exact_okay(use_exact, t_id_i, c_id_i, exact)) { + if (!eligible[c_id_i]) { + // If control isn't eligible, find new control and try again + + k = find_control_mahcovs(t_id_i, + ind_d_ord, + match_d_ord, + match_var, + match_var_caliper, + treat, + distance, + eligible, + gi, + r, + mm.row(t_id_t_i), + mah_covs, + ncc, + caliper_covs_mat, + caliper_covs, + use_caliper_dist, + caliper_dist, + use_exact, + exact, + aenc, + antiexact_covs); + + //If no matches... + if (k.empty()) { + eligible[t_id_i] = false; + n_eligible[focal]--; continue; } - if (!antiexact_okay(aenc, t_id_i, c_id_i, antiexact_covs)) { - continue; - } + c_id[hi] = k[0]; + dist[hi] = sum(pow(mah_covs.row(t_id_i) - mah_covs.row(k[0]), 2.0)); - if (use_caliper_dist) { - ps_diff = std::abs(distance[c_id_i] - distance[t_id_i]); + // Find new position of pair in heap + ci = std::lower_bound(heap_ord.begin() + i, heap_ord.end(), hi, cmp); - if (ps_diff > caliper_dist) { - continue; - } + if (ci != heap_ord.begin() + i) { + std::rotate(heap_ord.begin() + i, heap_ord.begin() + i + 1, ci); } - if (!caliper_covs_okay(ncc, t_id_i, c_id_i, caliper_covs_mat, caliper_covs)) { - continue; - } + continue; + } - //Compute distances among eligible - dist_c = sum(pow(mah_covs.row(t_id_i) - mah_covs.row(c_id_i), 2.0)); + mm(t_id_t_i, sum(!is_na(mm(t_id_t_i, _)))) = c_id_i; - if (!std::isfinite(dist_c)) { - continue; - } + ck_ = {c_id_i, t_id_i}; - if (any_match_found) { - if (dist_c < dist[hi]) { - c_id[hi] = c_id_i; - dist[hi] = dist_c; - } - } - else { - c_id[hi] = c_id_i; - dist[hi] = dist_c; - any_match_found = true; - } + if (use_unit_id) { + ck_ = which(!is_na(match(unit_id, as(unit_id[ck_])))); } - //If no matches... - if (!any_match_found) { - eligible[t_id_i] = false; - n_eligible[focal]--; - continue; - } + for (int ck : ck_) { - //Find new position of pair in heap - for (c = i; c < nf - 1; c++) { - if (dist[heap_ord[c]] < dist[heap_ord[c + 1]]) { - break; + if (!eligible[ck]) { + continue; } - swap_pos(heap_ord, c, c + 1); + times_matched[ck]++; + if (times_matched[ck] >= times_matched_allowed[ck]) { + eligible[ck] = false; + n_eligible[treat[ck]]--; + } } - continue; - } - - mm(t_id_t_i, sum(!is_na(mm(t_id_t_i, _)))) = c_id_i; - - ck_ = {c_id_i, t_id_i}; - - if (use_unit_id) { - ck_ = which(!is_na(match(unit_id, as(unit_id[ck_])))); - } - - for (int ck : ck_) { - - if (!eligible[ck]) { - continue; - } + p.increment(); - times_matched[ck]++; - if (times_matched[ck] >= times_matched_allowed[ck]) { - eligible[ck] = false; - n_eligible[treat[ck]]--; - } + i++; } - p.increment(); + t_id.clear(); + c_id.clear(); + dist.clear(); } p.update(prog_length); mm = mm + 1; - rownames(mm) = lab[ind_focal]; + rownames(mm) = as(lab[ind_focal]); return mm; -} +} \ No newline at end of file diff --git a/src/nn_matchC_vec_closest.cpp b/src/nn_matchC_vec_closest.cpp index 0b2abd18..10748795 100644 --- a/src/nn_matchC_vec_closest.cpp +++ b/src/nn_matchC_vec_closest.cpp @@ -1,9 +1,6 @@ // [[Rcpp::depends(RcppProgress)]] -#include #include "eta_progress_bar.h" -#include #include "internal.h" -#include using namespace Rcpp; // [[Rcpp::plugins(cpp11)]] @@ -20,6 +17,7 @@ IntegerMatrix nn_matchC_vec_closest(const IntegerVector& treat, const Nullable& caliper_covs_mat_ = R_NilValue, const Nullable& antiexact_covs_ = R_NilValue, const Nullable& unit_id_ = R_NilValue, + const bool& close = true, const bool& disl_prog = false) { IntegerVector unique_treat = {0, 1}; @@ -38,15 +36,18 @@ IntegerMatrix nn_matchC_vec_closest(const IntegerVector& treat, IntegerVector ind_match(n); ind_match.fill(NA_INTEGER); - IntegerVector times_matched(n); - times_matched.fill(0); LogicalVector eligible = !discarded; // IntegerVector g_c = Range(0, g - 1); // g_c = g_c[g_c != focal]; - for (gi = 0; gi < g; gi++) { - nt[gi] = sum(treat == gi); + IntegerVector n_eligible(g); + for (i = 0; i < n; i++) { + nt[treat[i]]++; + + if (eligible[i]) { + n_eligible[treat[i]]++; + } } int nf = nt[focal]; @@ -66,15 +67,11 @@ IntegerMatrix nn_matchC_vec_closest(const IntegerVector& treat, IntegerVector ind_focal = indt[Range(indt_sep[focal], indt_sep[focal + 1] - 1)]; - IntegerVector times_matched_allowed(n); - times_matched_allowed.fill(reuse_max); - times_matched_allowed[ind_focal] = ratio; + std::vector times_matched(n, 0); - IntegerVector n_eligible(g); - for (i = 0; i < n; i++) { - if (eligible[i]) { - n_eligible[treat[i]]++; - } + std::vector times_matched_allowed(n, reuse_max); + for (i = 0; i < nf; i++) { + times_matched_allowed[ind_focal[i]] = ratio[i]; } int max_ratio = max(ratio); @@ -84,13 +81,14 @@ IntegerMatrix nn_matchC_vec_closest(const IntegerVector& treat, mm.fill(NA_INTEGER); CharacterVector lab = treat.names(); - //Use base::order() because faster than Rcpp implementation of order() + //Use base::order() because faster than C++ std::sort() Function o("order"); IntegerVector ind_d_ord = o(distance); - ind_d_ord = ind_d_ord - 1; //location of each unit after sorting + ind_d_ord = ind_d_ord - 1; - IntegerVector match_d_ord = match(ind, ind_d_ord) - 1; + IntegerVector match_d_ord = o(ind_d_ord); + match_d_ord = match_d_ord - 1; IntegerVector last_control(g); last_control.fill(n - 1); @@ -141,25 +139,14 @@ IntegerMatrix nn_matchC_vec_closest(const IntegerVector& treat, } //storing closeness - IntegerVector t_id = rep_each(ind_focal, 2); - IntegerVector c_id = rep(-1, 2 * nf); - NumericVector dist = rep(R_PosInf, 2 * nf); - for (i = 0; i < nf; i++) { - c_id[2 * i] = -2; - } - - LogicalVector skipped_once = rep(false, nf); - - //progress bar - int prog_length = sum(ratio) + 1; - ETAProgressBar pb; - Progress p(prog_length, disl_prog, pb); - - IntegerVector ck_; - - int t_id_i, c_id_i, t_id_t_i, c, k; + std::vector t_id, c_id; + std::vector dist; + t_id.reserve(n_eligible[focal]); + c_id.reserve(n_eligible[focal]); + dist.reserve(n_eligible[focal]); gi = 0; + update_first_and_last_control(first_control, last_control, ind_d_ord, @@ -167,173 +154,195 @@ IntegerMatrix nn_matchC_vec_closest(const IntegerVector& treat, treat, gi); - //Find left and right matches for each treated unit - for (i = 0; i < (2 * nf); i++) { - t_id_i = t_id[i]; + IntegerVector ck_; - if (!eligible[t_id_i]) { - continue; - } + int c_id_i, t_id_t_i, t_id_i; - c_id_i = c_id[i]; - - k = find_lr(c_id_i, - t_id_i, - ind_d_ord, - match_d_ord, - treat, - distance, - eligible, - gi, - ncc, - caliper_covs_mat, - caliper_covs, - caliper_dist, - use_exact, - exact, - aenc, - antiexact_covs, - first_control, - last_control); - - if (k < 0) { - continue; - } + int counter = 0; + int r = 1; - c_id[i] = k; - dist[i] = std::abs(distance[t_id_i] - distance[k]); - } + IntegerVector heap_ord; + std::vector k(1); + R_xlen_t hi; - //Order the list - IntegerVector heap_ord = o(dist); - heap_ord = heap_ord - 1; + //progress bar + R_xlen_t prog_length = sum(ratio) + 1; + ETAProgressBar pb; + Progress p(prog_length, disl_prog, pb); - //Go down the list; update as needed - int hi; - int counter = -1; - bool find_new = false; + IntegerVector::iterator ci; - i = 0; - while (min(n_eligible) > 0 && i < (2 * nf)) { + std::function cmp; + if (close) { + cmp = [&dist](const int& a, const int& b) {return dist[a] < dist[b];}; + } + else { + cmp = [&dist](const int& a, const int& b) {return dist[a] >= dist[b];}; + } - counter++; - if (counter % 200 == 0) Rcpp::checkUserInterrupt(); + for (r = 1; r <= max_ratio; r++) { + //Find closest control unit to each treated unit + for (int ti : ind_focal) { - hi = heap_ord[i]; + if (!eligible[ti]) { + continue; + } - if (dist[hi] >= caliper_dist) { - break; - } + counter++; + if (counter == 200) { + counter = 0; + Rcpp::checkUserInterrupt(); + } - t_id_i = t_id[hi]; + t_id_t_i = ind_match[ti]; + + k = find_control_vec(ti, + ind_d_ord, + match_d_ord, + treat, + distance, + eligible, + gi, + r, + mm.row(t_id_t_i), + ncc, + caliper_covs_mat, + caliper_covs, + caliper_dist, + use_exact, + exact, + aenc, + antiexact_covs, + first_control, + last_control); + + if (k.empty()) { + eligible[ti] = false; + n_eligible[focal]--; + continue; + } - if (!eligible[t_id_i]) { - i++; - continue; + t_id.push_back(ti); + c_id.push_back(k[0]); + dist.push_back(std::abs(distance[ti] - distance[k[0]])); } - c_id_i = c_id[hi]; + nf = dist.size(); - t_id_t_i = ind_match[t_id_i]; + //Order the list + heap_ord = o(dist, _["decreasing"] = !close); + heap_ord = heap_ord - 1; - if (c_id_i < 0) { - if (skipped_once[t_id_t_i]) { - eligible[t_id_i] = false; - n_eligible[focal]--; - } - else { - skipped_once[t_id_t_i] = true; + i = 0; + + //Go through ordered list and assign matches, re-matching when necessary + while (min(n_eligible) > 0 && i < nf) { + counter++; + if (counter == 200) { + counter = 0; + Rcpp::checkUserInterrupt(); } - i++; - continue; - } - find_new = false; - if (!eligible[c_id_i]) { - find_new = true; - } - else if (!mm_okay(times_matched[t_id_i] + 1, c_id_i, mm.row(t_id_t_i))) { - find_new = true; - } + hi = heap_ord[i]; + + t_id_i = t_id[hi]; - if (find_new) { - // If control isn't eligible, find new control and try again - - update_first_and_last_control(first_control, - last_control, - ind_d_ord, - eligible, - treat, - gi); - - k = find_lr(c_id_i, - t_id_i, - ind_d_ord, - match_d_ord, - treat, - distance, - eligible, - gi, - ncc, - caliper_covs_mat, - caliper_covs, - caliper_dist, - use_exact, - exact, - aenc, - antiexact_covs, - first_control, - last_control); - - c_id[hi] = k; - - //If no new control found, continue - if (k < 0) { + if (!eligible[t_id_i]) { + i++; continue; } - dist[hi] = std::abs(distance[t_id_i] - distance[k]); + t_id_t_i = ind_match[t_id_i]; + + c_id_i = c_id[hi]; + + if (!eligible[c_id_i]) { + // If control isn't eligible, find new control and try again + update_first_and_last_control(first_control, + last_control, + ind_d_ord, + eligible, + treat, + gi); + + k = find_control_vec(t_id_i, + ind_d_ord, + match_d_ord, + treat, + distance, + eligible, + gi, + r, + mm.row(t_id_t_i), + ncc, + caliper_covs_mat, + caliper_covs, + caliper_dist, + use_exact, + exact, + aenc, + antiexact_covs, + first_control, + last_control, + 1, + c_id_i); + + //If no matches... + if (k.empty()) { + eligible[t_id_i] = false; + n_eligible[focal]--; + continue; + } + + c_id[hi] = k[0]; + dist[hi] = std::abs(distance[t_id_i] - distance[k[0]]); - //Find new position of pair in heap - for (c = i; c < (2 * nf) - 1; c++) { - if (dist[heap_ord[c]] < dist[heap_ord[c + 1]]) { - break; + // Find new position of pair in heap + ci = std::lower_bound(heap_ord.begin() + i, heap_ord.end(), hi, + cmp); + + if (ci != heap_ord.begin() + i) { + std::rotate(heap_ord.begin() + i, heap_ord.begin() + i + 1, ci); } - swap_pos(heap_ord, c, c + 1); + continue; } - continue; - } + mm(t_id_t_i, sum(!is_na(mm(t_id_t_i, _)))) = c_id_i; - mm(t_id_t_i, sum(!is_na(mm(t_id_t_i, _)))) = c_id_i; + ck_ = {c_id_i, t_id_i}; - ck_ = {c_id_i, t_id_i}; + if (use_unit_id) { + ck_ = which(!is_na(match(unit_id, as(unit_id[ck_])))); + } - if (use_unit_id) { - ck_ = which(!is_na(match(unit_id, as(unit_id[ck_])))); - } + for (int ck : ck_) { - for (int ck : ck_) { + if (!eligible[ck]) { + continue; + } - if (!eligible[ck]) { - continue; + times_matched[ck]++; + if (times_matched[ck] >= times_matched_allowed[ck]) { + eligible[ck] = false; + n_eligible[treat[ck]]--; + } } - times_matched[ck]++; - if (times_matched[ck] >= times_matched_allowed[ck]) { - eligible[ck] = false; - n_eligible[treat[ck]]--; - } + p.increment(); + + i++; } - p.increment(); + t_id.clear(); + c_id.clear(); + dist.clear(); } p.update(prog_length); mm = mm + 1; - rownames(mm) = lab[ind_focal]; + rownames(mm) = as(lab[ind_focal]); return mm; -} +} \ No newline at end of file From 1a867540b98c4ce7aeb8ba07f030775303e9ac36 Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Tue, 12 Nov 2024 02:30:12 -0500 Subject: [PATCH 36/48] Rcpp updates --- src/internal.h | 115 +++++++++++++++++++++++++++++++------------------ 1 file changed, 73 insertions(+), 42 deletions(-) diff --git a/src/internal.h b/src/internal.h index 1a42e007..86c08b19 100644 --- a/src/internal.h +++ b/src/internal.h @@ -2,6 +2,11 @@ #define INTERNAL_H #include +#include +#include +#include +#include +#include using namespace Rcpp; IntegerVector tabulateC_(const IntegerVector& bins, @@ -9,44 +14,68 @@ IntegerVector tabulateC_(const IntegerVector& bins, IntegerVector which(const LogicalVector& x); -int find_both(const int& t_id, - const IntegerVector& ind_d_ord, - const IntegerVector& match_d_ord, - const IntegerVector& treat, - const NumericVector& distance, - const LogicalVector& eligible, - const int& gi, - const int& r, - const IntegerVector& mm_ordi, - const int& ncc, - const NumericMatrix& caliper_covs_mat, - const NumericVector& caliper_covs, - const double& caliper_dist, - const bool& use_exact, - const IntegerVector& exact, - const int& aenc, - const IntegerMatrix& antiexact_covs, - const IntegerVector& first_control, - const IntegerVector& last_control); +std::vector find_control_vec(const int& t_id, + const IntegerVector& ind_d_ord, + const IntegerVector& match_d_ord, + const IntegerVector& treat, + const NumericVector& distance, + const LogicalVector& eligible, + const int& gi, + const int& r, + const IntegerVector& mm_rowi_, + const int& ncc, + const NumericMatrix& caliper_covs_mat, + const NumericVector& caliper_covs, + const double& caliper_dist, + const bool& use_exact, + const IntegerVector& exact, + const int& aenc, + const IntegerMatrix& antiexact_covs, + const IntegerVector& first_control, + const IntegerVector& last_control, + const int& ratio = 1, + const int& prev_start = -1); -int find_lr(const int& prev_match, - const int& t_id, - const IntegerVector& ind_d_ord, - const IntegerVector& match_d_ord, - const IntegerVector& treat, - const NumericVector& distance, - const LogicalVector& eligible, - const int& gi, - const int& ncc, - const NumericMatrix& caliper_covs_mat, - const NumericVector& caliper_covs, - const double& caliper_dist, - const bool& use_exact, - const IntegerVector& exact, - const int& aenc, - const IntegerMatrix& antiexact_covs, - const IntegerVector& first_control, - const IntegerVector& last_control); +std::vector find_control_mahcovs(const int& t_id, + const IntegerVector& ind_d_ord, + const IntegerVector& match_d_ord, + const NumericVector& match_var, + const double& match_var_caliper, + const IntegerVector& treat, + const NumericVector& distance, + const LogicalVector& eligible, + const int& gi, + const int& r, + const IntegerVector& mm_rowi, + const NumericMatrix& mah_covs, + const int& ncc, + const NumericMatrix& caliper_covs_mat, + const NumericVector& caliper_covs, + const bool& use_caliper_dist, + const double& caliper_dist, + const bool& use_exact, + const IntegerVector& exact, + const int& aenc, + const IntegerMatrix& antiexact_covs, + const int& ratio = 1); + +std::vector find_control_mat(const int& t_id, + const IntegerVector& treat, + const IntegerVector& ind_non_focal, + const NumericVector& distance_mat_row_i, + const LogicalVector& eligible, + const int& gi, + const int& r, + const IntegerVector& mm_rowi, + const int& ncc, + const NumericMatrix& caliper_covs_mat, + const NumericVector& caliper_covs, + const double& caliper_dist, + const bool& use_exact, + const IntegerVector& exact, + const int& aenc, + const IntegerMatrix& antiexact_covs, + const int& ratio = 1); bool antiexact_okay(const int& aenc, const int& i, @@ -59,6 +88,12 @@ bool caliper_covs_okay(const int& ncc, const NumericMatrix& caliper_covs_mat, const NumericVector& caliper_covs); +bool caliper_dist_okay(const bool& use_caliper_dist, + const int& i, + const int& j, + const NumericVector& distance, + const double& caliper_dist); + bool mm_okay(const int& r, const int& i, const IntegerVector& mm_rowi); @@ -68,10 +103,6 @@ bool exact_okay(const bool& use_exact, const int& j, const IntegerVector& exact); -void swap_pos(IntegerVector x, - const int& a, - const int& b); - double max_finite(const NumericVector& x); double min_finite(const NumericVector& x); @@ -83,4 +114,4 @@ void update_first_and_last_control(IntegerVector first_control, const IntegerVector& treat, const int& gi); -#endif \ No newline at end of file +#endif From 583a68719fda9d4dbf2b1fd9a209209e15430bab Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Tue, 12 Nov 2024 02:30:34 -0500 Subject: [PATCH 37/48] Vignette updates --- vignettes/matching-methods.Rmd | 16 ++++++++++------ vignettes/references.bib | 14 ++++++++++++++ 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/vignettes/matching-methods.Rmd b/vignettes/matching-methods.Rmd index ac6483a3..a9e229a2 100644 --- a/vignettes/matching-methods.Rmd +++ b/vignettes/matching-methods.Rmd @@ -52,6 +52,8 @@ Nearest neighbor matching requires the specification of a distance measure to de When using a matching ratio greater than 1 (i.e., when more than 1 control units are requested to be matched to each treated unit), matching occurs in a cycle, where each treated unit is first paired with one control unit, and then each treated unit is paired with a second control unit, etc. Ties are broken deterministically based on the order of the units in the dataset to ensure that multiple runs of the same specification yield the same result (unless the matching order is requested to be random). +Nearest neighbor matching is implemented in `MatchIt` using internal C++ code through `Rcpp`. When matching on a propensity score, this makes matching extremely fast, even for large datasets. Using a caliper on the propensity score (described below) makes it even faster. Run times may be a bit longer when matching on other distance measures (e.g., the Mahalanobis distance). In contrast to optimal pair matching (described below), nearest neighbor matching does not require computing the full distance matrix between units, which makes it more applicable to large datasets. + ### Optimal Pair Matching (`method = "optimal"`) Optimal pair matching (often just called optimal matching) is very similar to nearest neighbor matching in that it attempts to pair each treated unit with one or more control units. Unlike nearest neighbor matching, however, it is "optimal" rather than greedy; it is optimal in the sense that it attempts to choose matches that collectively optimize an overall criterion [@hansen2006; @gu1993]. The criterion used is the sum of the absolute pair distances in the matched sample. See `?method_optimal` for the documentation for `matchit()` with `method = "optimal"`. Optimal pair matching in `MatchIt` depends on the `fullmatch()` function in the `optmatch` package [@hansen2006]. @@ -121,7 +123,7 @@ Cardinality and profile matching are pure subset selection methods that involve Subset selection is performed by solving a mixed integer programming optimization problem with linear constraints. The problem involves maximizing the size of the matched sample subject to constraints on balance and sample size. For cardinality matching, the balance constraints refer to the mean difference for each covariate between the matched treated and control groups, and the sample size constraints require the matched treated and control groups to be the same size (or differ by a user-supplied factor). For profile matching, the balance constraints refer to the mean difference for each covariate between each treatment group and the target distribution; for the ATE, this requires the mean of each covariate in each treatment group to be within a given tolerance of the mean of the covariate in the full sample, and for the ATT, this requires the mean of each covariate in the control group to be within a given tolerance of the mean of the covariate in the treated group, which is left intact. The balance tolerances are controlled by the `tols` and `std.tols` arguments. One can also create pairs in the matched sample by using the `mahvars` argument, which requests that optimal Mahalanobis matching be done after subset selection; doing so can add additional precision and robustness [@zubizarretaMatchingBalancePairing2014]. -The optimization problem requires a special solver to solve. Currently, the available options in `MatchIt` are the GLPK solver (through the `Rglpk` package), the SYMPHONY solver (through the `Rsymphony` package), and the Gurobi solver (through the `gurobi` package). The differences among the solvers are in performance; Gurobi is by far the best (fastest, least likely to fail to find a solution), but it is proprietary (though has a free trial and academic license) and is a bit more complicated to install. The `designmatch` package also provides an implementation of cardinality matching with more options than `MatchIt` offers. +The optimization problem requires a special solver to solve. Currently, the available options in `MatchIt` are the HiGHS solver (through the `highs` package), the GLPK solver (through the `Rglpk` package), the SYMPHONY solver (through the `Rsymphony` package), and the Gurobi solver (through the `gurobi` package). The differences among the solvers are in performance; Gurobi is by far the best (fastest, least likely to fail to find a solution), but it is proprietary (though has a free trial and academic license) and is a bit more complicated to install. HiGHS is the default due to being open source, easily installed, and with performance comparable to Gurobi. The `designmatch` package also provides an implementation of cardinality matching with more options than `MatchIt` offers. ## Customizing the Matching Specification @@ -163,19 +165,21 @@ Anti-exact matching adds a restriction such that a treated and control unit with ### Matching with replacement (`replace`) -Nearest neighbor matching and genetic matching have the option of matching with or without replacement, and this is controlled by the `replace` argument. Matching without replacement means that each control unit is matched to only one treated unit, while matching with replacement means that control units can be reused and matched to multiple treated units. Matching without replacement carries certain statistical benefits in that weights for each unit can be omitted or are more straightforward to include and dependence between units depends only on pair membership. Special standard error estimators are sometimes required for estimating effects after matching with replacement [@austin2020a], and methods for accounting for uncertainty are not well understood for non-continuous outcomes. Matching with replacement will tend to yield better balance though, because the problem of "running out" of close control units to match to treated units is avoided, though the reuse of control units will decrease the effect sample size, thereby worsening precision [@austin2013b]. (This problem occurs in the Lalonde dataset used in `vignette("MatchIt")`, which is why nearest neighbor matching without replacement is not very effective there.) After matching with replacement, control units are assigned to more than one subclass, so the `get_matches()` function should be used instead of `match.data()` after matching with replacement if subclasses are to be used in follow-up analyses; see `vignette("estimating-effects")` for details. +Nearest neighbor matching and genetic matching have the option of matching with or without replacement, and this is controlled by the `replace` argument. Matching without replacement means that each control unit is matched to only one treated unit, while matching with replacement means that control units can be reused and matched to multiple treated units. Matching without replacement carries certain statistical benefits in that weights for each unit can be omitted or are more straightforward to include and dependence between units depends only on pair membership. However, it is not asymptotically consistent unless the propensity scores for all treated units are below .5 and there are many more control units than treated units [@savjeInconsistencyMatchingReplacement2022]. Special standard error estimators are sometimes required for estimating effects after matching with replacement [@austin2020a], and methods for accounting for uncertainty are not well understood for non-continuous outcomes. Matching with replacement will tend to yield better balance though, because the problem of "running out" of close control units to match to treated units is avoided, though the reuse of control units will decrease the effect sample size, thereby worsening precision [@austin2013b]. (This problem occurs in the Lalonde dataset used in `vignette("MatchIt")`, which is why nearest neighbor matching without replacement is not very effective there.) After matching with replacement, control units are assigned to more than one subclass, so the `get_matches()` function should be used instead of `match.data()` after matching with replacement if subclasses are to be used in follow-up analyses; see `vignette("estimating-effects")` for details. The `reuse.max` argument can also be used with `method = "nearest"` to control how many times each control unit can be reused as a match. Setting `reuse.max = 1` is equivalent to requiring matching without replacement (i.e., because each control can be used only once). Other values allow control units to be matched more than once, though only up to the specified number of times. Higher values will tend to improve balance at the cost of precision. ### $k$:1 matching (`ratio`) -The most common form of matching, 1:1 matching, involves pairing one control unit with each treated unit. To perform $k$:1 matching (e.g., 2:1 or 3:1), which pairs (up to) $k$ control units with each treated unit, the `ratio` argument can be specified. Performing $k$:1 matching can preserve precision by preventing too many control units from being unmatched and dropped from the matched sample, though the gain in precision by increasing $k$ diminishes rapidly after 4 [@rosenbaum2020]. Importantly, for $k>1$, the matches after the first match will generally be worse than the first match in terms of closeness to the treated unit, so increasing $k$ can also worsen balance. @austin2010a found that 1:1 or 1:2 matching generally performed best in terms of mean squared error. In general, it makes sense to use higher values of $k$ while ensuring that balance is satisfactory. +The most common form of matching, 1:1 matching, involves pairing one control unit with each treated unit. To perform $k$:1 matching (e.g., 2:1 or 3:1), which pairs (up to) $k$ control units with each treated unit, the `ratio` argument can be specified. Performing $k$:1 matching can preserve precision by preventing too many control units from being unmatched and dropped from the matched sample, though the gain in precision by increasing $k$ diminishes rapidly after 4 [@rosenbaum2020]. Importantly, for $k>1$, the matches after the first match will generally be worse than the first match in terms of closeness to the treated unit, so increasing $k$ can also worsen balance [@rassenOnetomanyPropensityScore2012]. @austin2010a found that 1:1 or 1:2 matching generally performed best in terms of mean squared error. In general, it makes sense to use higher values of $k$ while ensuring that balance is satisfactory. With nearest neighbor and optimal pair matching, variable $k$:1 matching, in which the number of controls matched to each treated unit varies, can also be used; this can have improved performance over "fixed" $k$:1 matching [@ming2000; @rassenOnetomanyPropensityScore2012]. See `?method_nearest` and `?method_optimal` for information on implementing variable $k$:1 matching. ### Matching order (`m.order`) -For nearest neighbor matching (including genetic matching), units are matched in an order, and that order can affect the quality of individual matches and of the resulting matched sample. With `method = "nearest"`, the allowable options to `m.order` to control the matching order are `"largest"`, `"smallest"`, `"closest"`, `"random"`, and `"data"`. With `method = "genetic"`, all but `"closest"` can be used. Requesting `"largest"` means that treated units with the largest propensity scores, i.e., those least like the control units, will be matched first, which prevents them from having bad matches after all the close control units have been used up. `"smallest"` means that treated units with the smallest propensity scores are matched first. `"closest"` means that potential pairs with the smallest distance between units will be matched first, which ensures that the best possible matches are included in the matched sample but can yield poor matches for units whose best match is far from them; this makes it particularly useful when matching with a caliper. `"random"` matches in a random order and `"data"` matches in order of the data. A propensity score is required for `"largest"` and `"smallest"` but not for the other options. @rubin1973 recommends using `"largest"` or `"random"`, though @austin2013b recommends against `"largest"` and instead favors `"closest"` or `"random"`. +For nearest neighbor matching (including genetic matching), units are matched in an order, and that order can affect the quality of individual matches and of the resulting matched sample. With `method = "nearest"`, the allowable options to `m.order` to control the matching order are `"largest"`, `"smallest"`, `"closest"`, `"farthest"`, `"random"`, and `"data"`. With `method = "genetic"`, all but `"closest"` and `"farthest"` can be used. Requesting `"largest"` means that treated units with the largest propensity scores, i.e., those least like the control units, will be matched first, which prevents them from having bad matches after all the close control units have been used up. `"smallest"` means that treated units with the smallest propensity scores are matched first. `"closest"` means that potential pairs with the smallest distance between units will be matched first, which ensures that the best possible matches are included in the matched sample but can yield poor matches for units whose best match is far from them; this makes it particularly useful when matching with a caliper. `"farthest"` means that closest pairs with the largest distance between them will be matched first, which ensures the hardest units to match are given the best chance to find matches. `"random"` matches in a random order, and `"data"` matches in order of the data. A propensity score is required for `"largest"` and `"smallest"` but not for the other options. + +@rubin1973 recommends using `"largest"` or `"random"`, though @austin2013b recommends against `"largest"` and instead favors `"closest"` or `"random"`. `"closest"` and `"smallest"` are best for prioritizing the best possible matches, while `"farthest"` and `"largest"` are best for preventing extreme pairwise distances between matched units. ## Choosing a Matching Method @@ -187,7 +191,7 @@ If the target of inference is the ATE, optimal or generalized full matching, sub Because exact and coarsened exact matching aim to balance the entire joint distribution of covariates, they are the most powerful methods. If it is possible to perform exact matching, this method should be used. If continuous covariates are present, coarsened exact matching can be tried. Care should be taken with retaining the target population and ensuring enough matched units remain; unless the control pool is much larger than the treated pool, it is likely some (or many) treated units will be discarded, thereby changing the estimand and possibly dramatically reducing precision. These methods are typically only available in the most optimistic of circumstances, but they should be used first when those circumstances arise. It may also be useful to combine exact or coarsened exact matching on some covariates with another form of matching on the others (i.e., by using the `exact` argument). -When estimating the ATE, either subclassification, full matching, or profile matching can be used. Optimal and generalized full matching can be effective because they optimize a balance criterion, often leading to better balance. With full matching, it's also possible to exact match on some variables and match using the Mahalanobis distance, eliminating the need to estimate propensity scores. Profile matching also ensures good balance, but because units are only given weights of zero or one, a solution may not be feasible and many units may have to be discarded. For large datasets, neither optimal full matching nor profile matching may be possible, in which case generalized full matching and subclassification are faster solutions. When using subclassification, the number of subclasses should be varied. With large samples, higher numbers of subclasses tend to yield better performance; one should not immediately settle for the default (6) or the often-cited recommendation of 5 without trying several other numbers. +When estimating the ATE, either subclassification, full matching, or profile matching can be used. Optimal and generalized full matching can be effective because they optimize a balance criterion, often leading to better balance. With full matching, it's also possible to exact match on some variables and match using the Mahalanobis distance, eliminating the need to estimate propensity scores. Profile matching also ensures good balance, but because units are only given weights of zero or one, a solution may not be feasible and many units may have to be discarded. For large datasets, neither optimal full matching nor profile matching may be possible, in which case generalized full matching and subclassification are faster solutions. When using subclassification, the number of subclasses should be varied. With large samples, higher numbers of subclasses tend to yield better performance; one should not immediately settle for the default (6) or the often-cited recommendation of 5 without trying several other numbers. The documentation for `cobalt::bal.compute()` contains an example of using balance to select the optimal number of subclasses. When estimating the ATT, a variety of methods can be tried. Genetic matching can perform well at achieving good balance because it directly optimizes covariate balance. With larger datasets, it may take a long time to reach a good solution (though that solution will tend to be good as well). Profile matching also will achieve good balance if a solution is feasible because balance is controlled by the user. Optimal pair matching and nearest neighbor matching without replacement tend to perform similarly to each other; nearest neighbor matching may be preferable for large datasets that cannot be handled by optimal matching. Nearest neighbor, optimal, and genetic matching allow some customizations like including covariates on which to exactly match, using the Mahalanobis distance instead of a propensity score difference, and performing $k$:1 matching with $k>1$. Nearest neighbor matching with replacement, full matching, and subclassification all involve weighting the control units with nonuniform weights, which often allows for improved balancing capabilities but can be accompanied by a loss in effective sample size, even when all units are retained. There is no reason not to try many of these methods, varying parameters here and there, in search of good balance and high remaining sample size. As previously mentioned, no single method can be recommended above all others because the optimal specification depends on the unique qualities of each dataset. @@ -195,7 +199,7 @@ When the target population is less important, for example, when engaging in trea It is important not to rely excessively on theoretical or simulation-based findings or specific recommendations when making choices about the best matching method to use. For example, although nearest neighbor matching without replacement balance covariates better than did subclassification with five or ten subclasses in Austin's [-@austin2009c] simulation, this does not imply it will be superior in all datasets. Likewise, though @rosenbaum1985a and @austin2011a both recommend using a caliper of .2 standard deviations of the logit of the propensity score, this does not imply that caliper will be optimal in all scenarios, and other widths should be tried, though it should be noted that tightening the caliper on the propensity score can sometimes degrade performance [@king2019]. -For large datasets (i.e., in 10,000s to millions), some matching methods will be too slow to be used at scale. Instead, users should consider generalized full matching, subclassification, or coarsened exact matching, which are all very fast and designed to work with large datasets. +For large datasets (i.e., in 10,000s to millions), some matching methods will be too slow to be used at scale. Instead, users should consider generalized full matching, subclassification, or coarsened exact matching, which are all very fast and designed to work with large datasets. Nearest neighbor matching on the propensity score has been optimized to run quickly for large datasets as well. ## Reporting the Matching Specification diff --git a/vignettes/references.bib b/vignettes/references.bib index d2cc7fb1..f6a9929e 100644 --- a/vignettes/references.bib +++ b/vignettes/references.bib @@ -1655,3 +1655,17 @@ @article{rassenOnetomanyPropensityScore2012 note = {{\_}eprint: https://onlinelibrary.wiley.com/doi/pdf/10.1002/pds.3263}, langid = {en} } + +@article{savjeInconsistencyMatchingReplacement2022, + title = {On the inconsistency of matching without replacement}, + author = {{Sävje}, F}, + year = {2022}, + month = {06}, + date = {2022-06-01}, + journal = {Biometrika}, + pages = {551--558}, + volume = {109}, + number = {2}, + doi = {10.1093/biomet/asab035}, + url = {https://doi.org/10.1093/biomet/asab035} +} From 2c695922aeeb6ddacaea278a60b18dfba212abbb Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Tue, 12 Nov 2024 02:30:42 -0500 Subject: [PATCH 38/48] Doc updates --- man/matchit.Rd | 97 ++++++++++++++++++++++++------------------- man/method_cem.Rd | 23 +++++++--- man/method_nearest.Rd | 19 +++++---- 3 files changed, 83 insertions(+), 56 deletions(-) diff --git a/man/matchit.Rd b/man/matchit.Rd index 3579b734..7ad2b399 100644 --- a/man/matchit.Rd +++ b/man/matchit.Rd @@ -2,7 +2,6 @@ % Please edit documentation in R/matchit.R \name{matchit} \alias{matchit} -\alias{print.matchit} \title{Matching for Causal Inference} \usage{ matchit( @@ -26,10 +25,9 @@ matchit( ratio = 1, verbose = FALSE, include.obj = FALSE, + normalize = TRUE, ... ) - -\method{print}{matchit}(x, ...) } \arguments{ \item{formula}{a two-sided \code{\link{formula}} object containing the treatment and @@ -47,6 +45,7 @@ will be sought in the environment.} \code{\link[=method_nearest]{"nearest"}} for nearest neighbor matching (on the propensity score by default), \code{\link[=method_optimal]{"optimal"}} for optimal pair matching, \code{\link[=method_full]{"full"}} for optimal +full matching, \code{\link[=method_quick]{"quick"}} for generalized (quick) full matching, \code{\link[=method_genetic]{"genetic"}} for genetic matching, \code{\link[=method_cem]{"cem"}} for coarsened exact matching, \code{\link[=method_exact]{"exact"}} for exact matching, @@ -77,12 +76,11 @@ scores.} \item{distance.options}{a named list containing additional arguments supplied to the function that estimates the distance measure as determined -by the argument to \code{distance}. See \link{distance} for an +by the argument to \code{distance}. See \code{\link{distance}} for an example of its use.} \item{estimand}{a string containing the name of the target estimand desired. -Can be one of \code{"ATT"} or \code{"ATC"}. Some methods accept \code{"ATE"} -as well. Default is \code{"ATT"}. See Details and the individual methods +Can be one of \code{"ATT"}, \code{"ATC"}, or \code{"ATE"}. Default is \code{"ATT"}. See Details and the individual methods pages for information on how this argument is used.} \item{exact}{for methods that allow it, for which variables exact matching @@ -129,7 +127,7 @@ incorporated into propensity score models and balance statistics. Can also be specified as a string containing the name of variable in \code{data} to be used or a one-sided formula with the variable on the right-hand side (e.g., \code{~ SW}). Not all propensity score models accept sampling -weights; see \link{distance} for information on which do and do not, +weights; see \code{\link{distance}} for information on which do and do not, and see \code{vignette("sampling-weights")} for details on how to use sampling weights in a matching analysis.} @@ -172,11 +170,11 @@ the matching process in the output, i.e., by the functions from other packages \code{matchit()} calls. What is included depends on the matching method. Default is \code{FALSE}.} +\item{normalize}{\code{logical}; whether to rescale the nonzero weights in each treatment group to have an average of 1. Default is \code{TRUE}. See "How Matching Weights Are Computed" below for more details.} + \item{\dots}{additional arguments passed to the functions used in the matching process. See the individual methods pages for information on what -additional arguments are allowed for each method. Ignored for \code{print()}.} - -\item{x}{a \code{matchit} object.} +additional arguments are allowed for each method.} } \value{ When \code{method} is something other than \code{"subclass"}, a @@ -185,21 +183,21 @@ When \code{method} is something other than \code{"subclass"}, a \item{match.matrix}{a matrix containing the matches. The rownames correspond to the treated units and the values in each row are the names (or indices) of the control units matched to each treated unit. When treated units are -matched to different numbers of control units (e.g., with exact matching or +matched to different numbers of control units (e.g., with variable ratio matching or matching with a caliper), empty spaces will be filled with \code{NA}. Not -included when \code{method} is \code{"full"}, \code{"cem"} (unless \code{k2k = TRUE}), \code{"exact"}, or \code{"cardinality"}.} +included when \code{method} is \code{"full"}, \code{"cem"} (unless \code{k2k = TRUE}), \code{"exact"}, \code{"quick"}, or \code{"cardinality"} (unless \code{mahvars} is supplied and \code{ratio} is an integer).} \item{subclass}{a factor containing matching pair/stratum membership for each unit. Unmatched units -will have a value of \code{NA}. Not included when \code{replace = TRUE}.} +will have a value of \code{NA}. Not included when \code{replace = TRUE} or when \code{method = "cardinality"} unless \code{mahvars} is supplied and \code{ratio} is an integer.} \item{weights}{a numeric vector of estimated matching weights. Unmatched and discarded units will have a weight of zero.} \item{model}{the fit object of the model used to estimate propensity scores when \code{distance} is -specified and not \code{"mahalanobis"} or a numeric vector. When +specified as a method of estimating propensity scores. When \code{reestimate = TRUE}, this is the model estimated after discarding units.} \item{X}{a data frame of covariates mentioned in \code{formula}, -\code{exact}, \code{mahvars}, and \code{antiexact}.} +\code{exact}, \code{mahvars}, \code{caliper}, and \code{antiexact}.} \item{call}{the \code{matchit()} call.} \item{info}{information on the matching method and distance measures used.} @@ -249,10 +247,11 @@ pages: \item \code{\link{method_nearest}} for nearest neighbor matching \item \code{\link{method_optimal}} for optimal pair matching \item \code{\link{method_full}} for optimal full matching +\item \code{\link{method_quick}} for generalized (quick) full matching \item \code{\link{method_genetic}} for genetic matching \item \code{\link{method_cem}} for coarsened exact matching \item \code{\link{method_exact}} for exact matching -\item \code{\link{method_cardinality}} for cardinality and template matching +\item \code{\link{method_cardinality}} for cardinality and profile matching \item \code{\link{method_subclass}} for subclassification } @@ -276,7 +275,7 @@ matching on any of the included covariates and on the propensity score if specified. All arguments other than \code{distance}, \code{discard}, and \code{reestimate} will be ignored. -See \link{distance} for details on the several ways to +See \code{\link{distance}} for details on the several ways to specify the \code{distance}, \code{link}, and \code{distance.options} arguments to estimate propensity scores and create distance measures. @@ -285,7 +284,7 @@ to one and returned as such in the \code{matchit()} output (see section Value, below). The following rules are used: 1) if \code{0} is one of the values, it will be considered the control and the other value the treated; 2) otherwise, if the variable is a factor, \code{levels(treat)[1]} will be -considered control and the other variable the treated; 3) otherwise, +considered control and the other value the treated; 3) otherwise, \code{sort(unique(treat))[1]} will be considered control and the other value the treated. It is safest to ensure the treatment variable is a \code{0/1} variable. @@ -320,22 +319,32 @@ stratification weights. Matching weights are computed in one of two ways depending on whether matching was done with replacement or not. +\subsection{Matching without replacement and subclassification}{ -For matching \emph{without} replacement (except for cardinality matching), each +For matching \emph{without} replacement (except for cardinality matching), including subclassification, each unit is assigned to a subclass, which represents the pair they are a part of (in the case of k:1 matching) or the stratum they belong to (in the case of exact matching, coarsened exact matching, full matching, or subclassification). The formula for computing the weights depends on the argument supplied to \code{estimand}. A new "stratum propensity score" -(\code{sp}) is computed as the proportion of units in each stratum that are -in the treated group, and all units in that stratum are assigned that +(\eqn{p^s_i}) is computed for each unit \eqn{i} as \eqn{p^s_i = \frac{1}{n_s}\sum_{j: s_j =s_i}{I(A_j=1)}} where \eqn{n_s} is the size of subclass \eqn{s} and \eqn{I(A_j=1)} is 1 if unit \eqn{j} is treated and 0 otherwise. That is, the stratum propensity score for stratum \eqn{s} is the proportion of units in stratum \eqn{s} that are +in the treated group, and all units in stratum \eqn{s} are assigned that stratum propensity score. This is distinct from the propensity score used for matching, if any. Weights are then computed using the standard formulas for -inverse probability weights with the stratum propensity score inserted: for the ATT, weights are 1 for the treated -units and \code{sp/(1-sp)} for the control units; for the ATC, weights are -\code{(1-sp)/sp} for the treated units and 1 for the control units; for the -ATE, weights are \code{1/sp} for the treated units and \code{1/(1-sp)} for the -control units. For cardinality matching, all matched units receive a weight +inverse probability weights with the stratum propensity score inserted: +\itemize{ +\item for the ATT, weights are 1 for the treated +units and \eqn{\frac{p^s}{1-p^s}} for the control units +\item for the ATC, weights are +\eqn{\frac{1-p^s}{p^s}} for the treated units and 1 for the control units +\item for the ATE, weights are \eqn{\frac{1}{p^s}} for the treated units and \eqn{\frac{1}{1-p^s}} for the +control units. +} + +For cardinality matching, all matched units receive a weight of 1. +} + +\subsection{Matching witht replacement}{ For matching \emph{with} replacement, units are not assigned to unique strata. For the ATT, each treated unit gets a weight of 1. Each control unit is weighted @@ -347,14 +356,24 @@ matched to it, the control unit in question would get a weight of 1/3 + 1/2 = 5/6. For the ATC, the same is true with the treated and control labels switched. The weights are computed using the \code{match.matrix} component of the \code{matchit()} output object. +} + +\subsection{Normalized weights}{ -In each treatment group, weights are divided by the mean of the nonzero +When \code{normalize = TRUE} (the default), in each treatment group, weights are divided by the mean of the nonzero weights in that treatment group to make the weights sum to the number of -units in that treatment group. If sampling weights are included through the +units in that treatment group (i.e., to have an average of 1). +} + +\subsection{Sampling weights}{ + +If sampling weights are included through the \code{s.weights} argument, they will be included in the \code{matchit()} output object but not incorporated into the matching weights. \code{\link[=match.data]{match.data()}}, which extracts the matched set from a \code{matchit} object, combines the matching weights and sampling weights. +} + } } \examples{ @@ -421,20 +440,14 @@ Statistical Software}, 42(8). \doi{10.18637/jss.v042.i08} } \seealso{ \code{\link[=summary.matchit]{summary.matchit()}} for balance assessment after matching, \code{\link[=plot.matchit]{plot.matchit()}} for plots of covariate balance and propensity score overlap after matching. - -\code{vignette("MatchIt")} for an introduction to matching with -\emph{MatchIt}; \code{vignette("matching-methods")} for descriptions of the -variety of matching methods and options available; -\code{vignette("assessing-balance")} for information on assessing the -quality of a matching specification; \code{vignette("estimating-effects")} -for instructions on how to estimate treatment effects after matching; and -\code{vignette("sampling-weights")} for a guide to using \emph{MatchIt} with -sampling weights. +\itemize{ +\item \code{vignette("MatchIt")} for an introduction to matching with \emph{MatchIt} +\item \code{vignette("matching-methods")} for descriptions of the variety of matching methods and options available +\item \code{vignette("assessing-balance")} for information on assessing the quality of a matching specification +\item \code{vignette("estimating-effects")} for instructions on how to estimate treatment effects after matching +\item \code{vignette("sampling-weights")} for a guide to using \emph{MatchIt} with sampling weights. +} } \author{ -Daniel Ho (\email{dho@law.stanford.edu}); Kosuke Imai -(\email{imai@harvard.edu}); Gary King (\email{king@harvard.edu}); -Elizabeth Stuart (\email{estuart@jhsph.edu}) - -Version 4.0.0 update by Noah Greifer (\email{noah.greifer@gmail.com}) +Daniel Ho, Kosuke Imai, Gary King, and Elizabeth Stuart wrote the original package. Starting with version 4.0.0, Noah Greifer is the primary maintainer and developer. } diff --git a/man/method_cem.Rd b/man/method_cem.Rd index 5cac4310..50318511 100644 --- a/man/method_cem.Rd +++ b/man/method_cem.Rd @@ -54,18 +54,29 @@ units will be dropped (e.g., if there are more treated than control units in the stratum, the treated units without a match will be dropped). The \code{k2k.method} argument controls how the distance between units is calculated. } -\item{\code{k2k.method}}{ \code{character}; how the distance +\item{\code{k2k.method}}{\code{character}; how the distance between units should be calculated if \code{k2k = TRUE}. Allowable arguments include \code{NULL} (for random matching), any argument to \code{\link[=distance]{distance()}} for computing a distance matrix from covariates (e.g., \code{"mahalanobis"}), or any allowable argument to \code{method} in \code{\link[=dist]{dist()}}. Matching will take place on the original -(non-coarsened) variables. The default is \code{"mahalanobis"}. } -\item{\code{mpower}}{ if \code{k2k.method = "minkowski"}, the power used in -creating the distance. This is passed to the \code{p} argument of \code{\link[=dist]{dist()}}.} +(non-coarsened) variables. The default is \code{"mahalanobis"}. +} +\item{\code{mpower}}{if \code{k2k.method = "minkowski"}, the power used in +creating the distance. This is passed to the \code{p} argument of \code{\link[=dist]{dist()}}. +} +\item{\code{m.order}}{\code{character}; the order that the matching takes place when \code{k2k = TRUE}. Allowable options +include \code{"closest"}, where matching takes place in +ascending order of the smallest distance between units; \code{"farthest"}, where matching takes place in +descending order of the smallest distance between units; \code{"random"}, where matching takes place +in a random order; and \code{"data"} where matching takes place based on the +order of units in the data. When \code{m.order = "random"}, results may differ +across different runs of the same code unless a seed is set and specified +with \code{\link[=set.seed]{set.seed()}}. The default of \code{NULL} corresponds to \code{"data"}. See \code{\link{method_nearest}} for more information. +} } -The arguments \code{distance} (and related arguments), \code{exact}, \code{mahvars}, \code{discard} (and related arguments), \code{replace}, \code{m.order}, \code{caliper} (and related arguments), and \code{ratio} are ignored with a warning.} +The arguments \code{distance} (and related arguments), \code{exact}, \code{mahvars}, \code{discard} (and related arguments), \code{replace}, \code{caliper} (and related arguments), and \code{ratio} are ignored with a warning.} } \description{ In \code{\link[=matchit]{matchit()}}, setting \code{method = "cem"} performs coarsened exact @@ -109,7 +120,7 @@ matches, the covariates can be manually coarsened outside of Setting \code{k2k = TRUE} is equivalent to first doing coarsened exact matching with \code{k2k = FALSE} and then supplying stratum membership as an exact matching variable (i.e., in \code{exact}) to another call to -\code{matchit()} with \code{method = "nearest"}, \code{distance = "mahalanobis"} and an argument to \code{discard} denoting unmatched units. +\code{matchit()} with \code{method = "nearest"}. It is also equivalent to performing nearest neighbor matching supplying coarsened versions of the variables to \code{exact}, except that \code{method = "cem"} automatically coarsens the continuous variables. The diff --git a/man/method_nearest.Rd b/man/method_nearest.Rd index d0652bf1..e25b70a4 100644 --- a/man/method_nearest.Rd +++ b/man/method_nearest.Rd @@ -53,13 +53,14 @@ into propensity score models and balance statistics.} include \code{"largest"}, where matching takes place in descending order of distance measures; \code{"smallest"}, where matching takes place in ascending order of distance measures; \code{"closest"}, where matching takes place in -ascending order of the distance between units; \code{"random"}, where matching takes place +ascending order of the smallest distance between units; \code{"farthest"}, where matching takes place in +descending order of the smallest distance between units; \code{"random"}, where matching takes place in a random order; and \code{"data"} where matching takes place based on the order of units in the data. When \code{m.order = "random"}, results may differ across different runs of the same code unless a seed is set and specified with \code{\link[=set.seed]{set.seed()}}. The default of \code{NULL} corresponds to \code{"largest"} when a propensity score is estimated or supplied as a vector and \code{"data"} -otherwise.} +otherwise. See Details for more information.} \item{caliper}{the width(s) of the caliper(s) used for caliper matching. Two units with a difference on a caliper variable larger than the caliper will not be paired. See Details and Examples.} @@ -68,7 +69,7 @@ are in standard deviation units (\code{TRUE}) or raw units (\code{FALSE}).} \item{ratio}{how many control units should be matched to each treated unit for k:1 matching. For variable ratio matching, see section "Variable Ratio -Matching" in Details below.} +Matching" in Details below. When \code{ratio} is greater than 1, all treated units will be attempted to be matched with a control unit before any treated unit is matched with a second control unit, etc. This reduces the possibility that control units will be used up before some treated units receive any matches.} \item{min.controls, max.controls}{for variable ratio matching, the minimum and maximum number of controls units to be matched to each treated unit. See @@ -226,14 +227,16 @@ specified, it is set to 1 by default. \code{min.controls} must be less than Examples below for an example of their use. } -\subsection{Using \code{m.order = "closest"}}{ +\subsection{Using \code{m.order = "closest"} or \code{"farthest"}}{ -As of version 4.6.0, \code{m.order} can be set to \code{"closest"}, which works regardless of how the distance measure is specified. This matches in order of the distance between units. The closest pair of units across all potential pairs of units will be matched first; the second closest pair of all potential pairs will be matched second, etc. This ensures that the best possible matches are given priority, and in that sense performs similarly to \code{m.order = "smallest"}. When \code{distance} is a vector or a method of estimating a propensity score, an algorithm described by \href{https://doi.org/10.1002/pds.3263}{Rassen et al. (2012)} is used. +\code{m.order} can be set to \code{"closest"} or \code{"farthest"}, which work regardless of how the distance measure is specified. This matches in order of the distance between units. First, all the closest match is found for all treated units and the pairwise distances computed; when \code{m.order = "closest"} the pair with the smallest of the distances is matched first, and when \code{m.order = "farthest"}, the pair with the largest of the distances is matched first. Then, the pair with the second smallest (or largest) is matched second. If the matched control is ineligible (i.e., because it has already been used in a prior match), a new match is found for the treated unit, the new pair's distance is re-computed, and the pairs are re-ordered by distance. + +Using \code{m.order = "closest"} ensures that the best possible matches are given priority, and in that sense should perform similarly to \code{m.order = "smallest"}. It can be used to ensure the best matches, especially when matching with a caliper. Using \code{m.order = "farthest"} ensures that the hardest units to match are given their best chance to find a close match, and in that sense should perform similarly to \code{m.order = "largest"}. It can be used to reduce the possibility of extreme imbalance when there are hard-to-match units competing for controls. Note that \code{m.order = "farthest"} \strong{does not} implement "far matching" (i.e., finding the farthest control unit from each treated unit); it defines the order in which the closest matches are selected. } \subsection{Reproducibility}{ -Nearest neighbor matching involves a random component only when \code{m.order = "random"}, so a seed must be set in that case using \code{\link[=set.seed]{set.seed()}} to ensure reproducibility. Otherwise, it is purely deterministic, and any ties are broken based on the order in which the data appear. +Nearest neighbor matching involves a random component only when \code{m.order = "random"} (or when the propensity is estimated using a method with randomness; see \code{\link{distance}} for details), so a seed must be set in that case using \code{\link[=set.seed]{set.seed()}} to ensure reproducibility. Otherwise, it is purely deterministic, and any ties are broken based on the order in which the data appear. } } \note{ @@ -306,6 +309,6 @@ within \emph{MatchIt}. For example, a sentence might read: a call to \code{matchit()}. \code{\link[=method_optimal]{method_optimal()}} for optimal pair matching, which is similar to -nearest neighbor matching except that an overall distance criterion is -minimized. +nearest neighbor matching without replacement except that an overall distance criterion is +minimized (i.e., as an alternative to specifying \code{m.order}). } From d8a36f82ea06b76e0f5deff7f62ca31e77043fd9 Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Tue, 12 Nov 2024 02:30:49 -0500 Subject: [PATCH 39/48] Rcpp updates --- R/RcppExports.R | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/R/RcppExports.R b/R/RcppExports.R index 7d57c39d..054d9e79 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -1,6 +1,10 @@ # Generated by using Rcpp::compileAttributes() -> do not edit by hand # Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393 +all_equal_to <- function(x, y) { + .Call(`_MatchIt_all_equal_to`, x, y) +} + eucdistC_N1xN0 <- function(x, t) { .Call(`_MatchIt_eucdistC_N1xN0`, x, t) } @@ -13,24 +17,28 @@ has_n_unique <- function(x, n) { .Call(`_MatchIt_has_n_unique`, x, n) } -nn_matchC <- function(treat_, ord, ratio, discarded, reuse_max, focal_, distance_ = NULL, distance_mat_ = NULL, exact_ = NULL, caliper_dist_ = NULL, caliper_covs_ = NULL, caliper_covs_mat_ = NULL, mah_covs_ = NULL, antiexact_covs_ = NULL, unit_id_ = NULL, disl_prog = FALSE) { - .Call(`_MatchIt_nn_matchC`, treat_, ord, ratio, discarded, reuse_max, focal_, distance_, distance_mat_, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, mah_covs_, antiexact_covs_, unit_id_, disl_prog) +nn_matchC_distmat <- function(treat_, ord, ratio, discarded, reuse_max, focal_, distance_mat, exact_ = NULL, caliper_dist_ = NULL, caliper_covs_ = NULL, caliper_covs_mat_ = NULL, antiexact_covs_ = NULL, unit_id_ = NULL, disl_prog = FALSE) { + .Call(`_MatchIt_nn_matchC_distmat`, treat_, ord, ratio, discarded, reuse_max, focal_, distance_mat, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, disl_prog) +} + +nn_matchC_distmat_closest <- function(treat, ratio, discarded, reuse_max, distance_mat, exact_ = NULL, caliper_dist_ = NULL, caliper_covs_ = NULL, caliper_covs_mat_ = NULL, antiexact_covs_ = NULL, unit_id_ = NULL, close = TRUE, disl_prog = FALSE) { + .Call(`_MatchIt_nn_matchC_distmat_closest`, treat, ratio, discarded, reuse_max, distance_mat, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, close, disl_prog) } -nn_matchC_closest <- function(distance_mat, treat, ratio, discarded, reuse_max, exact_ = NULL, caliper_dist_ = NULL, caliper_covs_ = NULL, caliper_covs_mat_ = NULL, antiexact_covs_ = NULL, unit_id_ = NULL, disl_prog = FALSE) { - .Call(`_MatchIt_nn_matchC_closest`, distance_mat, treat, ratio, discarded, reuse_max, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, disl_prog) +nn_matchC_mahcovs <- function(treat_, ord, ratio, discarded, reuse_max, focal_, mah_covs, distance_ = NULL, exact_ = NULL, caliper_dist_ = NULL, caliper_covs_ = NULL, caliper_covs_mat_ = NULL, antiexact_covs_ = NULL, unit_id_ = NULL, disl_prog = FALSE) { + .Call(`_MatchIt_nn_matchC_mahcovs`, treat_, ord, ratio, discarded, reuse_max, focal_, mah_covs, distance_, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, disl_prog) } -nn_matchC_mahcovs_closest <- function(treat, ratio, discarded, reuse_max, mah_covs, distance_ = NULL, exact_ = NULL, caliper_dist_ = NULL, caliper_covs_ = NULL, caliper_covs_mat_ = NULL, antiexact_covs_ = NULL, unit_id_ = NULL, disl_prog = FALSE) { - .Call(`_MatchIt_nn_matchC_mahcovs_closest`, treat, ratio, discarded, reuse_max, mah_covs, distance_, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, disl_prog) +nn_matchC_mahcovs_closest <- function(treat, ratio, discarded, reuse_max, mah_covs, distance_ = NULL, exact_ = NULL, caliper_dist_ = NULL, caliper_covs_ = NULL, caliper_covs_mat_ = NULL, antiexact_covs_ = NULL, unit_id_ = NULL, close = TRUE, disl_prog = FALSE) { + .Call(`_MatchIt_nn_matchC_mahcovs_closest`, treat, ratio, discarded, reuse_max, mah_covs, distance_, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, close, disl_prog) } nn_matchC_vec <- function(treat_, ord, ratio, discarded, reuse_max, focal_, distance, exact_ = NULL, caliper_dist_ = NULL, caliper_covs_ = NULL, caliper_covs_mat_ = NULL, antiexact_covs_ = NULL, unit_id_ = NULL, disl_prog = FALSE) { .Call(`_MatchIt_nn_matchC_vec`, treat_, ord, ratio, discarded, reuse_max, focal_, distance, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, disl_prog) } -nn_matchC_vec_closest <- function(treat, ratio, discarded, reuse_max, distance, exact_ = NULL, caliper_dist_ = NULL, caliper_covs_ = NULL, caliper_covs_mat_ = NULL, antiexact_covs_ = NULL, unit_id_ = NULL, disl_prog = FALSE) { - .Call(`_MatchIt_nn_matchC_vec_closest`, treat, ratio, discarded, reuse_max, distance, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, disl_prog) +nn_matchC_vec_closest <- function(treat, ratio, discarded, reuse_max, distance, exact_ = NULL, caliper_dist_ = NULL, caliper_covs_ = NULL, caliper_covs_mat_ = NULL, antiexact_covs_ = NULL, unit_id_ = NULL, close = TRUE, disl_prog = FALSE) { + .Call(`_MatchIt_nn_matchC_vec_closest`, treat, ratio, discarded, reuse_max, distance, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, close, disl_prog) } pairdistsubC <- function(x, t, s) { From 2656f98c1dde146402f857760fe0b4b5c01b87ec Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Tue, 12 Nov 2024 02:35:52 -0500 Subject: [PATCH 40/48] README updates --- README.Rmd | 8 ++++---- README.md | 40 +++++++++++++++++++--------------------- 2 files changed, 23 insertions(+), 25 deletions(-) diff --git a/README.Rmd b/README.Rmd index f19f5818..ee518502 100644 --- a/README.Rmd +++ b/README.Rmd @@ -49,7 +49,7 @@ We can check covariate balance for the original and matched samples using `summa summary(m.out) ``` -At the top is balance for the original sample. Below that is balance in the matched sample, followed by the percent reduction in imbalance and the sample sizes before and after matching. Smaller values for the balance statistics indicate better balance. (In this case, good balance was not achieved and other matching methods should be tried). We can plot the standardized mean differences in a Love plot for a clean, visual display of balance across the sample: +At the top is balance for the original sample. Below that is balance in the matched sample. Smaller values for the balance statistics indicate better balance. (In this case, fairly good balance was achieved, but other matching methods should be tried). We can plot the standardized mean differences in a Love plot for a clean, visual display of balance across the sample: ```{r, fig.alt ="Love plot of balance before and after matching."} #Plot balance @@ -77,9 +77,9 @@ install.packages("MatchIt") To install a development version, which may have a bug fixed or a new feature, run the following: ```{r, eval=F} -install.packages("remotes") #If not yet installed +install.packages("pak") #If not yet installed -remotes::install_github("ngreifer/MatchIt") +pak::pkg_install("ngreifer/MatchIt") ``` -This will require R to compile C++ code, which might require additional software be installed on your computer. If you need the development version but can't compile the package, ask the maintainer for a binary version of the package. \ No newline at end of file +This will require R to compile C++ code, which might require additional software to be installed on your computer. If you need the development version but can't compile the package, ask the maintainer for a binary version of the package. \ No newline at end of file diff --git a/README.md b/README.md index 4d65126e..98a54446 100644 --- a/README.md +++ b/README.md @@ -38,10 +38,9 @@ performed. m.out ``` - #> A matchit object + #> A `matchit` object #> - method: 1:1 nearest neighbor matching with replacement - #> - distance: Mahalanobis - #> - number of obs.: 614 (original), 261 (matched) + #> - distance: Mahalanobis - number of obs.: 614 (original), 264 (matched) #> - target estimand: ATT #> - covariates: age, educ, race, married, nodegree, re74, re75 @@ -72,31 +71,30 @@ summary(m.out) #> #> Summary of Balance for Matched Data: #> Means Treated Means Control Std. Mean Diff. Var. Ratio eCDF Mean eCDF Max Std. Pair Dist. - #> age 25.8162 25.5405 0.0385 0.6524 0.0466 0.1892 0.4827 - #> educ 10.3459 10.4270 -0.0403 1.1636 0.0077 0.0378 0.1963 + #> age 25.8162 25.5405 0.0385 0.6531 0.0466 0.1892 0.4827 + #> educ 10.3459 10.4270 -0.0403 1.1649 0.0077 0.0378 0.1963 #> raceblack 0.8432 0.8432 0.0000 . 0.0000 0.0000 0.0000 #> racehispan 0.0595 0.0595 0.0000 . 0.0000 0.0000 0.0000 #> racewhite 0.0973 0.0973 0.0000 . 0.0000 0.0000 0.0000 #> married 0.1892 0.1784 0.0276 . 0.0108 0.0108 0.0276 #> nodegree 0.7081 0.7081 0.0000 . 0.0000 0.0000 0.0000 - #> re74 2095.5737 1788.6941 0.0628 1.5690 0.0311 0.1730 0.2494 - #> re75 1532.0553 1087.7420 0.1380 2.1221 0.0330 0.0865 0.2360 + #> re74 2095.5737 1788.6941 0.0628 1.5707 0.0311 0.1730 0.2494 + #> re75 1532.0553 1087.7420 0.1380 2.1244 0.0330 0.0865 0.2360 #> #> Sample Sizes: #> Control Treated - #> All 429 185 - #> Matched (ESS) 33 185 - #> Matched 76 185 - #> Unmatched 353 0 - #> Discarded 0 0 + #> All 429. 185 + #> Matched (ESS) 34.19 185 + #> Matched 79. 185 + #> Unmatched 350. 0 + #> Discarded 0. 0 At the top is balance for the original sample. Below that is balance in -the matched sample, followed by the percent reduction in imbalance and -the sample sizes before and after matching. Smaller values for the -balance statistics indicate better balance. (In this case, good balance -was not achieved and other matching methods should be tried). We can -plot the standardized mean differences in a Love plot for a clean, -visual display of balance across the sample: +the matched sample. Smaller values for the balance statistics indicate +better balance. (In this case, fairly good balance was achieved, but +other matching methods should be tried). We can plot the standardized +mean differences in a Love plot for a clean, visual display of balance +across the sample: ``` r #Plot balance @@ -147,12 +145,12 @@ To install a development version, which may have a bug fixed or a new feature, run the following: ``` r -install.packages("remotes") #If not yet installed +install.packages("pak") #If not yet installed -remotes::install_github("ngreifer/MatchIt") +pak::pkg_install("ngreifer/MatchIt") ``` This will require R to compile C++ code, which might require additional -software be installed on your computer. If you need the development +software to be installed on your computer. If you need the development version but can’t compile the package, ask the maintainer for a binary version of the package. From d84ebf1a17ad3cb2d53824734f7b92762b9b683d Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Tue, 12 Nov 2024 02:37:56 -0500 Subject: [PATCH 41/48] Added rhub workflow --- .github/workflows/rhub.yaml | 95 +++++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 .github/workflows/rhub.yaml diff --git a/.github/workflows/rhub.yaml b/.github/workflows/rhub.yaml new file mode 100644 index 00000000..74ec7b05 --- /dev/null +++ b/.github/workflows/rhub.yaml @@ -0,0 +1,95 @@ +# R-hub's generic GitHub Actions workflow file. It's canonical location is at +# https://github.com/r-hub/actions/blob/v1/workflows/rhub.yaml +# You can update this file to a newer version using the rhub2 package: +# +# rhub::rhub_setup() +# +# It is unlikely that you need to modify this file manually. + +name: R-hub +run-name: "${{ github.event.inputs.id }}: ${{ github.event.inputs.name || format('Manually run by {0}', github.triggering_actor) }}" + +on: + workflow_dispatch: + inputs: + config: + description: 'A comma separated list of R-hub platforms to use.' + type: string + default: 'linux,windows,macos' + name: + description: 'Run name. You can leave this empty now.' + type: string + id: + description: 'Unique ID. You can leave this empty now.' + type: string + +jobs: + + setup: + runs-on: ubuntu-latest + outputs: + containers: ${{ steps.rhub-setup.outputs.containers }} + platforms: ${{ steps.rhub-setup.outputs.platforms }} + + steps: + # NO NEED TO CHECKOUT HERE + - uses: r-hub/actions/setup@v1 + with: + config: ${{ github.event.inputs.config }} + id: rhub-setup + + linux-containers: + needs: setup + if: ${{ needs.setup.outputs.containers != '[]' }} + runs-on: ubuntu-latest + name: ${{ matrix.config.label }} + strategy: + fail-fast: false + matrix: + config: ${{ fromJson(needs.setup.outputs.containers) }} + container: + image: ${{ matrix.config.container }} + + steps: + - uses: r-hub/actions/checkout@v1 + - uses: r-hub/actions/platform-info@v1 + with: + token: ${{ secrets.RHUB_TOKEN }} + job-config: ${{ matrix.config.job-config }} + - uses: r-hub/actions/setup-deps@v1 + with: + token: ${{ secrets.RHUB_TOKEN }} + job-config: ${{ matrix.config.job-config }} + - uses: r-hub/actions/run-check@v1 + with: + token: ${{ secrets.RHUB_TOKEN }} + job-config: ${{ matrix.config.job-config }} + + other-platforms: + needs: setup + if: ${{ needs.setup.outputs.platforms != '[]' }} + runs-on: ${{ matrix.config.os }} + name: ${{ matrix.config.label }} + strategy: + fail-fast: false + matrix: + config: ${{ fromJson(needs.setup.outputs.platforms) }} + + steps: + - uses: r-hub/actions/checkout@v1 + - uses: r-hub/actions/setup-r@v1 + with: + job-config: ${{ matrix.config.job-config }} + token: ${{ secrets.RHUB_TOKEN }} + - uses: r-hub/actions/platform-info@v1 + with: + token: ${{ secrets.RHUB_TOKEN }} + job-config: ${{ matrix.config.job-config }} + - uses: r-hub/actions/setup-deps@v1 + with: + job-config: ${{ matrix.config.job-config }} + token: ${{ secrets.RHUB_TOKEN }} + - uses: r-hub/actions/run-check@v1 + with: + job-config: ${{ matrix.config.job-config }} + token: ${{ secrets.RHUB_TOKEN }} From bcee422c07feef4249c0e076a78a83dd3523b3dc Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Tue, 12 Nov 2024 02:56:42 -0500 Subject: [PATCH 42/48] Temp update without gurobi --- DESCRIPTION | 4 ++-- R/matchit2cardinality.R | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index e650109a..9b9efdc4 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -33,7 +33,8 @@ Imports: Rcpp, utils, stats, - graphics + graphics, + grDevices Suggests: optmatch (>= 0.10.6), Matching, @@ -56,7 +57,6 @@ Suggests: highs, Rglpk, Rsymphony, - gurobi, knitr, rmarkdown, testthat (>= 3.0.0) diff --git a/R/matchit2cardinality.R b/R/matchit2cardinality.R index 093a6988..57e2c0af 100644 --- a/R/matchit2cardinality.R +++ b/R/matchit2cardinality.R @@ -722,9 +722,9 @@ dispatch_optimizer <- function(solver = "glpk", obj, mat, dir, rhs, types, max = dir[dir == "<="] <- "<" dir[dir == ">="] <- ">" dir[dir == "=="] <- "=" - opt.out <- gurobi::gurobi(list(A = mat, obj = obj, sense = dir, rhs = rhs, vtype = types, - modelsense = "max", lb = lb, ub = ub), - params = list(OutputFlag = as.integer(verbose), TimeLimit = time)) + # opt.out <- gurobi::gurobi(list(A = mat, obj = obj, sense = dir, rhs = rhs, vtype = types, + # modelsense = "max", lb = lb, ub = ub), + # params = list(OutputFlag = as.integer(verbose), TimeLimit = time)) } else if (solver == "highs") { rhs_h <- lhs_h <- rhs From ccfb4964e7dedee6697695b69e27be890a8da00b Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Tue, 12 Nov 2024 10:52:53 -0500 Subject: [PATCH 43/48] Vignette updates to improve checks --- _archive/MatchIt_A3_estimating_effects2.Rmd | 2 +- vignettes/MatchIt.Rmd | 2 +- vignettes/assessing-balance.Rmd | 2 +- vignettes/estimating-effects.Rmd | 2 +- vignettes/matching-methods.Rmd | 2 +- vignettes/sampling-weights.Rmd | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/_archive/MatchIt_A3_estimating_effects2.Rmd b/_archive/MatchIt_A3_estimating_effects2.Rmd index ae5cc82a..5a1847a5 100644 --- a/_archive/MatchIt_A3_estimating_effects2.Rmd +++ b/_archive/MatchIt_A3_estimating_effects2.Rmd @@ -8,7 +8,7 @@ output: vignette: > %\VignetteIndexEntry{Estimating Effects} - %\VignetteEngine{knitr::rmarkdown} + %\VignetteEngine{knitr::rmarkdown_notangle} %\VignetteEncoding{UTF-8} bibliography: references.bib --- diff --git a/vignettes/MatchIt.Rmd b/vignettes/MatchIt.Rmd index eff196bf..08fda507 100644 --- a/vignettes/MatchIt.Rmd +++ b/vignettes/MatchIt.Rmd @@ -7,7 +7,7 @@ output: toc: yes vignette: | %\VignetteIndexEntry{MatchIt: Getting Started} - %\VignetteEngine{knitr::rmarkdown} + %\VignetteEngine{knitr::rmarkdown_notangle} %\VignetteEncoding{UTF-8} bibliography: references.bib link-citations: true diff --git a/vignettes/assessing-balance.Rmd b/vignettes/assessing-balance.Rmd index a227673c..e7e74a22 100644 --- a/vignettes/assessing-balance.Rmd +++ b/vignettes/assessing-balance.Rmd @@ -7,7 +7,7 @@ output: toc: true vignette: > %\VignetteIndexEntry{Assessing Balance} - %\VignetteEngine{knitr::rmarkdown} + %\VignetteEngine{knitr::rmarkdown_notangle} %\VignetteEncoding{UTF-8} bibliography: references.bib link-citations: true diff --git a/vignettes/estimating-effects.Rmd b/vignettes/estimating-effects.Rmd index 7f336efc..30f40123 100644 --- a/vignettes/estimating-effects.Rmd +++ b/vignettes/estimating-effects.Rmd @@ -7,7 +7,7 @@ output: toc: true vignette: > %\VignetteIndexEntry{Estimating Effects After Matching} - %\VignetteEngine{knitr::rmarkdown} + %\VignetteEngine{knitr::rmarkdown_notangle} %\VignetteEncoding{UTF-8} bibliography: references.bib link-citations: true diff --git a/vignettes/matching-methods.Rmd b/vignettes/matching-methods.Rmd index a9e229a2..021649b5 100644 --- a/vignettes/matching-methods.Rmd +++ b/vignettes/matching-methods.Rmd @@ -7,7 +7,7 @@ output: toc: true vignette: > %\VignetteIndexEntry{Matching Methods} - %\VignetteEngine{knitr::rmarkdown} + %\VignetteEngine{knitr::rmarkdown_notangle} %\VignetteEncoding{UTF-8} bibliography: references.bib link-citations: true diff --git a/vignettes/sampling-weights.Rmd b/vignettes/sampling-weights.Rmd index 32da472f..c4d11b9c 100644 --- a/vignettes/sampling-weights.Rmd +++ b/vignettes/sampling-weights.Rmd @@ -7,7 +7,7 @@ output: toc: true vignette: > %\VignetteIndexEntry{Matching with Sampling Weights} - %\VignetteEngine{knitr::rmarkdown} + %\VignetteEngine{knitr::rmarkdown_notangle} %\VignetteEncoding{UTF-8} bibliography: references.bib link-citations: true From b04a79512ab9bfe1b8a1ba036b72841d746bf0bc Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Tue, 12 Nov 2024 12:30:44 -0500 Subject: [PATCH 44/48] Vignette updates --- vignettes/estimating-effects.Rmd | 4 ++-- vignettes/sampling-weights.Rmd | 29 ++++++++++++----------------- 2 files changed, 14 insertions(+), 19 deletions(-) diff --git a/vignettes/estimating-effects.Rmd b/vignettes/estimating-effects.Rmd index 30f40123..0a49d87e 100644 --- a/vignettes/estimating-effects.Rmd +++ b/vignettes/estimating-effects.Rmd @@ -447,7 +447,7 @@ boot.ci(boot_out, type = "perc") ```{r, include = FALSE} b <- { - if (boot_ok) boot.ci(boot_out, type = "perc") + if (boot_ok) boot::boot.ci(boot_out, type = "perc") else list(t0 = 1.347, percent = c(0, 0, 0, 1.144, 1.891)) } ``` @@ -524,7 +524,7 @@ boot.ci(cluster_boot_out, type = "perc") ```{r, include = FALSE} b <- { - if (boot_ok) boot.ci(cluster_boot_out, type = "perc") + if (boot_ok) boot::boot.ci(cluster_boot_out, type = "perc") else list(t0 = 1.588, percent = c(0,0,0, 1.348, 1.877)) } ``` diff --git a/vignettes/sampling-weights.Rmd b/vignettes/sampling-weights.Rmd index c4d11b9c..6a608af9 100644 --- a/vignettes/sampling-weights.Rmd +++ b/vignettes/sampling-weights.Rmd @@ -67,6 +67,10 @@ SW <- gen_SW(X) Y_C <- gen_Y_C(A, X) d <- data.frame(A, X, Y_C, SW) + +eval_est <- (requireNamespace("optmatch", quietly = TRUE) && + requireNamespace("marginaleffects", quietly = TRUE) && + requireNamespace("sandwich", quietly = TRUE)) ``` ## Introduction @@ -86,17 +90,11 @@ library("MatchIt") When using sampling weights with propensity score matching, one has the option of including the sampling weights in the model used to estimate the propensity scores. Although evidence is mixed on whether this is required [@austin2016; @lenis2019], it can be a good idea. The choice should depend on whether including the sampling weights improves the quality of the matches. Specifications including and excluding sampling weights should be tried to determine which is preferred. To supply sampling weights to the propensity score-estimating function in `matchit()`, the sampling weights variable should be supplied to the `s.weights` argument. It can be supplied either as a numerical vector containing the sampling weights, or a string or one-sided formula with the name of the sampling weights variable in the supplied dataset. Below we demonstrate including sampling weights into propensity scores estimated using logistic regression for optimal full matching for the average treatment effect in the population (ATE) (note that all methods and steps apply the same way to all forms of matching and all estimands). -```{asis, echo = !requireNamespace("optmatch", quietly = TRUE)} +```{asis, echo = eval_est} Note: if the `optmatch`, `marginaleffects`, or `sandwich` packages are not available, the subsequent lines will not run. ``` -```{r, include=FALSE} -#In case packages goes offline, don't run lines below -if (!requireNamespace("optmatch", quietly = TRUE) || - !requireNamespace("marginaleffects", quietly = TRUE) || - !requireNamespace("sandwich", quietly = TRUE)) knitr::opts_chunk$set(eval = FALSE) -``` -```{r} +```{r, eval = eval_est} mF_s <- matchit(A ~ X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 + X9, data = d, method = "full", distance = "glm", @@ -108,7 +106,7 @@ Notice that the description of the matching specification when the `matchit` obj Now let's perform full matching on a propensity score that does not include the sampling weights in its estimation. Here we use the same specification as was used in `vignette("estimating-effects")`. -```{r} +```{r, eval = eval_est} mF <- matchit(A ~ X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 + X9, data = d, method = "full", distance = "glm", @@ -118,7 +116,7 @@ mF Notice that there is no mention of sampling weights in the description of the matching specification. However, to properly assess balance and estimate effects, we need the sampling weights to be included in the `matchit` object, even if they were not used at all in the matching. To do so, we use the function `add_s.weights()`, which adds sampling weights to the supplied `matchit` objects. -```{r} +```{r, eval = eval_est} mF <- add_s.weights(mF, ~SW) mF @@ -134,7 +132,7 @@ Now we need to decide which matching specification is the best to use for effect We'll use `summary()` to examine balance on the two matching specifications. With sampling weights included, the balance statistics for the unmatched data are weighted by the sampling weights. The balance statistics for the matched data are weighted by the product of the sampling weights and the matching weights. It is the product of these weights that will be used in estimating the treatment effect. Below we use `summary()` to display balance for the two matching specifications. No additional arguments to `summary()` are required for it to use the sampling weights; as long as they are in the `matchit` object (either due to being supplied with the `s.weights` argument in the call to `matchit()` or to being added afterward by `add_s.weights()`), they will be correctly incorporated into the balance statistics. -```{r} +```{r, eval = eval_est} #Balance before matching and for the SW propensity score full matching summary(mF_s) @@ -152,7 +150,7 @@ Estimating the treatment effect after matching is straightforward when using sam Below we estimate the effect of `A` on `Y_C` in the matched and sampling weighted sample, adjusting for the covariates to improve precision and decrease bias. -```{r} +```{r, eval = eval_est} md_F_s <- match.data(mF_s) fit <- lm(Y_C ~ A * (X1 + X2 + X3 + X4 + X5 + @@ -163,15 +161,12 @@ library("marginaleffects") avg_comparisons(fit, variables = "A", vcov = ~subclass, - newdata = subset(md_F_s, A == 1), - wts = "weights") + newdata = subset(A == 1), + wts = "SW") ``` Note that `match.data()` and `get_weights()` have the option `include.s.weights`, which, when set to `FALSE`, makes it so the returned weights do not incorporate the sampling weights and are simply the matching weights. Because one might to forget to multiply the two sets of weights together, it is easier to just use the default of `include.s.weights = TRUE` and ignore the sampling weights in the rest of the analysis (because they are already included in the returned weights). `avg_comparisons()` also works more smoothly when the weights supplied to `weights` is a single variable rather than the product of two. -```{r, include=FALSE, eval=TRUE} -knitr::opts_chunk$set(eval = TRUE) -``` ## Code to Generate Data used in Examples ```{r, eval = FALSE} From 9282bb99abf2cf2e0a6df2d51dffa09bf3b0a010 Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Tue, 12 Nov 2024 18:42:12 -0500 Subject: [PATCH 45/48] Cleaning and organization --- R/aux_functions.R | 14 ++--- R/discard.R | 6 +-- R/get_weights_from_subclass.R | 2 +- R/input_processing.R | 4 +- R/match.data.R | 9 ++-- R/match.qoi.R | 4 +- R/matchit.R | 59 +++++++++------------ R/matchit2cardinality.R | 97 ++++++++++++++++++----------------- R/matchit2cem.R | 14 ++--- R/matchit2full.R | 17 +++--- R/matchit2genetic.R | 8 +-- R/matchit2nearest.R | 7 ++- R/matchit2optimal.R | 18 +++---- R/matchit2quick.R | 15 +++--- R/matchit2subclass.R | 2 +- R/plot.matchit.R | 4 +- R/rbind.matchdata.R | 2 +- R/summary.matchit.R | 6 +-- R/utils.R | 12 +++-- 19 files changed, 152 insertions(+), 148 deletions(-) diff --git a/R/aux_functions.R b/R/aux_functions.R index 4dc54dbc..b616d886 100644 --- a/R/aux_functions.R +++ b/R/aux_functions.R @@ -71,7 +71,7 @@ info.to.method <- function(info) { else "without replacement" } - firstup(do.call("paste", c(unname(out.list), list(sep = " ")))) + firstup(do.call("paste", unname(out.list))) } info.to.distance <- function(info) { @@ -153,11 +153,9 @@ exactify <- function(X, nam = NULL, sep = "|", include_vars = FALSE, justify = " } lev <- { - if (include_vars) { - sprintf("%s = %s", - names(X)[i], - add_quotes(unique_x, is.character(X[[i]]) || is.factor(X[[i]]))) - } + if (include_vars) sprintf("%s = %s", + names(X)[i], + add_quotes(unique_x, chk::vld_character_or_factor(X[[i]]))) else if (is_null(justify)) unique_x else format(unique_x, justify = justify) } @@ -190,7 +188,9 @@ get.covs.matrix <- function(formula = NULL, data = NULL) { na.action = na.pass) chars.in.mf <- vapply(mf, is.character, logical(1L)) - mf[chars.in.mf] <- lapply(mf[chars.in.mf], as.factor) + for (i in which(chars.in.mf)) { + mf[[i]] <- as.factor(mf[[i]]) + } mf <- droplevels(mf) diff --git a/R/discard.R b/R/discard.R index 74f2bf9a..87ee45d1 100644 --- a/R/discard.R +++ b/R/discard.R @@ -4,7 +4,7 @@ discard <- function(treat, pscore = NULL, option = NULL) { if (is_null(option)){ # keep all units - return(setNames(rep(FALSE, n.obs), names(treat))) + return(rep_with(FALSE, treat)) } if (is.logical(option) && length(option) == n.obs && !anyNA(option)) { @@ -20,7 +20,7 @@ discard <- function(treat, pscore = NULL, option = NULL) { if (option == "none") { # keep all units - return(setNames(rep(FALSE, n.obs), names(treat))) + return(rep_with(FALSE, treat)) } if (is_null(pscore)) { @@ -50,7 +50,7 @@ discard <- function(treat, pscore = NULL, option = NULL) { # X <- model.matrix(reformulate(names(covs), intercept = FALSE), data = covs, # contrasts.arg = lapply(Filter(is.factor, covs), # function(x) contrasts(x, contrasts = nlevels(x) == 1))) - # discarded <- rep(FALSE, n.obs) + # discarded <- rep.int(FALSE, n.obs) # if (option == "hull.control"){ # discard units not in T convex hull # wif <- WhatIf::whatif(cfact = X[treat==0,], data = X[treat==1,]) # discarded[treat==0] <- !wif$in.hull diff --git a/R/get_weights_from_subclass.R b/R/get_weights_from_subclass.R index edbf2a9a..bd1dd333 100644 --- a/R/get_weights_from_subclass.R +++ b/R/get_weights_from_subclass.R @@ -16,7 +16,7 @@ get_weights_from_subclass <- function(psclass, treat, estimand = "ATT") { .err("No control units were matched") } - weights <- setNames(rep(0.0, length(treat)), names(treat)) + weights <- rep_with(0.0, treat) if (!is.factor(psclass)) { psclass <- factor(psclass, nmax = min(length(i1), length(i0))) diff --git a/R/input_processing.R b/R/input_processing.R index ed850f26..387baa6c 100644 --- a/R/input_processing.R +++ b/R/input_processing.R @@ -441,7 +441,7 @@ process.caliper <- function(caliper = NULL, method = NULL, data = NULL, covs = N #Check std.caliper chk::chk_logical(std.caliper) if (length(std.caliper) == 1L) { - std.caliper <- setNames(rep.int(std.caliper, length(caliper)), names(caliper)) + std.caliper <- rep_with(std.caliper, caliper) } else if (length(std.caliper) == length(caliper)) { names(std.caliper) <- names(caliper) @@ -464,7 +464,7 @@ process.caliper <- function(caliper = NULL, method = NULL, data = NULL, covs = N else if (cal.in.covs[x]) var <- covs[[x]] else var <- mahcovs[[x]] - is.factor(var) || is.character(var) + chk::vld_character_or_factor(var) }, logical(1L)) if (any(cat.vars)) { diff --git a/R/match.data.R b/R/match.data.R index a9a52033..a3c88302 100644 --- a/R/match.data.R +++ b/R/match.data.R @@ -305,9 +305,12 @@ get_matches <- function(object, distance = "distance", weights = "weights", subc matched <- as.data.frame(matrix(NA_character_, nrow = nrow(mm) + sum(!is.na(mm)), ncol = 3)) names(matched) <- c(id, subclass, weights) - matched[[id]] <- c(as.vector(tmm[!is.na(tmm)]), rownames(mm)) - matched[[subclass]] <- c(as.vector(col(tmm)[!is.na(tmm)]), seq_len(nrow(mm))) - matched[[weights]] <- c(1/num.matches[matched[[subclass]][seq_len(sum(!is.na(mm)))]], rep(1, nrow(mm))) + matched[[id]] <- c(as.vector(tmm[!is.na(tmm)]), + rownames(mm)) + matched[[subclass]] <- c(as.vector(col(tmm)[!is.na(tmm)]), + seq_len(nrow(mm))) + matched[[weights]] <- c(1 / num.matches[matched[[subclass]][seq_len(sum(!is.na(mm)))]], + rep.int(1, nrow(mm))) if (is_not_null(object$s.weights) && include.s.weights) { matched[[weights]] <- matched[[weights]] * object$s.weights[matched[[id]]] diff --git a/R/match.qoi.R b/R/match.qoi.R index ad47f891..371d41a9 100644 --- a/R/match.qoi.R +++ b/R/match.qoi.R @@ -6,7 +6,7 @@ bal1var <- function(xx, tt, ww = NULL, s.weights, subclass = NULL, mm = NULL, un <- is_null(ww) bin.var <- all(xx == 0 | xx == 1) - xsum <- rep(NA_real_, 7) + xsum <- rep.int(NA_real_, 7L) if (standardize) names(xsum) <- c("Means Treated","Means Control", "Std. Mean Diff.", "Var. Ratio", "eCDF Mean", "eCDF Max", "Std. Pair Dist.") @@ -170,7 +170,7 @@ qqsum <- function(x, t, w = NULL, standardize = FALSE) { n.obs <- length(x) if (is_null(w)) { - w <- rep(1, n.obs) + w <- rep.int(1, n.obs) } if (has_n_unique(x, 2) && all(x == 0 | x == 1)) { diff --git a/R/matchit.R b/R/matchit.R index 3cd6bbfa..67224dcc 100644 --- a/R/matchit.R +++ b/R/matchit.R @@ -49,8 +49,7 @@ #' argument controlling the link function used in estimating the distance #' measure. Allowable options depend on the specific `distance` value #' specified. See [`distance`] for allowable options with each -#' option. The default is `"logit"`, which, along with `distance = "glm"`, identifies the default measure as logistic regression propensity -#' scores. +#' option. The default is `"logit"`, which, along with `distance = "glm"`, identifies the default measure as logistic regression propensity scores. #' @param distance.options a named list containing additional arguments #' supplied to the function that estimates the distance measure as determined #' by the argument to `distance`. See [`distance`] for an @@ -70,8 +69,7 @@ #' within propensity score calipers, where the propensity scores are computed #' using `formula` and `distance`. Can be specified as a string #' containing the names of variables in `data` to be used or a one-sided -#' formula with the desired variables on the right-hand side (e.g., `~ X3 + X4`). See the individual methods pages for information on whether and how -#' this argument is used. +#' formula with the desired variables on the right-hand side (e.g., `~ X3 + X4`). See the individual methods pages for information on whether and how this argument is used. #' @param antiexact for methods that allow it, for which variables anti-exact #' matching should take place. Anti-exact matching ensures paired individuals #' do not have the same value of the anti-exact matching variable(s). Can be @@ -246,8 +244,7 @@ #' treated unit across its matches. For example, if a control unit was matched #' to a treated unit that had two other control units matched to it, and that #' same control was matched to a treated unit that had one other control unit -#' matched to it, the control unit in question would get a weight of 1/3 + 1/2 -#' = 5/6. For the ATC, the same is true with the treated and control labels +#' matched to it, the control unit in question would get a weight of \eqn{1/3 + 1/2 = 5/6}. For the ATC, the same is true with the treated and control labels #' switched. The weights are computed using the `match.matrix` component #' of the `matchit()` output object. #' @@ -268,7 +265,7 @@ #' @return When `method` is something other than `"subclass"`, a #' `matchit` object with the following components: #' -#' \item{match.matrix}{a matrix containing the matches. The rownames correspond +#' \item{match.matrix}{a matrix containing the matches. The row names correspond #' to the treated units and the values in each row are the names (or indices) #' of the control units matched to each treated unit. When treated units are #' matched to different numbers of control units (e.g., with variable ratio matching or @@ -284,13 +281,10 @@ #' specified as a method of estimating propensity scores. When #' `reestimate = TRUE`, this is the model estimated after discarding #' units.} -#' \item{X}{a data frame of covariates mentioned in `formula`, -#' `exact`, `mahvars`, `caliper`, and `antiexact`.} +#' \item{X}{a data frame of covariates mentioned in `formula`, `exact`, `mahvars`, `caliper`, and `antiexact`.} #' \item{call}{the `matchit()` call.} -#' \item{info}{information on the matching method and -#' distance measures used.} -#' \item{estimand}{the argument supplied to -#' `estimand`.} +#' \item{info}{information on the matching method and distance measures used.} +#' \item{estimand}{the argument supplied to `estimand`.} #' \item{formula}{the `formula` supplied.} #' \item{treat}{a vector of treatment status converted to zeros (0) and ones #' (1) if not already in that format.} @@ -298,16 +292,11 @@ #' values (i.e., propensity scores) when `distance` is supplied as a #' method of estimating propensity scores or a numeric vector.} #' \item{discarded}{a logical vector denoting whether each observation was -#' discarded (`TRUE`) or not (`FALSE`) by the argument to -#' `discard`.} -#' \item{s.weights}{the vector of sampling weights supplied to -#' the `s.weights` argument, if any.} -#' \item{exact}{a one-sided formula -#' containing the variables, if any, supplied to `exact`.} -#' \item{mahvars}{a one-sided formula containing the variables, if any, -#' supplied to `mahvars`.} -#' \item{obj}{when `include.obj = TRUE`, an -#' object containing the intermediate results of the matching procedure. See +#' discarded (`TRUE`) or not (`FALSE`) by the argument to `discard`.} +#' \item{s.weights}{the vector of sampling weights supplied to the `s.weights` argument, if any.} +#' \item{exact}{a one-sided formula containing the variables, if any, supplied to `exact`.} +#' \item{mahvars}{a one-sided formula containing the variables, if any, supplied to `mahvars`.} +#' \item{obj}{when `include.obj = TRUE`, an object containing the intermediate results of the matching procedure. See #' the individual methods pages for what this component will contain.} #' #' When `method = "subclass"`, a `matchit.subclass` object with the same @@ -386,7 +375,7 @@ #' discard = "control", subclass = 10) #' s.out1 #' summary(s.out1, un = TRUE) -#' + #' @export matchit <- function(formula, data = NULL, @@ -429,7 +418,7 @@ matchit <- function(formula, } #Process formula and data inputs - if (is_null(formula) || !rlang::is_formula(formula, lhs = TRUE)) { + if (!rlang::is_formula(formula, lhs = TRUE)) { .err("`formula` must be a formula relating treatment to covariates") } @@ -453,8 +442,8 @@ matchit <- function(formula, reestimate = reestimate, s.weights = s.weights, replace = replace, ratio = ratio, m.order = m.order, estimand = estimand) - if (is_not_null(ignored.inputs)) { - for (i in ignored.inputs) assign(i, NULL) + for (i in ignored.inputs) { + assign(i, NULL) } #Process replace @@ -692,13 +681,15 @@ matchit <- function(formula, for (i in X.list.nm) { X_tmp <- get0(i, inherits = FALSE) - if (is_not_null(X_tmp)) { - if (is_null(X)) { - X <- X_tmp - } - else if (!all(hasName(X, names(X_tmp)))) { - X <- cbind(X, X_tmp[!names(X_tmp) %in% names(X)]) - } + if (is_null(X_tmp)) { + next + } + + if (is_null(X)) { + X <- X_tmp + } + else if (!all(hasName(X, names(X_tmp)))) { + X <- cbind(X, X_tmp[!names(X_tmp) %in% names(X)]) } } diff --git a/R/matchit2cardinality.R b/R/matchit2cardinality.R index 57e2c0af..5f3d119d 100644 --- a/R/matchit2cardinality.R +++ b/R/matchit2cardinality.R @@ -270,7 +270,7 @@ matchit2cardinality <- function(treat, data, discarded, formula, replace = FALSE, mahvars = NULL, exact = NULL, estimand = "ATT", verbose = FALSE, tols = .05, std.tols = TRUE, - solver = "glpk", time = 1*60, ...) { + solver = "highs", time = 1*60, ...) { .cat_verbose("Cardinality matching... \n", verbose = verbose) @@ -293,7 +293,7 @@ matchit2cardinality <- function(treat, data, discarded, formula, lab <- names(treat) - weights <- setNames(rep(0, length(treat)), lab) + weights <- rep_with(0, treat) X <- get.covs.matrix(formula, data = data) @@ -322,7 +322,7 @@ matchit2cardinality <- function(treat, data, discarded, formula, mahcovs <- transform_covariates(mahvars, data = data, method = "mahalanobis", s.weights = s.weights, treat = treat, discarded = discarded) - pair <- setNames(rep(NA_character_, length(treat)), lab) + pair <- rep_with(NA_character_, treat) #Set max problem size to Inf and return to original value after match omps <- getOption("optmatch_max_problem_size") @@ -338,7 +338,7 @@ matchit2cardinality <- function(treat, data, discarded, formula, chk::chk_numeric(tols) if (length(tols) == 1L) { - tols <- rep(tols, ncol(X)) + tols <- rep.int(tols, ncol(X)) } else if (length(tols) == max(assign)) { tols <- tols[assign] @@ -349,7 +349,7 @@ matchit2cardinality <- function(treat, data, discarded, formula, chk::chk_logical(std.tols) if (length(std.tols) == 1L) { - std.tols <- rep(std.tols, ncol(X)) + std.tols <- rep.int(std.tols, ncol(X)) } else if (length(std.tols) == max(assign)) { std.tols <- std.tols[assign] @@ -381,9 +381,10 @@ matchit2cardinality <- function(treat, data, discarded, formula, opt.out <- setNames(vector("list", nlevels(ex)), levels(ex)) for (e in levels(ex)[cc]) { - if (nlevels(ex) > 1 && verbose) { - cat(sprintf("Matching subgroup %s/%s: %s...\n", - match(e, levels(ex)[cc]), length(cc), e)) + if (nlevels(ex) > 1) { + .cat_verbose(sprintf("Matching subgroup %s/%s: %s...\n", + match(e, levels(ex)[cc]), length(cc), e), + verbose = verbose) } .e <- which(!discarded & ex == e) @@ -425,7 +426,9 @@ matchit2cardinality <- function(treat, data, discarded, formula, mm <- psclass <- NULL } - if (length(opt.out) == 1L) out <- out[[1]] + if (length(opt.out) == 1L) { + out <- out[[1]] + } res <- list(match.matrix = mm, subclass = psclass, @@ -450,15 +453,15 @@ cardinality_matchit <- function(treat, X, estimand = "ATT", tols = .05, s.weight #Check inputs if (is_null(s.weights)) { - s.weights <- rep(1, n) + s.weights <- rep.int(1, n) } else { - for (i in tvals) { - s.weights[treat == i] <- s.weights[treat == i]/mean(s.weights[treat == i]) - } + s.weights <- .make_sum_to_n(s.weights, treat) } - if (is_null(focal)) focal <- tvals[length(tvals)] + if (is_null(focal)) { + focal <- tvals[length(tvals)] + } chk::chk_number(time) chk::chk_gt(time, 0) @@ -483,15 +486,15 @@ cardinality_matchit <- function(treat, X, estimand = "ATT", tols = .05, s.weight #Objective function: total sample size O <- c( s.weights, #weight for each unit - rep(0, nt) #slack coefs for each sample size (n1, n0) + rep.int(0, nt) #slack coefs for each sample size (n1, n0) ) #Constraint matrix target.means <- apply(X, 2, wm, w = s.weights) C <- matrix(0, nrow = nt * (1 + 2*ncol(X)), ncol = length(O)) - Crhs <- rep(0, nrow(C)) - Cdir <- rep("==", nrow(C)) + Crhs <- rep.int(0, nrow(C)) + Cdir <- rep.int("==", nrow(C)) for (i in seq_len(nt)) { #Num in group i = ni @@ -513,7 +516,7 @@ cardinality_matchit <- function(treat, X, estimand = "ATT", tols = .05, s.weight #If ratio != 0, constrain n0 to be ratio*n1 if (nt == 2L && is.finite(ratio)) { - C_ratio <- c(rep(0, n), rep(-1, nt)) + C_ratio <- c(rep.int(0, n), rep.int(-1, nt)) C_ratio[n + which(tvals == focal)] <- ratio C <- rbind(C, C_ratio) Crhs <- c(Crhs, 0) @@ -521,13 +524,13 @@ cardinality_matchit <- function(treat, X, estimand = "ATT", tols = .05, s.weight } #Coef types - types <- c(rep("B", n), #Matching weights - rep("C", nt)) #Slack coefs for matched group size + types <- c(rep.int("B", n), #Matching weights + rep.int("C", nt)) #Slack coefs for matched group size - lower.bound <- c(rep(0, n), - rep(1, nt)) - upper.bound <- c(rep(1, n), - rep(Inf, nt)) + lower.bound <- c(rep.int(0, n), + rep.int(1, nt)) + upper.bound <- c(rep.int(1, n), + rep.int(Inf, nt)) } else if (match_type == "profile_att") { #Find largest control group that matches treated group @@ -538,8 +541,8 @@ cardinality_matchit <- function(treat, X, estimand = "ATT", tols = .05, s.weight #Objective function: size of matched control group O <- c( - rep(1, n0), #weights for each non-focal unit - rep(0, nt - 1) #slack coef for size of non-focal groups + rep.int(1, n0), #weights for each non-focal unit + rep.int(0, nt - 1) #slack coef for size of non-focal groups ) #Constraint matrix @@ -547,8 +550,8 @@ cardinality_matchit <- function(treat, X, estimand = "ATT", tols = .05, s.weight #One row per constraint, one column per coef C <- matrix(0, nrow = (nt - 1) * (1 + 2*ncol(X)), ncol = length(O)) - Crhs <- rep(0, nrow(C)) - Cdir <- rep("==", nrow(C)) + Crhs <- rep.int(0, nrow(C)) + Cdir <- rep.int("==", nrow(C)) for (i in seq_len(nt - 1)) { #Num in group i = ni @@ -569,13 +572,13 @@ cardinality_matchit <- function(treat, X, estimand = "ATT", tols = .05, s.weight } #Coef types - types <- c(rep("B", n0), #Matching weights - rep("C", nt - 1)) #Slack for num control matched + types <- c(rep.int("B", n0), #Matching weights + rep.int("C", nt - 1)) #Slack for num control matched - lower.bound <- c(rep(0, n0), - rep(0, nt - 1)) - upper.bound <- c(rep(1, n0), - rep(Inf, nt - 1)) + lower.bound <- c(rep.int(0, n0), + rep.int(0, nt - 1)) + upper.bound <- c(rep.int(1, n0), + rep.int(Inf, nt - 1)) } else if (match_type == "cardinality") { #True cardinality matching: find largest balanced sample @@ -591,8 +594,8 @@ cardinality_matchit <- function(treat, X, estimand = "ATT", tols = .05, s.weight t_combs <- combn(tvals, 2, simplify = FALSE) C <- matrix(0, nrow = nt + 2*ncol(X)*length(t_combs), ncol = length(O)) - Crhs <- rep(0, nrow(C)) - Cdir <- rep("==", nrow(C)) + Crhs <- rep.int(0, nrow(C)) + Cdir <- rep.int("==", nrow(C)) for (i in seq_len(nt)) { #Num in group i = ni @@ -616,13 +619,13 @@ cardinality_matchit <- function(treat, X, estimand = "ATT", tols = .05, s.weight } #Coef types - types <- c(rep("B", n), #Matching weights - rep("C", 1)) #Slack coef for treated group size (n1) + types <- c(rep.int("B", n), #Matching weights + rep.int("C", 1)) #Slack coef for treated group size (n1) - lower.bound <- c(rep(0, n), - rep(0, 1)) - upper.bound <- c(rep(1, n), - rep(min(tabulateC(treat)), 1)) + lower.bound <- c(rep.int(0, n), + rep.int(0, 1)) + upper.bound <- c(rep.int(1, n), + rep.int(min(tabulateC(treat)), 1)) } weights <- NULL @@ -643,7 +646,7 @@ cardinality_matchit <- function(treat, X, estimand = "ATT", tols = .05, s.weight weights <- round(sol[seq_len(n)]) } else if (match_type %in% c("profile_att")) { - weights <- rep(1, n) + weights <- rep.int(1, n) weights[treat != focal] <- round(sol[seq_len(n0)]) } @@ -700,7 +703,7 @@ cardinality_error_report <- function(out, solver) { } } -dispatch_optimizer <- function(solver = "glpk", obj, mat, dir, rhs, types, max = TRUE, +dispatch_optimizer <- function(solver = "highs", obj, mat, dir, rhs, types, max = TRUE, lb = NULL, ub = NULL, time = NULL, verbose = FALSE) { if (solver == "glpk") { dir[dir == "="] <- "==" @@ -722,9 +725,9 @@ dispatch_optimizer <- function(solver = "glpk", obj, mat, dir, rhs, types, max = dir[dir == "<="] <- "<" dir[dir == ">="] <- ">" dir[dir == "=="] <- "=" - # opt.out <- gurobi::gurobi(list(A = mat, obj = obj, sense = dir, rhs = rhs, vtype = types, - # modelsense = "max", lb = lb, ub = ub), - # params = list(OutputFlag = as.integer(verbose), TimeLimit = time)) + opt.out <- gurobi::gurobi(list(A = mat, obj = obj, sense = dir, rhs = rhs, vtype = types, + modelsense = "max", lb = lb, ub = ub), + params = list(OutputFlag = as.integer(verbose), TimeLimit = time)) } else if (solver == "highs") { rhs_h <- lhs_h <- rhs diff --git a/R/matchit2cem.R b/R/matchit2cem.R index c1b76076..cc37f52a 100644 --- a/R/matchit2cem.R +++ b/R/matchit2cem.R @@ -391,7 +391,7 @@ cem_matchit <- function(treat, X, cutpoints = "sturges", grouping = list(), ...) #Process cutpoints if (!is.list(cutpoints)) { - cutpoints <- setNames(rep(list(cutpoints), sum(is.numeric.cov)), names(X)[is.numeric.cov]) + cutpoints <- setNames(rep.int(list(cutpoints), sum(is.numeric.cov)), names(X)[is.numeric.cov]) } if (is_null(names(cutpoints))) { @@ -431,7 +431,7 @@ cem_matchit <- function(treat, X, cutpoints = "sturges", grouping = list(), ...) ngettext(nnnic, "its", "their")), n = nnnic) } - bad.cuts <- setNames(rep(FALSE, length(cutpoints)), names(cutpoints)) + bad.cuts <- rep_with(FALSE, cutpoints) for (i in names(cutpoints)) { if (is_null(cutpoints[[i]])) { @@ -470,7 +470,7 @@ cem_matchit <- function(treat, X, cutpoints = "sturges", grouping = list(), ...) } if (is_null(X)) { - return(setNames(rep(1L, length(treat)), names(treat))) + return(rep_with(1L, treat)) } #Create bins for numeric variables @@ -549,7 +549,7 @@ do_k2k <- function(treat, X, subclass, k2k.method = "mahalanobis", mpower = 2, s else { k2k.method <- "euclidean" X.match <- NULL - distance <- rep(0.0, length(treat)) + distance <- rep.int(0.0, length(treat)) } reuse.max <- 1L @@ -557,7 +557,7 @@ do_k2k <- function(treat, X, subclass, k2k.method = "mahalanobis", mpower = 2, s if (k2k.method %in% matchit_distances()) { discarded <- is.na(subclass) - ratio <- rep(1L, sum(treat == focal)) + ratio <- rep.int(1L, sum(treat == focal)) mm <- nn_matchC_dispatch(treat, focal, ratio, discarded, reuse.max, distance, NULL, subclass, caliper.dist, caliper.covs, caliper.covs.mat, X.match, @@ -570,9 +570,9 @@ do_k2k <- function(treat, X, subclass, k2k.method = "mahalanobis", mpower = 2, s for (s in levels(subclass)) { .e <- which(subclass == s) treat_ <- treat[.e] - discarded_ <- rep(FALSE, length(.e)) + discarded_ <- rep.int(FALSE, length(.e)) ex_ <- NULL - ratio_ <- rep(1L, sum(treat_ == focal)) + ratio_ <- rep.int(1L, sum(treat_ == focal)) distance_mat <- as.matrix(dist(X.match[.e,,drop = FALSE], method = k2k.method, p = mpower))[treat_ == focal, treat_ != focal, drop = FALSE] diff --git a/R/matchit2full.R b/R/matchit2full.R index b9f9429a..94af3619 100644 --- a/R/matchit2full.R +++ b/R/matchit2full.R @@ -279,7 +279,7 @@ matchit2full <- function(treat, formula, data, distance, discarded, } } else { - ex <- factor(rep("_", length(treat_)), levels = "_") + ex <- gl(1, length(treat_), labels = "_") cc <- 1 } @@ -346,10 +346,12 @@ matchit2full <- function(treat, formula, data, distance, discarded, } #Initialize pair membership; must include names - pair <- setNames(rep(NA_character_, length(treat)), names(treat)) + pair <- rep_with(NA_character_, treat) p <- setNames(vector("list", nlevels(ex)), levels(ex)) - t_df <- data.frame(treat) + A$data <- data.frame(treat) #just to get rownames; not actually used in matching + A$min.controls <- min.controls + A$max.controls <- max.controls for (e in levels(ex)[cc]) { if (nlevels(ex) > 1L) { @@ -372,13 +374,10 @@ matchit2full <- function(treat, formula, data, distance, discarded, next } + A$x <- mo_ + matchit_try({ - p[[e]] <- do.call(optmatch::fullmatch, - c(list(mo_, - min.controls = min.controls, - max.controls = max.controls, - data = t_df), #just to get rownames; not actually used in matching - A)) + p[[e]] <- do.call(optmatch::fullmatch, A) }, from = "optmatch") pair[names(p[[e]])[!is.na(p[[e]])]] <- paste(as.character(p[[e]][!is.na(p[[e]])]), e, sep = "|") diff --git a/R/matchit2genetic.R b/R/matchit2genetic.R index d6269d05..13e61ffc 100644 --- a/R/matchit2genetic.R +++ b/R/matchit2genetic.R @@ -338,7 +338,7 @@ matchit2genetic <- function(treat, data, distance, discarded, X <- cbind(X, ex) - exact.log <- c(rep(FALSE, ncol(X) - 1), TRUE) + exact.log <- c(rep.int(FALSE, ncol(X) - 1L), TRUE) } else { exact.log <- ex <- NULL @@ -353,7 +353,9 @@ matchit2genetic <- function(treat, data, distance, discarded, X <- cbind(X, calcovs) #Expand exact.log for newly added covariates - if (is_not_null(exact.log)) exact.log <- c(exact.log, rep(FALSE, ncol(calcovs))) + if (is_not_null(exact.log)) { + exact.log <- c(exact.log, rep.int(FALSE, ncol(calcovs))) + } } else { cov.cals <- NULL @@ -367,7 +369,7 @@ matchit2genetic <- function(treat, data, distance, discarded, }, numeric(1L)) #cal needs one value per variable in X - cal <- setNames(rep(Inf, ncol(X)), colnames(X)) + cal <- setNames(rep.int(Inf, ncol(X)), colnames(X)) #First put covariate calipers into cal if (is_not_null(cov.cals)) { diff --git a/R/matchit2nearest.R b/R/matchit2nearest.R index 4205e690..1d68d439 100644 --- a/R/matchit2nearest.R +++ b/R/matchit2nearest.R @@ -385,9 +385,8 @@ matchit2nearest <- function(treat, data, distance, discarded, #Process antiexact antiexactcovs <- NULL if (is_not_null(antiexact)) { - antiexactcovs <- model.frame(antiexact, data) - antiexactcovs <- do.call("cbind", lapply(seq_len(ncol(antiexactcovs)), function(i) { - unclass(as.factor(antiexactcovs[[i]])) + antiexactcovs <- do.call("cbind", lapply(model.frame(antiexact, data), function(i) { + unclass(as.factor(i)) })) } @@ -492,7 +491,7 @@ matchit2nearest <- function(treat, data, distance, discarded, kmin <- n1 - kmax - 1 kmed <- m - (min.controls * kmin + max.controls * kmax) - ratio0 <- c(rep(min.controls, kmin), kmed, rep(max.controls, kmax)) + ratio0 <- c(rep.int(min.controls, kmin), kmed, rep.int(max.controls, kmax)) #Make sure no units are assigned 0 matches while (any(ratio0 == 0)) { diff --git a/R/matchit2optimal.R b/R/matchit2optimal.R index 33090f4f..0624a694 100644 --- a/R/matchit2optimal.R +++ b/R/matchit2optimal.R @@ -323,7 +323,7 @@ matchit2optimal <- function(treat, formula, data, distance, discarded, } } else { - ex <- factor(rep("_", length(treat_)), levels = "_") + ex <- gl(1, length(treat_), labels = "_") cc <- 1 e_ratios <- setNames(sum(treat_ == 0)/sum(treat_ == 1), levels(ex)) @@ -377,7 +377,7 @@ matchit2optimal <- function(treat, formula, data, distance, discarded, } #Initialize pair membership; must include names - pair <- setNames(rep(NA_character_, length(treat)), names(treat)) + pair <- rep_with(NA_character_, treat) p <- setNames(vector("list", nlevels(ex)), levels(ex)) t_df <- data.frame(treat_) @@ -420,14 +420,14 @@ matchit2optimal <- function(treat, formula, data, distance, discarded, max.controls_ <- max.controls } + A$x <- mo_ + A$mean.controls <- ratio_ + A$min.controls <- min.controls_ + A$max.controls <- max.controls_ + A$data <- t_df[ex == e,, drop = FALSE] #just to get rownames; not actually used in matching + matchit_try({ - p[[e]] <- do.call(optmatch::fullmatch, - c(list(mo_, - mean.controls = ratio_, - min.controls = min.controls_, - max.controls = max.controls_, - data = t_df[ex == e,, drop = FALSE]), #just to get rownames; not actually used in matching - A)) + p[[e]] <- do.call(optmatch::fullmatch, A) }, from = "optmatch") pair[names(p[[e]])[!is.na(p[[e]])]] <- paste(as.character(p[[e]][!is.na(p[[e]])]), e, sep = "|") diff --git a/R/matchit2quick.R b/R/matchit2quick.R index f3589ac0..4a8a62de 100644 --- a/R/matchit2quick.R +++ b/R/matchit2quick.R @@ -184,7 +184,7 @@ matchit2quick <- function(treat, formula, data, distance, discarded, } } else { - ex <- factor(rep("_", length(treat_)), levels = "_") + ex <- gl(1, length(treat_), labels = "_") cc <- 1 } @@ -215,8 +215,10 @@ matchit2quick <- function(treat, formula, data, distance, discarded, } } + A$caliper <- caliper + #Initialize pair membership; must include names - pair <- setNames(rep(NA_character_, length(treat)), names(treat)) + pair <- rep_with(NA_character_, treat) p <- setNames(vector("list", nlevels(ex)), levels(ex)) for (e in levels(ex)[cc]) { @@ -226,14 +228,11 @@ matchit2quick <- function(treat, formula, data, distance, discarded, verbose = verbose) } - distcovs_ <- distcovs[ex == e,, drop = FALSE] + A$distances <- distcovs[ex == e,, drop = FALSE] + A$treatments <- treat_[ex == e] matchit_try({ - p[[e]] <- do.call(quickmatch::quickmatch, - c(list(distcovs_, - treatments = treat_[ex == e], - caliper = caliper), - A)) + p[[e]] <- do.call(quickmatch::quickmatch, A) }, from = "quickmatch") pair[which(ex == e)[!is.na(p[[e]])]] <- paste(as.character(p[[e]][!is.na(p[[e]])]), e, sep = "|") diff --git a/R/matchit2subclass.R b/R/matchit2subclass.R index 93fbaed8..018c3eb3 100644 --- a/R/matchit2subclass.R +++ b/R/matchit2subclass.R @@ -201,7 +201,7 @@ matchit2subclass <- function(treat, distance, discarded, quantile(distance, probs = sprobs, na.rm = TRUE)) ## Calculating Subclasses - psclass <- setNames(rep(NA_integer_, n.obs), names(treat)) + psclass <- rep_with(NA_integer_, treat) psclass[!discarded] <- as.integer(findInterval(distance[!discarded], q, all.inside = TRUE)) if (!has_n_unique(na.omit(psclass), subclass)) { diff --git a/R/plot.matchit.R b/R/plot.matchit.R index 0c48d3ec..8cffd29a 100644 --- a/R/plot.matchit.R +++ b/R/plot.matchit.R @@ -473,7 +473,9 @@ matchit.covplot.subclass <- function(object, type = "qq", which.subclass = NULL, } chars.in.X <- vapply(X, is.character, logical(1L)) - X[chars.in.X] <- lapply(X[chars.in.X], factor) + for (i in which(chars.in.X)) { + X[[i]] <- factor(X[[i]]) + } X <- droplevels(X) diff --git a/R/rbind.matchdata.R b/R/rbind.matchdata.R index 295743ce..db6a3ba5 100644 --- a/R/rbind.matchdata.R +++ b/R/rbind.matchdata.R @@ -95,7 +95,7 @@ rbind.matchdata <- function(..., deparse.level = 1) { attrs <- c("distance", "weights", "subclass", "id") attr_list <- setNames(vector("list", length(attrs)), attrs) - key_attrs <- setNames(rep(NA_character_, length(attrs)), attrs) + key_attrs <- setNames(rep.int(NA_character_, length(attrs)), attrs) for (i in attrs) { attr_list[[i]] <- unlist(lapply(md_list, function(m) { diff --git a/R/summary.matchit.R b/R/summary.matchit.R index 4be27821..c40f0b3e 100644 --- a/R/summary.matchit.R +++ b/R/summary.matchit.R @@ -727,13 +727,13 @@ print.summary.matchit.subclass <- function(x, digits = max(3, getOption("digits" #Attempt to extract data from matchit object; same as match.data() data.found <- FALSE for (i in 1:4) { - if (i == 2) { + if (i == 2L) { data <- try(eval(object$call$data, envir = environment(object$formula)), silent = TRUE) } - else if (i == 3) { + else if (i == 3L) { data <- try(eval(object$call$data, envir = parent.frame()), silent = TRUE) } - else if (i == 4) { + else if (i == 4L) { data <- object[["model"]][["data"]] } diff --git a/R/utils.R b/R/utils.R index 8e1cce4b..d9f6fc42 100644 --- a/R/utils.R +++ b/R/utils.R @@ -165,8 +165,8 @@ interaction2 <- function(..., sep = ".", lex.order = TRUE) { lev <- { if (is.null(lex.order)) unique(out) - else if (lex.order) unique(out[order(do.call("paste", c(args_char, sep = sep)))]) - else unique(out[order(do.call("paste", c(rev(args_char), sep = sep)))]) + else if (lex.order) unique(out[order(do.call("paste", c(args_char, list(sep = sep))))]) + else unique(out[order(do.call("paste", c(rev(args_char), list(sep = sep))))]) } factor(out, levels = lev) @@ -178,7 +178,7 @@ binarize <- function(variable, zero = NULL, one = NULL) { var.name <- deparse1(substitute(variable)) if (has_n_unique(variable, 1L)) { - return(setNames(rep.int(1L, length(variable)), names(variable))) + return(rep_with(1L, variable)) } if (!has_n_unique(variable, 2L)) { @@ -427,6 +427,7 @@ diff1 <- function(x) { x } +#Extract variables from ..., similar to ...elt(), by name without evaluating list(...) ...get <- function(x, ...) { m <- match(x, ...names(), 0L) @@ -437,6 +438,11 @@ diff1 <- function(x) { ...elt(m) } +#Helper function to fill named vectors with x and given names of y +rep_with <- function(x, y) { + setNames(rep.int(x, length(y)), names(y)) +} + #cat() if verbose = TRUE (default sep = "", line wrapping) .cat_verbose <- function(..., verbose = TRUE, sep = "") { if (!verbose) { From 57426b3a6c13ab051839c1add8f6f4fdfdbd7b9c Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Tue, 12 Nov 2024 18:42:26 -0500 Subject: [PATCH 46/48] Doc and vignette updates --- man/matchit.Rd | 36 +++++++++++---------------------- vignettes/assessing-balance.Rmd | 2 +- vignettes/matching-methods.Rmd | 2 +- vignettes/sampling-weights.Rmd | 4 ++-- 4 files changed, 16 insertions(+), 28 deletions(-) diff --git a/man/matchit.Rd b/man/matchit.Rd index 7ad2b399..d5976b61 100644 --- a/man/matchit.Rd +++ b/man/matchit.Rd @@ -71,8 +71,7 @@ used.} argument controlling the link function used in estimating the distance measure. Allowable options depend on the specific \code{distance} value specified. See \code{\link{distance}} for allowable options with each -option. The default is \code{"logit"}, which, along with \code{distance = "glm"}, identifies the default measure as logistic regression propensity -scores.} +option. The default is \code{"logit"}, which, along with \code{distance = "glm"}, identifies the default measure as logistic regression propensity scores.} \item{distance.options}{a named list containing additional arguments supplied to the function that estimates the distance measure as determined @@ -96,8 +95,7 @@ propensity scores. Usually used to perform Mahalanobis distance matching within propensity score calipers, where the propensity scores are computed using \code{formula} and \code{distance}. Can be specified as a string containing the names of variables in \code{data} to be used or a one-sided -formula with the desired variables on the right-hand side (e.g., \code{~ X3 + X4}). See the individual methods pages for information on whether and how -this argument is used.} +formula with the desired variables on the right-hand side (e.g., \code{~ X3 + X4}). See the individual methods pages for information on whether and how this argument is used.} \item{antiexact}{for methods that allow it, for which variables anti-exact matching should take place. Anti-exact matching ensures paired individuals @@ -180,7 +178,7 @@ additional arguments are allowed for each method.} When \code{method} is something other than \code{"subclass"}, a \code{matchit} object with the following components: -\item{match.matrix}{a matrix containing the matches. The rownames correspond +\item{match.matrix}{a matrix containing the matches. The row names correspond to the treated units and the values in each row are the names (or indices) of the control units matched to each treated unit. When treated units are matched to different numbers of control units (e.g., with variable ratio matching or @@ -196,13 +194,10 @@ the model used to estimate propensity scores when \code{distance} is specified as a method of estimating propensity scores. When \code{reestimate = TRUE}, this is the model estimated after discarding units.} -\item{X}{a data frame of covariates mentioned in \code{formula}, -\code{exact}, \code{mahvars}, \code{caliper}, and \code{antiexact}.} +\item{X}{a data frame of covariates mentioned in \code{formula}, \code{exact}, \code{mahvars}, \code{caliper}, and \code{antiexact}.} \item{call}{the \code{matchit()} call.} -\item{info}{information on the matching method and -distance measures used.} -\item{estimand}{the argument supplied to -\code{estimand}.} +\item{info}{information on the matching method and distance measures used.} +\item{estimand}{the argument supplied to \code{estimand}.} \item{formula}{the \code{formula} supplied.} \item{treat}{a vector of treatment status converted to zeros (0) and ones (1) if not already in that format.} @@ -210,16 +205,11 @@ distance measures used.} values (i.e., propensity scores) when \code{distance} is supplied as a method of estimating propensity scores or a numeric vector.} \item{discarded}{a logical vector denoting whether each observation was -discarded (\code{TRUE}) or not (\code{FALSE}) by the argument to -\code{discard}.} -\item{s.weights}{the vector of sampling weights supplied to -the \code{s.weights} argument, if any.} -\item{exact}{a one-sided formula -containing the variables, if any, supplied to \code{exact}.} -\item{mahvars}{a one-sided formula containing the variables, if any, -supplied to \code{mahvars}.} -\item{obj}{when \code{include.obj = TRUE}, an -object containing the intermediate results of the matching procedure. See +discarded (\code{TRUE}) or not (\code{FALSE}) by the argument to \code{discard}.} +\item{s.weights}{the vector of sampling weights supplied to the \code{s.weights} argument, if any.} +\item{exact}{a one-sided formula containing the variables, if any, supplied to \code{exact}.} +\item{mahvars}{a one-sided formula containing the variables, if any, supplied to \code{mahvars}.} +\item{obj}{when \code{include.obj = TRUE}, an object containing the intermediate results of the matching procedure. See the individual methods pages for what this component will contain.} When \code{method = "subclass"}, a \code{matchit.subclass} object with the same @@ -352,8 +342,7 @@ as the sum of the inverse of the number of control units matched to the same treated unit across its matches. For example, if a control unit was matched to a treated unit that had two other control units matched to it, and that same control was matched to a treated unit that had one other control unit -matched to it, the control unit in question would get a weight of 1/3 + 1/2 -= 5/6. For the ATC, the same is true with the treated and control labels +matched to it, the control unit in question would get a weight of \eqn{1/3 + 1/2 = 5/6}. For the ATC, the same is true with the treated and control labels switched. The weights are computed using the \code{match.matrix} component of the \code{matchit()} output object. } @@ -427,7 +416,6 @@ s.out1 <- matchit(treat ~ age + educ + race + nodegree + discard = "control", subclass = 10) s.out1 summary(s.out1, un = TRUE) - } \references{ Ho, D. E., Imai, K., King, G., & Stuart, E. A. (2007). Matching diff --git a/vignettes/assessing-balance.Rmd b/vignettes/assessing-balance.Rmd index e7e74a22..dd4358d8 100644 --- a/vignettes/assessing-balance.Rmd +++ b/vignettes/assessing-balance.Rmd @@ -281,7 +281,7 @@ love.plot(m.out, stats = c("m", "ks"), poly = 2, abs = TRUE, position = "bottom") ``` -The `love.plot()` documentation explains what each of these arguments do and the several other ones available. See `vignette("cobalt_A4_love.plot", package = "cobalt")` for other advanced customization of `love.plot()`. +The `love.plot()` documentation explains what each of these arguments do and the several other ones available. See `vignette("love.plot", package = "cobalt")` for other advanced customization of `love.plot()`. ### `bal.plot()` diff --git a/vignettes/matching-methods.Rmd b/vignettes/matching-methods.Rmd index 021649b5..2d667937 100644 --- a/vignettes/matching-methods.Rmd +++ b/vignettes/matching-methods.Rmd @@ -32,7 +32,7 @@ Matching is nonparametric in the sense that the estimated weights and pruning of It is important to note that this implementation of matching differs from the methods described by Abadie and Imbens [-@abadie2006; -@abadie2016] and implemented in the `Matching` R package and `teffects` routine in Stata. That form of matching is *matching imputation*, where the missing potential outcomes for each unit are imputed using the observed outcomes of paired units. This is a critical distinction because matching imputation is a specific estimation method with its own effect and standard error estimators, in contrast to subset selection, which is a preprocessing method that does not require specific estimators and is broadly compatible with other parametric and nonparametric analyses. The benefits of matching imputation are that its theoretical properties (i.e., the rate of convergence and asymptotic variance of the estimator) are well understood, it can be used in a straightforward way to estimate not just the average treatment effect in the treated (ATT) but also the average treatment effect in the population (ATE), and additional effective matching methods can be used in the imputation (e.g., kernel matching). The benefits of matching as nonparametric preprocessing are that it is far more flexible with respect to the types of effects that can be estimated because it does not involve any specific estimator, its empirical and finite-sample performance has been examined in depth and is generally well understood, and it aligns well with the design of experiments, which are more familiar to non-technical audiences. -In addition to subset selection, matching often (though not always) involves a form of *stratification*, the assignment of units to pairs or strata containing multiple units. The distinction between subset selection and stratification is described by @zubizarreta2014, who separate them into two separate steps. In `MatchIt`, with almost all matching methods, subset selection is performed by stratification; for example, treated units are paired with control units, and unpaired units are then dropped from the matched sample. With some methods, subclasses are used to assign matching or stratification weights to individual units, which increase or decrease each unit's leverage in a subsequent analysis. There has been some debate about the importance of stratification after subset selection; while some authors have argued that, with some forms of matching, pair membership is incidental [@stuart2008; @schafer2008], others have argued that correctly incorporating pair membership into effect estimation can improve the quality of inferences [@austin2014a; @wan2019]. For methods that allow it, `MatchIt` includes stratum membership as an additional output of each matching specification. How these strata can be used is detailed in `vignette("Estimating Effects")`. +In addition to subset selection, matching often (though not always) involves a form of *stratification*, the assignment of units to pairs or strata containing multiple units. The distinction between subset selection and stratification is described by @zubizarreta2014, who separate them into two separate steps. In `MatchIt`, with almost all matching methods, subset selection is performed by stratification; for example, treated units are paired with control units, and unpaired units are then dropped from the matched sample. With some methods, subclasses are used to assign matching or stratification weights to individual units, which increase or decrease each unit's leverage in a subsequent analysis. There has been some debate about the importance of stratification after subset selection; while some authors have argued that, with some forms of matching, pair membership is incidental [@stuart2008; @schafer2008], others have argued that correctly incorporating pair membership into effect estimation can improve the quality of inferences [@austin2014a; @wan2019]. For methods that allow it, `MatchIt` includes stratum membership as an additional output of each matching specification. How these strata can be used is detailed in `vignette("estimating-effects")`. At the heart of `MatchIt` are three classes of methods: distance matching, stratum matching, and pure subset selection. *Distance matching* involves considering a focal group (usually the treated group) and selecting members of the non-focal group (i.e., the control group) to pair with each member of the focal group based on the *distance* between units, which can be computed in one of several ways. Members of either group that are not paired are dropped from the sample. Nearest neighbor matching (`method = "nearest"`), optimal pair matching (`method = "optimal"`), optimal full matching (`method = "full"`), generalized full matching (`method = "quick"`), and genetic matching (`method = "genetic"`) are the methods of distance matching implemented in `MatchIt`. Typically, only the average treatment in the treated (ATT) or average treatment in the control (ATC), if the control group is the focal group, can be estimated after distance matching in `MatchIt` (full matching is an exception, described later). diff --git a/vignettes/sampling-weights.Rmd b/vignettes/sampling-weights.Rmd index 6a608af9..0d5bbbd5 100644 --- a/vignettes/sampling-weights.Rmd +++ b/vignettes/sampling-weights.Rmd @@ -146,7 +146,7 @@ Note that had we not added sampling weights to `mF`, the matching specification ## Estimating the Effect -Estimating the treatment effect after matching is straightforward when using sampling weights. Effects are estimated in the same way as when sampling weights are excluded, except that the matching weights must be multiplied by the sampling weights to yield accurate, generalizable estimates. `match.data()` and `get_matches()` do this automatically, so the weights produced by these functions already are a product of the matching weights and the sampling weights. Note this will only be true if sampling weights are incorporated into the `matchit` object. +Estimating the treatment effect after matching is straightforward when using sampling weights. Effects are estimated in the same way as when sampling weights are excluded, except that the matching weights must be multiplied by the sampling weights for use in the outcome model to yield accurate, generalizable estimates. `match.data()` and `get_matches()` do this automatically, so the weights produced by these functions already are a product of the matching weights and the sampling weights. Note this will only be true if sampling weights are incorporated into the `matchit` object. With `avg_comparisons()`, only the sampling weights should be included when estimating the treatment effect. Below we estimate the effect of `A` on `Y_C` in the matched and sampling weighted sample, adjusting for the covariates to improve precision and decrease bias. @@ -165,7 +165,7 @@ avg_comparisons(fit, wts = "SW") ``` -Note that `match.data()` and `get_weights()` have the option `include.s.weights`, which, when set to `FALSE`, makes it so the returned weights do not incorporate the sampling weights and are simply the matching weights. Because one might to forget to multiply the two sets of weights together, it is easier to just use the default of `include.s.weights = TRUE` and ignore the sampling weights in the rest of the analysis (because they are already included in the returned weights). `avg_comparisons()` also works more smoothly when the weights supplied to `weights` is a single variable rather than the product of two. +Note that `match.data()` and `get_weights()` have the option `include.s.weights`, which, when set to `FALSE`, makes it so the returned weights do not incorporate the sampling weights and are simply the matching weights. Because one might to forget to multiply the two sets of weights together, it is easier to just use the default of `include.s.weights = TRUE` and ignore the sampling weights in the rest of the analysis (because they are already included in the returned weights). ## Code to Generate Data used in Examples From 6c32225e04b049d4603ce93e1308288858cbf72c Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Tue, 12 Nov 2024 18:44:44 -0500 Subject: [PATCH 47/48] Doc and vignette updates --- R/matchit.R | 2 +- man/matchit.Rd | 2 +- vignettes/estimating-effects.Rmd | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/R/matchit.R b/R/matchit.R index 67224dcc..c5011e66 100644 --- a/R/matchit.R +++ b/R/matchit.R @@ -30,7 +30,7 @@ #' matching, [`"cem"`][method_cem] for coarsened exact matching, #' [`"exact"`][method_exact] for exact matching, #' [`"cardinality"`][method_cardinality] for cardinality and -#' template matching, and [`"subclass"`][method_subclass] for +#' profile matching, and [`"subclass"`][method_subclass] for #' subclassification. When set to `NULL`, no matching will occur, but #' propensity score estimation and common support restrictions will still occur #' if requested. See the linked pages for each method for more details on what diff --git a/man/matchit.Rd b/man/matchit.Rd index d5976b61..5f83cd6a 100644 --- a/man/matchit.Rd +++ b/man/matchit.Rd @@ -50,7 +50,7 @@ full matching, \code{\link[=method_genetic]{"genetic"}} for genetic matching, \code{\link[=method_cem]{"cem"}} for coarsened exact matching, \code{\link[=method_exact]{"exact"}} for exact matching, \code{\link[=method_cardinality]{"cardinality"}} for cardinality and -template matching, and \code{\link[=method_subclass]{"subclass"}} for +profile matching, and \code{\link[=method_subclass]{"subclass"}} for subclassification. When set to \code{NULL}, no matching will occur, but propensity score estimation and common support restrictions will still occur if requested. See the linked pages for each method for more details on what diff --git a/vignettes/estimating-effects.Rmd b/vignettes/estimating-effects.Rmd index 0a49d87e..63714d0e 100644 --- a/vignettes/estimating-effects.Rmd +++ b/vignettes/estimating-effects.Rmd @@ -109,7 +109,7 @@ This guide is structured as follows: first, information on the concepts related Before an effect is estimated, the estimand must be specified and clarified. Although some aspects of the estimand depend not only on how the effect is estimated after matching but also on the matching method itself, other aspects must be considered at the time of effect estimation and interpretation. Here, we consider three aspects of the estimand: the population the effect is meant to generalize to (the target population), the effect measure, and whether the effect is marginal or conditional. -**The target population.** Different matching methods allow you to estimate effects that can generalize to different target populations. The most common estimand in matching is the average treatment effect in the treated (ATT), which is the average effect of treatment for those who receive treatment. This estimand is estimable for matching methods that do not change the treated units (i.e., by weighting or discarding units) and is requested in `matchit()` by setting `estimand = "ATT"` (which is the default). The average treatment effect in the population (ATE) is the average effect of treatment for the population from which the sample is a random sample. This estimand is estimable only for methods that allow the ATE and either do not discard units from the sample or explicit target full sample balance, which in `MatchIt` is limited to full matching, subclassification, and template matching when setting `estimand = "ATE"`. When treated units are discarded (e.g., through the use of common support restrictions, calipers, cardinality matching, or [coarsened] exact matching), the estimand corresponds to neither the population ATT nor the population ATE, but rather to an average treatment effect in the remaining matched sample (ATM), which may not correspond to any specific target population. See @greiferChoosingEstimandWhen2021 for a discussion on the substantive considerations involved when choosing the target population of the estimand. +**The target population.** Different matching methods allow you to estimate effects that can generalize to different target populations. The most common estimand in matching is the average treatment effect in the treated (ATT), which is the average effect of treatment for those who receive treatment. This estimand is estimable for matching methods that do not change the treated units (i.e., by weighting or discarding units) and is requested in `matchit()` by setting `estimand = "ATT"` (which is the default). The average treatment effect in the population (ATE) is the average effect of treatment for the population from which the sample is a random sample. This estimand is estimable only for methods that allow the ATE and either do not discard units from the sample or explicit target full sample balance, which in `MatchIt` is limited to full matching, subclassification, and profile matching when setting `estimand = "ATE"`. When treated units are discarded (e.g., through the use of common support restrictions, calipers, cardinality matching, or [coarsened] exact matching), the estimand corresponds to neither the population ATT nor the population ATE, but rather to an average treatment effect in the remaining matched sample (ATM), which may not correspond to any specific target population. See @greiferChoosingEstimandWhen2021 for a discussion on the substantive considerations involved when choosing the target population of the estimand. **Marginal and conditional effects.** A marginal effect is a comparison between the expected potential outcome under treatment and the expected potential outcome under control. This is the same quantity estimated in randomized trials without blocking or covariate adjustment and is particularly useful for quantifying the overall effect of a policy or population-wide intervention. A conditional effect is the comparison between the expected potential outcomes in the treatment groups within strata. This is useful for identifying the effect of a treatment for an individual patient or a subset of the population. @@ -199,7 +199,7 @@ For almost all matching methods, whether a caliper, common support restriction, [^3]: The matching weights are not necessary when performing 1:1 matching, but we include them here for generality. When weights are not necessary, including them does not affect the estimates. Because it may not always be clear when weights are required, we recommend always including them. -There are a few adjustments that need to be made for certain scenarios, which we describe in the section "Adjustments to the Standard Case". These adjustments include for the following cases: when matching for the ATE rather than the ATT, for matching with replacement, for matching with a method that doesn't involve creating pairs (e.g., cardinality and template matching and coarsened exact matching), for subclassification, for estimating effects with binary outcomes, and for estimating effects with survival outcomes. You must read the Standard Case to understand the basic procedure before reading about these special scenarios. +There are a few adjustments that need to be made for certain scenarios, which we describe in the section "Adjustments to the Standard Case". These adjustments include for the following cases: when matching for the ATE rather than the ATT, for matching with replacement, for matching with a method that doesn't involve creating pairs (e.g., cardinality and profile matching and coarsened exact matching), for subclassification, for estimating effects with binary outcomes, and for estimating effects with survival outcomes. You must read the Standard Case to understand the basic procedure before reading about these special scenarios. Here, we demonstrate the faster analytic approach to estimating confidence intervals; for the bootstrap approach, see the section "Using Bootstrapping to Estimate Confidence Intervals" below. @@ -269,7 +269,7 @@ Because control units do not belong to unique pairs, there is no pair membership #### Matching without pairing -Some matching methods do not involve creating pairs; these include cardinality and template matching with `mahvars = NULL` (the default), exact matching, and coarsened exact matching with `k2k = FALSE` (the default). The only change that needs to be made to the Standard Case is that one should change `vcov = ~subclass` to `vcov = "HC3"` in the calls to `avg_comparisons()` and `avg_predictions()` to use robust SEs instead of cluster-robust SEs. Remember that if matching is done for the ATE (even if units are dropped), the `newdata` argument should be dropped. +Some matching methods do not involve creating pairs; these include cardinality and profile matching with `mahvars = NULL` (the default), exact matching, and coarsened exact matching with `k2k = FALSE` (the default). The only change that needs to be made to the Standard Case is that one should change `vcov = ~subclass` to `vcov = "HC3"` in the calls to `avg_comparisons()` and `avg_predictions()` to use robust SEs instead of cluster-robust SEs. Remember that if matching is done for the ATE (even if units are dropped), the `newdata` argument should be dropped. #### Propensity score subclassification @@ -338,7 +338,7 @@ For the marginal OR, the only thing that needs to change is that `comparison` sh There are several measures of effect size for survival outcomes. When using the Cox proportional hazards model, the quantity of interest is the hazard ratio (HR) between the treated and control groups. As with the OR, the HR is non-collapsible, which means the estimated HR will only be a valid estimate of the marginal HR when no other covariates are included in the model. Other effect measures, such as the difference in mean survival times or probability of survival after a given time, can be treated just like continuous and binary outcomes as previously described. -For the HR, we cannot compute average marginal effects and must use the coefficient on treatment in a Cox model fit without covariates[^8]. This means that we cannot use the procedures from the Standard Case. Here we describe estimating the marginal HR using `coxph()` from the `survival` package. (See `help("coxph", package = "survival")` for more information on this model.) To request cluster-robust SEs as recommended by @austin2013a, we need to supply pair membership (stored in the `subclass` column of `md`) to the `cluster` argument and set `robust = TRUE`. For matching methods that don't involve pairing (e.g., cardinality and template matching and [coarsened] exact matching), we can omit the `cluster` argument (but keep `robust = TRUE`)[^9]. +For the HR, we cannot compute average marginal effects and must use the coefficient on treatment in a Cox model fit without covariates[^8]. This means that we cannot use the procedures from the Standard Case. Here we describe estimating the marginal HR using `coxph()` from the `survival` package. (See `help("coxph", package = "survival")` for more information on this model.) To request cluster-robust SEs as recommended by @austin2013a, we need to supply pair membership (stored in the `subclass` column of `md`) to the `cluster` argument and set `robust = TRUE`. For matching methods that don't involve pairing (e.g., cardinality and profile matching and [coarsened] exact matching), we can omit the `cluster` argument (but keep `robust = TRUE`)[^9]. [^8]: It is not immediately clear how to estimate a marginal HR when covariates are included in the outcome model; though @austin2020 describe several ways of including covariates in a model to estimate the marginal HR, they do not develop SEs and little research has been done on this method, so we will not present it here. Instead, we fit a simple Cox model with the treatment as the sole predictor. From d3dfd92e6594b79551cffd888482aa62ce77b304 Mon Sep 17 00:00:00 2001 From: Noah Greifer Date: Tue, 12 Nov 2024 18:44:55 -0500 Subject: [PATCH 48/48] Prop for submission --- DESCRIPTION | 3 ++- NEWS.md | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 9b9efdc4..23dcbcc4 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,5 +1,5 @@ Package: MatchIt -Version: 4.5.5.9003 +Version: 4.6.0 Title: Nonparametric Preprocessing for Parametric Causal Inference Description: Selects matched samples of the original treated and control groups with similar covariate distributions -- can be @@ -48,6 +48,7 @@ Suggests: randomForest (>= 4.7-1), glmnet (>= 4.0), gbm (>= 2.1.7), + gurobi, cobalt (>= 4.2.3), boot, marginaleffects (>= 0.19.0), diff --git a/NEWS.md b/NEWS.md index 6f31204c..11f970db 100644 --- a/NEWS.md +++ b/NEWS.md @@ -6,7 +6,7 @@ output: `MatchIt` News and Updates ====== -# MatchIt (development version) +# MatchIt 4.6.0 Most improvements are related to performance. Some of these dramatically improve speeds for large datasets. Most come from improvements to `Rcpp` code.