Skip to content

Commit

Permalink
val/test split: update function to reflect real mechanism using var_n…
Browse files Browse the repository at this point in the history
…ames
  • Loading branch information
cblaauw committed Nov 19, 2024
1 parent 1348a0b commit b11048a
Showing 1 changed file with 4 additions and 20 deletions.
24 changes: 4 additions & 20 deletions src/crested/pp/_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def _split_by_chromosome_auto(
"""
chrom_count = defaultdict(int)
for region in regions:
if not region.contains(":"):
raise ValueError(f"Region names should start with the chromosome name, bound by a colon (:). Offending region: {region}")
chrom = region.split(":")[0]
chrom_count[chrom] += 1

Expand Down Expand Up @@ -188,7 +190,6 @@ def train_val_test_split(
test_size: float = 0.1,
val_chroms: list[str] = None,
test_chroms: list[str] = None,
chr_var_key: str = "chr",
shuffle: bool = True,
random_state: None | int = None,
) -> None:
Expand All @@ -208,9 +209,8 @@ def train_val_test_split(
adata
AnnData object to which the 'train/val/test' split column will be added.
strategy
strategy of split. Either 'region', 'chr' or 'chr_auto'. If 'chr' or 'chr_auto', the "target" df should
have a column "chr" with the chromosome names.
strategy of split. Either 'region', 'chr' or 'chr_auto'. If 'chr' or 'chr_auto', the anndata's var_names should
contain the chromosome name at the start, followed by a `:` (e.g. I:2000-2500 or chr3:10-20:+).
region: Split randomly on region indices.
chr: Split based on provided chromosomes.
Expand All @@ -227,8 +227,6 @@ def train_val_test_split(
List of chromosomes to include in the validation set. Required if strategy='chr'.
test_chroms
List of chromosomes to include in the test set. Required if strategy='chr'.
chr_var_key
Key in `.var` for chromosome.
shuffle
Whether or not to shuffle the data before splitting (when strategy='region').
random_state
Expand Down Expand Up @@ -265,11 +263,6 @@ def train_val_test_split(
raise ValueError("`val_size` should be a float between 0 and 1.")
if strategy in ["region", "chr_auto"] and not 0 <= test_size <= 1:
raise ValueError("`test_size` should be a float between 0 and 1.")
if strategy in ["chr", "chr_auto"] and chr_var_key not in adata.var.columns:
raise ValueError(
f"Column '{chr_var_key}' not found in `.var`. "
"Make sure to add the chromosome information to the `.var` DataFrame."
)
if (strategy == "region") and (val_chroms is not None or test_chroms is not None):
logger.warning(
"`val_chroms` and `test_chroms` provided but splitting strategy is 'region'. Will use 'chr' strategy instead."
Expand All @@ -280,15 +273,6 @@ def train_val_test_split(
raise ValueError(
"If `strategy` is 'chr', `val_chroms` and `test_chroms` should be provided."
)
unique_chr = adata.var[chr_var_key].unique()
if not set(val_chroms).issubset(unique_chr):
raise ValueError(
"Some chromosomes in `val_chroms` are not present in the dataset."
)
if not set(test_chroms).issubset(unique_chr):
raise ValueError(
"Some chromosomes in `test_chroms` are not present in the dataset."
)

# Split
regions = list(adata.var_names)
Expand Down

0 comments on commit b11048a

Please sign in to comment.