diff --git a/R/collapseCohorts.R b/R/collapseCohorts.R index 2f4d7a5a..48423e26 100644 --- a/R/collapseCohorts.R +++ b/R/collapseCohorts.R @@ -53,17 +53,14 @@ collapseCohorts <- function(cohort, dplyr::select(!"observation_period_id") } else if (gap > 0) { newCohort <- newCohort |> - PatientProfiles::addObservationPeriodId() |> joinOverlap( name = tmpNewCohort, gap = gap, by = c( "cohort_definition_id", - "subject_id", - "observation_period_id" + "subject_id" ) - ) |> - dplyr::select(!"observation_period_id") + ) } if (!all(ids %in% cohortId)) { newCohort <- unchangedCohort |> @@ -83,5 +80,3 @@ collapseCohorts <- function(cohort, return(newCohort) } - - diff --git a/R/intersectCohorts.R b/R/intersectCohorts.R index 98b3d3d9..1dd45f56 100644 --- a/R/intersectCohorts.R +++ b/R/intersectCohorts.R @@ -134,6 +134,7 @@ intersectCohorts <- function(cohort, "cohort_start_date", "cohort_end_date") %>% dplyr::compute(name = tblName, temporary = FALSE) if (cohortOut |> dplyr::tally() |> dplyr::pull("n") > 0) { + class(cohortOut) <- c(class(cohortOut), "cohort_table") cohortOut <- cohortOut %>% dplyr::compute(name = tblName, temporary = FALSE) |> joinOverlap(name = tblName, gap = gap) @@ -333,10 +334,12 @@ joinOverlap <- function(cohort, cdm <- omopgenerics::cdmReference(cohort) start <- cohort |> - dplyr::select(by, "date" := !!startDate) |> + PatientProfiles::addObservationPeriodId() |> + dplyr::select(by, "date" := !!startDate, "observation_period_id") |> dplyr::mutate("date_id" = -1) end <- cohort |> - dplyr::select(by, "date" := !!endDate) |> + PatientProfiles::addObservationPeriodId() |> + dplyr::select(by, "date" := !!endDate, "observation_period_id") |> dplyr::mutate("date_id" = 1) if (gap > 0) { end <- end %>% @@ -354,7 +357,7 @@ joinOverlap <- function(cohort, dplyr::compute(temporary = FALSE, name = workingTbl) x <- x |> - dplyr::group_by(dplyr::pick(by)) |> + dplyr::group_by(dplyr::pick(by), .data$observation_period_id) |> dplyr::arrange(.data$date, .data$date_id) |> dplyr::mutate( "cum_id" = cumsum(.data$date_id), @@ -366,10 +369,10 @@ joinOverlap <- function(cohort, dplyr::mutate("era_id" = cumsum(as.numeric(.data$era_id))) |> dplyr::ungroup() |> dplyr::arrange() |> - dplyr::select(dplyr::all_of(c(by, "era_id", "name", "date"))) |> + dplyr::select(dplyr::all_of(c(by, "observation_period_id", "era_id", "name", "date"))) |> dplyr::compute(temporary = FALSE, name = name) |> tidyr::pivot_wider(names_from = "name", values_from = "date") |> - dplyr::select(-"era_id") |> + dplyr::select(-"era_id", -"observation_period_id") |> dplyr::compute(temporary = FALSE, name = name) if (gap > 0) { x <- x %>% diff --git a/tests/testthat/test-intersectCohorts.R b/tests/testthat/test-intersectCohorts.R index 23c2ef70..5f003351 100644 --- a/tests/testthat/test-intersectCohorts.R +++ b/tests/testthat/test-intersectCohorts.R @@ -536,6 +536,53 @@ test_that("codelist", { PatientProfiles::mockDisconnect(cdm) }) +test_that("multiple observation periods", { + cdm_local <- omock::mockCdmReference() |> + omock::mockPerson(n = 4) |> + omock::mockObservationPeriod() + cdm_local$observation_period <- cdm_local$observation_period |> + dplyr::filter(person_id != 4) |> + dplyr::union_all(dplyr::tibble( + observation_period_id = c(4L,5L, 6L), + person_id = 4L, + observation_period_start_date = c(as.Date("1989-12-09"), as.Date("2003-01-01"), as.Date("2009-02-04")), + observation_period_end_date = c(as.Date("2002-12-31"), as.Date("2009-02-03"),as.Date("2013-12-31")), + period_type_concept_id = NA + ) + ) + cdm_local <- cdm_local |> + omock::mockCohort(name = c("cohort"), numberCohorts = 3, seed = 11) + cdm_local$cohort <- cdm_local$cohort |> + dplyr::union_all(dplyr::tibble( + cohort_definition_id = c(1L,2L, 3L), + subject_id = 4L, + cohort_start_date = c(as.Date("2009-04-05"), as.Date("2009-06-07"), as.Date("2009-01-01")), + cohort_end_date = c(as.Date("2010-01-01"), as.Date("2009-12-12"), as.Date("2009-02-01")) + ) + ) + cdm <- cdm_local |> copyCdm() + + cdm$cohort2 <- intersectCohorts( + cohort = cdm$cohort, name = "cohort2" , + keepOriginalCohorts = TRUE , + gap = 1000) + + cdm$cohort3 <- intersectCohorts( + cohort = cdm$cohort, name = "cohort3" , + keepOriginalCohorts = TRUE) + + expect_true(cdm$cohort2 |> + dplyr::tally() |> + dplyr::pull() == + 5) + expect_true(cdm$cohort3 |> + dplyr::tally() |> + dplyr::pull() == + 6) + + PatientProfiles::mockDisconnect(cdm) +}) + test_that("records combined for gap must be in the same observation period", { cdm_local <- omock::mockCdmReference() |> omock::mockPerson(n = 2) diff --git a/tests/testthat/test-unionCohorts.R b/tests/testthat/test-unionCohorts.R index d9812969..35242c78 100644 --- a/tests/testthat/test-unionCohorts.R +++ b/tests/testthat/test-unionCohorts.R @@ -297,3 +297,38 @@ test_that("keep original cohorts", { PatientProfiles::mockDisconnect(cdm) }) + +test_that("multiple observation periods", { + cdm_local <- omock::mockCdmReference() |> + omock::mockPerson(n = 4) |> + omock::mockObservationPeriod() + cdm_local$observation_period <- cdm_local$observation_period |> + dplyr::filter(person_id != 4) |> + dplyr::union_all(dplyr::tibble( + observation_period_id = c(4L,5L, 6L), + person_id = 4L, + observation_period_start_date = c(as.Date("1989-12-09"), as.Date("2003-01-01"), as.Date("2009-02-04")), + observation_period_end_date = c(as.Date("2002-12-31"), as.Date("2009-02-03"),as.Date("2013-12-31")), + period_type_concept_id = NA + ) + ) + cdm_local <- cdm_local |> + omock::mockCohort(name = c("cohort"), numberCohorts = 3, seed = 11) + + cdm <- cdm_local |> copyCdm() + + cdm$cohort2 <- unionCohorts(cdm$cohort, name = "cohort2", gap = 10000) + + expect_true(cdm$cohort2 |> + dplyr::filter(subject_id == 4 & cohort_start_date < as.Date("2003-01-01") & + cohort_end_date > as.Date("2003-01-01")) |> + dplyr::tally() |> + dplyr::pull() == 0) + expect_true(cdm$cohort2 |> + dplyr::filter(subject_id == 4 & cohort_start_date < as.Date("2009-02-04") & + cohort_end_date > as.Date("2009-02-04")) |> + dplyr::tally() |> + dplyr::pull() == 0) + PatientProfiles::mockDisconnect(cdm) + +})