diff --git a/.github/workflows/rhub.yaml b/.github/workflows/rhub.yaml new file mode 100644 index 00000000..74ec7b05 --- /dev/null +++ b/.github/workflows/rhub.yaml @@ -0,0 +1,95 @@ +# R-hub's generic GitHub Actions workflow file. It's canonical location is at +# https://github.com/r-hub/actions/blob/v1/workflows/rhub.yaml +# You can update this file to a newer version using the rhub2 package: +# +# rhub::rhub_setup() +# +# It is unlikely that you need to modify this file manually. + +name: R-hub +run-name: "${{ github.event.inputs.id }}: ${{ github.event.inputs.name || format('Manually run by {0}', github.triggering_actor) }}" + +on: + workflow_dispatch: + inputs: + config: + description: 'A comma separated list of R-hub platforms to use.' + type: string + default: 'linux,windows,macos' + name: + description: 'Run name. You can leave this empty now.' + type: string + id: + description: 'Unique ID. You can leave this empty now.' + type: string + +jobs: + + setup: + runs-on: ubuntu-latest + outputs: + containers: ${{ steps.rhub-setup.outputs.containers }} + platforms: ${{ steps.rhub-setup.outputs.platforms }} + + steps: + # NO NEED TO CHECKOUT HERE + - uses: r-hub/actions/setup@v1 + with: + config: ${{ github.event.inputs.config }} + id: rhub-setup + + linux-containers: + needs: setup + if: ${{ needs.setup.outputs.containers != '[]' }} + runs-on: ubuntu-latest + name: ${{ matrix.config.label }} + strategy: + fail-fast: false + matrix: + config: ${{ fromJson(needs.setup.outputs.containers) }} + container: + image: ${{ matrix.config.container }} + + steps: + - uses: r-hub/actions/checkout@v1 + - uses: r-hub/actions/platform-info@v1 + with: + token: ${{ secrets.RHUB_TOKEN }} + job-config: ${{ matrix.config.job-config }} + - uses: r-hub/actions/setup-deps@v1 + with: + token: ${{ secrets.RHUB_TOKEN }} + job-config: ${{ matrix.config.job-config }} + - uses: r-hub/actions/run-check@v1 + with: + token: ${{ secrets.RHUB_TOKEN }} + job-config: ${{ matrix.config.job-config }} + + other-platforms: + needs: setup + if: ${{ needs.setup.outputs.platforms != '[]' }} + runs-on: ${{ matrix.config.os }} + name: ${{ matrix.config.label }} + strategy: + fail-fast: false + matrix: + config: ${{ fromJson(needs.setup.outputs.platforms) }} + + steps: + - uses: r-hub/actions/checkout@v1 + - uses: r-hub/actions/setup-r@v1 + with: + job-config: ${{ matrix.config.job-config }} + token: ${{ secrets.RHUB_TOKEN }} + - uses: r-hub/actions/platform-info@v1 + with: + token: ${{ secrets.RHUB_TOKEN }} + job-config: ${{ matrix.config.job-config }} + - uses: r-hub/actions/setup-deps@v1 + with: + job-config: ${{ matrix.config.job-config }} + token: ${{ secrets.RHUB_TOKEN }} + - uses: r-hub/actions/run-check@v1 + with: + job-config: ${{ matrix.config.job-config }} + token: ${{ secrets.RHUB_TOKEN }} diff --git a/DESCRIPTION b/DESCRIPTION index 17606cf5..23dcbcc4 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,5 +1,5 @@ Package: MatchIt -Version: 4.5.5.9000 +Version: 4.6.0 Title: Nonparametric Preprocessing for Parametric Causal Inference Description: Selects matched samples of the original treated and control groups with similar covariate distributions -- can be @@ -33,7 +33,8 @@ Imports: Rcpp, utils, stats, - graphics + graphics, + grDevices Suggests: optmatch (>= 0.10.6), Matching, @@ -43,20 +44,20 @@ Suggests: rpart, mgcv, CBPS (>= 0.17), - dbarts, + dbarts (>= 0.9-28), randomForest (>= 4.7-1), glmnet (>= 4.0), gbm (>= 2.1.7), + gurobi, cobalt (>= 4.2.3), boot, - marginaleffects (>= 0.11.0), + marginaleffects (>= 0.19.0), sandwich (>= 2.5-1), survival, RcppProgress (>= 0.4.2), highs, Rglpk, Rsymphony, - gurobi, knitr, rmarkdown, testthat (>= 3.0.0) @@ -71,5 +72,5 @@ URL: https://kosukeimai.github.io/MatchIt/, BugReports: https://github.com/kosukeimai/MatchIt/issues VignetteBuilder: knitr Roxygen: list(markdown = TRUE) -RoxygenNote: 7.2.3 +RoxygenNote: 7.3.2 Config/testthat/edition: 3 diff --git a/NAMESPACE b/NAMESPACE index b986c168..54f41307 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -21,13 +21,11 @@ export(scaled_euclidean_dist) import(graphics) import(stats) importFrom(Rcpp,evalCpp) -importFrom(Rcpp,sourceCpp) importFrom(grDevices,devAskNewPage) importFrom(grDevices,nclass.FD) importFrom(grDevices,nclass.Sturges) importFrom(grDevices,nclass.scott) importFrom(utils,capture.output) importFrom(utils,combn) -importFrom(utils,setTxtProgressBar) -importFrom(utils,txtProgressBar) +importFrom(utils,hasName) useDynLib(MatchIt, .registration = TRUE) diff --git a/NEWS.md b/NEWS.md index f6817261..11f970db 100644 --- a/NEWS.md +++ b/NEWS.md @@ -6,10 +6,36 @@ output: `MatchIt` News and Updates ====== -# MatchIt (development version) +# MatchIt 4.6.0 + +Most improvements are related to performance. Some of these dramatically improve speeds for large datasets. Most come from improvements to `Rcpp` code. + +* When using `method = "nearest"`, `m.order` can now be set to `"farthest"` to prioritize hard-to-match treated units. Note this **does not** implement "far matching" but simply changes the order in which the closest matches are selected. + +* Speed improvements to `method = "nearest"`, especially when matching on a propensity score. + +* Speed improvements to `summary()` when `pair.dist = TRUE` and a `match.matrix` component is not included in the output (e.g., for `method = "full"` or `method = "quick"`). + +* Speed improvements to `method = "subclass"` with `min.n` greater than 0. + +* A new `normalize` argument has been added to `matchit()`. When set to `TRUE` (the default, which used to be the only option), the nonzero weights in each treatment group are rescaled to have an average of 1. When `FALSE`, the weights generated directly by the matching are returned instead. + +* When using `method = "nearest"` with `m.order = "closest"`, the full distance matrix is no longer computed, which increases support for larger samples. This uses an adaptation of an algorithm described by [Rassen et al. (2012)](https://doi.org/10.1002/pds.3263). + +* When using `method = "nearest"` with `verbose = TRUE`, the progress bar now displays an estimate of how much time remains. + +* When using `method = "nearest"` with `m.order = "closest"` and `ratio` greater than 1, all eligible units will receive their first match before any receive their second, etc. Previously, the closest pairs would be matched regardless of whether other units had been matched. This ensures consistency with other `m.order` arguments. + +* Speed and memory improvements to `method = "cem"` with many covariates and a large sample size. Previous versions used a Cartesian expansion of all levels of factor variables, which could easily explode. + +* When using `method = "cem"` with `k2k = TRUE`, `m.order` can be set to select the matching order. Allowable options include `"data"` (the default), `"closest"`, `"farthest"`, and `"random"`. `"closest"` is recommended, but `"data"` is the default for now to remain consistent with previous versions. + +* Documentation updates. * Fixed a bug when using `method = "optimal"` or `method = "full"` with `discard` specified and `data` given as a tibble (`tbl_df` object). (#185) +* Fixed a bug when using `method = "cardinality"` with a single covariate. (#194) + # MatchIt 4.5.5 * When using `method = "cardinality"`, a new solver, HiGHS, can be requested by setting `solver = "highs"`, which relies on the `highs` package. This is much faster and more reliable than GLPK and is free and easy to install as a regular R package with no additional requirements. diff --git a/R/MatchIt-package.R b/R/MatchIt-package.R index d38adc78..5efe8e1a 100644 --- a/R/MatchIt-package.R +++ b/R/MatchIt-package.R @@ -2,7 +2,6 @@ "_PACKAGE" ## usethis namespace: start -#' #' @import graphics #' @import stats #' @importFrom grDevices devAskNewPage @@ -10,11 +9,9 @@ #' @importFrom grDevices nclass.scott #' @importFrom grDevices nclass.Sturges #' @importFrom Rcpp evalCpp -#' @importFrom Rcpp sourceCpp #' @importFrom utils capture.output #' @importFrom utils combn -#' @importFrom utils setTxtProgressBar -#' @importFrom utils txtProgressBar +#' @importFrom utils hasName #' @useDynLib MatchIt, .registration = TRUE ## usethis namespace: end NULL diff --git a/R/RcppExports.R b/R/RcppExports.R index 96fa2764..054d9e79 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -1,36 +1,68 @@ # Generated by using Rcpp::compileAttributes() -> do not edit by hand # Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393 -dist_to_matrixC <- function(d) { - .Call(`_MatchIt_dist_to_matrixC`, d) +all_equal_to <- function(x, y) { + .Call(`_MatchIt_all_equal_to`, x, y) } -nn_matchC <- function(treat_, ord_, ratio, discarded, reuse_max, distance_ = NULL, distance_mat_ = NULL, exact_ = NULL, caliper_dist_ = NULL, caliper_covs_ = NULL, caliper_covs_mat_ = NULL, mah_covs_ = NULL, antiexact_covs_ = NULL, unit_id_ = NULL, disl_prog = FALSE) { - .Call(`_MatchIt_nn_matchC`, treat_, ord_, ratio, discarded, reuse_max, distance_, distance_mat_, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, mah_covs_, antiexact_covs_, unit_id_, disl_prog) +eucdistC_N1xN0 <- function(x, t) { + .Call(`_MatchIt_eucdistC_N1xN0`, x, t) } -nn_matchC_closest <- function(distance_mat, treat, ratio, discarded, reuse_max, exact_ = NULL, caliper_dist_ = NULL, caliper_covs_ = NULL, caliper_covs_mat_ = NULL, antiexact_covs_ = NULL, unit_id_ = NULL, disl_prog = FALSE) { - .Call(`_MatchIt_nn_matchC_closest`, distance_mat, treat, ratio, discarded, reuse_max, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, disl_prog) +get_splitsC <- function(x, caliper) { + .Call(`_MatchIt_get_splitsC`, x, caliper) } -nn_matchC_vec <- function(treat_, ord_, ratio_, discarded_, reuse_max, distance_, exact_ = NULL, caliper_dist_ = NULL, caliper_covs_ = NULL, caliper_covs_mat_ = NULL, antiexact_covs_ = NULL, unit_id_ = NULL, disl_prog = FALSE) { - .Call(`_MatchIt_nn_matchC_vec`, treat_, ord_, ratio_, discarded_, reuse_max, distance_, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, disl_prog) +has_n_unique <- function(x, n) { + .Call(`_MatchIt_has_n_unique`, x, n) } -pairdistsubC <- function(x_, t_, s_, num_sub) { - .Call(`_MatchIt_pairdistsubC`, x_, t_, s_, num_sub) +nn_matchC_distmat <- function(treat_, ord, ratio, discarded, reuse_max, focal_, distance_mat, exact_ = NULL, caliper_dist_ = NULL, caliper_covs_ = NULL, caliper_covs_mat_ = NULL, antiexact_covs_ = NULL, unit_id_ = NULL, disl_prog = FALSE) { + .Call(`_MatchIt_nn_matchC_distmat`, treat_, ord, ratio, discarded, reuse_max, focal_, distance_mat, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, disl_prog) } -subclass2mmC <- function(subclass, treat, focal) { - .Call(`_MatchIt_subclass2mmC`, subclass, treat, focal) +nn_matchC_distmat_closest <- function(treat, ratio, discarded, reuse_max, distance_mat, exact_ = NULL, caliper_dist_ = NULL, caliper_covs_ = NULL, caliper_covs_mat_ = NULL, antiexact_covs_ = NULL, unit_id_ = NULL, close = TRUE, disl_prog = FALSE) { + .Call(`_MatchIt_nn_matchC_distmat_closest`, treat, ratio, discarded, reuse_max, distance_mat, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, close, disl_prog) +} + +nn_matchC_mahcovs <- function(treat_, ord, ratio, discarded, reuse_max, focal_, mah_covs, distance_ = NULL, exact_ = NULL, caliper_dist_ = NULL, caliper_covs_ = NULL, caliper_covs_mat_ = NULL, antiexact_covs_ = NULL, unit_id_ = NULL, disl_prog = FALSE) { + .Call(`_MatchIt_nn_matchC_mahcovs`, treat_, ord, ratio, discarded, reuse_max, focal_, mah_covs, distance_, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, disl_prog) +} + +nn_matchC_mahcovs_closest <- function(treat, ratio, discarded, reuse_max, mah_covs, distance_ = NULL, exact_ = NULL, caliper_dist_ = NULL, caliper_covs_ = NULL, caliper_covs_mat_ = NULL, antiexact_covs_ = NULL, unit_id_ = NULL, close = TRUE, disl_prog = FALSE) { + .Call(`_MatchIt_nn_matchC_mahcovs_closest`, treat, ratio, discarded, reuse_max, mah_covs, distance_, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, close, disl_prog) +} + +nn_matchC_vec <- function(treat_, ord, ratio, discarded, reuse_max, focal_, distance, exact_ = NULL, caliper_dist_ = NULL, caliper_covs_ = NULL, caliper_covs_mat_ = NULL, antiexact_covs_ = NULL, unit_id_ = NULL, disl_prog = FALSE) { + .Call(`_MatchIt_nn_matchC_vec`, treat_, ord, ratio, discarded, reuse_max, focal_, distance, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, disl_prog) +} + +nn_matchC_vec_closest <- function(treat, ratio, discarded, reuse_max, distance, exact_ = NULL, caliper_dist_ = NULL, caliper_covs_ = NULL, caliper_covs_mat_ = NULL, antiexact_covs_ = NULL, unit_id_ = NULL, close = TRUE, disl_prog = FALSE) { + .Call(`_MatchIt_nn_matchC_vec_closest`, treat, ratio, discarded, reuse_max, distance, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, close, disl_prog) +} + +pairdistsubC <- function(x, t, s) { + .Call(`_MatchIt_pairdistsubC`, x, t, s) +} + +subclass2mmC <- function(subclass_, treat, focal) { + .Call(`_MatchIt_subclass2mmC`, subclass_, treat, focal) +} + +mm2subclassC <- function(mm, treat, focal = NULL) { + .Call(`_MatchIt_mm2subclassC`, mm, treat, focal) +} + +subclass_scootC <- function(subclass_, treat_, x_, min_n) { + .Call(`_MatchIt_subclass_scootC`, subclass_, treat_, x_, min_n) } tabulateC <- function(bins, nbins = NULL) { .Call(`_MatchIt_tabulateC`, bins, nbins) } -weights_matrixC <- function(mm, treat) { - .Call(`_MatchIt_weights_matrixC`, mm, treat) +weights_matrixC <- function(mm, treat_, focal = NULL) { + .Call(`_MatchIt_weights_matrixC`, mm, treat_, focal) } # Register entry points for exported C++ functions diff --git a/R/add_s.weights.R b/R/add_s.weights.R index 7754a255..9f4c08f8 100644 --- a/R/add_s.weights.R +++ b/R/add_s.weights.R @@ -61,63 +61,84 @@ add_s.weights <- function(m, chk::chk_is(m, "matchit") - if (!is.null(s.weights)) { - if (!is.numeric(s.weights)) { - if (is.null(data)) { - if (!is.null(m$model)) { - env <- attributes(terms(m$model))$.Environment - } else { - env <- parent.frame() - } - data <- eval(m$call$data, envir = env) - if (length(data) == 0) { - .err("a dataset could not be found. Please supply an argument to `data` containing the original dataset used in the matching") - } + if (is_null(s.weights)) { + return(m) + } + + if (!is.numeric(s.weights)) { + if (is_null(data)) { + if (is_not_null(m$model)) { + env <- attributes(terms(m$model))$.Environment } else { - if (!is.data.frame(data)) { - if (is.matrix(data)) data <- as.data.frame.matrix(data) - else .err("`data` must be a data frame") - } - if (nrow(data) != length(m$treat)) { - .err("`data` must have as many rows as there were units in the original call to `matchit()`") - } + env <- parent.frame() } - if (is.character(s.weights)) { - if (is.null(data) || !is.data.frame(data)) { - .err("if `s.weights` is specified a string, a data frame containing the named variable must be supplied to `data`") - } - if (!all(s.weights %in% names(data))) { - .err("the name supplied to `s.weights` must be a variable in `data`") - } - s.weights.form <- reformulate(s.weights) - s.weights <- model.frame(s.weights.form, data, na.action = "na.pass") - if (ncol(s.weights) != 1) .err("`s.weights` can only contain one named variable") - s.weights <- s.weights[[1]] + data <- eval(m$call$data, envir = env) + + if (is_null(data)) { + .err("a dataset could not be found. Please supply an argument to `data` containing the original dataset used in the matching") } - else if (rlang::is_formula(s.weights)) { - s.weights.form <- update(terms(s.weights, data = data), NULL ~ .) - s.weights <- model.frame(s.weights.form, data, na.action = "na.pass") - if (ncol(s.weights) != 1) .err("`s.weights` can only contain one named variable") - s.weights <- s.weights[[1]] + } + else { + if (!is.data.frame(data)) { + if (!is.matrix(data)) { + .err("`data` must be a data frame") + } + data <- as.data.frame.matrix(data) } - else { - .err("`s.weights` must be supplied as a numeric vector, string, or one-sided formula") + + if (nrow(data) != length(m$treat)) { + .err("`data` must have as many rows as there were units in the original call to `matchit()`") } } - chk::chk_not_any_na(s.weights) - if (length(s.weights) != length(m$treat)) .err("`s.weights` must be the same length as the treatment vector") + if (is.character(s.weights)) { + if (is_null(data) || !is.data.frame(data)) { + .err("if `s.weights` is specified a string, a data frame containing the named variable must be supplied to `data`") + } - names(s.weights) <- names(m$treat) + if (!all(hasName(data, s.weights))) { + .err("the name supplied to `s.weights` must be a variable in `data`") + } - attr(s.weights, "in_ps") <- isTRUE(all.equal(s.weights, m$s.weights)) + s.weights.form <- reformulate(s.weights) + s.weights <- model.frame(s.weights.form, data, na.action = "na.pass") - m$s.weights <- s.weights + if (ncol(s.weights) != 1L) { + .err("`s.weights` can only contain one named variable") + } + + s.weights <- s.weights[[1L]] + } + else if (rlang::is_formula(s.weights)) { + s.weights.form <- update(terms(s.weights, data = data), NULL ~ .) + s.weights <- model.frame(s.weights.form, data, na.action = "na.pass") + + if (ncol(s.weights) != 1L) { + .err("`s.weights` can only contain one named variable") + } + + s.weights <- s.weights[[1L]] + } + else { + .err("`s.weights` must be supplied as a numeric vector, string, or one-sided formula") + } + } - m$nn <- nn(m$treat, m$weights, m$discarded, s.weights) + chk::chk_not_any_na(s.weights) + + if (length(s.weights) != length(m$treat)) { + .err("`s.weights` must be the same length as the treatment vector") } + names(s.weights) <- names(m$treat) + + attr(s.weights, "in_ps") <- isTRUE(all.equal(s.weights, m$s.weights)) + + m$s.weights <- s.weights + + m$nn <- nn(m$treat, m$weights, m$discarded, s.weights) + m } diff --git a/R/aux_functions.R b/R/aux_functions.R index 54ffa38c..b616d886 100644 --- a/R/aux_functions.R +++ b/R/aux_functions.R @@ -1,100 +1,39 @@ -#Auxiliary functions; some from WeightIt - #Function to ensure no subclass is devoid of both treated and control units by "scooting" units -#from other subclasses. From WeightIt. +#from other subclasses. subclass_scoot <- function(sub, treat, x, min.n = 1) { #Reassigns subclasses so there are no empty subclasses - #for each treatment group. Copied from WeightIt with - #slight modifications. - - treat <- as.character(treat) - unique.treat <- unique(treat, nmax = 2) - - names(x) <- seq_along(x) - names(sub) <- seq_along(sub) - original.order <- names(x) + #for each treatment group. + subtab <- table(treat, sub) - nsub <- length(unique(sub)) + if (all(subtab >= min.n)) { + return(sub) + } - #Turn subs into a contiguous sequence - sub <- setNames(setNames(seq_len(nsub), sort(unique(sub)))[as.character(sub)], - original.order) + nsub <- ncol(subtab) - if (any(table(treat) < nsub * min.n)) { + if (any(rowSums(subtab) < nsub * min.n)) { .err(sprintf("not enough units to fit %s treated and control %s in each subclass", min.n, ngettext(min.n, "unit", "units"))) } - for (t in unique.treat) { - if (length(x[treat == t]) == nsub) { - sub[treat == t] <- seq_len(nsub) - } - } - - sub_tab <- table(treat, sub) - - if (any(sub_tab < min.n)) { - - soft_thresh <- function(x, minus = 1) { - x <- x - minus - x[x < 0] <- 0 - x - } - - for (t in unique.treat) { - for (n in seq_len(min.n)) { - while (any(sub_tab[t,] == 0)) { - first_0 <- which(sub_tab[t,] == 0)[1] - - if (first_0 == nsub || - (first_0 != 1 && - sum(soft_thresh(sub_tab[t, seq(1, first_0 - 1)]) / abs(first_0 - seq(1, first_0 - 1))) >= - sum(soft_thresh(sub_tab[t, seq(first_0 + 1, nsub)]) / abs(first_0 - seq(first_0 + 1, nsub))))) { - #If there are more and closer nonzero subs to the left... - first_non0_to_left <- max(seq(1, first_0 - 1)[sub_tab[t, seq(1, first_0 - 1)] > 0]) - - name_to_move <- names(sub)[which(x == max(x[treat == t & sub == first_non0_to_left]) & treat == t & sub == first_non0_to_left)[1]] - - sub[name_to_move] <- first_0 - sub_tab[t, first_0] <- 1L - sub_tab[t, first_non0_to_left] <- sub_tab[t, first_non0_to_left] - 1L - - } - else { - #If there are more and closer nonzero subs to the right... - first_non0_to_right <- min(seq(first_0 + 1, nsub)[sub_tab[t, seq(first_0 + 1, nsub)] > 0]) - - name_to_move <- names(sub)[which(x == min(x[treat == t & sub == first_non0_to_right]) & treat == t & sub == first_non0_to_right)[1]] - - sub[name_to_move] <- first_0 - sub_tab[t, first_0] <- 1L - sub_tab[t, first_non0_to_right] <- sub_tab[t, first_non0_to_right] - 1L - } - } - - sub_tab[t,] <- sub_tab[t,] - 1 - } - } - - #Unsort - sub <- sub[names(sub)] - } - - sub + subclass_scootC(as.integer(sub), as.integer(treat), + as.numeric(x), as.integer(min.n)) } #Create info component of matchit object -create_info <- function(method, fn1, link, discard, replace, ratio, mahalanobis, transform, subclass, antiexact, distance_is_matrix) { +create_info <- function(method, fn1, link, discard, replace, ratio, + mahalanobis, transform, subclass, antiexact, + distance_is_matrix) { info <- list(method = method, - distance = if (is.null(fn1)) NULL else sub("distance2", "", fn1, fixed = TRUE), - link = if (is.null(link)) NULL else link, + distance = if (is_null(fn1)) NULL else sub("distance2", "", fn1, fixed = TRUE), + link = if (is_null(link)) NULL else link, discard = discard, - replace = if (!is.null(method) && method %in% c("nearest", "genetic")) replace else NULL, - ratio = if (!is.null(method) && method %in% c("nearest", "optimal", "genetic")) ratio else NULL, - max.controls = if (!is.null(method) && method %in% c("nearest", "optimal")) attr(ratio, "max.controls") else NULL, + replace = if (is_not_null(method) && method %in% c("nearest", "genetic")) replace else NULL, + ratio = if (is_not_null(method) && method %in% c("nearest", "optimal", "genetic")) ratio else NULL, + max.controls = if (is_not_null(method) && method %in% c("nearest", "optimal")) attr(ratio, "max.controls") else NULL, mahalanobis = mahalanobis, transform = transform, - subclass = if (!is.null(method) && method == "subclass") length(unique(subclass[!is.na(subclass)])) else NULL, + subclass = if (is_not_null(method) && method == "subclass") length(unique(subclass[!is.na(subclass)])) else NULL, antiexact = antiexact, distance_is_matrix = distance_is_matrix) info @@ -104,36 +43,47 @@ create_info <- function(method, fn1, link, discard, replace, ratio, mahalanobis, info.to.method <- function(info) { out.list <- setNames(vector("list", 3), c("kto1", "type", "replace")) - out.list[["kto1"]] <- if (!is.null(info$ratio)) paste0(if (!is.null(info$max.controls)) "variable ratio ", round(info$ratio, 2), ":1") else NULL - out.list[["type"]] <- if (is.null(info$method)) "none (no matching)" else - switch(info$method, - "exact" = "exact matching", - "cem" = "coarsened exact matching", - "nearest" = "nearest neighbor matching", - "optimal" = "optimal pair matching", - "full" = "optimal full matching", - "quick" = "generalized full matching", - "genetic" = "genetic matching", - "subclass" = paste0("subclassification (", info$subclass, " subclasses)"), - "cardinality" = "cardinality matching", - if (is.null(attr(info$method, "method"))) "an unspecified matching method" - else attr(info$method, "method")) - out.list[["replace"]] <- if (!is.null(info$replace) && info$method %in% c("nearest", "genetic")) { - if (info$replace) "with replacement" + + out.list[["kto1"]] <- { + if (is_not_null(info$ratio)) paste0(if (is_not_null(info$max.controls)) "variable ratio ", round(info$ratio, 2), ":1") + else NULL + } + + out.list[["type"]] <- { + if (is_null(info$method)) "none (no matching)" + else switch(info$method, + "exact" = "exact matching", + "cem" = "coarsened exact matching", + "nearest" = "nearest neighbor matching", + "optimal" = "optimal pair matching", + "full" = "optimal full matching", + "quick" = "generalized full matching", + "genetic" = "genetic matching", + "subclass" = sprintf("subclassification (%s subclasses)", info$subclass), + "cardinality" = "cardinality matching", + if (is_null(attr(info$method, "method"))) "an unspecified matching method" + else attr(info$method, "method")) + } + + out.list[["replace"]] <- { + if (is_null(info$replace) || !info$method %in% c("nearest", "genetic")) NULL + else if (info$replace) "with replacement" else "without replacement" - } else NULL + } - firstup(do.call("paste", c(unname(out.list), list(sep = " ")))) + firstup(do.call("paste", unname(out.list))) } info.to.distance <- function(info) { distance <- info$distance link <- info$link - if (!is.null(link) && startsWith(as.character(link), "linear")) { + if (is_not_null(link) && startsWith(as.character(link), "linear")) { linear <- TRUE link <- sub("linear.", "", as.character(link)) } - else linear <- FALSE + else { + linear <- FALSE + } if (distance == "glm") { if (link == "logit") dist <- "logistic regression" @@ -177,295 +127,80 @@ info.to.distance <- function(info) { dist } -#Function to turn a vector into a string with "," and "and" or "or" for clean messages. 'and.or' -#controls whether words are separated by "and" or "or"; 'is.are' controls whether the list is -#followed by "is" or "are" (to avoid manually figuring out if plural); quotes controls whether -#quotes should be placed around words in string. From WeightIt. -word_list <- function(word.list = NULL, and.or = "and", is.are = FALSE, quotes = FALSE) { - #When given a vector of strings, creates a string of the form "a and b" - #or "a, b, and c" - #If is.are, adds "is" or "are" appropriately - L <- length(word.list) - word.list <- add_quotes(word.list, quotes) - - if (L == 0) { - out <- "" - attr(out, "plural") <- FALSE - } - else { - word.list <- word.list[!word.list %in% c(NA_character_, "")] - L <- length(word.list) - if (L == 0) { - out <- "" - attr(out, "plural") <- FALSE - } - else if (L == 1) { - out <- word.list - if (is.are) out <- paste(out, "is") - attr(out, "plural") <- FALSE - } - else { - and.or <- match_arg(and.or, c("and", "or")) - if (L == 2) { - out <- paste(word.list, collapse = paste0(" ", and.or, " ")) - } - else { - out <- paste(paste(word.list[seq_len(L - 1)], collapse = ", "), - word.list[L], sep = paste0(", ", and.or, " ")) - - } - if (is.are) out <- paste(out, "are") - attr(out, "plural") <- TRUE - } - - } - - out -} - -#Add quotes to a string -add_quotes <- function(x, quotes = 2L) { - if (isFALSE(quotes)) return(x) - - if (isTRUE(quotes)) quotes <- 2 - - if (chk::vld_string(quotes)) x <- paste0(quotes, x, quotes) - else if (chk::vld_whole_number(quotes)) { - if (as.integer(quotes) == 0) return(x) - else if (as.integer(quotes) == 1) x <- paste0("\'", x, "\'") - else if (as.integer(quotes) == 2) x <- paste0("\"", x, "\"") - else stop("`quotes` must be boolean, 1, 2, or a string.") - } - else { - stop("`quotes` must be boolean, 1, 2, or a string.") - } - - x -} - -#More informative and cleaner version of base::match.arg(). Uses chk. -match_arg <- function(arg, choices, several.ok = FALSE) { - #Replaces match.arg() but gives cleaner error message and processing - #of arg. - if (missing(arg)) - stop("No argument was supplied to match_arg.") - arg.name <- deparse1(substitute(arg), width.cutoff = 500L) - - if (missing(choices)) { - formal.args <- formals(sys.function(sysP <- sys.parent())) - choices <- eval(formal.args[[as.character(substitute(arg))]], - envir = sys.frame(sysP)) +#Make interaction vector out of matrix of covs; similar to interaction() +exactify <- function(X, nam = NULL, sep = "|", include_vars = FALSE, justify = "right") { + if (is_null(nam)) { + nam <- rownames(X) } - if (length(arg) == 0) return(choices[1L]) - - if (several.ok) { - chk::chk_character(arg, add_quotes(arg.name, "`")) + if (is.matrix(X)) { + X <- as.data.frame.matrix(X) } - else { - chk::chk_string(arg, add_quotes(arg.name, "`")) - if (identical(arg, choices)) return(arg[1L]) + else if (!is.list(X)) { + stop("X must be a matrix, data frame, or list.") } - i <- pmatch(arg, choices, nomatch = 0L, duplicates.ok = TRUE) - if (all(i == 0L)) - .err(sprintf("the argument to `%s` should be %s%s.", - arg.name, ngettext(length(choices), "", if (several.ok) "at least one of " else "one of "), - word_list(choices, and.or = "or", quotes = 2))) - i <- i[i > 0L] - - choices[i] -} + X <- X[lengths(X) > 0] -#Turn a vector into a 0/1 vector. 'zero' and 'one' can be supplied to make it clear which is -#which; otherwise, a guess is used. From WeightIt. -binarize <- function(variable, zero = NULL, one = NULL) { - var.name <- deparse1(substitute(variable)) - if (length(unique(variable)) > 2) { - stop(sprintf("Cannot binarize %s: more than two levels.", var.name)) - } - if (is.character(variable) || is.factor(variable)) { - variable <- factor(variable, nmax = 2) - unique.vals <- levels(variable) - } - else { - unique.vals <- unique(variable, nmax = 2) + if (is_null(X)) { + return(NULL) } - if (is.null(zero)) { - if (is.null(one)) { - if (can_str2num(unique.vals)) { - variable.numeric <- str2num(variable) - } - else { - variable.numeric <- as.numeric(variable) - } - - if (0 %in% variable.numeric) zero <- 0 - else zero <- min(variable.numeric, na.rm = TRUE) - - return(setNames(as.integer(variable.numeric != zero), names(variable))) - } - else { - if (one %in% unique.vals) return(setNames(as.integer(variable == one), names(variable))) - else stop("The argument to 'one' is not the name of a level of variable.") + for (i in seq_along(X)) { + unique_x <- { + if (is.factor(X[[i]])) levels(X[[i]]) + else sort(unique(X[[i]])) } - } - else { - if (zero %in% unique.vals) return(setNames(as.integer(variable != zero), names(variable))) - else stop("The argument to 'zero' is not the name of a level of variable.") - } -} -#Make interaction vector out of matrix of covs; similar to interaction() -exactify <- function(X, nam = NULL, sep = "|", include_vars = FALSE) { - if (is.null(nam)) nam <- rownames(X) - if (is.matrix(X)) X <- setNames(lapply(seq_len(ncol(X)), function(i) X[,i]), colnames(X)) - if (!is.list(X)) stop("X must be a matrix, data frame, or list.") - - if (include_vars) { - for (i in seq_along(X)) { - if (is.character(X[[i]]) || is.factor(X[[i]])) { - X[[i]] <- sprintf('%s = "%s"', names(X)[i], X[[i]]) - } - else { - X[[i]] <- sprintf('%s = %s', names(X)[i], X[[i]]) - } - } - } - else { - for (i in seq_along(X)) { - if (is.factor(X[[i]])) { - X[[i]] <- format(levels(X[[i]]), justify = "right")[X[[i]]] - } - else { - X[[i]] <- format(X[[i]], justify = "right") - } + lev <- { + if (include_vars) sprintf("%s = %s", + names(X)[i], + add_quotes(unique_x, chk::vld_character_or_factor(X[[i]]))) + else if (is_null(justify)) unique_x + else format(unique_x, justify = justify) } - } - - out <- do.call("paste", c(X, sep = sep)) - if (!is.null(nam)) names(out) <- nam - out -} - -#Determine whether a character vector can be coerced to numeric -can_str2num <- function(x) { - nas <- is.na(x) - suppressWarnings(x_num <- as.numeric(as.character(x[!nas]))) - !anyNA(x_num) -} - -#Cleanly coerces a character vector to numeric; best to use after can_str2num() -str2num <- function(x) { - nas <- is.na(x) - suppressWarnings(x_num <- as.numeric(as.character(x))) - is.na(x_num)[nas] <- TRUE - x_num -} - -#Capitalize first letter of string -firstup <- function(x) { - substr(x, 1, 1) <- toupper(substr(x, 1, 1)) - x -} - -#Capitalize first letter of each word -capwords <- function(s, strict = FALSE) { - cap <- function(s) paste0(toupper(substring(s, 1, 1)), - {s <- substring(s, 2); if(strict) tolower(s) else s}, - collapse = " ") - sapply(strsplit(s, split = " "), cap, USE.NAMES = !is.null(names(s))) -} - -#Clean printing of data frames with numeric and NA elements. -round_df_char <- function(df, digits, pad = "0", na_vals = "") { - #Digits is passed to round(). pad is used to replace trailing zeros so decimal - #lines up. Should be "0" or " "; "" (the empty string) un-aligns decimals. - #na_vals is what NA should print as. - - if (NROW(df) == 0 || NCOL(df) == 0) return(as.matrix(df)) - if (!is.data.frame(df)) df <- as.data.frame.matrix(df, stringsAsFactors = FALSE) - rn <- rownames(df) - cn <- colnames(df) - - infs <- o.negs <- array(FALSE, dim = dim(df)) - nas <- is.na(df) - nums <- vapply(df, is.numeric, logical(1)) - infs[,nums] <- vapply(which(nums), function(i) !nas[,i] & !is.finite(df[[i]]), logical(NROW(df))) - - for (i in which(!nums)) { - if (can_str2num(df[[i]])) { - df[[i]] <- str2num(df[[i]]) - nums[i] <- TRUE - } + X[[i]] <- factor(X[[i]], levels = unique_x, labels = lev) } - o.negs[,nums] <- !nas[,nums] & df[nums] < 0 & round(df[nums], digits) == 0 - df[nums] <- round(df[nums], digits = digits) + out <- interaction2(X, sep = sep, lex.order = if (include_vars) TRUE else NULL) - for (i in which(nums)) { - df[[i]] <- format(df[[i]], scientific = FALSE, justify = "none", trim = TRUE, - drop0trailing = !identical(as.character(pad), "0")) - - if (!identical(as.character(pad), "0") && any(grepl(".", df[[i]], fixed = TRUE))) { - s <- strsplit(df[[i]], ".", fixed = TRUE) - lengths <- lengths(s) - digits.r.of.. <- rep(0, NROW(df)) - digits.r.of..[lengths > 1] <- nchar(vapply(s[lengths > 1], `[[`, character(1L), 2)) - max.dig <- max(digits.r.of..) - - dots <- ifelse(lengths > 1, "", if (as.character(pad) != "") "." else pad) - pads <- vapply(max.dig - digits.r.of.., function(n) paste(rep(pad, n), collapse = ""), character(1L)) - - df[[i]] <- paste0(df[[i]], dots, pads) - } + if (is_null(nam)) { + return(out) } - df[o.negs] <- paste0("-", df[o.negs]) - - # Insert NA placeholders - df[nas] <- na_vals - df[infs] <- "N/A" - - if (length(rn) > 0) rownames(df) <- rn - if (length(cn) > 0) names(df) <- cn - - as.matrix(df) -} - -#Generalized inverse; port of MASS::ginv() -generalized_inverse <- function(sigma) { - sigmasvd <- svd(sigma) - pos <- sigmasvd$d > max(1e-8 * sigmasvd$d[1L], 0) - sigma_inv <- sigmasvd$v[, pos, drop = FALSE] %*% (sigmasvd$d[pos]^-1 * t(sigmasvd$u[, pos, drop = FALSE])) - sigma_inv + setNames(out, nam) } #Get covariates (RHS) vars from formula get.covs.matrix <- function(formula = NULL, data = NULL) { - if (is.null(formula)) { + if (is_null(formula)) { fnames <- colnames(data) - fnames[!startsWith(fnames, "`")] <- paste0("`", fnames[!startsWith(fnames, "`")], "`") + fnames[!startsWith(fnames, "`")] <- add_quotes(fnames[!startsWith(fnames, "`")], "`") formula <- reformulate(fnames) } - else formula <- update(terms(formula, data = data), NULL ~ . + 1) + else { + formula <- update(terms(formula, data = data), NULL ~ . + 1) + } mf <- model.frame(terms(formula, data = data), data, na.action = na.pass) chars.in.mf <- vapply(mf, is.character, logical(1L)) - mf[chars.in.mf] <- lapply(mf[chars.in.mf], factor) + for (i in which(chars.in.mf)) { + mf[[i]] <- as.factor(mf[[i]]) + } mf <- droplevels(mf) X <- model.matrix(formula, data = mf, contrasts.arg = lapply(Filter(is.factor, mf), contrasts, contrasts = FALSE)) + assign <- attr(X, "assign")[-1] - X <- X[,-1,drop=FALSE] + X <- X[,-1, drop = FALSE] + attr(X, "assign") <- assign X @@ -473,7 +208,9 @@ get.covs.matrix <- function(formula = NULL, data = NULL) { #Extracts and names the "assign" attribute from get.covs.matrix() get_assign <- function(mat) { - if (is.null(attr(mat, "assign"))) return(NULL) + if (is_null(attr(mat, "assign"))) { + return(NULL) + } setNames(attr(mat, "assign"), colnames(mat)) } @@ -481,78 +218,36 @@ get_assign <- function(mat) { #Convert match.matrix (mm) using numerical indices to using char rownames nummm2charmm <- function(nummm, treat) { #Assumes nummm has rownames - charmm <- array(NA_character_, dim = dim(nummm), dimnames = dimnames(nummm)) + charmm <- array(NA_character_, dim = dim(nummm), + dimnames = dimnames(nummm)) charmm[] <- names(treat)[nummm] charmm } charmm2nummm <- function(charmm, treat) { - nummm <- array(NA_integer_, dim = dim(charmm)) + nummm <- array(NA_integer_, dim = dim(charmm), + dimnames = dimnames(charmm)) n_index <- setNames(seq_along(treat), names(treat)) nummm[] <- n_index[charmm] nummm } #Get subclass from match.matrix. Only to be used if replace = FALSE. See subclass2mmC.cpp for reverse. -mm2subclass <- function(mm, treat) { - lab <- names(treat) - ind1 <- which(treat == 1) - - subclass <- setNames(rep(NA_character_, length(treat)), lab) - no.match <- is.na(mm) - subclass[ind1[!no.match[,1]]] <- ind1[!no.match[,1]] - subclass[mm[!no.match]] <- ind1[row(mm)[!no.match]] - - subclass <- setNames(factor(subclass, nmax = length(ind1)), lab) - levels(subclass) <- seq_len(nlevels(subclass)) - - subclass -} - -#(Weighted) variance that uses special formula for binary variables -wvar <- function(x, bin.var = NULL, w = NULL) { - if (is.null(w)) w <- rep(1, length(x)) - if (is.null(bin.var)) bin.var <- all(x == 0 | x == 1) - - w <- w / sum(w) #weights normalized to sum to 1 - mx <- sum(w * x) #weighted mean - - if (bin.var) { - return(mx * (1 - mx)) - } - - #Reliability weights variance; same as cov.wt() - sum(w * (x - mx)^2)/(1 - sum(w^2)) -} - -#Weighted mean faster than weighted.mean() -wm <- function(x, w = NULL, na.rm = TRUE) { - if (is.null(w)) { - if (anyNA(x)) { - if (!na.rm) return(NA_real_) - nas <- which(is.na(x)) - x <- x[-nas] - } - return(sum(x)/length(x)) - } - - if (anyNA(x) || anyNA(w)) { - if (!na.rm) return(NA_real_) - nas <- which(is.na(x) | is.na(w)) - x <- x[-nas] - w <- w[-nas] +mm2subclass <- function(mm, treat, focal = NULL) { + if (!is.integer(mm)) { + mm <- charmm2nummm(mm, treat) } - sum(x*w)/sum(w) + mm2subclassC(mm, treat, focal) } #Pooled within-group (weighted) covariance by group-mean centering covariates. Used #in Mahalanobis distance pooled_cov <- function(X, t, w = NULL) { unique_t <- unique(t) - if (is.null(dim(X))) X <- matrix(X, nrow = length(X)) + if (is_null(dim(X))) X <- matrix(X, nrow = length(X)) - if (is.null(w)) { + if (is_null(w)) { n <- nrow(X) for (i in unique_t) { in_t <- which(t == i) @@ -560,6 +255,7 @@ pooled_cov <- function(X, t, w = NULL) { X[in_t, j] <- X[in_t, j] - mean(X[in_t, j]) } } + return(cov(X)*(n-1)/(n-length(unique_t))) } @@ -569,18 +265,23 @@ pooled_cov <- function(X, t, w = NULL) { X[in_t, j] <- X[in_t, j] - wm(X[in_t, j], w[in_t]) } } + cov.wt(X, w)$cov } pooled_sd <- function(X, t, w = NULL, bin.var = NULL, contribution = "proportional") { contribution <- match_arg(contribution, c("proportional", "equal")) unique_t <- unique(t) - if (is.null(dim(X))) X <- matrix(X, nrow = length(X)) + if (is_null(dim(X))) X <- matrix(X, nrow = length(X)) n <- nrow(X) - if (is.null(bin.var)) bin.var <- apply(X, 2, function(x) all(x == 0 | x == 1)) + + if (is_null(bin.var)) { + bin.var <- apply(X, 2, function(x) all(x == 0 | x == 1)) + } if (contribution == "equal") { vars <- matrix(0, nrow = length(unique_t), ncol = ncol(X)) + for (i in seq_along(unique_t)) { in_t <- which(t == unique_t[i]) vars[i,] <- vapply(seq_len(ncol(X)), function(j) { @@ -589,6 +290,7 @@ pooled_sd <- function(X, t, w = NULL, bin.var = NULL, contribution = "proportion wvar(x[in_t], w = w[in_t], bin.var = b) }, numeric(1L)) } + pooled_var <- colMeans(vars) } else { @@ -597,7 +299,7 @@ pooled_sd <- function(X, t, w = NULL, bin.var = NULL, contribution = "proportion b <- bin.var[j] if (b) { - if (is.null(w)) { + if (is_null(w)) { v <- vapply(unique_t, function(i) { sxi <- sum(x[t == i]) ni <- sum(t == i) @@ -615,7 +317,7 @@ pooled_sd <- function(X, t, w = NULL, bin.var = NULL, contribution = "proportion } } else { - if (is.null(w)) { + if (is_null(w)) { for (i in unique_t) { x[t==i] <- x[t==i] - wm(x[t==i]) } @@ -643,11 +345,20 @@ ESS <- function(w) { #Compute sample sizes nn <- function(treat, weights, discarded = NULL, s.weights = NULL) { - if (is.null(discarded)) discarded <- rep(FALSE, length(treat)) - if (is.null(s.weights)) s.weights <- rep(1, length(treat)) + if (is_null(discarded)) { + discarded <- rep.int(FALSE, length(treat)) + } + + if (is_null(s.weights)) { + s.weights <- rep.int(1, length(treat)) + } + weights <- weights * s.weights - n <- matrix(0, ncol=2, nrow=6, dimnames = list(c("All (ESS)", "All", "Matched (ESS)","Matched", "Unmatched","Discarded"), - c("Control", "Treated"))) + + n <- matrix(0, ncol = 2L, nrow = 6L, + dimnames = list(c("All (ESS)", "All", "Matched (ESS)", + "Matched", "Unmatched","Discarded"), + c("Control", "Treated"))) # Control Treated n["All (ESS)",] <- c(ESS(s.weights[treat==0]), ESS(s.weights[treat==1])) @@ -664,17 +375,23 @@ nn <- function(treat, weights, discarded = NULL, s.weights = NULL) { qn <- function(treat, subclass, discarded = NULL) { treat <- factor(treat, levels = 0:1, labels = c("Control", "Treated")) - if (is.null(discarded)) discarded <- rep(FALSE, length(treat)) + + if (is_null(discarded)) { + discarded <- rep.int(FALSE, length(treat)) + } + qn <- table(treat[!discarded], subclass[!discarded]) if (any(is.na(subclass) & !discarded)) { qn <- cbind(qn, table(treat[is.na(subclass) & !discarded])) colnames(qn)[ncol(qn)] <- "Unmatched" } + if (any(discarded)) { qn <- cbind(qn, table(treat[discarded])) colnames(qn)[ncol(qn)] <- "Discarded" } + qn <- rbind(qn, colSums(qn)) rownames(qn)[nrow(qn)] <- "Total" @@ -684,138 +401,21 @@ qn <- function(treat, subclass, discarded = NULL) { qn } -#Faster diff() -diff1 <- function(x) { - x[-1] - x[-length(x)] -} - -#cumsum() for probabilities to ensure they are between 0 and 1 -.cumsum_prob <- function(x) { - s <- cumsum(x) - s / s[length(s)] -} - -#Make vector sum to 1, optionally by group -.make_sum_to_1 <- function(x, by = NULL) { - if (is.null(by)) { - return(x / sum(x)) - } - - for (i in unique(by)) { - in_i <- which(by == i) - x[in_i] <- x[in_i] / sum(x[in_i]) - } - - x -} - -#Make vector sum to n (average of 1), optionally by group -.make_sum_to_n <- function(x, by = NULL) { - if (is.null(by)) { - return(length(x) * x / sum(x)) - } - - for (i in unique(by)) { - in_i <- which(by == i) - x[in_i] <- length(in_i) * x[in_i] / sum(x[in_i]) - } - - x -} - -#Functions for error handling; based on chk and rlang -pkg_caller_call <- function(start = 1) { - package.funs <- c(getNamespaceExports(utils::packageName()), - .getNamespaceInfo(asNamespace(utils::packageName()), "S3methods")[, 3]) - k <- start #skip checking pkg_caller_call() - e_max <- start - while (!is.null(e <- rlang::caller_call(k))) { - if (!is.null(n <- rlang::call_name(e)) && - n %in% package.funs) e_max <- k - k <- k + 1 - } - rlang::caller_call(e_max) -} - -.err <- function(...) { - chk::err(..., call = pkg_caller_call(start = 2)) -} -.wrn <- function(...) { - chk::wrn(...) -} -.msg <- function(...) { - chk::msg(...) -} - -#De-bugged version of chk::chk_null_or() -.chk_null_or <- function(x, chk, ..., x_name = NULL) { - if (is.null(x_name)) { - x_name <- deparse1(substitute(x)) - } - - x_name <- add_quotes(x_name, "`") - - if (is.null(x)) { - return(invisible(x)) - } - - tryCatch(chk(x, ..., x_name = x_name), - error = function(e) { - msg <- sub("[.]$", " or `NULL`.", - conditionMessage(e)) - chk::err(msg, .subclass = "chk_error") - }) -} - -.chk_formula <- function(x, sides = NULL, x_name = NULL) { - if (is.null(sides)) { - if (rlang::is_formula(x)) { - return(invisible(x)) - } - if (is.null(x_name)) { - x_name <- chk::deparse_backtick_chk(substitute(x)) - } - chk::abort_chk(x_name, " must be a formula", - x = x) - } - else if (sides == 1) { - if (rlang::is_formula(x, lhs = FALSE)) { - return(invisible(x)) - } - if (is.null(x_name)) { - x_name <- chk::deparse_backtick_chk(substitute(x)) - } - chk::abort_chk(x_name, " must be a formula with no left-hand side", - x = x) - } - else if (sides == 2) { - if (rlang::is_formula(x, lhs = TRUE)) { - return(invisible(x)) - } - if (is.null(x_name)) { - x_name <- chk::deparse_backtick_chk(substitute(x)) - } - chk::abort_chk(x_name, " must be a formula with a left-hand side", - x = x) - } - else stop("`sides` must be NULL, 1, or 2") -} - #Function to capture and print errors and warnings better matchit_try <- function(expr, from = NULL, dont_warn_if = NULL) { tryCatch({ withCallingHandlers({ - expr + expr }, warning = function(w) { - if (is.null(dont_warn_if) || !grepl(dont_warn_if, conditionMessage(w), fixed = TRUE)) { - if (is.null(from)) .wrn(conditionMessage(w), tidy = FALSE) + if (is_null(dont_warn_if) || !grepl(dont_warn_if, conditionMessage(w), fixed = TRUE)) { + if (is_null(from)) .wrn(conditionMessage(w), tidy = FALSE) else .wrn(sprintf("(from %s) %s", from, conditionMessage(w)), tidy = FALSE) } invokeRestart("muffleWarning") })}, error = function(e) { - if (is.null(from)) .err(conditionMessage(e), tidy = FALSE) + if (is_null(from)) .err(conditionMessage(e), tidy = FALSE) else .err(sprintf("(from %s) %s", from, conditionMessage(e)), tidy = FALSE) }) } \ No newline at end of file diff --git a/R/discard.R b/R/discard.R index 764673e3..87ee45d1 100644 --- a/R/discard.R +++ b/R/discard.R @@ -2,9 +2,9 @@ discard <- function(treat, pscore = NULL, option = NULL) { n.obs <- length(treat) - if (length(option) == 0){ + if (is_null(option)){ # keep all units - return(setNames(rep(FALSE, n.obs), names(treat))) + return(rep_with(FALSE, treat)) } if (is.logical(option) && length(option) == n.obs && !anyNA(option)) { @@ -20,10 +20,10 @@ discard <- function(treat, pscore = NULL, option = NULL) { if (option == "none") { # keep all units - return(setNames(rep(FALSE, n.obs), names(treat))) + return(rep_with(FALSE, treat)) } - if (is.null(pscore)) { + if (is_null(pscore)) { .err('`discard` must be a logical vector or "none" in the absence of a propensity score') } @@ -50,7 +50,7 @@ discard <- function(treat, pscore = NULL, option = NULL) { # X <- model.matrix(reformulate(names(covs), intercept = FALSE), data = covs, # contrasts.arg = lapply(Filter(is.factor, covs), # function(x) contrasts(x, contrasts = nlevels(x) == 1))) - # discarded <- rep(FALSE, n.obs) + # discarded <- rep.int(FALSE, n.obs) # if (option == "hull.control"){ # discard units not in T convex hull # wif <- WhatIf::whatif(cfact = X[treat==0,], data = X[treat==1,]) # discarded[treat==0] <- !wif$in.hull diff --git a/R/dist_functions.R b/R/dist_functions.R index 98ab444f..7ffa48e4 100644 --- a/R/dist_functions.R +++ b/R/dist_functions.R @@ -104,15 +104,14 @@ #' #' @author Noah Greifer #' @seealso [`distance`], [matchit()], [dist()] (which is used -#' internally to compute Euclidean distances) +#' internally to compute some Euclidean distances) #' #' \pkgfun{optmatch}{match_on}, which provides similar functionality but with fewer #' options and a focus on efficient storage of the output. #' #' @references #' -#' Rosenbaum, P. R. (2010). *Design of observational studies*. -#' Springer. +#' Rosenbaum, P. R. (2010). *Design of observational studies*. Springer. #' #' Rosenbaum, P. R., & Rubin, D. B. (1985). Constructing a Control Group Using #' Multivariate Matched Sampling Methods That Incorporate the Propensity Score. @@ -221,32 +220,35 @@ transform_covariates <- function(formula = NULL, data = NULL, method = "mahalano #If allvariables have no variance, use Euclidean to avoid errors #If some have no variance, removes those to avoid messing up distances no_variance <- which(apply(X, 2, function(x) abs(max(x) - min(x)) < sqrt(.Machine$double.eps))) + if (length(no_variance) == ncol(X)) { method <- "euclidean" X <- X[, 1, drop = FALSE] } - else if (length(no_variance) > 0) { + else if (is_not_null(no_variance)) { X <- X[, -no_variance, drop = FALSE] } method <- match_arg(method, matchit_distances()) - if (is.null(discarded)) discarded <- rep(FALSE, nrow(X)) + if (is_null(discarded)) { + discarded <- rep.int(FALSE, nrow(X)) + } if (method == "mahalanobis") { # X <- sweep(X, 2, colMeans(X)) - if (is.null(var)) { + if (is_null(var)) { X <- scale(X) + #NOTE: optmatch and Rubin (1980) use pooled within-group covariance matrix - if (!is.null(treat)) { - var <- pooled_cov(X[!discarded,, drop = FALSE], treat[!discarded], s.weights[!discarded]) - } - else if (is.null(s.weights)) { - var <- cov(X[!discarded,, drop = FALSE]) - } - else { - var <- cov.wt(X[!discarded,, drop = FALSE], s.weights[!discarded])$cov + var <- { + if (is_not_null(treat)) + pooled_cov(X[!discarded,, drop = FALSE], treat[!discarded], s.weights[!discarded]) + else if (is_null(s.weights)) + cov(X[!discarded,, drop = FALSE]) + else + cov.wt(X[!discarded,, drop = FALSE], s.weights[!discarded])$cov } } else if (!is.cov_like(var)) { @@ -269,12 +271,18 @@ transform_covariates <- function(formula = NULL, data = NULL, method = "mahalano #Rosenbaum (2010, ch8) X_r <- matrix(0, nrow = sum(!discarded), ncol = ncol(X), dimnames = list(rownames(X)[!discarded], colnames(X))) - for (i in seq_len(ncol(X_r))) X_r[,i] <- rank(X[!discarded, i]) - if (is.null(s.weights)) var_r <- cov(X_r) - else var_r <- cov.wt(X_r, s.weights[!discarded])$cov + for (i in seq_len(ncol(X_r))) { + X_r[,i] <- rank(X[!discarded, i]) + } + + var_r <- { + if (is_null(s.weights)) cov(X_r) + else cov.wt(X_r, s.weights[!discarded])$cov + } + + multiplier <- sd(seq_len(sum(!discarded))) / sqrt(diag(var_r)) - multiplier <- sd(seq_len(sum(!discarded)))/sqrt(diag(var_r)) var_r <- var_r * outer(multiplier, multiplier, "*") inv_var <- NULL @@ -289,17 +297,16 @@ transform_covariates <- function(formula = NULL, data = NULL, method = "mahalano if (any(discarded)) { X_r <- array(0, dim = dim(X), dimnames = dimnames(X)) - for (i in seq_len(ncol(X_r))) X_r[!discarded,i] <- rank(X[!discarded,i]) + for (i in seq_len(ncol(X_r))) { + X_r[!discarded,i] <- rank(X[!discarded,i]) + } } X <- mahalanobize(X_r, inv_var) } - else if (method == "euclidean") { - #Do nothing - } else if (method == "scaled_euclidean") { - if (is.null(var)) { - if (!is.null(treat)) { + if (is_null(var)) { + if (is_not_null(treat)) { sds <- pooled_sd(X[!discarded,, drop = FALSE], treat[!discarded], s.weights[!discarded]) } else { @@ -320,34 +327,34 @@ transform_covariates <- function(formula = NULL, data = NULL, method = "mahalano X[,i] <- X[,i]/sds[i] } } + else if (method == "euclidean") { + #Do nothing + } attr(X, "treat") <- treat + X } #Internal function for fast(ish) Euclidean distance eucdist_internal <- function(X, treat = NULL) { - if (is.null(dim(X))) X <- as.matrix(X) - - if (is.null(treat)) { - if (ncol(X) == 1) { - d <- abs(outer(drop(X), drop(X), "-")) - } - else { - d <- dist_to_matrixC(dist(X)) + if (is_null(treat)) { + d <- { + if (NCOL(X) == 1L) abs(outer(drop(X), drop(X), "-")) + else as.matrix(dist(X)) } + dimnames(d) <- list(rownames(X), rownames(X)) } else { treat_l <- as.logical(treat) - if (ncol(X) == 1) { - d <- abs(outer(X[treat_l,], X[!treat_l,], "-")) - } - else { - d <- dist(X) - d <- dist_to_matrixC(d)[treat_l, !treat_l, drop = FALSE] + + d <- { + if (NCOL(X) == 1L) abs(outer(X[treat_l], X[!treat_l], "-")) + else eucdistC_N1xN0(X, as.integer(treat)) } + dimnames(d) <- list(rownames(X)[treat_l], rownames(X)[!treat_l]) } @@ -358,8 +365,11 @@ eucdist_internal <- function(X, treat = NULL) { #to ensure same result as when non-factor binary variable supplied (see optmatch:::contr.match_on) get.covs.matrix.for.dist <- function(formula = NULL, data = NULL) { - if (is.null(formula)) { - if (is.null(colnames(data))) colnames(data) <- paste0("X", seq_len(ncol(data))) + if (is_null(formula)) { + if (is_null(colnames(data))) { + colnames(data) <- paste0("X", seq_len(ncol(data))) + } + fnames <- colnames(data) fnames[!startsWith(fnames, "`")] <- add_quotes(fnames[!startsWith(fnames, "`")], "`") data <- as.data.frame(data) @@ -387,9 +397,9 @@ get.covs.matrix.for.dist <- function(formula = NULL, data = NULL) { contrasts.arg = lapply(Filter(is.factor, mf), function(x) contrasts(x, contrasts = FALSE)/sqrt(2))) - if (ncol(X) > 1) { - assign <- attr(X, "assign")[-1] - X <- X[, -1, drop = FALSE] + if (ncol(X) > 1L) { + assign <- attr(X, "assign")[-1L] + X <- X[, -1L, drop = FALSE] } attr(X, "assign") <- assign @@ -399,22 +409,30 @@ get.covs.matrix.for.dist <- function(formula = NULL, data = NULL) { } .check_X <- function(X) { - if (isTRUE(attr(X, "checked"))) return(X) + if (isTRUE(attr(X, "checked"))) { + return(X) + } treat <- attr(X, "treat") - if (is.data.frame(X)) X <- as.matrix(X) - else if (is.numeric(X) && is.null(dim(X))) { + if (is.data.frame(X)) { + X <- as.matrix(X) + } + else if (is.numeric(X) && is_null(dim(X))) { X <- matrix(X, nrow = length(X), dimnames = list(names(X), NULL)) } - if (anyNA(X)) .err("missing values are not allowed in the covariates") - if (any(!is.finite(X))) .err("Non-finite values are not allowed in the covariates") + chk::chk_not_any_na(X, "the covariates") + + if (any(!is.finite(X))) { + .err("non-finite values are not allowed in the covariates") + } if (!is.numeric(X) || length(dim(X)) != 2) { stop("bad X") } + attr(X, "checked") <- TRUE attr(X, "treat") <- treat X @@ -422,7 +440,7 @@ get.covs.matrix.for.dist <- function(formula = NULL, data = NULL) { is.cov_like <- function(var, X) { is.numeric(var) && - length(dim(var)) == 2 && + length(dim(var)) == 2L && (missing(X) || all(dim(var) == ncol(X))) && isSymmetric(var) && all(diag(var) >= 0) diff --git a/R/distance2_methods.R b/R/distance2_methods.R index fd58d9d0..f16b57eb 100644 --- a/R/distance2_methods.R +++ b/R/distance2_methods.R @@ -96,8 +96,11 @@ #' generalized boosted modeling as in *twang*; here, the number of trees is #' chosen based on cross-validation or out-of-bag error, rather than based on #' optimizing balance. \pkg{twang} should not be cited when using this method -#' to estimate propensity scores. } -#' \item{`"lasso"`, `"ridge"`, `"elasticnet"`}{ The propensity +#' to estimate propensity scores. Note that because there is a random component to choosing the tuning +#' parameter, results will vary across runs unless a [seed][set.seed] is +#' set.} +#' \item{`"lasso"`, `"ridge"`, `"elasticnet"`}{ +#' The propensity #' scores are estimated using a lasso, ridge, or elastic net model, #' respectively. The `formula` supplied to `matchit()` is processed #' with [model.matrix()] and passed to \pkgfun{glmnet}{cv.glmnet}, and @@ -128,7 +131,8 @@ #' directly to \pkgfun{randomForest}{randomForest}, and #' \pkgfun{randomForest}{predict.randomForest} is used to compute the propensity #' scores. The `link` argument is ignored, and predicted probabilities are -#' always returned as the distance measure.} +#' always returned as the distance measure. Note that because there is a random component, results will vary across runs unless a [seed][set.seed] is +#' set. } #' \item{`"nnet"`}{ The #' propensity scores are estimated using a single-hidden-layer neural network. #' The `formula` supplied to `matchit()` is passed directly to @@ -156,35 +160,35 @@ #' the linear predictor instead of the predicted probabilities. When #' `s.weights` is supplied to `matchit()`, it will not be passed to #' `bart2` because the `weights` argument in `bart2` does not -#' correspond to sampling weights. } +#' correspond to sampling weights. Note that because there is a random component to choosing the tuning +#' parameter, results will vary across runs unless the `seed` argument is supplied to `distance.options`. Note that setting a seed using [set.seed()] is not sufficient to guarantee reproducibility unless single-threading is used. See \pkgfun{dbarts}{bart2} for details.} #' } #' #' ## Methods for computing distances from covariates #' -#' The following methods involve computing a distance matrix from the covariates themselves -#' without estimating a propensity score. Calipers on the distance measure and -#' common support restrictions cannot be used, and the `distance` -#' component of the output object will be empty because no propensity scores -#' are estimated. The `link` and `distance.options` arguments are -#' ignored with these methods. See the individual matching methods pages for -#' whether these distances are allowed and how they are used. Each of these -#' distance measures can also be calculated outside `matchit()` using its -#' [corresponding function][euclidean_dist]. +#' The following methods involve computing a distance matrix from the covariates +#' themselves without estimating a propensity score. Calipers on the distance +#' measure and common support restrictions cannot be used, and the `distance` +#' component of the output object will be empty because no propensity scores are +#' estimated. The `link` and `distance.options` arguments are ignored with these +#' methods. See the individual matching methods pages for whether these +#' distances are allowed and how they are used. Each of these distance measures +#' can also be calculated outside `matchit()` using its [corresponding +#' function][euclidean_dist]. #' #' \describe{ #' \item{`"euclidean"`}{ The Euclidean distance is the raw -#' distance between units, computed as \deqn{d_{ij} = \sqrt{(x_i - x_j)(x_i - -#' x_j)'}} It is sensitive to the scale of the covariates, so covariates with +#' distance between units, computed as \deqn{d_{ij} = \sqrt{(x_i - x_j)(x_i - x_j)'}} It is sensitive to the scale of the covariates, so covariates with #' larger scales will take higher priority. } -#' \item{`"scaled_euclidean"`}{ The scaled Euclidean distance is the +#' \item{`"scaled_euclidean"`}{ +#' The scaled Euclidean distance is the #' Euclidean distance computed on the scaled (i.e., standardized) covariates. #' This ensures the covariates are on the same scale. The covariates are #' standardized using the pooled within-group standard deviations, computed by #' treatment group-mean centering each covariate before computing the standard -#' deviation in the full sample. } -#' \item{`"mahalanobis"`}{ The -#' Mahalanobis distance is computed as \deqn{d_{ij} = \sqrt{(x_i - -#' x_j)\Sigma^{-1}(x_i - x_j)'}} where \eqn{\Sigma} is the pooled within-group +#' deviation in the full sample. +#' } +#' \item{`"mahalanobis"`}{ The Mahalanobis distance is computed as \deqn{d_{ij} = \sqrt{(x_i - x_j)\Sigma^{-1}(x_i - x_j)'}} where \eqn{\Sigma} is the pooled within-group #' covariance matrix of the covariates, computed by treatment group-mean #' centering each covariate before computing the covariance in the full sample. #' This ensures the variables are on the same scale and accounts for the @@ -197,37 +201,36 @@ #' Mahalanobis distance but is not affinely invariant. } #' } #' -#' To perform Mahalanobis distance matching *and* estimate propensity -#' scores to be used for a purpose other than matching, the `mahvars` -#' argument should be used along with a different specification to -#' `distance`. See the individual matching method pages for details on how -#' to use `mahvars`. +#' To perform Mahalanobis distance matching *and* estimate propensity scores to +#' be used for a purpose other than matching, the `mahvars` argument should be +#' used along with a different specification to `distance`. See the individual +#' matching method pages for details on how to use `mahvars`. #' #' ## Distances supplied as a numeric vector or matrix #' -#' `distance` can also be supplied as a numeric vector whose values will be taken to -#' function like propensity scores; their pairwise difference will define the -#' distance between units. This might be useful for supplying propensity scores -#' computed outside `matchit()` or resupplying `matchit()` with -#' propensity scores estimated previously without having to recompute them. +#' `distance` can also be supplied as a numeric vector whose values will be +#' taken to function like propensity scores; their pairwise difference will +#' define the distance between units. This might be useful for supplying +#' propensity scores computed outside `matchit()` or resupplying `matchit()` +#' with propensity scores estimated previously without having to recompute them. #' #' `distance` can also be supplied as a matrix whose values represent the #' pairwise distances between units. The matrix should either be a square, with #' a row and column for each unit (e.g., as the output of a call to -#' `as.matrix(`[`dist`]`(.))`), or have as many rows as there are treated -#' units and as many columns as there are control units (e.g., as the output of -#' a call to [mahalanobis_dist()] or \pkgfun{optmatch}{match_on}). Distance values -#' of `Inf` will disallow the corresponding units to be matched. When -#' `distance` is a supplied as a numeric vector or matrix, `link` and -#' `distance.options` are ignored. +#' `as.matrix(`[`dist`]`(.))`), or have as many rows as there are treated units +#' and as many columns as there are control units (e.g., as the output of a call +#' to [mahalanobis_dist()] or \pkgfun{optmatch}{match_on}). Distance values of +#' `Inf` will disallow the corresponding units to be matched. When `distance` is +#' a supplied as a numeric vector or matrix, `link` and `distance.options` are +#' ignored. #' -#' @note -#' In versions of *MatchIt* prior to 4.0.0, `distance` was -#' specified in a slightly different way. When specifying arguments using the -#' old syntax, they will automatically be converted to the corresponding method -#' in the new syntax but a warning will be thrown. `distance = "logit"`, -#' the old default, will still work in the new syntax, though `distance = "glm", link = "logit"` is preferred (note that these are the default -#' settings and don't need to be made explicit). +#' @note In versions of *MatchIt* prior to 4.0.0, `distance` was specified in a +#' slightly different way. When specifying arguments using the old syntax, they +#' will automatically be converted to the corresponding method in the new syntax +#' but a warning will be thrown. `distance = "logit"`, the old default, will +#' still work in the new syntax, though `distance = "glm", link = "logit"` is +#' preferred (note that these are the default settings and don't need to be made +#' explicit). #' #' @examples #' data("lalonde") @@ -293,16 +296,16 @@ NULL #distance2glm----------------- distance2glm <- function(formula, data = NULL, link = "logit", ...) { - linear <- !is.null(link) && startsWith(as.character(link), "linear") + linear <- is_not_null(link) && startsWith(as.character(link), "linear") if (linear) link <- sub("linear.", "", as.character(link), fixed = TRUE) - A <- list(...) - A[!names(A) %in% c(names(formals(glm)), names(formals(glm.control)))] <- NULL + args <- c(names(formals(glm)), names(formals(glm.control))) + A <- setNames(lapply(args, ...get, ...), args) + A[lengths(A) == 0L] <- NULL res <- do.call("glm", c(list(formula = formula, data = data, family = quasibinomial(link = link)), A)) - if (linear) pred <- predict(res, type = "link") - else pred <- predict(res, type = "response") + pred <- predict(res, type = if (linear) "link" else "response") list(model = res, distance = pred) } @@ -311,7 +314,7 @@ distance2glm <- function(formula, data = NULL, link = "logit", ...) { distance2gam <- function(formula, data = NULL, link = "logit", ...) { rlang::check_installed("mgcv") - linear <- !is.null(link) && startsWith(as.character(link), "linear") + linear <- is_not_null(link) && startsWith(as.character(link), "linear") if (linear) link <- sub("linear.", "", as.character(link), fixed = TRUE) A <- list(...) @@ -322,8 +325,7 @@ distance2gam <- function(formula, data = NULL, link = "logit", ...) { weights = weights), A), quote = TRUE) - if (linear) pred <- predict(res, type = "link") - else pred <- predict(res, type = "response") + pred <- predict(res, type = if (linear) "link" else "response") list(model = res, distance = as.numeric(pred)) } @@ -331,8 +333,11 @@ distance2gam <- function(formula, data = NULL, link = "logit", ...) { #distance2rpart----------------- distance2rpart <- function(formula, data = NULL, link = NULL, ...) { rlang::check_installed("rpart") - A <- list(...) - A[!names(A) %in% c(names(formals(rpart::rpart)), names(formals(rpart::rpart.control)))] <- NULL + + args <- c(names(formals(rpart::rpart)), names(formals(rpart::rpart.control))) + A <- setNames(lapply(args, ...get, ...), args) + A[lengths(A) == 0L] <- NULL + A$formula <- formula A$data <- data A$method <- "class" @@ -357,25 +362,30 @@ distance2nnet <- function(formula, data = NULL, link = NULL, ...) { distance2cbps <- function(formula, data = NULL, link = NULL, ...) { rlang::check_installed("CBPS") - linear <- !is.null(link) && startsWith(as.character(link), "linear") + linear <- is_not_null(link) && startsWith(as.character(link), "linear") A <- list(...) A[["standardized"]] <- FALSE - if (is.null(A[["ATT"]])) { - if (is.null(A[["estimand"]])) A[["ATT"]] <- 1 + + if (is_null(A[["ATT"]])) { + if (is_null(A[["estimand"]])) { + A[["ATT"]] <- 1 + } else { estimand <- toupper(A[["estimand"]]) estimand <- match_arg(estimand, c("ATT", "ATC", "ATE")) A[["ATT"]] <- switch(estimand, "ATT" = 1, "ATC" = 2, 0) } } - if (is.null(A[["method"]])) { + + if (is_null(A[["method"]])) { A[["method"]] <- if (isFALSE(A[["over"]])) "exact" else "over" } + A[c("estimand", "over")] <- NULL - if (!is.null(A[["weights"]])) { + if (is_not_null(A[["weights"]])) { A[["sample.weights"]] <- A[["weights"]] A[["weights"]] <- NULL } @@ -394,17 +404,18 @@ distance2cbps <- function(formula, data = NULL, link = NULL, ...) { distance2bart <- function(formula, data = NULL, link = NULL, ...) { rlang::check_installed("dbarts") - linear <- !is.null(link) && startsWith(as.character(link), "linear") + linear <- is_not_null(link) && startsWith(as.character(link), "linear") + + args <- c(names(formals(dbarts::bart2)), names(formals(dbarts::dbartsControl))) + A <- setNames(lapply(args, ...get, ...), args) + A[lengths(A) == 0L] <- NULL - A <- list(...) - A[!names(A) %in% c(names(formals(dbarts::bart2)), names(formals(dbarts::dbartsControl)))] <- NULL A$formula <- formula A$data <- data res <- do.call(dbarts::bart2, A) - if (linear) pred <- fitted(res, type = "link") - else pred <- fitted(res, type = "response") + pred <- fitted(res, type = if (linear) "link" else "response") list(model = res, distance = pred) } @@ -412,7 +423,7 @@ distance2bart <- function(formula, data = NULL, link = NULL, ...) { # distance2bart <- function(formula, data, link = NULL, ...) { # rlang::check_installed("BART") # -# if (!is.null(link) && startsWith(as.character(link), "linear")) { +# if (is_not_null(link) && startsWith(as.character(link), "linear")) { # linear <- TRUE # link <- sub("linear.", "", as.character(link), fixed = TRUE) # } @@ -421,7 +432,7 @@ distance2bart <- function(formula, data = NULL, link = NULL, ...) { # #Keep link probit because default in matchit is logit but probit is much faster with BART # link <- "probit" # -# # if (is.null(link)) link <- "probit" +# # if (is_null(link)) link <- "probit" # # else if (!link %in% c("probit", "logit")) { # # stop("'link' must be \"probit\" or \"logit\" with distance = \"bart\".", call. = FALSE) # # } @@ -436,7 +447,7 @@ distance2bart <- function(formula, data = NULL, link = NULL, ...) { # # A <- list(...) # -# if (!is.null(A[["mc.cores"]]) && A[["mc.cores"]][1] > 1) fun <- BART::mc.gbart +# if (is_not_null(A[["mc.cores"]]) && A[["mc.cores"]][1] > 1) fun <- BART::mc.gbart # else fun <- BART::gbart # # res <- do.call(fun, c(list(X, @@ -459,26 +470,35 @@ distance2randomforest <- function(formula, data = NULL, link = NULL, ...) { newdata[[treatvar]] <- factor(newdata[[treatvar]], levels = c("0", "1")) res <- randomForest::randomForest(formula, data = newdata, ...) - list(model = res, distance = predict(res, type = "prob")[,"1"]) + list(model = res, distance = predict(res, type = "prob")[,"1"]) } #distance2glmnet-------------- distance2elasticnet <- function(formula, data = NULL, link = NULL, ...) { rlang::check_installed("glmnet") - linear <- !is.null(link) && startsWith(as.character(link), "linear") + linear <- is_not_null(link) && startsWith(as.character(link), "linear") if (linear) link <- sub("linear.", "", as.character(link), fixed = TRUE) - A <- list(...) - s <- A[["s"]] - A[!names(A) %in% c(names(formals(glmnet::glmnet)), names(formals(glmnet::cv.glmnet)))] <- NULL + s <- ...get("s", ...) + if (is_null(s)) { + s <- "lambda.1se" + } - if (is.null(link)) link <- "logit" - if (link == "logit") A$family <- "binomial" - else if (link == "log") A$family <- "poisson" - else A$family <- binomial(link = link) + args <- c(names(formals(glmnet::glmnet)), names(formals(glmnet::cv.glmnet))) + A <- setNames(lapply(args, ...get, ...), args) + A[lengths(A) == 0L] <- NULL - if (is.null(A[["alpha"]])) A[["alpha"]] <- .5 + if (is_null(link)) link <- "logit" + + A$family <- switch(link, + "logit" = "binomial", + "log" = "poisson", + binomial(link = link)) + + if (is_null(A[["alpha"]])) { + A[["alpha"]] <- .5 + } mf <- model.frame(formula, data = data) @@ -487,52 +507,70 @@ distance2elasticnet <- function(formula, data = NULL, link = NULL, ...) { res <- do.call(glmnet::cv.glmnet, A) - if (is.null(s)) s <- "lambda.1se" - pred <- drop(predict(res, newx = A$x, s = s, - type = if (linear) "link" else "response")) + type = if (linear) "link" else "response")) list(model = res, distance = pred) } distance2lasso <- function(formula, data = NULL, link = NULL, ...) { - A <- list(...) - A$alpha <- 1 - do.call("distance2elasticnet", c(list(formula, data = data, link = link), A), - quote = TRUE) + if ("alpha" %in% ...names()) { + args <- c("s", names(formals(glmnet::glmnet)), names(formals(glmnet::cv.glmnet))) + A <- setNames(lapply(args, ...get, ...), args) + A[lengths(A) == 0L] <- NULL + + A$alpha <- 1 + do.call("distance2elasticnet", c(list(formula, data = data, link = link), A), + quote = TRUE) + } + else { + distance2elasticnet(formula = formula, data = data, link = link, alpha = 1, ...) + } } distance2ridge <- function(formula, data = NULL, link = NULL, ...) { - A <- list(...) - A$alpha <- 0 - do.call("distance2elasticnet", c(list(formula, data = data, link = link), A), - quote = TRUE) + if ("alpha" %in% ...names()) { + args <- c("s", names(formals(glmnet::glmnet)), names(formals(glmnet::cv.glmnet))) + A <- setNames(lapply(args, ...get, ...), args) + A[lengths(A) == 0L] <- NULL + + A$alpha <- 0 + do.call("distance2elasticnet", c(list(formula, data = data, link = link), A), + quote = TRUE) + } + else { + distance2elasticnet(formula = formula, data = data, link = link, alpha = 0, ...) + } } #distance2gbm-------------- distance2gbm <- function(formula, data = NULL, link = NULL, ...) { rlang::check_installed("gbm") - linear <- !is.null(link) && startsWith(as.character(link), "linear") + linear <- is_not_null(link) && startsWith(as.character(link), "linear") A <- list(...) - method <- A[["method"]] - A[!names(A) %in% names(formals(gbm::gbm))] <- NULL + method <- ...get("method", ...) + + args <- names(formals(gbm::gbm)) + A <- setNames(lapply(args, ...get, ...), args) + A[lengths(A) == 0L] <- NULL A$formula <- formula A$data <- data A$distribution <- "bernoulli" - if (is.null(A[["n.trees"]])) A[["n.trees"]] <- 1e4 - if (is.null(A[["interaction.depth"]])) A[["interaction.depth"]] <- 3 - if (is.null(A[["shrinkage"]])) A[["shrinkage"]] <- .01 - if (is.null(A[["bag.fraction"]])) A[["bag.fraction"]] <- 1 - if (is.null(A[["cv.folds"]])) A[["cv.folds"]] <- 5 - if (is.null(A[["keep.data"]])) A[["keep.data"]] <- FALSE + if (is_null(A[["n.trees"]])) A[["n.trees"]] <- 1e4 + if (is_null(A[["interaction.depth"]])) A[["interaction.depth"]] <- 3 + if (is_null(A[["shrinkage"]])) A[["shrinkage"]] <- .01 + if (is_null(A[["bag.fraction"]])) A[["bag.fraction"]] <- 1 + if (is_null(A[["cv.folds"]])) A[["cv.folds"]] <- 5 + if (is_null(A[["keep.data"]])) A[["keep.data"]] <- FALSE if (A[["cv.folds"]] <= 1 && A[["bag.fraction"]] == 1) { .err('either `bag.fraction` must be less than 1 or `cv.folds` must be greater than 1 when using `distance = "gbm"`') } - if (is.null(method)) { + + if (is_null(method)) { if (A[["bag.fraction"]] < 1) method <- "OOB" else method <- "cv" } @@ -547,5 +585,5 @@ distance2gbm <- function(formula, data = NULL, link = NULL, ...) { pred <- drop(predict(res, newdata = data, n.trees = best.tree, type = if (linear) "link" else "response")) - list(model = res, distance = pred) + list(model = res, distance = pred) } diff --git a/R/get_weights_from_mm.R b/R/get_weights_from_mm.R index a7713d0e..64390a94 100644 --- a/R/get_weights_from_mm.R +++ b/R/get_weights_from_mm.R @@ -1,15 +1,22 @@ -get_weights_from_mm <- function(match.matrix, treat) { +get_weights_from_mm <- function(match.matrix, treat, focal = NULL) { - if (!is.integer(match.matrix)) match.matrix <- charmm2nummm(match.matrix, treat) + if (!is.integer(match.matrix)) { + match.matrix <- charmm2nummm(match.matrix, treat) + } - weights <- weights_matrixC(match.matrix, treat) + weights <- weights_matrixC(match.matrix, treat, focal) - if (sum(weights) == 0) + if (all_equal_to(weights, 0)) { .err("No units were matched") - if (sum(weights[treat == 1]) == 0) + } + + if (all_equal_to(weights[treat == 1], 0)) { .err("No treated units were matched") - if (sum(weights[treat == 0]) == 0) + } + + if (all_equal_to(weights[treat == 0], 0)) { .err("No control units were matched") + } setNames(weights, names(treat)) -} +} \ No newline at end of file diff --git a/R/get_weights_from_subclass.R b/R/get_weights_from_subclass.R index b6e2baa1..bd1dd333 100644 --- a/R/get_weights_from_subclass.R +++ b/R/get_weights_from_subclass.R @@ -2,52 +2,58 @@ get_weights_from_subclass <- function(psclass, treat, estimand = "ATT") { NAsub <- is.na(psclass) - i1 <- treat == 1 & !NAsub - i0 <- treat == 0 & !NAsub + i1 <- which(treat == 1 & !NAsub) + i0 <- which(treat == 0 & !NAsub) - weights <- setNames(rep(0, length(treat)), names(treat)) + if (is_null(i1)) { + if (is_null(i0)) { + .err("No units were matched") + } - if (!is.factor(psclass)) { - psclass <- factor(psclass, nmax = min(sum(i1), sum(i0))) - levels(psclass) <- seq_len(nlevels(psclass)) + .err("No treated units were matched") + } + else if (is_null(i0)) { + .err("No control units were matched") } - treated_by_sub <- setNames(tabulateC(psclass[i1], nlevels(psclass)), levels(psclass)) - control_by_sub <- setNames(tabulateC(psclass[i0], nlevels(psclass)), levels(psclass)) + weights <- rep_with(0.0, treat) + + if (!is.factor(psclass)) { + psclass <- factor(psclass, nmax = min(length(i1), length(i0))) + } - total_by_sub <- treated_by_sub + control_by_sub + treated_by_sub <- tabulate(psclass[i1], nlevels(psclass)) + control_by_sub <- tabulate(psclass[i0], nlevels(psclass)) - psclass <- as.character(psclass) + psclass <- unclass(psclass) if (estimand == "ATT") { - weights[i1] <- 1 + weights[i1] <- 1.0 weights[i0] <- (treated_by_sub/control_by_sub)[psclass[i0]] - - #Weights average 1 - weights[i0] <- .make_sum_to_n(weights[i0]) } else if (estimand == "ATC") { weights[i1] <- (control_by_sub/treated_by_sub)[psclass[i1]] - weights[i0] <- 1 - - #Weights average 1 - weights[i1] <- .make_sum_to_n(weights[i1]) + weights[i0] <- 1.0 } else if (estimand == "ATE") { - weights[i1] <- (total_by_sub/treated_by_sub)[psclass[i1]] - weights[i0] <- (total_by_sub/control_by_sub)[psclass[i0]] - - #Weights average 1 - weights[i1] <- .make_sum_to_n(weights[i1]) - weights[i0] <- .make_sum_to_n(weights[i0]) + weights[i1] <- 1.0 + (control_by_sub/treated_by_sub)[psclass[i1]] + weights[i0] <- 1.0 + (treated_by_sub/control_by_sub)[psclass[i0]] } - if (sum(weights) == 0) - .err("No units were matched") - if (sum(weights[treat == 1]) == 0) - .err("No treated units were matched") - if (sum(weights[treat == 0]) == 0) - .err("No control units were matched") - weights } + +# get_weights_from_subclass2 <- function(psclass, treat, estimand = "ATT") { +# +# weights <- weights_subclassC(psclass, treat, +# switch(estimand, "ATT" = 1, "ATC" = 0, NULL)) +# +# if (sum(weights) == 0) +# .err("No units were matched") +# if (sum(weights[treat == 1]) == 0) +# .err("No treated units were matched") +# if (sum(weights[treat == 0]) == 0) +# .err("No control units were matched") +# +# weights +# } \ No newline at end of file diff --git a/R/input_processing.R b/R/input_processing.R index 305229d1..387baa6c 100644 --- a/R/input_processing.R +++ b/R/input_processing.R @@ -4,34 +4,37 @@ check.inputs <- function(mcall, method, distance, exact, mahvars, antiexact, ratio, m.order, estimand, ..., min.controls = NULL, max.controls = NULL) { - null.method <- is.null(method) + null.method <- is_null(method) + if (null.method) { method <- "NULL" } else { - method <- match_arg(method, c("exact", "cem", "nearest", "optimal", "full", "genetic", "subclass", "cardinality", + method <- match_arg(method, c("exact", "cem", "nearest", "optimal", "full", + "genetic", "subclass", "cardinality", "quick")) } ignored.inputs <- character(0) error.inputs <- character(0) + if (null.method) { for (i in c("exact", "mahvars", "antiexact", "caliper", "std.caliper", "replace", "ratio", "min.controls", "max.controls", "m.order")) { - if (i %in% names(mcall) && !is.null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { + if (hasName(mcall, i) && is_not_null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { ignored.inputs <- c(ignored.inputs, i) } } } else if (method == "exact") { for (i in c("distance", "exact", "mahvars", "antiexact", "caliper", "std.caliper", "discard", "reestimate", "replace", "ratio", "min.controls", "max.controls", "m.order")) { - if (i %in% names(mcall) && !is.null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { + if (hasName(mcall, i) && is_not_null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { ignored.inputs <- c(ignored.inputs, i) } } } else if (method == "cem") { - for (i in c("distance", "exact", "mahvars", "antiexact", "caliper", "std.caliper", "discard", "reestimate", "replace", "ratio", "min.controls", "max.controls", "m.order")) { - if (i %in% names(mcall) && !is.null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { + for (i in c("distance", "exact", "mahvars", "antiexact", "caliper", "std.caliper", "discard", "reestimate", "replace", "ratio", "min.controls", "max.controls")) { + if (hasName(mcall, i) && is_not_null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { ignored.inputs <- c(ignored.inputs, i) } } @@ -39,7 +42,7 @@ check.inputs <- function(mcall, method, distance, exact, mahvars, antiexact, else if (method == "nearest") { if (is.character(distance) && distance %in% matchit_distances()) { for (e in c("mahvars", "reestimate")) { - if (e %in% names(mcall) && !is.null(e_ <- get0(e, inherits = FALSE)) && !identical(e_, formals(matchit)[[e]])) { + if (hasName(mcall, e) && is_not_null(e_ <- get0(e, inherits = FALSE)) && !identical(e_, formals(matchit)[[e]])) { error.inputs <- c(error.inputs, e) } } @@ -48,14 +51,14 @@ check.inputs <- function(mcall, method, distance, exact, mahvars, antiexact, else if (method == "optimal") { if (is.character(distance) && distance %in% matchit_distances()) { for (e in c("mahvars", "reestimate")) { - if (e %in% names(mcall) && !is.null(e_ <- get0(e, inherits = FALSE)) && !identical(e_, formals(matchit)[[e]])) { + if (hasName(mcall, e) && is_not_null(e_ <- get0(e, inherits = FALSE)) && !identical(e_, formals(matchit)[[e]])) { error.inputs <- c(error.inputs, e) } } } for (i in c("replace", "caliper", "std.caliper", "m.order")) { - if (i %in% names(mcall) && !is.null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { + if (hasName(mcall, i) && is_not_null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { ignored.inputs <- c(ignored.inputs, i) } } @@ -64,14 +67,14 @@ check.inputs <- function(mcall, method, distance, exact, mahvars, antiexact, else if (method == "full") { if (is.character(distance) && distance %in% matchit_distances()) { for (e in c("mahvars", "reestimate")) { - if (e %in% names(mcall) && !is.null(e_ <- get0(e, inherits = FALSE)) && !identical(e_, formals(matchit)[[e]])) { + if (hasName(mcall, e) && is_not_null(e_ <- get0(e, inherits = FALSE)) && !identical(e_, formals(matchit)[[e]])) { error.inputs <- c(error.inputs, e) } } } for (i in c("replace", "ratio", "m.order")) { - if (i %in% names(mcall) && !is.null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { + if (hasName(mcall, i) && is_not_null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { ignored.inputs <- c(ignored.inputs, i) } } @@ -79,27 +82,27 @@ check.inputs <- function(mcall, method, distance, exact, mahvars, antiexact, else if (method == "genetic") { if (is.character(distance) && distance %in% matchit_distances()) { for (e in c("mahvars", "reestimate")) { - if (e %in% names(mcall) && !is.null(e_ <- get0(e, inherits = FALSE)) && !identical(e_, formals(matchit)[[e]])) { + if (hasName(mcall, e) && is_not_null(e_ <- get0(e, inherits = FALSE)) && !identical(e_, formals(matchit)[[e]])) { error.inputs <- c(error.inputs, e) } } } for (i in c("min.controls", "max.controls")) { - if (i %in% names(mcall) && !is.null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { + if (hasName(mcall, i) && is_not_null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { ignored.inputs <- c(ignored.inputs, i) } } } else if (method == "cardinality") { for (i in c("distance", "antiexact", "caliper", "std.caliper", "reestimate", "replace", "min.controls", "m.order")) { - if (i %in% names(mcall) && !is.null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { + if (hasName(mcall, i) && is_not_null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { ignored.inputs <- c(ignored.inputs, i) } } } else if (method == "subclass") { for (i in c("exact", "mahvars", "antiexact", "caliper", "std.caliper", "replace", "ratio", "min.controls", "max.controls", "m.order")) { - if (i %in% names(mcall) && !is.null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { + if (hasName(mcall, i) && is_not_null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { ignored.inputs <- c(ignored.inputs, i) } } @@ -107,47 +110,67 @@ check.inputs <- function(mcall, method, distance, exact, mahvars, antiexact, else if (method == "quick") { if (is.character(distance) && distance %in% matchit_distances()) { for (e in c("mahvars", "reestimate")) { - if (e %in% names(mcall) && !is.null(e_ <- get0(e, inherits = FALSE)) && !identical(e_, formals(matchit)[[e]])) { + if (hasName(mcall, e) && is_not_null(e_ <- get0(e, inherits = FALSE)) && !identical(e_, formals(matchit)[[e]])) { error.inputs <- c(error.inputs, e) } } } for (i in c("replace", "ratio", "min.controls", "max.controls", "m.order", "antiexact")) { - if (i %in% names(mcall) && !is.null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { + if (hasName(mcall, i) && is_not_null(i_ <- get0(i, inherits = FALSE)) && !identical(i_, formals(matchit)[[i]])) { ignored.inputs <- c(ignored.inputs, i) } } } - if (length(ignored.inputs) > 0) .wrn(sprintf("the %s %s not used with `method = %s` and will be ignored", - ngettext(length(ignored.inputs), "argument", "arguments"), - word_list(ignored.inputs, quotes = 1, is.are = TRUE), - add_quotes(method, quotes = !null.method))) - if (length(error.inputs) > 0) .err(sprintf("the %s %s not used with `method = %s` and `distance = \"%s\"`", - ngettext(length(error.inputs), "argument", "arguments"), - word_list(error.inputs, quotes = 1, is.are = TRUE), - add_quotes(method, quotes = !null.method), - distance)) + if (is_not_null(ignored.inputs)) { + .wrn(sprintf("the argument%%s %s %%r not used with `method = %s` and will be ignored", + word_list(ignored.inputs, quotes = "`"), + add_quotes(method, quotes = !null.method)), + n = length(ignored.inputs)) + } + + if (is_not_null(error.inputs)) { + .err(sprintf("the argument%%s %s %%r not used with `method = %s` and `distance = %s`", + word_list(error.inputs, quotes = "`"), + add_quotes(method, quotes = !null.method), + add_quotes(distance)), + n = length(error.inputs)) + } + ignored.inputs } #Check treatment for type, binary, missing, num. rows check_treat <- function(treat = NULL, X = NULL) { - if (is.null(treat)) { - if (is.null(X) || is.null(attr(X, "treat"))) return(NULL) + if (is_null(treat)) { + if (is_null(X) || is_null(attr(X, "treat"))) { + return(NULL) + } + treat <- attr(X, "treat") } - if (isTRUE(attr(treat, "checked"))) return(treat) - if (!is.atomic(treat) || !is.null(dim(treat))) { + if (isTRUE(attr(treat, "checked"))) { + return(treat) + } + + if (!is.atomic(treat) || is_not_null(dim(treat))) { .err("the treatment must be a vector") } - if (anyNA(treat)) .err("missing values are not allowed in the treatment") - if (length(unique(treat)) != 2) .err("the treatment must be a binary variable") - if (!is.null(X) && length(treat) != nrow(X)) .err("the treatment and covariates must have the same number of units") + if (anyNA(treat)) { + .err("missing values are not allowed in the treatment") + } + + if (!has_n_unique(treat, 2L)) { + .err("the treatment must be a binary variable") + } + + if (is_not_null(X) && length(treat) != nrow(X)) { + .err("the treatment and covariates must have the same number of units") + } treat <- binarize(treat) #make 0/1 attr(treat, "checked") <- TRUE @@ -156,15 +179,16 @@ check_treat <- function(treat = NULL, X = NULL) { #Function to process distance and give warnings about new syntax process.distance <- function(distance, method = NULL, treat) { - if (is.null(distance)) { - if (!is.null(method) && !method %in% c("cem", "exact", "cardinality")) { + if (is_null(distance)) { + if (is_not_null(method) && !method %in% c("cem", "exact", "cardinality")) { .err(sprintf("`distance` cannot be `NULL` with `method = \"%s\"`", method)) } + return(distance) } - if (is.character(distance) && length(distance) == 1) { + if (chk::vld_string(distance)) { allowable.distances <- c( #Propensity score methods "glm", "cbps", "gam", "nnet", "rpart", "bart", @@ -176,15 +200,19 @@ process.distance <- function(distance, method = NULL, treat) { if (tolower(distance) %in% c("cauchit", "cloglog", "linear.cloglog", "linear.log", "linear.logit", "linear.probit", "linear.cauchit", "log", "probit")) { link <- tolower(distance) + .wrn(sprintf("`distance = \"%s\"` will be deprecated; please use `distance = \"glm\", link = \"%s\"` in the future", - distance, link)) + distance, link)) + distance <- "glm" attr(distance, "link") <- link } else if (tolower(distance) %in% tolower(c("GAMcloglog", "GAMlog", "GAMlogit", "GAMprobit"))) { link <- tolower(substr(distance, 4, nchar(distance))) + .wrn(sprintf("`distance = \"%s\"` will be deprecated; please use `distance = \"gam\", link = \"%s\"` in the future", - distance, link)) + distance, link)) + distance <- "gam" attr(distance, "link") <- link } @@ -198,7 +226,7 @@ process.distance <- function(distance, method = NULL, treat) { else if (!tolower(distance) %in% allowable.distances) { .err("the argument supplied to `distance` is not an allowable value. See `help(\"distance\")` for allowable options") } - else if (!is.null(method) && method == "subclass" && tolower(distance) %in% matchit_distances()) { + else if (is_not_null(method) && method == "subclass" && tolower(distance) %in% matchit_distances()) { .err(sprintf("`distance` cannot be %s with `method = \"subclass\"`", add_quotes(distance))) } @@ -206,109 +234,143 @@ process.distance <- function(distance, method = NULL, treat) { distance <- tolower(distance) } + return(distance) } - else if (!is.numeric(distance) || (!is.null(dim(distance)) && length(dim(distance)) != 2)) { + + if (!is.numeric(distance) || (is_not_null(dim(distance)) && length(dim(distance)) != 2)) { .err("`distance` must be a string with the name of the distance measure to be used or a numeric vector or matrix containing distance measures") } - else if (is.matrix(distance) && (is.null(method) || !method %in% c("nearest", "optimal", "full"))) { + + if (is.matrix(distance) && (is_null(method) || !method %in% c("nearest", "optimal", "full"))) { .err(sprintf("`distance` cannot be supplied as a matrix with `method = %s`", - add_quotes(method, quotes = !is.null(method)))) + add_quotes(method, quotes = is_not_null(method)))) } - if (is.numeric(distance)) { - if (is.matrix(distance)) { - dim.distance <- dim(distance) - if (all(dim.distance == length(treat))) { - if (!is.null(rownames(distance))) distance <- distance[names(treat),, drop = FALSE] - if (!is.null(colnames(distance))) distance <- distance[,names(treat), drop = FALSE] - distance <- distance[treat == 1, treat == 0, drop = FALSE] + + if (is.matrix(distance)) { + dim.distance <- dim(distance) + + if (all_equal_to(dim.distance, length(treat))) { + if (is_not_null(rownames(distance))) { + distance <- distance[names(treat),, drop = FALSE] } - else if (all(dim.distance == c(sum(treat==1), sum(treat==0)))) { - if (!is.null(rownames(distance))) distance <- distance[names(treat)[treat == 1],, drop = FALSE] - if (!is.null(colnames(distance))) distance <- distance[,names(treat)[treat == 0], drop = FALSE] + + if (is_not_null(colnames(distance))) { + distance <- distance[,names(treat), drop = FALSE] } - else { - .err("when supplied as a matrix, `distance` must have dimensions NxN or N1xN0. See `help(\"distance\")` for details") + + distance <- distance[treat == 1, treat == 0, drop = FALSE] + } + else if (dim.distance[1L] == sum(treat == 1) && + dim.distance[2L] == sum(treat == 0)) { + if (is_not_null(rownames(distance))) { + distance <- distance[names(treat)[treat == 1],, drop = FALSE] + } + + if (is_not_null(colnames(distance))) { + distance <- distance[,names(treat)[treat == 0], drop = FALSE] } } else { - if (length(distance) != length(treat)) { - .err("`distance` must be the same length as the dataset if specified as a numeric vector") - } + .err("when supplied as a matrix, `distance` must have dimensions NxN or N1xN0. See `help(\"distance\")` for details") } - - chk::chk_not_any_na(distance) + } + else if (length(distance) != length(treat)) { + .err("`distance` must be the same length as the dataset if specified as a numeric vector") } + chk::chk_not_any_na(distance) + distance } #Function to check ratio is acceptable process.ratio <- function(ratio, method = NULL, ..., min.controls = NULL, max.controls = NULL) { #Should be run after process.inputs() and ignored inputs set to NULL - ratio.null <- length(ratio) == 0 + ratio.null <- is_null(ratio) ratio.na <- !ratio.null && anyNA(ratio) - if (is.null(method)) return(1) + if (is_null(method)) { + return(1) + } + if (method %in% c("nearest", "optimal")) { - if (ratio.null) ratio <- 1 - else if (ratio.na) .err("`ratio` cannot be `NA`") - else if (!is.atomic(ratio) || !is.numeric(ratio) || length(ratio) > 1 || ratio < 1) { - .err("`ratio` must be a single number greater than or equal to 1") + if (ratio.null) { + ratio <- 1 + } + else { + chk::chk_number(ratio) + chk::chk_gte(ratio, 1) } - if (is.null(max.controls)) { + if (is_null(max.controls)) { if (!chk::vld_whole_number(ratio)) { .err("`ratio` must be a whole number when `max.controls` is not specified") } + ratio <- round(ratio) } - else if (anyNA(max.controls) || !is.atomic(max.controls) || !is.numeric(max.controls) || length(max.controls) > 1) { - .err("`max.controls` must be a single positive number") - } else { - if (ratio <= 1) .err("`ratio` must be greater than 1 for variable ratio matching") + chk::chk_count(max.controls) - max.controls <- ceiling(max.controls) - if (max.controls <= ratio) .err("`max.controls` must be greater than `ratio` for variable ratio matching") + if (ratio == 1) { + .err("`ratio` must be greater than 1 for variable ratio matching") + } + + if (max.controls <= ratio) { + .err("`max.controls` must be greater than `ratio` for variable ratio matching") + } - if (is.null(min.controls)) min.controls <- 1 - else if (anyNA(max.controls) || !is.atomic(max.controls) || !is.numeric(max.controls) || length(max.controls) > 1) { - .err("`max.controls` must be a single positive number") + if (is_null(min.controls)) { + min.controls <- 1 + } + else { + chk::chk_count(min.controls) } - else min.controls <- floor(min.controls) - if (min.controls < 1) .err("`min.controls` cannot be less than 1 for variable ratio matching") - if (min.controls >= ratio) .err("`min.controls` must be less than `ratio` for variable ratio matching") + if (min.controls < 1) { + .err("`min.controls` cannot be less than 1 for variable ratio matching") + } + + if (min.controls >= ratio) { + .err("`min.controls` must be less than `ratio` for variable ratio matching") + } } } else if (method == "full") { - if (is.null(max.controls)) max.controls <- Inf - else if ((anyNA(max.controls) || !is.atomic(max.controls) || !is.numeric(max.controls) || length(max.controls) > 1)) { - .err("`max.controls` must be a single positive number") + if (is_null(max.controls)) { + max.controls <- Inf + } + else { + chk::chk_number(max.controls) + chk::chk_gt(max.controls, 0) } - if (is.null(min.controls)) min.controls <- 0 - else if ((anyNA(min.controls) || !is.atomic(min.controls) || !is.numeric(min.controls) || length(min.controls) > 1)) { - .err("`min.controls` must be a single positive number") + if (is_null(min.controls)) { + min.controls <- 0 + } + else { + chk::chk_number(min.controls) + chk::chk_gt(min.controls, 0) } ratio <- 1 #Just to get min.controls and max.controls out } else if (method == "genetic") { - if (ratio.null) ratio <- 1 - else if (ratio.na) .err("`ratio` cannot be `NA`") - else if (!is.atomic(ratio) || !is.numeric(ratio) || length(ratio) > 1 || ratio < 1 || - !chk::vld_whole_number(ratio)) { - .err("`ratio` must be a single whole number greater than or equal to 1") + if (ratio.null) { + ratio <- 1 + } + else { + chk::chk_count(ratio) } - ratio <- round(ratio) min.controls <- max.controls <- NULL } else if (method == "cardinality") { - if (ratio.null) ratio <- 1 - else if (!ratio.na && (!is.atomic(ratio) || !is.numeric(ratio) || length(ratio) > 1 || ratio < 0)) { + if (ratio.null) { + ratio <- 1 + } + else if (!ratio.na && (!chk::vld_number(ratio) || !chk::vld_gte(ratio, 0))) { .err("`ratio` must be a single positive number or `NA`") } @@ -318,7 +380,7 @@ process.ratio <- function(ratio, method = NULL, ..., min.controls = NULL, max.co min.controls <- max.controls <- NULL } - if (!is.null(ratio)) { + if (is_not_null(ratio)) { attr(ratio, "min.controls") <- min.controls attr(ratio, "max.controls") <- max.controls } @@ -339,16 +401,20 @@ process.caliper <- function(caliper = NULL, method = NULL, data = NULL, covs = N #If std, export standardized versions #Check need for caliper - if (length(caliper) == 0 || is.null(method) || !method %in% c("nearest", "genetic", "full", "quick")) return(NULL) + if (is_null(caliper) || is_null(method) || !method %in% c("nearest", "genetic", "full", "quick")) { + return(NULL) + } #Check if form of caliper is okay - if (!is.atomic(caliper) || !is.numeric(caliper)) .err("`caliper` must be a numeric vector") + if (!is.atomic(caliper) || !is.numeric(caliper)) { + .err("`caliper` must be a numeric vector") + } #Check caliper names - if (length(caliper) == 1 && (is.null(names(caliper)) || identical(names(caliper), ""))) { + if (length(caliper) == 1L && (is_null(names(caliper)) || identical(names(caliper), ""))) { names(caliper) <- "" } - else if (is.null(names(caliper))) { + else if (is_null(names(caliper))) { .err("`caliper` must be a named vector with names corresponding to the variables for which a caliper is to be applied") } else if (anyNA(names(caliper))) { @@ -358,7 +424,7 @@ process.caliper <- function(caliper = NULL, method = NULL, data = NULL, covs = N .err("no more than one entry in `caliper` can have no name") } - if (any(names(caliper) == "") && is.null(distance)) { + if (hasName(caliper, "") && is_null(distance)) { .err("all entries in `caliper` must be named when `distance` does not correspond to a propensity score") } @@ -368,49 +434,41 @@ process.caliper <- function(caliper = NULL, method = NULL, data = NULL, covs = N cal.in.mahcovs <- setNames(names(caliper) %in% names(mahcovs), names(caliper)) if (any(names(caliper) != "" & !cal.in.covs & !cal.in.data)) { .err(paste0("All variables named in `caliper` must be in `data`. Variables not in `data`:\n\t", - paste0(names(caliper)[names(caliper) != "" & !cal.in.data & !cal.in.covs & !cal.in.mahcovs], collapse = ", ")), tidy = FALSE) + paste0(names(caliper)[names(caliper) != "" & !cal.in.data & !cal.in.covs & !cal.in.mahcovs], collapse = ", ")), + tidy = FALSE) } #Check std.caliper chk::chk_logical(std.caliper) - if (length(std.caliper) == 1) { - std.caliper <- setNames(rep.int(std.caliper, length(caliper)), names(caliper)) + if (length(std.caliper) == 1L) { + std.caliper <- rep_with(std.caliper, caliper) + } + else if (length(std.caliper) == length(caliper)) { + names(std.caliper) <- names(caliper) } - else if (length(std.caliper) != length(caliper)) { + else { .err("`std.caliper` must be the same length as `caliper`") } - else names(std.caliper) <- names(caliper) #Remove trivial calipers caliper <- caliper[is.finite(caliper)] - num.unique <- vapply(names(caliper), function(x) { - if (x == "") var <- distance - else if (cal.in.data[x]) var <- data[[x]] - else if (cal.in.covs[x]) var <- covs[[x]] - else var <- mahcovs[[x]] - - length(unique(var)) - }, integer(1L)) - - caliper <- caliper[num.unique > 1] - - if (length(caliper) == 0) return(NULL) + if (is_null(caliper)) { + return(NULL) + } #Ensure no calipers on categorical variables cat.vars <- vapply(names(caliper), function(x) { - if (num.unique[names(num.unique) == x] == 2) return(TRUE) - if (x == "") var <- distance else if (cal.in.data[x]) var <- data[[x]] else if (cal.in.covs[x]) var <- covs[[x]] else var <- mahcovs[[x]] - is.factor(var) || is.character(var) + chk::vld_character_or_factor(var) }, logical(1L)) if (any(cat.vars)) { - .err(paste0("Calipers cannot be used with binary, factor, or character variables. Offending variables:\n\t", + .err(paste0("Calipers cannot be used with factor or character variables. Offending variables:\n\t", paste0(ifelse(names(caliper) == "", "", names(caliper))[cat.vars], collapse = ", ")), tidy = FALSE) } @@ -420,9 +478,10 @@ process.caliper <- function(caliper = NULL, method = NULL, data = NULL, covs = N chk::chk_not_any_na(std.caliper) if (any(std.caliper)) { - if ("" %in% names(std.caliper) && isTRUE(std.caliper[names(std.caliper) == ""]) && is.matrix(distance)) { + if (hasName(std.caliper, "") && isTRUE(std.caliper[names(std.caliper) == ""]) && is.matrix(distance)) { .err("when `distance` is supplied as a matrix and a caliper for it is specified, `std.caliper` must be `FALSE` for the distance measure") } + caliper[std.caliper] <- caliper[std.caliper] * vapply(names(caliper)[std.caliper], function(x) { if (x == "") sd(distance[!discarded]) else if (cal.in.data[x]) sd(data[[x]][!discarded]) @@ -442,27 +501,30 @@ process.caliper <- function(caliper = NULL, method = NULL, data = NULL, covs = N #Function to process replace argument process.replace <- function(replace, method = NULL, ..., reuse.max = NULL) { - if (is.null(method)) return(FALSE) + if (is_null(method)) { + return(FALSE) + } + + if (is_null(replace)) { + replace <- FALSE + } - if (is.null(replace)) replace <- FALSE chk::chk_flag(replace) if (method %in% c("nearest")) { - if (is.null(reuse.max)) { - if (replace) reuse.max <- .Machine$integer.max - else reuse.max <- 1L - } - else if (length(reuse.max) == 1 && is.numeric(reuse.max) && - (!is.finite(reuse.max) || reuse.max > .Machine$integer.max) && - !anyNA(reuse.max)) { - reuse.max <- .Machine$integer.max + if (is_null(reuse.max)) { + reuse.max <- if (replace) .Machine$integer.max else 1L } - else if (abs(reuse.max - round(reuse.max)) > 1e-8 || length(reuse.max) != 1 || - anyNA(reuse.max) || reuse.max < 1) { - .err("`reuse.max` must be a positive integer of length 1") + else { + chk::chk_count(reuse.max) + chk::chk_gte(reuse.max, 1) + + if (reuse.max > .Machine$integer.max) { + reuse.max <- .Machine$integer.max + } } - replace <- reuse.max != 1L + replace <- reuse.max > 1L attr(replace, "reuse.max") <- as.integer(reuse.max) } @@ -474,17 +536,21 @@ process.replace <- function(replace, method = NULL, ..., reuse.max = NULL) { process.variable.input <- function(x, data = NULL) { n <- deparse1(substitute(x)) - if (is.null(x)) return(NULL) + if (is_null(x)) { + return(NULL) + } if (is.character(x)) { - if (is.null(data) || !is.data.frame(data)) { + if (is_null(data) || !is.data.frame(data)) { .err(sprintf("if `%s` is specified as strings, a data frame containing the named variables must be supplied to `data`", n)) } - if (!all(x %in% names(data))) { + + if (!all(hasName(data, x))) { .err(sprintf("All names supplied to `%s` must be variables in `data`. Variables not in `data`:\n\t%s", n, paste(add_quotes(setdiff(x, names(data))), collapse = ", ")), tidy = FALSE) } + x <- reformulate(x) } else if (rlang::is_formula(x)) { @@ -495,9 +561,10 @@ process.variable.input <- function(x, data = NULL) { } x_covs <- model.frame(x, data, na.action = "na.pass") + if (anyNA(x_covs)) { .err(sprintf("missing values are not allowed in the covariates named in `%s`", n)) } x_covs -} \ No newline at end of file +} diff --git a/R/match.data.R b/R/match.data.R index ce538411..a3c88302 100644 --- a/R/match.data.R +++ b/R/match.data.R @@ -178,57 +178,67 @@ match.data <- function(object, group = "all", distance = "distance", weights = " chk::chk_is(object, "matchit") - if (is.null(data)) { - env <- environment(object$formula) - data <- try(eval(object$call$data, envir = env), silent = TRUE) - if (length(data) == 0 || inherits(data, "try-error") || length(dim(data)) != 2 || nrow(data) != length(object[["treat"]])) { - env <- parent.frame() - data <- try(eval(object$call$data, envir = env), silent = TRUE) - if (length(data) == 0 || inherits(data, "try-error") || length(dim(data)) != 2 || nrow(data) != length(object[["treat"]])) { - data <- object[["model"]][["data"]] - if (length(data) == 0 || nrow(data) != length(object[["treat"]])) { - .err("a valid dataset could not be found. Please supply an argument to `data` containing the original dataset used in the matching") - } - } + data.found <- FALSE + for (i in 1:4) { + if (i == 2) { + data <- try(eval(object$call$data, envir = environment(object$formula)), silent = TRUE) + } + else if (i == 3) { + data <- try(eval(object$call$data, envir = parent.frame()), silent = TRUE) + } + else if (i == 4) { + data <- object[["model"]][["data"]] + } + + if (!null_or_error(data) && length(dim(data)) == 2L && nrow(data) == length(object[["treat"]])) { + data.found <- TRUE + break } } + if (!data.found) { + .err("a valid dataset could not be found. Please supply an argument to `data` containing the original dataset used in the matching") + } + if (!is.data.frame(data)) { - if (is.matrix(data)) data <- as.data.frame.matrix(data) - else .err("`data` must be a data frame") + if (!is.matrix(data)) { + .err("`data` must be a data frame") + } + data <- as.data.frame.matrix(data) } + if (nrow(data) != length(object$treat)) { .err("`data` must have as many rows as there were units in the original call to `matchit()`") } - if (!is.null(object$distance)) { + if (is_not_null(object$distance)) { chk::chk_not_null(distance) chk::chk_string(distance) - if (distance %in% names(data)) { + if (hasName(data, distance)) { .err(sprintf("%s is already the name of a variable in the data. Please choose another name for distance using the `distance` argument", add_quotes(distance))) } data[[distance]] <- object$distance } - if (!is.null(object$weights)) { + if (is_not_null(object$weights)) { chk::chk_not_null(weights) chk::chk_string(weights) - if (weights %in% names(data)) { + if (hasName(data, weights)) { .err(sprintf("%s is already the name of a variable in the data. Please choose another name for weights using the `weights` argument", add_quotes(weights))) } data[[weights]] <- object$weights - if (!is.null(object$s.weights) && include.s.weights) { + if (is_not_null(object$s.weights) && include.s.weights) { data[[weights]] <- data[[weights]] * object$s.weights } } - if (!is.null(object$subclass)) { + if (is_not_null(object$subclass)) { chk::chk_not_null(subclass) chk::chk_string(subclass) - if (subclass %in% names(data)) { + if (hasName(data, subclass)) { .err(sprintf("%s is already the name of a variable in the data. Please choose another name for subclass using the `subclass` argument", add_quotes(subclass))) } @@ -237,7 +247,7 @@ match.data <- function(object, group = "all", distance = "distance", weights = " treat <- object$treat - if (drop.unmatched && !is.null(object$weights)) { + if (drop.unmatched && is_not_null(object$weights)) { data <- data[object$weights > 0,,drop = FALSE] treat <- treat[object$weights > 0] } @@ -246,9 +256,9 @@ match.data <- function(object, group = "all", distance = "distance", weights = " if (group == "treated") data <- data[treat == 1,,drop = FALSE] else if (group == "control") data <- data[treat == 0,,drop = FALSE] - if (!is.null(object$distance)) attr(data, "distance") <- distance - if (!is.null(object$weights)) attr(data, "weights") <- weights - if (!is.null(object$subclass)) attr(data, "subclass") <- subclass + if (is_not_null(object$distance)) attr(data, "distance") <- distance + if (is_not_null(object$weights)) attr(data, "weights") <- weights + if (is_not_null(object$subclass)) attr(data, "subclass") <- subclass class(data) <- c("matchdata", class(data)) @@ -262,7 +272,7 @@ get_matches <- function(object, distance = "distance", weights = "weights", subc chk::chk_is(object, "matchit") - if (is.null(object$match.matrix)) { + if (is_null(object$match.matrix)) { .err("a match.matrix component must be present in the matchit object, which does not occur with all types of matching. Use `match.data()` instead") } @@ -275,7 +285,7 @@ get_matches <- function(object, distance = "distance", weights = "weights", subc chk::chk_not_null(id) chk::chk_string(id) - if (id %in% names(m.data)) { + if (hasName(m.data, id)) { .err(sprintf("%s is already the name of a variable in the data. Please choose another name for id using the `id` argument", add_quotes(id))) } @@ -283,7 +293,7 @@ get_matches <- function(object, distance = "distance", weights = "weights", subc m.data[[id]] <- names(object$treat)[object$weights > 0] for (i in c(weights, subclass)) { - if (i %in% names(m.data)) m.data[[i]] <- NULL + if (hasName(m.data, i)) m.data[[i]] <- NULL } mm <- object$match.matrix @@ -295,11 +305,14 @@ get_matches <- function(object, distance = "distance", weights = "weights", subc matched <- as.data.frame(matrix(NA_character_, nrow = nrow(mm) + sum(!is.na(mm)), ncol = 3)) names(matched) <- c(id, subclass, weights) - matched[[id]] <- c(as.vector(tmm[!is.na(tmm)]), rownames(mm)) - matched[[subclass]] <- c(as.vector(col(tmm)[!is.na(tmm)]), seq_len(nrow(mm))) - matched[[weights]] <- c(1/num.matches[matched[[subclass]][seq_len(sum(!is.na(mm)))]], rep(1, nrow(mm))) + matched[[id]] <- c(as.vector(tmm[!is.na(tmm)]), + rownames(mm)) + matched[[subclass]] <- c(as.vector(col(tmm)[!is.na(tmm)]), + seq_len(nrow(mm))) + matched[[weights]] <- c(1 / num.matches[matched[[subclass]][seq_len(sum(!is.na(mm)))]], + rep.int(1, nrow(mm))) - if (!is.null(object$s.weights) && include.s.weights) { + if (is_not_null(object$s.weights) && include.s.weights) { matched[[weights]] <- matched[[weights]] * object$s.weights[matched[[id]]] } @@ -310,7 +323,7 @@ get_matches <- function(object, distance = "distance", weights = "weights", subc out[[subclass]] <- factor(out[[subclass]], labels = seq_len(nrow(mm))) - if (!is.null(object$distance)) attr(out, "distance") <- distance + if (is_not_null(object$distance)) attr(out, "distance") <- distance attr(out, "weights") <- weights attr(out, "subclass") <- subclass attr(out, "id") <- id diff --git a/R/match.qoi.R b/R/match.qoi.R index c4b4f01e..371d41a9 100644 --- a/R/match.qoi.R +++ b/R/match.qoi.R @@ -1,11 +1,12 @@ ## Functions to calculate summary stats -bal1var <- function(xx, tt, ww = NULL, s.weights, subclass = NULL, mm = NULL, s.d.denom = "treated", standardize = FALSE, +bal1var <- function(xx, tt, ww = NULL, s.weights, subclass = NULL, mm = NULL, + s.d.denom = "treated", standardize = FALSE, compute.pair.dist = TRUE) { - un <- is.null(ww) + un <- is_null(ww) bin.var <- all(xx == 0 | xx == 1) - xsum <- rep(NA_real_, 7) + xsum <- rep.int(NA_real_, 7L) if (standardize) names(xsum) <- c("Means Treated","Means Control", "Std. Mean Diff.", "Var. Ratio", "eCDF Mean", "eCDF Max", "Std. Pair Dist.") @@ -21,8 +22,8 @@ bal1var <- function(xx, tt, ww = NULL, s.weights, subclass = NULL, mm = NULL, s. too.small <- sum(ww[i1] != 0) < 2 && sum(ww[i0] != 0) < 2 - xsum["Means Treated"] <- wm(xx[i1], ww[i1], na.rm=TRUE) - xsum["Means Control"] <- wm(xx[i0], ww[i0], na.rm=TRUE) + xsum["Means Treated"] <- wm(xx[i1], ww[i1], na.rm = TRUE) + xsum["Means Control"] <- wm(xx[i0], ww[i0], na.rm = TRUE) mdiff <- xsum["Means Treated"] - xsum["Means Control"] @@ -46,12 +47,17 @@ bal1var <- function(xx, tt, ww = NULL, s.weights, subclass = NULL, mm = NULL, s. } xsum[3] <- mdiff/std - if (!un && compute.pair.dist) xsum[7] <- pair.dist(xx, tt, subclass, mm, std) + if (!un && compute.pair.dist) { + xsum[7] <- pair.dist(xx, tt, subclass, mm, std) + } } } else { xsum[3] <- mdiff - if (!un && compute.pair.dist) xsum[7] <- pair.dist(xx, tt, subclass, mm) + + if (!un && compute.pair.dist) { + xsum[7] <- pair.dist(xx, tt, subclass, mm) + } } if (bin.var) { @@ -67,7 +73,8 @@ bal1var <- function(xx, tt, ww = NULL, s.weights, subclass = NULL, mm = NULL, s. xsum } -bal1var.subclass <- function(xx, tt, s.weights, subclass, s.d.denom = "treated", standardize = FALSE, which.subclass = NULL) { +bal1var.subclass <- function(xx, tt, s.weights, subclass, s.d.denom = "treated", + standardize = FALSE, which.subclass = NULL) { #Within-subclass balance statistics bin.var <- all(xx == 0 | xx == 1) in.sub <- !is.na(subclass) & subclass == which.subclass @@ -84,10 +91,10 @@ bal1var.subclass <- function(xx, tt, s.weights, subclass, s.d.denom = "treated", i1 <- which(in.sub & tt == 1) i0 <- which(in.sub & tt == 0) - too.small <- length(i1) < 2 && length(i0) < 2 + too.small <- length(i1) < 2L && length(i0) < 2L - xsum["Subclass","Means Treated"] <- wm(xx[i1], s.weights[i1], na.rm=TRUE) - xsum["Subclass","Means Control"] <- wm(xx[i0], s.weights[i0], na.rm=TRUE) + xsum["Subclass","Means Treated"] <- wm(xx[i1], s.weights[i1], na.rm = TRUE) + xsum["Subclass","Means Control"] <- wm(xx[i0], s.weights[i0], na.rm = TRUE) mdiff <- xsum["Subclass","Means Treated"] - xsum["Subclass","Means Control"] @@ -131,9 +138,13 @@ bal1var.subclass <- function(xx, tt, s.weights, subclass, s.d.denom = "treated", } #Compute within-pair/subclass distances -pair.dist <- function(xx, tt, subclass = NULL, mm = NULL, std = NULL, fast = TRUE) { +pair.dist <- function(xx, tt, subclass = NULL, mm = NULL, std = NULL) { - if (!is.null(mm)) { + if (is_not_null(subclass)) { + mpdiff <- pairdistsubC(as.numeric(xx), as.integer(tt), + as.integer(subclass)) + } + else if (is_not_null(mm)) { names(xx) <- names(tt) xx_t <- xx[rownames(mm)] xx_c <- matrix(0, nrow = nrow(mm), ncol = ncol(mm)) @@ -141,29 +152,12 @@ pair.dist <- function(xx, tt, subclass = NULL, mm = NULL, std = NULL, fast = TRU mpdiff <- mean(abs(xx_t - xx_c), na.rm = TRUE) } - else if (!is.null(subclass)) { - if (!fast) { - dists <- unlist(lapply(levels(subclass), function(s) { - t1 <- which(!is.na(subclass) & subclass == s & tt == 1) - t0 <- which(!is.na(subclass) & subclass == s & tt == 0) - if (length(t1) == 1 || length(t0) == 1) { - xx[t1] - xx[t0] - } - else { - outer(xx[t1], xx[t0], "-") - } - })) - mpdiff <- mean(abs(dists)) - } - else { - mpdiff <- pairdistsubC(as.numeric(xx), as.integer(tt), - as.integer(subclass), nlevels(subclass)) - } + else { + return(NA_real_) } - else return(NA_real_) - if (!is.null(std) && abs(mpdiff) > 1e-8) { - mpdiff <- mpdiff/std + if (is_not_null(std) && abs(mpdiff) > 1e-8) { + return(mpdiff/std) } mpdiff @@ -175,12 +169,15 @@ qqsum <- function(x, t, w = NULL, standardize = FALSE) { n.obs <- length(x) - if (is.null(w)) w <- rep(1, n.obs) + if (is_null(w)) { + w <- rep.int(1, n.obs) + } - if (all(x == 0 | x == 1)) { + if (has_n_unique(x, 2) && all(x == 0 | x == 1)) { t1 <- t == t[1] #For binary variables, just difference in means ediff <- abs(wm(x[t1], w[t1]) - wm(x[-t1], w[-t1])) + return(c(meandiff = ediff, maxdiff = ediff)) } @@ -239,9 +236,9 @@ qqsum <- function(x, t, w = NULL, standardize = FALSE) { method = "constant", ties = "ordered")$y } } + ediff <- abs(x1 - x0) } c(meandiff = mean(ediff), maxdiff = max(ediff)) - } \ No newline at end of file diff --git a/R/matchit.R b/R/matchit.R index 49c2b66e..c5011e66 100644 --- a/R/matchit.R +++ b/R/matchit.R @@ -1,7 +1,5 @@ #' Matching for Causal Inference #' -#' @aliases matchit print.matchit -#' #' @description #' `matchit()` is the main function of *MatchIt* and performs #' pairing, subset selection, and subclassification with the aim of creating @@ -27,11 +25,12 @@ #' [`"nearest"`][method_nearest] for nearest neighbor matching (on #' the propensity score by default), [`"optimal"`][method_optimal] #' for optimal pair matching, [`"full"`][method_full] for optimal +#' full matching, [`"quick"`][method_quick] for generalized (quick) #' full matching, [`"genetic"`][method_genetic] for genetic #' matching, [`"cem"`][method_cem] for coarsened exact matching, #' [`"exact"`][method_exact] for exact matching, #' [`"cardinality"`][method_cardinality] for cardinality and -#' template matching, and [`"subclass"`][method_subclass] for +#' profile matching, and [`"subclass"`][method_subclass] for #' subclassification. When set to `NULL`, no matching will occur, but #' propensity score estimation and common support restrictions will still occur #' if requested. See the linked pages for each method for more details on what @@ -50,15 +49,13 @@ #' argument controlling the link function used in estimating the distance #' measure. Allowable options depend on the specific `distance` value #' specified. See [`distance`] for allowable options with each -#' option. The default is `"logit"`, which, along with `distance = "glm"`, identifies the default measure as logistic regression propensity -#' scores. +#' option. The default is `"logit"`, which, along with `distance = "glm"`, identifies the default measure as logistic regression propensity scores. #' @param distance.options a named list containing additional arguments #' supplied to the function that estimates the distance measure as determined -#' by the argument to `distance`. See [distance] for an +#' by the argument to `distance`. See [`distance`] for an #' example of its use. #' @param estimand a string containing the name of the target estimand desired. -#' Can be one of `"ATT"` or `"ATC"`. Some methods accept `"ATE"` -#' as well. Default is `"ATT"`. See Details and the individual methods +#' Can be one of `"ATT"`, `"ATC"`, or `"ATE"`. Default is `"ATT"`. See Details and the individual methods #' pages for information on how this argument is used. #' @param exact for methods that allow it, for which variables exact matching #' should take place. Can be specified as a string containing the names of @@ -72,8 +69,7 @@ #' within propensity score calipers, where the propensity scores are computed #' using `formula` and `distance`. Can be specified as a string #' containing the names of variables in `data` to be used or a one-sided -#' formula with the desired variables on the right-hand side (e.g., `~ X3 + X4`). See the individual methods pages for information on whether and how -#' this argument is used. +#' formula with the desired variables on the right-hand side (e.g., `~ X3 + X4`). See the individual methods pages for information on whether and how this argument is used. #' @param antiexact for methods that allow it, for which variables anti-exact #' matching should take place. Anti-exact matching ensures paired individuals #' do not have the same value of the anti-exact matching variable(s). Can be @@ -99,7 +95,7 @@ #' be specified as a string containing the name of variable in `data` to #' be used or a one-sided formula with the variable on the right-hand side #' (e.g., `~ SW`). Not all propensity score models accept sampling -#' weights; see [distance] for information on which do and do not, +#' weights; see [`distance`] for information on which do and do not, #' and see `vignette("sampling-weights")` for details on how to use #' sampling weights in a matching analysis. #' @param replace for methods that allow it, whether matching should be done @@ -134,10 +130,10 @@ #' the matching process in the output, i.e., by the functions from other #' packages `matchit()` calls. What is included depends on the matching #' method. Default is `FALSE`. +#' @param normalize `logical`; whether to rescale the nonzero weights in each treatment group to have an average of 1. Default is `TRUE`. See "How Matching Weights Are Computed" below for more details. #' @param \dots additional arguments passed to the functions used in the #' matching process. See the individual methods pages for information on what -#' additional arguments are allowed for each method. Ignored for `print()`. -#' @param x a `matchit` object. +#' additional arguments are allowed for each method. #' #' @details #' Details for the various matching methods can be found at the following help @@ -145,10 +141,11 @@ #' * [`method_nearest`] for nearest neighbor matching #' * [`method_optimal`] for optimal pair matching #' * [`method_full`] for optimal full matching +#' * [`method_quick`] for generalized (quick) full matching #' * [`method_genetic`] for genetic matching #' * [`method_cem`] for coarsened exact matching #' * [`method_exact`] for exact matching -#' * [`method_cardinality`] for cardinality and template matching +#' * [`method_cardinality`] for cardinality and profile matching #' * [`method_subclass`] for subclassification #' #' The pages contain information on what the method does, which of the arguments above are @@ -171,7 +168,7 @@ #' specified. All arguments other than `distance`, `discard`, and #' `reestimate` will be ignored. #' -#' See [distance] for details on the several ways to +#' See [`distance`] for details on the several ways to #' specify the `distance`, `link`, and `distance.options` #' arguments to estimate propensity scores and create distance measures. #' @@ -180,7 +177,7 @@ #' Value, below). The following rules are used: 1) if `0` is one of the #' values, it will be considered the control and the other value the treated; #' 2) otherwise, if the variable is a factor, `levels(treat)[1]` will be -#' considered control and the other variable the treated; 3) otherwise, +#' considered control and the other value the treated; 3) otherwise, #' `sort(unique(treat))[1]` will be considered control and the other value #' the treated. It is safest to ensure the treatment variable is a `0/1` #' variable. @@ -217,36 +214,49 @@ #' Matching weights are computed in one of two ways depending on whether matching was done with replacement #' or not. #' -#' For matching *without* replacement (except for cardinality matching), each +#' ### Matching without replacement and subclassification +#' +#' For matching *without* replacement (except for cardinality matching), including subclassification, each #' unit is assigned to a subclass, which represents the pair they are a part of #' (in the case of k:1 matching) or the stratum they belong to (in the case of #' exact matching, coarsened exact matching, full matching, or #' subclassification). The formula for computing the weights depends on the #' argument supplied to `estimand`. A new "stratum propensity score" -#' (`sp`) is computed as the proportion of units in each stratum that are -#' in the treated group, and all units in that stratum are assigned that +#' (\eqn{p^s_i}) is computed for each unit \eqn{i} as \eqn{p^s_i = \frac{1}{n_s}\sum_{j: s_j =s_i}{I(A_j=1)}} where \eqn{n_s} is the size of subclass \eqn{s} and \eqn{I(A_j=1)} is 1 if unit \eqn{j} is treated and 0 otherwise. That is, the stratum propensity score for stratum \eqn{s} is the proportion of units in stratum \eqn{s} that are +#' in the treated group, and all units in stratum \eqn{s} are assigned that #' stratum propensity score. This is distinct from the propensity score used for matching, if any. Weights are then computed using the standard formulas for -#' inverse probability weights with the stratum propensity score inserted: for the ATT, weights are 1 for the treated -#' units and `sp/(1-sp)` for the control units; for the ATC, weights are -#' `(1-sp)/sp` for the treated units and 1 for the control units; for the -#' ATE, weights are `1/sp` for the treated units and `1/(1-sp)` for the -#' control units. For cardinality matching, all matched units receive a weight +#' inverse probability weights with the stratum propensity score inserted: +#' * for the ATT, weights are 1 for the treated +#' units and \eqn{\frac{p^s}{1-p^s}} for the control units +#' * for the ATC, weights are +#' \eqn{\frac{1-p^s}{p^s}} for the treated units and 1 for the control units +#' * for the ATE, weights are \eqn{\frac{1}{p^s}} for the treated units and \eqn{\frac{1}{1-p^s}} for the +#' control units. +#' +#' For cardinality matching, all matched units receive a weight #' of 1. #' +#' ### Matching witht replacement +#' #' For matching *with* replacement, units are not assigned to unique strata. For #' the ATT, each treated unit gets a weight of 1. Each control unit is weighted #' as the sum of the inverse of the number of control units matched to the same #' treated unit across its matches. For example, if a control unit was matched #' to a treated unit that had two other control units matched to it, and that #' same control was matched to a treated unit that had one other control unit -#' matched to it, the control unit in question would get a weight of 1/3 + 1/2 -#' = 5/6. For the ATC, the same is true with the treated and control labels +#' matched to it, the control unit in question would get a weight of \eqn{1/3 + 1/2 = 5/6}. For the ATC, the same is true with the treated and control labels #' switched. The weights are computed using the `match.matrix` component #' of the `matchit()` output object. #' -#' In each treatment group, weights are divided by the mean of the nonzero +#' ### Normalized weights +#' +#' When `normalize = TRUE` (the default), in each treatment group, weights are divided by the mean of the nonzero #' weights in that treatment group to make the weights sum to the number of -#' units in that treatment group. If sampling weights are included through the +#' units in that treatment group (i.e., to have an average of 1). +#' +#' ### Sampling weights +#' +#' If sampling weights are included through the #' `s.weights` argument, they will be included in the `matchit()` #' output object but not incorporated into the matching weights. #' [match.data()], which extracts the matched set from a `matchit` object, @@ -255,30 +265,26 @@ #' @return When `method` is something other than `"subclass"`, a #' `matchit` object with the following components: #' -#' \item{match.matrix}{a matrix containing the matches. The rownames correspond +#' \item{match.matrix}{a matrix containing the matches. The row names correspond #' to the treated units and the values in each row are the names (or indices) #' of the control units matched to each treated unit. When treated units are -#' matched to different numbers of control units (e.g., with exact matching or +#' matched to different numbers of control units (e.g., with variable ratio matching or #' matching with a caliper), empty spaces will be filled with `NA`. Not -#' included when `method` is `"full"`, `"cem"` (unless `k2k -#' = TRUE`), `"exact"`, or `"cardinality"`.} +#' included when `method` is `"full"`, `"cem"` (unless `k2k = TRUE`), `"exact"`, `"quick"`, or `"cardinality"` (unless `mahvars` is supplied and `ratio` is an integer).} #' \item{subclass}{a factor #' containing matching pair/stratum membership for each unit. Unmatched units -#' will have a value of `NA`. Not included when `replace = TRUE`.} +#' will have a value of `NA`. Not included when `replace = TRUE` or when `method = "cardinality"` unless `mahvars` is supplied and `ratio` is an integer.} #' \item{weights}{a numeric vector of estimated matching weights. Unmatched and #' discarded units will have a weight of zero.} #' \item{model}{the fit object of #' the model used to estimate propensity scores when `distance` is -#' specified and not `"mahalanobis"` or a numeric vector. When +#' specified as a method of estimating propensity scores. When #' `reestimate = TRUE`, this is the model estimated after discarding #' units.} -#' \item{X}{a data frame of covariates mentioned in `formula`, -#' `exact`, `mahvars`, and `antiexact`.} +#' \item{X}{a data frame of covariates mentioned in `formula`, `exact`, `mahvars`, `caliper`, and `antiexact`.} #' \item{call}{the `matchit()` call.} -#' \item{info}{information on the matching method and -#' distance measures used.} -#' \item{estimand}{the argument supplied to -#' `estimand`.} +#' \item{info}{information on the matching method and distance measures used.} +#' \item{estimand}{the argument supplied to `estimand`.} #' \item{formula}{the `formula` supplied.} #' \item{treat}{a vector of treatment status converted to zeros (0) and ones #' (1) if not already in that format.} @@ -286,16 +292,11 @@ #' values (i.e., propensity scores) when `distance` is supplied as a #' method of estimating propensity scores or a numeric vector.} #' \item{discarded}{a logical vector denoting whether each observation was -#' discarded (`TRUE`) or not (`FALSE`) by the argument to -#' `discard`.} -#' \item{s.weights}{the vector of sampling weights supplied to -#' the `s.weights` argument, if any.} -#' \item{exact}{a one-sided formula -#' containing the variables, if any, supplied to `exact`.} -#' \item{mahvars}{a one-sided formula containing the variables, if any, -#' supplied to `mahvars`.} -#' \item{obj}{when `include.obj = TRUE`, an -#' object containing the intermediate results of the matching procedure. See +#' discarded (`TRUE`) or not (`FALSE`) by the argument to `discard`.} +#' \item{s.weights}{the vector of sampling weights supplied to the `s.weights` argument, if any.} +#' \item{exact}{a one-sided formula containing the variables, if any, supplied to `exact`.} +#' \item{mahvars}{a one-sided formula containing the variables, if any, supplied to `mahvars`.} +#' \item{obj}{when `include.obj = TRUE`, an object containing the intermediate results of the matching procedure. See #' the individual methods pages for what this component will contain.} #' #' When `method = "subclass"`, a `matchit.subclass` object with the same @@ -304,24 +305,18 @@ #' distance measure cutpoints used to define the subclasses. See #' [`method_subclass`] for details. #' -#' @author Daniel Ho (\email{dho@@law.stanford.edu}); Kosuke Imai -#' (\email{imai@@harvard.edu}); Gary King (\email{king@@harvard.edu}); -#' Elizabeth Stuart (\email{estuart@@jhsph.edu}) -#' -#' Version 4.0.0 update by Noah Greifer (\email{noah.greifer@@gmail.com}) +#' @author Daniel Ho, Kosuke Imai, Gary King, and Elizabeth Stuart wrote the original package. Starting with version 4.0.0, Noah Greifer is the primary maintainer and developer. #' #' @seealso [summary.matchit()] for balance assessment after matching, [plot.matchit()] for plots of covariate balance and propensity score overlap after matching. #' -#' `vignette("MatchIt")` for an introduction to matching with -#' *MatchIt*; `vignette("matching-methods")` for descriptions of the -#' variety of matching methods and options available; -#' `vignette("assessing-balance")` for information on assessing the -#' quality of a matching specification; `vignette("estimating-effects")` -#' for instructions on how to estimate treatment effects after matching; and -#' `vignette("sampling-weights")` for a guide to using *MatchIt* with -#' sampling weights. +#' * `vignette("MatchIt")` for an introduction to matching with *MatchIt* +#' * `vignette("matching-methods")` for descriptions of the variety of matching methods and options available +#' * `vignette("assessing-balance")` for information on assessing the quality of a matching specification +#' * `vignette("estimating-effects")` for instructions on how to estimate treatment effects after matching +#' * `vignette("sampling-weights")` for a guide to using *MatchIt* with sampling weights. #' -#' @references Ho, D. E., Imai, K., King, G., & Stuart, E. A. (2007). Matching +#' @references +#' Ho, D. E., Imai, K., King, G., & Stuart, E. A. (2007). Matching #' as Nonparametric Preprocessing for Reducing Model Dependence in Parametric #' Causal Inference. *Political Analysis*, 15(3), 199–236. \doi{10.1093/pan/mpl013} #' @@ -380,7 +375,7 @@ #' discard = "control", subclass = 10) #' s.out1 #' summary(s.out1, un = TRUE) -#' + #' @export matchit <- function(formula, data = NULL, @@ -402,6 +397,7 @@ matchit <- function(formula, ratio = 1, verbose = FALSE, include.obj = FALSE, + normalize = TRUE, ...) { #Checking input format @@ -409,22 +405,22 @@ matchit <- function(formula, mcall <- match.call() ## Process method - .chk_null_or(method, chk::chk_string) - if (length(method) == 1 && is.character(method)) { - method <- tolower(method) - method <- match_arg(method, c("exact", "cem", "nearest", "optimal", "full", "genetic", "subclass", "cardinality", - "quick")) - fn2 <- paste0("matchit2", method) - } - else if (is.null(method)) { + chk::chk_null_or(method, vld = chk::vld_string) + if (is_null(method)) { fn2 <- "matchit2null" } else { - .err("`method` must be the name of a supported matching method. See `?matchit` for allowable options") + method <- tolower(method) + method <- match_arg(method, c("exact", "cem", "nearest", "optimal", + "full", "genetic", "subclass", "cardinality", + "quick")) + fn2 <- paste0("matchit2", method) } #Process formula and data inputs - .chk_formula(formula, sides = 2) + if (!rlang::is_formula(formula, lhs = TRUE)) { + .err("`formula` must be a formula relating treatment to covariates") + } treat.form <- update(terms(formula, data = data), . ~ 0) treat.mf <- model.frame(treat.form, data = data, na.action = "na.pass") @@ -432,7 +428,9 @@ matchit <- function(formula, #Check and binarize treat treat <- check_treat(treat) - if (length(treat) == 0) .err("the treatment cannot be `NULL`") + if (is_null(treat)) { + .err("the treatment cannot be `NULL`") + } names(treat) <- rownames(treat.mf) @@ -444,8 +442,8 @@ matchit <- function(formula, reestimate = reestimate, s.weights = s.weights, replace = replace, ratio = ratio, m.order = m.order, estimand = estimand) - if (length(ignored.inputs) > 0) { - for (i in ignored.inputs) assign(i, NULL) + for (i in ignored.inputs) { + assign(i, NULL) } #Process replace @@ -455,40 +453,51 @@ matchit <- function(formula, ratio <- process.ratio(ratio, method, ...) #Process s.weights - if (!is.null(s.weights)) { + if (is_not_null(s.weights)) { if (is.character(s.weights)) { - if (is.null(data) || !is.data.frame(data)) { + if (is_null(data) || !is.data.frame(data)) { .err("if `s.weights` is specified a string, a data frame containing the named variable must be supplied to `data`") } - if (!all(s.weights %in% names(data))) { + + if (!all(hasName(data, s.weights))) { .err("the name supplied to `s.weights` must be a variable in `data`") } + s.weights.form <- reformulate(s.weights) s.weights <- model.frame(s.weights.form, data, na.action = "na.pass") - if (ncol(s.weights) != 1) .err("`s.weights` can only contain one named variable") - s.weights <- s.weights[[1]] + + if (ncol(s.weights) != 1L) { + .err("`s.weights` can only contain one named variable") + } + + s.weights <- s.weights[[1L]] } - else if (inherits(s.weights, "formula")) { + else if (rlang::is_formula(s.weights)) { s.weights.form <- update(terms(s.weights, data = data), NULL ~ .) s.weights <- model.frame(s.weights.form, data, na.action = "na.pass") - if (ncol(s.weights) != 1) .err("`s.weights` can only contain one named variable") - s.weights <- s.weights[[1]] + + if (ncol(s.weights) != 1L) { + .err("`s.weights` can only contain one named variable") + } + + s.weights <- s.weights[[1L]] } else if (!is.numeric(s.weights)) { .err("`s.weights` must be supplied as a numeric vector, string, or one-sided formula") } chk::chk_not_any_na(s.weights) - if (length(s.weights) != n.obs) .err("`s.weights` must be the same length as the treatment vector") + if (length(s.weights) != n.obs) { + .err("`s.weights` must be the same length as the treatment vector") + } names(s.weights) <- names(treat) - } #Process distance function is.full.mahalanobis <- FALSE fn1 <- NULL - if (is.null(method) || !method %in% c("exact", "cem", "cardinality")) { + if (is_null(method) || !method %in% c("exact", "cem", "cardinality")) { distance <- process.distance(distance, method, treat) if (is.numeric(distance)) { @@ -507,7 +516,7 @@ matchit <- function(formula, } #Process covs - if (!is.null(fn1) && fn1 == "distance2gam") { + if (is_not_null(fn1) && fn1 == "distance2gam") { rlang::check_installed("mgcv") env <- environment(formula) covs.formula <- mgcv::interpret.gam(formula)$fake.formula @@ -529,7 +538,10 @@ matchit <- function(formula, .err(paste0("Missing and non-finite values are not allowed in the covariates. Covariates with missingness or non-finite values:\n\t", paste(covariates.with.missingness, collapse = ", ")), tidy = FALSE) } - if (is.character(covs[[i]])) covs[[i]] <- factor(covs[[i]]) + + if (is.character(covs[[i]])) { + covs[[i]] <- factor(covs[[i]]) + } } #Process exact, mahvars, and antiexact @@ -542,8 +554,11 @@ matchit <- function(formula, antiexactcovs <- process.variable.input(antiexact, data) antiexact <- attr(antiexactcovs, "terms") + chk::chk_flag(verbose) + chk::chk_flag(normalize) + #Estimate distance, discard from common support, optionally re-estimate distance - if (is.null(fn1) || is.full.mahalanobis) { + if (is_null(fn1) || is.full.mahalanobis) { #No distance measure dist.model <- distance <- link <- NULL } @@ -551,25 +566,36 @@ matchit <- function(formula, dist.model <- link <- NULL } else { - if (verbose) { - cat("Estimating propensity scores... \n") - } + .cat_verbose("Estimating propensity scores... \n", verbose = verbose) - if (!is.null(s.weights)) { + if (is_not_null(s.weights)) { attr(s.weights, "in_ps") <- !distance %in% c("bart") } #Estimate distance - if (is.null(distance.options$formula)) distance.options$formula <- formula - if (is.null(distance.options$data)) distance.options$data <- data - if (is.null(distance.options$verbose)) distance.options$verbose <- verbose - if (is.null(distance.options$estimand)) distance.options$estimand <- estimand - if (is.null(distance.options$weights) && !fn1 %in% c("distance2bart")) { + if (is_null(distance.options)) { + distance.options <- list(formula = formula, + data = data, + verbose = verbose, + estimand = estimand) + } + else { + chk::chk_list(distance.options) + + if (is_null(distance.options$formula)) distance.options$formula <- formula + if (is_null(distance.options$data)) distance.options$data <- data + if (is_null(distance.options$verbose)) distance.options$verbose <- verbose + if (is_null(distance.options$estimand)) distance.options$estimand <- estimand + } + + if (is_null(distance.options$weights) && !fn1 %in% c("distance2bart")) { distance.options$weights <- s.weights } - if (!is.null(attr(distance, "link"))) distance.options$link <- attr(distance, "link") - else distance.options$link <- link + distance.options$link <- { + if (is_not_null(attr(distance, "link"))) attr(distance, "link") + else link + } dist.out <- do.call(fn1, distance.options, quote = TRUE) @@ -585,7 +611,7 @@ matchit <- function(formula, } #Process discard - if (is.null(fn1) || is.full.mahalanobis || fn1 == "distance2user") { + if (is_null(fn1) || is.full.mahalanobis || identical(fn1, "distance2user")) { discarded <- discard(treat, distance, discard) } else { @@ -597,7 +623,7 @@ matchit <- function(formula, if (length(distance.options[[i]]) == n.obs) { distance.options[[i]] <- distance.options[[i]][!discarded] } - else if (length(dim(distance.options[[i]])) == 2 && nrow(distance.options[[i]]) == n.obs) { + else if (length(dim(distance.options[[i]])) == 2L && nrow(distance.options[[i]]) == n.obs) { distance.options[[i]] <- distance.options[[i]][!discarded,,drop = FALSE] } } @@ -610,12 +636,17 @@ matchit <- function(formula, #Process caliper calcovs <- NULL - if (!is.null(caliper)) { + if (is_not_null(caliper)) { caliper <- process.caliper(caliper, method, data, covs, mahcovs, distance, discarded, std.caliper) - if (!is.null(attr(caliper, "cal.formula"))) { - calcovs <- model.frame(attr(caliper, "cal.formula"), data, na.action = "na.pass") - if (anyNA(calcovs)) .err("missing values are not allowed in the covariates named in `caliper`") + if (is_not_null(attr(caliper, "cal.formula"))) { + calcovs <- model.frame(attr(caliper, "cal.formula"), data = data, + na.action = "na.pass") + + if (anyNA(calcovs)) { + .err("missing values are not allowed in the covariates named in `caliper`") + } + attr(caliper, "cal.formula") <- NULL } } @@ -629,31 +660,51 @@ matchit <- function(formula, antiexact = antiexact, ...), quote = TRUE) + weights <- match.out[["weights"]] + + #Normalize weights + if (normalize) { + wi <- which(weights > 0) + weights[wi] <- .make_sum_to_n(weights[wi], treat[wi]) + } + info <- create_info(method, fn1, link, discard, replace, ratio, - mahalanobis = is.full.mahalanobis || !is.null(mahvars), + mahalanobis = is.full.mahalanobis || is_not_null(mahvars), transform = attr(is.full.mahalanobis, "transform"), subclass = match.out$subclass, antiexact = colnames(antiexactcovs), - distance_is_matrix = !is.null(distance) && is.matrix(distance)) + distance_is_matrix = is_not_null(distance) && is.matrix(distance)) + + #Create X output, removing duplicate variables + X.list.nm <- c("covs", "exactcovs", "mahcovs", "calcovs", "antiexactcovs") + X <- NULL + for (i in X.list.nm) { + X_tmp <- get0(i, inherits = FALSE) + + if (is_null(X_tmp)) { + next + } - #Create X.list for X output, removing duplicate variables - X.list <- list(covs, exactcovs, mahcovs, calcovs, antiexactcovs) - all.covs <- lapply(X.list, names) - for (i in seq_along(X.list)[-1]) if (!is.null(X.list[[i]])) X.list[[i]][names(X.list[[i]]) %in% unlist(all.covs[1:(i-1)])] <- NULL - X.list[vapply(X.list, is.null, logical(1L))] <- NULL + if (is_null(X)) { + X <- X_tmp + } + else if (!all(hasName(X, names(X_tmp)))) { + X <- cbind(X, X_tmp[!names(X_tmp) %in% names(X)]) + } + } ## putting all the results together out <- list( match.matrix = match.out[["match.matrix"]], subclass = match.out[["subclass"]], - weights = match.out[["weights"]], - X = do.call("cbind", X.list), + weights = weights, + X = X, call = mcall, info = info, estimand = estimand, formula = formula, treat = treat, - distance = if (!is.null(distance) && !is.matrix(distance)) setNames(distance, names(treat)), + distance = if (is_not_null(distance) && !is.matrix(distance)) setNames(distance, names(treat)), discarded = discarded, s.weights = s.weights, exact = exact, @@ -664,74 +715,100 @@ matchit <- function(formula, obj = if (include.obj) match.out[["obj"]] ) - out[vapply(out, is.null, logical(1L))] <- NULL + out[lengths(out) == 0L] <- NULL class(out) <- class(match.out) + out } -#' @export -#' @rdname matchit +#' @exportS3Method print matchit print.matchit <- function(x, ...) { info <- x[["info"]] - cal <- !is.null(x[["caliper"]]) + cal <- is_not_null(x[["caliper"]]) dis <- c("both", "control", "treat")[pmatch(info$discard, c("both", "control", "treat"), 0L)] - disl <- length(dis) > 0 - nm <- is.null(x[["method"]]) - cat("A matchit object") - cat(paste0("\n - method: ", info.to.method(info))) + disl <- is_not_null(dis) + nm <- is_null(x[["method"]]) + + cat("A `matchit` object\n") + + cat(sprintf(" - method: %s\n", info.to.method(info))) - if (!is.null(info$distance) || info$mahalanobis) { - cat("\n - distance: ") + if (is_not_null(info$distance) || info$mahalanobis) { + cat(" - distance: ") if (info$mahalanobis) { - if (is.null(info$transform)) #mahvars used + if (is_null(info$transform)) #mahvars used cat("Mahalanobis") else { cat(capwords(gsub("_", " ", info$transform, fixed = TRUE))) } } - if (!is.null(info$distance) && !info$distance %in% matchit_distances()) { + + if (is_not_null(info$distance) && !info$distance %in% matchit_distances()) { if (info$mahalanobis) cat(" [matching]\n ") if (info$distance_is_matrix) cat("User-defined (matrix)") else if (info$distance != "user") cat("Propensity score") - else if (!is.null(attr(info$distance, "custom"))) cat(attr(info$distance, "custom")) + else if (is_not_null(attr(info$distance, "custom"))) cat(attr(info$distance, "custom")) else cat("User-defined") if (cal || disl) { - cal.ps <- "" %in% names(x[["caliper"]]) - cat(" [") - cat(paste(c("matching", "subclassification", "caliper", "common support")[c(!nm && !info$mahalanobis && info$method != "subclass", !nm && info$method == "subclass", cal.ps, disl)], collapse = ", ")) - cat("]") + cal.ps <- hasName(x[["caliper"]], "") + cat(sprintf(" [%s]\n", + paste(c("matching", "subclassification", "caliper", "common support")[c(!nm && !info$mahalanobis && info$method != "subclass", !nm && info$method == "subclass", cal.ps, disl)], collapse = ", "))) } + if (info$distance != "user") { - cat("\n - estimated with ") - cat(info.to.distance(info)) - if (!is.null(x[["s.weights"]])) { - if (isTRUE(attr(x[["s.weights"]], "in_ps"))) - cat("\n - sampling weights included in estimation") - else cat("\n - sampling weights not included in estimation") + cat(sprintf(" - estimated with %s\n", + info.to.distance(info))) + if (is_not_null(x[["s.weights"]])) { + cat(sprintf(" - sampling weights %sincluded in estimation\n", + if (isTRUE(attr(x[["s.weights"]], "in_ps"))) "" else "not ")) } } } } + if (cal) { - cat(paste0("\n - caliper: ", paste(vapply(seq_along(x[["caliper"]]), function(z) paste0(if (names(x[["caliper"]])[z] == "") "" else names(x[["caliper"]])[z], - " (", format(round(x[["caliper"]][z], 3)), ")"), character(1L)), - collapse = ", "))) + cat(sprintf(" - caliper: %s\n", + paste(vapply(seq_along(x[["caliper"]]), + function(z) { + sprintf("%s (%s)", + if (names(x[["caliper"]])[z] == "") "" + else names(x[["caliper"]])[z], + format(round(x[["caliper"]][z], 3))) + }, character(1L)), + collapse = ", "))) } + if (disl) { - cat("\n - common support: ") - if (dis == "both") cat("units from both groups") - else if (dis == "treat") cat("treated units") - else if (dis == "control") cat("control units") - cat(" dropped") + cat(sprintf(" - common support: %s dropped\n", + switch(dis, + "both" = "units from both groups", + "treat" = "treated units", + "control" = "control units"))) } - cat(paste0("\n - number of obs.: ", length(x[["treat"]]), " (original)", if (!all(x[["weights"]] == 1)) paste0(", ", sum(x[["weights"]] != 0), " (matched)"))) - if (!is.null(x[["s.weights"]])) cat("\n - sampling weights: present") - if (!is.null(x[["estimand"]])) cat(paste0("\n - target estimand: ", x[["estimand"]])) - if (!is.null(x[["X"]])) cat(paste0("\n - covariates: ", ifelse(length(names(x[["X"]])) > 40, "too many to name", paste(names(x[["X"]]), collapse = ", ")))) - cat("\n") + + cat(sprintf(" - number of obs.: %s (original)%s\n", + length(x[["treat"]]), + if (!all_equal_to(x[["weights"]], 1)) + sprintf(", %s (matched)", sum(x[["weights"]] != 0)) + else "")) + + if (is_not_null(x[["s.weights"]])) { + cat(" - sampling weights: present\n") + } + + if (is_not_null(x[["estimand"]])) { + cat(sprintf(" - target estimand: %s\n", x[["estimand"]])) + } + + if (is_not_null(x[["X"]])) { + cat(sprintf(" - covariates: %s\n", + if (length(names(x[["X"]])) > 40L) "too many to name" + else paste(names(x[["X"]]), collapse = ", "))) + } + invisible(x) } diff --git a/R/matchit2cardinality.R b/R/matchit2cardinality.R index 12c65330..5f3d119d 100644 --- a/R/matchit2cardinality.R +++ b/R/matchit2cardinality.R @@ -265,24 +265,24 @@ #' # covariates included if possible. NULL -matchit2cardinality <- function(treat, data, discarded, formula, - ratio = 1, focal = NULL, s.weights = NULL, - replace = FALSE, mahvars = NULL, exact = NULL, - estimand = "ATT", verbose = FALSE, - tols = .05, std.tols = TRUE, - solver = "glpk", time = 1*60, ...){ +matchit2cardinality <- function(treat, data, discarded, formula, + ratio = 1, focal = NULL, s.weights = NULL, + replace = FALSE, mahvars = NULL, exact = NULL, + estimand = "ATT", verbose = FALSE, + tols = .05, std.tols = TRUE, + solver = "highs", time = 1*60, ...) { - if (verbose) { - cat("Cardinality matching... \n") - } + .cat_verbose("Cardinality matching... \n", verbose = verbose) tvals <- unique(treat) nt <- length(tvals) estimand <- toupper(estimand) estimand <- match_arg(estimand, c("ATT", "ATC", "ATE")) - if (!is.null(focal)) { - if (!focal %in% tvals) .err("`focal` must be a value of the treatment") + if (is_not_null(focal)) { + if (!focal %in% tvals) { + .err("`focal` must be a value of the treatment") + } } else if (estimand == "ATC") { focal <- min(tvals) @@ -293,15 +293,18 @@ matchit2cardinality <- function(treat, data, discarded, formula, lab <- names(treat) - weights <- setNames(rep(0, length(treat)), lab) + weights <- rep_with(0, treat) X <- get.covs.matrix(formula, data = data) - if (!is.null(exact)) { - ex <- factor(exactify(model.frame(exact, data = data), nam = lab, sep = ", ", include_vars = TRUE)) + if (is_not_null(exact)) { + ex <- exactify(model.frame(exact, data = data), nam = lab, sep = ", ", include_vars = TRUE) - cc <- do.call("intersect", lapply(tvals, function(t) as.integer(ex)[treat == t])) - if (length(cc) == 0) .err("no matches were found") + cc <- Reduce("intersect", lapply(tvals, function(t) unclass(ex)[treat==t])) + + if (is_null(cc)) { + .err("no matches were found") + } } else { ex <- gl(1, length(treat)) @@ -309,15 +312,17 @@ matchit2cardinality <- function(treat, data, discarded, formula, } #Process mahvars - if (!is.null(mahvars)) { + if (is_not_null(mahvars)) { if (!is.finite(ratio) || !chk::vld_whole_number(ratio)) { .err("`mahvars` can only be used with `method = \"cardinality\"` when `ratio` is a whole number") } + rlang::check_installed("optmatch") + mahcovs <- transform_covariates(mahvars, data = data, method = "mahalanobis", s.weights = s.weights, treat = treat, discarded = discarded) - pair <- setNames(rep(NA_character_, length(treat)), lab) + pair <- rep_with(NA_character_, treat) #Set max problem size to Inf and return to original value after match omps <- getOption("optmatch_max_problem_size") @@ -332,8 +337,8 @@ matchit2cardinality <- function(treat, data, discarded, formula, assign <- get_assign(X) chk::chk_numeric(tols) - if (length(tols) == 1) { - tols <- rep(tols, ncol(X)) + if (length(tols) == 1L) { + tols <- rep.int(tols, ncol(X)) } else if (length(tols) == max(assign)) { tols <- tols[assign] @@ -343,8 +348,8 @@ matchit2cardinality <- function(treat, data, discarded, formula, } chk::chk_logical(std.tols) - if (length(std.tols) == 1) { - std.tols <- rep(std.tols, ncol(X)) + if (length(std.tols) == 1L) { + std.tols <- rep.int(std.tols, ncol(X)) } else if (length(std.tols) == max(assign)) { std.tols <- std.tols[assign] @@ -355,6 +360,8 @@ matchit2cardinality <- function(treat, data, discarded, formula, #Apply std.tols if (any(std.tols)) { + std.tols <- which(std.tols) + sds <- { if (estimand == "ATE") { pooled_sd(X[, std.tols, drop = FALSE], t = treat, @@ -366,37 +373,37 @@ matchit2cardinality <- function(treat, data, discarded, formula, } } - zero.sds <- sds < 1e-10 - - X[,std.tols][,!zero.sds] <- scale(X[, std.tols, drop = FALSE][,!zero.sds, drop = FALSE], - center = FALSE, scale = sds[!zero.sds]) + for (i in which(sds >= 1e-10)) { + X[,std.tols[i]] <- X[,std.tols[i]] / sds[i] + } } opt.out <- setNames(vector("list", nlevels(ex)), levels(ex)) for (e in levels(ex)[cc]) { - if (nlevels(ex) > 1 && verbose) { - cat(sprintf("Matching subgroup %s/%s: %s...\n", - match(e, levels(ex)[cc]), length(cc), e)) + if (nlevels(ex) > 1) { + .cat_verbose(sprintf("Matching subgroup %s/%s: %s...\n", + match(e, levels(ex)[cc]), length(cc), e), + verbose = verbose) } - in.exact <- which(!discarded & ex == e) + .e <- which(!discarded & ex == e) - treat_in.exact <- treat[in.exact] + treat_in.exact <- treat[.e] out <- cardinality_matchit(treat = treat_in.exact, - X = X[in.exact,, drop = FALSE], + X = X[.e,, drop = FALSE], estimand = estimand, tols = tols, - s.weights = s.weights[in.exact], + s.weights = s.weights[.e], ratio = ratio, focal = focal, tvals = tvals, solver = solver, time = time, verbose = verbose) - weights[in.exact] <- out[["weights"]] + weights[.e] <- out[["weights"]] opt.out[[e]] <- out[["opt.out"]] - if (!is.null(mahvars)) { - mo <- eucdist_internal(mahcovs[in.exact[out[["weights"]] > 0],, drop = FALSE], + if (is_not_null(mahvars)) { + mo <- eucdist_internal(mahcovs[.e[out[["weights"]] > 0],, drop = FALSE], treat_in.exact[out[["weights"]] > 0]) pm <- optmatch::pairmatch(mo, @@ -407,7 +414,7 @@ matchit2cardinality <- function(treat, data, discarded, formula, } } - if (!is.null(pair)) { + if (is_not_null(pair)) { psclass <- factor(pair) levels(psclass) <- seq_len(nlevels(psclass)) names(psclass) <- names(treat) @@ -419,14 +426,16 @@ matchit2cardinality <- function(treat, data, discarded, formula, mm <- psclass <- NULL } - if (length(opt.out) == 1L) out <- out[[1]] + if (length(opt.out) == 1L) { + out <- out[[1]] + } res <- list(match.matrix = mm, subclass = psclass, weights = weights, obj = opt.out) - if (verbose) cat("Done.\n") + .cat_verbose("Done.\n", verbose = verbose) class(res) <- "matchit" @@ -439,27 +448,36 @@ cardinality_matchit <- function(treat, X, estimand = "ATT", tols = .05, s.weight time = 2*60, verbose = FALSE) { n <- length(treat) - if (is.null(tvals)) tvals <- if (is.factor(treat)) levels(treat) else sort(unique(treat)) + if (is_null(tvals)) tvals <- if (is.factor(treat)) levels(treat) else sort(unique(treat)) nt <- length(tvals) #Check inputs - if (is.null(s.weights)) s.weights <- rep(1, n) - else for (i in tvals) s.weights[treat == i] <- s.weights[treat == i]/mean(s.weights[treat == i]) + if (is_null(s.weights)) { + s.weights <- rep.int(1, n) + } + else { + s.weights <- .make_sum_to_n(s.weights, treat) + } - if (is.null(focal)) focal <- tvals[length(tvals)] + if (is_null(focal)) { + focal <- tvals[length(tvals)] + } chk::chk_number(time) chk::chk_gt(time, 0) chk::chk_string(solver) solver <- match_arg(solver, c("highs", "glpk", "symphony", "gurobi")) + rlang::check_installed(switch(solver, glpk = "Rglpk", symphony = "Rsymphony", gurobi = "gurobi", highs = "highs")) #Select match type - if (estimand == "ATE") match_type <- "profile_ate" - else if (!is.finite(ratio)) match_type <- "profile_att" - else match_type <- "cardinality" + match_type <- { + if (estimand == "ATE") "profile_ate" + else if (!is.finite(ratio)) "profile_att" + else "cardinality" + } #Set objective and constraints if (match_type == "profile_ate") { @@ -468,15 +486,15 @@ cardinality_matchit <- function(treat, X, estimand = "ATT", tols = .05, s.weight #Objective function: total sample size O <- c( s.weights, #weight for each unit - rep(0, nt) #slack coefs for each sample size (n1, n0) + rep.int(0, nt) #slack coefs for each sample size (n1, n0) ) #Constraint matrix target.means <- apply(X, 2, wm, w = s.weights) C <- matrix(0, nrow = nt * (1 + 2*ncol(X)), ncol = length(O)) - Crhs <- rep(0, nrow(C)) - Cdir <- rep("==", nrow(C)) + Crhs <- rep.int(0, nrow(C)) + Cdir <- rep.int("==", nrow(C)) for (i in seq_len(nt)) { #Num in group i = ni @@ -498,7 +516,7 @@ cardinality_matchit <- function(treat, X, estimand = "ATT", tols = .05, s.weight #If ratio != 0, constrain n0 to be ratio*n1 if (nt == 2L && is.finite(ratio)) { - C_ratio <- c(rep(0, n), rep(-1, nt)) + C_ratio <- c(rep.int(0, n), rep.int(-1, nt)) C_ratio[n + which(tvals == focal)] <- ratio C <- rbind(C, C_ratio) Crhs <- c(Crhs, 0) @@ -506,13 +524,13 @@ cardinality_matchit <- function(treat, X, estimand = "ATT", tols = .05, s.weight } #Coef types - types <- c(rep("B", n), #Matching weights - rep("C", nt)) #Slack coefs for matched group size + types <- c(rep.int("B", n), #Matching weights + rep.int("C", nt)) #Slack coefs for matched group size - lower.bound <- c(rep(0, n), - rep(1, nt)) - upper.bound <- c(rep(1, n), - rep(Inf, nt)) + lower.bound <- c(rep.int(0, n), + rep.int(1, nt)) + upper.bound <- c(rep.int(1, n), + rep.int(Inf, nt)) } else if (match_type == "profile_att") { #Find largest control group that matches treated group @@ -523,8 +541,8 @@ cardinality_matchit <- function(treat, X, estimand = "ATT", tols = .05, s.weight #Objective function: size of matched control group O <- c( - rep(1, n0), #weights for each non-focal unit - rep(0, nt - 1) #slack coef for size of non-focal groups + rep.int(1, n0), #weights for each non-focal unit + rep.int(0, nt - 1) #slack coef for size of non-focal groups ) #Constraint matrix @@ -532,8 +550,8 @@ cardinality_matchit <- function(treat, X, estimand = "ATT", tols = .05, s.weight #One row per constraint, one column per coef C <- matrix(0, nrow = (nt - 1) * (1 + 2*ncol(X)), ncol = length(O)) - Crhs <- rep(0, nrow(C)) - Cdir <- rep("==", nrow(C)) + Crhs <- rep.int(0, nrow(C)) + Cdir <- rep.int("==", nrow(C)) for (i in seq_len(nt - 1)) { #Num in group i = ni @@ -554,13 +572,13 @@ cardinality_matchit <- function(treat, X, estimand = "ATT", tols = .05, s.weight } #Coef types - types <- c(rep("B", n0), #Matching weights - rep("C", nt - 1)) #Slack for num control matched + types <- c(rep.int("B", n0), #Matching weights + rep.int("C", nt - 1)) #Slack for num control matched - lower.bound <- c(rep(0, n0), - rep(0, nt - 1)) - upper.bound <- c(rep(1, n0), - rep(Inf, nt - 1)) + lower.bound <- c(rep.int(0, n0), + rep.int(0, nt - 1)) + upper.bound <- c(rep.int(1, n0), + rep.int(Inf, nt - 1)) } else if (match_type == "cardinality") { #True cardinality matching: find largest balanced sample @@ -576,8 +594,8 @@ cardinality_matchit <- function(treat, X, estimand = "ATT", tols = .05, s.weight t_combs <- combn(tvals, 2, simplify = FALSE) C <- matrix(0, nrow = nt + 2*ncol(X)*length(t_combs), ncol = length(O)) - Crhs <- rep(0, nrow(C)) - Cdir <- rep("==", nrow(C)) + Crhs <- rep.int(0, nrow(C)) + Cdir <- rep.int("==", nrow(C)) for (i in seq_len(nt)) { #Num in group i = ni @@ -601,13 +619,13 @@ cardinality_matchit <- function(treat, X, estimand = "ATT", tols = .05, s.weight } #Coef types - types <- c(rep("B", n), #Matching weights - rep("C", 1)) #Slack coef for treated group size (n1) + types <- c(rep.int("B", n), #Matching weights + rep.int("C", 1)) #Slack coef for treated group size (n1) - lower.bound <- c(rep(0, n), - rep(0, 1)) - upper.bound <- c(rep(1, n), - rep(min(tabulateC(treat)), 1)) + lower.bound <- c(rep.int(0, n), + rep.int(0, 1)) + upper.bound <- c(rep.int(1, n), + rep.int(min(tabulateC(treat)), 1)) } weights <- NULL @@ -628,7 +646,7 @@ cardinality_matchit <- function(treat, X, estimand = "ATT", tols = .05, s.weight weights <- round(sol[seq_len(n)]) } else if (match_type %in% c("profile_att")) { - weights <- rep(1, n) + weights <- rep.int(1, n) weights[treat != focal] <- round(sol[seq_len(n0)]) } @@ -651,7 +669,7 @@ cardinality_matchit <- function(treat, X, estimand = "ATT", tols = .05, s.weight cardinality_error_report <- function(out, solver) { if (solver == "glpk") { if (out$status == 1) { - if (all(out$solution == 0)) { + if (all_equal_to(out$solution, 0)) { .err("the optimization problem may be infeasible. Try increasing the value of `tols`.\nSee `?method_cardinality` for additional details") } .wrn("the optimizer failed to find an optimal solution in the time alotted. The returned solution may not be optimal.\nSee `?method_cardinality` for additional details") @@ -685,7 +703,8 @@ cardinality_error_report <- function(out, solver) { } } -dispatch_optimizer <- function(solver = "glpk", obj, mat, dir, rhs, types, max = TRUE, lb = NULL, ub = NULL, time = NULL, verbose = FALSE) { +dispatch_optimizer <- function(solver = "highs", obj, mat, dir, rhs, types, max = TRUE, + lb = NULL, ub = NULL, time = NULL, verbose = FALSE) { if (solver == "glpk") { dir[dir == "="] <- "==" opt.out <- Rglpk::Rglpk_solve_LP(obj = obj, mat = mat, dir = dir, rhs = rhs, max = max, diff --git a/R/matchit2cem.R b/R/matchit2cem.R index 1f396b62..cc37f52a 100644 --- a/R/matchit2cem.R +++ b/R/matchit2cem.R @@ -75,18 +75,29 @@ #' the stratum, the treated units without a match will be dropped). The #' `k2k.method` argument controls how the distance between units is #' calculated. } -#' \item{`k2k.method`}{ `character`; how the distance +#' \item{`k2k.method`}{`character`; how the distance #' between units should be calculated if `k2k = TRUE`. Allowable arguments #' include `NULL` (for random matching), any argument to #' [distance()] for computing a distance matrix from covariates #' (e.g., `"mahalanobis"`), or any allowable argument to `method` in #' [dist()]. Matching will take place on the original -#' (non-coarsened) variables. The default is `"mahalanobis"`. } -#' \item{`mpower`}{ if `k2k.method = "minkowski"`, the power used in -#' creating the distance. This is passed to the `p` argument of [dist()].} +#' (non-coarsened) variables. The default is `"mahalanobis"`. +#' } +#' \item{`mpower`}{if `k2k.method = "minkowski"`, the power used in +#' creating the distance. This is passed to the `p` argument of [dist()]. +#' } +#' \item{`m.order`}{`character`; the order that the matching takes place when `k2k = TRUE`. Allowable options +#' include `"closest"`, where matching takes place in +#' ascending order of the smallest distance between units; `"farthest"`, where matching takes place in +#' descending order of the smallest distance between units; `"random"`, where matching takes place +#' in a random order; and `"data"` where matching takes place based on the +#' order of units in the data. When `m.order = "random"`, results may differ +#' across different runs of the same code unless a seed is set and specified +#' with [set.seed()]. The default of `NULL` corresponds to `"data"`. See [`method_nearest`] for more information. +#' } #' } #' -#' The arguments `distance` (and related arguments), `exact`, `mahvars`, `discard` (and related arguments), `replace`, `m.order`, `caliper` (and related arguments), and `ratio` are ignored with a warning. +#' The arguments `distance` (and related arguments), `exact`, `mahvars`, `discard` (and related arguments), `replace`, `caliper` (and related arguments), and `ratio` are ignored with a warning. #' #' @section Outputs: #' @@ -109,8 +120,7 @@ #' Setting `k2k = TRUE` is equivalent to first doing coarsened exact #' matching with `k2k = FALSE` and then supplying stratum membership as an #' exact matching variable (i.e., in `exact`) to another call to -#' `matchit()` with `method = "nearest"`, `distance = -#' "mahalanobis"` and an argument to `discard` denoting unmatched units. +#' `matchit()` with `method = "nearest"`. #' It is also equivalent to performing nearest neighbor matching supplying #' coarsened versions of the variables to `exact`, except that #' `method = "cem"` automatically coarsens the continuous variables. The @@ -259,81 +269,82 @@ #' k2k = TRUE, k2k.method = "mahalanobis") #' m.out2 #' summary(m.out2, un = FALSE) -#' + NULL -matchit2cem <- function(treat, covs, estimand = "ATT", s.weights = NULL, verbose = FALSE, ...) { - if (length(covs) == 0) { +matchit2cem <- function(treat, covs, estimand = "ATT", s.weights = NULL, m.order = NULL, verbose = FALSE, ...) { + if (is_null(covs)) { .err("Covariates must be specified in the input formula to use coarsened exact matching") } - if (verbose) cat("Coarsened exact matching...\n") + .cat_verbose("Coarsened exact matching... \n", verbose = verbose) - A <- list(...) + # if (isTRUE(A[["k2k"]])) { + # if (!has_n_unique(treat, 2L)) { + # .err("`k2k` cannot be `TRUE` with a multi-category treatment") + # } + # } estimand <- toupper(estimand) estimand <- match_arg(estimand, c("ATT", "ATC", "ATE")) - #Uses in-house cem, no need for cem package. See cem_matchit.R for code. - strat <- do.call("cem_matchit", c(list(treat = treat, X = covs, estimand = estimand, - s.weights = s.weights), - A[names(A) %in% names(formals(cem_matchit))]), - quote = TRUE) - - levels(strat) <- seq_len(nlevels(strat)) - names(strat) <- names(treat) + #Uses in-house cem, no need for cem package. + strat <- cem_matchit(treat = treat, X = covs, ...) mm <- NULL - if (isTRUE(A[["k2k"]])) { - mm <- nummm2charmm(subclass2mmC(strat, treat, focal = switch(estimand, "ATC" = 0, 1)), - treat) + if (isTRUE(...get("k2k", ...))) { + focal <- switch(estimand, "ATC" = 0, 1) + + mm <- do_k2k(treat = treat, + X = covs, + subclass = strat, + s.weights = s.weights, + focal = focal, + m.order = m.order, + verbose = verbose, + ...) + + strat <- mm2subclass(mm, treat, focal = focal) + levels(strat) <- seq_len(nlevels(strat)) + + mm <- nummm2charmm(mm, treat) + + weights <- get_weights_from_mm(mm, treat, focal) + } + else { + levels(strat) <- seq_len(nlevels(strat)) + + weights <- get_weights_from_subclass(strat, treat, estimand) } - if (verbose) cat("Calculating matching weights... ") + .cat_verbose("Calculating matching weights... ", verbose = verbose) res <- list(match.matrix = mm, subclass = strat, - weights = get_weights_from_subclass(strat, treat, estimand)) + weights = weights) - if (verbose) cat("Done.\n") + .cat_verbose("Done.\n", verbose = verbose) class(res) <- "matchit" res } -cem_matchit <- function(treat, X, cutpoints = "sturges", grouping = list(), k2k = FALSE, - k2k.method = "mahalanobis", mpower = 2, s.weights = NULL, - estimand = "ATT") { +cem_matchit <- function(treat, X, cutpoints = "sturges", grouping = list(), ...) { #In-house implementation of cem. Basically the same except: #treat is a vector if treatment status, not the name of a variable #X is a data.frame of covariates #when cutpoints are given as integer or string, they define the number of bins, not the number of breakpoints. "ss" is no longer allowed. - #When k2k = TRUE, subclasses are created for each pair, mimicking true matching, not each covariate combo. - #k2k.method is used instead of method. When k2k.method = NULL, units are matched based on order rather than random. Default is "mahalanobis" (not available in cem). - #k2k now works with single covariates (previously it was ignored). k2k uses original variables, not coarsened versions - - if (k2k) { - if (length(unique(treat)) > 2) { - .err("`k2k` cannot be `TRUE` with a multi-category treatment") - } - if (!is.null(k2k.method)) { - k2k.method <- tolower(k2k.method) - k2k.method <- match_arg(k2k.method, c(matchit_distances(), "maximum", "manhattan", "canberra", "binary", "minkowski")) - X.match <- transform_covariates(data = X, s.weights = s.weights, treat = treat, - method = if (k2k.method %in% matchit_distances()) k2k.method else "euclidean") - } + for (i in seq_along(X)) { + if (is.ordered(X[[i]])) X[[i]] <- unclass(X[[i]]) } - for (i in names(X)) { - if (is.ordered(X[[i]])) X[[i]] <- as.numeric(X[[i]]) - } is.numeric.cov <- setNames(vapply(X, is.numeric, logical(1L)), names(X)) #Process grouping - if (length(grouping) > 0) { - if (!is.list(grouping) || is.null(names(grouping))) { + if (is_not_null(grouping)) { + if (!is.list(grouping) || is_null(names(grouping))) { .err("`grouping` must be a named list of grouping values with an element for each variable whose values are to be binned") } @@ -353,8 +364,10 @@ cem_matchit <- function(treat, X, cutpoints = "sturges", grouping = list(), k2k !is.list(g) || !all(vapply(g, function(gg) is.atomic(gg) && is.vector(gg), logical(1L))) }, logical(1L))] + nbg <- length(bag.groupings) - if (nbg > 0) { + + if (nbg > 0L) { .err(paste0("Each entry in the list supplied to `groupings` must be a list with entries containing values of the corresponding variable.", "\nIncorrectly specified variable%s:\n\t"), paste(bag.groupings, collapse = ", "), @@ -378,24 +391,29 @@ cem_matchit <- function(treat, X, cutpoints = "sturges", grouping = list(), k2k #Process cutpoints if (!is.list(cutpoints)) { - cutpoints <- setNames(lapply(names(X)[is.numeric.cov], function(i) cutpoints), names(X)[is.numeric.cov]) + cutpoints <- setNames(rep.int(list(cutpoints), sum(is.numeric.cov)), names(X)[is.numeric.cov]) } - if (is.null(names(cutpoints))) { + if (is_null(names(cutpoints))) { .err("`cutpoints` must be a named list of binning values with an element for each numeric variable") } + bad.names <- setdiff(names(cutpoints), names(X)) + nb <- length(bad.names) - if (nb > 0) { + + if (nb > 0L) { .wrn(sprintf("the variable%%s %s named in `cutpoints` %%r not in the variables supplied to `matchit()` and will be ignored", word_list(bad.names, quotes = 2, and.or = "and")), n = nb) cutpoints[bad.names] <- NULL } - if (length(grouping) > 0) { + if (is_not_null(grouping)) { grouping.cutpoint.names <- intersect(names(grouping), names(cutpoints)) + ngc <- length(grouping.cutpoint.names) - if (ngc > 0) { + + if (ngc > 0L) { .wrn(sprintf("the variable%%s %s %%r named in both `grouping` and `cutpoints`; %s entr%%y%%s in `cutpoints` will be ignored", word_list(grouping.cutpoint.names, quotes = 2, and.or = "and"), ngettext(ngc, "its", "their")), n = ngc) @@ -404,43 +422,41 @@ cem_matchit <- function(treat, X, cutpoints = "sturges", grouping = list(), k2k } non.numeric.in.cutpoints <- intersect(names(X)[!is.numeric.cov], names(cutpoints)) + nnnic <- length(non.numeric.in.cutpoints) - if (nnnic > 0) { + + if (nnnic > 0L) { .wrn(sprintf("the variable%%s %s named in `cutpoints` %%r not numeric and %s cutpoints will not be applied. Use `grouping` for non-numeric variables", word_list(non.numeric.in.cutpoints, quotes = 2, and.or = "and"), ngettext(nnnic, "its", "their")), n = nnnic) } - bad.cuts <- setNames(rep(FALSE, length(cutpoints)), names(cutpoints)) + bad.cuts <- rep_with(FALSE, cutpoints) + for (i in names(cutpoints)) { - if (length(cutpoints[[i]]) == 0) { + if (is_null(cutpoints[[i]])) { cutpoints[[i]] <- "sturges" } - else if (length(cutpoints[[i]]) == 1) { - if (is.na(cutpoints[[i]])) is.numeric.cov[i] <- FALSE #Will not be binned - else if (is.character(cutpoints[[i]])) { - bad.cuts[i] <- !(startsWith(cutpoints[[i]], "q") && can_str2num(substring(cutpoints[[i]], 2))) && - is.na(pmatch(cutpoints[[i]], c("sturges", "fd", "scott"))) - } - else if (is.numeric(cutpoints[[i]])) { - if (!is.finite(cutpoints[[i]]) || cutpoints[[i]] < 0) { - bad.cuts[i] <- TRUE - } - else if (cutpoints[[i]] == 0) { - is.numeric.cov[i] <- FALSE #Will not be binned - } - else if (cutpoints[[i]] == 1) { - X[[i]] <- NULL #Removing from X, still in X.match - is.numeric.cov <- is.numeric.cov[names(is.numeric.cov) != i] - } - } - else { - bad.cuts[i] <- TRUE - } - } - else { + else if (length(cutpoints[[i]]) > 1L) { bad.cuts[i] <- !is.numeric(cutpoints[[i]]) } + else if (is.na(cutpoints[[i]])) { + is.numeric.cov[i] <- FALSE #Will not be binned + } + else if (is.character(cutpoints[[i]])) { + bad.cuts[i] <- !(startsWith(cutpoints[[i]], "q") && can_str2num(substring(cutpoints[[i]], 2))) && + is.na(pmatch(cutpoints[[i]], c("sturges", "fd", "scott"))) + } + else if (!is.numeric(cutpoints[[i]]) || !is.finite(cutpoints[[i]]) || cutpoints[[i]] < 0) { + bad.cuts[i] <- TRUE + } + else if (cutpoints[[i]] == 0) { + is.numeric.cov[i] <- FALSE #Will not be binned + } + else if (cutpoints[[i]] == 1) { + X[[i]] <- NULL #Removing from X, still in X.match + is.numeric.cov <- is.numeric.cov[names(is.numeric.cov) != i] + } } if (any(bad.cuts)) { @@ -449,14 +465,20 @@ cem_matchit <- function(treat, X, cutpoints = "sturges", grouping = list(), k2k "\n\t- a single number corresponding to the number of bins", "\n\t- a numeric vector containing the cut points separating bins", "\nIncorrectly specified variable%s:\n\t"), - paste(names(cutpoints)[bad.cuts], collapse = ", "), + paste(names(cutpoints)[bad.cuts], collapse = ", "), tidy = FALSE, n = sum(bad.cuts)) } + if (is_null(X)) { + return(rep_with(1L, treat)) + } + #Create bins for numeric variables for (i in names(X)[is.numeric.cov]) { - if (is.null(cutpoints) || !i %in% names(cutpoints)) bins <- "sturges" - else bins <- cutpoints[[i]] + bins <- { + if (is_not_null(cutpoints) && any(names(cutpoints) == i)) cutpoints[[i]] + else "sturges" + } if (is.character(bins)) { if (startsWith(bins, "q") || can_str2num(substring(bins, 2))) { @@ -474,7 +496,7 @@ cem_matchit <- function(treat, X, cutpoints = "sturges", grouping = list(), k2k } } - if (length(bins) == 1) { + if (length(bins) == 1L) { #cutpoints is number of bins, unlike in cem breaks <- seq(min(X[[i]]), max(X[[i]]), length = bins + 1) breaks[c(1, bins + 1)] <- c(-Inf, Inf) @@ -486,70 +508,83 @@ cem_matchit <- function(treat, X, cutpoints = "sturges", grouping = list(), k2k X[[i]] <- findInterval(X[[i]], breaks) } - if (length(X) == 0) { - subclass <- setNames(rep(1, length(treat)), names(treat)) - } - else { - #Exact match - xx <- exactify(X, names(treat)) - cc <- do.call("intersect", unname(split(xx, treat))) + #Exact match + ex <- unclass(exactify(X, names(treat))) - if (length(cc) == 0) { - .err("no units were matched. Try coarsening the variables further or decrease the number of variables to match on") - } + cc <- Reduce("intersect", lapply(unique(treat), function(t) ex[treat==t])) - subclass <- setNames(match(xx, cc), names(treat)) + if (is_null(cc)) { + .err("no units were matched. Try coarsening the variables further or decrease the number of variables to match on") } - extra.sub <- max(subclass, na.rm = TRUE) - - if (k2k) { - - na.sub <- is.na(subclass) + setNames(factor(match(ex, cc), nmax = length(cc)), names(treat)) +} - s <- switch(estimand, "ATC" = 0, 1) +do_k2k <- function(treat, X, subclass, k2k.method = "mahalanobis", mpower = 2, s.weights = NULL, + focal, m.order = "data", verbose = FALSE, k2k = TRUE, ...) { + #Note: need k2k argument to prevent partial matching for k2k.method - for (i in which(tabulateC(subclass[!na.sub]) > 2)) { + m.order <- match_arg(m.order, c("data", "random", "closest", "farthest")) - in.sub <- which(!na.sub & subclass == i) + .cat_verbose("K:K matching...\n", verbose = verbose) - #Compute distance matrix; all 0s if k2k.method = NULL for matched based on data order - if (is.null(k2k.method)) { - dist.mat <- matrix(0, nrow = sum(treat[in.sub] == s), ncol = sum(treat[in.sub] != s), - dimnames = list(names(treat)[in.sub][treat[in.sub] == s], - names(treat)[in.sub][treat[in.sub] != s])) + if (is_not_null(k2k.method)) { + chk::chk_string(k2k.method) + k2k.method <- tolower(k2k.method) + k2k.method <- match_arg(k2k.method, c(matchit_distances(), "maximum", "manhattan", "canberra", "binary", "minkowski")) - } - else if (k2k.method %in% matchit_distances()) { - #X.match has been transformed - dist.mat <- eucdist_internal(X.match[in.sub,,drop = FALSE], treat[in.sub] == s) - } - else { - dist.mat <- dist_to_matrixC(dist(X.match[in.sub,,drop = FALSE], method = k2k.method, p = mpower)) + if (k2k.method == "minkowski") { + chk::chk_number(mpower) + chk::chk_gt(mpower, 0) - #Put smaller group on rows - d.rows <- which(rownames(dist.mat) %in% names(treat[in.sub])[treat[in.sub] == s]) - dist.mat <- dist.mat[d.rows, -d.rows, drop = FALSE] + if (mpower == 2) { + k2k.method <- "euclidean" } + } - #For each member of group on row, find closest remaining pair from cols - while (all(dim(dist.mat) > 0)) { - extra.sub <- extra.sub + 1 - - closest <- which.min(dist.mat[1,]) - subclass[c(rownames(dist.mat)[1], colnames(dist.mat)[closest])] <- extra.sub + X.match <- transform_covariates(data = X, s.weights = s.weights, treat = treat, + method = if (k2k.method %in% matchit_distances()) k2k.method else "euclidean") + distance <- NULL + } + else { + k2k.method <- "euclidean" + X.match <- NULL + distance <- rep.int(0.0, length(treat)) + } - #Drop already paired units from dist.mat - dist.mat <- dist.mat[-1,-closest, drop = FALSE] - } + reuse.max <- 1L + caliper.dist <- caliper.covs <- caliper.covs.mat <- antiexactcovs <- unit.id <- NULL - #If any unmatched units remain, give them NA subclass - if (any(dim(dist.mat) > 0)) is.na(subclass)[unlist(dimnames(dist.mat))] <- TRUE + if (k2k.method %in% matchit_distances()) { + discarded <- is.na(subclass) + ratio <- rep.int(1L, sum(treat == focal)) + mm <- nn_matchC_dispatch(treat, focal, ratio, discarded, reuse.max, distance, NULL, + subclass, caliper.dist, caliper.covs, caliper.covs.mat, X.match, + antiexactcovs, unit.id, m.order, verbose) + } + else { + mm <- matrix(NA_integer_, ncol = 1, nrow = sum(treat == 1), + dimnames = list(names(treat)[treat == 1], NULL)) + + for (s in levels(subclass)) { + .e <- which(subclass == s) + treat_ <- treat[.e] + discarded_ <- rep.int(FALSE, length(.e)) + ex_ <- NULL + ratio_ <- rep.int(1L, sum(treat_ == focal)) + distance_mat <- as.matrix(dist(X.match[.e,,drop = FALSE], + method = k2k.method, p = mpower))[treat_ == focal, treat_ != focal, drop = FALSE] + + mm_ <- nn_matchC_dispatch(treat_, focal, ratio_, discarded_, reuse.max, distance, distance_mat, + ex_, caliper.dist, caliper.covs, caliper.covs.mat, NULL, + antiexactcovs, unit.id, m.order, FALSE) + + #Ensure matched indices correspond to indices in full sample, not subgroup + mm_[] <- .e[mm_] + mm[rownames(mm_),] <- mm_ } } - subclass <- factor(subclass, nmax = extra.sub) - - setNames(subclass, names(treat)) -} + mm +} \ No newline at end of file diff --git a/R/matchit2exact.R b/R/matchit2exact.R index ecf3bbae..85170634 100644 --- a/R/matchit2exact.R +++ b/R/matchit2exact.R @@ -89,29 +89,30 @@ NULL matchit2exact <- function(treat, covs, data, estimand = "ATT", verbose = FALSE, ...){ - if(verbose) - cat("Exact matching... \n") + .cat_verbose("Exact matching...\n", verbose = verbose) - if (length(covs) == 0) .err("covariates must be specified in the input formula to use exact matching") + if (is_null(covs)) { + .err("covariates must be specified in the input formula to use exact matching") + } estimand <- toupper(estimand) estimand <- match_arg(estimand, c("ATT", "ATC", "ATE")) xx <- exactify(covs, names(treat)) - cc <- do.call("intersect", lapply(unique(treat), function(t) xx[treat == t])) + cc <- Reduce("intersect", lapply(unique(treat), function(t) xx[treat == t])) - if (length(cc) == 0) { - .err("No exact matches were found") + if (is_null(cc)) { + .err("no exact matches were found") } psclass <- setNames(factor(match(xx, cc), nmax = length(cc)), names(treat)) - if (verbose) cat("Calculating matching weights... ") + .cat_verbose("Calculating matching weights... ", verbose = verbose) res <- list(subclass = psclass, weights = get_weights_from_subclass(psclass, treat, estimand)) - if (verbose) cat("Done.\n") + .cat_verbose("Done.\n", verbose = verbose) class(res) <- "matchit" res diff --git a/R/matchit2full.R b/R/matchit2full.R index 8da98e25..94af3619 100644 --- a/R/matchit2full.R +++ b/R/matchit2full.R @@ -72,7 +72,7 @@ #' place when `distance` corresponds to a propensity score (e.g., for #' caliper matching or to discard units for common support). If specified, the #' distance measure will not be used in matching. -#' @param antiexact for which variables ant-exact matching should take place. +#' @param antiexact for which variables anti-exact matching should take place. #' Anti-exact matching is processed using \pkgfun{optmatch}{antiExactMatch}. #' @param discard a string containing a method for discarding units outside a #' region of common support. Only allowed when `distance` corresponds to a @@ -230,12 +230,11 @@ matchit2full <- function(treat, formula, data, distance, discarded, rlang::check_installed("optmatch") - if (verbose) cat("Full matching... \n") - - A <- list(...) + .cat_verbose("Full matching... \n", verbose = verbose) fm.args <- c("omit.fraction", "mean.controls", "tol", "solver") - A[!names(A) %in% fm.args] <- NULL + A <- setNames(lapply(fm.args, ...get, ...), fm.args) + A[lengths(A) == 0L] <- NULL #Set max problem size to Inf and return to original value after match omps <- getOption("optmatch_max_problem_size") @@ -255,10 +254,10 @@ matchit2full <- function(treat, formula, data, distance, discarded, treat_ <- setNames(as.integer(treat[!discarded] == focal), names(treat)[!discarded]) - # if (!is.null(data)) data <- data[!discarded,] + # if (is_not_null(data)) data <- data[!discarded,] if (is.full.mahalanobis) { - if (length(attr(terms(formula, data = data), "term.labels")) == 0) { + if (is_null(attr(terms(formula, data = data), "term.labels"))) { .err(sprintf("covariates must be specified in the input formula when `distance = \"%s\"`", attr(is.full.mahalanobis, "transform"))) } @@ -269,20 +268,24 @@ matchit2full <- function(treat, formula, data, distance, discarded, max.controls <- attr(ratio, "max.controls") #Exact matching strata - if (!is.null(exact)) { + if (is_not_null(exact)) { ex <- factor(exactify(model.frame(exact, data = data), sep = ", ", include_vars = TRUE)[!discarded]) - cc <- intersect(as.integer(ex)[treat_==1], as.integer(ex)[treat_==0]) - if (length(cc) == 0) .err("No matches were found") + + cc <- Reduce("intersect", lapply(unique(treat_), function(t) unclass(ex)[treat_==t])) + + if (is_null(cc)) { + .err("No matches were found") + } } else { - ex <- factor(rep("_", length(treat_)), levels = "_") + ex <- gl(1, length(treat_), labels = "_") cc <- 1 } #Create distance matrix; note that Mahalanobis distance computed using entire #sample (minus discarded), like method2nearest, as opposed to within exact strata, like optmatch. - if (!is.null(mahvars)) { + if (is_not_null(mahvars)) { transform <- if (is.full.mahalanobis) attr(is.full.mahalanobis, "transform") else "mahalanobis" mahcovs <- transform_covariates(mahvars, data = data, method = transform, s.weights = s.weights, treat = treat, @@ -307,7 +310,7 @@ matchit2full <- function(treat, formula, data, distance, discarded, mo <- optmatch::as.InfinitySparseMatrix(mo) #Process antiexact - if (!is.null(antiexact)) { + if (is_not_null(antiexact)) { antiexactcovs <- model.frame(antiexact, data) for (i in seq_len(ncol(antiexactcovs))) { mo <- mo + optmatch::antiExactMatch(antiexactcovs[[i]][!discarded], z = treat_) @@ -315,7 +318,7 @@ matchit2full <- function(treat, formula, data, distance, discarded, } #Process caliper - if (!is.null(caliper)) { + if (is_not_null(caliper)) { if (min.controls != 0) { .err("calipers cannot be used with `method = \"full\"` when `min.controls` is specified") } @@ -324,11 +327,12 @@ matchit2full <- function(treat, formula, data, distance, discarded, cov.cals <- setdiff(names(caliper), "") calcovs <- get.covs.matrix(reformulate(cov.cals, intercept = FALSE), data = data) } + for (i in seq_along(caliper)) { if (names(caliper)[i] != "") { mo_cal <- optmatch::match_on(setNames(calcovs[!discarded, names(caliper)[i]], names(treat_)), z = treat_) } - else if (is.null(mahvars) || is.matrix(distance)) { + else if (is_null(mahvars) || is.matrix(distance)) { mo_cal <- mo } else { @@ -337,45 +341,55 @@ matchit2full <- function(treat, formula, data, distance, discarded, mo <- mo + optmatch::caliper(mo_cal, caliper[i]) } + rm(mo_cal) } #Initialize pair membership; must include names - pair <- setNames(rep(NA_character_, length(treat)), names(treat)) + pair <- rep_with(NA_character_, treat) p <- setNames(vector("list", nlevels(ex)), levels(ex)) - t_df <- data.frame(treat) + A$data <- data.frame(treat) #just to get rownames; not actually used in matching + A$min.controls <- min.controls + A$max.controls <- max.controls for (e in levels(ex)[cc]) { - if (nlevels(ex) > 1) { - if (verbose) { - cat(sprintf("Matching subgroup %s/%s: %s...\n", - match(e, levels(ex)[cc]), length(cc), e)) - } + if (nlevels(ex) > 1L) { + .cat_verbose(sprintf("Matching subgroup %s/%s: %s...\n", + match(e, levels(ex)[cc]), length(cc), e), + verbose = verbose) + mo_ <- mo[ex[treat_==1] == e, ex[treat_==0] == e] } - else mo_ <- mo + else { + mo_ <- mo + } + + if (any(dim(mo_) == 0) || !any(is.finite(mo_))) { + next + } - if (any(dim(mo_) == 0) || !any(is.finite(mo_))) next - else if (all(dim(mo_) == 1) && all(is.finite(mo_))) { + if (all_equal_to(dim(mo_), 1) && all(is.finite(mo_))) { pair[ex == e] <- paste(1, e, sep = "|") next } + A$x <- mo_ + matchit_try({ - p[[e]] <- do.call(optmatch::fullmatch, - c(list(mo_, - min.controls = min.controls, - max.controls = max.controls, - data = t_df), #just to get rownames; not actually used in matching - A)) + p[[e]] <- do.call(optmatch::fullmatch, A) }, from = "optmatch") pair[names(p[[e]])[!is.na(p[[e]])]] <- paste(as.character(p[[e]][!is.na(p[[e]])]), e, sep = "|") } - if (all(is.na(pair))) .err("No matches were found") - if (length(p) == 1) p <- p[[1]] + if (all(is.na(pair))) { + .err("No matches were found") + } + + if (length(p) == 1L) { + p <- p[[1]] + } psclass <- factor(pair) levels(psclass) <- seq_len(nlevels(psclass)) @@ -384,13 +398,14 @@ matchit2full <- function(treat, formula, data, distance, discarded, #No match.matrix because treated units don't index matched strata (i.e., more than one #treated unit can be in the same stratum). Stratum information is contained in subclass. - if (verbose) cat("Calculating matching weights... ") + .cat_verbose("Calculating matching weights... ", + verbose = verbose) res <- list(subclass = psclass, weights = get_weights_from_subclass(psclass, treat, estimand), obj = p) - if (verbose) cat("Done.\n") + .cat_verbose("Done.\n", verbose = verbose) class(res) <- c("matchit") res diff --git a/R/matchit2genetic.R b/R/matchit2genetic.R index 9c5c3927..13e61ffc 100644 --- a/R/matchit2genetic.R +++ b/R/matchit2genetic.R @@ -80,7 +80,7 @@ #' to the distance matrix. Use `mahvars` to only supply a subset. Even if #' `mahvars` is specified, balance will be optimized on all covariates in #' `formula`. See Details. -#' @param antiexact for which variables ant-exact matching should take place. +#' @param antiexact for which variables anti-exact matching should take place. #' Anti-exact matching is processed using the `restrict` argument to #' `Matching::GenMatch()` and `Matching::Match()`. #' @param discard a string containing a method for discarding units outside a @@ -190,6 +190,10 @@ #' support the ATE as an estimand, `matchit()` only supports the ATT and #' ATC for genetic matching. #' +#' ## Reproducibility +#' +#' Genetic matching involves a random component, so a seed must be set using [set.seed()] to ensure reproducibility. When `cluster` is used for parallel processing, the seed must be compatible with parallel processing (e.g., by setting `type = "L'Ecuyer-CMRG"`). +#' #' @seealso [matchit()] for a detailed explanation of the inputs and outputs of #' a call to `matchit()`. #' @@ -255,9 +259,11 @@ matchit2genetic <- function(treat, data, distance, discarded, rlang::check_installed(c("Matching", "rgenoud")) - if (verbose) cat("Genetic matching... \n") + .cat_verbose("Genetic matching...\n", verbose = verbose) - A <- list(...) + args <- names(formals(Matching::GenMatch)) + A <- setNames(lapply(args, ...get, ...), args) + A[lengths(A) == 0L] <- NULL estimand <- toupper(estimand) estimand <- match_arg(estimand, c("ATT", "ATC")) @@ -273,7 +279,7 @@ matchit2genetic <- function(treat, data, distance, discarded, if (!replace) { if (sum(!discarded & treat != focal) < sum(!discarded & treat == focal)) { .wrn(sprintf("fewer %s units than %s units; not all %s units will get a match", - tc[2], tc[1], tc[1])) + tc[2], tc[1], tc[1])) } else if (sum(!discarded & treat != focal) < sum(!discarded & treat == focal)*ratio) { .err(sprintf("not enough %s units for %s matches for each %s unit", @@ -286,11 +292,11 @@ matchit2genetic <- function(treat, data, distance, discarded, n.obs <- length(treat) n1 <- sum(treat == 1) - if (is.null(names(treat))) names(treat) <- seq_len(n.obs) + if (is_null(names(treat))) names(treat) <- seq_len(n.obs) m.order <- { - if (is.null(distance)) match_arg(m.order, c("data", "random")) - else if (!is.null(m.order)) match_arg(m.order, c("largest", "smallest", "random", "data")) + if (is_null(distance)) match_arg(m.order, c("data", "random")) + else if (is_not_null(m.order)) match_arg(m.order, c("largest", "smallest", "random", "data")) else if (estimand == "ATC") "smallest" else "largest" } @@ -304,7 +310,7 @@ matchit2genetic <- function(treat, data, distance, discarded, #Create X (matching variables) and covs_to_balance covs_to_balance <- get.covs.matrix(formula, data = data) - if (!is.null(mahvars)) { + if (is_not_null(mahvars)) { X <- get.covs.matrix.for.dist(mahvars, data = data) } else if (is.full.mahalanobis) { @@ -314,36 +320,46 @@ matchit2genetic <- function(treat, data, distance, discarded, X <- cbind(covs_to_balance, distance) } - if (ncol(covs_to_balance) == 0) { + if (ncol(covs_to_balance) == 0L) { .err("covariates must be specified in the input formula to use genetic matching") } #Process exact; exact.log will be supplied to GenMatch() and Match() - if (!is.null(exact)) { + if (is_not_null(exact)) { #Add covariates in exact not in X to X - ex <- as.integer(factor(exactify(model.frame(exact, data = data), names(treat), sep = ", ", include_vars = TRUE))) + ex <- unclass(exactify(model.frame(exact, data = data), names(treat), + sep = ", ", include_vars = TRUE)) cc <- intersect(ex[treat==1], ex[treat==0]) - if (length(cc) == 0) .err("No matches were found") + + if (is_null(cc)) { + .err("No matches were found") + } X <- cbind(X, ex) - exact.log <- c(rep(FALSE, ncol(X) - 1), TRUE) + exact.log <- c(rep.int(FALSE, ncol(X) - 1L), TRUE) + } + else { + exact.log <- ex <- NULL } - else exact.log <- ex <- NULL #Process caliper; cal will be supplied to GenMatch() and Match() - if (!is.null(caliper)) { + if (is_not_null(caliper)) { #Add covariates in caliper other than distance (cov.cals) not in X to X cov.cals <- setdiff(names(caliper), "") - if (length(cov.cals) > 0 && any(!cov.cals %in% colnames(X))) { + if (is_not_null(cov.cals) && any(!cov.cals %in% colnames(X))) { calcovs <- get.covs.matrix(reformulate(cov.cals[!cov.cals %in% colnames(X)]), data = data) X <- cbind(X, calcovs) #Expand exact.log for newly added covariates - if (!is.null(exact.log)) exact.log <- c(exact.log, rep(FALSE, ncol(calcovs))) + if (is_not_null(exact.log)) { + exact.log <- c(exact.log, rep.int(FALSE, ncol(calcovs))) + } + } + else { + cov.cals <- NULL } - else cov.cals <- NULL #Matching::Match multiplies calipers by pop SD, so we need to divide by pop SD to unstandardize pop.sd <- function(x) sqrt(sum((x-mean(x))^2)/length(x)) @@ -353,31 +369,32 @@ matchit2genetic <- function(treat, data, distance, discarded, }, numeric(1L)) #cal needs one value per variable in X - cal <- setNames(rep(Inf, ncol(X)), colnames(X)) + cal <- setNames(rep.int(Inf, ncol(X)), colnames(X)) #First put covariate calipers into cal - if (length(cov.cals) > 0) { + if (is_not_null(cov.cals)) { cal[intersect(cov.cals, names(cal))] <- caliper[intersect(cov.cals, names(cal))] } #Then put distance caliper into cal - if ("" %in% names(caliper)) { + if (hasName(caliper, "")) { dist.cal <- caliper[names(caliper) == ""] - if (!is.null(mahvars)) { + if (is_not_null(mahvars)) { #If mahvars specified, distance is not yet in X, so add it to X X <- cbind(X, distance) cal <- c(cal, dist.cal) #Expand exact.log for newly added distance - if (!is.null(exact.log)) exact.log <- c(exact.log, FALSE) + if (is_not_null(exact.log)) exact.log <- c(exact.log, FALSE) } else { #Otherwise, distance is in X at the specified index cal[ncol(covs_to_balance) + 1] <- dist.cal } } - else dist.cal <- NULL - + else { + dist.cal <- NULL + } } else { cal <- dist.cal <- cov.cals <- NULL @@ -389,9 +406,9 @@ matchit2genetic <- function(treat, data, distance, discarded, treat_ <- treat[ord] covs_to_balance <- covs_to_balance[ord,,drop = FALSE] X <- X[ord,,drop = FALSE] - if (!is.null(s.weights)) s.weights <- s.weights[ord] + if (is_not_null(s.weights)) s.weights <- s.weights[ord] - if (!is.null(antiexact)) { + if (is_not_null(antiexact)) { antiexactcovs <- model.frame(antiexact, data)[ord,,drop = FALSE] antiexact_restrict <- cbind(do.call("rbind", lapply(seq_len(ncol(antiexactcovs)), function(i) { unique.vals <- unique(antiexactcovs[,i]) @@ -399,14 +416,15 @@ matchit2genetic <- function(treat, data, distance, discarded, t(combn(which(antiexactcovs[,i] == u), 2)) })) })), -1) - if (!is.null(A[["restrict"]])) A[["restrict"]] <- rbind(A[["restrict"]], antiexact_restrict) + + if (is_not_null(A[["restrict"]])) A[["restrict"]] <- rbind(A[["restrict"]], antiexact_restrict) else A[["restrict"]] <- antiexact_restrict } else { antiexactcovs <- NULL } - if (is.null(A[["distance.tolerance"]])) { + if (is_null(A[["distance.tolerance"]])) { A[["distance.tolerance"]] <- 0 } @@ -418,7 +436,7 @@ matchit2genetic <- function(treat, data, distance, discarded, replace = replace, estimand = "ATT", ties = FALSE, CommonSupport = FALSE, verbose = verbose, weights = s.weights, print.level = 2*verbose), - A[names(A) %in% names(formals(Matching::GenMatch))])) + A[names(A) %in% args])) }, from = "Matching", dont_warn_if = "replace==FALSE, but there are more (weighted) treated obs than control obs.") } @@ -441,9 +459,11 @@ matchit2genetic <- function(treat, data, distance, discarded, replace = replace, estimand = "ATT", ties = FALSE, weights = s.weights, CommonSupport = FALSE, distance.tolerance = A[["distance.tolerance"]], Weight = 3, - Weight.matrix = if (use.genetic) g.out - else if (is.null(s.weights)) generalized_inverse(cor(X)) - else generalized_inverse(cov.wt(X, s.weights, cor = TRUE)$cor), + Weight.matrix = { + if (use.genetic) g.out + else if (is_null(s.weights)) generalized_inverse(cor(X)) + else generalized_inverse(cov.wt(X, s.weights, cor = TRUE)$cor) + }, restrict = A[["restrict"]], version = "fast") }, from = "Matching", dont_warn_if = "replace==FALSE, but there are more (weighted) treated obs than control obs.") @@ -464,10 +484,10 @@ matchit2genetic <- function(treat, data, distance, discarded, # else { # #Use nn_match() instead of Match() # ord1 <- ord[ord %in% which(treat == 1)] - # if (!is.null(cov.cals)) calcovs <- get.covs.matrix(reformulate(cov.cals), data = data) + # if (is_not_null(cov.cals)) calcovs <- get.covs.matrix(reformulate(cov.cals), data = data) # else calcovs <- NULL # - # if (is.null(g.out)) MWM <- generalized_inverse(cov(X)) + # if (is_null(g.out)) MWM <- generalized_inverse(cov(X)) # else MWM <- g.out$Weight.matrix %*% diag(1/apply(X, 2, var)) # # if (isFALSE(A$fast)) { @@ -483,21 +503,23 @@ matchit2genetic <- function(treat, data, distance, discarded, # dimnames(mm) <- list(lab1, seq_len(ratio)) # } - if (verbose) cat("Calculating matching weights... ") + .cat_verbose("Calculating matching weights... ", verbose = verbose) if (replace) { psclass <- NULL + weights <- get_weights_from_mm(mm, treat, 1L) } else { - psclass <- mm2subclass(mm, treat) + psclass <- mm2subclass(mm, treat, 1L) + weights <- get_weights_from_subclass(psclass, treat) } res <- list(match.matrix = nummm2charmm(mm, treat), subclass = psclass, - weights = get_weights_from_mm(mm, treat), + weights = weights, obj = g.out) - if (verbose) cat("Done.\n") + .cat_verbose("Done.\n", verbose = verbose) class(res) <- "matchit" res diff --git a/R/matchit2nearest.R b/R/matchit2nearest.R index f8878cf1..1d68d439 100644 --- a/R/matchit2nearest.R +++ b/R/matchit2nearest.R @@ -56,12 +56,12 @@ #' by the argument to `distance`. #' @param estimand a string containing the desired estimand. Allowable options #' include `"ATT"` and `"ATC"`. See Details. -#' @param exact for which variables exact matching should take place. +#' @param exact for which variables exact matching should take place; two units with different values of an exact matching variable will not be paired. #' @param mahvars for which variables Mahalanobis distance matching should take #' place when `distance` corresponds to a propensity score (e.g., for #' caliper matching or to discard units for common support). If specified, the #' distance measure will not be used in matching. -#' @param antiexact for which variables ant-exact matching should take place. +#' @param antiexact for which variables anti-exact matching should take place; two units with the same value of an anti-exact matching variable will not be paired. #' @param discard a string containing a method for discarding units outside a #' region of common support. Only allowed when `distance` corresponds to a #' propensity score. @@ -69,31 +69,31 @@ #' re-estimate the propensity score in the remaining sample prior to matching. #' @param s.weights the variable containing sampling weights to be incorporated #' into propensity score models and balance statistics. -#' @param replace whether matching should be done with replacement. +#' @param replace whether matching should be done with replacement (i.e., whether control units can be used as matches multiple times). See also the `reuse.max` argument below. Default is `FALSE` for matching without replacement. #' @param m.order the order that the matching takes place. Allowable options #' include `"largest"`, where matching takes place in descending order of #' distance measures; `"smallest"`, where matching takes place in ascending #' order of distance measures; `"closest"`, where matching takes place in -#' order of the distance between units; `"random"`, where matching takes place +#' ascending order of the smallest distance between units; `"farthest"`, where matching takes place in +#' descending order of the smallest distance between units; `"random"`, where matching takes place #' in a random order; and `"data"` where matching takes place based on the #' order of units in the data. When `m.order = "random"`, results may differ #' across different runs of the same code unless a seed is set and specified #' with [set.seed()]. The default of `NULL` corresponds to `"largest"` when a #' propensity score is estimated or supplied as a vector and `"data"` -#' otherwise. -#' @param caliper the width(s) of the caliper(s) used for caliper matching. See -#' Details and Examples. +#' otherwise. See Details for more information. +#' @param caliper the width(s) of the caliper(s) used for caliper matching. Two units with a difference on a caliper variable larger than the caliper will not be paired. See Details and Examples. #' @param std.caliper `logical`; when calipers are specified, whether they #' are in standard deviation units (`TRUE`) or raw units (`FALSE`). #' @param ratio how many control units should be matched to each treated unit #' for k:1 matching. For variable ratio matching, see section "Variable Ratio -#' Matching" in Details below. +#' Matching" in Details below. When `ratio` is greater than 1, all treated units will be attempted to be matched with a control unit before any treated unit is matched with a second control unit, etc. This reduces the possibility that control units will be used up before some treated units receive any matches. #' @param min.controls,max.controls for variable ratio matching, the minimum #' and maximum number of controls units to be matched to each treated unit. See #' section "Variable Ratio Matching" in Details below. #' @param verbose `logical`; whether information about the matching #' process should be printed to the console. When `TRUE`, a progress bar -#' implemented using *RcppProgress* will be displayed. +#' implemented using *RcppProgress* will be displayed along with an estimate of the time remaining. #' @param \dots additional arguments that control the matching specification: #' \describe{ #' \item{`reuse.max`}{ `numeric`; the maximum number of @@ -109,8 +109,7 @@ #' unit. Once a control observation has been matched, no other observation with #' the same unit ID can be used as matches. This ensures each control unit is #' used only once even if it has multiple observations associated with it. -#' Omitting this argument is the same as giving each observation a unique ID. -#' Ignored when `replace = TRUE`. } +#' Omitting this argument is the same as giving each observation a unique ID.} #' } #' #' @note Sometimes an error will be produced by *Rcpp* along the lines of @@ -189,8 +188,7 @@ #' #' ## Variable Ratio Matching #' -#' `matchit()` can perform variable -#' ratio "extremal" matching as described by Ming and Rosenbaum (2000). This +#' `matchit()` can perform variable ratio "extremal" matching as described by Ming and Rosenbaum (2000). This #' method tends to result in better balance than fixed ratio matching at the #' expense of some precision. When `ratio > 1`, rather than requiring all #' treated units to receive `ratio` matches, each treated unit is assigned @@ -219,16 +217,22 @@ #' `ratio`, and `max.controls` must be greater than `ratio`. See #' Examples below for an example of their use. #' -#' ## Using `m.order = "closest"` +#' ## Using `m.order = "closest"` or `"farthest"` #' -#' As of version 4.6.0, `m.order` can be set to `"closest"`, which works regardless of how the distance measure is specified. This matches in order of the distance between units. The closest pair of units across all potential pairs of units will be matched first; the second closest pair of all potential pairs will be matched second, etc. This ensures that the best possible matches are given priority, and in that sense performs similarly to `m.order = "smallest"`. +#' `m.order` can be set to `"closest"` or `"farthest"`, which work regardless of how the distance measure is specified. This matches in order of the distance between units. First, all the closest match is found for all treated units and the pairwise distances computed; when `m.order = "closest"` the pair with the smallest of the distances is matched first, and when `m.order = "farthest"`, the pair with the largest of the distances is matched first. Then, the pair with the second smallest (or largest) is matched second. If the matched control is ineligible (i.e., because it has already been used in a prior match), a new match is found for the treated unit, the new pair's distance is re-computed, and the pairs are re-ordered by distance. +#' +#' Using `m.order = "closest"` ensures that the best possible matches are given priority, and in that sense should perform similarly to `m.order = "smallest"`. It can be used to ensure the best matches, especially when matching with a caliper. Using `m.order = "farthest"` ensures that the hardest units to match are given their best chance to find a close match, and in that sense should perform similarly to `m.order = "largest"`. It can be used to reduce the possibility of extreme imbalance when there are hard-to-match units competing for controls. Note that `m.order = "farthest"` **does not** implement "far matching" (i.e., finding the farthest control unit from each treated unit); it defines the order in which the closest matches are selected. +#' +#' ## Reproducibility +#' +#' Nearest neighbor matching involves a random component only when `m.order = "random"` (or when the propensity is estimated using a method with randomness; see [`distance`] for details), so a seed must be set in that case using [set.seed()] to ensure reproducibility. Otherwise, it is purely deterministic, and any ties are broken based on the order in which the data appear. #' #' @seealso [matchit()] for a detailed explanation of the inputs and outputs of #' a call to `matchit()`. #' #' [method_optimal()] for optimal pair matching, which is similar to -#' nearest neighbor matching except that an overall distance criterion is -#' minimized. +#' nearest neighbor matching without replacement except that an overall distance criterion is +#' minimized (i.e., as an alternative to specifying `m.order`). #' #' @references In a manuscript, you don't need to cite another package when #' using `method = "nearest"` because the matching is performed completely @@ -279,19 +283,17 @@ NULL matchit2nearest <- function(treat, data, distance, discarded, - ratio = 1, s.weights = NULL, replace = FALSE, m.order = NULL, - caliper = NULL, mahvars = NULL, exact = NULL, - formula = NULL, estimand = "ATT", verbose = FALSE, - is.full.mahalanobis, - antiexact = NULL, unit.id = NULL, ...){ - - if (verbose) { - rlang::check_installed("RcppProgress") - cat("Nearest neighbor matching... \n") - } + ratio = 1, s.weights = NULL, replace = FALSE, m.order = NULL, + caliper = NULL, mahvars = NULL, exact = NULL, + formula = NULL, estimand = "ATT", verbose = FALSE, + is.full.mahalanobis, + antiexact = NULL, unit.id = NULL, ...) { + + .cat_verbose("Nearest neighbor matching... \n", verbose = verbose) estimand <- toupper(estimand) estimand <- match_arg(estimand, c("ATT", "ATC")) + if (estimand == "ATC") { tc <- c("control", "treated") focal <- 0 @@ -304,7 +306,7 @@ matchit2nearest <- function(treat, data, distance, discarded, treat <- setNames(as.integer(treat == focal), names(treat)) if (is.full.mahalanobis) { - if (length(attr(terms(formula, data = data), "term.labels")) == 0) { + if (is_null(attr(terms(formula, data = data), "term.labels"))) { .err(sprintf("covariates must be specified in the input formula when `distance = \"%s\"`", attr(is.full.mahalanobis, "transform"))) } @@ -315,19 +317,16 @@ matchit2nearest <- function(treat, data, distance, discarded, n1 <- sum(treat == 1) n0 <- n.obs - n1 - lab <- names(treat) - lab1 <- lab[treat == 1] - - if (!is.null(distance)) { - names(distance) <- names(treat) - } - min.controls <- attr(ratio, "min.controls") max.controls <- attr(ratio, "max.controls") mahcovs <- distance_mat <- NULL - if (!is.null(mahvars)) { - transform <- if (is.full.mahalanobis) attr(is.full.mahalanobis, "transform") else "mahalanobis" + if (is_not_null(mahvars)) { + transform <- { + if (is.full.mahalanobis) attr(is.full.mahalanobis, "transform") + else "mahalanobis" + } + mahcovs <- transform_covariates(mahvars, data = data, method = transform, s.weights = s.weights, treat = treat, discarded = discarded) @@ -336,41 +335,60 @@ matchit2nearest <- function(treat, data, distance, discarded, distance_mat <- distance distance <- NULL - if (focal == 0) distance_mat <- t(distance_mat) + if (focal == 0) { + distance_mat <- t(distance_mat) + } } #Process caliper - if (!is.null(caliper)) { + caliper.dist <- caliper.covs <- NULL + caliper.covs.mat <- NULL + ex.caliper <- NULL + + if (is_not_null(caliper)) { if (any(names(caliper) != "")) { caliper.covs <- caliper[names(caliper) != ""] caliper.covs.mat <- get.covs.matrix(reformulate(names(caliper.covs)), data = data) - } - else { - caliper.covs.mat <- caliper.covs <- NULL + + ex.caliper.list <- setNames(lapply(names(caliper.covs), function(i) { + splits <- get_splitsC(as.numeric(caliper.covs.mat[,i]), + as.numeric(caliper.covs[i])) + + if (is_null(splits)) { + return(integer(0)) + } + + cut(caliper.covs.mat[,i], + breaks = splits, + include.lowest = TRUE) + }), names(caliper.covs)) + + ex.caliper.list <- ex.caliper.list[lengths(ex.caliper.list) > 0L] + + # ex.caliper.list <- Filter(ex.caliper.list, f = function(x) nlevels(x) > 1) + + if (is_not_null(ex.caliper.list)) { + for (i in seq_along(ex.caliper.list)) { + levels(ex.caliper.list[[i]]) <- paste(names(ex.caliper.list)[i], "\u2208", levels(ex.caliper.list[[i]])) + } + + ex.caliper <- exactify(ex.caliper.list, nam = names(treat), sep = ", ", + justify = NULL) + } } - if (any(names(caliper) == "")) { + if (hasName(caliper, "")) { caliper.dist <- caliper[names(caliper) == ""] } - else { - caliper.dist <- NULL - } - } - else { - caliper.dist <- caliper.covs <- NULL - caliper.covs.mat <- NULL } #Process antiexact - if (!is.null(antiexact)) { - antiexactcovs <- model.frame(antiexact, data) - antiexactcovs <- do.call("cbind", lapply(seq_len(ncol(antiexactcovs)), function(i) { - as.integer(as.factor(antiexactcovs[[i]])) + antiexactcovs <- NULL + if (is_not_null(antiexact)) { + antiexactcovs <- do.call("cbind", lapply(model.frame(antiexact, data), function(i) { + unclass(as.factor(i)) })) } - else { - antiexactcovs <- NULL - } reuse.max <- attr(replace, "reuse.max") @@ -378,97 +396,116 @@ matchit2nearest <- function(treat, data, distance, discarded, m.order <- "data" } - if (!is.null(unit.id) && reuse.max < n1) { + #unit.id + if (is_not_null(unit.id) && reuse.max < n1) { unit.id <- process.variable.input(unit.id, data) - unit.id <- factor(exactify(model.frame(unit.id, data = data), - nam = lab, sep = ", ", include_vars = TRUE)) + unit.id <- exactify(model.frame(unit.id, data = data), + nam = names(treat), sep = ", ", include_vars = TRUE) + num_ctrl_unit.ids <- length(unique(unit.id[treat == 0])) + num_trt_unit.ids <- length(unique(unit.id[treat == 1])) - #If each control unit is a unit.id, unit.ids are meaningless - if (num_ctrl_unit.ids == n0) unit.id <- NULL + #If each unit is a unit.id, unit.ids are meaningless + if (num_ctrl_unit.ids == n0 && num_trt_unit.ids == n1) { + unit.id <- NULL + } } else { unit.id <- NULL } - if (!is.null(exact)) { - ex <- factor(exactify(model.frame(exact, data = data), nam = lab, sep = ", ", include_vars = TRUE)) + #Process exact + ex <- NULL + if (is_not_null(exact)) { + ex <- exactify(model.frame(exact, data = data), + nam = names(treat), sep = ", ", include_vars = TRUE) - cc <- intersect(as.integer(ex)[treat==1], as.integer(ex)[treat==0]) - if (length(cc) == 0) .err("No matches were found") + cc <- Reduce("intersect", lapply(unique(treat), function(t) unclass(ex)[treat==t])) - if (reuse.max < n1) { + if (is_null(cc)) { + .err("no matches were found") + } + cc <- sort(cc) + } + + if (reuse.max < n1) { + if (is_not_null(ex)) { e_ratios <- vapply(levels(ex), function(e) { - if (is.null(unit.id)) sum(treat[ex == e] == 0)*(reuse.max/sum(treat[ex == e] == 1)) - else length(unique(unit.id[treat == 0 & ex == e]))*(reuse.max/sum(treat[ex == e] == 1)) + if (is_null(unit.id)) reuse.max * sum(treat[ex == e] == 0) / sum(treat[ex == e] == 1) + else reuse.max * length(unique(unit.id[treat == 0 & ex == e])) / sum(treat[ex == e] == 1) }, numeric(1L)) if (any(e_ratios < 1)) { .wrn(sprintf("fewer %s units than %s units in some `exact` strata; not all %s units will get a match", - tc[2], tc[1], tc[1])) + tc[2], tc[1], tc[1])) } + if (ratio > 1 && any(e_ratios < ratio)) { - if (is.null(max.controls) || ratio == max.controls) + if (is_null(max.controls) || ratio == max.controls) .wrn(sprintf("not all %s units will get %s matches", - tc[1], ratio)) + tc[1], ratio)) else .wrn(sprintf("not enough %s units for an average of %s matches per %s unit in all `exact` strata", - tc[2], ratio, tc[1])) + tc[2], ratio, tc[1])) } } - } - else { - ex <- NULL - - if (reuse.max < n1) { - + else { e_ratios <- { - if (is.null(unit.id)) as.numeric(reuse.max)*n0/n1 - else as.numeric(reuse.max)*num_ctrl_unit.ids/n1 + if (is_null(unit.id)) as.numeric(reuse.max) * n0 / n1 + else as.numeric(reuse.max) * num_ctrl_unit.ids / num_trt_unit.ids } if (e_ratios < 1) { .wrn(sprintf("fewer %s %s than %s units; not all %s units will get a match", - tc[2], if (is.null(unit.id)) "units" else "unit IDs", tc[1], tc[1])) + tc[2], if (is_null(unit.id)) "units" else "unit IDs", tc[1], tc[1])) } else if (e_ratios < ratio) { - if (is.null(max.controls) || ratio == max.controls) + if (is_null(max.controls) || ratio == max.controls) .wrn(sprintf("not all %s units will get %s matches", - tc[1], ratio)) + tc[1], ratio)) else .wrn(sprintf("not enough %s %s for an average of %s matches per %s unit", - tc[2], if (is.null(unit.id)) "units" else "unit IDs", ratio, tc[1])) + tc[2], if (is_null(unit.id)) "units" else "unit IDs", ratio, tc[1])) } } } #Variable ratio (extremal matching), Ming & Rosenbaum (2000) #Each treated unit get its own value of ratio - if (!is.null(max.controls)) { - if (is.null(distance)) { - if (is.full.mahalanobis) .err(sprintf("`distance` cannot be \"%s\" for variable ratio matching", - transform)) - .err("`distance` cannot be supplied as a matrix for variable ratio matching") + if (is_not_null(max.controls)) { + if (is_null(distance)) { + if (is.full.mahalanobis) { + .err(sprintf("`distance` cannot be \"%s\" for variable ratio nearest neighbor matching", + transform)) + } + else { + .err("`distance` cannot be supplied as a matrix for variable ratio nearest neighbor matching") + } } m <- round(ratio * n1) # if (m > sum(treat == 0)) stop("'ratio' must be less than or equal to n0/n1") - kmax <- floor((m - min.controls*(n1-1)) / (max.controls - min.controls)) + kmax <- floor((m - min.controls * (n1 - 1)) / (max.controls - min.controls)) kmin <- n1 - kmax - 1 - kmed <- m - (min.controls*kmin + max.controls*kmax) + kmed <- m - (min.controls * kmin + max.controls * kmax) - ratio0 <- c(rep(min.controls, kmin), kmed, rep(max.controls, kmax)) + ratio0 <- c(rep.int(min.controls, kmin), kmed, rep.int(max.controls, kmax)) #Make sure no units are assigned 0 matches - if (any(ratio0 == 0)) { - ind <- which(ratio0 == 0) + while (any(ratio0 == 0)) { + ind <- which(ratio0 == 0)[1L] ratio0[ind] <- 1 - ratio0[ind + 1] <- ratio0[ind + 1] - 1 + + if (ind == length(ratio0)) { + break + } + + ratio0[ind + 1L] <- ratio0[ind + 1] - 1 } - ratio <- rep(NA_integer_, n1) + ratio <- rep.int(NA_integer_, n1) #Order by distance; treated are supposed to have higher values ratio[order(distance[treat == 1], @@ -476,150 +513,167 @@ matchit2nearest <- function(treat, data, distance, discarded, ratio <- as.integer(ratio) } else { - ratio <- as.integer(rep(ratio, n1)) + ratio <- as.integer(rep.int(ratio, n1)) } m.order <- { - if (is.null(distance)) match_arg(m.order, c("data", "random", "closest")) - else if (!is.null(m.order)) match_arg(m.order, c("largest", "smallest", "data", "random", "closest")) + if (is_null(distance)) match_arg(m.order, c("data", "random", "closest", "farthest")) + else if (is_not_null(m.order)) match_arg(m.order, c("largest", "smallest", "data", "random", "closest", "farthest")) else if (estimand == "ATC") "smallest" else "largest" } - if (is.null(ex) || !is.null(unit.id)) { - if (m.order == "closest") { - if (!is.null(mahcovs)) { - distance_mat <- eucdist_internal(mahcovs, treat) - - # If caliper on PS (distance), treat it as a covariate - if (!is.null(distance) && !is.null(caliper.dist)) { - caliper.covs.mat <- { - if (is.null(caliper.covs)) as.matrix(distance) - else cbind(caliper.covs.mat, distance) - } - caliper.covs <- c(caliper.covs, caliper.dist) - caliper.dist <- NULL - } - } - else if (is.null(distance_mat)) { - distance_mat <- eucdist_internal(distance, treat) - } + #If mahcovs is only 1 variable, use vec matching for speed + if (is.full.mahalanobis && is_null(distance) && is_null(caliper.dist) && + is_not_null(mahcovs) && ncol(mahcovs) == 1L) { + distance <- mahcovs[,1] + mahcovs <- NULL + } - ord <- NULL - } - else { - ord <- switch(m.order, - "largest" = order(distance[treat == 1], decreasing = TRUE), - "smallest" = order(distance[treat == 1], decreasing = FALSE), - "random" = sample.int(n1), - "data" = seq_len(n1)) + if (is_not_null(ex)) { + discarded[!ex %in% levels(ex)[cc]] <- TRUE + } + + if (is_null(ex) || is_not_null(unit.id) || (is_null(mahcovs) && is_null(distance_mat) && + !(m.order %in% c("closest", "farthest")))) { + if (is_not_null(ex.caliper)) { + ex <- exactify(list(ex, ex.caliper), + nam = names(treat), sep = ", ", include_vars = FALSE) } - mm <- nn_matchC_dispatch(treat, ord, ratio, discarded, reuse.max, distance, distance_mat, + mm <- nn_matchC_dispatch(treat, 1L, ratio, discarded, reuse.max, distance, distance_mat, ex, caliper.dist, caliper.covs, caliper.covs.mat, mahcovs, antiexactcovs, unit.id, m.order, verbose) - } else { - distance_ <- caliper.covs.mat_ <- mahcovs_ <- antiexactcovs_ <- distance_mat_ <- NULL mm_list <- lapply(levels(ex)[cc], function(e) { - if (verbose) { - cat(sprintf("Matching subgroup %s/%s: %s...\n", - match(e, levels(ex)[cc]), length(cc), e)) - } + .cat_verbose(sprintf("Matching subgroup %s/%s: %s...\n", + match(e, levels(ex)[cc]), length(cc), e), + verbose = verbose) .e <- which(ex == e) .e1 <- which(ex[treat==1] == e) treat_ <- treat[.e] + discarded_ <- discarded[.e] - if (!is.null(distance)) distance_ <- distance[.e] - if (!is.null(caliper.covs.mat)) caliper.covs.mat_ <- caliper.covs.mat[.e,,drop = FALSE] - if (!is.null(mahcovs)) mahcovs_ <- mahcovs[.e,,drop = FALSE] - if (!is.null(antiexactcovs)) antiexactcovs_ <- antiexactcovs[.e,,drop = FALSE] - if (!is.null(distance_mat)) { - .e0 <- which(ex[treat==0] == e) - distance_mat_ <- distance_mat[.e1, .e0, drop = FALSE] + + distance_ <- NULL + if (is_not_null(distance)) { + distance_ <- distance[.e] } - ratio_ <- ratio[.e1] - n1_ <- sum(treat_ == 1) + ex.caliper_ <- NULL + if (is_not_null(ex.caliper)) { + ex.caliper_ <- ex.caliper[.e] + } + caliper.covs.mat_ <- NULL + if (is_not_null(caliper.covs.mat)) { + caliper.covs.mat_ <- caliper.covs.mat[.e,, drop = FALSE] + } - if (m.order == "closest") { - if (!is.null(mahcovs)) { - distance_mat_ <- eucdist_internal(mahcovs_, treat_) - } - else if (is.null(distance_mat)) { - distance_mat_ <- eucdist_internal(distance_, treat_) - } + mahcovs_ <- NULL + if (is_not_null(mahcovs)) { + mahcovs_ <- mahcovs[.e,,drop = FALSE] + } - ord <- NULL + antiexactcovs_ <- NULL + if (is_not_null(antiexactcovs)) { + antiexactcovs_ <- antiexactcovs[.e,, drop = FALSE] } - else { - ord_ <- switch(m.order, - "largest" = order(distance_[treat_ == 1], decreasing = TRUE), - "smallest" = order(distance_[treat_ == 1], decreasing = FALSE), - "random" = sample.int(n1_), - "data" = seq_len(n1_)) + + distance_mat_ <- NULL + if (is_not_null(distance_mat)) { + .e0 <- which(ex[treat==0] == e) + distance_mat_ <- distance_mat[.e1, .e0, drop = FALSE] } - mm_ <- nn_matchC_dispatch(treat_, ord_, ratio_, discarded_, reuse.max, distance_, distance_mat_, - NULL, caliper.dist, caliper.covs, caliper.covs.mat_, mahcovs_, - antiexactcovs_, NULL, m.order, verbose) + ratio_ <- ratio[.e1] + + mm_ <- nn_matchC_dispatch(treat_, 1L, ratio_, discarded_, reuse.max, distance_, distance_mat_, + ex.caliper_, caliper.dist, caliper.covs, caliper.covs.mat_, mahcovs_, + antiexactcovs_, NULL, m.order, verbose) #Ensure matched indices correspond to indices in full sample, not subgroup - mm_[] <- seq_along(treat)[.e][mm_] + mm_[] <- .e[mm_] mm_ }) #Construct match.matrix - mm <- matrix(NA_integer_, nrow = length(lab1), + mm <- matrix(NA_integer_, nrow = n1, ncol = max(vapply(mm_list, ncol, numeric(1L))), - dimnames = list(lab1, NULL)) + dimnames = list(names(treat)[treat == 1], NULL)) for (m in mm_list) { mm[rownames(m), seq_len(ncol(m))] <- m } } - if (verbose) cat("Calculating matching weights... ") + .cat_verbose("Calculating matching weights... ", + verbose = verbose) if (reuse.max > 1) { psclass <- NULL + weights <- get_weights_from_mm(mm, treat, 1L) } else { - psclass <- mm2subclass(mm, treat) + psclass <- mm2subclass(mm, treat, 1L) + weights <- get_weights_from_subclass(psclass, treat) } res <- list(match.matrix = nummm2charmm(mm, treat), subclass = psclass, - weights = get_weights_from_mm(mm, treat)) + weights = weights) - if (verbose) cat("Done.\n") + .cat_verbose("Done.\n", verbose = verbose) class(res) <- "matchit" res } -# Dispatches Rcpp function for NN matching -# nn_matchC_vec() if distance_mat and mahcovs are NULL -# nn_matchC() otherwise -nn_matchC_dispatch <- function(treat, ord, ratio, discarded, reuse.max, distance, distance_mat, ex, caliper.dist, +# Dispatches Rcpp functions for NN matching +nn_matchC_dispatch <- function(treat, focal, ratio, discarded, reuse.max, distance, distance_mat, ex, caliper.dist, caliper.covs, caliper.covs.mat, mahcovs, antiexactcovs, unit.id, m.order, verbose) { - if (m.order == "closest") { - nn_matchC_closest(distance_mat, treat, ratio, discarded, reuse.max, - ex, caliper.dist, caliper.covs, caliper.covs.mat, - antiexactcovs, unit.id, verbose) - } - else if (is.null(distance_mat) && is.null(mahcovs)) { - nn_matchC_vec(treat, ord, ratio, discarded, reuse.max, distance, - ex, caliper.dist, caliper.covs, caliper.covs.mat, - antiexactcovs, unit.id, verbose) + if (m.order %in% c("closest", "farthest")) { + if (is_not_null(mahcovs)) { + nn_matchC_mahcovs_closest(treat, ratio, discarded, reuse.max, mahcovs, distance, + ex, caliper.dist, caliper.covs, caliper.covs.mat, + antiexactcovs, unit.id, m.order == "closest", verbose) + } + else if (is_not_null(distance_mat)) { + nn_matchC_distmat_closest(treat, ratio, discarded, reuse.max, distance_mat, + ex, caliper.dist, caliper.covs, caliper.covs.mat, + antiexactcovs, unit.id, m.order == "closest", verbose) + + } + else { + nn_matchC_vec_closest(treat, ratio, discarded, reuse.max, distance, + ex, caliper.dist, caliper.covs, caliper.covs.mat, + antiexactcovs, unit.id, m.order == "closest", verbose) + } } else { - nn_matchC(treat, ord, ratio, discarded, reuse.max, distance, distance_mat, - ex, caliper.dist, caliper.covs, caliper.covs.mat, mahcovs, - antiexactcovs, unit.id, verbose) + ord <- switch(m.order, + "largest" = order(distance[treat == focal], decreasing = TRUE), + "smallest" = order(distance[treat == focal], decreasing = FALSE), + "random" = sample(which(!discarded[treat == focal])), + "data" = which(!discarded[treat == focal])) + + if (is_not_null(mahcovs)) { + nn_matchC_mahcovs(treat, ord, ratio, discarded, reuse.max, focal, mahcovs, distance, + ex, caliper.dist, caliper.covs, caliper.covs.mat, + antiexactcovs, unit.id, verbose) + } + else if (is_not_null(distance_mat)) { + nn_matchC_distmat(treat, ord, ratio, discarded, reuse.max, focal, distance_mat, + ex, caliper.dist, caliper.covs, caliper.covs.mat, + antiexactcovs, unit.id, verbose) + } + else { + nn_matchC_vec(treat, ord, ratio, discarded, reuse.max, focal, distance, + ex, caliper.dist, caliper.covs, caliper.covs.mat, + antiexactcovs, unit.id, verbose) + } } } diff --git a/R/matchit2optimal.R b/R/matchit2optimal.R index 89c63b7a..0624a694 100644 --- a/R/matchit2optimal.R +++ b/R/matchit2optimal.R @@ -65,7 +65,7 @@ #' place when `distance` corresponds to a propensity score (e.g., for #' caliper matching or to discard units for common support). If specified, the #' distance measure will not be used in matching. -#' @param antiexact for which variables ant-exact matching should take place. +#' @param antiexact for which variables anti-exact matching should take place. #' Anti-exact matching is processed using \pkgfun{optmatch}{antiExactMatch}. #' @param discard a string containing a method for discarding units outside a #' region of common support. Only allowed when `distance` is not @@ -248,11 +248,11 @@ matchit2optimal <- function(treat, formula, data, distance, discarded, rlang::check_installed("optmatch") - if (verbose) cat("Optimal matching... \n") + .cat_verbose("Optimal matching...\n", verbose = verbose) - A <- list(...) - pm.args <- c("tol", "solver") - A[!names(A) %in% pm.args] <- NULL + args <- c("tol", "solver") + A <- setNames(lapply(args, ...get, ...), args) + A[lengths(A) == 0L] <- NULL #Set max problem size to Inf and return to original value after match omps <- getOption("optmatch_max_problem_size") @@ -272,17 +272,17 @@ matchit2optimal <- function(treat, formula, data, distance, discarded, treat_ <- setNames(as.integer(treat[!discarded] == focal), names(treat)[!discarded]) - # if (!is.null(data)) data <- data[!discarded,] + # if (is_not_null(data)) data <- data[!discarded,] if (is.full.mahalanobis) { - if (length(attr(terms(formula, data = data), "term.labels")) == 0) { + if (is_null(attr(terms(formula, data = data), "term.labels"))) { .err(sprintf("covariates must be specified in the input formula when `distance = \"%s\"`", attr(is.full.mahalanobis, "transform"))) } mahvars <- formula } - if (!is.null(caliper)) { + if (is_not_null(caliper)) { .wrn("calipers are currently not compatible with `method = \"optimal\"` and will be ignored") caliper <- NULL } @@ -290,56 +290,61 @@ matchit2optimal <- function(treat, formula, data, distance, discarded, min.controls <- attr(ratio, "min.controls") max.controls <- attr(ratio, "max.controls") - if (is.null(max.controls)) { + if (is_null(max.controls)) { min.controls <- max.controls <- ratio } #Exact matching strata - if (!is.null(exact)) { + if (is_not_null(exact)) { ex <- factor(exactify(model.frame(exact, data = data), sep = ", ", include_vars = TRUE)[!discarded]) - cc <- intersect(as.integer(ex)[treat_==1], as.integer(ex)[treat_==0]) - if (length(cc) == 0) .err("No matches were found") + cc <- Reduce("intersect", lapply(unique(treat_), function(t) unclass(ex)[treat_==t])) - e_ratios <- vapply(levels(ex), function(e) sum(treat_[ex == e] == 0)/sum(treat_[ex == e] == 1), numeric(1L)) + if (is_null(cc) ) { + .err("no matches were found") + } + + e_ratios <- vapply(levels(ex), function(e) { + sum(treat_[ex == e] == 0)/sum(treat_[ex == e] == 1) + }, numeric(1L)) if (any(e_ratios < 1)) { .wrn(sprintf("Fewer %s units than %s units in some `exact` strata; not all %s units will get a match", - tc[2], tc[1], tc[1])) + tc[2], tc[1], tc[1])) } if (ratio > 1 && any(e_ratios < ratio)) { if (ratio == max.controls) .wrn(sprintf("Not all %s units will get %s matches", - tc[1], ratio)) + tc[1], ratio)) else .wrn(sprintf("Not enough %s units for an average of %s matches per %s unit in all `exact` strata", - tc[2], ratio, tc[1])) + tc[2], ratio, tc[1])) } } else { - ex <- factor(rep("_", length(treat_)), levels = "_") + ex <- gl(1, length(treat_), labels = "_") cc <- 1 e_ratios <- setNames(sum(treat_ == 0)/sum(treat_ == 1), levels(ex)) if (e_ratios < 1) { .wrn(sprintf("Fewer %s units than %s units; not all %s units will get a match", - tc[2], tc[1], tc[1])) + tc[2], tc[1], tc[1])) } else if (e_ratios < ratio) { if (ratio == max.controls) .wrn(sprintf("Not all %s units will get %s matches", - tc[1], ratio)) + tc[1], ratio)) else .wrn(sprintf("Not enough %s units for an average of %s matches per %s unit", - tc[2], ratio, tc[1])) + tc[2], ratio, tc[1])) } } #Create distance matrix; note that Mahalanobis distance computed using entire #sample (minus discarded), like method2nearest, as opposed to within exact strata, like optmatch. - if (!is.null(mahvars)) { + if (is_not_null(mahvars)) { transform <- if (is.full.mahalanobis) attr(is.full.mahalanobis, "transform") else "mahalanobis" mahcovs <- transform_covariates(mahvars, data = data, method = transform, s.weights = s.weights, treat = treat, @@ -364,7 +369,7 @@ matchit2optimal <- function(treat, formula, data, distance, discarded, mo <- optmatch::as.InfinitySparseMatrix(mo) #Process antiexact - if (!is.null(antiexact)) { + if (is_not_null(antiexact)) { antiexactcovs <- model.frame(antiexact, data) for (i in seq_len(ncol(antiexactcovs))) { mo <- mo + optmatch::antiExactMatch(antiexactcovs[[i]][!discarded], z = treat_) @@ -372,23 +377,26 @@ matchit2optimal <- function(treat, formula, data, distance, discarded, } #Initialize pair membership; must include names - pair <- setNames(rep(NA_character_, length(treat)), names(treat)) + pair <- rep_with(NA_character_, treat) p <- setNames(vector("list", nlevels(ex)), levels(ex)) t_df <- data.frame(treat_) for (e in levels(ex)[cc]) { - if (nlevels(ex) > 1) { - if (verbose) { - cat(sprintf("Matching subgroup %s/%s: %s...\n", - match(e, levels(ex)[cc]), length(cc), e)) - } + if (nlevels(ex) > 1L) { + .cat_verbose(sprintf("Matching subgroup %s/%s: %s...\n", + match(e, levels(ex)[cc]), length(cc), e), + verbose = verbose) + mo_ <- mo[ex[treat_==1] == e, ex[treat_==0] == e] } else mo_ <- mo - if (any(dim(mo_) == 0) || !any(is.finite(mo_))) next - else if (all(dim(mo_) == 1) && all(is.finite(mo_))) { + if (any(dim(mo_) == 0) || !any(is.finite(mo_))) { + next + } + + if (all_equal_to(dim(mo_), 1) && all(is.finite(mo_))) { pair[ex == e] <- paste(1, e, sep = "|") next } @@ -412,21 +420,22 @@ matchit2optimal <- function(treat, formula, data, distance, discarded, max.controls_ <- max.controls } + A$x <- mo_ + A$mean.controls <- ratio_ + A$min.controls <- min.controls_ + A$max.controls <- max.controls_ + A$data <- t_df[ex == e,, drop = FALSE] #just to get rownames; not actually used in matching + matchit_try({ - p[[e]] <- do.call(optmatch::fullmatch, - c(list(mo_, - mean.controls = ratio_, - min.controls = min.controls_, - max.controls = max.controls_, - data = t_df[ex == e,, drop = FALSE]), #just to get rownames; not actually used in matching - A)) + p[[e]] <- do.call(optmatch::fullmatch, A) }, from = "optmatch") pair[names(p[[e]])[!is.na(p[[e]])]] <- paste(as.character(p[[e]][!is.na(p[[e]])]), e, sep = "|") } - if (all(is.na(pair))) .err("No matches were found") - if (length(p) == 1) p <- p[[1]] + if (length(p) == 1L) { + p <- p[[1]] + } psclass <- factor(pair) levels(psclass) <- seq_len(nlevels(psclass)) @@ -434,7 +443,7 @@ matchit2optimal <- function(treat, formula, data, distance, discarded, mm <- nummm2charmm(subclass2mmC(psclass, treat, focal), treat) - if (verbose) cat("Calculating matching weights... ") + .cat_verbose("Calculating matching weights... ", verbose = verbose) ## calculate weights and return the results res <- list(match.matrix = mm, @@ -442,7 +451,7 @@ matchit2optimal <- function(treat, formula, data, distance, discarded, weights = get_weights_from_subclass(psclass, treat, estimand), obj = p) - if (verbose) cat("Done.\n") + .cat_verbose("Done.\n", verbose = verbose) class(res) <- "matchit" res diff --git a/R/matchit2quick.R b/R/matchit2quick.R index cd310975..4a8a62de 100644 --- a/R/matchit2quick.R +++ b/R/matchit2quick.R @@ -143,7 +143,7 @@ matchit2quick <- function(treat, formula, data, distance, discarded, rlang::check_installed("quickmatch") - if (verbose) cat("Generalized full matching... \n") + .cat_verbose("Generalized full matching...\n", verbose = verbose) A <- list(...) @@ -165,7 +165,7 @@ matchit2quick <- function(treat, formula, data, distance, discarded, # treat_ <- setNames(as.integer(treat[!discarded] == focal), names(treat)[!discarded]) if (is.full.mahalanobis) { - if (length(attr(terms(formula, data = data), "term.labels")) == 0) { + if (is_null(attr(terms(formula, data = data), "term.labels"))) { .err(sprintf("covariates must be specified in the input formula when `distance = \"%s\"`", attr(is.full.mahalanobis, "transform"))) } @@ -173,21 +173,24 @@ matchit2quick <- function(treat, formula, data, distance, discarded, } #Exact matching strata - if (!is.null(exact)) { + if (is_not_null(exact)) { ex <- factor(exactify(model.frame(exact, data = data), sep = ", ", include_vars = TRUE)[!discarded]) - cc <- intersect(as.integer(ex)[treat_==1], as.integer(ex)[treat_==0]) - if (length(cc) == 0) .err("no matches were found") + cc <- Reduce("intersect", lapply(unique(treat_), function(t) unclass(ex)[treat_==t])) + + if (is_null(cc)) { + .err("no matches were found") + } } else { - ex <- factor(rep("_", length(treat_)), levels = "_") + ex <- gl(1, length(treat_), labels = "_") cc <- 1 } #Create distance matrix; note that Mahalanobis distance computed using entire #sample (minus discarded), like method2nearest, as opposed to within exact strata, like optmatch. - if (!is.null(mahvars)) { + if (is_not_null(mahvars)) { transform <- if (is.full.mahalanobis) attr(is.full.mahalanobis, "transform") else "mahalanobis" distcovs <- transform_covariates(mahvars, data = data, method = transform, s.weights = s.weights, treat = treat, @@ -202,40 +205,42 @@ matchit2quick <- function(treat, formula, data, distance, discarded, rownames(distcovs) <- names(treat_) #Process caliper - if (!is.null(caliper)) { - if (!is.null(mahvars)) { + if (is_not_null(caliper)) { + if (is_not_null(mahvars)) { .err("with `method = \"quick\"`, a caliper can only be used when `distance` is a propensity score or vector and `mahvars` is not specified") } - if (length(caliper) > 1 || !identical(names(caliper), "")) { + + if (length(caliper) > 1L || !identical(names(caliper), "")) { .err("with `method = \"quick\"`, calipers cannot be placed on covariates") } } + A$caliper <- caliper + #Initialize pair membership; must include names - pair <- setNames(rep(NA_character_, length(treat)), names(treat)) + pair <- rep_with(NA_character_, treat) p <- setNames(vector("list", nlevels(ex)), levels(ex)) for (e in levels(ex)[cc]) { - if (verbose && nlevels(ex) > 1) { - cat(sprintf("Matching subgroup %s/%s: %s...\n", - match(e, levels(ex)[cc]), length(cc), e)) + if (nlevels(ex) > 1L) { + .cat_verbose(sprintf("Matching subgroup %s/%s: %s...\n", + match(e, levels(ex)[cc]), length(cc), e), + verbose = verbose) } - distcovs_ <- distcovs[ex == e,, drop = FALSE] + A$distances <- distcovs[ex == e,, drop = FALSE] + A$treatments <- treat_[ex == e] matchit_try({ - p[[e]] <- do.call(quickmatch::quickmatch, - c(list(distcovs_, - treatments = treat_[ex == e], - caliper = caliper), - A)) + p[[e]] <- do.call(quickmatch::quickmatch, A) }, from = "quickmatch") pair[which(ex == e)[!is.na(p[[e]])]] <- paste(as.character(p[[e]][!is.na(p[[e]])]), e, sep = "|") } - if (all(is.na(pair))) .err("no matches were found") - if (length(p) == 1) p <- p[[1]] + if (length(p) == 1L) { + p <- p[[1]] + } psclass <- factor(pair) levels(psclass) <- seq_len(nlevels(psclass)) @@ -244,13 +249,13 @@ matchit2quick <- function(treat, formula, data, distance, discarded, #No match.matrix because treated units don't index matched strata (i.e., more than one #treated unit can be in the same stratum). Stratum information is contained in subclass. - if (verbose) cat("Calculating matching weights... ") + .cat_verbose("Calculating matching weights... ", verbose = verbose) res <- list(subclass = psclass, weights = get_weights_from_subclass(psclass, treat, estimand), obj = p) - if (verbose) cat("Done.\n") + .cat_verbose("Done.\n", verbose = verbose) class(res) <- c("matchit") res diff --git a/R/matchit2subclass.R b/R/matchit2subclass.R index aca5a792..018c3eb3 100644 --- a/R/matchit2subclass.R +++ b/R/matchit2subclass.R @@ -158,61 +158,54 @@ NULL matchit2subclass <- function(treat, distance, discarded, replace = FALSE, exact = NULL, estimand = "ATT", verbose = FALSE, + subclass = 6L, min.n = 1L, ...) { - if(verbose) - cat("Subclassifying... \n") - - A <- list(...) - subclass <- A[["subclass"]] - sub.by <- A[["sub.by"]] - min.n <- A[["min.n"]] + .cat_verbose("Subclassifying...\n", verbose = verbose) #Checks - if (is.null(subclass)) subclass <- 6 chk::chk_numeric(subclass) - if (length(subclass) == 1) { + if (length(subclass) == 1L) { chk::chk_gt(subclass, 1) } - else if (!all(subclass <= 1 & subclass >= 0)) { - .err("When specifying `subclass` as a vector of quantiles, all values must be between 0 and 1") + else if (any(subclass > 1) || any(subclass < 0)) { + .err("when specifying `subclass` as a vector of quantiles, all values must be between 0 and 1") } - if (!is.null(sub.by)) { + if (is_not_null(...get("sub.by", ...))) { .err("`sub.by` is defunct and has been replaced with `estimand`") } estimand <- toupper(estimand) estimand <- match_arg(estimand, c("ATT", "ATC", "ATE")) - if (is.null(min.n)) min.n <- 1 chk::chk_count(min.n) n.obs <- length(treat) ## Setting Cut Points - if (length(subclass) == 1) { + if (length(subclass) == 1L) { sprobs <- seq(0, 1, length.out = round(subclass) + 1) } else { sprobs <- sort(subclass) if (sprobs[1] != 0) sprobs <- c(0, sprobs) if (sprobs[length(sprobs)] != 1) sprobs <- c(sprobs, 1) - subclass <- length(sprobs) - 1 + subclass <- length(sprobs) - 1L } q <- switch(estimand, - "ATT" = quantile(distance[treat==1], probs = sprobs, na.rm = TRUE), - "ATC" = quantile(distance[treat==0], probs = sprobs, na.rm = TRUE), + "ATT" = quantile(distance[treat == 1], probs = sprobs, na.rm = TRUE), + "ATC" = quantile(distance[treat == 0], probs = sprobs, na.rm = TRUE), quantile(distance, probs = sprobs, na.rm = TRUE)) ## Calculating Subclasses - psclass <- setNames(rep(NA_integer_, n.obs), names(treat)) + psclass <- rep_with(NA_integer_, treat) psclass[!discarded] <- as.integer(findInterval(distance[!discarded], q, all.inside = TRUE)) - if (length(unique(na.omit(psclass))) != subclass){ - .wrn("Due to discreteness in the distance measure, fewer subclasses were generated than were requested") + if (!has_n_unique(na.omit(psclass), subclass)) { + .wrn("due to discreteness in the distance measure, fewer subclasses were generated than were requested") } if (min.n == 0) { @@ -220,22 +213,26 @@ matchit2subclass <- function(treat, distance, discarded, is.na(psclass)[!discarded & !psclass %in% intersect(psclass[!discarded & treat == 1], psclass[!discarded & treat == 0])] <- TRUE } - else if (any(table(treat, psclass) < min.n)) { + else { ## If any subclasses don't have members of a treatment group, fill them ## by "scooting" units from nearby subclasses until each subclass has a unit ## from each treatment group - psclass[!discarded] <- subclass_scoot(psclass[!discarded], treat[!discarded], distance[!discarded], min.n) + psclass[!discarded] <- subclass_scoot(psclass[!discarded], + treat[!discarded], + distance[!discarded], + min.n) } psclass <- setNames(factor(psclass, nmax = length(q)), names(treat)) levels(psclass) <- as.character(seq_len(nlevels(psclass))) - if (verbose) cat("Calculating matching weights... ") + .cat_verbose("Calculating matching weights... ", verbose = verbose) - res <- list(subclass = psclass, q.cut = q, + res <- list(subclass = psclass, + q.cut = q, weights = get_weights_from_subclass(psclass, treat, estimand)) - if (verbose) cat("Done.\n") + .cat_verbose("Done.\n", verbose = verbose) class(res) <- c("matchit.subclass", "matchit") res diff --git a/R/plot.matchit.R b/R/plot.matchit.R index 464c8deb..8cffd29a 100644 --- a/R/plot.matchit.R +++ b/R/plot.matchit.R @@ -149,16 +149,16 @@ plot.matchit <- function(x, type = "qq", interactive = TRUE, which.xs = NULL, da which.xs = which.xs, data = data, ...) } else if (type == "jitter") { - if (is.null(x$distance)) { + if (is_null(x$distance)) { .err("`type = \"jitter\"` cannot be used if a distance measure is not estimated or supplied. No plots generated") } - jitter.pscore(x, interactive = interactive,...) + jitter_pscore(x, interactive = interactive,...) } else if (type == "histogram") { - if (is.null(x$distance)) { + if (is_null(x$distance)) { .err("`type = \"hist\"` cannot be used if a distance measure is not estimated or supplied. No plots generated") } - hist.pscore(x,...) + hist_pscore(x,...) } invisible(x) } @@ -192,15 +192,22 @@ plot.matchit.subclass <- function(x, type = "qq", interactive = TRUE, which.xs = #If subclass = NULL, if interactive, use to choose subclass, else display aggregate across subclass subclasses <- levels(x$subclass) - miss.sub <- missing(subclass) || is.null(subclass) - if (miss.sub || isFALSE(subclass)) which.subclass <- NULL - else if (isTRUE(subclass)) which.subclass <- subclasses - else if (!is.atomic(subclass) || !all(subclass %in% seq_along(subclasses))) { + miss.sub <- missing(subclass) || is_null(subclass) + + if (miss.sub || isFALSE(subclass)) { + which.subclass <- NULL + } + else if (isTRUE(subclass)) { + which.subclass <- subclasses + } + else if (is.atomic(subclass) && all(subclass %in% seq_along(subclasses))) { + which.subclass <- subclasses[subclass] + } + else { .err("`subclass` should be `TRUE`, `FALSE`, or a vector of subclass indices for which subclass balance is to be displayed") } - else which.subclass <- subclasses[subclass] - if (!is.null(which.subclass)) { + if (is_not_null(which.subclass)) { matchit.covplot.subclass(x, type = type, which.subclass = which.subclass, interactive = interactive, which.xs = which.xs, ...) } @@ -226,16 +233,16 @@ plot.matchit.subclass <- function(x, type = "qq", interactive = TRUE, which.xs = } } else if (type=="jitter") { - if (is.null(x$distance)) { + if (is_null(x$distance)) { .err("`type = \"jitter\"` cannot be used when no distance variable was estimated or supplied") } - jitter.pscore(x, interactive = interactive, ...) + jitter_pscore(x, interactive = interactive, ...) } else if (type == "histogram") { - if (is.null(x$distance)) { + if (is_null(x$distance)) { .err("`type = \"histogram\"` cannot be used when no distance variable was estimated or supplied") } - hist.pscore(x,...) + hist_pscore(x,...) } invisible(x) } @@ -243,25 +250,26 @@ plot.matchit.subclass <- function(x, type = "qq", interactive = TRUE, which.xs = ## plot helper functions matchit.covplot <- function(object, type = "qq", interactive = TRUE, which.xs = NULL, data = NULL, ...) { - if (is.null(which.xs)) { - if (length(object$X) == 0) { + if (is_null(which.xs)) { + if (is_null(object$X)) { .wrn("No covariates to plot") return(invisible(NULL)) } + X <- object$X - if (!is.null(object$exact)) { + if (is_not_null(object$exact)) { Xexact <- model.frame(object$exact, data = object$X) X <- cbind(X, Xexact[setdiff(names(Xexact), names(X))]) } - if (!is.null(object$mahvars)) { + if (is_not_null(object$mahvars)) { Xmahvars <- model.frame(object$mahvars, data = object$X) X <- cbind(X, Xmahvars[setdiff(names(Xmahvars), names(X))]) } } else { - if (!is.null(data)) { + if (is_not_null(data)) { if (!is.data.frame(data) || nrow(data) != length(object$treat)) { .err("`data` must be a data frame with as many rows as there are units in the supplied `matchit` object") } @@ -271,19 +279,19 @@ matchit.covplot <- function(object, type = "qq", interactive = TRUE, which.xs = data <- object$X } - if (!is.null(object$exact)) { + if (is_not_null(object$exact)) { Xexact <- model.frame(object$exact, data = object$X) data <- cbind(data, Xexact[setdiff(names(Xexact), names(data))]) } - if (!is.null(object$mahvars)) { + if (is_not_null(object$mahvars)) { Xmahvars <- model.frame(object$mahvars, data = object$X) data <- cbind(data, Xmahvars[setdiff(names(Xmahvars), names(data))]) } if (is.character(which.xs)) { - if (!all(which.xs %in% names(data))) { - .err("All variables in `which.xs` must be in the supplied `matchit` object or in `data`") + if (!all(hasName(data, which.xs))) { + .err("all variables in `which.xs` must be in the supplied `matchit` object or in `data`") } X <- data[which.xs] } @@ -304,13 +312,16 @@ matchit.covplot <- function(object, type = "qq", interactive = TRUE, which.xs = for (i in seq_len(k)) { if (anyNA(X[[i]]) || (is.numeric(X[[i]]) && any(!is.finite(X[[i]])))) { covariates.with.missingness <- names(X)[i:k][vapply(i:k, function(j) anyNA(X[[j]]) || - (is.numeric(X[[j]]) && - any(!is.finite(X[[j]]))), - logical(1L))] + (is.numeric(X[[j]]) && + any(!is.finite(X[[j]]))), + logical(1L))] .err(paste0("Missing and non-finite values are not allowed in the covariates named in `which.xs`. Variables with missingness or non-finite values:\n\t", paste(covariates.with.missingness, collapse = ", ")), tidy = FALSE) } - if (is.character(X[[i]])) X[[i]] <- factor(X[[i]]) + + if (is.character(X[[i]])) { + X[[i]] <- factor(X[[i]]) + } } } @@ -318,9 +329,13 @@ matchit.covplot <- function(object, type = "qq", interactive = TRUE, which.xs = t <- object$treat - sw <- if (is.null(object$s.weights)) rep(1, length(t)) else object$s.weights + sw <- { + if (is_null(object$s.weights)) rep(1, length(t)) + else object$s.weights + } + w <- object$weights * sw - if (is.null(w)) w <- rep(1, length(t)) + if (is_null(w)) w <- rep(1, length(t)) w <- .make_sum_to_1(w, by = t) sw <- .make_sum_to_1(sw, by = t) @@ -391,6 +406,7 @@ matchit.covplot <- function(object, type = "qq", interactive = TRUE, which.xs = devAskNewPage(ask = interactive) } + devAskNewPage(ask = FALSE) invisible(NULL) @@ -399,25 +415,27 @@ matchit.covplot <- function(object, type = "qq", interactive = TRUE, which.xs = matchit.covplot.subclass <- function(object, type = "qq", which.subclass = NULL, interactive = TRUE, which.xs = NULL, data = NULL, ...) { - if (is.null(which.xs)) { - if (length(object$X) == 0) { - .wrn("No covariates to plot") + if (is_null(which.xs)) { + if (is_null(object$X)) { + .wrn("no covariates to plot") + return(invisible(NULL)) } + X <- object$X - if (!is.null(object$exact)) { + if (is_not_null(object$exact)) { Xexact <- model.frame(object$exact, data = object$X) X <- cbind(X, Xexact[setdiff(names(Xexact), names(X))]) } - if (!is.null(object$mahvars)) { + if (is_not_null(object$mahvars)) { Xmahvars <- model.frame(object$mahvars, data = object$X) X <- cbind(X, Xmahvars[setdiff(names(Xmahvars), names(X))]) } } else { - if (!is.null(data)) { + if (is_not_null(data)) { if (!is.data.frame(data) || nrow(data) != length(object$treat)) { .err("`data` must be a data frame with as many rows as there are units in the supplied `matchit` object") } @@ -427,19 +445,19 @@ matchit.covplot.subclass <- function(object, type = "qq", which.subclass = NULL, data <- object$X } - if (!is.null(object$exact)) { + if (is_not_null(object$exact)) { Xexact <- model.frame(object$exact, data = object$X) data <- cbind(data, Xexact[setdiff(names(Xexact), names(data))]) } - if (!is.null(object$mahvars)) { + if (is_not_null(object$mahvars)) { Xmahvars <- model.frame(object$mahvars, data = object$X) data <- cbind(data, Xmahvars[setdiff(names(Xmahvars), names(data))]) } if (is.character(which.xs)) { - if (!all(which.xs %in% names(data))) { - .err("All variables in `which.xs` must be in the supplied `matchit` object or in `data`") + if (!all(hasName(data, which.xs))) { + .err("all variables in `which.xs` must be in the supplied `matchit` object or in `data`") } X <- data[which.xs] } @@ -447,9 +465,7 @@ matchit.covplot.subclass <- function(object, type = "qq", which.subclass = NULL, which.xs <- update(terms(which.xs, data = data), NULL ~ .) X <- model.frame(which.xs, data, na.action = "na.pass") - if (anyNA(X)) { - .err("Missing values are not allowed in the covariates named in `which.xs`") - } + chk::chk_not_any_na(X, "the covariates named in `which.xs`") } else { .err("`which.xs` must be supplied as a character vector of names or a one-sided formula") @@ -457,17 +473,20 @@ matchit.covplot.subclass <- function(object, type = "qq", which.subclass = NULL, } chars.in.X <- vapply(X, is.character, logical(1L)) - X[chars.in.X] <- lapply(X[chars.in.X], factor) + for (i in which(chars.in.X)) { + X[[i]] <- factor(X[[i]]) + } X <- droplevels(X) t <- object$treat if (!is.atomic(which.subclass)) { - .err("The argument to `subclass` must be NULL or the indices of the subclasses for which to display covariate distributions") + .err("the argument to `subclass` must be `NULL` or the indices of the subclasses for which to display covariate distributions") } + if (!all(which.subclass %in% object$subclass[!is.na(object$subclass)])) { - .err("The argument supplied to `subclass` is not the index of any subclass in the matchit object") + .err("the argument supplied to `subclass` is not the index of any subclass in the `matchit` object") } if (type == "density") { @@ -491,39 +510,43 @@ matchit.covplot.subclass <- function(object, type = "qq", which.subclass = NULL, opar <- par(mfrow = c(3, 3), mar = c(1.5,.5,1.5,.5), oma = oma) } - sw <- if (is.null(object$s.weights)) rep(1, length(t)) else object$s.weights - w <- sw*(!is.na(object$subclass) & object$subclass == s) + sw <- { + if (is_null(object$s.weights)) rep.int(1, length(t)) + else object$s.weights + } + + w <- sw * (!is.na(object$subclass) & object$subclass == s) w <- .make_sum_to_1(w, by = t) sw <- .make_sum_to_1(sw, by = t) for (i in seq_along(varnames)){ - x <- if (type == "density") X[[i]] else X[,i] + x <- switch(type, "density" = X[[i]], X[,i]) plot.new() - if (((i-1)%%3)==0) { + if (((i-1) %% 3) == 0) { if (type == "qq") { - htext <- paste0("eQQ Plots (Subclass ", s,")") - mtext(htext, 3, 2, TRUE, 0.5, cex=1.1, font = 2) - mtext("All", 3, .25, TRUE, 0.5, cex=1, font = 1) - mtext("Matched", 3, .25, TRUE, 0.83, cex=1, font = 1) - mtext("Control Units", 1, 0, TRUE, 2/3, cex=1, font = 1) - mtext("Treated Units", 4, 0, TRUE, 0.5, cex=1, font = 1) + htext <- sprintf("eQQ Plots (Subclass %s)", s) + mtext(htext, 3, 2, TRUE, 0.5, cex = 1.1, font = 2) + mtext("All", 3, .25, TRUE, 0.5, cex = 1, font = 1) + mtext("Matched", 3, .25, TRUE, 0.83, cex = 1, font = 1) + mtext("Control Units", 1, 0, TRUE, 2/3, cex = 1, font = 1) + mtext("Treated Units", 4, 0, TRUE, 0.5, cex = 1, font = 1) } else if (type == "ecdf") { - htext <- paste0("eCDF Plots (Subclass ", s,")") - mtext(htext, 3, 2, TRUE, 0.5, cex=1.1, font = 2) - mtext("All", 3, .25, TRUE, 0.5, cex=1, font = 1) - mtext("Matched", 3, .25, TRUE, 0.83, cex=1, font = 1) + htext <- sprintf("eCDF Plots (Subclass %s)", s) + mtext(htext, 3, 2, TRUE, 0.5, cex = 1.1, font = 2) + mtext("All", 3, .25, TRUE, 0.5, cex = 1, font = 1) + mtext("Matched", 3, .25, TRUE, 0.83, cex = 1, font = 1) } else if (type == "density") { - htext <- paste0("Density Plots (Subclass ", s,")") - mtext(htext, 3, 2, TRUE, 0.5, cex=1.1, font = 2) - mtext("All", 3, .25, TRUE, 0.5, cex=1, font = 1) - mtext("Matched", 3, .25, TRUE, 0.83, cex=1, font = 1) + htext <- sprintf("Density Plots (Subclass %s)", s) + mtext(htext, 3, 2, TRUE, 0.5, cex = 1.1, font = 2) + mtext("All", 3, .25, TRUE, 0.5, cex = 1, font = 1) + mtext("Matched", 3, .25, TRUE, 0.83, cex = 1, font = 1) } } @@ -547,9 +570,10 @@ matchit.covplot.subclass <- function(object, type = "qq", which.subclass = NULL, devAskNewPage(ask = interactive) } } + devAskNewPage(ask = FALSE) - invisible(NULL) + invisible(NULL) } qqplot_match <- function(x, t, w, sw, discrete.cutoff = 5, ...) { @@ -608,8 +632,8 @@ qqplot_match <- function(x, t, w, sw, discrete.cutoff = 5, ...) { rr <- range(c(x0, x1)) plot(x0, x1, xlab = "", ylab = "", xlim = rr, ylim = rr, axes = FALSE, ...) abline(a = 0, b = 1) - abline(a = (rr[2]-rr[1])*0.1, b = 1, lty = 2) - abline(a = -(rr[2]-rr[1])*0.1, b = 1, lty = 2) + abline(a = (rr[2]-rr[1]) * 0.1, b = 1, lty = 2) + abline(a = -(rr[2]-rr[1]) * 0.1, b = 1, lty = 2) axis(2) box() @@ -658,8 +682,8 @@ qqplot_match <- function(x, t, w, sw, discrete.cutoff = 5, ...) { plot(x0, x1, xlab = "", ylab = "", xlim = rr, ylim = rr, axes = FALSE, ...) abline(a = 0, b = 1) - abline(a = (rr[2]-rr[1])*0.1, b = 1, lty = 2) - abline(a = -(rr[2]-rr[1])*0.1, b = 1, lty = 2) + abline(a = (rr[2]-rr[1]) * 0.1, b = 1, lty = 2) + abline(a = -(rr[2]-rr[1]) * 0.1, b = 1, lty = 2) box() } @@ -670,7 +694,8 @@ ecdfplot_match <- function(x, t, w, sw, ...) { x.range <- x.max - x.min #Unmatched samples - plot(x = x, y = w, type= "n" , xlim = c(x.min - .02 * x.range, x.max + .02 * x.range), + plot(x = x, y = w, type = "n" , + xlim = c(x.min - .02 * x.range, x.max + .02 * x.range), ylim = c(0, 1), axes = TRUE, ...) for (tr in 0:1) { @@ -686,8 +711,10 @@ ecdfplot_match <- function(x, t, w, sw, ...) { box() #Matched sample - plot(x = x, y = w, type= "n" , xlim = c(x.min - .02 * x.range, x.max + .02 * x.range), + plot(x = x, y = w, type = "n" , + xlim = c(x.min - .02 * x.range, x.max + .02 * x.range), ylim = c(0, 1), axes = FALSE, ...) + for (tr in 0:1) { in.tr <- t[ord] == tr ordt <- ord[in.tr] @@ -703,43 +730,49 @@ ecdfplot_match <- function(x, t, w, sw, ...) { box() } -densityplot_match <- function(x, t, w, sw, ...) { +densityplot_match <- function(x, t, w, sw, bw = NULL, cut = 3, ...) { - if (length(unique(x)) == 2L) x <- factor(x) + if (has_n_unique(x, 2L)) { + x <- factor(x, nmax = 2) + } + + u <- unique(t) if (!is.factor(x)) { #Density plot for continuous variable - small.tr <- (0:1)[which.min(c(sum(t==0), sum(t==1)))] - x_small <- x[t==small.tr] + small.tr <- u[which.min(vapply(u, function(tr) sum(t == tr), numeric(1L)))] + x_small <- x[t == small.tr] x.min <- min(x) x.max <- max(x) - # - A <- list(...) - bw <- A[["bw"]] - if (is.null(bw)) A[["bw"]] <- bw.nrd0(x_small) + if (is_null(bw)) { + bw <- bw.nrd0(x_small) + } else if (is.character(bw)) { bw <- tolower(bw) bw <- match_arg(bw, c("nrd0", "nrd", "ucv", "bcv", "sj", "sj-ste", "sj-dpi")) - A[["bw"]] <- switch(bw, nrd0 = bw.nrd0(x_small), nrd = bw.nrd(x_small), - ucv = bw.ucv(x_small), bcv = bw.bcv(x_small), sj = , - `sj-ste` = bw.SJ(x_small, method = "ste"), - `sj-dpi` = bw.SJ(x_small, method = "dpi")) - } - if (is.null(A[["cut"]])) A[["cut"]] <- 3 - - d_unmatched <- do.call("rbind", lapply(0:1, function(tr) { - cbind(as.data.frame(do.call("density", c(list(x[t==tr], weights = sw[t==tr], - from = x.min - A[["cut"]]*A[["bw"]], - to = x.max + A[["cut"]]*A[["bw"]]), A))[1:2]), + bw <- switch(bw, nrd0 = bw.nrd0(x_small), nrd = bw.nrd(x_small), + ucv = bw.ucv(x_small), bcv = bw.bcv(x_small), sj = , + `sj-ste` = bw.SJ(x_small, method = "ste"), + `sj-dpi` = bw.SJ(x_small, method = "dpi")) + } + + d_unmatched <- do.call("rbind", lapply(u, function(tr) { + cbind(as.data.frame(density(x[t == tr], + weights = sw[t==tr], + from = x.min - cut * bw, + to = x.max + cut * bw, + bw = bw, cut = cut, ...)[1:2]), t = tr) })) - d_matched <- do.call("rbind", lapply(0:1, function(tr) { - cbind(as.data.frame(do.call("density", c(list(x[t==tr], weights = w[t==tr], - from = x.min - A[["cut"]]*A[["bw"]], - to = x.max + A[["cut"]]*A[["bw"]]), A))[1:2]), + d_matched <- do.call("rbind", lapply(u, function(tr) { + cbind(as.data.frame(density(x[t == tr], + weights = w[t == tr], + from = x.min - cut * bw, + to = x.max + cut * bw, + bw = bw, cut = cut, ...)[1:2]), t = tr) })) @@ -747,8 +780,9 @@ densityplot_match <- function(x, t, w, sw, ...) { #Unmatched samples plot(x = d_unmatched$x, y = d_unmatched$y, type = "n", - xlim = c(x.min - A[["cut"]]*A[["bw"]], x.max + A[["cut"]]*A[["bw"]]), - ylim = c(0, 1.1*y.max), axes = TRUE, ...) + xlim = c(x.min - cut * bw, x.max + cut * bw), + ylim = c(0, 1.1 * y.max), + axes = TRUE, ...) for (tr in 0:1) { in.tr <- d_unmatched$t == tr @@ -763,8 +797,9 @@ densityplot_match <- function(x, t, w, sw, ...) { #Matched sample plot(x = d_matched$x, y = d_matched$y, type = "n", - xlim = c(x.min - A[["cut"]]*A[["bw"]], x.max + A[["cut"]]*A[["bw"]]), - ylim = c(0, 1.1*y.max), axes = FALSE, ...) + xlim = c(x.min - cut * bw, x.max + cut * bw), + ylim = c(0, 1.1 * y.max), + axes = FALSE, ...) for (tr in 0:1) { in.tr <- d_matched$t == tr @@ -782,15 +817,15 @@ densityplot_match <- function(x, t, w, sw, ...) { #Bar plot for binary variable x_t_un <- lapply(sort(unique(t)), function(t_) { vapply(levels(x), function(i) { - wm(x[t==t_] == i, sw[t==t_]) - }, numeric(1L))}) + wm(x[t == t_] == i, sw[t == t_]) + }, numeric(1L))}) x_t_m <- lapply(sort(unique(t)), function(t_) { vapply(levels(x), function(i) { - wm(x[t==t_] == i, w[t==t_]) + wm(x[t == t_] == i, w[t == t_]) }, numeric(1L))}) - ylim <- c(0, 1.1*max(unlist(x_t_un), unlist(x_t_m))) + ylim <- c(0, 1.1 * max(unlist(x_t_un), unlist(x_t_m))) borders <- c("grey60", "black") for (i in seq_along(x_t_un)) { @@ -813,22 +848,24 @@ densityplot_match <- function(x, t, w, sw, ...) { } } -hist.pscore <- function(x, xlab = "Propensity Score", freq = FALSE, ...){ +hist_pscore <- function(x, xlab = "Propensity Score", freq = FALSE, ...) { .pardefault <- par(no.readonly = TRUE) on.exit(par(.pardefault)) treat <- x$treat pscore <- x$distance[!is.na(x$distance)] - s.weights <- if (is.null(x$s.weights)) rep(1, length(treat)) else x$s.weights + + s.weights <- { + if (is_null(x$s.weights)) rep(1, length(treat)) + else x$s.weights + } + weights <- x$weights * s.weights matched <- weights != 0 q.cut <- x$q.cut minp <- min(pscore) maxp <- max(pscore) - ratio <- x$call$ratio - - if (is.null(ratio)) ratio <- 1 if (freq) { weights <- .make_sum_to_n(weights, by = treat) @@ -847,11 +884,9 @@ hist.pscore <- function(x, xlab = "Propensity Score", freq = FALSE, ...){ xlim <- range(breaks) for (n in c("Raw Treated", "Matched Treated", "Raw Control", "Matched Control")) { - if (startsWith(n, "Raw")) w <- s.weights - else w <- weights + w <- if (startsWith(n, "Raw")) s.weights else weights - if (endsWith(n, "Treated")) t <- 1 - else t <- 0 + t <- if (endsWith(n, "Treated")) 1 else 0 #Create histogram using weights #Manually assign density, which is used as height of the bars. The scaling @@ -862,35 +897,41 @@ hist.pscore <- function(x, xlab = "Propensity Score", freq = FALSE, ...){ pm[["density"]] <- vapply(seq_len(length(pm$breaks) - 1), function(i) { sum(w[treat == t & pscore >= pm$breaks[i] & pscore < pm$breaks[i+1]]) }, numeric(1L)) + plot(pm, xlim = xlim, xlab = xlab, main = n, ylab = ylab, freq = FALSE, col = "lightgray", ...) - if (!startsWith(n, "Raw") && !is.null(q.cut)) abline(v = q.cut, lty=2) + if (!startsWith(n, "Raw") && is_not_null(q.cut)) { + abline(v = q.cut, lty = 2) + } } } -jitter.pscore <- function(x, interactive, pch = 1, ...){ +jitter_pscore <- function(x, interactive, pch = 1, ...) { .pardefault <- par(no.readonly = TRUE) on.exit(par(.pardefault)) treat <- x$treat pscore <- x$distance - s.weights <- if (is.null(x$s.weights)) rep(1, length(treat)) else x$s.weights + s.weights <- if (is_null(x$s.weights)) rep.int(1, length(treat)) else x$s.weights weights <- x$weights * s.weights matched <- weights > 0 q.cut <- x$q.cut - jitp <- jitter(rep(1,length(treat)), factor=6)+(treat==1)*(weights==0)-(treat==0) - (weights==0)*(treat==0) + jitp <- jitter(rep.int(1, length(treat)), factor = 6) + (treat==1) * (weights == 0) - (treat==0) - (weights==0) * (treat==0) cswt <- sqrt(s.weights) cwt <- sqrt(weights) minp <- min(pscore, na.rm = TRUE) maxp <- max(pscore, na.rm = TRUE) - plot(pscore, xlim = c(minp - 0.05*(maxp-minp), maxp + 0.05*(maxp-minp)), ylim = c(-1.5,2.5), + plot(pscore, xlim = c(minp - 0.05*(maxp-minp), maxp + 0.05 * (maxp - minp)), ylim = c(-1.5, 2.5), type = "n", ylab = "", xlab = "Propensity Score", axes = FALSE, main = "Distribution of Propensity Scores", ...) - if (!is.null(q.cut)) abline(v = q.cut, col = "grey", lty = 1) + + if (is_not_null(q.cut)) { + abline(v = q.cut, col = "grey", lty = 1) + } #Matched treated points(pscore[treat==1 & matched], jitp[treat==1 & matched], diff --git a/R/plot.summary.matchit.R b/R/plot.summary.matchit.R index 6d2e4dab..d3df1178 100644 --- a/R/plot.summary.matchit.R +++ b/R/plot.summary.matchit.R @@ -85,8 +85,8 @@ plot.summary.matchit <- function(x, on.exit(par(.pardefault)) sub <- inherits(x, "summary.matchit.subclass") - matched <- sub || !is.null(x[["sum.matched"]]) - un <- !is.null(x[["sum.all"]]) + matched <- sub || is_not_null(x[["sum.matched"]]) + un <- is_not_null(x[["sum.all"]]) standard.sum <- if (un) x[["sum.all"]] else x[[if (sub) "sum.across" else "sum.matched"]] @@ -131,7 +131,7 @@ plot.summary.matchit <- function(x, bg = NA, color = NA, ...) abline(v = 0) - if (sub && length(x$sum.subclass) > 0) { + if (sub && is_not_null(x$sum.subclass)) { for (i in seq_along(x$sum.subclass)) { sd.sub <- x$sum.subclass[[i]][,"Std. Mean Diff."] if (abs) sd.sub <- abs(sd.sub) @@ -149,7 +149,7 @@ plot.summary.matchit <- function(x, pch = 21, bg = "black", col = "black") } - if (!is.null(threshold)) { + if (is_not_null(threshold)) { if (abs) { abline(v = threshold, lty = seq_along(threshold)) } @@ -159,7 +159,7 @@ plot.summary.matchit <- function(x, } } - if (sum(matched, un) > 1 && !is.null(position)) { + if (sum(matched, un) > 1 && is_not_null(position)) { position <- match_arg(position, c("bottomright", "bottom", "bottomleft", "left", "topleft", "top", "topright", "right", "center")) legend(position, legend = c("All", "Matched"), diff --git a/R/rbind.matchdata.R b/R/rbind.matchdata.R index 9089155a..db6a3ba5 100644 --- a/R/rbind.matchdata.R +++ b/R/rbind.matchdata.R @@ -73,7 +73,7 @@ rbind.matchdata <- function(..., deparse.level = 1) { allargs <- list(...) allargs <- allargs[lengths(allargs) > 0L] - if (is.null(names(allargs))) { + if (is_null(names(allargs))) { md_list <- allargs allargs <- list() } @@ -84,20 +84,31 @@ rbind.matchdata <- function(..., deparse.level = 1) { allargs$deparse.level <- deparse.level type <- intersect(c("matchdata", "getmatches"), unlist(lapply(md_list, class))) - if (length(type) == 0) .err("A `matchdata` or `getmatches` object must be supplied") - if (length(type) == 2) .err("Supplied objects must be all `matchdata` objects or all `getmatches` objects") + + if (is_null(type)) { + .err("A `matchdata` or `getmatches` object must be supplied") + } + + if (length(type) == 2L) { + .err("Supplied objects must be all `matchdata` objects or all `getmatches` objects") + } attrs <- c("distance", "weights", "subclass", "id") attr_list <- setNames(vector("list", length(attrs)), attrs) - key_attrs <- setNames(rep(NA_character_, length(attrs)), attrs) + key_attrs <- setNames(rep.int(NA_character_, length(attrs)), attrs) for (i in attrs) { attr_list[[i]] <- unlist(lapply(md_list, function(m) { a <- attr(m, i) - if (length(a) == 0) NA_character_ else a + if (is_null(a)) NA_character_ else a })) - if (all(is.na(attr_list[[i]]))) attr_list[[i]] <- NULL - else key_attrs[i] <- attr_list[[i]][which(!is.na(attr_list[[i]]))[1]] + + if (all(is.na(attr_list[[i]]))) { + attr_list[[i]] <- NULL + } + else { + key_attrs[i] <- Find(Negate(is.na), attr_list[[i]]) + } } attrs <- names(attr_list) key_attrs <- key_attrs[attrs] @@ -106,8 +117,10 @@ rbind.matchdata <- function(..., deparse.level = 1) { other_col_list <- lapply(seq_along(md_list), function(d) { setdiff(names(md_list[[d]]), unlist(lapply(attr_list, `[`, d))) }) + for (d in seq_along(md_list)[-1]) { - if (length(other_col_list[[d]]) != length(other_col_list[[1]]) || !all(other_col_list[[d]] %in% other_col_list[[1]])) { + if (length(other_col_list[[d]]) != length(other_col_list[[1]]) || + !all(other_col_list[[d]] %in% other_col_list[[1]])) { .err(sprintf("the %s inputs must come from the same dataset", switch(type, "matchdata" = "`match.data()`", "`get_matches()`"))) } @@ -116,13 +129,21 @@ rbind.matchdata <- function(..., deparse.level = 1) { for (d in seq_along(md_list)) { for (i in attrs) { #Rename columns of each attribute the same across datasets - if (is.null(attr(md_list[[d]], i))) md_list[[d]] <- setNames(cbind(md_list[[d]], NA), c(names(md_list[[d]]), key_attrs[i])) - else names(md_list[[d]])[names(md_list[[d]]) == attr_list[[i]][d]] <- key_attrs[i] + if (is_null(attr(md_list[[d]], i))) { + md_list[[d]] <- setNames(cbind(md_list[[d]], NA), c(names(md_list[[d]]), key_attrs[i])) + } + else { + names(md_list[[d]])[names(md_list[[d]]) == attr_list[[i]][d]] <- key_attrs[i] + } #Give subclasses unique values across datasets if (i == "subclass") { - if (all(is.na(md_list[[d]][[key_attrs[i]]]))) md_list[[d]][[key_attrs[i]]] <- factor(md_list[[d]][[key_attrs[i]]], levels = NA) - else levels(md_list[[d]][[key_attrs[i]]]) <- paste(d, levels(md_list[[d]][[key_attrs[i]]]), sep = "_") + if (all(is.na(md_list[[d]][[key_attrs[i]]]))) { + md_list[[d]][[key_attrs[i]]] <- factor(md_list[[d]][[key_attrs[i]]], levels = NA) + } + else { + levels(md_list[[d]][[key_attrs[i]]]) <- paste(d, levels(md_list[[d]][[key_attrs[i]]]), sep = "_") + } } } @@ -130,7 +151,8 @@ rbind.matchdata <- function(..., deparse.level = 1) { if (d > 1) { md_list[[d]] <- md_list[[d]][names(md_list[[1]])] } - class(md_list[[d]]) <- class(md_list[[d]])[class(md_list[[d]]) != type] + + class(md_list[[d]]) <- setdiff(class(md_list[[d]]), type) } out <- do.call("rbind", c(md_list, allargs)) diff --git a/R/summary.matchit.R b/R/summary.matchit.R index 590a4a8b..c40f0b3e 100644 --- a/R/summary.matchit.R +++ b/R/summary.matchit.R @@ -208,9 +208,12 @@ summary.matchit <- function(object, treat <- object$treat weights <- object$weights - s.weights <- if (is.null(object$s.weights)) rep(1, length(weights)) else object$s.weights + s.weights <- { + if (is_null(object$s.weights)) rep(1, length(weights)) + else object$s.weights + } - no_x <- length(X) == 0 + no_x <- is_null(X) if (no_x) { X <- matrix(1, nrow = length(treat), ncol = 1, @@ -227,7 +230,7 @@ summary.matchit <- function(object, kk <- ncol(X) - matched <- !is.null(object$info$method) + matched <- is_not_null(object$info$method) un <- un || !matched chk::chk_flag(interactions) @@ -272,8 +275,10 @@ summary.matchit <- function(object, if (!no_x && interactions) { n.int <- kk*(kk+1)/2 - if (un) sum.all.int <- matrix(NA_real_, nrow = n.int, ncol = length(aa.all[[1]]), dimnames = list(NULL, names(aa.all[[1]]))) - if (matched) sum.matched.int <- matrix(NA_real_, nrow = n.int, ncol = length(aa.matched[[1]]), dimnames = list(NULL, names(aa.matched[[1]]))) + if (un) sum.all.int <- matrix(NA_real_, nrow = n.int, ncol = length(aa.all[[1]]), + dimnames = list(NULL, names(aa.all[[1]]))) + if (matched) sum.matched.int <- matrix(NA_real_, nrow = n.int, ncol = length(aa.matched[[1]]), + dimnames = list(NULL, names(aa.matched[[1]]))) to.remove <- rep(FALSE, n.int) int.names <- character(n.int) @@ -312,13 +317,14 @@ summary.matchit <- function(object, rownames(sum.all.int) <- int.names sum.all <- rbind(sum.all, sum.all.int[!to.remove,,drop = FALSE]) } + if (matched) { rownames(sum.matched.int) <- int.names sum.matched <- rbind(sum.matched, sum.matched.int[!to.remove,,drop = FALSE]) } } - if (!is.null(object$distance)) { + if (is_not_null(object$distance)) { if (un) { ad.all <- bal1var(object$distance, tt = treat, ww = NULL, s.weights = s.weights, standardize = standardize, s.d.denom = s.d.denom) @@ -397,7 +403,10 @@ summary.matchit.subclass <- function(object, which.subclass <- subclass treat <- object$treat weights <- object$weights - s.weights <- if (is.null(object$s.weights)) rep(1, length(weights)) else object$s.weights + s.weights <- { + if (is_null(object$s.weights)) rep(1, length(weights)) + else object$s.weights + } subclass <- object$subclass nam <- colnames(X) @@ -411,13 +420,13 @@ summary.matchit.subclass <- function(object, chk::chk_flag(un) chk::chk_flag(improvement) - if (standardize) { - s.d.denom <- switch(object$estimand, - "ATT" = "treated", - "ATC" = "control", - "ATE" = "pooled") + s.d.denom <- { + if (standardize) switch(object$estimand, + "ATT" = "treated", + "ATC" = "control", + "ATE" = "pooled") + else NULL } - else s.d.denom <- NULL if (isTRUE(which.subclass)) which.subclass <- subclasses else if (isFALSE(which.subclass)) which.subclass <- NULL @@ -427,7 +436,7 @@ summary.matchit.subclass <- function(object, else which.subclass <- subclasses[which.subclass] matched <- TRUE #always compute aggregate balance so plot.summary can use it - subs <- !is.null(which.subclass) + subs <- is_not_null(which.subclass) ## Aggregate Subclass #Use the estimated weights to compute aggregate balance. @@ -458,13 +467,15 @@ summary.matchit.subclass <- function(object, if (interactions) { n.int <- kk*(kk+1)/2 - if (un) sum.all.int <- matrix(NA_real_, nrow = n.int, ncol = length(aa.all[[1]]), dimnames = list(NULL, names(aa.all[[1]]))) - if (matched) sum.matched.int <- matrix(NA_real_, nrow = n.int, ncol = length(aa.matched[[1]]), dimnames = list(NULL, names(aa.matched[[1]]))) + if (un) sum.all.int <- matrix(NA_real_, nrow = n.int, ncol = length(aa.all[[1]]), + dimnames = list(NULL, names(aa.all[[1]]))) + if (matched) sum.matched.int <- matrix(NA_real_, nrow = n.int, ncol = length(aa.matched[[1]]), + dimnames = list(NULL, names(aa.matched[[1]]))) to.remove <- rep(FALSE, n.int) int.names <- character(n.int) k <- 1 - for (i in 1:kk) { + for (i in seq_len(kk)) { for (j in i:kk) { x2 <- X[,i] * X[,j] if (all(abs(x2) < sqrt(.Machine$double.eps)) || @@ -502,7 +513,7 @@ summary.matchit.subclass <- function(object, } } - if (!is.null(object$distance)) { + if (is_not_null(object$distance)) { if (un) { ad.all <- bal1var(object$distance, tt = treat, ww = NULL, s.weights = s.weights, standardize = standardize, s.d.denom = s.d.denom) @@ -535,7 +546,9 @@ summary.matchit.subclass <- function(object, #bal1var.subclass only returns unmatched stats, which is all we need within #subclasses. Otherwise, identical to matched stats. aa <- setNames(lapply(seq_len(kk), function(i) { - bal1var.subclass(X[,i], tt = treat, s.weights = s.weights, subclass = subclass, s.d.denom = s.d.denom, standardize = standardize, which.subclass = s) + bal1var.subclass(X[,i], tt = treat, s.weights = s.weights, + subclass = subclass, s.d.denom = s.d.denom, + standardize = standardize, which.subclass = s) }), colnames(X)) sum.sub <- matrix(NA_real_, nrow = kk, ncol = ncol(aa[[1]]), dimnames = list(nam, colnames(aa[[1]]))) @@ -545,7 +558,8 @@ summary.matchit.subclass <- function(object, sum.sub[i,] <- aa[[i]] } if (interactions) { - sum.sub.int <- matrix(NA_real_, nrow = kk*(kk+1)/2, ncol = length(aa[[1]]), dimnames = list(NULL, names(aa[[1]]))) + sum.sub.int <- matrix(NA_real_, nrow = kk*(kk+1)/2, ncol = length(aa[[1]]), + dimnames = list(NULL, names(aa[[1]]))) to.remove <- rep(FALSE, nrow(sum.sub.int)) int.names <- character(nrow(sum.sub.int)) k <- 1 @@ -553,7 +567,9 @@ summary.matchit.subclass <- function(object, for (j in i:kk) { if (!to.remove[k]) { #to.remove defined above x2 <- X[,i] * X[,j] - jqoi <- bal1var.subclass(x2, tt = treat, s.weights = s.weights, subclass = subclass, s.d.denom = s.d.denom, standardize = standardize, which.subclass = s) + jqoi <- bal1var.subclass(x2, tt = treat, s.weights = s.weights, + subclass = subclass, s.d.denom = s.d.denom, + standardize = standardize, which.subclass = s) sum.sub.int[k,] <- jqoi if (i == j) { int.names[k] <- paste0(nam[i], "\u00B2") @@ -570,7 +586,7 @@ summary.matchit.subclass <- function(object, sum.sub <- rbind(sum.sub, sum.sub.int[!to.remove,,drop = FALSE]) } - if (!is.null(object$distance)) { + if (is_not_null(object$distance)) { ad <- bal1var.subclass(object$distance, tt = treat, s.weights = s.weights, subclass = subclass, s.d.denom = s.d.denom, standardize = standardize, which.subclass = s) sum.sub <- rbind(ad, sum.sub) @@ -599,28 +615,30 @@ summary.matchit.subclass <- function(object, #' @exportS3Method print summary.matchit #' @rdname summary.matchit print.summary.matchit <- function(x, digits = max(3, getOption("digits") - 3), - ...){ + ...) { - if (!is.null(x$call)) cat("\nCall:", deparse(x$call), sep = "\n") + if (is_not_null(x$call)) { + cat("\nCall:", deparse(x$call), sep = "\n") + } - if (!is.null(x$sum.all)) { + if (is_not_null(x$sum.all)) { cat("\nSummary of Balance for All Data:\n") print(round_df_char(x$sum.all[,-7, drop = FALSE], digits, pad = "0", na_vals = "."), right = TRUE, quote = FALSE) } - if (!is.null(x$sum.matched)) { + if (is_not_null(x$sum.matched)) { cat("\nSummary of Balance for Matched Data:\n") if (all(is.na(x$sum.matched[,7]))) x$sum.matched <- x$sum.matched[,-7,drop = FALSE] #Remove pair dist if empty print(round_df_char(x$sum.matched, digits, pad = "0", na_vals = "."), right = TRUE, quote = FALSE) } - if (!is.null(x$reduction)) { + if (is_not_null(x$reduction)) { cat("\nPercent Balance Improvement:\n") print(round_df_char(x$reduction[,-5, drop = FALSE], 1, pad = "0", na_vals = "."), right = TRUE, quote = FALSE) } - if (!is.null(x$nn)) { + if (is_not_null(x$nn)) { cat("\nSample Sizes:\n") nn <- x$nn if (isTRUE(all.equal(nn["All (ESS)",], nn["All",]))) { @@ -641,41 +659,43 @@ print.summary.matchit <- function(x, digits = max(3, getOption("digits") - 3), #' @exportS3Method print summary.matchit.subclass print.summary.matchit.subclass <- function(x, digits = max(3, getOption("digits") - 3), ...){ - if (!is.null(x$call)) cat("\nCall:", deparse(x$call), sep = "\n") + if (is_not_null(x$call)) { + cat("\nCall:", deparse(x$call), sep = "\n") + } - if (!is.null(x$sum.all)) { + if (is_not_null(x$sum.all)) { cat("\nSummary of Balance for All Data:\n") print(round_df_char(x$sum.all[,-7, drop = FALSE], digits, pad = "0", na_vals = "."), right = TRUE, quote = FALSE) } - if (length(x$sum.subclass) > 0) { + if (is_not_null(x$sum.subclass)) { cat("\nSummary of Balance by Subclass:\n") for (s in seq_along(x$sum.subclass)) { cat(paste0("\n- ", names(x$sum.subclass)[s], "\n")) print(round_df_char(x$sum.subclass[[s]][,-7, drop = FALSE], digits, pad = "0", na_vals = "."), right = TRUE, quote = FALSE) } - if (!is.null(x$qn)) { + if (is_not_null(x$qn)) { cat("\nSample Sizes by Subclass:\n") print(round_df_char(x$qn, 2, pad = " ", na_vals = "."), right = TRUE, quote = FALSE) } } else { - if (!is.null(x$sum.across)) { + if (is_not_null(x$sum.across)) { cat("\nSummary of Balance Across Subclasses\n") if (all(is.na(x$sum.across[,7]))) x$sum.across <- x$sum.across[,-7,drop = FALSE] print(round_df_char(x$sum.across, digits, pad = "0", na_vals = "."), right = TRUE, quote = FALSE) } - if (!is.null(x$reduction)) { + if (is_not_null(x$reduction)) { cat("\nPercent Balance Improvement:\n") print(round_df_char(x$reduction[,-5, drop = FALSE], 1, pad = "0", na_vals = "."), right = TRUE, quote = FALSE) } - if (!is.null(x$nn)) { + if (is_not_null(x$nn)) { cat("\nSample Sizes:\n") nn <- x$nn if (isTRUE(all.equal(nn["All (ESS)",], nn["All",]))) { @@ -696,98 +716,97 @@ print.summary.matchit.subclass <- function(x, digits = max(3, getOption("digits" .process_X <- function(object, addlvariables = NULL, data = NULL) { X <- { - if (length(object$X) == 0) matrix(nrow = length(object$treat), ncol = 0) + if (is_null(object$X)) matrix(nrow = length(object$treat), ncol = 0) else get.covs.matrix(data = object$X) } - if (!is.null(addlvariables)) { - - #Attempt to extrct data from matchit object; same as match.data() - data.fram.matchit <- FALSE - if (is.null(data)) { - env <- environment(object$formula) - data <- try(eval(object$call$data, envir = env), silent = TRUE) - if (length(data) == 0 || inherits(data, "try-error") || length(dim(data)) != 2 || nrow(data) != length(object[["treat"]])) { - env <- parent.frame() - data <- try(eval(object$call$data, envir = env), silent = TRUE) - if (length(data) == 0 || inherits(data, "try-error") || length(dim(data)) != 2 || nrow(data) != length(object[["treat"]])) { - data <- object[["model"]][["data"]] - if (length(data) == 0 || nrow(data) != length(object[["treat"]])) { - data <- NULL - } - else data.fram.matchit <- TRUE - } - else data.fram.matchit <- TRUE - } - else data.fram.matchit <- TRUE - } + if (is_null(addlvariables)) { + return(X) + } - if (is.character(addlvariables)) { - if (!is.null(data) && is.data.frame(data)) { - if (all(addlvariables %in% names(data))) { - addlvariables <- data[addlvariables] - } - else { - .err("All variables in `addlvariables` must be in `data`") - } - } - else { - .err("If `addlvariables` is specified as a string, a data frame argument must be supplied to `data`") - } + #Attempt to extract data from matchit object; same as match.data() + data.found <- FALSE + for (i in 1:4) { + if (i == 2L) { + data <- try(eval(object$call$data, envir = environment(object$formula)), silent = TRUE) } - else if (inherits(addlvariables, "formula")) { - vars.in.formula <- all.vars(addlvariables) - if (!is.null(data) && is.data.frame(data)) { - data <- data.frame(data[names(data) %in% vars.in.formula], - object$X[names(object$X) %in% setdiff(vars.in.formula, names(data))]) - } - else data <- object$X - - # addlvariables <- get.covs.matrix(addlvariables, data = data) + else if (i == 3L) { + data <- try(eval(object$call$data, envir = parent.frame()), silent = TRUE) } - else if (!is.matrix(addlvariables) && !is.data.frame(addlvariables)) { - .err("The argument to `addlvariables` must be in one of the accepted forms. See `?summary.matchit` for details") + else if (i == 4L) { + data <- object[["model"]][["data"]] } + if (!null_or_error(data) && length(dim(data)) == 2L && nrow(data) == length(object[["treat"]])) { + data.found <- TRUE + break + } + } - if (af <- inherits(addlvariables, "formula")) { - addvariables_f <- addlvariables - addlvariables <- model.frame(addvariables_f, data = data, na.action = "na.pass") + if (is.character(addlvariables)) { + if (is_null(data) || !is.data.frame(data)) { + .err("if `addlvariables` is specified as a string, a data frame argument must be supplied to `data`") } - if (nrow(addlvariables) != length(object$treat)) { - if (is.null(data) || data.fram.matchit) { - .err("Variables specified in `addlvariables` must have the same number of units as are present in the original call to `matchit()`") - } - else { - .err("`data` must have the same number of units as are present in the original call to `matchit()`") - } + if (!all(hasName(data, addlvariables))) { + .err("all variables in `addlvariables` must be in `data`") } - k <- ncol(addlvariables) - for (i in seq_len(k)) { - if (anyNA(addlvariables[[i]]) || (is.numeric(addlvariables[[i]]) && any(!is.finite(addlvariables[[i]])))) { - covariates.with.missingness <- names(addlvariables)[i:k][vapply(i:k, function(j) anyNA(addlvariables[[j]]) || - (is.numeric(addlvariables[[j]]) && - any(!is.finite(addlvariables[[j]]))), - logical(1L))] - .err(paste0("Missing and non-finite values are not allowed in `addlvariables`. Variables with missingness or non-finite values:\n\t", - paste(covariates.with.missingness, collapse = ", ")), tidy = FALSE) - } - if (is.character(addlvariables[[i]])) addlvariables[[i]] <- factor(addlvariables[[i]]) + addlvariables <- data[addlvariables] + } + else if (rlang::is_formula(addlvariables)) { + if (is_not_null(data) && is.data.frame(data)) { + vars.in.formula <- all.vars(addlvariables) + data <- cbind(data[names(data) %in% vars.in.formula], + object$X[names(object$X) %in% setdiff(vars.in.formula, names(data))]) + } + else { + data <- object$X } + } + else if (!is.matrix(addlvariables) && !is.data.frame(addlvariables)) { + .err("the argument to `addlvariables` must be in one of the accepted forms. See `?summary.matchit` for details") + } - if (af) { - addlvariables <- get.covs.matrix(addvariables_f, data = data) + + if (af <- rlang::is_formula(addlvariables)) { + addvariables_f <- addlvariables + addlvariables <- model.frame(addvariables_f, data = data, na.action = "na.pass") + } + + if (nrow(addlvariables) != length(object$treat)) { + if (is_null(data) || data.found) { + .err("variables specified in `addlvariables` must have the same number of units as are present in the original call to `matchit()`") } else { - addlvariables <- get.covs.matrix(data = addlvariables) + .err("`data` must have the same number of units as are present in the original call to `matchit()`") + } + } + + k <- ncol(addlvariables) + for (i in seq_len(k)) { + if (anyNA(addlvariables[[i]]) || (is.numeric(addlvariables[[i]]) && any(!is.finite(addlvariables[[i]])))) { + covariates.with.missingness <- names(addlvariables)[i:k][vapply(i:k, function(j) anyNA(addlvariables[[j]]) || + (is.numeric(addlvariables[[j]]) && + any(!is.finite(addlvariables[[j]]))), + logical(1L))] + .err(paste0("Missing and non-finite values are not allowed in `addlvariables`. Variables with missingness or non-finite values:\n\t", + paste(covariates.with.missingness, collapse = ", ")), tidy = FALSE) } + if (is.character(addlvariables[[i]])) { + addlvariables[[i]] <- factor(addlvariables[[i]]) + } + } - # addl_assign <- get_assign(addlvariables) - X <- cbind(X, addlvariables[, setdiff(colnames(addlvariables), colnames(X)), drop = FALSE]) + if (af) { + addlvariables <- get.covs.matrix(addvariables_f, data = data) + } + else { + addlvariables <- get.covs.matrix(data = addlvariables) } - X -} \ No newline at end of file + # addl_assign <- get_assign(addlvariables) + cbind(X, addlvariables[, setdiff(colnames(addlvariables), colnames(X)), drop = FALSE]) + +} diff --git a/R/utils.R b/R/utils.R new file mode 100644 index 00000000..d9f6fc42 --- /dev/null +++ b/R/utils.R @@ -0,0 +1,488 @@ +#Function to turn a vector into a string with "," and "and" or "or" for clean messages. 'and.or' +#controls whether words are separated by "and" or "or"; 'is.are' controls whether the list is +#followed by "is" or "are" (to avoid manually figuring out if plural); quotes controls whether +#quotes should be placed around words in string. From WeightIt. +word_list <- function(word.list = NULL, and.or = "and", is.are = FALSE, quotes = FALSE) { + #When given a vector of strings, creates a string of the form "a and b" + #or "a, b, and c" + #If is.are, adds "is" or "are" appropriately + + word.list <- setdiff(word.list, c(NA_character_, "")) + + if (is_null(word.list)) { + out <- "" + attr(out, "plural") <- FALSE + return(out) + } + + word.list <- add_quotes(word.list, quotes) + + L <- length(word.list) + + if (L == 1L) { + out <- word.list + if (is.are) out <- paste(out, "is") + attr(out, "plural") <- FALSE + return(out) + } + + if (is_null(and.or) || isFALSE(and.or)) { + out <- paste(word.list, collapse = ", ") + } + else { + and.or <- match_arg(and.or, c("and", "or")) + + if (L == 2L) { + out <- sprintf("%s %s %s", + word.list[1L], + and.or, + word.list[2L]) + } + else { + out <- sprintf("%s, %s %s", + paste(word.list[-L], collapse = ", "), + and.or, + word.list[L]) + } + } + + if (is.are) { + out <- sprintf("%s are", out) + } + + attr(out, "plural") <- TRUE + + out +} + +#Add quotes to a string +add_quotes <- function(x, quotes = 2L) { + if (isFALSE(quotes)) { + return(x) + } + + if (isTRUE(quotes)) + quotes <- '"' + + if (chk::vld_string(quotes)) { + return(paste0(quotes, x, quotes)) + } + + if (!chk::vld_count(quotes) || quotes > 2) { + stop("`quotes` must be boolean, 1, 2, or a string.") + } + + if (quotes == 0L) { + return(x) + } + + x <- { + if (quotes == 1) sprintf("'%s'", x) + else sprintf('"%s"', x) + } + + x +} + +#More informative and cleaner version of base::match.arg(). Uses chk. +match_arg <- function(arg, choices, several.ok = FALSE) { + #Replaces match.arg() but gives cleaner error message and processing + #of arg. + if (missing(arg)) { + stop("No argument was supplied to match_arg.") + } + + arg.name <- deparse1(substitute(arg), width.cutoff = 500L) + + if (missing(choices)) { + formal.args <- formals(sys.function(sysP <- sys.parent())) + choices <- eval(formal.args[[as.character(substitute(arg))]], + envir = sys.frame(sysP)) + } + + if (is_null(arg)) { + return(choices[1L]) + } + + if (several.ok) { + chk::chk_character(arg, x_name = add_quotes(arg.name, "`")) + } + else { + chk::chk_string(arg, x_name = add_quotes(arg.name, "`")) + + if (identical(arg, choices)) { + return(arg[1L]) + } + } + + i <- pmatch(arg, choices, nomatch = 0L, duplicates.ok = TRUE) + + if (all_equal_to(i, 0L)) { + .err(sprintf("the argument to `%s` should be %s%s", + arg.name, + ngettext(length(choices), "", if (several.ok) "at least one of " else "one of "), + word_list(choices, and.or = "or", quotes = 2))) + } + + i <- i[i > 0L] + + choices[i] +} + +# Version of interaction(., drop = TRUE) that doesn't succumb to vector limit reached by +# avoiding Cartesian expansion. Falls back to interaction() for small problems. +interaction2 <- function(..., sep = ".", lex.order = TRUE) { + + narg <- ...length() + + if (narg == 0L) { + stop("No factors specified") + } + + if (narg == 1L && is.list(..1)) { + args <- ..1 + narg <- length(args) + } + else { + args <- list(...) + } + + for (i in seq_len(narg)) { + args[[i]] <- as.factor(args[[i]]) + } + + if (do.call("prod", lapply(args, nlevels)) <= 1e6) { + return(interaction(args, drop = TRUE, sep = sep, + lex.order = if (is.null(lex.order)) TRUE else lex.order)) + } + + out <- do.call(function(...) paste(..., sep = sep), args) + + args_char <- lapply(args, function(x) { + x <- unclass(x) + formatC(x, format = "d", flag = "0", width = ceiling(log10(max(x)))) + }) + + lev <- { + if (is.null(lex.order)) unique(out) + else if (lex.order) unique(out[order(do.call("paste", c(args_char, list(sep = sep))))]) + else unique(out[order(do.call("paste", c(rev(args_char), list(sep = sep))))]) + } + + factor(out, levels = lev) +} + +#Turn a vector into a 0/1 vector. 'zero' and 'one' can be supplied to make it clear which is +#which; otherwise, a guess is used. From WeightIt. +binarize <- function(variable, zero = NULL, one = NULL) { + var.name <- deparse1(substitute(variable)) + + if (has_n_unique(variable, 1L)) { + return(rep_with(1L, variable)) + } + + if (!has_n_unique(variable, 2L)) { + .err(sprintf("cannot binarize %s: more than two levels", var.name)) + } + + if (is.character(variable) || is.factor(variable)) { + variable <- factor(variable, nmax = 2L) + unique.vals <- levels(variable) + } + else { + unique.vals <- unique(variable, nmax = 2L) + } + + if (is_not_null(zero)) { + if (!zero %in% unique.vals) { + .err(sprintf("the argument to `zero` is not the name of a level of %s", var.name)) + } + + return(setNames(as.integer(variable != zero), names(variable))) + } + + if (is_not_null(one)) { + if (!one %in% unique.vals) { + .err(sprintf("the argument to `one` is not the name of a level of %s", var.name)) + } + + return(setNames(as.integer(variable == one), names(variable))) + } + + if (is.logical(variable)) { + return(setNames(as.integer(variable), names(variable))) + } + + if (is.numeric(variable)) { + zero <- { + if (any(unique.vals == 0)) 0 + else min(unique.vals, na.rm = TRUE) + } + + return(setNames(as.integer(variable != zero), names(variable))) + } + + variable.numeric <- { + if (can_str2num(unique.vals)) setNames(str2num(unique.vals), unique.vals)[variable] + else unclass(factor(variable, levels = unique.vals)) + } + + zero <- { + if (0 %in% variable.numeric) 0 + else min(variable.numeric, na.rm = TRUE) + } + + setNames(as.integer(variable.numeric != zero), names(variable)) +} + +is_null <- function(x) length(x) == 0L +is_not_null <- function(x) !is_null(x) + +null_or_error <- function(x) {is_null(x) || inherits(x, "try-error")} + +#Determine whether a character vector can be coerced to numeric +can_str2num <- function(x) { + if (is.numeric(x) || is.logical(x)) { + return(TRUE) + } + + nas <- is.na(x) + suppressWarnings(x_num <- as.numeric(as.character(x[!nas]))) + + !anyNA(x_num) +} + +#Cleanly coerces a character vector to numeric; best to use after can_str2num() +str2num <- function(x) { + nas <- is.na(x) + if (!is.numeric(x) && !is.logical(x)) x <- as.character(x) + suppressWarnings(x_num <- as.numeric(x)) + is.na(x_num)[nas] <- TRUE + x_num +} + +#Capitalize first letter of string +firstup <- function(x) { + substr(x, 1, 1) <- toupper(substr(x, 1, 1)) + x +} + +#Capitalize first letter of each word +capwords <- function(s, strict = FALSE) { + cap <- function(s) paste0(toupper(substring(s, 1, 1)), + {s <- substring(s, 2) + if (strict) tolower(s) else s}, + collapse = " ") + sapply(strsplit(s, split = " "), cap, USE.NAMES = is_not_null(names(s))) +} + +#Clean printing of data frames with numeric and NA elements. +round_df_char <- function(df, digits, pad = "0", na_vals = "") { + if (NROW(df) == 0L || NCOL(df) == 0L) { + return(df) + } + + if (!is.data.frame(df)) { + df <- as.data.frame.matrix(df, stringsAsFactors = FALSE) + } + + rn <- rownames(df) + cn <- colnames(df) + + infs <- o.negs <- array(FALSE, dim = dim(df)) + nas <- is.na(df) + nums <- vapply(df, is.numeric, logical(1)) + + for (i in which(nums)) { + infs[,i] <- !nas[,i] & !is.finite(df[[i]]) + } + + for (i in which(!nums)) { + if (can_str2num(df[[i]])) { + df[[i]] <- str2num(df[[i]]) + nums[i] <- TRUE + } + } + + o.negs[,nums] <- !nas[,nums] & df[nums] < 0 & round(df[nums], digits) == 0 + df[nums] <- round(df[nums], digits = digits) + + pad0 <- identical(as.character(pad), "0") + + for (i in which(nums)) { + df[[i]] <- format(df[[i]], scientific = FALSE, justify = "none", trim = TRUE, + drop0trailing = !pad0) + + if (!pad0 && any(grepl(".", df[[i]], fixed = TRUE))) { + s <- strsplit(df[[i]], ".", fixed = TRUE) + lengths <- lengths(s) + digits.r.of.. <- rep.int(0, NROW(df)) + digits.r.of..[lengths > 1] <- nchar(vapply(s[lengths > 1], `[[`, character(1L), 2)) + + dots <- rep.int("", length(s)) + dots[lengths <= 1] <- if (as.character(pad) != "") "." else pad + + pads <- vapply(max(digits.r.of..) - digits.r.of.., + function(n) paste(rep.int(pad, n), collapse = ""), + character(1L)) + + df[[i]] <- paste0(df[[i]], dots, pads) + } + } + + df[o.negs] <- paste0("-", df[o.negs]) + + # Insert NA placeholders + df[nas] <- na_vals + df[infs] <- "N/A" + + if (length(rn) > 0) rownames(df) <- rn + if (length(cn) > 0) names(df) <- cn + + df +} + +#Generalized inverse; port of MASS::ginv() +generalized_inverse <- function(sigma, tol = 1e-8) { + sigmasvd <- svd(sigma) + + pos <- sigmasvd$d > max(tol * sigmasvd$d[1L], 0) + + sigmasvd$v[, pos, drop = FALSE] %*% (sigmasvd$d[pos]^-1 * t(sigmasvd$u[, pos, drop = FALSE])) +} + +#(Weighted) variance that uses special formula for binary variables +wvar <- function(x, bin.var = NULL, w = NULL) { + if (is_null(w)) w <- rep.int(1, length(x)) + if (is_null(bin.var)) bin.var <- all(x == 0 | x == 1) + + w <- w / sum(w) #weights normalized to sum to 1 + mx <- sum(w * x) #weighted mean + + if (bin.var) { + return(mx * (1 - mx)) + } + + #Reliability weights variance; same as cov.wt() + sum(w * (x - mx)^2)/(1 - sum(w^2)) +} + +#Weighted mean faster than weighted.mean() +wm <- function(x, w = NULL, na.rm = TRUE) { + if (is_null(w)) { + if (anyNA(x)) { + if (!na.rm) return(NA_real_) + nas <- which(is.na(x)) + x <- x[-nas] + } + return(sum(x)/length(x)) + } + + if (anyNA(x) || anyNA(w)) { + if (!na.rm) return(NA_real_) + nas <- which(is.na(x) | is.na(w)) + x <- x[-nas] + w <- w[-nas] + } + + sum(x*w)/sum(w) +} + +#Faster diff() +diff1 <- function(x) { + x[-1] - x[-length(x)] +} + +#cumsum() for probabilities to ensure they are between 0 and 1 +.cumsum_prob <- function(x) { + s <- cumsum(x) + s / s[length(s)] +} + +#Make vector sum to 1, optionally by group +.make_sum_to_1 <- function(x, by = NULL) { + if (is_null(by)) { + return(x / sum(x)) + } + + for (i in unique(by)) { + in_i <- which(by == i) + x[in_i] <- x[in_i] / sum(x[in_i]) + } + + x +} + +#Make vector sum to n (average of 1), optionally by group +.make_sum_to_n <- function(x, by = NULL) { + if (is_null(by)) { + return(length(x) * x / sum(x)) + } + + for (i in unique(by)) { + in_i <- which(by == i) + x[in_i] <- length(in_i) * x[in_i] / sum(x[in_i]) + } + + x +} + +#Extract variables from ..., similar to ...elt(), by name without evaluating list(...) +...get <- function(x, ...) { + m <- match(x, ...names(), 0L) + + if (m == 0L) { + return(NULL) + } + + ...elt(m) +} + +#Helper function to fill named vectors with x and given names of y +rep_with <- function(x, y) { + setNames(rep.int(x, length(y)), names(y)) +} + +#cat() if verbose = TRUE (default sep = "", line wrapping) +.cat_verbose <- function(..., verbose = TRUE, sep = "") { + if (!verbose) { + return(invisible(NULL)) + } + + m <- do.call(function(...) paste(..., sep = sep), list(...)) + + cat(paste(strwrap(m), collapse = "\n")) +} + +#Functions for error handling; based on chk and rlang +pkg_caller_call <- function(start = 1) { + pn <- utils::packageName() + package.funs <- c(getNamespaceExports(pn), + .getNamespaceInfo(asNamespace(pn), "S3methods")[, 3]) + k <- start #skip checking pkg_caller_call() + e_max <- start + while (is_not_null(e <- rlang::caller_call(k))) { + if (is_not_null(n <- rlang::call_name(e)) && + n %in% package.funs) e_max <- k + k <- k + 1 + } + rlang::caller_call(e_max) +} + +.err <- function(..., n = NULL, tidy = TRUE) { + m <- chk::message_chk(..., n = n, tidy = tidy) + rlang::abort(paste(strwrap(m), collapse = "\n"), + call = pkg_caller_call(start = 2)) +} +.wrn <- function(..., n = NULL, tidy = TRUE, immediate = TRUE) { + if (immediate && isTRUE(all.equal(0, getOption("warn")))) { + op <- options(warn = 1) + on.exit(options(op)) + } + m <- chk::message_chk(..., n = n, tidy = tidy) + rlang::warn(paste(strwrap(m), collapse = "\n")) +} +.msg <- function(..., n = NULL, tidy = TRUE) { + m <- chk::message_chk(..., n = n, tidy = tidy) + rlang::inform(paste(strwrap(m), collapse = "\n"), tidy = FALSE) +} \ No newline at end of file diff --git a/README.Rmd b/README.Rmd index f19f5818..ee518502 100644 --- a/README.Rmd +++ b/README.Rmd @@ -49,7 +49,7 @@ We can check covariate balance for the original and matched samples using `summa summary(m.out) ``` -At the top is balance for the original sample. Below that is balance in the matched sample, followed by the percent reduction in imbalance and the sample sizes before and after matching. Smaller values for the balance statistics indicate better balance. (In this case, good balance was not achieved and other matching methods should be tried). We can plot the standardized mean differences in a Love plot for a clean, visual display of balance across the sample: +At the top is balance for the original sample. Below that is balance in the matched sample. Smaller values for the balance statistics indicate better balance. (In this case, fairly good balance was achieved, but other matching methods should be tried). We can plot the standardized mean differences in a Love plot for a clean, visual display of balance across the sample: ```{r, fig.alt ="Love plot of balance before and after matching."} #Plot balance @@ -77,9 +77,9 @@ install.packages("MatchIt") To install a development version, which may have a bug fixed or a new feature, run the following: ```{r, eval=F} -install.packages("remotes") #If not yet installed +install.packages("pak") #If not yet installed -remotes::install_github("ngreifer/MatchIt") +pak::pkg_install("ngreifer/MatchIt") ``` -This will require R to compile C++ code, which might require additional software be installed on your computer. If you need the development version but can't compile the package, ask the maintainer for a binary version of the package. \ No newline at end of file +This will require R to compile C++ code, which might require additional software to be installed on your computer. If you need the development version but can't compile the package, ask the maintainer for a binary version of the package. \ No newline at end of file diff --git a/README.md b/README.md index 4d65126e..98a54446 100644 --- a/README.md +++ b/README.md @@ -38,10 +38,9 @@ performed. m.out ``` - #> A matchit object + #> A `matchit` object #> - method: 1:1 nearest neighbor matching with replacement - #> - distance: Mahalanobis - #> - number of obs.: 614 (original), 261 (matched) + #> - distance: Mahalanobis - number of obs.: 614 (original), 264 (matched) #> - target estimand: ATT #> - covariates: age, educ, race, married, nodegree, re74, re75 @@ -72,31 +71,30 @@ summary(m.out) #> #> Summary of Balance for Matched Data: #> Means Treated Means Control Std. Mean Diff. Var. Ratio eCDF Mean eCDF Max Std. Pair Dist. - #> age 25.8162 25.5405 0.0385 0.6524 0.0466 0.1892 0.4827 - #> educ 10.3459 10.4270 -0.0403 1.1636 0.0077 0.0378 0.1963 + #> age 25.8162 25.5405 0.0385 0.6531 0.0466 0.1892 0.4827 + #> educ 10.3459 10.4270 -0.0403 1.1649 0.0077 0.0378 0.1963 #> raceblack 0.8432 0.8432 0.0000 . 0.0000 0.0000 0.0000 #> racehispan 0.0595 0.0595 0.0000 . 0.0000 0.0000 0.0000 #> racewhite 0.0973 0.0973 0.0000 . 0.0000 0.0000 0.0000 #> married 0.1892 0.1784 0.0276 . 0.0108 0.0108 0.0276 #> nodegree 0.7081 0.7081 0.0000 . 0.0000 0.0000 0.0000 - #> re74 2095.5737 1788.6941 0.0628 1.5690 0.0311 0.1730 0.2494 - #> re75 1532.0553 1087.7420 0.1380 2.1221 0.0330 0.0865 0.2360 + #> re74 2095.5737 1788.6941 0.0628 1.5707 0.0311 0.1730 0.2494 + #> re75 1532.0553 1087.7420 0.1380 2.1244 0.0330 0.0865 0.2360 #> #> Sample Sizes: #> Control Treated - #> All 429 185 - #> Matched (ESS) 33 185 - #> Matched 76 185 - #> Unmatched 353 0 - #> Discarded 0 0 + #> All 429. 185 + #> Matched (ESS) 34.19 185 + #> Matched 79. 185 + #> Unmatched 350. 0 + #> Discarded 0. 0 At the top is balance for the original sample. Below that is balance in -the matched sample, followed by the percent reduction in imbalance and -the sample sizes before and after matching. Smaller values for the -balance statistics indicate better balance. (In this case, good balance -was not achieved and other matching methods should be tried). We can -plot the standardized mean differences in a Love plot for a clean, -visual display of balance across the sample: +the matched sample. Smaller values for the balance statistics indicate +better balance. (In this case, fairly good balance was achieved, but +other matching methods should be tried). We can plot the standardized +mean differences in a Love plot for a clean, visual display of balance +across the sample: ``` r #Plot balance @@ -147,12 +145,12 @@ To install a development version, which may have a bug fixed or a new feature, run the following: ``` r -install.packages("remotes") #If not yet installed +install.packages("pak") #If not yet installed -remotes::install_github("ngreifer/MatchIt") +pak::pkg_install("ngreifer/MatchIt") ``` This will require R to compile C++ code, which might require additional -software be installed on your computer. If you need the development +software to be installed on your computer. If you need the development version but can’t compile the package, ask the maintainer for a binary version of the package. diff --git a/_archive/MatchIt_A3_estimating_effects2.Rmd b/_archive/MatchIt_A3_estimating_effects2.Rmd index ae5cc82a..5a1847a5 100644 --- a/_archive/MatchIt_A3_estimating_effects2.Rmd +++ b/_archive/MatchIt_A3_estimating_effects2.Rmd @@ -8,7 +8,7 @@ output: vignette: > %\VignetteIndexEntry{Estimating Effects} - %\VignetteEngine{knitr::rmarkdown} + %\VignetteEngine{knitr::rmarkdown_notangle} %\VignetteEncoding{UTF-8} bibliography: references.bib --- diff --git a/man/distance.Rd b/man/distance.Rd index c234b760..f4c45075 100644 --- a/man/distance.Rd +++ b/man/distance.Rd @@ -12,12 +12,13 @@ defining calipers. This page documents the options that can be supplied to the \code{distance} argument to \code{\link[=matchit]{matchit()}}. } \note{ -In versions of \emph{MatchIt} prior to 4.0.0, \code{distance} was -specified in a slightly different way. When specifying arguments using the -old syntax, they will automatically be converted to the corresponding method -in the new syntax but a warning will be thrown. \code{distance = "logit"}, -the old default, will still work in the new syntax, though \verb{distance = "glm", link = "logit"} is preferred (note that these are the default -settings and don't need to be made explicit). +In versions of \emph{MatchIt} prior to 4.0.0, \code{distance} was specified in a +slightly different way. When specifying arguments using the old syntax, they +will automatically be converted to the corresponding method in the new syntax +but a warning will be thrown. \code{distance = "logit"}, the old default, will +still work in the new syntax, though \verb{distance = "glm", link = "logit"} is +preferred (note that these are the default settings and don't need to be made +explicit). } \section{Allowable options}{ @@ -104,8 +105,11 @@ defaults as used in \emph{WeightIt} and \emph{twang}, except for generalized boosted modeling as in \emph{twang}; here, the number of trees is chosen based on cross-validation or out-of-bag error, rather than based on optimizing balance. \pkg{twang} should not be cited when using this method -to estimate propensity scores. } -\item{\code{"lasso"}, \code{"ridge"}, \code{"elasticnet"}}{ The propensity +to estimate propensity scores. Note that because there is a random component to choosing the tuning +parameter, results will vary across runs unless a \link[=set.seed]{seed} is +set.} +\item{\code{"lasso"}, \code{"ridge"}, \code{"elasticnet"}}{ +The propensity scores are estimated using a lasso, ridge, or elastic net model, respectively. The \code{formula} supplied to \code{matchit()} is processed with \code{\link[=model.matrix]{model.matrix()}} and passed to \pkgfun{glmnet}{cv.glmnet}, and @@ -136,7 +140,8 @@ random forest. The \code{formula} supplied to \code{matchit()} is passed directly to \pkgfun{randomForest}{randomForest}, and \pkgfun{randomForest}{predict.randomForest} is used to compute the propensity scores. The \code{link} argument is ignored, and predicted probabilities are -always returned as the distance measure.} +always returned as the distance measure. Note that because there is a random component, results will vary across runs unless a \link[=set.seed]{seed} is +set. } \item{\code{"nnet"}}{ The propensity scores are estimated using a single-hidden-layer neural network. The \code{formula} supplied to \code{matchit()} is passed directly to @@ -164,36 +169,35 @@ scores. The \code{link} argument can be specified as \code{"linear"} to use the linear predictor instead of the predicted probabilities. When \code{s.weights} is supplied to \code{matchit()}, it will not be passed to \code{bart2} because the \code{weights} argument in \code{bart2} does not -correspond to sampling weights. } +correspond to sampling weights. Note that because there is a random component to choosing the tuning +parameter, results will vary across runs unless the \code{seed} argument is supplied to \code{distance.options}. Note that setting a seed using \code{\link[=set.seed]{set.seed()}} is not sufficient to guarantee reproducibility unless single-threading is used. See \pkgfun{dbarts}{bart2} for details.} } } \subsection{Methods for computing distances from covariates}{ -The following methods involve computing a distance matrix from the covariates themselves -without estimating a propensity score. Calipers on the distance measure and -common support restrictions cannot be used, and the \code{distance} -component of the output object will be empty because no propensity scores -are estimated. The \code{link} and \code{distance.options} arguments are -ignored with these methods. See the individual matching methods pages for -whether these distances are allowed and how they are used. Each of these -distance measures can also be calculated outside \code{matchit()} using its -\link[=euclidean_dist]{corresponding function}. +The following methods involve computing a distance matrix from the covariates +themselves without estimating a propensity score. Calipers on the distance +measure and common support restrictions cannot be used, and the \code{distance} +component of the output object will be empty because no propensity scores are +estimated. The \code{link} and \code{distance.options} arguments are ignored with these +methods. See the individual matching methods pages for whether these +distances are allowed and how they are used. Each of these distance measures +can also be calculated outside \code{matchit()} using its \link[=euclidean_dist]{corresponding function}. \describe{ \item{\code{"euclidean"}}{ The Euclidean distance is the raw -distance between units, computed as \deqn{d_{ij} = \sqrt{(x_i - x_j)(x_i - -x_j)'}} It is sensitive to the scale of the covariates, so covariates with +distance between units, computed as \deqn{d_{ij} = \sqrt{(x_i - x_j)(x_i - x_j)'}} It is sensitive to the scale of the covariates, so covariates with larger scales will take higher priority. } -\item{\code{"scaled_euclidean"}}{ The scaled Euclidean distance is the +\item{\code{"scaled_euclidean"}}{ +The scaled Euclidean distance is the Euclidean distance computed on the scaled (i.e., standardized) covariates. This ensures the covariates are on the same scale. The covariates are standardized using the pooled within-group standard deviations, computed by treatment group-mean centering each covariate before computing the standard -deviation in the full sample. } -\item{\code{"mahalanobis"}}{ The -Mahalanobis distance is computed as \deqn{d_{ij} = \sqrt{(x_i - -x_j)\Sigma^{-1}(x_i - x_j)'}} where \eqn{\Sigma} is the pooled within-group +deviation in the full sample. +} +\item{\code{"mahalanobis"}}{ The Mahalanobis distance is computed as \deqn{d_{ij} = \sqrt{(x_i - x_j)\Sigma^{-1}(x_i - x_j)'}} where \eqn{\Sigma} is the pooled within-group covariance matrix of the covariates, computed by treatment group-mean centering each covariate before computing the covariance in the full sample. This ensures the variables are on the same scale and accounts for the @@ -206,30 +210,29 @@ that handles outliers and rare categories better than the standard Mahalanobis distance but is not affinely invariant. } } -To perform Mahalanobis distance matching \emph{and} estimate propensity -scores to be used for a purpose other than matching, the \code{mahvars} -argument should be used along with a different specification to -\code{distance}. See the individual matching method pages for details on how -to use \code{mahvars}. +To perform Mahalanobis distance matching \emph{and} estimate propensity scores to +be used for a purpose other than matching, the \code{mahvars} argument should be +used along with a different specification to \code{distance}. See the individual +matching method pages for details on how to use \code{mahvars}. } \subsection{Distances supplied as a numeric vector or matrix}{ -\code{distance} can also be supplied as a numeric vector whose values will be taken to -function like propensity scores; their pairwise difference will define the -distance between units. This might be useful for supplying propensity scores -computed outside \code{matchit()} or resupplying \code{matchit()} with -propensity scores estimated previously without having to recompute them. +\code{distance} can also be supplied as a numeric vector whose values will be +taken to function like propensity scores; their pairwise difference will +define the distance between units. This might be useful for supplying +propensity scores computed outside \code{matchit()} or resupplying \code{matchit()} +with propensity scores estimated previously without having to recompute them. \code{distance} can also be supplied as a matrix whose values represent the pairwise distances between units. The matrix should either be a square, with a row and column for each unit (e.g., as the output of a call to -\verb{as.matrix(}\code{\link{dist}}\verb{(.))}), or have as many rows as there are treated -units and as many columns as there are control units (e.g., as the output of -a call to \code{\link[=mahalanobis_dist]{mahalanobis_dist()}} or \pkgfun{optmatch}{match_on}). Distance values -of \code{Inf} will disallow the corresponding units to be matched. When -\code{distance} is a supplied as a numeric vector or matrix, \code{link} and -\code{distance.options} are ignored. +\verb{as.matrix(}\code{\link{dist}}\verb{(.))}), or have as many rows as there are treated units +and as many columns as there are control units (e.g., as the output of a call +to \code{\link[=mahalanobis_dist]{mahalanobis_dist()}} or \pkgfun{optmatch}{match_on}). Distance values of +\code{Inf} will disallow the corresponding units to be matched. When \code{distance} is +a supplied as a numeric vector or matrix, \code{link} and \code{distance.options} are +ignored. } } diff --git a/man/mahalanobis_dist.Rd b/man/mahalanobis_dist.Rd index 3871069a..0e10934f 100644 --- a/man/mahalanobis_dist.Rd +++ b/man/mahalanobis_dist.Rd @@ -177,8 +177,7 @@ table(lalonde$treat) } \references{ -Rosenbaum, P. R. (2010). \emph{Design of observational studies}. -Springer. +Rosenbaum, P. R. (2010). \emph{Design of observational studies}. Springer. Rosenbaum, P. R., & Rubin, D. B. (1985). Constructing a Control Group Using Multivariate Matched Sampling Methods That Incorporate the Propensity Score. @@ -189,7 +188,7 @@ Rubin, D. B. (1980). Bias Reduction Using Mahalanobis-Metric Matching. } \seealso{ \code{\link{distance}}, \code{\link[=matchit]{matchit()}}, \code{\link[=dist]{dist()}} (which is used -internally to compute Euclidean distances) +internally to compute some Euclidean distances) \pkgfun{optmatch}{match_on}, which provides similar functionality but with fewer options and a focus on efficient storage of the output. diff --git a/man/matchit.Rd b/man/matchit.Rd index 3579b734..5f83cd6a 100644 --- a/man/matchit.Rd +++ b/man/matchit.Rd @@ -2,7 +2,6 @@ % Please edit documentation in R/matchit.R \name{matchit} \alias{matchit} -\alias{print.matchit} \title{Matching for Causal Inference} \usage{ matchit( @@ -26,10 +25,9 @@ matchit( ratio = 1, verbose = FALSE, include.obj = FALSE, + normalize = TRUE, ... ) - -\method{print}{matchit}(x, ...) } \arguments{ \item{formula}{a two-sided \code{\link{formula}} object containing the treatment and @@ -47,11 +45,12 @@ will be sought in the environment.} \code{\link[=method_nearest]{"nearest"}} for nearest neighbor matching (on the propensity score by default), \code{\link[=method_optimal]{"optimal"}} for optimal pair matching, \code{\link[=method_full]{"full"}} for optimal +full matching, \code{\link[=method_quick]{"quick"}} for generalized (quick) full matching, \code{\link[=method_genetic]{"genetic"}} for genetic matching, \code{\link[=method_cem]{"cem"}} for coarsened exact matching, \code{\link[=method_exact]{"exact"}} for exact matching, \code{\link[=method_cardinality]{"cardinality"}} for cardinality and -template matching, and \code{\link[=method_subclass]{"subclass"}} for +profile matching, and \code{\link[=method_subclass]{"subclass"}} for subclassification. When set to \code{NULL}, no matching will occur, but propensity score estimation and common support restrictions will still occur if requested. See the linked pages for each method for more details on what @@ -72,17 +71,15 @@ used.} argument controlling the link function used in estimating the distance measure. Allowable options depend on the specific \code{distance} value specified. See \code{\link{distance}} for allowable options with each -option. The default is \code{"logit"}, which, along with \code{distance = "glm"}, identifies the default measure as logistic regression propensity -scores.} +option. The default is \code{"logit"}, which, along with \code{distance = "glm"}, identifies the default measure as logistic regression propensity scores.} \item{distance.options}{a named list containing additional arguments supplied to the function that estimates the distance measure as determined -by the argument to \code{distance}. See \link{distance} for an +by the argument to \code{distance}. See \code{\link{distance}} for an example of its use.} \item{estimand}{a string containing the name of the target estimand desired. -Can be one of \code{"ATT"} or \code{"ATC"}. Some methods accept \code{"ATE"} -as well. Default is \code{"ATT"}. See Details and the individual methods +Can be one of \code{"ATT"}, \code{"ATC"}, or \code{"ATE"}. Default is \code{"ATT"}. See Details and the individual methods pages for information on how this argument is used.} \item{exact}{for methods that allow it, for which variables exact matching @@ -98,8 +95,7 @@ propensity scores. Usually used to perform Mahalanobis distance matching within propensity score calipers, where the propensity scores are computed using \code{formula} and \code{distance}. Can be specified as a string containing the names of variables in \code{data} to be used or a one-sided -formula with the desired variables on the right-hand side (e.g., \code{~ X3 + X4}). See the individual methods pages for information on whether and how -this argument is used.} +formula with the desired variables on the right-hand side (e.g., \code{~ X3 + X4}). See the individual methods pages for information on whether and how this argument is used.} \item{antiexact}{for methods that allow it, for which variables anti-exact matching should take place. Anti-exact matching ensures paired individuals @@ -129,7 +125,7 @@ incorporated into propensity score models and balance statistics. Can also be specified as a string containing the name of variable in \code{data} to be used or a one-sided formula with the variable on the right-hand side (e.g., \code{~ SW}). Not all propensity score models accept sampling -weights; see \link{distance} for information on which do and do not, +weights; see \code{\link{distance}} for information on which do and do not, and see \code{vignette("sampling-weights")} for details on how to use sampling weights in a matching analysis.} @@ -172,39 +168,36 @@ the matching process in the output, i.e., by the functions from other packages \code{matchit()} calls. What is included depends on the matching method. Default is \code{FALSE}.} +\item{normalize}{\code{logical}; whether to rescale the nonzero weights in each treatment group to have an average of 1. Default is \code{TRUE}. See "How Matching Weights Are Computed" below for more details.} + \item{\dots}{additional arguments passed to the functions used in the matching process. See the individual methods pages for information on what -additional arguments are allowed for each method. Ignored for \code{print()}.} - -\item{x}{a \code{matchit} object.} +additional arguments are allowed for each method.} } \value{ When \code{method} is something other than \code{"subclass"}, a \code{matchit} object with the following components: -\item{match.matrix}{a matrix containing the matches. The rownames correspond +\item{match.matrix}{a matrix containing the matches. The row names correspond to the treated units and the values in each row are the names (or indices) of the control units matched to each treated unit. When treated units are -matched to different numbers of control units (e.g., with exact matching or +matched to different numbers of control units (e.g., with variable ratio matching or matching with a caliper), empty spaces will be filled with \code{NA}. Not -included when \code{method} is \code{"full"}, \code{"cem"} (unless \code{k2k = TRUE}), \code{"exact"}, or \code{"cardinality"}.} +included when \code{method} is \code{"full"}, \code{"cem"} (unless \code{k2k = TRUE}), \code{"exact"}, \code{"quick"}, or \code{"cardinality"} (unless \code{mahvars} is supplied and \code{ratio} is an integer).} \item{subclass}{a factor containing matching pair/stratum membership for each unit. Unmatched units -will have a value of \code{NA}. Not included when \code{replace = TRUE}.} +will have a value of \code{NA}. Not included when \code{replace = TRUE} or when \code{method = "cardinality"} unless \code{mahvars} is supplied and \code{ratio} is an integer.} \item{weights}{a numeric vector of estimated matching weights. Unmatched and discarded units will have a weight of zero.} \item{model}{the fit object of the model used to estimate propensity scores when \code{distance} is -specified and not \code{"mahalanobis"} or a numeric vector. When +specified as a method of estimating propensity scores. When \code{reestimate = TRUE}, this is the model estimated after discarding units.} -\item{X}{a data frame of covariates mentioned in \code{formula}, -\code{exact}, \code{mahvars}, and \code{antiexact}.} +\item{X}{a data frame of covariates mentioned in \code{formula}, \code{exact}, \code{mahvars}, \code{caliper}, and \code{antiexact}.} \item{call}{the \code{matchit()} call.} -\item{info}{information on the matching method and -distance measures used.} -\item{estimand}{the argument supplied to -\code{estimand}.} +\item{info}{information on the matching method and distance measures used.} +\item{estimand}{the argument supplied to \code{estimand}.} \item{formula}{the \code{formula} supplied.} \item{treat}{a vector of treatment status converted to zeros (0) and ones (1) if not already in that format.} @@ -212,16 +205,11 @@ distance measures used.} values (i.e., propensity scores) when \code{distance} is supplied as a method of estimating propensity scores or a numeric vector.} \item{discarded}{a logical vector denoting whether each observation was -discarded (\code{TRUE}) or not (\code{FALSE}) by the argument to -\code{discard}.} -\item{s.weights}{the vector of sampling weights supplied to -the \code{s.weights} argument, if any.} -\item{exact}{a one-sided formula -containing the variables, if any, supplied to \code{exact}.} -\item{mahvars}{a one-sided formula containing the variables, if any, -supplied to \code{mahvars}.} -\item{obj}{when \code{include.obj = TRUE}, an -object containing the intermediate results of the matching procedure. See +discarded (\code{TRUE}) or not (\code{FALSE}) by the argument to \code{discard}.} +\item{s.weights}{the vector of sampling weights supplied to the \code{s.weights} argument, if any.} +\item{exact}{a one-sided formula containing the variables, if any, supplied to \code{exact}.} +\item{mahvars}{a one-sided formula containing the variables, if any, supplied to \code{mahvars}.} +\item{obj}{when \code{include.obj = TRUE}, an object containing the intermediate results of the matching procedure. See the individual methods pages for what this component will contain.} When \code{method = "subclass"}, a \code{matchit.subclass} object with the same @@ -249,10 +237,11 @@ pages: \item \code{\link{method_nearest}} for nearest neighbor matching \item \code{\link{method_optimal}} for optimal pair matching \item \code{\link{method_full}} for optimal full matching +\item \code{\link{method_quick}} for generalized (quick) full matching \item \code{\link{method_genetic}} for genetic matching \item \code{\link{method_cem}} for coarsened exact matching \item \code{\link{method_exact}} for exact matching -\item \code{\link{method_cardinality}} for cardinality and template matching +\item \code{\link{method_cardinality}} for cardinality and profile matching \item \code{\link{method_subclass}} for subclassification } @@ -276,7 +265,7 @@ matching on any of the included covariates and on the propensity score if specified. All arguments other than \code{distance}, \code{discard}, and \code{reestimate} will be ignored. -See \link{distance} for details on the several ways to +See \code{\link{distance}} for details on the several ways to specify the \code{distance}, \code{link}, and \code{distance.options} arguments to estimate propensity scores and create distance measures. @@ -285,7 +274,7 @@ to one and returned as such in the \code{matchit()} output (see section Value, below). The following rules are used: 1) if \code{0} is one of the values, it will be considered the control and the other value the treated; 2) otherwise, if the variable is a factor, \code{levels(treat)[1]} will be -considered control and the other variable the treated; 3) otherwise, +considered control and the other value the treated; 3) otherwise, \code{sort(unique(treat))[1]} will be considered control and the other value the treated. It is safest to ensure the treatment variable is a \code{0/1} variable. @@ -320,22 +309,32 @@ stratification weights. Matching weights are computed in one of two ways depending on whether matching was done with replacement or not. +\subsection{Matching without replacement and subclassification}{ -For matching \emph{without} replacement (except for cardinality matching), each +For matching \emph{without} replacement (except for cardinality matching), including subclassification, each unit is assigned to a subclass, which represents the pair they are a part of (in the case of k:1 matching) or the stratum they belong to (in the case of exact matching, coarsened exact matching, full matching, or subclassification). The formula for computing the weights depends on the argument supplied to \code{estimand}. A new "stratum propensity score" -(\code{sp}) is computed as the proportion of units in each stratum that are -in the treated group, and all units in that stratum are assigned that +(\eqn{p^s_i}) is computed for each unit \eqn{i} as \eqn{p^s_i = \frac{1}{n_s}\sum_{j: s_j =s_i}{I(A_j=1)}} where \eqn{n_s} is the size of subclass \eqn{s} and \eqn{I(A_j=1)} is 1 if unit \eqn{j} is treated and 0 otherwise. That is, the stratum propensity score for stratum \eqn{s} is the proportion of units in stratum \eqn{s} that are +in the treated group, and all units in stratum \eqn{s} are assigned that stratum propensity score. This is distinct from the propensity score used for matching, if any. Weights are then computed using the standard formulas for -inverse probability weights with the stratum propensity score inserted: for the ATT, weights are 1 for the treated -units and \code{sp/(1-sp)} for the control units; for the ATC, weights are -\code{(1-sp)/sp} for the treated units and 1 for the control units; for the -ATE, weights are \code{1/sp} for the treated units and \code{1/(1-sp)} for the -control units. For cardinality matching, all matched units receive a weight +inverse probability weights with the stratum propensity score inserted: +\itemize{ +\item for the ATT, weights are 1 for the treated +units and \eqn{\frac{p^s}{1-p^s}} for the control units +\item for the ATC, weights are +\eqn{\frac{1-p^s}{p^s}} for the treated units and 1 for the control units +\item for the ATE, weights are \eqn{\frac{1}{p^s}} for the treated units and \eqn{\frac{1}{1-p^s}} for the +control units. +} + +For cardinality matching, all matched units receive a weight of 1. +} + +\subsection{Matching witht replacement}{ For matching \emph{with} replacement, units are not assigned to unique strata. For the ATT, each treated unit gets a weight of 1. Each control unit is weighted @@ -343,18 +342,27 @@ as the sum of the inverse of the number of control units matched to the same treated unit across its matches. For example, if a control unit was matched to a treated unit that had two other control units matched to it, and that same control was matched to a treated unit that had one other control unit -matched to it, the control unit in question would get a weight of 1/3 + 1/2 -= 5/6. For the ATC, the same is true with the treated and control labels +matched to it, the control unit in question would get a weight of \eqn{1/3 + 1/2 = 5/6}. For the ATC, the same is true with the treated and control labels switched. The weights are computed using the \code{match.matrix} component of the \code{matchit()} output object. +} + +\subsection{Normalized weights}{ -In each treatment group, weights are divided by the mean of the nonzero +When \code{normalize = TRUE} (the default), in each treatment group, weights are divided by the mean of the nonzero weights in that treatment group to make the weights sum to the number of -units in that treatment group. If sampling weights are included through the +units in that treatment group (i.e., to have an average of 1). +} + +\subsection{Sampling weights}{ + +If sampling weights are included through the \code{s.weights} argument, they will be included in the \code{matchit()} output object but not incorporated into the matching weights. \code{\link[=match.data]{match.data()}}, which extracts the matched set from a \code{matchit} object, combines the matching weights and sampling weights. +} + } } \examples{ @@ -408,7 +416,6 @@ s.out1 <- matchit(treat ~ age + educ + race + nodegree + discard = "control", subclass = 10) s.out1 summary(s.out1, un = TRUE) - } \references{ Ho, D. E., Imai, K., King, G., & Stuart, E. A. (2007). Matching @@ -421,20 +428,14 @@ Statistical Software}, 42(8). \doi{10.18637/jss.v042.i08} } \seealso{ \code{\link[=summary.matchit]{summary.matchit()}} for balance assessment after matching, \code{\link[=plot.matchit]{plot.matchit()}} for plots of covariate balance and propensity score overlap after matching. - -\code{vignette("MatchIt")} for an introduction to matching with -\emph{MatchIt}; \code{vignette("matching-methods")} for descriptions of the -variety of matching methods and options available; -\code{vignette("assessing-balance")} for information on assessing the -quality of a matching specification; \code{vignette("estimating-effects")} -for instructions on how to estimate treatment effects after matching; and -\code{vignette("sampling-weights")} for a guide to using \emph{MatchIt} with -sampling weights. +\itemize{ +\item \code{vignette("MatchIt")} for an introduction to matching with \emph{MatchIt} +\item \code{vignette("matching-methods")} for descriptions of the variety of matching methods and options available +\item \code{vignette("assessing-balance")} for information on assessing the quality of a matching specification +\item \code{vignette("estimating-effects")} for instructions on how to estimate treatment effects after matching +\item \code{vignette("sampling-weights")} for a guide to using \emph{MatchIt} with sampling weights. +} } \author{ -Daniel Ho (\email{dho@law.stanford.edu}); Kosuke Imai -(\email{imai@harvard.edu}); Gary King (\email{king@harvard.edu}); -Elizabeth Stuart (\email{estuart@jhsph.edu}) - -Version 4.0.0 update by Noah Greifer (\email{noah.greifer@gmail.com}) +Daniel Ho, Kosuke Imai, Gary King, and Elizabeth Stuart wrote the original package. Starting with version 4.0.0, Noah Greifer is the primary maintainer and developer. } diff --git a/man/method_cem.Rd b/man/method_cem.Rd index aa337d15..50318511 100644 --- a/man/method_cem.Rd +++ b/man/method_cem.Rd @@ -54,18 +54,29 @@ units will be dropped (e.g., if there are more treated than control units in the stratum, the treated units without a match will be dropped). The \code{k2k.method} argument controls how the distance between units is calculated. } -\item{\code{k2k.method}}{ \code{character}; how the distance +\item{\code{k2k.method}}{\code{character}; how the distance between units should be calculated if \code{k2k = TRUE}. Allowable arguments include \code{NULL} (for random matching), any argument to \code{\link[=distance]{distance()}} for computing a distance matrix from covariates (e.g., \code{"mahalanobis"}), or any allowable argument to \code{method} in \code{\link[=dist]{dist()}}. Matching will take place on the original -(non-coarsened) variables. The default is \code{"mahalanobis"}. } -\item{\code{mpower}}{ if \code{k2k.method = "minkowski"}, the power used in -creating the distance. This is passed to the \code{p} argument of \code{\link[=dist]{dist()}}.} +(non-coarsened) variables. The default is \code{"mahalanobis"}. +} +\item{\code{mpower}}{if \code{k2k.method = "minkowski"}, the power used in +creating the distance. This is passed to the \code{p} argument of \code{\link[=dist]{dist()}}. +} +\item{\code{m.order}}{\code{character}; the order that the matching takes place when \code{k2k = TRUE}. Allowable options +include \code{"closest"}, where matching takes place in +ascending order of the smallest distance between units; \code{"farthest"}, where matching takes place in +descending order of the smallest distance between units; \code{"random"}, where matching takes place +in a random order; and \code{"data"} where matching takes place based on the +order of units in the data. When \code{m.order = "random"}, results may differ +across different runs of the same code unless a seed is set and specified +with \code{\link[=set.seed]{set.seed()}}. The default of \code{NULL} corresponds to \code{"data"}. See \code{\link{method_nearest}} for more information. +} } -The arguments \code{distance} (and related arguments), \code{exact}, \code{mahvars}, \code{discard} (and related arguments), \code{replace}, \code{m.order}, \code{caliper} (and related arguments), and \code{ratio} are ignored with a warning.} +The arguments \code{distance} (and related arguments), \code{exact}, \code{mahvars}, \code{discard} (and related arguments), \code{replace}, \code{caliper} (and related arguments), and \code{ratio} are ignored with a warning.} } \description{ In \code{\link[=matchit]{matchit()}}, setting \code{method = "cem"} performs coarsened exact @@ -109,7 +120,7 @@ matches, the covariates can be manually coarsened outside of Setting \code{k2k = TRUE} is equivalent to first doing coarsened exact matching with \code{k2k = FALSE} and then supplying stratum membership as an exact matching variable (i.e., in \code{exact}) to another call to -\code{matchit()} with \code{method = "nearest"}, \code{distance = "mahalanobis"} and an argument to \code{discard} denoting unmatched units. +\code{matchit()} with \code{method = "nearest"}. It is also equivalent to performing nearest neighbor matching supplying coarsened versions of the variables to \code{exact}, except that \code{method = "cem"} automatically coarsens the continuous variables. The @@ -246,7 +257,6 @@ m.out2 <- matchit(treat ~ age + race + married + educ, data = lalonde, k2k = TRUE, k2k.method = "mahalanobis") m.out2 summary(m.out2, un = FALSE) - } \references{ In a manuscript, you don't need to cite another package when diff --git a/man/method_full.Rd b/man/method_full.Rd index 3dfb0e35..28aab09f 100644 --- a/man/method_full.Rd +++ b/man/method_full.Rd @@ -39,7 +39,7 @@ place when \code{distance} corresponds to a propensity score (e.g., for caliper matching or to discard units for common support). If specified, the distance measure will not be used in matching.} -\item{antiexact}{for which variables ant-exact matching should take place. +\item{antiexact}{for which variables anti-exact matching should take place. Anti-exact matching is processed using \pkgfun{optmatch}{antiExactMatch}.} \item{discard}{a string containing a method for discarding units outside a diff --git a/man/method_genetic.Rd b/man/method_genetic.Rd index 23aed566..13060c7d 100644 --- a/man/method_genetic.Rd +++ b/man/method_genetic.Rd @@ -50,7 +50,7 @@ to the distance matrix. Use \code{mahvars} to only supply a subset. Even if \code{mahvars} is specified, balance will be optimized on all covariates in \code{formula}. See Details.} -\item{antiexact}{for which variables ant-exact matching should take place. +\item{antiexact}{for which variables anti-exact matching should take place. Anti-exact matching is processed using the \code{restrict} argument to \code{Matching::GenMatch()} and \code{Matching::Match()}.} @@ -200,6 +200,11 @@ considered "focal". Note that while \code{GenMatch()} and \code{Match()} support the ATE as an estimand, \code{matchit()} only supports the ATT and ATC for genetic matching. } + +\subsection{Reproducibility}{ + +Genetic matching involves a random component, so a seed must be set using \code{\link[=set.seed]{set.seed()}} to ensure reproducibility. When \code{cluster} is used for parallel processing, the seed must be compatible with parallel processing (e.g., by setting \code{type = "L'Ecuyer-CMRG"}). +} } \section{Outputs}{ diff --git a/man/method_nearest.Rd b/man/method_nearest.Rd index b211d9b9..e25b70a4 100644 --- a/man/method_nearest.Rd +++ b/man/method_nearest.Rd @@ -28,14 +28,14 @@ by the argument to \code{distance}.} \item{estimand}{a string containing the desired estimand. Allowable options include \code{"ATT"} and \code{"ATC"}. See Details.} -\item{exact}{for which variables exact matching should take place.} +\item{exact}{for which variables exact matching should take place; two units with different values of an exact matching variable will not be paired.} \item{mahvars}{for which variables Mahalanobis distance matching should take place when \code{distance} corresponds to a propensity score (e.g., for caliper matching or to discard units for common support). If specified, the distance measure will not be used in matching.} -\item{antiexact}{for which variables ant-exact matching should take place.} +\item{antiexact}{for which variables anti-exact matching should take place; two units with the same value of an anti-exact matching variable will not be paired.} \item{discard}{a string containing a method for discarding units outside a region of common support. Only allowed when \code{distance} corresponds to a @@ -47,29 +47,29 @@ re-estimate the propensity score in the remaining sample prior to matching.} \item{s.weights}{the variable containing sampling weights to be incorporated into propensity score models and balance statistics.} -\item{replace}{whether matching should be done with replacement.} +\item{replace}{whether matching should be done with replacement (i.e., whether control units can be used as matches multiple times). See also the \code{reuse.max} argument below. Default is \code{FALSE} for matching without replacement.} \item{m.order}{the order that the matching takes place. Allowable options include \code{"largest"}, where matching takes place in descending order of distance measures; \code{"smallest"}, where matching takes place in ascending order of distance measures; \code{"closest"}, where matching takes place in -order of the distance between units; \code{"random"}, where matching takes place +ascending order of the smallest distance between units; \code{"farthest"}, where matching takes place in +descending order of the smallest distance between units; \code{"random"}, where matching takes place in a random order; and \code{"data"} where matching takes place based on the order of units in the data. When \code{m.order = "random"}, results may differ across different runs of the same code unless a seed is set and specified with \code{\link[=set.seed]{set.seed()}}. The default of \code{NULL} corresponds to \code{"largest"} when a propensity score is estimated or supplied as a vector and \code{"data"} -otherwise.} +otherwise. See Details for more information.} -\item{caliper}{the width(s) of the caliper(s) used for caliper matching. See -Details and Examples.} +\item{caliper}{the width(s) of the caliper(s) used for caliper matching. Two units with a difference on a caliper variable larger than the caliper will not be paired. See Details and Examples.} \item{std.caliper}{\code{logical}; when calipers are specified, whether they are in standard deviation units (\code{TRUE}) or raw units (\code{FALSE}).} \item{ratio}{how many control units should be matched to each treated unit for k:1 matching. For variable ratio matching, see section "Variable Ratio -Matching" in Details below.} +Matching" in Details below. When \code{ratio} is greater than 1, all treated units will be attempted to be matched with a control unit before any treated unit is matched with a second control unit, etc. This reduces the possibility that control units will be used up before some treated units receive any matches.} \item{min.controls, max.controls}{for variable ratio matching, the minimum and maximum number of controls units to be matched to each treated unit. See @@ -77,7 +77,7 @@ section "Variable Ratio Matching" in Details below.} \item{verbose}{\code{logical}; whether information about the matching process should be printed to the console. When \code{TRUE}, a progress bar -implemented using \emph{RcppProgress} will be displayed.} +implemented using \emph{RcppProgress} will be displayed along with an estimate of the time remaining.} \item{\dots}{additional arguments that control the matching specification: \describe{ @@ -94,8 +94,7 @@ observation, i.e., in case multiple observations correspond to the same unit. Once a control observation has been matched, no other observation with the same unit ID can be used as matches. This ensures each control unit is used only once even if it has multiple observations associated with it. -Omitting this argument is the same as giving each observation a unique ID. -Ignored when \code{replace = TRUE}. } +Omitting this argument is the same as giving each observation a unique ID.} }} } \description{ @@ -198,8 +197,7 @@ switch to trigger which treatment group is considered "focal". \subsection{Variable Ratio Matching}{ -\code{matchit()} can perform variable -ratio "extremal" matching as described by Ming and Rosenbaum (2000). This +\code{matchit()} can perform variable ratio "extremal" matching as described by Ming and Rosenbaum (2000). This method tends to result in better balance than fixed ratio matching at the expense of some precision. When \code{ratio > 1}, rather than requiring all treated units to receive \code{ratio} matches, each treated unit is assigned @@ -229,9 +227,16 @@ specified, it is set to 1 by default. \code{min.controls} must be less than Examples below for an example of their use. } -\subsection{Using \code{m.order = "closest"}}{ +\subsection{Using \code{m.order = "closest"} or \code{"farthest"}}{ -As of version 4.6.0, \code{m.order} can be set to \code{"closest"}, which works regardless of how the distance measure is specified. This matches in order of the distance between units. The closest pair of units across all potential pairs of units will be matched first; the second closest pair of all potential pairs will be matched second, etc. This ensures that the best possible matches are given priority, and in that sense performs similarly to \code{m.order = "smallest"}. +\code{m.order} can be set to \code{"closest"} or \code{"farthest"}, which work regardless of how the distance measure is specified. This matches in order of the distance between units. First, all the closest match is found for all treated units and the pairwise distances computed; when \code{m.order = "closest"} the pair with the smallest of the distances is matched first, and when \code{m.order = "farthest"}, the pair with the largest of the distances is matched first. Then, the pair with the second smallest (or largest) is matched second. If the matched control is ineligible (i.e., because it has already been used in a prior match), a new match is found for the treated unit, the new pair's distance is re-computed, and the pairs are re-ordered by distance. + +Using \code{m.order = "closest"} ensures that the best possible matches are given priority, and in that sense should perform similarly to \code{m.order = "smallest"}. It can be used to ensure the best matches, especially when matching with a caliper. Using \code{m.order = "farthest"} ensures that the hardest units to match are given their best chance to find a close match, and in that sense should perform similarly to \code{m.order = "largest"}. It can be used to reduce the possibility of extreme imbalance when there are hard-to-match units competing for controls. Note that \code{m.order = "farthest"} \strong{does not} implement "far matching" (i.e., finding the farthest control unit from each treated unit); it defines the order in which the closest matches are selected. +} + +\subsection{Reproducibility}{ + +Nearest neighbor matching involves a random component only when \code{m.order = "random"} (or when the propensity is estimated using a method with randomness; see \code{\link{distance}} for details), so a seed must be set in that case using \code{\link[=set.seed]{set.seed()}} to ensure reproducibility. Otherwise, it is purely deterministic, and any ties are broken based on the order in which the data appear. } } \note{ @@ -304,6 +309,6 @@ within \emph{MatchIt}. For example, a sentence might read: a call to \code{matchit()}. \code{\link[=method_optimal]{method_optimal()}} for optimal pair matching, which is similar to -nearest neighbor matching except that an overall distance criterion is -minimized. +nearest neighbor matching without replacement except that an overall distance criterion is +minimized (i.e., as an alternative to specifying \code{m.order}). } diff --git a/man/method_optimal.Rd b/man/method_optimal.Rd index 3c7c825e..9fc08157 100644 --- a/man/method_optimal.Rd +++ b/man/method_optimal.Rd @@ -37,7 +37,7 @@ place when \code{distance} corresponds to a propensity score (e.g., for caliper matching or to discard units for common support). If specified, the distance measure will not be used in matching.} -\item{antiexact}{for which variables ant-exact matching should take place. +\item{antiexact}{for which variables anti-exact matching should take place. Anti-exact matching is processed using \pkgfun{optmatch}{antiExactMatch}.} \item{discard}{a string containing a method for discarding units outside a diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 4adbd0b2..7203f657 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -13,76 +13,163 @@ Rcpp::Rostream& Rcpp::Rcout = Rcpp::Rcpp_cout_get(); Rcpp::Rostream& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get(); #endif -// dist_to_matrixC -NumericMatrix dist_to_matrixC(const NumericVector& d); -RcppExport SEXP _MatchIt_dist_to_matrixC(SEXP dSEXP) { +// all_equal_to +bool all_equal_to(RObject x, RObject y); +RcppExport SEXP _MatchIt_all_equal_to(SEXP xSEXP, SEXP ySEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; - Rcpp::traits::input_parameter< const NumericVector& >::type d(dSEXP); - rcpp_result_gen = Rcpp::wrap(dist_to_matrixC(d)); + Rcpp::traits::input_parameter< RObject >::type x(xSEXP); + Rcpp::traits::input_parameter< RObject >::type y(ySEXP); + rcpp_result_gen = Rcpp::wrap(all_equal_to(x, y)); return rcpp_result_gen; END_RCPP } -// nn_matchC -IntegerMatrix nn_matchC(const IntegerVector& treat_, const IntegerVector& ord_, const IntegerVector& ratio, const LogicalVector& discarded, const int& reuse_max, const Nullable& distance_, const Nullable& distance_mat_, const Nullable& exact_, const Nullable& caliper_dist_, const Nullable& caliper_covs_, const Nullable& caliper_covs_mat_, const Nullable& mah_covs_, const Nullable& antiexact_covs_, const Nullable& unit_id_, const bool& disl_prog); -RcppExport SEXP _MatchIt_nn_matchC(SEXP treat_SEXP, SEXP ord_SEXP, SEXP ratioSEXP, SEXP discardedSEXP, SEXP reuse_maxSEXP, SEXP distance_SEXP, SEXP distance_mat_SEXP, SEXP exact_SEXP, SEXP caliper_dist_SEXP, SEXP caliper_covs_SEXP, SEXP caliper_covs_mat_SEXP, SEXP mah_covs_SEXP, SEXP antiexact_covs_SEXP, SEXP unit_id_SEXP, SEXP disl_progSEXP) { +// eucdistC_N1xN0 +NumericVector eucdistC_N1xN0(const NumericMatrix& x, const IntegerVector& t); +RcppExport SEXP _MatchIt_eucdistC_N1xN0(SEXP xSEXP, SEXP tSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const NumericMatrix& >::type x(xSEXP); + Rcpp::traits::input_parameter< const IntegerVector& >::type t(tSEXP); + rcpp_result_gen = Rcpp::wrap(eucdistC_N1xN0(x, t)); + return rcpp_result_gen; +END_RCPP +} +// get_splitsC +NumericVector get_splitsC(const NumericVector& x, const double& caliper); +RcppExport SEXP _MatchIt_get_splitsC(SEXP xSEXP, SEXP caliperSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const NumericVector& >::type x(xSEXP); + Rcpp::traits::input_parameter< const double& >::type caliper(caliperSEXP); + rcpp_result_gen = Rcpp::wrap(get_splitsC(x, caliper)); + return rcpp_result_gen; +END_RCPP +} +// has_n_unique +bool has_n_unique(const SEXP& x, const int& n); +RcppExport SEXP _MatchIt_has_n_unique(SEXP xSEXP, SEXP nSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const SEXP& >::type x(xSEXP); + Rcpp::traits::input_parameter< const int& >::type n(nSEXP); + rcpp_result_gen = Rcpp::wrap(has_n_unique(x, n)); + return rcpp_result_gen; +END_RCPP +} +// nn_matchC_distmat +IntegerMatrix nn_matchC_distmat(const IntegerVector& treat_, const IntegerVector& ord, const IntegerVector& ratio, const LogicalVector& discarded, const int& reuse_max, const int& focal_, const NumericMatrix& distance_mat, const Nullable& exact_, const Nullable& caliper_dist_, const Nullable& caliper_covs_, const Nullable& caliper_covs_mat_, const Nullable& antiexact_covs_, const Nullable& unit_id_, const bool& disl_prog); +RcppExport SEXP _MatchIt_nn_matchC_distmat(SEXP treat_SEXP, SEXP ordSEXP, SEXP ratioSEXP, SEXP discardedSEXP, SEXP reuse_maxSEXP, SEXP focal_SEXP, SEXP distance_matSEXP, SEXP exact_SEXP, SEXP caliper_dist_SEXP, SEXP caliper_covs_SEXP, SEXP caliper_covs_mat_SEXP, SEXP antiexact_covs_SEXP, SEXP unit_id_SEXP, SEXP disl_progSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; Rcpp::traits::input_parameter< const IntegerVector& >::type treat_(treat_SEXP); - Rcpp::traits::input_parameter< const IntegerVector& >::type ord_(ord_SEXP); + Rcpp::traits::input_parameter< const IntegerVector& >::type ord(ordSEXP); Rcpp::traits::input_parameter< const IntegerVector& >::type ratio(ratioSEXP); Rcpp::traits::input_parameter< const LogicalVector& >::type discarded(discardedSEXP); Rcpp::traits::input_parameter< const int& >::type reuse_max(reuse_maxSEXP); - Rcpp::traits::input_parameter< const Nullable& >::type distance_(distance_SEXP); - Rcpp::traits::input_parameter< const Nullable& >::type distance_mat_(distance_mat_SEXP); - Rcpp::traits::input_parameter< const Nullable& >::type exact_(exact_SEXP); + Rcpp::traits::input_parameter< const int& >::type focal_(focal_SEXP); + Rcpp::traits::input_parameter< const NumericMatrix& >::type distance_mat(distance_matSEXP); + Rcpp::traits::input_parameter< const Nullable& >::type exact_(exact_SEXP); Rcpp::traits::input_parameter< const Nullable& >::type caliper_dist_(caliper_dist_SEXP); Rcpp::traits::input_parameter< const Nullable& >::type caliper_covs_(caliper_covs_SEXP); Rcpp::traits::input_parameter< const Nullable& >::type caliper_covs_mat_(caliper_covs_mat_SEXP); - Rcpp::traits::input_parameter< const Nullable& >::type mah_covs_(mah_covs_SEXP); Rcpp::traits::input_parameter< const Nullable& >::type antiexact_covs_(antiexact_covs_SEXP); Rcpp::traits::input_parameter< const Nullable& >::type unit_id_(unit_id_SEXP); Rcpp::traits::input_parameter< const bool& >::type disl_prog(disl_progSEXP); - rcpp_result_gen = Rcpp::wrap(nn_matchC(treat_, ord_, ratio, discarded, reuse_max, distance_, distance_mat_, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, mah_covs_, antiexact_covs_, unit_id_, disl_prog)); + rcpp_result_gen = Rcpp::wrap(nn_matchC_distmat(treat_, ord, ratio, discarded, reuse_max, focal_, distance_mat, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, disl_prog)); return rcpp_result_gen; END_RCPP } -// nn_matchC_closest -IntegerMatrix nn_matchC_closest(const NumericMatrix& distance_mat, const IntegerVector& treat, const IntegerVector& ratio, const LogicalVector& discarded, const int& reuse_max, const Nullable& exact_, const Nullable& caliper_dist_, const Nullable& caliper_covs_, const Nullable& caliper_covs_mat_, const Nullable& antiexact_covs_, const Nullable& unit_id_, const bool& disl_prog); -RcppExport SEXP _MatchIt_nn_matchC_closest(SEXP distance_matSEXP, SEXP treatSEXP, SEXP ratioSEXP, SEXP discardedSEXP, SEXP reuse_maxSEXP, SEXP exact_SEXP, SEXP caliper_dist_SEXP, SEXP caliper_covs_SEXP, SEXP caliper_covs_mat_SEXP, SEXP antiexact_covs_SEXP, SEXP unit_id_SEXP, SEXP disl_progSEXP) { +// nn_matchC_distmat_closest +IntegerMatrix nn_matchC_distmat_closest(const IntegerVector& treat, const IntegerVector& ratio, const LogicalVector& discarded, const int& reuse_max, const NumericMatrix& distance_mat, const Nullable& exact_, const Nullable& caliper_dist_, const Nullable& caliper_covs_, const Nullable& caliper_covs_mat_, const Nullable& antiexact_covs_, const Nullable& unit_id_, const bool& close, const bool& disl_prog); +RcppExport SEXP _MatchIt_nn_matchC_distmat_closest(SEXP treatSEXP, SEXP ratioSEXP, SEXP discardedSEXP, SEXP reuse_maxSEXP, SEXP distance_matSEXP, SEXP exact_SEXP, SEXP caliper_dist_SEXP, SEXP caliper_covs_SEXP, SEXP caliper_covs_mat_SEXP, SEXP antiexact_covs_SEXP, SEXP unit_id_SEXP, SEXP closeSEXP, SEXP disl_progSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const IntegerVector& >::type treat(treatSEXP); + Rcpp::traits::input_parameter< const IntegerVector& >::type ratio(ratioSEXP); + Rcpp::traits::input_parameter< const LogicalVector& >::type discarded(discardedSEXP); + Rcpp::traits::input_parameter< const int& >::type reuse_max(reuse_maxSEXP); Rcpp::traits::input_parameter< const NumericMatrix& >::type distance_mat(distance_matSEXP); + Rcpp::traits::input_parameter< const Nullable& >::type exact_(exact_SEXP); + Rcpp::traits::input_parameter< const Nullable& >::type caliper_dist_(caliper_dist_SEXP); + Rcpp::traits::input_parameter< const Nullable& >::type caliper_covs_(caliper_covs_SEXP); + Rcpp::traits::input_parameter< const Nullable& >::type caliper_covs_mat_(caliper_covs_mat_SEXP); + Rcpp::traits::input_parameter< const Nullable& >::type antiexact_covs_(antiexact_covs_SEXP); + Rcpp::traits::input_parameter< const Nullable& >::type unit_id_(unit_id_SEXP); + Rcpp::traits::input_parameter< const bool& >::type close(closeSEXP); + Rcpp::traits::input_parameter< const bool& >::type disl_prog(disl_progSEXP); + rcpp_result_gen = Rcpp::wrap(nn_matchC_distmat_closest(treat, ratio, discarded, reuse_max, distance_mat, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, close, disl_prog)); + return rcpp_result_gen; +END_RCPP +} +// nn_matchC_mahcovs +IntegerMatrix nn_matchC_mahcovs(const IntegerVector& treat_, const IntegerVector& ord, const IntegerVector& ratio, const LogicalVector& discarded, const int& reuse_max, const int& focal_, const NumericMatrix& mah_covs, const Nullable& distance_, const Nullable& exact_, const Nullable& caliper_dist_, const Nullable& caliper_covs_, const Nullable& caliper_covs_mat_, const Nullable& antiexact_covs_, const Nullable& unit_id_, const bool& disl_prog); +RcppExport SEXP _MatchIt_nn_matchC_mahcovs(SEXP treat_SEXP, SEXP ordSEXP, SEXP ratioSEXP, SEXP discardedSEXP, SEXP reuse_maxSEXP, SEXP focal_SEXP, SEXP mah_covsSEXP, SEXP distance_SEXP, SEXP exact_SEXP, SEXP caliper_dist_SEXP, SEXP caliper_covs_SEXP, SEXP caliper_covs_mat_SEXP, SEXP antiexact_covs_SEXP, SEXP unit_id_SEXP, SEXP disl_progSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const IntegerVector& >::type treat_(treat_SEXP); + Rcpp::traits::input_parameter< const IntegerVector& >::type ord(ordSEXP); + Rcpp::traits::input_parameter< const IntegerVector& >::type ratio(ratioSEXP); + Rcpp::traits::input_parameter< const LogicalVector& >::type discarded(discardedSEXP); + Rcpp::traits::input_parameter< const int& >::type reuse_max(reuse_maxSEXP); + Rcpp::traits::input_parameter< const int& >::type focal_(focal_SEXP); + Rcpp::traits::input_parameter< const NumericMatrix& >::type mah_covs(mah_covsSEXP); + Rcpp::traits::input_parameter< const Nullable& >::type distance_(distance_SEXP); + Rcpp::traits::input_parameter< const Nullable& >::type exact_(exact_SEXP); + Rcpp::traits::input_parameter< const Nullable& >::type caliper_dist_(caliper_dist_SEXP); + Rcpp::traits::input_parameter< const Nullable& >::type caliper_covs_(caliper_covs_SEXP); + Rcpp::traits::input_parameter< const Nullable& >::type caliper_covs_mat_(caliper_covs_mat_SEXP); + Rcpp::traits::input_parameter< const Nullable& >::type antiexact_covs_(antiexact_covs_SEXP); + Rcpp::traits::input_parameter< const Nullable& >::type unit_id_(unit_id_SEXP); + Rcpp::traits::input_parameter< const bool& >::type disl_prog(disl_progSEXP); + rcpp_result_gen = Rcpp::wrap(nn_matchC_mahcovs(treat_, ord, ratio, discarded, reuse_max, focal_, mah_covs, distance_, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, disl_prog)); + return rcpp_result_gen; +END_RCPP +} +// nn_matchC_mahcovs_closest +IntegerMatrix nn_matchC_mahcovs_closest(const IntegerVector& treat, const IntegerVector& ratio, const LogicalVector& discarded, const int& reuse_max, const NumericMatrix& mah_covs, const Nullable& distance_, const Nullable& exact_, const Nullable& caliper_dist_, const Nullable& caliper_covs_, const Nullable& caliper_covs_mat_, const Nullable& antiexact_covs_, const Nullable& unit_id_, const bool& close, const bool& disl_prog); +RcppExport SEXP _MatchIt_nn_matchC_mahcovs_closest(SEXP treatSEXP, SEXP ratioSEXP, SEXP discardedSEXP, SEXP reuse_maxSEXP, SEXP mah_covsSEXP, SEXP distance_SEXP, SEXP exact_SEXP, SEXP caliper_dist_SEXP, SEXP caliper_covs_SEXP, SEXP caliper_covs_mat_SEXP, SEXP antiexact_covs_SEXP, SEXP unit_id_SEXP, SEXP closeSEXP, SEXP disl_progSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; Rcpp::traits::input_parameter< const IntegerVector& >::type treat(treatSEXP); Rcpp::traits::input_parameter< const IntegerVector& >::type ratio(ratioSEXP); Rcpp::traits::input_parameter< const LogicalVector& >::type discarded(discardedSEXP); Rcpp::traits::input_parameter< const int& >::type reuse_max(reuse_maxSEXP); + Rcpp::traits::input_parameter< const NumericMatrix& >::type mah_covs(mah_covsSEXP); + Rcpp::traits::input_parameter< const Nullable& >::type distance_(distance_SEXP); Rcpp::traits::input_parameter< const Nullable& >::type exact_(exact_SEXP); Rcpp::traits::input_parameter< const Nullable& >::type caliper_dist_(caliper_dist_SEXP); Rcpp::traits::input_parameter< const Nullable& >::type caliper_covs_(caliper_covs_SEXP); Rcpp::traits::input_parameter< const Nullable& >::type caliper_covs_mat_(caliper_covs_mat_SEXP); Rcpp::traits::input_parameter< const Nullable& >::type antiexact_covs_(antiexact_covs_SEXP); Rcpp::traits::input_parameter< const Nullable& >::type unit_id_(unit_id_SEXP); + Rcpp::traits::input_parameter< const bool& >::type close(closeSEXP); Rcpp::traits::input_parameter< const bool& >::type disl_prog(disl_progSEXP); - rcpp_result_gen = Rcpp::wrap(nn_matchC_closest(distance_mat, treat, ratio, discarded, reuse_max, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, disl_prog)); + rcpp_result_gen = Rcpp::wrap(nn_matchC_mahcovs_closest(treat, ratio, discarded, reuse_max, mah_covs, distance_, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, close, disl_prog)); return rcpp_result_gen; END_RCPP } // nn_matchC_vec -IntegerMatrix nn_matchC_vec(const IntegerVector& treat_, const IntegerVector& ord_, const IntegerVector& ratio_, const LogicalVector& discarded_, const int& reuse_max, const NumericVector& distance_, const Nullable& exact_, const Nullable& caliper_dist_, const Nullable& caliper_covs_, const Nullable& caliper_covs_mat_, const Nullable& antiexact_covs_, const Nullable& unit_id_, const bool& disl_prog); -RcppExport SEXP _MatchIt_nn_matchC_vec(SEXP treat_SEXP, SEXP ord_SEXP, SEXP ratio_SEXP, SEXP discarded_SEXP, SEXP reuse_maxSEXP, SEXP distance_SEXP, SEXP exact_SEXP, SEXP caliper_dist_SEXP, SEXP caliper_covs_SEXP, SEXP caliper_covs_mat_SEXP, SEXP antiexact_covs_SEXP, SEXP unit_id_SEXP, SEXP disl_progSEXP) { +IntegerMatrix nn_matchC_vec(const IntegerVector& treat_, const IntegerVector& ord, const IntegerVector& ratio, const LogicalVector& discarded, const int& reuse_max, const int& focal_, const NumericVector& distance, const Nullable& exact_, const Nullable& caliper_dist_, const Nullable& caliper_covs_, const Nullable& caliper_covs_mat_, const Nullable& antiexact_covs_, const Nullable& unit_id_, const bool& disl_prog); +RcppExport SEXP _MatchIt_nn_matchC_vec(SEXP treat_SEXP, SEXP ordSEXP, SEXP ratioSEXP, SEXP discardedSEXP, SEXP reuse_maxSEXP, SEXP focal_SEXP, SEXP distanceSEXP, SEXP exact_SEXP, SEXP caliper_dist_SEXP, SEXP caliper_covs_SEXP, SEXP caliper_covs_mat_SEXP, SEXP antiexact_covs_SEXP, SEXP unit_id_SEXP, SEXP disl_progSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; Rcpp::traits::input_parameter< const IntegerVector& >::type treat_(treat_SEXP); - Rcpp::traits::input_parameter< const IntegerVector& >::type ord_(ord_SEXP); - Rcpp::traits::input_parameter< const IntegerVector& >::type ratio_(ratio_SEXP); - Rcpp::traits::input_parameter< const LogicalVector& >::type discarded_(discarded_SEXP); + Rcpp::traits::input_parameter< const IntegerVector& >::type ord(ordSEXP); + Rcpp::traits::input_parameter< const IntegerVector& >::type ratio(ratioSEXP); + Rcpp::traits::input_parameter< const LogicalVector& >::type discarded(discardedSEXP); Rcpp::traits::input_parameter< const int& >::type reuse_max(reuse_maxSEXP); - Rcpp::traits::input_parameter< const NumericVector& >::type distance_(distance_SEXP); + Rcpp::traits::input_parameter< const int& >::type focal_(focal_SEXP); + Rcpp::traits::input_parameter< const NumericVector& >::type distance(distanceSEXP); Rcpp::traits::input_parameter< const Nullable& >::type exact_(exact_SEXP); Rcpp::traits::input_parameter< const Nullable& >::type caliper_dist_(caliper_dist_SEXP); Rcpp::traits::input_parameter< const Nullable& >::type caliper_covs_(caliper_covs_SEXP); @@ -90,34 +177,83 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const Nullable& >::type antiexact_covs_(antiexact_covs_SEXP); Rcpp::traits::input_parameter< const Nullable& >::type unit_id_(unit_id_SEXP); Rcpp::traits::input_parameter< const bool& >::type disl_prog(disl_progSEXP); - rcpp_result_gen = Rcpp::wrap(nn_matchC_vec(treat_, ord_, ratio_, discarded_, reuse_max, distance_, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, disl_prog)); + rcpp_result_gen = Rcpp::wrap(nn_matchC_vec(treat_, ord, ratio, discarded, reuse_max, focal_, distance, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, disl_prog)); + return rcpp_result_gen; +END_RCPP +} +// nn_matchC_vec_closest +IntegerMatrix nn_matchC_vec_closest(const IntegerVector& treat, const IntegerVector& ratio, const LogicalVector& discarded, const int& reuse_max, const NumericVector& distance, const Nullable& exact_, const Nullable& caliper_dist_, const Nullable& caliper_covs_, const Nullable& caliper_covs_mat_, const Nullable& antiexact_covs_, const Nullable& unit_id_, const bool& close, const bool& disl_prog); +RcppExport SEXP _MatchIt_nn_matchC_vec_closest(SEXP treatSEXP, SEXP ratioSEXP, SEXP discardedSEXP, SEXP reuse_maxSEXP, SEXP distanceSEXP, SEXP exact_SEXP, SEXP caliper_dist_SEXP, SEXP caliper_covs_SEXP, SEXP caliper_covs_mat_SEXP, SEXP antiexact_covs_SEXP, SEXP unit_id_SEXP, SEXP closeSEXP, SEXP disl_progSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const IntegerVector& >::type treat(treatSEXP); + Rcpp::traits::input_parameter< const IntegerVector& >::type ratio(ratioSEXP); + Rcpp::traits::input_parameter< const LogicalVector& >::type discarded(discardedSEXP); + Rcpp::traits::input_parameter< const int& >::type reuse_max(reuse_maxSEXP); + Rcpp::traits::input_parameter< const NumericVector& >::type distance(distanceSEXP); + Rcpp::traits::input_parameter< const Nullable& >::type exact_(exact_SEXP); + Rcpp::traits::input_parameter< const Nullable& >::type caliper_dist_(caliper_dist_SEXP); + Rcpp::traits::input_parameter< const Nullable& >::type caliper_covs_(caliper_covs_SEXP); + Rcpp::traits::input_parameter< const Nullable& >::type caliper_covs_mat_(caliper_covs_mat_SEXP); + Rcpp::traits::input_parameter< const Nullable& >::type antiexact_covs_(antiexact_covs_SEXP); + Rcpp::traits::input_parameter< const Nullable& >::type unit_id_(unit_id_SEXP); + Rcpp::traits::input_parameter< const bool& >::type close(closeSEXP); + Rcpp::traits::input_parameter< const bool& >::type disl_prog(disl_progSEXP); + rcpp_result_gen = Rcpp::wrap(nn_matchC_vec_closest(treat, ratio, discarded, reuse_max, distance, exact_, caliper_dist_, caliper_covs_, caliper_covs_mat_, antiexact_covs_, unit_id_, close, disl_prog)); return rcpp_result_gen; END_RCPP } // pairdistsubC -double pairdistsubC(const NumericVector& x_, const IntegerVector& t_, const IntegerVector& s_, const int& num_sub); -RcppExport SEXP _MatchIt_pairdistsubC(SEXP x_SEXP, SEXP t_SEXP, SEXP s_SEXP, SEXP num_subSEXP) { +double pairdistsubC(const NumericVector& x, const IntegerVector& t, const IntegerVector& s); +RcppExport SEXP _MatchIt_pairdistsubC(SEXP xSEXP, SEXP tSEXP, SEXP sSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; - Rcpp::traits::input_parameter< const NumericVector& >::type x_(x_SEXP); - Rcpp::traits::input_parameter< const IntegerVector& >::type t_(t_SEXP); - Rcpp::traits::input_parameter< const IntegerVector& >::type s_(s_SEXP); - Rcpp::traits::input_parameter< const int& >::type num_sub(num_subSEXP); - rcpp_result_gen = Rcpp::wrap(pairdistsubC(x_, t_, s_, num_sub)); + Rcpp::traits::input_parameter< const NumericVector& >::type x(xSEXP); + Rcpp::traits::input_parameter< const IntegerVector& >::type t(tSEXP); + Rcpp::traits::input_parameter< const IntegerVector& >::type s(sSEXP); + rcpp_result_gen = Rcpp::wrap(pairdistsubC(x, t, s)); return rcpp_result_gen; END_RCPP } // subclass2mmC -IntegerMatrix subclass2mmC(const IntegerVector& subclass, const IntegerVector& treat, const int& focal); -RcppExport SEXP _MatchIt_subclass2mmC(SEXP subclassSEXP, SEXP treatSEXP, SEXP focalSEXP) { +IntegerMatrix subclass2mmC(const IntegerVector& subclass_, const IntegerVector& treat, const int& focal); +RcppExport SEXP _MatchIt_subclass2mmC(SEXP subclass_SEXP, SEXP treatSEXP, SEXP focalSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; - Rcpp::traits::input_parameter< const IntegerVector& >::type subclass(subclassSEXP); + Rcpp::traits::input_parameter< const IntegerVector& >::type subclass_(subclass_SEXP); Rcpp::traits::input_parameter< const IntegerVector& >::type treat(treatSEXP); Rcpp::traits::input_parameter< const int& >::type focal(focalSEXP); - rcpp_result_gen = Rcpp::wrap(subclass2mmC(subclass, treat, focal)); + rcpp_result_gen = Rcpp::wrap(subclass2mmC(subclass_, treat, focal)); + return rcpp_result_gen; +END_RCPP +} +// mm2subclassC +IntegerVector mm2subclassC(const IntegerMatrix& mm, const IntegerVector& treat, const Nullable& focal); +RcppExport SEXP _MatchIt_mm2subclassC(SEXP mmSEXP, SEXP treatSEXP, SEXP focalSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const IntegerMatrix& >::type mm(mmSEXP); + Rcpp::traits::input_parameter< const IntegerVector& >::type treat(treatSEXP); + Rcpp::traits::input_parameter< const Nullable& >::type focal(focalSEXP); + rcpp_result_gen = Rcpp::wrap(mm2subclassC(mm, treat, focal)); + return rcpp_result_gen; +END_RCPP +} +// subclass_scootC +IntegerVector subclass_scootC(const IntegerVector& subclass_, const IntegerVector& treat_, const NumericVector& x_, const int& min_n); +RcppExport SEXP _MatchIt_subclass_scootC(SEXP subclass_SEXP, SEXP treat_SEXP, SEXP x_SEXP, SEXP min_nSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const IntegerVector& >::type subclass_(subclass_SEXP); + Rcpp::traits::input_parameter< const IntegerVector& >::type treat_(treat_SEXP); + Rcpp::traits::input_parameter< const NumericVector& >::type x_(x_SEXP); + Rcpp::traits::input_parameter< const int& >::type min_n(min_nSEXP); + rcpp_result_gen = Rcpp::wrap(subclass_scootC(subclass_, treat_, x_, min_n)); return rcpp_result_gen; END_RCPP } @@ -134,14 +270,15 @@ BEGIN_RCPP END_RCPP } // weights_matrixC -NumericVector weights_matrixC(const IntegerMatrix& mm, const IntegerVector& treat); -RcppExport SEXP _MatchIt_weights_matrixC(SEXP mmSEXP, SEXP treatSEXP) { +NumericVector weights_matrixC(const IntegerMatrix& mm, const IntegerVector& treat_, const Nullable& focal); +RcppExport SEXP _MatchIt_weights_matrixC(SEXP mmSEXP, SEXP treat_SEXP, SEXP focalSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; Rcpp::traits::input_parameter< const IntegerMatrix& >::type mm(mmSEXP); - Rcpp::traits::input_parameter< const IntegerVector& >::type treat(treatSEXP); - rcpp_result_gen = Rcpp::wrap(weights_matrixC(mm, treat)); + Rcpp::traits::input_parameter< const IntegerVector& >::type treat_(treat_SEXP); + Rcpp::traits::input_parameter< const Nullable& >::type focal(focalSEXP); + rcpp_result_gen = Rcpp::wrap(weights_matrixC(mm, treat_, focal)); return rcpp_result_gen; END_RCPP } @@ -161,14 +298,22 @@ RcppExport SEXP _MatchIt_RcppExport_registerCCallable() { } static const R_CallMethodDef CallEntries[] = { - {"_MatchIt_dist_to_matrixC", (DL_FUNC) &_MatchIt_dist_to_matrixC, 1}, - {"_MatchIt_nn_matchC", (DL_FUNC) &_MatchIt_nn_matchC, 15}, - {"_MatchIt_nn_matchC_closest", (DL_FUNC) &_MatchIt_nn_matchC_closest, 12}, - {"_MatchIt_nn_matchC_vec", (DL_FUNC) &_MatchIt_nn_matchC_vec, 13}, - {"_MatchIt_pairdistsubC", (DL_FUNC) &_MatchIt_pairdistsubC, 4}, + {"_MatchIt_all_equal_to", (DL_FUNC) &_MatchIt_all_equal_to, 2}, + {"_MatchIt_eucdistC_N1xN0", (DL_FUNC) &_MatchIt_eucdistC_N1xN0, 2}, + {"_MatchIt_get_splitsC", (DL_FUNC) &_MatchIt_get_splitsC, 2}, + {"_MatchIt_has_n_unique", (DL_FUNC) &_MatchIt_has_n_unique, 2}, + {"_MatchIt_nn_matchC_distmat", (DL_FUNC) &_MatchIt_nn_matchC_distmat, 14}, + {"_MatchIt_nn_matchC_distmat_closest", (DL_FUNC) &_MatchIt_nn_matchC_distmat_closest, 13}, + {"_MatchIt_nn_matchC_mahcovs", (DL_FUNC) &_MatchIt_nn_matchC_mahcovs, 15}, + {"_MatchIt_nn_matchC_mahcovs_closest", (DL_FUNC) &_MatchIt_nn_matchC_mahcovs_closest, 14}, + {"_MatchIt_nn_matchC_vec", (DL_FUNC) &_MatchIt_nn_matchC_vec, 14}, + {"_MatchIt_nn_matchC_vec_closest", (DL_FUNC) &_MatchIt_nn_matchC_vec_closest, 13}, + {"_MatchIt_pairdistsubC", (DL_FUNC) &_MatchIt_pairdistsubC, 3}, {"_MatchIt_subclass2mmC", (DL_FUNC) &_MatchIt_subclass2mmC, 3}, + {"_MatchIt_mm2subclassC", (DL_FUNC) &_MatchIt_mm2subclassC, 3}, + {"_MatchIt_subclass_scootC", (DL_FUNC) &_MatchIt_subclass_scootC, 4}, {"_MatchIt_tabulateC", (DL_FUNC) &_MatchIt_tabulateC, 2}, - {"_MatchIt_weights_matrixC", (DL_FUNC) &_MatchIt_weights_matrixC, 2}, + {"_MatchIt_weights_matrixC", (DL_FUNC) &_MatchIt_weights_matrixC, 3}, {"_MatchIt_RcppExport_registerCCallable", (DL_FUNC) &_MatchIt_RcppExport_registerCCallable, 0}, {NULL, NULL, 0} }; diff --git a/src/all_equal_to.cpp b/src/all_equal_to.cpp new file mode 100644 index 00000000..9a72a771 --- /dev/null +++ b/src/all_equal_to.cpp @@ -0,0 +1,35 @@ +#include "internal.h" +using namespace Rcpp; + +// [[Rcpp::plugins(cpp11)]] + +// Templated function to check if a vector has exactly n unique values +template +bool all_equal_to_(Vector x, + typename traits::storage_type::type y) { + + for (auto xi : x) { + if (xi != y) { + return false; + } + } + + return true; +} + +// Wrapper function to handle different types of R vectors +// [[Rcpp::export]] +bool all_equal_to(RObject x, + RObject y) { + + switch (TYPEOF(x)) { + case INTSXP: + return all_equal_to_(as(x), as(y)); + case REALSXP: + return all_equal_to_(as(x), as(y)); + case LGLSXP: + return all_equal_to_(as(x), as(y)); + default: + stop("Unsupported vector type"); + } +} \ No newline at end of file diff --git a/src/dist_to_matrix.cpp b/src/dist_to_matrix.cpp deleted file mode 100644 index 4884023f..00000000 --- a/src/dist_to_matrix.cpp +++ /dev/null @@ -1,30 +0,0 @@ -#include -using namespace Rcpp; -//Faster alternative to stats:::as.matrix.dist(). - -// [[Rcpp::export]] -NumericMatrix dist_to_matrixC(const NumericVector& d) { - int n = d.attr("Size"); - - NumericMatrix m(n,n); - - double dk; - int i, j; - int k = 0; - - for (i = 0; i < n; i++) { - for (j = i + 1; j < n; j++) { - dk = d[k]; - m(i, j) = dk; - m(j, i) = dk; - k++; - } - } - - if (d.hasAttribute("Labels")) { - CharacterVector lab = d.attr("Labels"); - rownames(m) = lab; - colnames(m) = lab; - } - return m; -} diff --git a/src/eta_progress_bar.h b/src/eta_progress_bar.h new file mode 100644 index 00000000..590338e7 --- /dev/null +++ b/src/eta_progress_bar.h @@ -0,0 +1,259 @@ +/* + * eta_progress_bar.h + * + * A custom ProgressBar class to display a progress bar with time estimation + * + * Author: clemens@nevrome.de + * + * Copied from https://github.com/kforner/rcpp_progress/blob/master/inst/examples/RcppProgressETA/src/eta_progress_bar.hpp with modifications by NHG + * + */ +#ifndef _RcppProgress_ETA_PROGRESS_BAR_HPP +#define _RcppProgress_ETA_PROGRESS_BAR_HPP + +#include +#include +#include +#include +#include + +#include +#include "progress_bar.hpp" + +// for unices only +#if !defined(WIN32) && !defined(__WIN32) && !defined(__WIN32__) +#include +#endif + +class ETAProgressBar: public ProgressBar{ + public: // ====== LIFECYCLE ===== + + /** + * Main constructor + */ + ETAProgressBar() { + _max_ticks = 50; + _finalized = false; + _timer_flag = true; + } + + ~ETAProgressBar() { + } + + public: // ===== main methods ===== + + void display() { + REprintf("0%% 10 20 30 40 50 60 70 80 90 100%%\n"); + REprintf("[----|----|----|----|----|----|----|----|----|----|\n"); + flush_console(); + } + + // update display + void update(float progress) { + + // stop if already finalized + if (_finalized) return; + + // measure current time + time(¤t_time); + + // start time measurement when update() is called the first time + if (_timer_flag) { + _timer_flag = false; + // measure start time + time_at_start = current_time; + + time_at_last_refresh = current_time; + + progress_at_last_refresh = progress; + + _num_ticks = _compute_nb_ticks(progress); + + time_string = "calculating..."; + + // create progress bar string + std::string progress_bar_string = _current_ticks_display(_num_ticks); + + // merge progress bar and time string + std::stringstream strs; + strs << "|" << progress_bar_string << "| ETA: " << time_string; + std::string temp_str = strs.str(); + char const* char_type = temp_str.c_str(); + + // print: remove old and replace with new + REprintf("\r"); + REprintf("%s", char_type); + } else { + + double time_since_start = std::difftime(current_time, time_at_start); + + if (progress != 1) { + // ensure overwriting of old time info + int empty_length = time_string.length(); + + int _num_ticks_current = _compute_nb_ticks(progress); + + bool update_bar = (_num_ticks_current != _num_ticks); + + _num_ticks = _num_ticks_current; + + if (progress > 0 && time_since_start > 1) { + double time_since_last_refresh = std::difftime(current_time, time_at_last_refresh); + + if (time_since_last_refresh >= .5) { + update_bar = true; + + double progress_since_last_refresh = progress - progress_at_last_refresh; + + double total_rate = progress / time_since_start; + + if (progress_since_last_refresh == 0) { + progress_since_last_refresh = .0000001; + } + + double current_rate = progress_since_last_refresh / time_since_last_refresh; + + //alpha weights average rate against recent recent (current) rate; + //alpha = 1 => estimate based on on total_rate (treats as constant) + //alpha = 0 => estimate based on recent rate (high fluctuation) + double alpha = .8; + + double eta = (1 - progress) * (alpha / total_rate + (1 - alpha) / current_rate); + + // convert seconds to time string + time_string = "~"; + time_string += _time_to_string(eta); + + time_at_last_refresh = current_time; + progress_at_last_refresh = progress; + } + } + + if (update_bar) { + // create progress bar string + std::string progress_bar_string = _current_ticks_display(_num_ticks); + + std::string empty_space = std::string(std::fdim(empty_length, time_string.length()), ' '); + + // merge progress bar and time string + std::stringstream strs; + strs << "|" << progress_bar_string << "| ETA: " << time_string << empty_space; + std::string temp_str = strs.str(); + char const* char_type = temp_str.c_str(); + + // print: remove old and replace with new + REprintf("\r"); + REprintf("%s", char_type); + } + + } else { + // ensure overwriting of old time info + int empty_length = time_string.length(); + + // finalize display when ready + + // convert seconds to time string + std::string time_string = _time_to_string(time_since_start); + + std::string empty_space = std::string(std::fdim(empty_length, time_string.length()), ' '); + + // create progress bar string + _num_ticks = _compute_nb_ticks(progress); + + std::string progress_bar_string = _current_ticks_display(_num_ticks); + + // merge progress bar and time string + std::stringstream strs; + strs << "|" << progress_bar_string << "| " << "Elapsed: " << time_string << empty_space; + + std::string temp_str = strs.str(); + char const* char_type = temp_str.c_str(); + + // print: remove old and replace with new + REprintf("\r"); + REprintf("%s", char_type); + + _finalize_display(); + } + } + } + + void end_display() { + update(1); + } + + protected: // ==== other instance methods ===== + + // convert double with seconds to time string + std::string _time_to_string(double seconds) { + + int time = (int) seconds; + + int hour = 0; + int min = 0; + int sec = 0; + + hour = time / 3600; + time = time % 3600; + min = time / 60; + time = time % 60; + sec = time; + + std::stringstream time_strs; + if (hour != 0) time_strs << hour << "h "; + if (min != 0) time_strs << min << "min "; + if (sec != 0 || (hour == 0 && min == 0)) time_strs << sec << "s "; + std::string time_str = time_strs.str(); + + return time_str; + } + + // update the ticks display corresponding to progress + std::string _current_ticks_display(int nb) { + + std::stringstream ticks_strs; + for (int i = 0; i < (_max_ticks - 1); ++i) { + if (i < nb) { + ticks_strs << "="; + } else { + ticks_strs << " "; + } + } + std::string tick_space_string = ticks_strs.str(); + + return tick_space_string; + } + + // finalize + void _finalize_display() { + if (_finalized) return; + + REprintf("\n"); + flush_console(); + _finalized = true; + } + + // compute number of ticks according to progress + int _compute_nb_ticks(float progress) { + return int(progress * _max_ticks); + } + + // N.B: does nothing on windows + void flush_console() { +#if !defined(WIN32) && !defined(__WIN32) && !defined(__WIN32__) + R_FlushConsole(); +#endif + } + + private: // ===== INSTANCE VARIABLES ==== + int _max_ticks; // the total number of ticks to print + int _num_ticks; + bool _finalized; + bool _timer_flag; + time_t time_at_start, current_time, time_at_last_refresh; + float progress_at_last_refresh; + std::string time_string; + +}; + +#endif \ No newline at end of file diff --git a/src/eucdistC.cpp b/src/eucdistC.cpp new file mode 100644 index 00000000..0cc68869 --- /dev/null +++ b/src/eucdistC.cpp @@ -0,0 +1,34 @@ +#include "internal.h" +using namespace Rcpp; + +// [[Rcpp::plugins(cpp11)]] + +// [[Rcpp::export]] +NumericVector eucdistC_N1xN0(const NumericMatrix& x, + const IntegerVector& t) { + + IntegerVector ind0 = which(t == 0); + IntegerVector ind1 = which(t == 1); + int p = x.ncol(); + int i; + double d, di; + + NumericVector dist(ind1.size() * ind0.size()); + + int k = 0; + for (double i0 : ind0) { + for (double i1 : ind1) { + d = 0; + for (i = 0; i < p; i++) { + di = x(i0, i) - x(i1, i); + d += di * di; + } + dist[k] = sqrt(d); + k++; + } + } + + dist.attr("dim") = Dimension(ind1.size(), ind0.size()); + + return dist; +} diff --git a/src/get_splitsC.cpp b/src/get_splitsC.cpp new file mode 100644 index 00000000..fde059d6 --- /dev/null +++ b/src/get_splitsC.cpp @@ -0,0 +1,32 @@ +#include "internal.h" +using namespace Rcpp; + +// [[Rcpp::plugins(cpp11)]] + +// [[Rcpp::export]] +NumericVector get_splitsC(const NumericVector& x, + const double& caliper) { + + NumericVector splits; + + NumericVector x_ = unique(x); + NumericVector x_sorted = x_.sort(); + + R_xlen_t n = x_sorted.size(); + + if (n <= 1) { + return splits; + } + + splits = x_sorted[0]; + + for (int i = 1; i < x_sorted.length(); i++) { + if (x_sorted[i] - x_sorted[i - 1] <= caliper) continue; + + splits.push_back((x_sorted[i] + x_sorted[i - 1]) / 2); + } + + splits.push_back(x_sorted[n - 1]); + + return splits; +} \ No newline at end of file diff --git a/src/has_n_unique.cpp b/src/has_n_unique.cpp new file mode 100644 index 00000000..714569a9 --- /dev/null +++ b/src/has_n_unique.cpp @@ -0,0 +1,64 @@ +#include "internal.h" +using namespace Rcpp; + +// [[Rcpp::plugins(cpp11)]] + +// Templated function to check if a vector has exactly n unique values +template +bool has_n_unique_(Vector x, + const int& n) { + + Vector seen(n); + seen[0] = x[0]; + int n_seen = 1; + + int j; + bool was_seen; + + // Iterate over the vector and add elements to the unordered set + for (auto it = x.begin() + 1; it != x.end(); ++it) { + if (*it == *(it - 1)) { + continue; + } + + was_seen = false; + + for (j = 0; j < n_seen; j++) { + if (*it == seen[j]) { + was_seen = true; + break; + } + } + + if (!was_seen) { + n_seen++; + + if (n_seen > n) { + return false; + } + + seen[n_seen - 1] = *it; + } + } + + // Check if the number of unique elements is exactly n + return n_seen == n; +} + +// Wrapper function to handle different types of R vectors +// [[Rcpp::export]] +bool has_n_unique(const SEXP& x, + const int& n) { + switch (TYPEOF(x)) { + case INTSXP: + return has_n_unique_(x, n); + case REALSXP: + return has_n_unique_(x, n); + case STRSXP: + return has_n_unique_(x, n); + case LGLSXP: + return has_n_unique_(x, n); + default: + stop("Unsupported vector type"); + } +} \ No newline at end of file diff --git a/src/internal.cpp b/src/internal.cpp index 3c98d77b..a401f7e2 100644 --- a/src/internal.cpp +++ b/src/internal.cpp @@ -1,6 +1,9 @@ #include +#include using namespace Rcpp; +// [[Rcpp::plugins(cpp11)]] + // Rcpp internal functions //C implementation of tabulate(). Faster than base::tabulate(), but real @@ -8,17 +11,20 @@ using namespace Rcpp; // [[Rcpp::interfaces(cpp)]] IntegerVector tabulateC_(const IntegerVector& bins, - const Nullable& nbins = R_NilValue) { + const int& nbins = 0) { int max_bin; - if (nbins.isNotNull()) max_bin = as(nbins); + + if (nbins > 0) max_bin = nbins; else max_bin = max(na_omit(bins)); IntegerVector counts(max_bin); int n = bins.size(); for (int i = 0; i < n; i++) { - if (bins[i] > 0 && bins[i] <= max_bin) + if (bins[i] > 0 && bins[i] <= max_bin) { counts[bins[i] - 1]++; + } } + return counts; } @@ -28,4 +34,770 @@ IntegerVector tabulateC_(const IntegerVector& bins, IntegerVector which(const LogicalVector& x) { IntegerVector ind = Range(0, x.size() - 1); return ind[x]; +} + +// [[Rcpp::interfaces(cpp)]] +bool antiexact_okay(const int& aenc, + const int& i, + const int& j, + const IntegerMatrix& antiexact_covs) { + if (aenc == 0) { + return true; + } + + for (int k = 0; k < aenc; k++) { + if (antiexact_covs(i, k) == antiexact_covs(j, k)) { + return false; + } + } + + return true; +} + +// [[Rcpp::interfaces(cpp)]] +bool caliper_covs_okay(const int& ncc, + const int& i, + const int& j, + const NumericMatrix& caliper_covs_mat, + const NumericVector& caliper_covs) { + if (ncc == 0) { + return true; + } + + for (int k = 0; k < ncc; k++) { + if (caliper_covs[k] >= 0) { + if (std::abs(caliper_covs_mat(i, k) - caliper_covs_mat(j, k)) > caliper_covs[k]) { + return false; + } + } + else { + if (std::abs(caliper_covs_mat(i, k) - caliper_covs_mat(j, k)) <= -caliper_covs[k]) { + return false; + } + } + } + + return true; +} + +// [[Rcpp::interfaces(cpp)]] +bool caliper_dist_okay(const bool& use_caliper_dist, + const int& i, + const int& j, + const NumericVector& distance, + const double& caliper_dist) { + if (!use_caliper_dist) { + return true; + } + + return std::abs(distance[i] - distance[j]) <= caliper_dist; +} + +// [[Rcpp::interfaces(cpp)]] +bool mm_okay(const int& r, + const int& i, + const IntegerVector& mm_rowi) { + + if (r > 1) { + for (int j : na_omit(mm_rowi)) { + if (i == j) { + return false; + } + } + } + + return true; +} + +// [[Rcpp::interfaces(cpp)]] +bool exact_okay(const bool& use_exact, + const int& i, + const int& j, + const IntegerVector& exact) { + + if (!use_exact) { + return true; + } + + return exact[i] == exact[j]; +} + +// [[Rcpp::interfaces(cpp)]] +std::vector find_control_vec(const int& t_id, + const IntegerVector& ind_d_ord, + const IntegerVector& match_d_ord, + const IntegerVector& treat, + const NumericVector& distance, + const LogicalVector& eligible, + const int& gi, + const int& r, + const IntegerVector& mm_rowi_, + const int& ncc, + const NumericMatrix& caliper_covs_mat, + const NumericVector& caliper_covs, + const double& caliper_dist, + const bool& use_exact, + const IntegerVector& exact, + const int& aenc, + const IntegerMatrix& antiexact_covs, + const IntegerVector& first_control, + const IntegerVector& last_control, + const int& ratio = 1, + const int& prev_start = -1) { + + int ii = match_d_ord[t_id]; + + IntegerVector mm_rowi; + std::vector possible_starts; + + + if (r > 1) { + mm_rowi = na_omit(mm_rowi_); + mm_rowi = mm_rowi[as(treat[mm_rowi]) == gi]; + possible_starts.reserve(mm_rowi.size() + 2); + + for (int mmi : mm_rowi) { + possible_starts.push_back(match_d_ord[mmi]); + } + } + else { + possible_starts.reserve(2); + } + + possible_starts.push_back(ii); + + if (prev_start >= 0) { + possible_starts.push_back(match_d_ord[prev_start]); + } + + int iil, iir; + double min_dist; + + if (possible_starts.size() == 1) { + iil = ii; + iir = ii; + min_dist = 0; + } + else { + iil = *std::min_element(possible_starts.begin(), possible_starts.end()); + iir = *std::max_element(possible_starts.begin(), possible_starts.end()); + + if (iil == ii) { + min_dist = std::abs(distance[t_id] - distance[ind_d_ord[iir]]); + } + else if (iir == ii) { + min_dist = std::abs(distance[t_id] - distance[ind_d_ord[iil]]); + } + else { + min_dist = std::max(std::abs(distance[t_id] - distance[ind_d_ord[iil]]), + std::abs(distance[t_id] - distance[ind_d_ord[iir]])); + } + } + + int min_ii = first_control[gi]; + int max_ii = last_control[gi]; + + double di = distance[t_id]; + + bool l_stop = false; + bool r_stop = false; + + double dist_c; + + std::vector potential_matches_id; + potential_matches_id.reserve(2 * ratio); + std::vector potential_matches_dist; + potential_matches_dist.reserve(2 * ratio); + + int num_matches_l = 0; + int num_matches_r = 0; + + int iz; + int z = 1; + int num_closer_than_dist_c; + + while (!l_stop || !r_stop) { + if (l_stop) { + z = 1; + } + else if (r_stop) { + z = -1; + } + else { + z *= -1; + } + + if (z == -1) { + if (iil <= min_ii || num_matches_l == ratio) { + l_stop = true; + continue; + } + + iil += z; + iz = ind_d_ord[iil]; + } + else { + if (iir >= max_ii || num_matches_r == ratio) { + r_stop = true; + continue; + } + + iir += z; + iz = ind_d_ord[iir]; + } + + if (!eligible[iz]) { + continue; + } + + if (treat[iz] != gi) { + continue; + } + + if (!mm_okay(r, iz, mm_rowi)) { + continue; + } + + dist_c = std::abs(di - distance[iz]); + + //If current dist is worse than ratio dists, continue + if (potential_matches_id.size() >= ratio) { + num_closer_than_dist_c = 0; + for (double d : potential_matches_dist) { + if (d < dist_c) { + num_closer_than_dist_c++; + if (num_closer_than_dist_c == ratio) { + break; + } + } + } + + if (num_closer_than_dist_c >= ratio) { + if (z == -1) { + l_stop = true; + } + else { + r_stop = true; + } + continue; + } + } + + if (dist_c > caliper_dist) { + if (z == -1) { + l_stop = true; + } + else { + r_stop = true; + } + continue; + } + + if (dist_c < min_dist) { + continue; + } + + if (!exact_okay(use_exact, t_id, iz, exact)) { + continue; + } + + if (!antiexact_okay(aenc, t_id, iz, antiexact_covs)) { + continue; + } + + if (!caliper_covs_okay(ncc, t_id, iz, caliper_covs_mat, caliper_covs)) { + continue; + } + + potential_matches_id.push_back(iz); + potential_matches_dist.push_back(dist_c); + + if (z == -1) { + num_matches_l++; + if (num_matches_l == ratio) { + l_stop = true; + } + } + else { + num_matches_r++; + if (num_matches_r == ratio) { + r_stop = true; + } + } + } + + int n_potential_matches = potential_matches_id.size(); + + if (n_potential_matches <= 1) { + return potential_matches_id; + } + + if (n_potential_matches <= ratio && + std::is_sorted(potential_matches_dist.begin(), + potential_matches_dist.end())) { + return potential_matches_id; + } + + std::vector ind(n_potential_matches); + std::iota(ind.begin(), ind.end(), 0); + + std::vector matches_out; + + if (n_potential_matches > ratio) { + std::partial_sort(ind.begin(), ind.begin() + ratio, ind.end(), + [&potential_matches_dist](int a, int b){ + return potential_matches_dist[a] < potential_matches_dist[b]; + }); + + matches_out.reserve(ratio); + + for (auto it = ind.begin(); it != ind.begin() + ratio; ++it) { + matches_out.push_back(potential_matches_id[*it]); + } + } + else { + std::sort(ind.begin(), ind.end(), + [&potential_matches_dist](int a, int b){ + return potential_matches_dist[a] < potential_matches_dist[b]; + }); + + matches_out.reserve(n_potential_matches); + + for (auto it = ind.begin(); it != ind.end(); ++it) { + matches_out.push_back(potential_matches_id[*it]); + } + } + + return matches_out; +} + +// [[Rcpp::interfaces(cpp)]] +std::vector find_control_mahcovs(const int& t_id, + const IntegerVector& ind_d_ord, + const IntegerVector& match_d_ord, + const NumericVector& match_var, + const double& match_var_caliper, + const IntegerVector& treat, + const NumericVector& distance, + const LogicalVector& eligible, + const int& gi, + const int& r, + const IntegerVector& mm_rowi, + const NumericMatrix& mah_covs, + const int& ncc, + const NumericMatrix& caliper_covs_mat, + const NumericVector& caliper_covs, + const bool& use_caliper_dist, + const double& caliper_dist, + const bool& use_exact, + const IntegerVector& exact, + const int& aenc, + const IntegerMatrix& antiexact_covs, + const int& ratio = 1) { + + int ii = match_d_ord[t_id]; + + int iil, iir; + + iil = ii; + iir = ii; + + int min_ii = 0; + int max_ii = match_d_ord.size() - 1; + + bool l_stop = false; + bool r_stop = false; + + double dist_c; + + std::vector> potential_matches; + potential_matches.reserve(ratio); + + std::pair new_match; + + int num_matches_l = 0; + int num_matches_r = 0; + + double mv_i = match_var[t_id]; + double mv_dist; + + int iz; + int z = 1; + + auto dist_comp = [](std::pair a, std::pair b) { + return a.second < b.second; + }; + + while (!l_stop || !r_stop) { + if (l_stop) { + z = 1; + } + else if (r_stop) { + z = -1; + } + else { + z *= -1; + } + + if (z == -1) { + if (iil <= min_ii || num_matches_l == ratio) { + l_stop = true; + continue; + } + + iil += z; + iz = ind_d_ord[iil]; + } + else { + if (iir >= max_ii || num_matches_r == ratio) { + r_stop = true; + continue; + } + + iir += z; + iz = ind_d_ord[iir]; + } + + if (!eligible[iz]) { + continue; + } + + if (treat[iz] != gi) { + continue; + } + + if (!mm_okay(r, iz, mm_rowi)) { + continue; + } + + mv_dist = pow(mv_i - match_var[iz], 2); + + if (mv_dist > match_var_caliper) { + if (z == -1) { + l_stop = true; + } + else { + r_stop = true; + } + + continue; + } + + //If current dist is worse than ratio dists, continue + if (potential_matches.size() == ratio) { + if (potential_matches.back().second < mv_dist) { + if (z == -1) { + l_stop = true; + } + else { + r_stop = true; + } + + continue; + } + } + + if (!exact_okay(use_exact, t_id, iz, exact)) { + continue; + } + + if (!caliper_dist_okay(use_caliper_dist, t_id, iz, distance, caliper_dist)) { + continue; + } + + if (!antiexact_okay(aenc, t_id, iz, antiexact_covs)) { + continue; + } + + if (!caliper_covs_okay(ncc, t_id, iz, caliper_covs_mat, caliper_covs)) { + continue; + } + + dist_c = sum(pow(mah_covs.row(t_id) - mah_covs.row(iz), 2.0)); + + if (!std::isfinite(dist_c)) { + continue; + } + + new_match = std::pair(iz, dist_c); + + if (potential_matches.empty()) { + potential_matches.push_back(new_match); + } + else if (dist_c > potential_matches.back().second) { + if (potential_matches.size() == ratio) { + continue; + } + + potential_matches.push_back(new_match); + } + else if (ratio == 1) { + potential_matches[0] = new_match; + } + else { + if (potential_matches.size() == ratio) { + potential_matches.pop_back(); + } + + if (dist_c > potential_matches.back().second) { + potential_matches.push_back(new_match); + } + else { + potential_matches.insert(std::lower_bound(potential_matches.begin(), potential_matches.end(), + new_match, dist_comp), + new_match); + } + } + } + + std::vector matches_out; + matches_out.reserve(potential_matches.size()); + + for (auto p : potential_matches) { + matches_out.push_back(p.first); + } + + return matches_out; +} + +// [[Rcpp::interfaces(cpp)]] +std::vector find_control_mat(const int& t_id, + const IntegerVector& treat, + const IntegerVector& ind_non_focal, + const NumericVector& distance_mat_row_i, + const LogicalVector& eligible, + const int& gi, + const int& r, + const IntegerVector& mm_rowi, + const int& ncc, + const NumericMatrix& caliper_covs_mat, + const NumericVector& caliper_covs, + const double& caliper_dist, + const bool& use_exact, + const IntegerVector& exact, + const int& aenc, + const IntegerMatrix& antiexact_covs, + const int& ratio = 1) { + + int c_id_i; + double dist_c; + + std::vector potential_matches_id; + + if (ratio < 1) { + return potential_matches_id; + } + + std::vector potential_matches_dist; + double max_dist; + + R_xlen_t nc = distance_mat_row_i.size(); + + potential_matches_id.reserve(nc); + potential_matches_dist.reserve(nc); + + for (R_xlen_t c = 0; c < nc; c++) { + + dist_c = distance_mat_row_i[c]; + + if (potential_matches_id.size() == ratio) { + if (dist_c > max_dist) { + continue; + } + } + + if (dist_c > caliper_dist) { + continue; + } + + if (!std::isfinite(dist_c)) { + continue; + } + + c_id_i = ind_non_focal[c]; + + if (!eligible[c_id_i]) { + continue; + } + + if (treat[c_id_i] != gi) { + continue; + } + + if (!mm_okay(r, c_id_i, mm_rowi)) { + continue; + } + + if (!exact_okay(use_exact, t_id, c_id_i, exact)) { + continue; + } + + if (!antiexact_okay(aenc, t_id, c_id_i, antiexact_covs)) { + continue; + } + + if (!caliper_covs_okay(ncc, t_id, c_id_i, caliper_covs_mat, caliper_covs)) { + continue; + } + + potential_matches_id.push_back(c_id_i); + potential_matches_dist.push_back(dist_c); + + if (potential_matches_id.size() == 1) { + max_dist = dist_c; + } + else if (dist_c > max_dist) { + max_dist = dist_c; + } + } + + int n_potential_matches = potential_matches_id.size(); + + if (n_potential_matches <= 1) { + return potential_matches_id; + } + + if (n_potential_matches <= ratio && + std::is_sorted(potential_matches_dist.begin(), + potential_matches_dist.end())) { + return potential_matches_id; + } + + std::vector ind(n_potential_matches); + std::iota(ind.begin(), ind.end(), 0); + + std::vector matches_out; + + if (n_potential_matches > ratio) { + std::partial_sort(ind.begin(), ind.begin() + ratio, ind.end(), + [&potential_matches_dist](int a, int b){ + return potential_matches_dist[a] < potential_matches_dist[b]; + }); + + matches_out.reserve(ratio); + + for (auto it = ind.begin(); it != ind.begin() + ratio; ++it) { + matches_out.push_back(potential_matches_id[*it]); + } + } + else { + std::sort(ind.begin(), ind.end(), + [&potential_matches_dist](int a, int b){ + return potential_matches_dist[a] < potential_matches_dist[b]; + }); + + matches_out.reserve(n_potential_matches); + + for (auto it = ind.begin(); it != ind.end(); ++it) { + matches_out.push_back(potential_matches_id[*it]); + } + } + + return matches_out; +} + +// [[Rcpp::interfaces(cpp)]] +double max_finite(const NumericVector& x) { + double m = NA_REAL; + + R_xlen_t n = x.size(); + R_xlen_t i; + bool found = false; + + //Find first finite value + for (i = 0; i < n; i++) { + if (std::isfinite(x[i])) { + m = x[i]; + found = true; + break; + } + } + + //If none found, return NA + if (!found) { + return m; + } + + //Find largest finite value + for (R_xlen_t j = i + 1; j < n; j++) { + if (!std::isfinite(x[j])) { + continue; + } + + if (x[j] > m) { + m = x[j]; + } + } + + return m; +} + +// [[Rcpp::interfaces(cpp)]] +double min_finite(const NumericVector& x) { + double m = NA_REAL; + + R_xlen_t n = x.size(); + R_xlen_t i; + bool found = false; + + //Find first finite value + for (i = 0; i < n; i++) { + if (std::isfinite(x[i])) { + m = x[i]; + found = true; + break; + } + } + + //If none found, return NA + if (!found) { + return m; + } + + //Find smallest finite value + for (R_xlen_t j = i + 1; j < n; j++) { + if (!std::isfinite(x[j])) { + continue; + } + + if (x[j] < m) { + m = x[j]; + } + } + + return m; +} + +// [[Rcpp::interfaces(cpp)]] +void update_first_and_last_control(IntegerVector first_control, + IntegerVector last_control, + const IntegerVector& ind_d_ord, + const LogicalVector& eligible, + const IntegerVector& treat, + const int& gi) { + R_xlen_t c; + + // Update first_control + if (!eligible[ind_d_ord[first_control[gi]]]) { + for (c = first_control[gi] + 1; c <= last_control[gi]; c++) { + if (eligible[ind_d_ord[c]]) { + if (treat[ind_d_ord[c]] == gi) { + first_control[gi] = c; + break; + } + } + } + } + + // Update last_control + if (!eligible[ind_d_ord[last_control[gi]]]) { + for (c = last_control[gi] - 1; c >= first_control[gi]; c--) { + if (eligible[ind_d_ord[c]]) { + if (treat[ind_d_ord[c]] == gi) { + last_control[gi] = c; + break; + } + } + } + } } \ No newline at end of file diff --git a/src/internal.h b/src/internal.h index 7ee17ce9..86c08b19 100644 --- a/src/internal.h +++ b/src/internal.h @@ -2,12 +2,116 @@ #define INTERNAL_H #include +#include +#include +#include +#include +#include using namespace Rcpp; IntegerVector tabulateC_(const IntegerVector& bins, - const Nullable& nbins = R_NilValue); + const int& nbins = 0); IntegerVector which(const LogicalVector& x); +std::vector find_control_vec(const int& t_id, + const IntegerVector& ind_d_ord, + const IntegerVector& match_d_ord, + const IntegerVector& treat, + const NumericVector& distance, + const LogicalVector& eligible, + const int& gi, + const int& r, + const IntegerVector& mm_rowi_, + const int& ncc, + const NumericMatrix& caliper_covs_mat, + const NumericVector& caliper_covs, + const double& caliper_dist, + const bool& use_exact, + const IntegerVector& exact, + const int& aenc, + const IntegerMatrix& antiexact_covs, + const IntegerVector& first_control, + const IntegerVector& last_control, + const int& ratio = 1, + const int& prev_start = -1); -#endif \ No newline at end of file +std::vector find_control_mahcovs(const int& t_id, + const IntegerVector& ind_d_ord, + const IntegerVector& match_d_ord, + const NumericVector& match_var, + const double& match_var_caliper, + const IntegerVector& treat, + const NumericVector& distance, + const LogicalVector& eligible, + const int& gi, + const int& r, + const IntegerVector& mm_rowi, + const NumericMatrix& mah_covs, + const int& ncc, + const NumericMatrix& caliper_covs_mat, + const NumericVector& caliper_covs, + const bool& use_caliper_dist, + const double& caliper_dist, + const bool& use_exact, + const IntegerVector& exact, + const int& aenc, + const IntegerMatrix& antiexact_covs, + const int& ratio = 1); + +std::vector find_control_mat(const int& t_id, + const IntegerVector& treat, + const IntegerVector& ind_non_focal, + const NumericVector& distance_mat_row_i, + const LogicalVector& eligible, + const int& gi, + const int& r, + const IntegerVector& mm_rowi, + const int& ncc, + const NumericMatrix& caliper_covs_mat, + const NumericVector& caliper_covs, + const double& caliper_dist, + const bool& use_exact, + const IntegerVector& exact, + const int& aenc, + const IntegerMatrix& antiexact_covs, + const int& ratio = 1); + +bool antiexact_okay(const int& aenc, + const int& i, + const int& j, + const IntegerMatrix& antiexact_covs); + +bool caliper_covs_okay(const int& ncc, + const int& i, + const int& j, + const NumericMatrix& caliper_covs_mat, + const NumericVector& caliper_covs); + +bool caliper_dist_okay(const bool& use_caliper_dist, + const int& i, + const int& j, + const NumericVector& distance, + const double& caliper_dist); + +bool mm_okay(const int& r, + const int& i, + const IntegerVector& mm_rowi); + +bool exact_okay(const bool& use_exact, + const int& i, + const int& j, + const IntegerVector& exact); + +double max_finite(const NumericVector& x); + +double min_finite(const NumericVector& x); + +void update_first_and_last_control(IntegerVector first_control, + IntegerVector last_control, + const IntegerVector& ind_d_ord, + const LogicalVector& eligible, + const IntegerVector& treat, + const int& gi); + +#endif diff --git a/src/nn_matchC.cpp b/src/nn_matchC.cpp deleted file mode 100644 index ccb0dac3..00000000 --- a/src/nn_matchC.cpp +++ /dev/null @@ -1,314 +0,0 @@ -// [[Rcpp::depends(RcppProgress)]] -#include -#include -using namespace Rcpp; - -// [[Rcpp::plugins(cpp11)]] - -// [[Rcpp::export]] -IntegerMatrix nn_matchC(const IntegerVector& treat_, - const IntegerVector& ord_, - const IntegerVector& ratio, - const LogicalVector& discarded, - const int& reuse_max, - const Nullable& distance_ = R_NilValue, - const Nullable& distance_mat_ = R_NilValue, - const Nullable& exact_ = R_NilValue, - const Nullable& caliper_dist_ = R_NilValue, - const Nullable& caliper_covs_ = R_NilValue, - const Nullable& caliper_covs_mat_ = R_NilValue, - const Nullable& mah_covs_ = R_NilValue, - const Nullable& antiexact_covs_ = R_NilValue, - const Nullable& unit_id_ = R_NilValue, - const bool& disl_prog = false) - { - - // Initialize - - NumericVector distance, caliper_covs; - double caliper_dist; - NumericMatrix distance_mat, caliper_covs_mat, mah_covs, mah_covs_c; - IntegerMatrix antiexact_covs; - IntegerVector exact, exact_c, antiexact_col, unit_id, units_with_id_of_chosen_unit; - int id_of_chosen_unit; - - bool use_dist_mat = false; - bool use_exact = false; - bool use_caliper_dist = false; - bool use_caliper_covs = false; - bool use_mah_covs = false; - bool use_antiexact = false; - bool use_reuse_max = false; - - // Info about original treat - int n_ = treat_.size(); - IntegerVector ind_ = Range(0, n_ - 1); - IntegerVector ind1_ = ind_[treat_ == 1]; - IntegerVector ind0_ = ind_[treat_ == 0]; - int n1_ = ind1_.size(); - int n0_ = n_ - n1_; - CharacterVector lab = treat_.names(); - - // Output matrix with sample indices of C units - int max_rat = max(ratio); - IntegerMatrix mm(n1_, max_rat); - mm.fill(NA_INTEGER); - // rownames(mm) = lab[ind1_]; - - // Store who has been matched - IntegerVector matched = rep(0, n_); - matched[discarded] = n1_; //discarded are unmatchable - - // After discarding - - IntegerVector ind = ind_[!discarded]; - IntegerVector treat = treat_[!discarded]; - IntegerVector ind0 = ind[treat == 0]; - int n0 = ind0.size(); - - int t, t_ind, min_ind, c_chosen, num_eligible, cal_len, t_rat, n_anti, antiexact_t; - double dt, cal_var_t; - - NumericVector cal_var, cal_diff, ps_diff, diff, dist_t, mah_covs_t, mah_covs_row, - match_distance(n0); - - IntegerVector c_eligible(n0), indices(n0); - LogicalVector finite_match_distance(n0); - - if (distance_.isNotNull()) { - distance = distance_; - } - if (exact_.isNotNull()) { - exact = exact_; - use_exact = true; - } - if (caliper_dist_.isNotNull()) { - caliper_dist = as(caliper_dist_); - use_caliper_dist = true; - ps_diff = NumericVector(n_); - } - if (caliper_covs_.isNotNull()) { - caliper_covs = caliper_covs_; - use_caliper_covs = true; - cal_len = caliper_covs.size(); - cal_diff = NumericVector(n0); - } - if (caliper_covs_mat_.isNotNull()) { - caliper_covs_mat = as(caliper_covs_mat_); - } - if (mah_covs_.isNotNull()) { - mah_covs = as(mah_covs_); - NumericVector mah_covs_row(mah_covs.ncol()); - use_mah_covs = true; - } else { - if (distance_mat_.isNotNull()) { - distance_mat = as(distance_mat_); - - // IntegerVector ind0_ = ind_[treat_ == 0]; - NumericVector dist_t(n0_); - use_dist_mat = true; - } - ps_diff = NumericVector(n_); - } - if (antiexact_covs_.isNotNull()) { - antiexact_covs = as(antiexact_covs_); - n_anti = antiexact_covs.ncol(); - use_antiexact = true; - } - if (reuse_max < n1_) { - use_reuse_max = true; - } - if (unit_id_.isNotNull()) { - unit_id = as(unit_id_); - } - else { - unit_id = ind_; - } - - bool ps_diff_assigned; - - //progress bar - int prog_length; - if (!use_reuse_max) prog_length = n1_ + 1; - else prog_length = max_rat*n1_ + 1; - Progress p(prog_length, disl_prog); - - //Counters - int rat, i, x, j, j_, a, k; - k = -1; - - //Matching - for (rat = 0; rat < max_rat; ++rat) { - for (i = 0; i < n1_; ++i) { - - k++; - if (k % 500 == 0) Rcpp::checkUserInterrupt(); - - p.increment(); - - if (all(as(matched[ind0]) >= reuse_max).is_true()){ - break; - } - - t = ord_[i] - 1; // index among treated - t_ind = ind1_[t]; // index among sample - - // Skip discarded units (only discarded treated units have matched > 0) - if (matched[t_ind] > 0) { - continue; - } - - //Check if unit has enough matches - t_rat = ratio[t]; - - if (t_rat < rat + 1) { - continue; - } - - c_eligible = ind0; // index among sample - - //Make ineligible control units that have been matched too many times - c_eligible = c_eligible[as(matched[c_eligible]) < reuse_max]; - - //Prevent control units being matched to same treated unit again - if (rat > 0) { - c_eligible = as(c_eligible[!in(c_eligible, na_omit(mm.row(t)))]); - } - - if (use_exact) { - exact_c = exact[c_eligible]; - c_eligible = c_eligible[exact_c == exact[t_ind]]; - } - - if (c_eligible.size() == 0) { - continue; - } - - if (use_antiexact) { - for (a = 0; (a < n_anti) && (c_eligible.size() > 0); ++a) { - antiexact_col = antiexact_covs(_, a); - antiexact_t = antiexact_col[t_ind]; - antiexact_col = antiexact_col[c_eligible]; - c_eligible = c_eligible[antiexact_col != antiexact_t]; - } - if (c_eligible.size() == 0) { - continue; - } - } - - ps_diff_assigned = false; - - if (use_caliper_dist) { - if (use_dist_mat) { - dist_t = distance_mat.row(t); - diff = dist_t[match(c_eligible, ind0_) - 1]; - } else { - dt = distance[t_ind]; - diff = Rcpp::abs(as(distance[c_eligible]) - dt); - } - - ps_diff[c_eligible] = diff; - ps_diff_assigned = true; - - c_eligible = c_eligible[diff <= caliper_dist]; - - if (c_eligible.size() == 0) { - continue; - } - } - - if (use_caliper_covs) { - for (x = 0; (x < cal_len) && (c_eligible.size() > 0); ++x) { - cal_var = caliper_covs_mat( _ , x ); - - cal_var_t = cal_var[t_ind]; - - diff = Rcpp::abs(as(cal_var[c_eligible]) - cal_var_t); - - cal_diff = diff; - - c_eligible = c_eligible[cal_diff <= caliper_covs[x]]; - } - - if (c_eligible.size() == 0) { - continue; - } - } - - //Compute distances among eligible - num_eligible = c_eligible.size(); - - //If replace and few eligible controls, assign all and move on - if (!use_reuse_max && (num_eligible <= t_rat)) { - for (j = 0; j < num_eligible; ++j) { - mm( t , j ) = c_eligible[j]; - } - continue; - } - - if (use_mah_covs) { - - match_distance = rep(0.0, num_eligible); - mah_covs_t = mah_covs.row(t_ind); - - for (j = 0; j < num_eligible; j++) { - j_ = c_eligible[j]; - mah_covs_row = mah_covs.row(j_); - match_distance[j] = sqrt(sum(pow(mah_covs_t - mah_covs_row, 2.0))); - } - - } else if (ps_diff_assigned) { - match_distance = ps_diff[c_eligible]; //c_eligible might have shrunk since previous assignment - } else if (use_dist_mat) { - dist_t = distance_mat.row(t); - match_distance = dist_t[match(c_eligible, ind0_) - 1]; - } else { - dt = distance[t_ind]; - match_distance = Rcpp::abs(as(distance[c_eligible]) - dt); - } - - //Remove infinite distances - finite_match_distance = is_finite(match_distance); - c_eligible = c_eligible[finite_match_distance]; - num_eligible = c_eligible.size(); - if (num_eligible == 0) { - continue; - } - match_distance = match_distance[finite_match_distance]; - - if (!use_reuse_max) { - //When matching w/ replacement, get t_rat closest control units - indices = Range(0, num_eligible - 1); - - std::partial_sort(indices.begin(), indices.begin() + t_rat, indices.end(), - [&match_distance](int k, int j) {return match_distance[k] < match_distance[j];}); - - for (j = 0; j < t_rat; ++j) { - min_ind = indices[j]; - mm( t , j ) = c_eligible[min_ind]; - } - } - else { - min_ind = which_min(match_distance); - c_chosen = c_eligible[min_ind]; - - mm( t , rat ) = c_chosen; - - id_of_chosen_unit = unit_id[c_chosen]; - units_with_id_of_chosen_unit = ind_[unit_id == id_of_chosen_unit]; - for (j = 0; j < units_with_id_of_chosen_unit.size(); ++j) { - matched[units_with_id_of_chosen_unit[j]] = matched[units_with_id_of_chosen_unit[j]] + 1; - } - } - } - - if (!use_reuse_max) break; - } - - p.update(prog_length); - - mm = mm + 1; // + 1 because C indexing starts at 0 but mm is sent to R - rownames(mm) = lab[ind1_]; - - return mm; -} diff --git a/src/nn_matchC_closest.cpp b/src/nn_matchC_closest.cpp deleted file mode 100644 index 22ce7282..00000000 --- a/src/nn_matchC_closest.cpp +++ /dev/null @@ -1,197 +0,0 @@ -// [[Rcpp::depends(RcppProgress)]] -#include -#include -#include "internal.h" -using namespace Rcpp; - -// [[Rcpp::plugins(cpp11)]] - -// [[Rcpp::export]] -IntegerMatrix nn_matchC_closest(const NumericMatrix& distance_mat, - const IntegerVector& treat, - const IntegerVector& ratio, - const LogicalVector& discarded, - const int& reuse_max, - const Nullable& exact_ = R_NilValue, - const Nullable& caliper_dist_ = R_NilValue, - const Nullable& caliper_covs_ = R_NilValue, - const Nullable& caliper_covs_mat_ = R_NilValue, - const Nullable& antiexact_covs_ = R_NilValue, - const Nullable& unit_id_ = R_NilValue, - const bool& disl_prog = false) -{ - - int r = distance_mat.nrow(); - int c = distance_mat.ncol(); - - IntegerMatrix mm(r, max(ratio)); - mm.fill(NA_INTEGER); - - CharacterVector lab = treat.names(); - - IntegerVector matched_t = rep(0, r); - IntegerVector matched_c = rep(0, c); - - // IntegerVector ind = seq(0, treat.size() - 1); - IntegerVector ind0 = which(treat == 0); - IntegerVector ind1 = which(treat == 1); - - //caliper_dist - bool use_caliper_dist = false; - double caliper_dist; - if (caliper_dist_.isNotNull()) { - caliper_dist = as(caliper_dist_); - use_caliper_dist = true; - } - - //caliper_covs - NumericVector caliper_covs; - NumericMatrix caliper_covs_mat; - bool use_caliper_covs = false; - double n_cal_covs; - if (caliper_covs_.isNotNull()) { - caliper_covs = as(caliper_covs_); - caliper_covs_mat = as(caliper_covs_mat_); - n_cal_covs = caliper_covs_mat.ncol(); - use_caliper_covs = true; - } - - //exact - bool use_exact = false; - IntegerVector exact; - if (exact_.isNotNull()) { - exact = as(exact_); - use_exact = true; - } - - //antiexact - IntegerMatrix antiexact_covs; - bool use_antiexact = false; - int n_anti; - if (antiexact_covs_.isNotNull()) { - antiexact_covs = as(antiexact_covs_); - n_anti = antiexact_covs.ncol(); - use_antiexact = true; - } - - //unit_id - IntegerVector unit_id, ck_; - bool use_unit_id = false; - if (unit_id_.isNotNull()) { - unit_id = as(unit_id_); - use_unit_id = true; - } - - //progress bar - int prog_length; - prog_length = sum(ratio) + 1; - Progress p(prog_length, disl_prog); - p.increment(); - - Function o("order"); - - IntegerVector d_ord = o(distance_mat); - d_ord = d_ord - 1; //Because R uses 1-indexing - - int rj, cj, dj, i, ind0i, ind1i; - bool okay; - - for (int j = 0; j < d_ord.size(); j++) { - - dj = d_ord[j]; - - // If distance is greater tha distance caliper, stop the whole thing because - // no remaining distance will be smaller - if (use_caliper_dist) { - if (distance_mat[dj] > caliper_dist) break; - } - - // Get row and column index of potential pair - rj = dj % r; - cj = dj / r; - - // Get sample indices of members of potential pair - ind1i = ind1[rj]; - ind0i = ind0[cj]; - - // If either member is discarded, move on - if (discarded[ind1i]) continue; - if (discarded[ind0i]) continue; - - // If either member has been matched enough times, move on - if (matched_t[rj] >= ratio[rj]) continue; - if (matched_c[cj] >= reuse_max) continue; - - // Exact matching criterion - if (use_exact) { - if (exact[ind1i] != exact[ind0i]) { - continue; - } - } - - // Covariate caliper criterion - if (use_caliper_covs) { - i = 0; - okay = true; - while (okay && (i < n_cal_covs)) { - if (std::abs(caliper_covs_mat(ind1i, i) - caliper_covs_mat(ind0i, i)) > caliper_covs[i]) { - okay = false; - } - i++; - } - if (!okay) continue; - } - - // Antiexact criterion - if (use_antiexact) { - i = 0; - okay = true; - while (okay && (i < n_anti)) { - if (antiexact_covs(ind1i, i) == antiexact_covs(ind0i, i)) { - okay = false; - } - i++; - } - if (!okay) continue; - } - - // If all criteria above are satisfied, potential pair becomes a pair! - - // If unit_id used, increase match count of all units with that ID - if (use_unit_id) { - ck_ = which(as(unit_id[ind1]) == unit_id[ind1i]); - - for (i = 0; i < ck_.size(); i++) { - matched_t[ck_[i]]++; - } - - ck_ = which(as(unit_id[ind0]) == unit_id[ind0i]); - - for (i = 0; i < ck_.size(); i++) { - matched_c[ck_[i]]++; - } - } - else { - matched_t[rj]++; - matched_c[cj]++; - } - - mm(rj, matched_t[rj] - 1) = ind0i; - - p.increment(); - - if (matched_t[rj] >= ratio[rj]) { - if (all(matched_t >= ratio).is_true()) break; - } - if (matched_c[cj] >= reuse_max) { - if (all(matched_c >= reuse_max).is_true()) break; - } - } - - p.update(prog_length); - - mm = mm + 1; - rownames(mm) = lab[treat == 1]; - - return mm; -} diff --git a/src/nn_matchC_distmat.cpp b/src/nn_matchC_distmat.cpp new file mode 100644 index 00000000..fb2d879a --- /dev/null +++ b/src/nn_matchC_distmat.cpp @@ -0,0 +1,322 @@ +// [[Rcpp::depends(RcppProgress)]] +#include "eta_progress_bar.h" +#include "internal.h" +using namespace Rcpp; + +// [[Rcpp::plugins(cpp11)]] + +// [[Rcpp::export]] +IntegerMatrix nn_matchC_distmat(const IntegerVector& treat_, + const IntegerVector& ord, + const IntegerVector& ratio, + const LogicalVector& discarded, + const int& reuse_max, + const int& focal_, + const NumericMatrix& distance_mat, + const Nullable& exact_ = R_NilValue, + const Nullable& caliper_dist_ = R_NilValue, + const Nullable& caliper_covs_ = R_NilValue, + const Nullable& caliper_covs_mat_ = R_NilValue, + const Nullable& antiexact_covs_ = R_NilValue, + const Nullable& unit_id_ = R_NilValue, + const bool& disl_prog = false) { + + IntegerVector unique_treat = unique(treat_); + std::sort(unique_treat.begin(), unique_treat.end()); + int g = unique_treat.size(); + IntegerVector treat = match(treat_, unique_treat) - 1; + int focal; + for (focal = 0; focal < g; focal++) { + if (unique_treat[focal] == focal_) { + break; + } + } + + R_xlen_t n = treat.size(); + IntegerVector ind = Range(0, n - 1); + + R_xlen_t i; + int gi; + IntegerVector indt(n); + IntegerVector indt_sep(g + 1); + IntegerVector indt_tmp; + IntegerVector nt(g); + IntegerVector ind_match(n); + ind_match.fill(NA_INTEGER); + + LogicalVector eligible = !discarded; + + IntegerVector g_c = Range(0, g - 1); + g_c = g_c[g_c != focal]; + + IntegerVector n_eligible(g); + for (i = 0; i < n; i++) { + nt[treat[i]]++; + + if (eligible[i]) { + n_eligible[treat[i]]++; + } + } + + int nf = nt[focal]; + + indt_sep[0] = 0; + + for (gi = 0; gi < g; gi++) { + indt_sep[gi + 1] = indt_sep[gi] + nt[gi]; + + indt_tmp = ind[treat == gi]; + + for (i = 0; i < nt[gi]; i++) { + indt[indt_sep[gi] + i] = indt_tmp[i]; + } + } + + IntegerVector ind_focal = indt[Range(indt_sep[focal], indt_sep[focal + 1] - 1)]; + + std::vector times_matched(n, 0); + + std::vector times_matched_allowed(n, reuse_max); + for (i = 0; i < nf; i++) { + times_matched_allowed[ind_focal[i]] = ratio[i]; + } + + int max_ratio = max(ratio); + + IntegerVector ind_non_focal = which(treat != focal); + + for (i = 0; i < n - nf; i++) { + ind_match[ind_non_focal[i]] = i; + } + + for (i = 0; i < nf; i++) { + ind_match[ind_focal[i]] = i; + } + + // Output matrix with sample indices of control units + IntegerMatrix mm(nf, max_ratio); + mm.fill(NA_INTEGER); + CharacterVector lab = treat_.names(); + + //exact + bool use_exact = false; + IntegerVector exact; + if (exact_.isNotNull()) { + exact = as(exact_); + use_exact = true; + } + + //caliper_dist + double caliper_dist; + if (caliper_dist_.isNotNull()) { + caliper_dist = as(caliper_dist_); + } + else { + caliper_dist = max_finite(distance_mat) + .1; + } + + //caliper_covs + NumericVector caliper_covs; + NumericMatrix caliper_covs_mat; + int ncc = 0; + if (caliper_covs_.isNotNull()) { + caliper_covs = as(caliper_covs_); + caliper_covs_mat = as(caliper_covs_mat_); + ncc = caliper_covs_mat.ncol(); + } + + //antiexact + IntegerMatrix antiexact_covs; + int aenc = 0; + if (antiexact_covs_.isNotNull()) { + antiexact_covs = as(antiexact_covs_); + aenc = antiexact_covs.ncol(); + } + + //reuse_max + bool use_reuse_max = (reuse_max < nf); + + //unit_id + IntegerVector unit_id; + bool use_unit_id = false; + if (unit_id_.isNotNull()) { + unit_id = as(unit_id_); + use_unit_id = true; + use_reuse_max = true; + } + + IntegerVector matches_i(1 + max_ratio * (g - 1)); + int k_total; + + //progress bar + int prog_length; + if (use_reuse_max) prog_length = sum(ratio) + 1; + else prog_length = nf + 1; + ETAProgressBar pb; + Progress p(prog_length, disl_prog, pb); + + R_xlen_t c; + int r, t_id_t_i, t_id_i; + IntegerVector ck_; + std::vector k(max_ratio); + + int counter = 0; + + if (use_reuse_max) { + for (r = 1; r <= max_ratio; r++) { + for (auto it = ord.begin(); it != ord.end() && max(as(n_eligible[g_c])) > 0; ++it) { + // i: generic looping index + // t_id_t_i; index of treated unit to match among treated units + // t_id_i: index of treated unit to match among all units + counter++; + if (counter == 200) { + counter = 0; + Rcpp::checkUserInterrupt(); + } + + t_id_t_i = *it - 1; + t_id_i = ind_focal[t_id_t_i]; + + if (r > times_matched_allowed[t_id_i]) { + continue; + } + + p.increment(); + + if (!eligible[t_id_i]) { + continue; + } + + k_total = 0; + + for (int gi : g_c) { + k = find_control_mat(t_id_i, + treat, + ind_non_focal, + distance_mat.row(t_id_t_i), + eligible, + gi, + r, + mm.row(t_id_t_i), + ncc, + caliper_covs_mat, + caliper_covs, + caliper_dist, + use_exact, + exact, + aenc, + antiexact_covs); + + if (k.empty()) { + if (r == 1) { + k_total = 0; + break; + } + continue; + } + + matches_i[k_total] = k[0]; + k_total++; + } + + if (k_total == 0) { + eligible[t_id_i] = false; + n_eligible[focal]--; + continue; + } + + for (c = 0; c < k_total; c++) { + mm(t_id_t_i, sum(!is_na(mm(t_id_t_i, _)))) = matches_i[c]; + } + + matches_i[k_total] = t_id_i; + + ck_ = matches_i[Range(0, k_total)]; + + if (use_unit_id) { + ck_ = which(!is_na(match(unit_id, as(unit_id[ck_])))); + } + + for (int ck : ck_) { + if (!eligible[ck]) { + continue; + } + + times_matched[ck]++; + if (times_matched[ck] >= times_matched_allowed[ck]) { + eligible[ck] = false; + n_eligible[treat[ck]]--; + } + } + } + } + } + else { + for (auto it = ord.begin(); it != ord.end(); ++it) { + // i: generic looping index + // t_id_t_i; index of treated unit to match among treated units + // t_id_i: index of treated unit to match among all units + counter++; + if (counter == 200) { + counter = 0; + Rcpp::checkUserInterrupt(); + } + + t_id_t_i = *it - 1; + t_id_i = ind_focal[t_id_t_i]; + + p.increment(); + + if (!eligible[t_id_i]) { + continue; + } + + k_total = 0; + + for (int gi : g_c) { + k = find_control_mat(t_id_i, + treat, + ind_non_focal, + distance_mat.row(t_id_t_i), + eligible, + gi, + 1, + mm.row(t_id_t_i), + ncc, + caliper_covs_mat, + caliper_covs, + caliper_dist, + use_exact, + exact, + aenc, + antiexact_covs, + ratio[t_id_t_i]); + + if (k.empty()) { + k_total = 0; + break; + } + + for (int cc : k) { + matches_i[k_total] = cc; + k_total++; + } + } + + if (k_total == 0) { + continue; + } + + for (c = 0; c < k_total; c++) { + mm(t_id_t_i, sum(!is_na(mm(t_id_t_i, _)))) = matches_i[c]; + } + } + } + + p.update(prog_length); + + mm = mm + 1; + rownames(mm) = as(lab[ind_focal]); + + return mm; +} \ No newline at end of file diff --git a/src/nn_matchC_distmat_closest.cpp b/src/nn_matchC_distmat_closest.cpp new file mode 100644 index 00000000..c4651556 --- /dev/null +++ b/src/nn_matchC_distmat_closest.cpp @@ -0,0 +1,319 @@ +// [[Rcpp::depends(RcppProgress)]] +#include "eta_progress_bar.h" +#include "internal.h" +using namespace Rcpp; + +// [[Rcpp::plugins(cpp11)]] + +// [[Rcpp::export]] +IntegerMatrix nn_matchC_distmat_closest(const IntegerVector& treat, + const IntegerVector& ratio, + const LogicalVector& discarded, + const int& reuse_max, + const NumericMatrix& distance_mat, + const Nullable& exact_ = R_NilValue, + const Nullable& caliper_dist_ = R_NilValue, + const Nullable& caliper_covs_ = R_NilValue, + const Nullable& caliper_covs_mat_ = R_NilValue, + const Nullable& antiexact_covs_ = R_NilValue, + const Nullable& unit_id_ = R_NilValue, + const bool& close = true, + const bool& disl_prog = false) { + + IntegerVector unique_treat = {0, 1}; + int g = unique_treat.size(); + int focal = 1; + + R_xlen_t n = treat.size(); + IntegerVector ind = Range(0, n - 1); + + R_xlen_t i; + int gi; + IntegerVector indt(n); + IntegerVector indt_sep(g + 1); + IntegerVector indt_tmp; + IntegerVector nt(g); + IntegerVector ind_match(n); + ind_match.fill(NA_INTEGER); + + IntegerVector times_matched(n); + times_matched.fill(0); + LogicalVector eligible = !discarded; + + for (gi = 0; gi < g; gi++) { + nt[gi] = sum(treat == gi); + } + + int nf = nt[focal]; + + indt_sep[0] = 0; + + for (gi = 0; gi < g; gi++) { + indt_sep[gi + 1] = indt_sep[gi] + nt[gi]; + + indt_tmp = ind[treat == gi]; + + for (i = 0; i < nt[gi]; i++) { + indt[indt_sep[gi] + i] = indt_tmp[i]; + } + } + + IntegerVector ind_non_focal = which(treat != focal); + IntegerVector ind_focal = which(treat == focal); + + for (i = 0; i < n - nf; i++) { + ind_match[ind_non_focal[i]] = i; + } + + for (i = 0; i < nf; i++) { + ind_match[ind_focal[i]] = i; + } + + IntegerVector times_matched_allowed(n); + times_matched_allowed.fill(reuse_max); + times_matched_allowed[ind_focal] = ratio; + + IntegerVector n_eligible(g); + for (i = 0; i < n; i++) { + if (eligible[i]) { + n_eligible[treat[i]]++; + } + } + + int max_ratio = max(ratio); + + // Output matrix with sample indices of control units + IntegerMatrix mm(nf, max_ratio); + mm.fill(NA_INTEGER); + CharacterVector lab = treat.names(); + + Function o("order"); + + //exact + bool use_exact = false; + IntegerVector exact; + if (exact_.isNotNull()) { + exact = as(exact_); + use_exact = true; + } + + //caliper_dist + double caliper_dist; + if (caliper_dist_.isNotNull()) { + caliper_dist = as(caliper_dist_); + } + else { + caliper_dist = max_finite(distance_mat) + .1; + } + + //caliper_covs + NumericVector caliper_covs; + NumericMatrix caliper_covs_mat; + int ncc = 0; + if (caliper_covs_.isNotNull()) { + caliper_covs = as(caliper_covs_); + caliper_covs_mat = as(caliper_covs_mat_); + ncc = caliper_covs_mat.ncol(); + } + + //antiexact + IntegerMatrix antiexact_covs; + int aenc = 0; + if (antiexact_covs_.isNotNull()) { + antiexact_covs = as(antiexact_covs_); + aenc = antiexact_covs.ncol(); + } + + //unit_id + IntegerVector unit_id; + bool use_unit_id = false; + if (unit_id_.isNotNull()) { + unit_id = as(unit_id_); + use_unit_id = true; + } + + //storing closeness + std::vector t_id, c_id; + std::vector dist; + t_id.reserve(n_eligible[focal]); + c_id.reserve(n_eligible[focal]); + dist.reserve(n_eligible[focal]); + + //progress bar + R_xlen_t prog_length = n_eligible[focal] + sum(ratio) + 1; + ETAProgressBar pb; + Progress p(prog_length, disl_prog, pb); + + gi = 0; + + IntegerVector ck_; + + int c_id_i, t_id_t_i, t_id_i; + + int counter = 0; + int r = 1; + + IntegerVector heap_ord(n_eligible[focal]); + std::vector k; + k.reserve(1); + R_xlen_t hi; + + IntegerVector::iterator ci; + + std::function cmp; + if (close) { + cmp = [&dist](const int& a, const int& b) {return dist[a] < dist[b];}; + } + else { + cmp = [&dist](const int& a, const int& b) {return dist[a] >= dist[b];}; + } + + for (r = 1; r <= max_ratio; r++) { + for (int ti : ind_focal) { + if (!eligible[ti]) { + continue; + } + + counter++; + if (counter == 200) { + counter = 0; + Rcpp::checkUserInterrupt(); + } + + t_id_t_i = ind_match[ti]; + + k = find_control_mat(ti, + treat, + ind_non_focal, + distance_mat.row(t_id_t_i), + eligible, + gi, + r, + mm.row(t_id_t_i), + ncc, + caliper_covs_mat, + caliper_covs, + caliper_dist, + use_exact, + exact, + aenc, + antiexact_covs); + + p.increment(); + + if (k.empty()) { + eligible[ti] = false; + n_eligible[focal]--; + continue; + } + + t_id.push_back(ti); + c_id.push_back(k[0]); + dist.push_back(distance_mat(t_id_t_i, ind_match[k[0]])); + } + + nf = dist.size(); + + //Order the list + heap_ord = o(dist, _["decreasing"] = !close); + heap_ord = heap_ord - 1; + + i = 0; + while (min(n_eligible) > 0 && i < nf) { + counter++; + if (counter == 200) { + counter = 0; + Rcpp::checkUserInterrupt(); + } + + hi = heap_ord[i]; + + t_id_i = t_id[hi]; + + if (!eligible[t_id_i]) { + i++; + continue; + } + + t_id_t_i = ind_match[t_id_i]; + + c_id_i = c_id[hi]; + + if (!eligible[c_id_i]) { + // If control isn't eligible, find new control and try again + + k = find_control_mat(t_id_i, + treat, + ind_non_focal, + distance_mat.row(t_id_t_i), + eligible, + gi, + r, + mm.row(t_id_t_i), + ncc, + caliper_covs_mat, + caliper_covs, + caliper_dist, + use_exact, + exact, + aenc, + antiexact_covs); + + //If no matches... + if (k.empty()) { + eligible[t_id_i] = false; + n_eligible[focal]--; + continue; + } + + c_id[hi] = k[0]; + dist[hi] = distance_mat(t_id_t_i, ind_match[k[0]]); + + // Find new position of pair in heap + ci = std::lower_bound(heap_ord.begin() + i, heap_ord.end(), hi, cmp); + + if (ci != heap_ord.begin() + i) { + std::rotate(heap_ord.begin() + i, heap_ord.begin() + i + 1, ci); + } + + continue; + } + + mm(t_id_t_i, sum(!is_na(mm(t_id_t_i, _)))) = c_id_i; + + ck_ = {c_id_i, t_id_i}; + + if (use_unit_id) { + ck_ = which(!is_na(match(unit_id, as(unit_id[ck_])))); + } + + for (int ck : ck_) { + + if (!eligible[ck]) { + continue; + } + + times_matched[ck]++; + if (times_matched[ck] >= times_matched_allowed[ck]) { + eligible[ck] = false; + n_eligible[treat[ck]]--; + } + } + + p.increment(); + + i++; + } + + t_id.clear(); + c_id.clear(); + dist.clear(); + } + + p.update(prog_length); + + mm = mm + 1; + rownames(mm) = as(lab[ind_focal]); + + return mm; +} \ No newline at end of file diff --git a/src/nn_matchC_mahcovs.cpp b/src/nn_matchC_mahcovs.cpp new file mode 100644 index 00000000..911f3046 --- /dev/null +++ b/src/nn_matchC_mahcovs.cpp @@ -0,0 +1,346 @@ +// [[Rcpp::depends(RcppProgress)]] +#include "eta_progress_bar.h" +#include "internal.h" +using namespace Rcpp; + +// [[Rcpp::plugins(cpp11)]] + +// [[Rcpp::export]] +IntegerMatrix nn_matchC_mahcovs(const IntegerVector& treat_, + const IntegerVector& ord, + const IntegerVector& ratio, + const LogicalVector& discarded, + const int& reuse_max, + const int& focal_, + const NumericMatrix& mah_covs, + const Nullable& distance_ = R_NilValue, + const Nullable& exact_ = R_NilValue, + const Nullable& caliper_dist_ = R_NilValue, + const Nullable& caliper_covs_ = R_NilValue, + const Nullable& caliper_covs_mat_ = R_NilValue, + const Nullable& antiexact_covs_ = R_NilValue, + const Nullable& unit_id_ = R_NilValue, + const bool& disl_prog = false) { + IntegerVector unique_treat = unique(treat_); + std::sort(unique_treat.begin(), unique_treat.end()); + int g = unique_treat.size(); + IntegerVector treat = match(treat_, unique_treat) - 1; + int focal; + for (focal = 0; focal < g; focal++) { + if (unique_treat[focal] == focal_) { + break; + } + } + + R_xlen_t n = treat.size(); + IntegerVector ind = Range(0, n - 1); + + R_xlen_t i; + int gi; + IntegerVector indt(n); + IntegerVector indt_sep(g + 1); + IntegerVector indt_tmp; + IntegerVector nt(g); + IntegerVector ind_match(n); + ind_match.fill(NA_INTEGER); + + LogicalVector eligible = !discarded; + + IntegerVector g_c = Range(0, g - 1); + g_c = g_c[g_c != focal]; + + IntegerVector n_eligible(g); + for (i = 0; i < n; i++) { + nt[treat[i]]++; + + if (eligible[i]) { + n_eligible[treat[i]]++; + } + } + + int nf = nt[focal]; + + indt_sep[0] = 0; + + for (gi = 0; gi < g; gi++) { + indt_sep[gi + 1] = indt_sep[gi] + nt[gi]; + + indt_tmp = ind[treat == gi]; + + for (i = 0; i < nt[gi]; i++) { + indt[indt_sep[gi] + i] = indt_tmp[i]; + ind_match[indt_tmp[i]] = i; + } + } + + IntegerVector ind_focal = indt[Range(indt_sep[focal], indt_sep[focal + 1] - 1)]; + + std::vector times_matched(n, 0); + + std::vector times_matched_allowed(n, reuse_max); + for (i = 0; i < nf; i++) { + times_matched_allowed[ind_focal[i]] = ratio[i]; + } + + int max_ratio = max(ratio); + + // Output matrix with sample indices of control units + IntegerMatrix mm(nf, max_ratio); + mm.fill(NA_INTEGER); + CharacterVector lab = treat_.names(); + + Function o("order"); + + NumericVector match_var = mah_covs.column(0); + double match_var_caliper = R_PosInf; + + IntegerVector ind_d_ord = o(match_var); + ind_d_ord = ind_d_ord - 1; //location of each unit after sorting + + IntegerVector match_d_ord = o(ind_d_ord); + match_d_ord = match_d_ord - 1; + + //exact + bool use_exact = false; + IntegerVector exact; + if (exact_.isNotNull()) { + exact = as(exact_); + use_exact = true; + } + + //distance & caliper_dist + bool use_caliper_dist = false; + double caliper_dist; + NumericVector distance; + if (caliper_dist_.isNotNull() && distance_.isNotNull()) { + distance = as(distance_); + caliper_dist = as(caliper_dist_); + use_caliper_dist = true; + } + + //caliper_covs + NumericVector caliper_covs; + NumericMatrix caliper_covs_mat; + int ncc = 0; + if (caliper_covs_.isNotNull()) { + caliper_covs = as(caliper_covs_); + caliper_covs_mat = as(caliper_covs_mat_); + ncc = caliper_covs_mat.ncol(); + + //Find if caliper placed on match_var + for (int cci = 0; cci < ncc; cci++) { + if (std::equal(caliper_covs_mat.column(cci).begin(), + caliper_covs_mat.column(cci).end(), + match_var.begin(), + match_var.end())) { + match_var_caliper = caliper_covs[cci]; + break; + } + } + } + + //antiexact + IntegerMatrix antiexact_covs; + int aenc = 0; + if (antiexact_covs_.isNotNull()) { + antiexact_covs = as(antiexact_covs_); + aenc = antiexact_covs.ncol(); + } + + //reuse_max + bool use_reuse_max = (reuse_max < nf); + + //unit_id + IntegerVector unit_id; + bool use_unit_id = false; + if (unit_id_.isNotNull()) { + unit_id = as(unit_id_); + use_unit_id = true; + use_reuse_max = true; + } + + IntegerVector matches_i(1 + max_ratio * (g - 1)); + int k_total; + + //progress bar + int prog_length; + if (use_reuse_max) prog_length = sum(ratio) + 1; + else prog_length = nf + 1; + ETAProgressBar pb; + Progress p(prog_length, disl_prog, pb); + + R_xlen_t c; + int r, t_id_t_i, t_id_i; + IntegerVector ck_; + std::vector k(max_ratio); + + int counter = 0; + + if (use_reuse_max) { + for (r = 1; r <= max_ratio; r++) { + for (auto it = ord.begin(); it != ord.end() && max(as(n_eligible[g_c])) > 0; ++it) { + // i: generic looping index + // t_id_t_i; index of treated unit to match among treated units + // t_id_i: index of treated unit to match among all units + counter++; + if (counter == 200) { + counter = 0; + Rcpp::checkUserInterrupt(); + } + + t_id_t_i = *it - 1; + t_id_i = ind_focal[t_id_t_i]; + + if (r > times_matched_allowed[t_id_i]) { + continue; + } + + p.increment(); + + if (!eligible[t_id_i]) { + continue; + } + + k_total = 0; + + for (int gi : g_c) { + k = find_control_mahcovs(t_id_i, + ind_d_ord, + match_d_ord, + match_var, + match_var_caliper, + treat, + distance, + eligible, + gi, + r, + mm.row(t_id_t_i), + mah_covs, + ncc, + caliper_covs_mat, + caliper_covs, + use_caliper_dist, + caliper_dist, + use_exact, + exact, + aenc, + antiexact_covs); + + if (k.empty()) { + if (r == 1) { + k_total = 0; + break; + } + continue; + } + + matches_i[k_total] = k[0]; + k_total++; + } + + if (k_total == 0) { + eligible[t_id_i] = false; + n_eligible[focal]--; + continue; + } + + for (c = 0; c < k_total; c++) { + mm(t_id_t_i, sum(!is_na(mm(t_id_t_i, _)))) = matches_i[c]; + } + + matches_i[k_total] = t_id_i; + + ck_ = matches_i[Range(0, k_total)]; + + if (use_unit_id) { + ck_ = which(!is_na(match(unit_id, as(unit_id[ck_])))); + } + + for (int ck : ck_) { + if (!eligible[ck]) { + continue; + } + + times_matched[ck]++; + if (times_matched[ck] >= times_matched_allowed[ck]) { + eligible[ck] = false; + n_eligible[treat[ck]]--; + } + } + } + } + } + else { + for (auto it = ord.begin(); it != ord.end(); ++it) { + // i: generic looping index + // t_id_t_i; index of treated unit to match among treated units + // t_id_i: index of treated unit to match among all units + counter++; + if (counter == 200) { + counter = 0; + Rcpp::checkUserInterrupt(); + } + + t_id_t_i = *it - 1; + t_id_i = ind_focal[t_id_t_i]; + + p.increment(); + + if (!eligible[t_id_i]) { + continue; + } + + k_total = 0; + + for (int gi : g_c) { + k = find_control_mahcovs(t_id_i, + ind_d_ord, + match_d_ord, + match_var, + match_var_caliper, + treat, + distance, + eligible, + gi, + 1, + mm.row(t_id_t_i), + mah_covs, + ncc, + caliper_covs_mat, + caliper_covs, + use_caliper_dist, + caliper_dist, + use_exact, + exact, + aenc, + antiexact_covs, + ratio[t_id_t_i]); + + if (k.empty()) { + k_total = 0; + break; + } + + for (int cc : k) { + matches_i[k_total] = cc; + k_total++; + } + } + + if (k_total == 0) { + continue; + } + + for (c = 0; c < k_total; c++) { + mm(t_id_t_i, sum(!is_na(mm(t_id_t_i, _)))) = matches_i[c]; + } + } + } + + p.update(prog_length); + + mm = mm + 1; + rownames(mm) = as(lab[ind_focal]); + + return mm; +} \ No newline at end of file diff --git a/src/nn_matchC_mahcovs_closest.cpp b/src/nn_matchC_mahcovs_closest.cpp new file mode 100644 index 00000000..cf09ff0c --- /dev/null +++ b/src/nn_matchC_mahcovs_closest.cpp @@ -0,0 +1,347 @@ +// [[Rcpp::depends(RcppProgress)]] +#include "eta_progress_bar.h" +#include "internal.h" +using namespace Rcpp; + +// [[Rcpp::plugins(cpp11)]] + +// [[Rcpp::export]] +IntegerMatrix nn_matchC_mahcovs_closest(const IntegerVector& treat, + const IntegerVector& ratio, + const LogicalVector& discarded, + const int& reuse_max, + const NumericMatrix& mah_covs, + const Nullable& distance_ = R_NilValue, + const Nullable& exact_ = R_NilValue, + const Nullable& caliper_dist_ = R_NilValue, + const Nullable& caliper_covs_ = R_NilValue, + const Nullable& caliper_covs_mat_ = R_NilValue, + const Nullable& antiexact_covs_ = R_NilValue, + const Nullable& unit_id_ = R_NilValue, + const bool& close = true, + const bool& disl_prog = false) { + + IntegerVector unique_treat = {0, 1}; + int g = unique_treat.size(); + int focal = 1; + + R_xlen_t n = treat.size(); + IntegerVector ind = Range(0, n - 1); + + R_xlen_t i; + int gi; + IntegerVector indt(n); + IntegerVector indt_sep(g + 1); + IntegerVector indt_tmp; + IntegerVector nt(g); + IntegerVector ind_match(n); + ind_match.fill(NA_INTEGER); + + LogicalVector eligible = !discarded; + + IntegerVector g_c = Range(0, g - 1); + g_c = g_c[g_c != focal]; + + IntegerVector n_eligible(g); + for (i = 0; i < n; i++) { + nt[treat[i]]++; + + if (eligible[i]) { + n_eligible[treat[i]]++; + } + } + + int nf = nt[focal]; + + indt_sep[0] = 0; + + for (gi = 0; gi < g; gi++) { + indt_sep[gi + 1] = indt_sep[gi] + nt[gi]; + + indt_tmp = ind[treat == gi]; + + for (i = 0; i < nt[gi]; i++) { + indt[indt_sep[gi] + i] = indt_tmp[i]; + ind_match[indt_tmp[i]] = i; + } + } + + IntegerVector ind_focal = indt[Range(indt_sep[focal], indt_sep[focal + 1] - 1)]; + + std::vector times_matched(n, 0); + + std::vector times_matched_allowed(n, reuse_max); + for (i = 0; i < nf; i++) { + times_matched_allowed[ind_focal[i]] = ratio[i]; + } + + int max_ratio = max(ratio); + + // Output matrix with sample indices of control units + IntegerMatrix mm(nf, max_ratio); + mm.fill(NA_INTEGER); + CharacterVector lab = treat.names(); + + Function o("order"); + + NumericVector match_var = mah_covs.column(0); + double match_var_caliper = R_PosInf; + + IntegerVector ind_d_ord = o(match_var); + ind_d_ord = ind_d_ord - 1; //location of each unit after sorting + + IntegerVector match_d_ord = o(ind_d_ord); + match_d_ord = match_d_ord - 1; + + //exact + bool use_exact = false; + IntegerVector exact; + if (exact_.isNotNull()) { + exact = as(exact_); + use_exact = true; + } + + //distance & caliper_dist + bool use_caliper_dist = false; + double caliper_dist; + NumericVector distance; + if (caliper_dist_.isNotNull() && distance_.isNotNull()) { + distance = as(distance_); + caliper_dist = as(caliper_dist_); + use_caliper_dist = true; + } + + //caliper_covs + NumericVector caliper_covs; + NumericMatrix caliper_covs_mat; + int ncc = 0; + if (caliper_covs_.isNotNull()) { + caliper_covs = as(caliper_covs_); + caliper_covs_mat = as(caliper_covs_mat_); + ncc = caliper_covs_mat.ncol(); + + //Find if caliper placed on match_var + for (int cci = 0; cci < ncc; cci++) { + if (std::equal(caliper_covs_mat.column(cci).begin(), + caliper_covs_mat.column(cci).end(), + match_var.begin(), + match_var.end())) { + match_var_caliper = caliper_covs[cci]; + break; + } + } + } + + //antiexact + IntegerMatrix antiexact_covs; + int aenc = 0; + if (antiexact_covs_.isNotNull()) { + antiexact_covs = as(antiexact_covs_); + aenc = antiexact_covs.ncol(); + } + + //unit_id + IntegerVector unit_id; + bool use_unit_id = false; + if (unit_id_.isNotNull()) { + unit_id = as(unit_id_); + use_unit_id = true; + } + + //storing closeness + std::vector t_id, c_id; + std::vector dist; + t_id.reserve(n_eligible[focal]); + c_id.reserve(n_eligible[focal]); + dist.reserve(n_eligible[focal]); + + //progress bar + R_xlen_t prog_length = n_eligible[focal] + sum(ratio) + 1; + ETAProgressBar pb; + Progress p(prog_length, disl_prog, pb); + + gi = 0; + + IntegerVector ck_; + + int c_id_i, t_id_t_i, t_id_i; + + int counter = 0; + int r = 1; + + IntegerVector heap_ord; + std::vector k; + k.reserve(1); + R_xlen_t hi; + + std::function cmp; + if (close) { + cmp = [&dist](const int& a, const int& b) {return dist[a] < dist[b];}; + } + else { + cmp = [&dist](const int& a, const int& b) {return dist[a] >= dist[b];}; + } + + IntegerVector::iterator ci; + + for (r = 1; r <= max_ratio; r++) { + //Find closest control unit to each treated unit + for (int ti : ind_focal) { + + if (!eligible[ti]) { + continue; + } + + counter++; + if (counter == 200) { + counter = 0; + Rcpp::checkUserInterrupt(); + } + + t_id_t_i = ind_match[ti]; + + k = find_control_mahcovs(ti, + ind_d_ord, + match_d_ord, + match_var, + match_var_caliper, + treat, + distance, + eligible, + gi, + r, + mm.row(t_id_t_i), + mah_covs, + ncc, + caliper_covs_mat, + caliper_covs, + use_caliper_dist, + caliper_dist, + use_exact, + exact, + aenc, + antiexact_covs); + + p.increment(); + + if (k.empty()) { + eligible[ti] = false; + n_eligible[focal]--; + continue; + } + + t_id.push_back(ti); + c_id.push_back(k[0]); + dist.push_back(sum(pow(mah_covs.row(ti) - mah_covs.row(k[0]), 2.0))); + } + + nf = dist.size(); + + //Order the list + heap_ord = o(dist, _["decreasing"] = !close); + heap_ord = heap_ord - 1; + + i = 0; + while (min(n_eligible) > 0 && i < nf) { + counter++; + if (counter == 200) { + counter = 0; + Rcpp::checkUserInterrupt(); + } + + hi = heap_ord[i]; + + t_id_i = t_id[hi]; + + if (!eligible[t_id_i]) { + i++; + continue; + } + + t_id_t_i = ind_match[t_id_i]; + + c_id_i = c_id[hi]; + + if (!eligible[c_id_i]) { + // If control isn't eligible, find new control and try again + + k = find_control_mahcovs(t_id_i, + ind_d_ord, + match_d_ord, + match_var, + match_var_caliper, + treat, + distance, + eligible, + gi, + r, + mm.row(t_id_t_i), + mah_covs, + ncc, + caliper_covs_mat, + caliper_covs, + use_caliper_dist, + caliper_dist, + use_exact, + exact, + aenc, + antiexact_covs); + + //If no matches... + if (k.empty()) { + eligible[t_id_i] = false; + n_eligible[focal]--; + continue; + } + + c_id[hi] = k[0]; + dist[hi] = sum(pow(mah_covs.row(t_id_i) - mah_covs.row(k[0]), 2.0)); + + // Find new position of pair in heap + ci = std::lower_bound(heap_ord.begin() + i, heap_ord.end(), hi, cmp); + + if (ci != heap_ord.begin() + i) { + std::rotate(heap_ord.begin() + i, heap_ord.begin() + i + 1, ci); + } + + continue; + } + + mm(t_id_t_i, sum(!is_na(mm(t_id_t_i, _)))) = c_id_i; + + ck_ = {c_id_i, t_id_i}; + + if (use_unit_id) { + ck_ = which(!is_na(match(unit_id, as(unit_id[ck_])))); + } + + for (int ck : ck_) { + + if (!eligible[ck]) { + continue; + } + + times_matched[ck]++; + if (times_matched[ck] >= times_matched_allowed[ck]) { + eligible[ck] = false; + n_eligible[treat[ck]]--; + } + } + + p.increment(); + + i++; + } + + t_id.clear(); + c_id.clear(); + dist.clear(); + } + + p.update(prog_length); + + mm = mm + 1; + rownames(mm) = as(lab[ind_focal]); + + return mm; +} \ No newline at end of file diff --git a/src/nn_matchC_vec.cpp b/src/nn_matchC_vec.cpp index 43c04096..15e7c572 100644 --- a/src/nn_matchC_vec.cpp +++ b/src/nn_matchC_vec.cpp @@ -1,474 +1,340 @@ // [[Rcpp::depends(RcppProgress)]] -#include -#include +#include "eta_progress_bar.h" #include "internal.h" using namespace Rcpp; -// Version of nn_matchC that works when `distance` is a vector. -// Doesn't accept `distance_mat_` or `mah_covs_`. +// [[Rcpp::plugins(cpp11)]] -bool check_in(int x, - IntegerVector table) { - int t = table.size(); - if (t < 1) return false; +// [[Rcpp::export]] +IntegerMatrix nn_matchC_vec(const IntegerVector& treat_, + const IntegerVector& ord, + const IntegerVector& ratio, + const LogicalVector& discarded, + const int& reuse_max, + const int& focal_, + const NumericVector& distance, + const Nullable& exact_ = R_NilValue, + const Nullable& caliper_dist_ = R_NilValue, + const Nullable& caliper_covs_ = R_NilValue, + const Nullable& caliper_covs_mat_ = R_NilValue, + const Nullable& antiexact_covs_ = R_NilValue, + const Nullable& unit_id_ = R_NilValue, + const bool& disl_prog = false) { - for (int j = 0; j < t; j++) { - if (x == table[j]) return true; + IntegerVector unique_treat = unique(treat_); + std::sort(unique_treat.begin(), unique_treat.end()); + int g = unique_treat.size(); + IntegerVector treat = match(treat_, unique_treat) - 1; + int focal; + for (focal = 0; focal < g; focal++) { + if (unique_treat[focal] == focal_) { + break; + } } - return false; -} -int find_right(int ii, - int last_control, - IntegerVector treat, - LogicalVector can_be_matched, - int r, - IntegerVector row, - IntegerVector d_ord, - NumericVector distance, - bool use_caliper_dist, - double caliper_dist, - bool use_caliper_covs, - NumericVector caliper_covs, - NumericMatrix caliper_covs_mat, - bool use_exact, - IntegerVector exact, - bool use_antiexact, - IntegerMatrix antiexact_covs) { - - int k = ii + 1; - bool found = false, okay; - int i; - - int n_anti; - if (use_antiexact) n_anti = antiexact_covs.ncol(); - - int n_cal_covs; - if (use_caliper_covs) n_cal_covs = caliper_covs_mat.ncol(); - - while (!found && k <= last_control) { - if (treat[k] == 1) { - k++; //if unit is treated, move right - continue; - } + R_xlen_t n = treat.size(); + IntegerVector ind = Range(0, n - 1); - if (!can_be_matched[k]) { - k++; //if unit is matched, move right - continue; - } + R_xlen_t i; + int gi; + IntegerVector indt(n); + IntegerVector indt_sep(g + 1); + IntegerVector indt_tmp; + IntegerVector nt(g); + IntegerVector ind_match(n); + ind_match.fill(NA_INTEGER); - //if unit has already been matched to unit i, skip - if (r > 0) { - if (check_in(d_ord[k], row)) { - k++; - continue; - } - } + LogicalVector eligible = !discarded; - if (use_caliper_dist) { - if (std::abs(distance[ii] - distance[k]) > caliper_dist) { - //if closest is outside caliper, break; none can be found - break; - } - } + IntegerVector g_c = Range(0, g - 1); + g_c = g_c[g_c != focal]; - if (use_exact) { - if (exact[ii] != exact[k]) { - k++; //if not exact match, move right - continue; - } - } + IntegerVector n_eligible(g); + for (i = 0; i < n; i++) { + nt[treat[i]]++; - if (use_antiexact) { - i = 0; - okay = true; - while (okay && (i < n_anti)) { - if (antiexact_covs(ii, i) == antiexact_covs(k, i)) { - okay = false; - } - i++; - } - if (!okay) { - k++; //if any antiexact checks failed, move right - continue; - } - } - - if (use_caliper_covs) { - i = 0; - okay = true; - while (okay && (i < n_cal_covs)) { - if (std::abs(caliper_covs_mat(ii, i) - caliper_covs_mat(k, i)) > caliper_covs[i]) { - okay = false; - } - i++; - } - if (!okay) { - k++; //if any cov caliper checks failed, move right - continue; - } + if (eligible[i]) { + n_eligible[treat[i]]++; } - - found = true; } - if (!found) k = -1; + int nf = nt[focal]; - return k; -} + indt_sep[0] = 0; -int find_left(int ii, - int first_control, - IntegerVector treat, - LogicalVector can_be_matched, - int r, - IntegerVector row, - IntegerVector d_ord, - NumericVector distance, - bool use_caliper_dist, - double caliper_dist, - bool use_caliper_covs, - NumericVector caliper_covs, - NumericMatrix caliper_covs_mat, - bool use_exact, - IntegerVector exact, - bool use_antiexact, - IntegerMatrix antiexact_covs) { - - int k = ii - 1; - bool found = false, okay; - int i; - - int n_anti; - if (use_antiexact) n_anti = antiexact_covs.ncol(); - - int n_cal_covs; - if (use_caliper_covs) n_cal_covs = caliper_covs_mat.ncol(); - - while (!found && k >= first_control) { - if (treat[k] == 1) { - k--; //if unit is treated, move left - continue; - } + for (gi = 0; gi < g; gi++) { + indt_sep[gi + 1] = indt_sep[gi] + nt[gi]; - if (!can_be_matched[k]) { - k--; //if unit is matched, move left - continue; - } - - //if unit has already been matched to unit i, skip - if (r > 0) { - if (check_in(d_ord[k], row)) { - k--; - continue; - } - } - - if (use_caliper_dist) { - if (std::abs(distance[ii] - distance[k]) > caliper_dist) { - //if closest is outside caliper, break - break; - } - } - - if (use_exact) { - if (exact[ii] != exact[k]) { - k--; //if not exact match, move left - continue; - } - } - - if (use_antiexact) { - i = 0; - okay = true; - while (okay && (i < n_anti)) { - if (antiexact_covs(ii, i) == antiexact_covs(k, i)) { - okay = false; - } - i++; - } - if (!okay) { - k--; //if any antiexact checks failed, move left - continue; - } - } + indt_tmp = ind[treat == gi]; - if (use_caliper_covs) { - i = 0; - okay = true; - while (okay && (i < n_cal_covs)) { - if (std::abs(caliper_covs_mat(ii, i) - caliper_covs_mat(k, i)) > caliper_covs[i]) { - okay = false; - } - i++; - } - if (!okay) { - k--; //if any cov caliper checks failed, move left - continue; - } + for (i = 0; i < nt[gi]; i++) { + indt[indt_sep[gi] + i] = indt_tmp[i]; + ind_match[indt_tmp[i]] = i; } - - found = true; } - if (!found) k = -1; + IntegerVector ind_focal = indt[Range(indt_sep[focal], indt_sep[focal + 1] - 1)]; - return k; -} + std::vector times_matched(n, 0); -// [[Rcpp::plugins(cpp11)]] - -// [[Rcpp::export]] -IntegerMatrix nn_matchC_vec(const IntegerVector& treat_, - const IntegerVector& ord_, - const IntegerVector& ratio_, - const LogicalVector& discarded_, - const int& reuse_max, - const NumericVector& distance_, - const Nullable& exact_ = R_NilValue, - const Nullable& caliper_dist_ = R_NilValue, - const Nullable& caliper_covs_ = R_NilValue, - const Nullable& caliper_covs_mat_ = R_NilValue, - const Nullable& antiexact_covs_ = R_NilValue, - const Nullable& unit_id_ = R_NilValue, - const bool& disl_prog = false) { + std::vector times_matched_allowed(n, reuse_max); + for (i = 0; i < nf; i++) { + times_matched_allowed[ind_focal[i]] = ratio[i]; + } - int n = treat_.size(); + int max_ratio = max(ratio); - CharacterVector lab_ = treat_.names(); + // Output matrix with sample indices of control units + IntegerMatrix mm(nf, max_ratio); + mm.fill(NA_INTEGER); + CharacterVector lab = treat_.names(); //Use base::order() because faster than Rcpp implementation of order() Function o("order"); - IntegerVector d_ord = o(distance_, Named("decreasing") = false); - d_ord = d_ord - 1; - - IntegerVector treat = treat_[d_ord]; - NumericVector distance = distance_[d_ord]; - CharacterVector lab = lab_[d_ord]; - LogicalVector discarded = discarded_[d_ord]; - - IntegerVector ratio_tmp(n); - ratio_tmp[treat_ == 1] = ratio_; - IntegerVector ratio = ratio_tmp[d_ord]; + IntegerVector ind_d_ord = o(distance); + ind_d_ord = ind_d_ord - 1; //location of each unit after sorting - IntegerVector ord = ord_ - 1; + IntegerVector match_d_ord = o(ind_d_ord); + match_d_ord = match_d_ord - 1; - int max_ratio = max(ratio); - - IntegerVector ind = Range(0, n - 1); - IntegerVector ind0 = ind[treat == 0]; - IntegerVector ind1 = ind[treat == 1]; - int n0 = ind0.size(); - int n1 = ind1.size(); - ind1.names() = lab[ind1]; - - IntegerVector t(n); - IntegerVector t0(n0), t1(n1); - int i; - - //ind: 1 2 3 4 5 6 7 8 - //ind1: 2 3 5 7 - - IntegerMatrix mm(n1, max_ratio); - mm.fill(NA_INTEGER); - CharacterVector lab1 = lab[ind1]; - CharacterVector mm_nm = lab_[treat_ == 1]; + IntegerVector last_control(g); + last_control.fill(n - 1); + IntegerVector first_control(g); + first_control.fill(0); + //exact + bool use_exact = false; + IntegerVector exact; + if (exact_.isNotNull()) { + exact = as(exact_); + use_exact = true; + } //caliper_dist double caliper_dist; - bool use_caliper_dist = false; if (caliper_dist_.isNotNull()) { caliper_dist = as(caliper_dist_); - use_caliper_dist = true; + } + else { + caliper_dist = max_finite(distance) - min_finite(distance) + 1; } //caliper_covs NumericVector caliper_covs; NumericMatrix caliper_covs_mat; - bool use_caliper_covs = false; + int ncc = 0; if (caliper_covs_.isNotNull()) { caliper_covs = as(caliper_covs_); caliper_covs_mat = as(caliper_covs_mat_); - NumericVector tmp_cc(caliper_covs_mat.nrow()); - for (int i = 0; i < caliper_covs_mat.ncol(); i++) { - tmp_cc = caliper_covs_mat(_, i); - tmp_cc = tmp_cc[d_ord]; - caliper_covs_mat(_, i) = tmp_cc; - } - use_caliper_covs = true; - } - - //exact - bool use_exact = false; - IntegerVector exact; - if (exact_.isNotNull()) { - exact = as(exact_)[d_ord]; - use_exact = true; + ncc = caliper_covs_mat.ncol(); } //antiexact IntegerMatrix antiexact_covs; - bool use_antiexact = false; + int aenc = 0; if (antiexact_covs_.isNotNull()) { antiexact_covs = as(antiexact_covs_); - NumericVector tmp_ae(antiexact_covs.nrow()); - for (int i = 0; i < antiexact_covs.ncol(); i++) { - tmp_ae = antiexact_covs(_, i); - tmp_ae = tmp_ae[d_ord]; - antiexact_covs(_, i) = tmp_ae; - } - use_antiexact = true; + aenc = antiexact_covs.ncol(); } + //reuse_max + bool use_reuse_max = (reuse_max < nf); + //unit_id IntegerVector unit_id; bool use_unit_id = false; if (unit_id_.isNotNull()) { - unit_id = as(unit_id_)[d_ord]; + unit_id = as(unit_id_); use_unit_id = true; + use_reuse_max = true; } - IntegerVector times_matched = rep(0, n); - LogicalVector can_be_matched = !as(discarded); - - IntegerVector ind_cbm = ind[can_be_matched]; - int first_control = ind_cbm[0]; - int last_control = ind_cbm[ind_cbm.size() - 1]; + IntegerVector matches_i(1 + max_ratio * (g - 1)); + int k_total; //progress bar int prog_length; - prog_length = max_ratio*n1 + 1; - Progress p(prog_length, disl_prog); - - int ii, k, j, ck, r, row_to_fill; - int k_left, k_right; - double dti; - String labi; - IntegerVector mm_row, ck_; - bool done = false; - - for (r = 0; r < max_ratio; r++) { - for (i = 0; i < n1; i++) { - p.increment(); - - row_to_fill = ord[i]; - - labi = mm_nm[row_to_fill]; - - ii = ind1[labi]; //ii'th unit overall - - if (!can_be_matched[ii]) continue; - - mm_row = na_omit(mm(row_to_fill, _)); - - //find control unit to left and right - k_left = find_left(ii, first_control, - treat, - can_be_matched, - r, mm_row, - d_ord, - distance, use_caliper_dist, caliper_dist, - use_caliper_covs, caliper_covs, caliper_covs_mat, - use_exact, exact, - use_antiexact, antiexact_covs); - - k_right = find_right(ii, last_control, - treat, - can_be_matched, - r, mm_row, - d_ord, - distance, use_caliper_dist, caliper_dist, - use_caliper_covs, caliper_covs, caliper_covs_mat, - use_exact, exact, - use_antiexact, antiexact_covs); - - if ((k_left >= 0) && (k_right >= 0)) { - dti = distance[ii]; - if (std::abs(distance[k_left] - dti) <= std::abs(distance[k_right] - dti)) { - k = k_left; + if (use_reuse_max) prog_length = sum(ratio) + 1; + else prog_length = nf + 1; + ETAProgressBar pb; + Progress p(prog_length, disl_prog, pb); + + R_xlen_t c; + int r, t_id_t_i, t_id_i; + IntegerVector ck_; + std::vector k(max_ratio); + + int counter = 0; + + if (use_reuse_max) { + for (r = 1; r <= max_ratio; r++) { + for (auto it = ord.begin(); it != ord.end() && max(as(n_eligible[g_c])) > 0; ++it) { + // i: generic looping index + // t_id_t_i; index of treated unit to match among treated units + // t_id_i: index of treated unit to match among all units + counter++; + if (counter == 200) { + counter = 0; + Rcpp::checkUserInterrupt(); } - else { - k = k_right; + + t_id_t_i = *it - 1; + t_id_i = ind_focal[t_id_t_i]; + + if (r > times_matched_allowed[t_id_i]) { + continue; } - } - else if (k_left >= 0) { - k = k_left; - } - else if (k_right >= 0) { - k = k_right; - } - else { - can_be_matched[ii] = false; - continue; - } - mm( row_to_fill, r ) = d_ord[k]; + p.increment(); - if (use_unit_id) { - ck_ = ind[unit_id == unit_id[ii]]; + if (!eligible[t_id_i]) { + continue; + } - for (j = 0; j < ck_.size(); j++) { - ck = ck_[j]; - times_matched[ck]++; - if (times_matched[ck] >= ratio[ck]) { - can_be_matched[ck] = false; + k_total = 0; + + for (int gi : g_c) { + update_first_and_last_control(first_control, + last_control, + ind_d_ord, + eligible, + treat, + gi); + + k = find_control_vec(t_id_i, + ind_d_ord, + match_d_ord, + treat, + distance, + eligible, + gi, + r, + mm.row(t_id_t_i), + ncc, + caliper_covs_mat, + caliper_covs, + caliper_dist, + use_exact, + exact, + aenc, + antiexact_covs, + first_control, + last_control); + + if (k.empty()) { + if (r == 1) { + k_total = 0; + break; + } + continue; } + + matches_i[k_total] = k[0]; + k_total++; + } + + if (k_total == 0) { + eligible[t_id_i] = false; + n_eligible[focal]--; + continue; + } + + for (c = 0; c < k_total; c++) { + mm(t_id_t_i, sum(!is_na(mm(t_id_t_i, _)))) = matches_i[c]; + } + + matches_i[k_total] = t_id_i; + + ck_ = matches_i[Range(0, k_total)]; + + if (use_unit_id) { + ck_ = which(!is_na(match(unit_id, as(unit_id[ck_])))); } - ck_ = ind[unit_id == unit_id[k]]; + for (int ck : ck_) { + if (!eligible[ck]) { + continue; + } - for (j = 0; j < ck_.size(); j++) { - ck = ck_[j]; times_matched[ck]++; - if (times_matched[ck] >= reuse_max) { - can_be_matched[ck] = false; + if (times_matched[ck] >= times_matched_allowed[ck]) { + eligible[ck] = false; + n_eligible[treat[ck]]--; } } } - else { - times_matched[ii]++; - if (times_matched[ii] >= ratio[ii]) { - can_be_matched[ii] = false; - } + } + } + else { + for (auto it = ord.begin(); it != ord.end(); ++it) { + // i: generic looping index + // t_id_t_i; index of treated unit to match among treated units + // t_id_i: index of treated unit to match among all units + counter++; + if (counter == 200) { + counter = 0; + Rcpp::checkUserInterrupt(); + } - times_matched[k]++; - if (times_matched[k] >= reuse_max) { - can_be_matched[k] = false; - } + t_id_t_i = *it - 1; + t_id_i = ind_focal[t_id_t_i]; + + p.increment(); + + if (!eligible[t_id_i]) { + continue; } - if (!can_be_matched[ii]) { - if (any(as(can_be_matched[treat == 1])).is_false()) { - done = true; + k_total = 0; + + for (int gi : g_c) { + k = find_control_vec(t_id_i, + ind_d_ord, + match_d_ord, + treat, + distance, + eligible, + gi, + 1, + mm.row(t_id_t_i), + ncc, + caliper_covs_mat, + caliper_covs, + caliper_dist, + use_exact, + exact, + aenc, + antiexact_covs, + first_control, + last_control, + ratio[t_id_t_i]); + + if (k.empty()) { + k_total = 0; break; } - } - if (!can_be_matched[k]) { - if (any(as(can_be_matched[treat == 0])).is_false()) { - done = true; - break; + for (int cc : k) { + matches_i[k_total] = cc; + k_total++; } } - ind_cbm = ind[can_be_matched & (treat == 0)]; - if (!can_be_matched[first_control]) { - first_control = ind_cbm[0]; + if (k_total == 0) { + continue; } - if (!can_be_matched[last_control]) { - last_control = ind_cbm[ind_cbm.size() - 1]; + + for (c = 0; c < k_total; c++) { + mm(t_id_t_i, sum(!is_na(mm(t_id_t_i, _)))) = matches_i[c]; } } - - if (done) break; } p.update(prog_length); mm = mm + 1; - rownames(mm) = mm_nm; + rownames(mm) = as(lab[ind_focal]); return mm; } diff --git a/src/nn_matchC_vec_closest.cpp b/src/nn_matchC_vec_closest.cpp new file mode 100644 index 00000000..10748795 --- /dev/null +++ b/src/nn_matchC_vec_closest.cpp @@ -0,0 +1,348 @@ +// [[Rcpp::depends(RcppProgress)]] +#include "eta_progress_bar.h" +#include "internal.h" +using namespace Rcpp; + +// [[Rcpp::plugins(cpp11)]] + +// [[Rcpp::export]] +IntegerMatrix nn_matchC_vec_closest(const IntegerVector& treat, + const IntegerVector& ratio, + const LogicalVector& discarded, + const int& reuse_max, + const NumericVector& distance, + const Nullable& exact_ = R_NilValue, + const Nullable& caliper_dist_ = R_NilValue, + const Nullable& caliper_covs_ = R_NilValue, + const Nullable& caliper_covs_mat_ = R_NilValue, + const Nullable& antiexact_covs_ = R_NilValue, + const Nullable& unit_id_ = R_NilValue, + const bool& close = true, + const bool& disl_prog = false) { + + IntegerVector unique_treat = {0, 1}; + int g = unique_treat.size(); + int focal = 1; + + R_xlen_t n = treat.size(); + IntegerVector ind = Range(0, n - 1); + + R_xlen_t i; + int gi; + IntegerVector indt(n); + IntegerVector indt_sep(g + 1); + IntegerVector indt_tmp; + IntegerVector nt(g); + IntegerVector ind_match(n); + ind_match.fill(NA_INTEGER); + + LogicalVector eligible = !discarded; + + // IntegerVector g_c = Range(0, g - 1); + // g_c = g_c[g_c != focal]; + + IntegerVector n_eligible(g); + for (i = 0; i < n; i++) { + nt[treat[i]]++; + + if (eligible[i]) { + n_eligible[treat[i]]++; + } + } + + int nf = nt[focal]; + + indt_sep[0] = 0; + + for (gi = 0; gi < g; gi++) { + indt_sep[gi + 1] = indt_sep[gi] + nt[gi]; + + indt_tmp = ind[treat == gi]; + + for (i = 0; i < nt[gi]; i++) { + indt[indt_sep[gi] + i] = indt_tmp[i]; + ind_match[indt_tmp[i]] = i; + } + } + + IntegerVector ind_focal = indt[Range(indt_sep[focal], indt_sep[focal + 1] - 1)]; + + std::vector times_matched(n, 0); + + std::vector times_matched_allowed(n, reuse_max); + for (i = 0; i < nf; i++) { + times_matched_allowed[ind_focal[i]] = ratio[i]; + } + + int max_ratio = max(ratio); + + // Output matrix with sample indices of control units + IntegerMatrix mm(nf, max_ratio); + mm.fill(NA_INTEGER); + CharacterVector lab = treat.names(); + + //Use base::order() because faster than C++ std::sort() + Function o("order"); + + IntegerVector ind_d_ord = o(distance); + ind_d_ord = ind_d_ord - 1; + + IntegerVector match_d_ord = o(ind_d_ord); + match_d_ord = match_d_ord - 1; + + IntegerVector last_control(g); + last_control.fill(n - 1); + IntegerVector first_control(g); + first_control.fill(0); + + //exact + bool use_exact = false; + IntegerVector exact; + if (exact_.isNotNull()) { + exact = as(exact_); + use_exact = true; + } + + //caliper_dist + double caliper_dist; + if (caliper_dist_.isNotNull()) { + caliper_dist = as(caliper_dist_); + } + else { + caliper_dist = max_finite(distance) - min_finite(distance) + 1; + } + + //caliper_covs + NumericVector caliper_covs; + NumericMatrix caliper_covs_mat; + int ncc = 0; + if (caliper_covs_.isNotNull()) { + caliper_covs = as(caliper_covs_); + caliper_covs_mat = as(caliper_covs_mat_); + ncc = caliper_covs_mat.ncol(); + } + + //antiexact + IntegerMatrix antiexact_covs; + int aenc = 0; + if (antiexact_covs_.isNotNull()) { + antiexact_covs = as(antiexact_covs_); + aenc = antiexact_covs.ncol(); + } + + //unit_id + IntegerVector unit_id; + bool use_unit_id = false; + if (unit_id_.isNotNull()) { + unit_id = as(unit_id_); + use_unit_id = true; + } + + //storing closeness + std::vector t_id, c_id; + std::vector dist; + t_id.reserve(n_eligible[focal]); + c_id.reserve(n_eligible[focal]); + dist.reserve(n_eligible[focal]); + + gi = 0; + + update_first_and_last_control(first_control, + last_control, + ind_d_ord, + eligible, + treat, + gi); + + IntegerVector ck_; + + int c_id_i, t_id_t_i, t_id_i; + + int counter = 0; + int r = 1; + + IntegerVector heap_ord; + std::vector k(1); + R_xlen_t hi; + + //progress bar + R_xlen_t prog_length = sum(ratio) + 1; + ETAProgressBar pb; + Progress p(prog_length, disl_prog, pb); + + IntegerVector::iterator ci; + + std::function cmp; + if (close) { + cmp = [&dist](const int& a, const int& b) {return dist[a] < dist[b];}; + } + else { + cmp = [&dist](const int& a, const int& b) {return dist[a] >= dist[b];}; + } + + for (r = 1; r <= max_ratio; r++) { + //Find closest control unit to each treated unit + for (int ti : ind_focal) { + + if (!eligible[ti]) { + continue; + } + + counter++; + if (counter == 200) { + counter = 0; + Rcpp::checkUserInterrupt(); + } + + t_id_t_i = ind_match[ti]; + + k = find_control_vec(ti, + ind_d_ord, + match_d_ord, + treat, + distance, + eligible, + gi, + r, + mm.row(t_id_t_i), + ncc, + caliper_covs_mat, + caliper_covs, + caliper_dist, + use_exact, + exact, + aenc, + antiexact_covs, + first_control, + last_control); + + if (k.empty()) { + eligible[ti] = false; + n_eligible[focal]--; + continue; + } + + t_id.push_back(ti); + c_id.push_back(k[0]); + dist.push_back(std::abs(distance[ti] - distance[k[0]])); + } + + nf = dist.size(); + + //Order the list + heap_ord = o(dist, _["decreasing"] = !close); + heap_ord = heap_ord - 1; + + i = 0; + + //Go through ordered list and assign matches, re-matching when necessary + while (min(n_eligible) > 0 && i < nf) { + counter++; + if (counter == 200) { + counter = 0; + Rcpp::checkUserInterrupt(); + } + + hi = heap_ord[i]; + + t_id_i = t_id[hi]; + + if (!eligible[t_id_i]) { + i++; + continue; + } + + t_id_t_i = ind_match[t_id_i]; + + c_id_i = c_id[hi]; + + if (!eligible[c_id_i]) { + // If control isn't eligible, find new control and try again + update_first_and_last_control(first_control, + last_control, + ind_d_ord, + eligible, + treat, + gi); + + k = find_control_vec(t_id_i, + ind_d_ord, + match_d_ord, + treat, + distance, + eligible, + gi, + r, + mm.row(t_id_t_i), + ncc, + caliper_covs_mat, + caliper_covs, + caliper_dist, + use_exact, + exact, + aenc, + antiexact_covs, + first_control, + last_control, + 1, + c_id_i); + + //If no matches... + if (k.empty()) { + eligible[t_id_i] = false; + n_eligible[focal]--; + continue; + } + + c_id[hi] = k[0]; + dist[hi] = std::abs(distance[t_id_i] - distance[k[0]]); + + // Find new position of pair in heap + ci = std::lower_bound(heap_ord.begin() + i, heap_ord.end(), hi, + cmp); + + if (ci != heap_ord.begin() + i) { + std::rotate(heap_ord.begin() + i, heap_ord.begin() + i + 1, ci); + } + + continue; + } + + mm(t_id_t_i, sum(!is_na(mm(t_id_t_i, _)))) = c_id_i; + + ck_ = {c_id_i, t_id_i}; + + if (use_unit_id) { + ck_ = which(!is_na(match(unit_id, as(unit_id[ck_])))); + } + + for (int ck : ck_) { + + if (!eligible[ck]) { + continue; + } + + times_matched[ck]++; + if (times_matched[ck] >= times_matched_allowed[ck]) { + eligible[ck] = false; + n_eligible[treat[ck]]--; + } + } + + p.increment(); + + i++; + } + + t_id.clear(); + c_id.clear(); + dist.clear(); + } + + p.update(prog_length); + + mm = mm + 1; + rownames(mm) = as(lab[ind_focal]); + + return mm; +} \ No newline at end of file diff --git a/src/pairdistC.cpp b/src/pairdistC.cpp index 500b65a0..2625ba12 100644 --- a/src/pairdistC.cpp +++ b/src/pairdistC.cpp @@ -1,44 +1,44 @@ -#include #include "internal.h" using namespace Rcpp; +// [[Rcpp::plugins(cpp11)]] + // [[Rcpp::export]] -double pairdistsubC(const NumericVector& x_, - const IntegerVector& t_, - const IntegerVector& s_, - const int& num_sub) { +double pairdistsubC(const NumericVector& x, + const IntegerVector& t, + const IntegerVector& s) { double dist = 0; - LogicalVector not_na_sub = !is_na(s_); - NumericVector x = x_[not_na_sub]; - IntegerVector t = t_[not_na_sub]; - IntegerVector s = s_[not_na_sub]; + R_xlen_t i, j; + int s_i, ord_i, ord_j; + int k = 0; - int n = t.size(); - LogicalVector in_s_i(n); - NumericVector x_t0(n); - IntegerVector t_ind_s(n), c_ind_s(n); + Function o("order"); + IntegerVector ord = o(s); + ord = ord - 1; - int k = 0; - int i, i1, n1_s; - for (i = 1; i <= num_sub; ++i) { - in_s_i = (s == i); + R_xlen_t n = sum(!is_na(s)); - t_ind_s = which((t == 1) & in_s_i); - c_ind_s = which((t == 0) & in_s_i); + for (i = 0; i < n; i++) { + ord_i = ord[i]; + s_i = s[ord_i]; - n1_s = t_ind_s.size(); + for (j = i + 1; j < n; j++) { + ord_j = ord[j]; - x_t0 = x[c_ind_s]; + if (s[ord_j] != s_i) { + break; + } - for (i1 = 0; i1 < n1_s; ++i1) { - dist += sum(Rcpp::abs(x[t_ind_s[i1]] - x_t0)); + if (t[ord_j] == t[ord_i]) { + continue; + } + + k++; + dist += (std::abs(x[ord_j] - x[ord_i]) - dist) / k; } - k += n1_s * c_ind_s.size(); } - dist /= k; - return dist; } \ No newline at end of file diff --git a/src/subclass2mm.cpp b/src/subclass2mm.cpp index bdc5622a..07db755d 100644 --- a/src/subclass2mm.cpp +++ b/src/subclass2mm.cpp @@ -1,47 +1,129 @@ -#include #include "internal.h" using namespace Rcpp; +// [[Rcpp::plugins(cpp11)]] + //Turns subclass vector given as a factor into a numeric match.matrix. //focal is the treatment level (0/1) that corresponds to the rownames. // [[Rcpp::export]] -IntegerMatrix subclass2mmC(const IntegerVector& subclass, +IntegerMatrix subclass2mmC(const IntegerVector& subclass_, const IntegerVector& treat, const int& focal) { - IntegerVector tab = tabulateC_(subclass); - int mm_col = max(tab) - 1; + LogicalVector na_sub = is_na(subclass_); + IntegerVector unique_sub = unique(as(subclass_[!na_sub])); + IntegerVector subclass = match(subclass_, unique_sub) - 1; + + R_xlen_t nsub = unique_sub.size(); + + R_xlen_t n = treat.size(); + IntegerVector ind = Range(0, n - 1); + IntegerVector ind_focal = ind[treat == focal]; + R_xlen_t n1 = ind_focal.size(); + + IntegerVector subtab(nsub); + subtab.fill(-1); + + R_xlen_t i; + for (i = 0; i < n; i++) { + if (na_sub[i]) { + continue; + } + + subtab[subclass[i]]++; + } - IntegerVector ind = Range(0, treat.size() - 1); - IntegerVector ind1 = ind[treat == focal]; - int n1 = ind1.size(); + int mm_col = max(subtab); IntegerMatrix mm(n1, mm_col); mm.fill(NA_INTEGER); - rownames(mm) = as(treat.names())[ind1]; + CharacterVector lab = treat.names(); - IntegerVector in_sub = ind[!is_na(subclass)]; - IntegerVector ind_in_sub = ind[in_sub]; - IntegerVector ind0_in_sub = ind_in_sub[as(treat[in_sub]) != focal]; - IntegerVector sub0_in_sub = subclass[ind0_in_sub]; - - int i, t, s, nmc, mci; - IntegerVector mc(mm_col); + IntegerVector ss(n1); + ss.fill(NA_INTEGER); + int s, si; for (i = 0; i < n1; i++) { - t = ind1[i]; - s = subclass[t]; - - if (s != NA_INTEGER) { - mc = ind0_in_sub[sub0_in_sub == s]; - nmc = mc.size(); - for (mci = 0; mci < nmc; mci++) { - mm(i, mci) = mc[mci] + 1; + if (na_sub[ind_focal[i]]) { + continue; + } + + ss[i] = subclass[ind_focal[i]]; + } + + for (i = 0; i < n; i++) { + if (treat[i] == focal) { + continue; + } + + if (na_sub[i]) { + continue; + } + + si = subclass[i]; + + for (s = 0; s < n1; s++) { + if (!std::isfinite(ss[s])) { + continue; + } + + if (si != ss[s]) { + continue; } + + mm(s, sum(!is_na(mm(s, _)))) = i; + break; } } + mm = mm + 1; + rownames(mm) = as(lab[ind_focal]); + return mm; } +// [[Rcpp::export]] +IntegerVector mm2subclassC(const IntegerMatrix& mm, + const IntegerVector& treat, + const Nullable& focal = R_NilValue) { + + CharacterVector lab = treat.names(); + + R_xlen_t n1 = treat.size(); + + IntegerVector subclass(n1); + subclass.fill(NA_INTEGER); + subclass.names() = lab; + + IntegerVector ind1; + if (focal.isNotNull()) { + ind1 = which(treat == as(focal)); + } + else { + ind1 = match(as(rownames(mm)), lab) - 1; + } + + R_xlen_t r = mm.nrow(); + R_xlen_t ki = 0; + + for (R_xlen_t i : which(!is_na(mm))) { + if (i / r == 0) { + //If first entry in row, increment ki and assign subclass of treated + ki++; + subclass[ind1[i % r]] = ki; + } + + subclass[mm[i] - 1] = ki; + } + + CharacterVector levs(ki); + for (R_xlen_t j = 0; j < ki; j++){ + levs[j] = std::to_string(j + 1); + } + + subclass.attr("class") = "factor"; + subclass.attr("levels") = levs; + + return subclass; +} \ No newline at end of file diff --git a/src/subclass_scootC.cpp b/src/subclass_scootC.cpp new file mode 100644 index 00000000..67c7669d --- /dev/null +++ b/src/subclass_scootC.cpp @@ -0,0 +1,152 @@ +#include "internal.h" +using namespace Rcpp; + +// [[Rcpp::plugins(cpp11)]] + +// [[Rcpp::export]] +IntegerVector subclass_scootC(const IntegerVector& subclass_, + const IntegerVector& treat_, + const NumericVector& x_, + const int& min_n) { + + if (min_n == 0) { + return subclass_; + } + + int m, i, s, s2; + int best_i; + double best_x, score; + R_xlen_t nt; + + LogicalVector na_sub = is_na(subclass_); + + IntegerVector subclass = subclass_[!na_sub]; + IntegerVector treat = treat_[!na_sub]; + NumericVector x = x_[!na_sub]; + + R_xlen_t n = subclass.size(); + + IntegerVector unique_sub = unique(subclass); + std::sort(unique_sub.begin(), unique_sub.end()); + + subclass = match(subclass, unique_sub) - 1; + + R_xlen_t nsub = unique_sub.size(); + + NumericVector subtab(nsub); + IntegerVector indt; + bool left = false; + + IntegerVector ut = unique(treat); + + for (int t : ut) { + indt = which(treat == t); + nt = indt.size(); + + //Tabulate + subtab.fill(0.0); + + for (int i : indt) { + subtab[subclass[i]]++; + } + + for (m = 0; m < min_n; m++) { + while (min(subtab) <= 0) { + for (s = 0; s < nsub; s++) { + if (subtab[s] == 0) { + break; + } + } + + //Find which way to look for new member + if (s == nsub - 1) { + left = true; + } + else if (s == 0) { + left = false; + } + else { + score = 0.; + + for (s2 = 0; s2 < nsub; s2++) { + if (subtab[s2] <= 1) { + continue; + } + + if (s2 == s) { + continue; + } + + score += (subtab[s2] - 1) / static_cast(s2 - s); + } + + left = (score <= 0); + } + + //Find which subclass to take from (s2) + if (left) { + for (s2 = s - 1; s2 >= 0; s2--) { + if (subtab[s2] > 0) { + break; + } + } + } + else { + for (s2 = s + 1; s2 < nsub; s2++) { + if (subtab[s2] > 0) { + break; + } + } + } + + //Find unit with closest x in that subclass to take + for (i = 0; i < nt; i++) { + if (subclass[indt[i]] == s2) { + best_i = i; + best_x = x[indt[i]]; + break; + } + } + + for (i = best_i + 1; i < nt; i++) { + if (subclass[indt[i]] != s2) { + continue; + } + + if (left) { + if (x[indt[i]] < best_x) { + continue; + } + } + else { + if (x[indt[i]] > best_x) { + continue; + } + } + + best_i = i; + best_x = x[indt[i]]; + } + + subclass[indt[best_i]] = s; + subtab[s]++; + subtab[s2]--; + } + + for (s = 0; s < nsub; s++) { + subtab[s]--; + } + } + } + + for (i = 0; i < n; i++) { + subclass[i] = unique_sub[subclass[i]]; + } + + IntegerVector sub_out(subclass_.size()); + sub_out.fill(NA_INTEGER); + + sub_out[!na_sub] = subclass; + + return sub_out; +} \ No newline at end of file diff --git a/src/tabulateC.cpp b/src/tabulateC.cpp index e1d06c50..5ec5419f 100644 --- a/src/tabulateC.cpp +++ b/src/tabulateC.cpp @@ -5,5 +5,9 @@ using namespace Rcpp; // [[Rcpp::export]] IntegerVector tabulateC(const IntegerVector& bins, const Nullable& nbins = R_NilValue) { - return tabulateC_(bins, nbins); + + int nbins_ = 0; + if (nbins.isNotNull()) nbins_ = as(nbins); + + return tabulateC_(bins, nbins_); } \ No newline at end of file diff --git a/src/weights_matrixC.cpp b/src/weights_matrixC.cpp index dbbaf85f..10a923e3 100644 --- a/src/weights_matrixC.cpp +++ b/src/weights_matrixC.cpp @@ -1,54 +1,60 @@ -#include +#include "internal.h" using namespace Rcpp; -// Computes matching weights from match.matrix +// [[Rcpp::plugins(cpp11)]] +// Computes matching weights from match.matrix // [[Rcpp::export]] NumericVector weights_matrixC(const IntegerMatrix& mm, - const IntegerVector& treat) { - int n = treat.size(); - IntegerVector ind = Range(0, n - 1); - IntegerVector ind0 = ind[treat == 0]; - IntegerVector ind1 = ind[treat == 1]; - - NumericVector weights = rep(0., n); - // weights.fill(0); - - int nr = mm.nrow(); - int nc = mm.ncol(); - - int r, c, row_not_na, which_c, t_ind; - double weights_c, add_w; - IntegerVector row_r(nc); - - for (r = 0; r < nr; r++) { - row_r = na_omit(mm(r, _)); - row_not_na = row_r.size(); - if (row_not_na == 0) { - continue; - } - add_w = 1.0/static_cast(row_not_na); + const IntegerVector& treat_, + const Nullable& focal = R_NilValue) { - for (c = 0; c < row_not_na; c++) { - which_c = row_r[c] - 1; - weights_c = weights[which_c]; - weights[which_c] = weights_c + add_w; - } + CharacterVector lab = treat_.names(); + IntegerVector unique_treat = unique(treat_); + std::sort(unique_treat.begin(), unique_treat.end()); + int g = unique_treat.size(); + IntegerVector treat = match(treat_, unique_treat) - 1; - t_ind = ind1[r]; - weights[t_ind] = 1; + R_xlen_t n = treat.size(); + int gi; + + NumericVector weights(n); + weights.fill(0.0); + weights.names() = lab; + + IntegerVector row_ind; + if (focal.isNotNull()) { + row_ind = which(treat == as(focal)); } + else { + row_ind = match(as(rownames(mm)), lab) - 1; + } + + NumericVector matches_g = rep(0.0, g); - NumericVector c_weights = weights[ind0]; - double sum_c_w = sum(c_weights); - double sum_matched_c = sum(c_weights > 0); - int n0 = ind0.size(); + IntegerVector row_r(mm.ncol()); - if (sum_c_w > 0) { - for (int i = 0; i < n0; i++ ) { - which_c = ind0[i]; - weights[which_c] = c_weights[i] * sum_matched_c / sum_c_w; + for (int r : which(!is_na(mm(_, 0)))) { + + row_r = na_omit(mm.row(r)); + + for (gi = 0; gi < g; gi++) { + matches_g[gi] = 0.0; + } + + for (int i : row_r - 1) { + matches_g[treat[i]] += 1.0; } + + for (int i : row_r - 1) { + if (matches_g[treat[i]] == 0.0) { + continue; + } + + weights[i] += 1.0/matches_g[treat[i]]; + } + + weights[row_ind[r]] += 1.0; } return weights; diff --git a/tests/testthat/helpers.R b/tests/testthat/helpers.R index 72656448..236ceeb3 100644 --- a/tests/testthat/helpers.R +++ b/tests/testthat/helpers.R @@ -12,11 +12,14 @@ expect_good_matchit <- function(m, expect_subclass = NULL, expect_distance = NUL expect_s3_class(m, "matchit") + n <- length(m$treat) + n1 <- sum(m$treat == 1) + #Related to subclass if (!is.null(expect_subclass)) { if (expect_subclass) { expect_false(is.null(m$subclass)) - expect_length(m$sunclass, n) + expect_length(m$subclass, n) expect_true(is.factor(m$subclass)) expect_false(is.null(names(m$subclass))) } @@ -29,23 +32,27 @@ expect_good_matchit <- function(m, expect_subclass = NULL, expect_distance = NUL if (!is.null(expect_match.matrix)) { if (expect_match.matrix) { expect_false(is.null(m$match.matrix)) - expect_type(m$match.matrix, "matrix") + expect_true(is.matrix(m$match.matrix)) expect_true(is.character(m$match.matrix)) - expect_equal(nrow(m$match.matrix), n) - expect_equal(ncol(m$match.matrix), ratio) + expect_equal(nrow(m$match.matrix), n1) + expect_equal(ncol(m$match.matrix), max(ratio[1], attr(ratio, "max.controls"))) expect_false(is.null(rownames(m$match.matrix))) expect_false(any(rownames(m$match.matrix) %in% m$match.matrix)) #Check no duplicates within each row - expect_true(all(apply(m$match.matrix, 1, function(i) anyDuplicated(i[!is.na(i)]) == 0))) + expect_true(all(apply(m$match.matrix, 1, function(i) anyDuplicated(na.omit(i)) == 0))) if (!is.null(replace)) { if (replace) { - #May not be duplicates incidentially; make sure examples induce duplicates - expect_true(!isTRUE(all.equal(anyDuplicated(m$match.matrix), 0))) + #May not be duplicates incidentally; make sure examples induce duplicates + expect_true(!isTRUE(all.equal(anyDuplicated(na.omit(m$match.matrix)), 0))) + + if (!is.null(attr(replace, "reuse.max"))) { + expect_true(max(table(m$match.matrix)) <= attr(replace, "reuse.max")) + } } else { - expect_equal(anyDuplicated(m$match.matrix), 0) + expect_equal(anyDuplicated(na.omit(m$match.matrix)), 0) } } } @@ -55,7 +62,7 @@ expect_good_matchit <- function(m, expect_subclass = NULL, expect_distance = NUL } #Related to distance - if (!is.null(distance)) { + if (!is.null(expect_distance)) { if (expect_distance) { expect_false(is.null(m$distance)) expect_length(m$distance, n) diff --git a/tests/testthat/test-method_cem.R b/tests/testthat/test-method_cem.R new file mode 100644 index 00000000..4ac1a419 --- /dev/null +++ b/tests/testthat/test-method_cem.R @@ -0,0 +1,137 @@ +test_that("Coarsened exact matching works", { + set.seed(123) + k <- 6 + n <- 1e4 + + d <- as.data.frame(matrix(rnorm(k * n), nrow = n)) + + d[[1]] <- factor(cut(d[[1]], 4, labels = FALSE)) + d[[2]] <- factor(cut(d[[2]], 10, labels = FALSE)) + + d$a <- rbinom(n, 1, .3) + + m <- matchit(a ~ ., data = d, method = "cem") + + expect_good_matchit(m, expect_subclass = TRUE, expect_distance = FALSE, + expect_match.matrix = FALSE) + + #Categories are exactly matched by default + expect_true(all(sapply(levels(m$subclass), function(s) length(unique(d[[1]][which(m$subclass == s)])) == 1))) + expect_true(all(sapply(levels(m$subclass), function(s) length(unique(d[[2]][which(m$subclass == s)])) == 1))) + expect_false(all(sapply(levels(m$subclass), function(s) length(unique(d[[3]][which(m$subclass == s)])) == 1))) + + #k2k didn't accidentally activate + expect_true(length(unique(sapply(unique(m$treat), function(t) { + sum(m$weights[m$treat == t] > 0) + }))) > 1L) + + #Groupings: V1 into 2 categories, no grouping of V2 + m <- matchit(a ~ ., data = d, method = "cem", + grouping = list(V1 = list(c("1", "2"), c("3", "4")), + V2 = list(levels(d$V2)))) + + expect_good_matchit(m, expect_subclass = TRUE, expect_distance = FALSE, + expect_match.matrix = FALSE) + + + #Each subclass has V1 in 1,2 or 3,4 + expect_true(all(sapply(levels(m$subclass), function(s) all(d[[1]][which(m$subclass == s)] %in% c("1", "2")) || + all(d[[1]][which(m$subclass == s)] %in% c("3", "4"))))) + + #No restriction on bins for V2 + expect_false(all(sapply(levels(m$subclass), function(s) length(unique(d[[2]][which(m$subclass == s)])) == 1))) + + + m <- matchit(a ~ ., data = d, method = "cem", + grouping = list(V1 = list(c("1", "2"), c("3", "4")), + V2 = list(levels(d$V2))), + cutpoints = list(V3 = c(-1.5, 1.5), + V4 = 1)) + + expect_good_matchit(m, expect_subclass = TRUE, expect_distance = FALSE, + expect_match.matrix = FALSE) + + #Each subclass has V1 in 1,2 or 3,4 + expect_true(all(sapply(levels(m$subclass), function(s) all(d[[1]][which(m$subclass == s)] %in% c("1", "2")) || + all(d[[1]][which(m$subclass == s)] %in% c("3", "4"))))) + + #No restriction on bins for V2 + expect_false(all(sapply(levels(m$subclass), function(s) length(unique(d[[2]][which(m$subclass == s)])) == 1))) + + #V3 correctly split into defined bins + expect_true(all(sapply(levels(m$subclass), function(s) { + all(d[[3]][which(m$subclass == s)] < -1.5) || + all(d[[3]][which(m$subclass == s)] > -1.5 | d[[3]][which(m$subclass == s)] < 1.5) || + all(d[[3]][which(m$subclass == s)] > 1.5) + }))) + + #Setting V1 = 1 in cutpoints same as omitting it + m1 <- matchit(a ~ . - V4, data = d, method = "cem", + grouping = list(V1 = list(c("1", "2"), c("3", "4")), + V2 = list(levels(d$V2))), + cutpoints = list(V3 = c(-1.5, 1.5))) + + expect_equal(m$subclass, m1$subclass) + + m <- matchit(a ~ ., data = d, method = "cem", k2k = TRUE) + + expect_good_matchit(m, expect_subclass = TRUE, expect_distance = FALSE, + expect_match.matrix = TRUE, ratio = 1) + + #1:1 matched + expect_true(length(unique(sapply(unique(m$treat), function(t) { + sum(m$weights[m$treat == t] > 0) + }))) == 1L) + + #Default is using Mahalanobis + m1 <- matchit(a ~ ., data = d, method = "cem", k2k = TRUE, k2k.method = "mahalanobis") + expect_good_matchit(m, expect_subclass = TRUE, expect_distance = FALSE, + expect_match.matrix = TRUE, ratio = 1) + + expect_equal(m$match.matrix, m1$match.matrix) + expect_equal(m$subclass, m1$subclass) + + m1 <- matchit(a ~ ., data = d, method = "cem", k2k = TRUE, k2k.method = "scaled_euclidean") + + expect_good_matchit(m, expect_subclass = TRUE, expect_distance = FALSE, + expect_match.matrix = TRUE, ratio = 1) + + expect_failure(expect_equal(m$match.matrix, m1$match.matrix)) + expect_failure(expect_equal(m$subclass, m1$subclass)) + + m1 <- matchit(a ~ ., data = d, method = "cem", k2k = TRUE, k2k.method = "manhattan") + + expect_good_matchit(m, expect_subclass = TRUE, expect_distance = FALSE, + expect_match.matrix = TRUE, ratio = 1) + + expect_failure(expect_equal(m$match.matrix, m1$match.matrix)) + expect_failure(expect_equal(m$subclass, m1$subclass)) + + m1 <- matchit(a ~ ., data = d, method = "cem", k2k = TRUE, m.order = "data") + + expect_good_matchit(m, expect_subclass = TRUE, expect_distance = FALSE, + expect_match.matrix = TRUE, ratio = 1) + + expect_equal(m$match.matrix, m1$match.matrix) + expect_equal(m$subclass, m1$subclass) + + m1 <- matchit(a ~ ., data = d, method = "cem", k2k = TRUE, m.order = "closest") + + expect_good_matchit(m, expect_subclass = TRUE, expect_distance = FALSE, + expect_match.matrix = TRUE, ratio = 1) + + expect_failure(expect_equal(m$match.matrix, m1$match.matrix)) + expect_failure(expect_equal(m$subclass, m1$subclass)) + + m2 <- matchit(a ~ ., data = d, method = "cem") + m2$subclass[is.na(m2$subclass)] <- m2$subclass[!is.na(m2$subclass)][1] + + # Equivalent to NN matching with exact matching on subclass + suppressWarnings({ + m1 <- matchit(a ~ ., data = d, method = "nearest", distance = "mahalanobis", + discard = m$weights == 0, exact = ~m2$subclass) + }) + + expect_equal(m$match.matrix, m1$match.matrix) + expect_equal(m$subclass, m1$subclass) +}) diff --git a/tests/testthat/test-method_nearest.R b/tests/testthat/test-method_nearest.R new file mode 100644 index 00000000..51c2a962 --- /dev/null +++ b/tests/testthat/test-method_nearest.R @@ -0,0 +1,120 @@ +test_that("distance vector, mah vars, and distance matrix yield identical results", { + set.seed(1234) + n <- 1e3 + p <- runif(n, 0, .4) + x <- runif(n) + g <- sample(1:5, n, TRUE) + a <- rbinom(n, 1, p) + u <- 1:n; u[a == 0] <- sample(u[a == 0][1:round(sum(a == 0)/5)], sum(a == 0), replace = TRUE) + dis <- as.logical(rbinom(n, 1, .1)) + d <- data.frame(p, x, a, g, u, dis) + d$p_ <- d$p + + dd <- euclidean_dist(a ~ p, data = d) + + test_all <- function(..., which = 1:4) { + + M <- list() + if (any(which == 1)) { + matchit_try({ + m <- matchit(a ~ p + p_, data = d, + distance = d$p, + ...) + }, dont_warn_if = "Fewer control") + expect_good_matchit(m, expect_distance = TRUE, expect_match.matrix = TRUE, + expect_subclass = !m$info$replace, replace = m$info$replace, + ratio = m$info$ratio) + M <- c(M, list(m)) + } + if (any(which == 2)) { + matchit_try({ + m <- matchit(a ~ p + p_, data = d, + distance = "euclidean", + ...) + }, dont_warn_if = "Fewer control") + expect_good_matchit(m, expect_distance = FALSE, expect_match.matrix = TRUE, + expect_subclass = !m$info$replace, replace = m$info$replace, + ratio = m$info$ratio) + M <- c(M, list(m)) + } + if (any(which == 3)) { + matchit_try({ + m <- matchit(a ~ p + p_, data = d, + distance = d$p, + mahvars = ~p + p_, + ...) + }, dont_warn_if = "Fewer control") + expect_good_matchit(m, expect_distance = TRUE, expect_match.matrix = TRUE, + expect_subclass = !m$info$replace, replace = m$info$replace, + ratio = m$info$ratio) + M <- c(M, list(m)) + } + if (any(which == 4)) { + matchit_try({ + m <- matchit(a ~ p + p_, data = d, + distance = dd, + ...) + }, dont_warn_if = "Fewer control") + expect_good_matchit(m, expect_distance = FALSE, expect_match.matrix = TRUE, + expect_subclass = !m$info$replace, replace = m$info$replace, + ratio = m$info$ratio) + M <- c(M, list(m)) + } + + all(unlist(lapply(M[-1], function(m) isTRUE(all.equal(M[[1]]$match.matrix, + m$match.matrix))))) + } + + expect_true(test_all(m.order = "data")) + expect_true(test_all(m.order = "closest")) + expect_true(test_all(m.order = "farthest")) + expect_true(test_all(m.order = "largest", which = c(1, 3))) + expect_true(test_all(m.order = "smallest", which = c(1, 3))) + + expect_true(test_all(m.order = "data", ratio = 2)) + expect_true(test_all(m.order = "closest", ratio = 2)) + + expect_true(test_all(m.order = "data", ratio = 2, max.controls = 3, which = c(1, 3))) + expect_true(test_all(m.order = "closest", ratio = 2, max.controls = 3, which = c(1, 3))) + expect_true(test_all(m.order = "largest", ratio = 2, max.controls = 3, which = c(1, 3))) + + expect_true(test_all(m.order = "data", ratio = 2, replace = TRUE)) + expect_true(test_all(m.order = "closest", ratio = 2, replace = TRUE)) + + expect_true(test_all(m.order = "data", ratio = 2, reuse.max = 3)) + expect_true(test_all(m.order = "closest", ratio = 2, reuse.max = 3)) + + expect_true(test_all(m.order = "data", ratio = 2, caliper = .001, std.caliper = FALSE, which = c(1, 3))) + expect_true(test_all(m.order = "closest", ratio = 2, caliper = .001, std.caliper = FALSE, which = c(1, 3))) + expect_true(test_all(m.order = "largest", ratio = 2, caliper = .001, std.caliper = FALSE, which = c(1, 3))) + + expect_true(test_all(m.order = "data", ratio = 2, caliper = c(p = .001), std.caliper = FALSE)) + expect_true(test_all(m.order = "closest", ratio = 2, caliper = c(p = .001), std.caliper = FALSE)) + + expect_true(test_all(m.order = "data", ratio = 2, caliper = c(p = .001), std.caliper = FALSE, reuse.max = 3)) + expect_true(test_all(m.order = "closest", ratio = 2, caliper = c(p = .001), std.caliper = FALSE, reuse.max = 3)) + + expect_true(test_all(m.order = "data", ratio = 2, exact = ~g)) + expect_true(test_all(m.order = "closest", ratio = 2, exact = ~g)) + + expect_true(test_all(m.order = "data", ratio = 2, exact = ~g, replace = TRUE)) + expect_true(test_all(m.order = "closest", ratio = 2, exact = ~g, replace = TRUE)) + + expect_true(test_all(m.order = "data", ratio = 2, antiexact = ~g)) + expect_true(test_all(m.order = "closest", ratio = 2, antiexact = ~g)) + + expect_true(test_all(m.order = "data", ratio = 2, antiexact = ~g, replace = TRUE)) + expect_true(test_all(m.order = "closest", ratio = 2, antiexact = ~g, replace = TRUE)) + + expect_true(test_all(m.order = "data", ratio = 2, discard = dis)) + expect_true(test_all(m.order = "closest", ratio = 2, discard = dis)) + + expect_true(test_all(m.order = "data", ratio = 2, unit.id = ~u)) + expect_true(test_all(m.order = "closest", ratio = 2, unit.id = ~u)) + + expect_true(test_all(m.order = "data", ratio = 2, unit.id = ~u, reuse.max = 3)) + expect_true(test_all(m.order = "closest", ratio = 2, unit.id = ~u, reuse.max = 3)) + + expect_true(test_all(m.order = "data", ratio = 2, unit.id = ~u, replace = TRUE)) + expect_true(test_all(m.order = "closest", ratio = 2, unit.id = ~u, replace = TRUE)) +}) diff --git a/vignettes/MatchIt.Rmd b/vignettes/MatchIt.Rmd index eff196bf..08fda507 100644 --- a/vignettes/MatchIt.Rmd +++ b/vignettes/MatchIt.Rmd @@ -7,7 +7,7 @@ output: toc: yes vignette: | %\VignetteIndexEntry{MatchIt: Getting Started} - %\VignetteEngine{knitr::rmarkdown} + %\VignetteEngine{knitr::rmarkdown_notangle} %\VignetteEncoding{UTF-8} bibliography: references.bib link-citations: true diff --git a/vignettes/assessing-balance.Rmd b/vignettes/assessing-balance.Rmd index a227673c..dd4358d8 100644 --- a/vignettes/assessing-balance.Rmd +++ b/vignettes/assessing-balance.Rmd @@ -7,7 +7,7 @@ output: toc: true vignette: > %\VignetteIndexEntry{Assessing Balance} - %\VignetteEngine{knitr::rmarkdown} + %\VignetteEngine{knitr::rmarkdown_notangle} %\VignetteEncoding{UTF-8} bibliography: references.bib link-citations: true @@ -281,7 +281,7 @@ love.plot(m.out, stats = c("m", "ks"), poly = 2, abs = TRUE, position = "bottom") ``` -The `love.plot()` documentation explains what each of these arguments do and the several other ones available. See `vignette("cobalt_A4_love.plot", package = "cobalt")` for other advanced customization of `love.plot()`. +The `love.plot()` documentation explains what each of these arguments do and the several other ones available. See `vignette("love.plot", package = "cobalt")` for other advanced customization of `love.plot()`. ### `bal.plot()` diff --git a/vignettes/estimating-effects.Rmd b/vignettes/estimating-effects.Rmd index ed0fded9..63714d0e 100644 --- a/vignettes/estimating-effects.Rmd +++ b/vignettes/estimating-effects.Rmd @@ -7,7 +7,7 @@ output: toc: true vignette: > %\VignetteIndexEntry{Estimating Effects After Matching} - %\VignetteEngine{knitr::rmarkdown} + %\VignetteEngine{knitr::rmarkdown_notangle} %\VignetteEncoding{UTF-8} bibliography: references.bib link-citations: true @@ -33,6 +33,7 @@ me_ok <- requireNamespace("marginaleffects", quietly = TRUE) && su_ok <- requireNamespace("survival", quietly = TRUE) boot_ok <- requireNamespace("boot", quietly = TRUE) ``` + ```{r, include = FALSE} #Generating data similar to Austin (2009) for demonstrating treatment effect estimation gen_X <- function(n) { @@ -108,7 +109,7 @@ This guide is structured as follows: first, information on the concepts related Before an effect is estimated, the estimand must be specified and clarified. Although some aspects of the estimand depend not only on how the effect is estimated after matching but also on the matching method itself, other aspects must be considered at the time of effect estimation and interpretation. Here, we consider three aspects of the estimand: the population the effect is meant to generalize to (the target population), the effect measure, and whether the effect is marginal or conditional. -**The target population.** Different matching methods allow you to estimate effects that can generalize to different target populations. The most common estimand in matching is the average treatment effect in the treated (ATT), which is the average effect of treatment for those who receive treatment. This estimand is estimable for matching methods that do not change the treated units (i.e., by weighting or discarding units) and is requested in `matchit()` by setting `estimand = "ATT"` (which is the default). The average treatment effect in the population (ATE) is the average effect of treatment for the population from which the sample is a random sample. This estimand is estimable only for methods that allow the ATE and either do not discard units from the sample or explicit target full sample balance, which in `MatchIt` is limited to full matching, subclassification, and template matching when setting `estimand = "ATE"`. When treated units are discarded (e.g., through the use of common support restrictions, calipers, cardinality matching, or [coarsened] exact matching), the estimand corresponds to neither the population ATT nor the population ATE, but rather to an average treatment effect in the remaining matched sample (ATM), which may not correspond to any specific target population. See @greiferChoosingEstimandWhen2021 for a discussion on the substantive considerations involved when choosing the target population of the estimand. +**The target population.** Different matching methods allow you to estimate effects that can generalize to different target populations. The most common estimand in matching is the average treatment effect in the treated (ATT), which is the average effect of treatment for those who receive treatment. This estimand is estimable for matching methods that do not change the treated units (i.e., by weighting or discarding units) and is requested in `matchit()` by setting `estimand = "ATT"` (which is the default). The average treatment effect in the population (ATE) is the average effect of treatment for the population from which the sample is a random sample. This estimand is estimable only for methods that allow the ATE and either do not discard units from the sample or explicit target full sample balance, which in `MatchIt` is limited to full matching, subclassification, and profile matching when setting `estimand = "ATE"`. When treated units are discarded (e.g., through the use of common support restrictions, calipers, cardinality matching, or [coarsened] exact matching), the estimand corresponds to neither the population ATT nor the population ATE, but rather to an average treatment effect in the remaining matched sample (ATM), which may not correspond to any specific target population. See @greiferChoosingEstimandWhen2021 for a discussion on the substantive considerations involved when choosing the target population of the estimand. **Marginal and conditional effects.** A marginal effect is a comparison between the expected potential outcome under treatment and the expected potential outcome under control. This is the same quantity estimated in randomized trials without blocking or covariate adjustment and is particularly useful for quantifying the overall effect of a policy or population-wide intervention. A conditional effect is the comparison between the expected potential outcomes in the treatment groups within strata. This is useful for identifying the effect of a treatment for an individual patient or a subset of the population. @@ -190,7 +191,7 @@ library("marginaleffects") All effect estimates will be computed using `marginaleffects::avg_comparions()`, even when its use may be superfluous (e.g., for performing a t-test in the matched set). As previously mentioned, this is because it is useful to have a single workflow that works no matter the situation, perhaps with very slight modifications to accommodate different contexts. Using `avg_comparions()` has several advantages, even when the alternatives are simple: it only provides the effect estimate, and not other coefficients; it automatically incorporates robust and cluster-robust SEs if requested; and it always produces average marginal effects for the correct population if requested. -Other packages may be of use but are not used here. There are alternatives to the `marginaleffects` package for computing average marginal effects, including `margins` and `stdReg`. The `survey` package can be used to estimate robust SEs incorporating weights and provides functions for survey-weighted generalized linear models and Cox-proportional hazards models. It is often used with propensity score weighting. +Other packages may be of use but are not used here. There are alternatives to the `marginaleffects` package for computing average marginal effects, including `margins` and `stdReg`. The `survey` package can be used to estimate robust SEs incorporating weights and provides functions for survey-weighted generalized linear models and Cox-proportional hazards models. ### The Standard Case @@ -198,7 +199,7 @@ For almost all matching methods, whether a caliper, common support restriction, [^3]: The matching weights are not necessary when performing 1:1 matching, but we include them here for generality. When weights are not necessary, including them does not affect the estimates. Because it may not always be clear when weights are required, we recommend always including them. -There are a few adjustments that need to be made for certain scenarios, which we describe in the section "Adjustments to the Standard Case". These adjustments include for the following cases: when matching for the ATE rather than the ATT, for matching with replacement, for matching with a method that doesn't involve creating pairs (e.g., cardinality and template matching and coarsened exact matching), for subclassification, for estimating effects with binary outcomes, and for estimating effects with survival outcomes. You must read the Standard Case to understand the basic procedure before reading about these special scenarios. +There are a few adjustments that need to be made for certain scenarios, which we describe in the section "Adjustments to the Standard Case". These adjustments include for the following cases: when matching for the ATE rather than the ATT, for matching with replacement, for matching with a method that doesn't involve creating pairs (e.g., cardinality and profile matching and coarsened exact matching), for subclassification, for estimating effects with binary outcomes, and for estimating effects with survival outcomes. You must read the Standard Case to understand the basic procedure before reading about these special scenarios. Here, we demonstrate the faster analytic approach to estimating confidence intervals; for the bootstrap approach, see the section "Using Bootstrapping to Estimate Confidence Intervals" below. @@ -235,24 +236,22 @@ Next, we use `marginaleffects::avg_comparisons()` to estimate the ATT. ```{r, eval=me_ok} avg_comparisons(fit1, variables = "A", vcov = ~subclass, - newdata = subset(md, A == 1), - wts = "weights") + newdata = subset(A == 1)) ``` -Let's break down the call to `avg_comparisons()`: to the first argument, we supply the model fit, `fit1`; to the `variables` argument, the name of the treatment (`"A"`); to the `vcov` argument, a formula with subclass membership (`~subclass`) to request cluster-robust SEs; to the `newdata` argument, a version of the matched dataset containing only the treated units (`subset(md, A == 1)`) to request the ATT; and to the `wts` argument, the names of the variable in `md` containing the matching weights (`"weights"`) to ensure they are included in the analysis. Some of these arguments differ depending on the specifics of the matching method and outcome type; see the sections below for information. +Let's break down the call to `avg_comparisons()`: to the first argument, we supply the model fit, `fit1`; to the `variables` argument, the name of the treatment (`"A"`); to the `vcov` argument, a formula with subclass membership (`~subclass`) to request cluster-robust SEs; and to the `newdata` argument, a version of the matched dataset containing only the treated units (`subset(A == 1)`) to request the ATT. Some of these arguments differ depending on the specifics of the matching method and outcome type; see the sections below for information. If, in addition to the effect estimate, we want the average estimated potential outcomes, we can use `marginaleffects::avg_predictions()`, which we demonstrate below. Note the interpretation of the resulting estimates as the expected potential outcomes is only valid if all covariates present in the outcome model (if any) are interacted with the treatment. ```{r, eval=me_ok && packageVersion("marginaleffects") >= "0.11.0"} avg_predictions(fit1, variables = "A", vcov = ~subclass, - newdata = subset(md, A == 1), - wts = "weights") + newdata = subset(A == 1)) ``` We can see that the difference in potential outcome means is equal to the average treatment effect computed previously[^4]. All of the arguments to `avg_predictions()` are the same as those to `avg_comparisons()`. -[^4]: To verify that they are equal, supply the output of `avg_predictions()` to `hypotheses(), e.g., `avg_predictions(...) |> hypotheses("revpairwise")`; this explicitly compares the average potential outcomes and should yield identical estimates to the `avg_comparisons()` call. +[^4]: To verify that they are equal, supply the output of `avg_predictions()` to `hypotheses(), e.g.,`avg_predictions(...) \|\> hypotheses("revpairwise")`; this explicitly compares the average potential outcomes and should yield identical estimates to the`avg_comparisons()\` call. ### Adjustments to the Standard Case @@ -260,7 +259,7 @@ This section explains how the procedure might differ if any of the following spe #### Matching for the ATE -When matching for the ATE (including [coarsened] exact matching, full matching, subclassification, and cardinality matching), everything is identical to the Standard Case except that in the calls to `avg_comparisons()` and `avg_predictions()`, the `newdata` argument is omitted (or can be replaced with `newdata = md`). This is because the estimated potential outcomes are computed for the full matched sample rather than just the treated units. +When matching for the ATE (including [coarsened] exact matching, full matching, subclassification, and cardinality matching), everything is identical to the Standard Case except that in the calls to `avg_comparisons()` and `avg_predictions()`, the `newdata` argument is omitted. This is because the estimated potential outcomes are computed for the full sample rather than just the treated units. #### Matching with replacement @@ -270,7 +269,7 @@ Because control units do not belong to unique pairs, there is no pair membership #### Matching without pairing -Some matching methods do not involve creating pairs; these include cardinality and template matching with `mahvars = NULL` (the default), exact matching, and coarsened exact matching with `k2k = FALSE` (the default). The only change that needs to be made to the Standard Case is that one should change `vcov = ~subclass` to `vcov = "HC3"` in the calls to `avg_comparisons()` and `avg_predictions()` to use robust SEs instead of cluster-robust SEs. Remember that if matching is done for the ATE (even if units are dropped), the `newdata` argument should be dropped. +Some matching methods do not involve creating pairs; these include cardinality and profile matching with `mahvars = NULL` (the default), exact matching, and coarsened exact matching with `k2k = FALSE` (the default). The only change that needs to be made to the Standard Case is that one should change `vcov = ~subclass` to `vcov = "HC3"` in the calls to `avg_comparisons()` and `avg_predictions()` to use robust SEs instead of cluster-robust SEs. Remember that if matching is done for the ATE (even if units are dropped), the `newdata` argument should be dropped. #### Propensity score subclassification @@ -280,7 +279,7 @@ There are two natural ways to estimate marginal effects after subclassification: All of the methods described above for the Standard Case also work with MMWS because the formation of the weights is the same; the only difference is that it is not appropriate to use cluster-robust SEs with MMWS because of how few clusters are present, so one should change `vcov = ~subclass` to `vcov = "HC3"` in the calls to `avg_comparisons()` and `avg_predictions()` to use robust SEs instead of cluster-robust SEs. The subclasses can optionally be included in the outcome model (optionally interacting with treatment) as an alternative to including the propensity score. -The subclass-specific approach omits the weights and uses the subclasses directly. It is only appropriate when there are a small number of subclasses relative to the sample size. In the outcome model, `subclass` should interact with all other predictors in the model (including the treatment, covariates, and interactions, if any), and the `weights` argument should be omitted. In the calls to `avg_comparisons()` and `avg_predictions()`, the `wts` argument should be omitted. As with MMWS, one should change `vcov = ~subclass` to `vcov = "HC3"` in the calls to `avg_comparisons()` and `avg_predictions()`. See an example below: +The subclass-specific approach omits the weights and uses the subclasses directly. It is only appropriate when there are a small number of subclasses relative to the sample size. In the outcome model, `subclass` should interact with all other predictors in the model (including the treatment, covariates, and interactions, if any), and the `weights` argument should be omitted. As with MMWS, one should change `vcov = ~subclass` to `vcov = "HC3"` in the calls to `avg_comparisons()` and `avg_predictions()`. See an example below: ```{r, eval=me_ok} #Subclassification on the PS for the ATT @@ -298,7 +297,7 @@ fitS <- lm(Y_C ~ subclass * (A * (X1 + X2 + X3 + X4 + X5 + avg_comparisons(fitS, variables = "A", vcov = "HC3", - newdata = subset(md, A == 1)) + newdata = subset(A == 1)) ``` A model with fewer terms may be required when subclasses are small; removing covariates or their interactions with treatment may be required and can increase precision in smaller datasets. Remember that if subclassification is done for the ATE (even if units are dropped), the `newdata` argument should be dropped. @@ -311,7 +310,7 @@ To fit a logistic regression model, change `lm()` to `glm()` and set `family = q [^6]: We use `quasibinomial()` instead of `binomial()` simply to avoid a spurious warning that can occur with certain kinds of matching; the results will be identical regardless. -[^7]: Note that for low or high average expected risks computed with `predictions()`, the confidence intervals may go below 0 or above 1; this is because an approximation is used. To avoid this problem, bootstrapping or simulation-based inference can be used instead. +[^7]: Note that for low or high average expected risks computed with `avg_predictions()`, the confidence intervals may go below 0 or above 1; this is because an approximation is used. To avoid this problem, bootstrapping or simulation-based inference can be used instead. To compute the marginal RR, we need to add `comparison = "lnratioavg"` to `avg_comparisons()`; this computes the marginal log RR. To get the marginal RR, we need to add `transform = "exp"` to `avg_comparisons()`, which exponentiates the marginal log RR and its confidence interval. The code below computes the effects and displays the statistics of interest: @@ -326,8 +325,7 @@ fit2 <- glm(Y_B ~ A * (X1 + X2 + X3 + X4 + X5 + avg_comparisons(fit2, variables = "A", vcov = ~subclass, - newdata = subset(md, A == 1), - wts = "weights", + newdata = subset(A == 1), comparison = "lnratioavg", transform = "exp") ``` @@ -340,7 +338,7 @@ For the marginal OR, the only thing that needs to change is that `comparison` sh There are several measures of effect size for survival outcomes. When using the Cox proportional hazards model, the quantity of interest is the hazard ratio (HR) between the treated and control groups. As with the OR, the HR is non-collapsible, which means the estimated HR will only be a valid estimate of the marginal HR when no other covariates are included in the model. Other effect measures, such as the difference in mean survival times or probability of survival after a given time, can be treated just like continuous and binary outcomes as previously described. -For the HR, we cannot compute average marginal effects and must use the coefficient on treatment in a Cox model fit without covariates[^8]. This means that we cannot use the procedures from the Standard Case. Here we describe estimating the marginal HR using `coxph()` from the `survival` package. (See `help("coxph", package = "survival")` for more information on this model.) To request cluster-robust SEs as recommended by @austin2013a, we need to supply pair membership (stored in the `subclass` column of `md`) to the `cluster` argument and set `robust = TRUE`. For matching methods that don't involve pairing (e.g., cardinality and template matching and [coarsened] exact matching), we can omit the `cluster` argument (but keep `robust = TRUE`)[^9]. +For the HR, we cannot compute average marginal effects and must use the coefficient on treatment in a Cox model fit without covariates[^8]. This means that we cannot use the procedures from the Standard Case. Here we describe estimating the marginal HR using `coxph()` from the `survival` package. (See `help("coxph", package = "survival")` for more information on this model.) To request cluster-robust SEs as recommended by @austin2013a, we need to supply pair membership (stored in the `subclass` column of `md`) to the `cluster` argument and set `robust = TRUE`. For matching methods that don't involve pairing (e.g., cardinality and profile matching and [coarsened] exact matching), we can omit the `cluster` argument (but keep `robust = TRUE`)[^9]. [^8]: It is not immediately clear how to estimate a marginal HR when covariates are included in the outcome model; though @austin2020 describe several ways of including covariates in a model to estimate the marginal HR, they do not develop SEs and little research has been done on this method, so we will not present it here. Instead, we fit a simple Cox model with the treatment as the sole predictor. @@ -424,12 +422,12 @@ boot_fun <- function(data, i) { #Estimated potential outcomes under treatment p1 <- predict(fit, type = "response", newdata = transform(md1, A = 1)) - Ep1 <- weighted.mean(p1, md1$weights) + Ep1 <- mean(p1) #Estimated potential outcomes under control p0 <- predict(fit, type = "response", newdata = transform(md1, A = 0)) - Ep0 <- weighted.mean(p0, md1$weights) + Ep0 <- mean(p0) #Risk ratio return(Ep1 / Ep0) @@ -449,7 +447,7 @@ boot.ci(boot_out, type = "perc") ```{r, include = FALSE} b <- { - if (boot_ok) boot.ci(boot_out, type = "perc") + if (boot_ok) boot::boot.ci(boot_out, type = "perc") else list(t0 = 1.347, percent = c(0, 0, 0, 1.144, 1.891)) } ``` @@ -500,12 +498,12 @@ cluster_boot_fun <- function(pairs, i) { #Estimated potential outcomes under treatment p1 <- predict(fit, type = "response", newdata = transform(md1, A = 1)) - Ep1 <- weighted.mean(p1, md1$weights) + Ep1 <- mean(p1) #Estimated potential outcomes under control p0 <- predict(fit, type = "response", newdata = transform(md1, A = 0)) - Ep0 <- weighted.mean(p0, md1$weights) + Ep0 <- mean(p0) #Risk ratio return(Ep1 / Ep0) @@ -526,7 +524,7 @@ boot.ci(cluster_boot_out, type = "perc") ```{r, include = FALSE} b <- { - if (boot_ok) boot.ci(cluster_boot_out, type = "perc") + if (boot_ok) boot::boot.ci(cluster_boot_out, type = "perc") else list(t0 = 1.588, percent = c(0,0,0, 1.348, 1.877)) } ``` @@ -565,8 +563,7 @@ To estimate the subgroup ATTs, we can use `avg_comparisons()`, this time specify ```{r, eval=me_ok} avg_comparisons(fitP, variables = "A", vcov = ~subclass, - newdata = subset(md, A == 1), - wts = "weights", + newdata = subset(A == 1), by = "X5") ``` @@ -575,8 +572,7 @@ We can see that the subgroup mean differences are quite similar to each other. F ```{r, eval=me_ok} avg_comparisons(fitP, variables = "A", vcov = ~subclass, - newdata = subset(md, A == 1), - wts = "weights", + newdata = subset(A == 1), by = "X5", hypothesis = "pairwise") ``` diff --git a/vignettes/matching-methods.Rmd b/vignettes/matching-methods.Rmd index ac6483a3..2d667937 100644 --- a/vignettes/matching-methods.Rmd +++ b/vignettes/matching-methods.Rmd @@ -7,7 +7,7 @@ output: toc: true vignette: > %\VignetteIndexEntry{Matching Methods} - %\VignetteEngine{knitr::rmarkdown} + %\VignetteEngine{knitr::rmarkdown_notangle} %\VignetteEncoding{UTF-8} bibliography: references.bib link-citations: true @@ -32,7 +32,7 @@ Matching is nonparametric in the sense that the estimated weights and pruning of It is important to note that this implementation of matching differs from the methods described by Abadie and Imbens [-@abadie2006; -@abadie2016] and implemented in the `Matching` R package and `teffects` routine in Stata. That form of matching is *matching imputation*, where the missing potential outcomes for each unit are imputed using the observed outcomes of paired units. This is a critical distinction because matching imputation is a specific estimation method with its own effect and standard error estimators, in contrast to subset selection, which is a preprocessing method that does not require specific estimators and is broadly compatible with other parametric and nonparametric analyses. The benefits of matching imputation are that its theoretical properties (i.e., the rate of convergence and asymptotic variance of the estimator) are well understood, it can be used in a straightforward way to estimate not just the average treatment effect in the treated (ATT) but also the average treatment effect in the population (ATE), and additional effective matching methods can be used in the imputation (e.g., kernel matching). The benefits of matching as nonparametric preprocessing are that it is far more flexible with respect to the types of effects that can be estimated because it does not involve any specific estimator, its empirical and finite-sample performance has been examined in depth and is generally well understood, and it aligns well with the design of experiments, which are more familiar to non-technical audiences. -In addition to subset selection, matching often (though not always) involves a form of *stratification*, the assignment of units to pairs or strata containing multiple units. The distinction between subset selection and stratification is described by @zubizarreta2014, who separate them into two separate steps. In `MatchIt`, with almost all matching methods, subset selection is performed by stratification; for example, treated units are paired with control units, and unpaired units are then dropped from the matched sample. With some methods, subclasses are used to assign matching or stratification weights to individual units, which increase or decrease each unit's leverage in a subsequent analysis. There has been some debate about the importance of stratification after subset selection; while some authors have argued that, with some forms of matching, pair membership is incidental [@stuart2008; @schafer2008], others have argued that correctly incorporating pair membership into effect estimation can improve the quality of inferences [@austin2014a; @wan2019]. For methods that allow it, `MatchIt` includes stratum membership as an additional output of each matching specification. How these strata can be used is detailed in `vignette("Estimating Effects")`. +In addition to subset selection, matching often (though not always) involves a form of *stratification*, the assignment of units to pairs or strata containing multiple units. The distinction between subset selection and stratification is described by @zubizarreta2014, who separate them into two separate steps. In `MatchIt`, with almost all matching methods, subset selection is performed by stratification; for example, treated units are paired with control units, and unpaired units are then dropped from the matched sample. With some methods, subclasses are used to assign matching or stratification weights to individual units, which increase or decrease each unit's leverage in a subsequent analysis. There has been some debate about the importance of stratification after subset selection; while some authors have argued that, with some forms of matching, pair membership is incidental [@stuart2008; @schafer2008], others have argued that correctly incorporating pair membership into effect estimation can improve the quality of inferences [@austin2014a; @wan2019]. For methods that allow it, `MatchIt` includes stratum membership as an additional output of each matching specification. How these strata can be used is detailed in `vignette("estimating-effects")`. At the heart of `MatchIt` are three classes of methods: distance matching, stratum matching, and pure subset selection. *Distance matching* involves considering a focal group (usually the treated group) and selecting members of the non-focal group (i.e., the control group) to pair with each member of the focal group based on the *distance* between units, which can be computed in one of several ways. Members of either group that are not paired are dropped from the sample. Nearest neighbor matching (`method = "nearest"`), optimal pair matching (`method = "optimal"`), optimal full matching (`method = "full"`), generalized full matching (`method = "quick"`), and genetic matching (`method = "genetic"`) are the methods of distance matching implemented in `MatchIt`. Typically, only the average treatment in the treated (ATT) or average treatment in the control (ATC), if the control group is the focal group, can be estimated after distance matching in `MatchIt` (full matching is an exception, described later). @@ -52,6 +52,8 @@ Nearest neighbor matching requires the specification of a distance measure to de When using a matching ratio greater than 1 (i.e., when more than 1 control units are requested to be matched to each treated unit), matching occurs in a cycle, where each treated unit is first paired with one control unit, and then each treated unit is paired with a second control unit, etc. Ties are broken deterministically based on the order of the units in the dataset to ensure that multiple runs of the same specification yield the same result (unless the matching order is requested to be random). +Nearest neighbor matching is implemented in `MatchIt` using internal C++ code through `Rcpp`. When matching on a propensity score, this makes matching extremely fast, even for large datasets. Using a caliper on the propensity score (described below) makes it even faster. Run times may be a bit longer when matching on other distance measures (e.g., the Mahalanobis distance). In contrast to optimal pair matching (described below), nearest neighbor matching does not require computing the full distance matrix between units, which makes it more applicable to large datasets. + ### Optimal Pair Matching (`method = "optimal"`) Optimal pair matching (often just called optimal matching) is very similar to nearest neighbor matching in that it attempts to pair each treated unit with one or more control units. Unlike nearest neighbor matching, however, it is "optimal" rather than greedy; it is optimal in the sense that it attempts to choose matches that collectively optimize an overall criterion [@hansen2006; @gu1993]. The criterion used is the sum of the absolute pair distances in the matched sample. See `?method_optimal` for the documentation for `matchit()` with `method = "optimal"`. Optimal pair matching in `MatchIt` depends on the `fullmatch()` function in the `optmatch` package [@hansen2006]. @@ -121,7 +123,7 @@ Cardinality and profile matching are pure subset selection methods that involve Subset selection is performed by solving a mixed integer programming optimization problem with linear constraints. The problem involves maximizing the size of the matched sample subject to constraints on balance and sample size. For cardinality matching, the balance constraints refer to the mean difference for each covariate between the matched treated and control groups, and the sample size constraints require the matched treated and control groups to be the same size (or differ by a user-supplied factor). For profile matching, the balance constraints refer to the mean difference for each covariate between each treatment group and the target distribution; for the ATE, this requires the mean of each covariate in each treatment group to be within a given tolerance of the mean of the covariate in the full sample, and for the ATT, this requires the mean of each covariate in the control group to be within a given tolerance of the mean of the covariate in the treated group, which is left intact. The balance tolerances are controlled by the `tols` and `std.tols` arguments. One can also create pairs in the matched sample by using the `mahvars` argument, which requests that optimal Mahalanobis matching be done after subset selection; doing so can add additional precision and robustness [@zubizarretaMatchingBalancePairing2014]. -The optimization problem requires a special solver to solve. Currently, the available options in `MatchIt` are the GLPK solver (through the `Rglpk` package), the SYMPHONY solver (through the `Rsymphony` package), and the Gurobi solver (through the `gurobi` package). The differences among the solvers are in performance; Gurobi is by far the best (fastest, least likely to fail to find a solution), but it is proprietary (though has a free trial and academic license) and is a bit more complicated to install. The `designmatch` package also provides an implementation of cardinality matching with more options than `MatchIt` offers. +The optimization problem requires a special solver to solve. Currently, the available options in `MatchIt` are the HiGHS solver (through the `highs` package), the GLPK solver (through the `Rglpk` package), the SYMPHONY solver (through the `Rsymphony` package), and the Gurobi solver (through the `gurobi` package). The differences among the solvers are in performance; Gurobi is by far the best (fastest, least likely to fail to find a solution), but it is proprietary (though has a free trial and academic license) and is a bit more complicated to install. HiGHS is the default due to being open source, easily installed, and with performance comparable to Gurobi. The `designmatch` package also provides an implementation of cardinality matching with more options than `MatchIt` offers. ## Customizing the Matching Specification @@ -163,19 +165,21 @@ Anti-exact matching adds a restriction such that a treated and control unit with ### Matching with replacement (`replace`) -Nearest neighbor matching and genetic matching have the option of matching with or without replacement, and this is controlled by the `replace` argument. Matching without replacement means that each control unit is matched to only one treated unit, while matching with replacement means that control units can be reused and matched to multiple treated units. Matching without replacement carries certain statistical benefits in that weights for each unit can be omitted or are more straightforward to include and dependence between units depends only on pair membership. Special standard error estimators are sometimes required for estimating effects after matching with replacement [@austin2020a], and methods for accounting for uncertainty are not well understood for non-continuous outcomes. Matching with replacement will tend to yield better balance though, because the problem of "running out" of close control units to match to treated units is avoided, though the reuse of control units will decrease the effect sample size, thereby worsening precision [@austin2013b]. (This problem occurs in the Lalonde dataset used in `vignette("MatchIt")`, which is why nearest neighbor matching without replacement is not very effective there.) After matching with replacement, control units are assigned to more than one subclass, so the `get_matches()` function should be used instead of `match.data()` after matching with replacement if subclasses are to be used in follow-up analyses; see `vignette("estimating-effects")` for details. +Nearest neighbor matching and genetic matching have the option of matching with or without replacement, and this is controlled by the `replace` argument. Matching without replacement means that each control unit is matched to only one treated unit, while matching with replacement means that control units can be reused and matched to multiple treated units. Matching without replacement carries certain statistical benefits in that weights for each unit can be omitted or are more straightforward to include and dependence between units depends only on pair membership. However, it is not asymptotically consistent unless the propensity scores for all treated units are below .5 and there are many more control units than treated units [@savjeInconsistencyMatchingReplacement2022]. Special standard error estimators are sometimes required for estimating effects after matching with replacement [@austin2020a], and methods for accounting for uncertainty are not well understood for non-continuous outcomes. Matching with replacement will tend to yield better balance though, because the problem of "running out" of close control units to match to treated units is avoided, though the reuse of control units will decrease the effect sample size, thereby worsening precision [@austin2013b]. (This problem occurs in the Lalonde dataset used in `vignette("MatchIt")`, which is why nearest neighbor matching without replacement is not very effective there.) After matching with replacement, control units are assigned to more than one subclass, so the `get_matches()` function should be used instead of `match.data()` after matching with replacement if subclasses are to be used in follow-up analyses; see `vignette("estimating-effects")` for details. The `reuse.max` argument can also be used with `method = "nearest"` to control how many times each control unit can be reused as a match. Setting `reuse.max = 1` is equivalent to requiring matching without replacement (i.e., because each control can be used only once). Other values allow control units to be matched more than once, though only up to the specified number of times. Higher values will tend to improve balance at the cost of precision. ### $k$:1 matching (`ratio`) -The most common form of matching, 1:1 matching, involves pairing one control unit with each treated unit. To perform $k$:1 matching (e.g., 2:1 or 3:1), which pairs (up to) $k$ control units with each treated unit, the `ratio` argument can be specified. Performing $k$:1 matching can preserve precision by preventing too many control units from being unmatched and dropped from the matched sample, though the gain in precision by increasing $k$ diminishes rapidly after 4 [@rosenbaum2020]. Importantly, for $k>1$, the matches after the first match will generally be worse than the first match in terms of closeness to the treated unit, so increasing $k$ can also worsen balance. @austin2010a found that 1:1 or 1:2 matching generally performed best in terms of mean squared error. In general, it makes sense to use higher values of $k$ while ensuring that balance is satisfactory. +The most common form of matching, 1:1 matching, involves pairing one control unit with each treated unit. To perform $k$:1 matching (e.g., 2:1 or 3:1), which pairs (up to) $k$ control units with each treated unit, the `ratio` argument can be specified. Performing $k$:1 matching can preserve precision by preventing too many control units from being unmatched and dropped from the matched sample, though the gain in precision by increasing $k$ diminishes rapidly after 4 [@rosenbaum2020]. Importantly, for $k>1$, the matches after the first match will generally be worse than the first match in terms of closeness to the treated unit, so increasing $k$ can also worsen balance [@rassenOnetomanyPropensityScore2012]. @austin2010a found that 1:1 or 1:2 matching generally performed best in terms of mean squared error. In general, it makes sense to use higher values of $k$ while ensuring that balance is satisfactory. With nearest neighbor and optimal pair matching, variable $k$:1 matching, in which the number of controls matched to each treated unit varies, can also be used; this can have improved performance over "fixed" $k$:1 matching [@ming2000; @rassenOnetomanyPropensityScore2012]. See `?method_nearest` and `?method_optimal` for information on implementing variable $k$:1 matching. ### Matching order (`m.order`) -For nearest neighbor matching (including genetic matching), units are matched in an order, and that order can affect the quality of individual matches and of the resulting matched sample. With `method = "nearest"`, the allowable options to `m.order` to control the matching order are `"largest"`, `"smallest"`, `"closest"`, `"random"`, and `"data"`. With `method = "genetic"`, all but `"closest"` can be used. Requesting `"largest"` means that treated units with the largest propensity scores, i.e., those least like the control units, will be matched first, which prevents them from having bad matches after all the close control units have been used up. `"smallest"` means that treated units with the smallest propensity scores are matched first. `"closest"` means that potential pairs with the smallest distance between units will be matched first, which ensures that the best possible matches are included in the matched sample but can yield poor matches for units whose best match is far from them; this makes it particularly useful when matching with a caliper. `"random"` matches in a random order and `"data"` matches in order of the data. A propensity score is required for `"largest"` and `"smallest"` but not for the other options. @rubin1973 recommends using `"largest"` or `"random"`, though @austin2013b recommends against `"largest"` and instead favors `"closest"` or `"random"`. +For nearest neighbor matching (including genetic matching), units are matched in an order, and that order can affect the quality of individual matches and of the resulting matched sample. With `method = "nearest"`, the allowable options to `m.order` to control the matching order are `"largest"`, `"smallest"`, `"closest"`, `"farthest"`, `"random"`, and `"data"`. With `method = "genetic"`, all but `"closest"` and `"farthest"` can be used. Requesting `"largest"` means that treated units with the largest propensity scores, i.e., those least like the control units, will be matched first, which prevents them from having bad matches after all the close control units have been used up. `"smallest"` means that treated units with the smallest propensity scores are matched first. `"closest"` means that potential pairs with the smallest distance between units will be matched first, which ensures that the best possible matches are included in the matched sample but can yield poor matches for units whose best match is far from them; this makes it particularly useful when matching with a caliper. `"farthest"` means that closest pairs with the largest distance between them will be matched first, which ensures the hardest units to match are given the best chance to find matches. `"random"` matches in a random order, and `"data"` matches in order of the data. A propensity score is required for `"largest"` and `"smallest"` but not for the other options. + +@rubin1973 recommends using `"largest"` or `"random"`, though @austin2013b recommends against `"largest"` and instead favors `"closest"` or `"random"`. `"closest"` and `"smallest"` are best for prioritizing the best possible matches, while `"farthest"` and `"largest"` are best for preventing extreme pairwise distances between matched units. ## Choosing a Matching Method @@ -187,7 +191,7 @@ If the target of inference is the ATE, optimal or generalized full matching, sub Because exact and coarsened exact matching aim to balance the entire joint distribution of covariates, they are the most powerful methods. If it is possible to perform exact matching, this method should be used. If continuous covariates are present, coarsened exact matching can be tried. Care should be taken with retaining the target population and ensuring enough matched units remain; unless the control pool is much larger than the treated pool, it is likely some (or many) treated units will be discarded, thereby changing the estimand and possibly dramatically reducing precision. These methods are typically only available in the most optimistic of circumstances, but they should be used first when those circumstances arise. It may also be useful to combine exact or coarsened exact matching on some covariates with another form of matching on the others (i.e., by using the `exact` argument). -When estimating the ATE, either subclassification, full matching, or profile matching can be used. Optimal and generalized full matching can be effective because they optimize a balance criterion, often leading to better balance. With full matching, it's also possible to exact match on some variables and match using the Mahalanobis distance, eliminating the need to estimate propensity scores. Profile matching also ensures good balance, but because units are only given weights of zero or one, a solution may not be feasible and many units may have to be discarded. For large datasets, neither optimal full matching nor profile matching may be possible, in which case generalized full matching and subclassification are faster solutions. When using subclassification, the number of subclasses should be varied. With large samples, higher numbers of subclasses tend to yield better performance; one should not immediately settle for the default (6) or the often-cited recommendation of 5 without trying several other numbers. +When estimating the ATE, either subclassification, full matching, or profile matching can be used. Optimal and generalized full matching can be effective because they optimize a balance criterion, often leading to better balance. With full matching, it's also possible to exact match on some variables and match using the Mahalanobis distance, eliminating the need to estimate propensity scores. Profile matching also ensures good balance, but because units are only given weights of zero or one, a solution may not be feasible and many units may have to be discarded. For large datasets, neither optimal full matching nor profile matching may be possible, in which case generalized full matching and subclassification are faster solutions. When using subclassification, the number of subclasses should be varied. With large samples, higher numbers of subclasses tend to yield better performance; one should not immediately settle for the default (6) or the often-cited recommendation of 5 without trying several other numbers. The documentation for `cobalt::bal.compute()` contains an example of using balance to select the optimal number of subclasses. When estimating the ATT, a variety of methods can be tried. Genetic matching can perform well at achieving good balance because it directly optimizes covariate balance. With larger datasets, it may take a long time to reach a good solution (though that solution will tend to be good as well). Profile matching also will achieve good balance if a solution is feasible because balance is controlled by the user. Optimal pair matching and nearest neighbor matching without replacement tend to perform similarly to each other; nearest neighbor matching may be preferable for large datasets that cannot be handled by optimal matching. Nearest neighbor, optimal, and genetic matching allow some customizations like including covariates on which to exactly match, using the Mahalanobis distance instead of a propensity score difference, and performing $k$:1 matching with $k>1$. Nearest neighbor matching with replacement, full matching, and subclassification all involve weighting the control units with nonuniform weights, which often allows for improved balancing capabilities but can be accompanied by a loss in effective sample size, even when all units are retained. There is no reason not to try many of these methods, varying parameters here and there, in search of good balance and high remaining sample size. As previously mentioned, no single method can be recommended above all others because the optimal specification depends on the unique qualities of each dataset. @@ -195,7 +199,7 @@ When the target population is less important, for example, when engaging in trea It is important not to rely excessively on theoretical or simulation-based findings or specific recommendations when making choices about the best matching method to use. For example, although nearest neighbor matching without replacement balance covariates better than did subclassification with five or ten subclasses in Austin's [-@austin2009c] simulation, this does not imply it will be superior in all datasets. Likewise, though @rosenbaum1985a and @austin2011a both recommend using a caliper of .2 standard deviations of the logit of the propensity score, this does not imply that caliper will be optimal in all scenarios, and other widths should be tried, though it should be noted that tightening the caliper on the propensity score can sometimes degrade performance [@king2019]. -For large datasets (i.e., in 10,000s to millions), some matching methods will be too slow to be used at scale. Instead, users should consider generalized full matching, subclassification, or coarsened exact matching, which are all very fast and designed to work with large datasets. +For large datasets (i.e., in 10,000s to millions), some matching methods will be too slow to be used at scale. Instead, users should consider generalized full matching, subclassification, or coarsened exact matching, which are all very fast and designed to work with large datasets. Nearest neighbor matching on the propensity score has been optimized to run quickly for large datasets as well. ## Reporting the Matching Specification diff --git a/vignettes/references.bib b/vignettes/references.bib index d2cc7fb1..f6a9929e 100644 --- a/vignettes/references.bib +++ b/vignettes/references.bib @@ -1655,3 +1655,17 @@ @article{rassenOnetomanyPropensityScore2012 note = {{\_}eprint: https://onlinelibrary.wiley.com/doi/pdf/10.1002/pds.3263}, langid = {en} } + +@article{savjeInconsistencyMatchingReplacement2022, + title = {On the inconsistency of matching without replacement}, + author = {{Sävje}, F}, + year = {2022}, + month = {06}, + date = {2022-06-01}, + journal = {Biometrika}, + pages = {551--558}, + volume = {109}, + number = {2}, + doi = {10.1093/biomet/asab035}, + url = {https://doi.org/10.1093/biomet/asab035} +} diff --git a/vignettes/sampling-weights.Rmd b/vignettes/sampling-weights.Rmd index 32da472f..0d5bbbd5 100644 --- a/vignettes/sampling-weights.Rmd +++ b/vignettes/sampling-weights.Rmd @@ -7,7 +7,7 @@ output: toc: true vignette: > %\VignetteIndexEntry{Matching with Sampling Weights} - %\VignetteEngine{knitr::rmarkdown} + %\VignetteEngine{knitr::rmarkdown_notangle} %\VignetteEncoding{UTF-8} bibliography: references.bib link-citations: true @@ -67,6 +67,10 @@ SW <- gen_SW(X) Y_C <- gen_Y_C(A, X) d <- data.frame(A, X, Y_C, SW) + +eval_est <- (requireNamespace("optmatch", quietly = TRUE) && + requireNamespace("marginaleffects", quietly = TRUE) && + requireNamespace("sandwich", quietly = TRUE)) ``` ## Introduction @@ -86,17 +90,11 @@ library("MatchIt") When using sampling weights with propensity score matching, one has the option of including the sampling weights in the model used to estimate the propensity scores. Although evidence is mixed on whether this is required [@austin2016; @lenis2019], it can be a good idea. The choice should depend on whether including the sampling weights improves the quality of the matches. Specifications including and excluding sampling weights should be tried to determine which is preferred. To supply sampling weights to the propensity score-estimating function in `matchit()`, the sampling weights variable should be supplied to the `s.weights` argument. It can be supplied either as a numerical vector containing the sampling weights, or a string or one-sided formula with the name of the sampling weights variable in the supplied dataset. Below we demonstrate including sampling weights into propensity scores estimated using logistic regression for optimal full matching for the average treatment effect in the population (ATE) (note that all methods and steps apply the same way to all forms of matching and all estimands). -```{asis, echo = !requireNamespace("optmatch", quietly = TRUE)} +```{asis, echo = eval_est} Note: if the `optmatch`, `marginaleffects`, or `sandwich` packages are not available, the subsequent lines will not run. ``` -```{r, include=FALSE} -#In case packages goes offline, don't run lines below -if (!requireNamespace("optmatch", quietly = TRUE) || - !requireNamespace("marginaleffects", quietly = TRUE) || - !requireNamespace("sandwich", quietly = TRUE)) knitr::opts_chunk$set(eval = FALSE) -``` -```{r} +```{r, eval = eval_est} mF_s <- matchit(A ~ X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 + X9, data = d, method = "full", distance = "glm", @@ -108,7 +106,7 @@ Notice that the description of the matching specification when the `matchit` obj Now let's perform full matching on a propensity score that does not include the sampling weights in its estimation. Here we use the same specification as was used in `vignette("estimating-effects")`. -```{r} +```{r, eval = eval_est} mF <- matchit(A ~ X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 + X9, data = d, method = "full", distance = "glm", @@ -118,7 +116,7 @@ mF Notice that there is no mention of sampling weights in the description of the matching specification. However, to properly assess balance and estimate effects, we need the sampling weights to be included in the `matchit` object, even if they were not used at all in the matching. To do so, we use the function `add_s.weights()`, which adds sampling weights to the supplied `matchit` objects. -```{r} +```{r, eval = eval_est} mF <- add_s.weights(mF, ~SW) mF @@ -134,7 +132,7 @@ Now we need to decide which matching specification is the best to use for effect We'll use `summary()` to examine balance on the two matching specifications. With sampling weights included, the balance statistics for the unmatched data are weighted by the sampling weights. The balance statistics for the matched data are weighted by the product of the sampling weights and the matching weights. It is the product of these weights that will be used in estimating the treatment effect. Below we use `summary()` to display balance for the two matching specifications. No additional arguments to `summary()` are required for it to use the sampling weights; as long as they are in the `matchit` object (either due to being supplied with the `s.weights` argument in the call to `matchit()` or to being added afterward by `add_s.weights()`), they will be correctly incorporated into the balance statistics. -```{r} +```{r, eval = eval_est} #Balance before matching and for the SW propensity score full matching summary(mF_s) @@ -148,11 +146,11 @@ Note that had we not added sampling weights to `mF`, the matching specification ## Estimating the Effect -Estimating the treatment effect after matching is straightforward when using sampling weights. Effects are estimated in the same way as when sampling weights are excluded, except that the matching weights must be multiplied by the sampling weights to yield accurate, generalizable estimates. `match.data()` and `get_matches()` do this automatically, so the weights produced by these functions already are a product of the matching weights and the sampling weights. Note this will only be true if sampling weights are incorporated into the `matchit` object. +Estimating the treatment effect after matching is straightforward when using sampling weights. Effects are estimated in the same way as when sampling weights are excluded, except that the matching weights must be multiplied by the sampling weights for use in the outcome model to yield accurate, generalizable estimates. `match.data()` and `get_matches()` do this automatically, so the weights produced by these functions already are a product of the matching weights and the sampling weights. Note this will only be true if sampling weights are incorporated into the `matchit` object. With `avg_comparisons()`, only the sampling weights should be included when estimating the treatment effect. Below we estimate the effect of `A` on `Y_C` in the matched and sampling weighted sample, adjusting for the covariates to improve precision and decrease bias. -```{r} +```{r, eval = eval_est} md_F_s <- match.data(mF_s) fit <- lm(Y_C ~ A * (X1 + X2 + X3 + X4 + X5 + @@ -163,15 +161,12 @@ library("marginaleffects") avg_comparisons(fit, variables = "A", vcov = ~subclass, - newdata = subset(md_F_s, A == 1), - wts = "weights") + newdata = subset(A == 1), + wts = "SW") ``` -Note that `match.data()` and `get_weights()` have the option `include.s.weights`, which, when set to `FALSE`, makes it so the returned weights do not incorporate the sampling weights and are simply the matching weights. Because one might to forget to multiply the two sets of weights together, it is easier to just use the default of `include.s.weights = TRUE` and ignore the sampling weights in the rest of the analysis (because they are already included in the returned weights). `avg_comparisons()` also works more smoothly when the weights supplied to `weights` is a single variable rather than the product of two. +Note that `match.data()` and `get_weights()` have the option `include.s.weights`, which, when set to `FALSE`, makes it so the returned weights do not incorporate the sampling weights and are simply the matching weights. Because one might to forget to multiply the two sets of weights together, it is easier to just use the default of `include.s.weights = TRUE` and ignore the sampling weights in the rest of the analysis (because they are already included in the returned weights). -```{r, include=FALSE, eval=TRUE} -knitr::opts_chunk$set(eval = TRUE) -``` ## Code to Generate Data used in Examples ```{r, eval = FALSE}