diff --git a/CHANGES.md b/CHANGES.md index 7ac015a76..1ece03de7 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -7,12 +7,14 @@ * A new command, `augur merge`, now allows for generalized merging of two or more metadata tables. [#1563][] (@tsibley) * Two new commands, `augur read-file` and `augur write-file`, now allow external programs to do i/o like Augur by piping from/to these new commands. They provide handling of compression formats and newlines consistent with the rest of Augur. [#1562][] (@tsibley) * A new debugging mode can be enabled by setting the `AUGUR_DEBUG` environment variable to `1` (or any non-empty value). Currently the only effect is to print more information about handled (i.e. anticipated) errors. For example, stack traces and parent exceptions in an exception chain are normally omitted for handled errors, but setting this env var includes them. Future debugging and troubleshooting features, like verbose operation logging, will likely also condition on this new debugging mode. [#1577][] (@tsibley) +* filter: Added the ability to use weights in subsampling. See help text of `--group-by-weights` for more information. [#1454][] (@victorlin) ### Bug Fixes * Embedded newlines in quoted field values of metadata files read/written by many commands, annotation files read by `augur curate apply-record-annotations`, and index files written by `augur index` are now properly handled. [#1561][] [#1564][] (@tsibley) * Output written to stderr (e.g. informational messages, warnings, errors, etc.) is now always line-buffered regardless of the Python version in use. This helps with interleaved stderr and stdout. Previously, stderr was block-buffered on Python 3.8 and line-buffered on 3.9 and higher. [#1563][] (@tsibley) +[#1454]: https://github.com/nextstrain/augur/pull/1454 [#1561]: https://github.com/nextstrain/augur/pull/1561 [#1562]: https://github.com/nextstrain/augur/pull/1562 [#1563]: https://github.com/nextstrain/augur/pull/1563 diff --git a/augur/dates/__init__.py b/augur/dates/__init__.py index 046638826..119954027 100644 --- a/augur/dates/__init__.py +++ b/augur/dates/__init__.py @@ -143,5 +143,9 @@ def get_numerical_dates(metadata:pd.DataFrame, name_col = None, date_col='date', dates = metadata[date_col].astype(float) return dict(zip(strains, dates)) -def get_iso_year_week(year, month, day): - return datetime.date(year, month, day).isocalendar()[:2] +def get_year_month(year, month): + return f"{year}-{str(month).zfill(2)}" + +def get_year_week(year, month, day): + year, week = datetime.date(year, month, day).isocalendar()[:2] + return f"{year}-{str(week).zfill(2)}" diff --git a/augur/filter/__init__.py b/augur/filter/__init__.py index 947875258..2c6c9d7db 100644 --- a/augur/filter/__init__.py +++ b/augur/filter/__init__.py @@ -67,9 +67,34 @@ def register_arguments(parser): subsample_limits_group = subsample_group.add_mutually_exclusive_group() subsample_limits_group.add_argument('--sequences-per-group', type=int, help="subsample to no more than this number of sequences per category") subsample_limits_group.add_argument('--subsample-max-sequences', type=int, help="subsample to no more than this number of sequences; can be used without the group_by argument") - group_size_options = subsample_group.add_mutually_exclusive_group() - group_size_options.add_argument('--probabilistic-sampling', action='store_true', help="Allow probabilistic sampling during subsampling. This is useful when there are more groups than requested sequences. This option only applies when `--subsample-max-sequences` is provided.") - group_size_options.add_argument('--no-probabilistic-sampling', action='store_false', dest='probabilistic_sampling') + probabilistic_sampling_group = subsample_group.add_mutually_exclusive_group() + probabilistic_sampling_group.add_argument('--probabilistic-sampling', action='store_true', help="Allow probabilistic sampling during subsampling. This is useful when there are more groups than requested sequences. This option only applies when `--subsample-max-sequences` is provided.") + probabilistic_sampling_group.add_argument('--no-probabilistic-sampling', action='store_false', dest='probabilistic_sampling') + subsample_group.add_argument('--group-by-weights', type=str, metavar="FILE", help=""" + TSV file defining weights for grouping. Requirements: + + (1) Lines starting with '#' are treated as comment lines. + (2) The first non-comment line must be a header row. + (3) There must be a numeric ``weight`` column (weights can take on any + non-negative values). + (4) Other columns must be a subset of columns used in ``--group-by``, + with combinations of values covering all combinations present in the + metadata. + (5) This option only applies when ``--group-by`` and + ``--subsample-max-sequences`` are provided. + (6) This option cannot be used with ``--no-probabilistic-sampling``. + + Notes: + + (1) Any ``--group-by`` columns absent from this file will be given equal + weighting across all values *within* groups defined by the other + weighted columns. + (2) An entry with the value ``default`` under all columns will be + treated as the default weight for specific groups present in the + metadata but missing from the weights file. If there is no default + weight and the metadata contains rows that are not covered by the + given weights, augur filter will exit with an error. + """) subsample_group.add_argument('--priority', type=str, help="""tab-delimited file with list of priority scores for strains (e.g., "\\t") and no header. When scores are provided, Augur converts scores to floating point values, sorts strains within each subsampling group from highest to lowest priority, and selects the top N strains per group where N is the calculated or requested number of strains per group. Higher numbers indicate higher priority. @@ -81,6 +106,7 @@ def register_arguments(parser): output_group.add_argument('--output-metadata', help="metadata for strains that passed filters") output_group.add_argument('--output-strains', help="list of strains that passed filters (no header)") output_group.add_argument('--output-log', help="tab-delimited file with one row for each filtered strain and the reason it was filtered. Keyword arguments used for a given filter are reported in JSON format in a `kwargs` column.") + output_group.add_argument('--output-group-by-sizes', help="tab-delimited file one row per group with target size.") output_group.add_argument( '--empty-output-reporting', type=EmptyOutputReportingMethod.argtype, diff --git a/augur/filter/_run.py b/augur/filter/_run.py index f015b8c9a..cca387f13 100644 --- a/augur/filter/_run.py +++ b/augur/filter/_run.py @@ -23,7 +23,7 @@ from . import include_exclude_rules from .io import cleanup_outputs, get_useful_metadata_columns, read_priority_scores, write_metadata_based_outputs from .include_exclude_rules import apply_filters, construct_filters -from .subsample import PriorityQueue, TooManyGroupsError, calculate_sequences_per_group, get_probabilistic_group_sizes, create_queues_by_group, get_groups_for_subsampling +from .subsample import PriorityQueue, TooManyGroupsError, calculate_sequences_per_group, get_probabilistic_group_sizes, create_queues_by_group, get_groups_for_subsampling, get_weighted_group_sizes def run(args): @@ -264,32 +264,43 @@ def run(args): # group. Then, we need to make a second pass through the metadata to find # the requested number of records. if args.subsample_max_sequences and records_per_group is not None: - # Calculate sequences per group. If there are more groups than maximum - # sequences requested, sequences per group will be a floating point - # value and subsampling will be probabilistic. - try: - sequences_per_group, probabilistic_used = calculate_sequences_per_group( - args.subsample_max_sequences, - records_per_group.values(), - args.probabilistic_sampling, - ) - except TooManyGroupsError as error: - raise AugurError(error) - if queues_by_group is None: # We know all of the possible groups now from the first pass through # the metadata, so we can create queues for all groups at once. - if (probabilistic_used): - print_err(f"Sampling probabilistically at {sequences_per_group:0.4f} sequences per group, meaning it is possible to have more than the requested maximum of {args.subsample_max_sequences} sequences after filtering.") - group_sizes = get_probabilistic_group_sizes( - records_per_group.keys(), - sequences_per_group, - random_seed=args.subsample_seed, + if args.group_by_weights: + print_err(f"Sampling with weights defined by {args.group_by_weights}.") + group_sizes = get_weighted_group_sizes( + records_per_group, + group_by, + args.group_by_weights, + args.subsample_max_sequences, + args.output_group_by_sizes, + args.subsample_seed, ) else: - print_err(f"Sampling at {sequences_per_group} per group.") - assert type(sequences_per_group) is int - group_sizes = {group: sequences_per_group for group in records_per_group.keys()} + # Calculate sequences per group. If there are more groups than maximum + # sequences requested, sequences per group will be a floating point + # value and subsampling will be probabilistic. + try: + sequences_per_group, probabilistic_used = calculate_sequences_per_group( + args.subsample_max_sequences, + records_per_group.values(), + args.probabilistic_sampling, + ) + except TooManyGroupsError as error: + raise AugurError(error) + + if (probabilistic_used): + print_err(f"Sampling probabilistically at {sequences_per_group:0.4f} sequences per group, meaning it is possible to have more than the requested maximum of {args.subsample_max_sequences} sequences after filtering.") + group_sizes = get_probabilistic_group_sizes( + records_per_group.keys(), + sequences_per_group, + random_seed=args.subsample_seed, + ) + else: + print_err(f"Sampling at {sequences_per_group} per group.") + assert type(sequences_per_group) is int + group_sizes = {group: sequences_per_group for group in records_per_group.keys()} queues_by_group = create_queues_by_group(group_sizes) # Make a second pass through the metadata, only considering records that diff --git a/augur/filter/subsample.py b/augur/filter/subsample.py index ed0f73c9d..11f2eef86 100644 --- a/augur/filter/subsample.py +++ b/augur/filter/subsample.py @@ -1,15 +1,21 @@ +from collections import defaultdict import heapq import itertools import uuid import numpy as np import pandas as pd -from typing import Collection +from textwrap import dedent +from typing import Collection, Dict, Iterable, List, Optional, Set, Tuple, Union -from augur.dates import get_iso_year_week +from augur.dates import get_year_month, get_year_week from augur.errors import AugurError from augur.io.metadata import METADATA_DATE_COLUMN from augur.io.print import print_err from . import constants +from .weights_file import WEIGHTS_COLUMN, COLUMN_VALUE_FOR_DEFAULT_WEIGHT, get_default_weight, get_weighted_columns, read_weights_file + +Group = Tuple[str, ...] +"""Combination of grouping column values in tuple form.""" def get_groups_for_subsampling(strains, metadata, group_by=None): @@ -45,7 +51,7 @@ def get_groups_for_subsampling(strains, metadata, group_by=None): >>> group_by = ["year", "month"] >>> group_by_strain = get_groups_for_subsampling(strains, metadata, group_by) >>> group_by_strain - {'strain1': (2020, (2020, 1)), 'strain2': (2020, (2020, 2))} + {'strain1': (2020, '2020-01'), 'strain2': (2020, '2020-02')} If we omit the grouping columns, the result will group by a dummy column. @@ -67,7 +73,7 @@ def get_groups_for_subsampling(strains, metadata, group_by=None): >>> group_by = ["year", "month", "missing_column"] >>> group_by_strain = get_groups_for_subsampling(strains, metadata, group_by) >>> group_by_strain - {'strain1': (2020, (2020, 1), 'unknown'), 'strain2': (2020, (2020, 2), 'unknown')} + {'strain1': (2020, '2020-01', 'unknown'), 'strain2': (2020, '2020-02', 'unknown')} We can group metadata without any non-ID columns. @@ -138,15 +144,16 @@ def get_groups_for_subsampling(strains, metadata, group_by=None): if constants.DATE_YEAR_COLUMN in generated_columns_requested: metadata[constants.DATE_YEAR_COLUMN] = metadata[f'{temp_prefix}year'] if constants.DATE_MONTH_COLUMN in generated_columns_requested: - metadata[constants.DATE_MONTH_COLUMN] = list(zip( - metadata[f'{temp_prefix}year'], - metadata[f'{temp_prefix}month'] - )) + metadata[constants.DATE_MONTH_COLUMN] = metadata.apply(lambda row: get_year_month( + row[f'{temp_prefix}year'], + row[f'{temp_prefix}month'] + ), axis=1 + ) if constants.DATE_WEEK_COLUMN in generated_columns_requested: # Note that week = (year, week) from the date.isocalendar(). # Do not combine the raw year with the ISO week number alone, # since raw year ≠ ISO year. - metadata[constants.DATE_WEEK_COLUMN] = metadata.apply(lambda row: get_iso_year_week( + metadata[constants.DATE_WEEK_COLUMN] = metadata.apply(lambda row: get_year_week( row[f'{temp_prefix}year'], row[f'{temp_prefix}month'], row[f'{temp_prefix}day'] @@ -297,6 +304,203 @@ def get_probabilistic_group_sizes(groups, target_group_size, random_seed=None): return max_sizes_per_group +TARGET_SIZE_COLUMN = '_augur_filter_target_size' +INPUT_SIZE_COLUMN = '_augur_filter_input_size' +OUTPUT_SIZE_COLUMN = '_augur_filter_subsampling_output_size' + + +def get_weighted_group_sizes( + records_per_group: Dict[Group, int], + group_by: List[str], + weights_file: str, + target_total_size: int, + output_sizes_file: Optional[str], + random_seed: Optional[int], + ) -> Dict[Group, int]: + """Return target group sizes based on weights defined in ``weights_file``. + """ + groups = records_per_group.keys() + + weights = read_weights_file(weights_file) + + weighted_columns = get_weighted_columns(weights_file) + + # Other columns in group_by are considered unweighted. + unweighted_columns = list(set(group_by) - set(weighted_columns)) + + if unweighted_columns: + # This has the side effect of weighting the values *alongside* (rather + # than within) each weighted group. After dropping unused groups, adjust + # weights to ensure equal weighting of unweighted columns *within* each + # weighted group defined by the weighted columns. + weights = _add_unweighted_columns(weights, groups, group_by, unweighted_columns) + + weights = _handle_incomplete_weights(weights, weights_file, weighted_columns, group_by, groups) + weights = _drop_unused_groups(weights, groups, group_by) + + weights = _adjust_weights_for_unweighted_columns(weights, weighted_columns, unweighted_columns) + else: + weights = _handle_incomplete_weights(weights, weights_file, weighted_columns, group_by, groups) + weights = _drop_unused_groups(weights, groups, group_by) + + weights = _calculate_weighted_group_sizes(weights, target_total_size, random_seed) + + # Add columns to summarize the input data + weights[INPUT_SIZE_COLUMN] = weights.apply(lambda row: records_per_group[tuple(row[group_by].values)], axis=1) + weights[OUTPUT_SIZE_COLUMN] = weights[[INPUT_SIZE_COLUMN, TARGET_SIZE_COLUMN]].min(axis=1) + + # Warn on any under-sampled groups + for _, row in weights.iterrows(): + if row[INPUT_SIZE_COLUMN] < row[TARGET_SIZE_COLUMN]: + sequences = 'sequence' if row[TARGET_SIZE_COLUMN] == 1 else 'sequences' + are = 'is' if row[INPUT_SIZE_COLUMN] == 1 else 'are' + group = list(f'{col}={value!r}' for col, value in row[group_by].items()) + print_err(f"WARNING: Targeted {row[TARGET_SIZE_COLUMN]} {sequences} for group {group} but only {row[INPUT_SIZE_COLUMN]} {are} available.") + + if output_sizes_file: + weights.to_csv(output_sizes_file, index=False, sep='\t') + + return dict(zip(weights[group_by].apply(tuple, axis=1), weights[TARGET_SIZE_COLUMN])) + + +def _add_unweighted_columns( + weights: pd.DataFrame, + groups: Iterable[Group], + group_by: List[str], + unweighted_columns: List[str], + ) -> pd.DataFrame: + """Add the unweighted columns to the weights DataFrame. + + This is done by extending the existing weights to the newly created groups. + """ + + # Get unique values for each unweighted column. + values_for_unweighted_columns = defaultdict(set) + for group in groups: + # NOTE: The ordering of entries in `group` corresponds to the column + # names in `group_by`, but only because `get_groups_for_subsampling` + # conveniently retains the order. This could be more tightly coupled, + # but it works. + column_to_value_map = dict(zip(group_by, group)) + for column in unweighted_columns: + values_for_unweighted_columns[column].add(column_to_value_map[column]) + + # Create a DataFrame for all permutations of values in unweighted columns. + lists = [list(values_for_unweighted_columns[column]) for column in unweighted_columns] + unweighted_permutations = pd.DataFrame(list(itertools.product(*lists)), columns=unweighted_columns) + + return pd.merge(unweighted_permutations, weights, how='cross') + + +def _drop_unused_groups( + weights: pd.DataFrame, + groups: Collection[Group], + group_by: List[str], + ) -> pd.DataFrame: + """Drop any groups from ``weights`` that don't appear in ``groups``. + """ + weights.set_index(group_by, inplace=True) + + # Pandas only uses MultiIndex if there is more than one column in the index. + valid_index: Set[Union[Group, str]] + if len(group_by) > 1: + valid_index = set(groups) + else: + valid_index = set(group[0] for group in groups) + + extra_groups = set(weights.index) - valid_index + if extra_groups: + count = len(extra_groups) + unit = "group" if count == 1 else "groups" + print_err(f"NOTE: Skipping {count} {unit} due to lack of entries in metadata.") + weights = weights[weights.index.isin(valid_index)] + + weights.reset_index(inplace=True) + + return weights + + +def _adjust_weights_for_unweighted_columns( + weights: pd.DataFrame, + weighted_columns: List[str], + unweighted_columns: Collection[str], + ) -> pd.DataFrame: + """Adjust weights for unweighted columns to reflect equal weighting within each weighted group. + """ + columns = 'column' if len(unweighted_columns) == 1 else 'columns' + those = 'that' if len(unweighted_columns) == 1 else 'those' + print_err(f"NOTE: Weights were not provided for the {columns} {', '.join(repr(col) for col in unweighted_columns)}. Using equal weights across values in {those} {columns}.") + + weights_grouped = weights.groupby(weighted_columns) + weights[WEIGHTS_COLUMN] = weights_grouped[WEIGHTS_COLUMN].transform(lambda x: x / len(x)) + + return weights + + +def _calculate_weighted_group_sizes( + weights: pd.DataFrame, + target_total_size: int, + random_seed: Optional[int], + ) -> pd.DataFrame: + """Calculate maximum group sizes based on weights. + """ + weights[TARGET_SIZE_COLUMN] = pd.Series(weights[WEIGHTS_COLUMN] / weights[WEIGHTS_COLUMN].sum() * target_total_size) + + # Group sizes must be whole numbers. Round probabilistically by adding a + # random number between [0,1) and truncating the decimal part. + rng = np.random.default_rng(random_seed) + weights[TARGET_SIZE_COLUMN] = (weights[TARGET_SIZE_COLUMN].add(pd.Series(rng.random(len(weights))))).astype(int) + + return weights + + +def _handle_incomplete_weights( + weights: pd.DataFrame, + weights_file: str, + weighted_columns: List[str], + group_by: List[str], + groups: Iterable[Group], + ) -> pd.DataFrame: + """Handle the case where the weights file does not cover all rows in the metadata. + """ + missing_groups = set(groups) - set(weights[group_by].apply(tuple, axis=1)) + + if not missing_groups: + return weights + + # Collect the column values that are missing weights. + missing_values_by_column = defaultdict(set) + for group in missing_groups: + # NOTE: The ordering of entries in `group` corresponds to the column + # names in `group_by`, but only because `get_groups_for_subsampling` + # conveniently retains the order. This could be more tightly coupled, + # but it works. + column_to_value_map = dict(zip(group_by, group)) + for column in weighted_columns: + missing_values_by_column[column].add(column_to_value_map[column]) + + columns_with_values = '\n - '.join(f'{column!r}: {list(sorted(values))}' for column, values in sorted(missing_values_by_column.items())) + + default_weight = get_default_weight(weights, weighted_columns) + + if not default_weight: + raise AugurError(dedent(f"""\ + The input metadata contains these values under the following columns that are not covered by {weights_file!r}: + - {columns_with_values} + To fix this, either: + (1) specify weights explicitly - add entries to {weights_file!r} for the values above, or + (2) specify a default weight - add an entry to {weights_file!r} with the value {COLUMN_VALUE_FOR_DEFAULT_WEIGHT!r} for all columns""")) + else: + print_err(dedent(f"""\ + WARNING: The input metadata contains these values under the following columns that are not directly covered by {weights_file!r}: + - {columns_with_values} + The default weight of {default_weight!r} will be used for all groups defined by those values.""")) + + missing_weights = pd.DataFrame(sorted(missing_groups), columns=group_by) + missing_weights[WEIGHTS_COLUMN] = default_weight + return pd.merge(weights, missing_weights, on=[*group_by, WEIGHTS_COLUMN], how='outer') + + def create_queues_by_group(max_sizes_per_group): return {group: PriorityQueue(max_size) for group, max_size in max_sizes_per_group.items()} diff --git a/augur/filter/validate_arguments.py b/augur/filter/validate_arguments.py index 49ee310bd..9e639a83d 100644 --- a/augur/filter/validate_arguments.py +++ b/augur/filter/validate_arguments.py @@ -1,4 +1,5 @@ from augur.errors import AugurError +from augur.filter.weights_file import get_weighted_columns from augur.io.vcf import is_vcf as filename_is_vcf @@ -43,3 +44,22 @@ def validate_arguments(args): # If user requested grouping, confirm that other required inputs are provided, too. if args.group_by and not any((args.sequences_per_group, args.subsample_max_sequences)): raise AugurError("You must specify a number of sequences per group or maximum sequences to subsample.") + + # Weighted columns must be specified explicitly. + if args.group_by_weights: + weighted_columns = get_weighted_columns(args.group_by_weights) + if (not set(weighted_columns) <= set(args.group_by)): + raise AugurError("Columns in --group-by-weights must be a subset of columns provided in --group-by.") + + # --output-group-by-sizes is only available for --group-by-weights. + if args.output_group_by_sizes and not args.group_by_weights: + raise AugurError( + "--output-group-by-sizes is only available for --group-by-weights. " + "It may be added to other sampling methods in the future - see " + ) + + # --group-by-weights cannot be used with --no-probabilistic-sampling. + if args.group_by_weights and not args.probabilistic_sampling: + raise AugurError( + "--group-by-weights cannot be used with --no-probabilistic-sampling." + ) diff --git a/augur/filter/weights_file.py b/augur/filter/weights_file.py new file mode 100644 index 000000000..3c3389923 --- /dev/null +++ b/augur/filter/weights_file.py @@ -0,0 +1,63 @@ +import pandas as pd +from textwrap import dedent +from typing import List +from augur.errors import AugurError + + +WEIGHTS_COLUMN = 'weight' +COLUMN_VALUE_FOR_DEFAULT_WEIGHT = 'default' + + +class InvalidWeightsFile(AugurError): + def __init__(self, file, error_message): + super().__init__(f"Bad weights file {file!r}.\n{error_message}") + + +def read_weights_file(weights_file): + weights = pd.read_csv(weights_file, delimiter='\t', comment='#') + + if not pd.api.types.is_numeric_dtype(weights[WEIGHTS_COLUMN]): + non_numeric_weight_lines = [index + 2 for index in weights[~weights[WEIGHTS_COLUMN].str.isnumeric()].index.tolist()] + raise InvalidWeightsFile(weights_file, dedent(f"""\ + Found non-numeric weights on the following lines: {non_numeric_weight_lines} + {WEIGHTS_COLUMN!r} column must be numeric.""")) + + if any(weights[WEIGHTS_COLUMN] < 0): + negative_weight_lines = [index + 2 for index in weights[weights[WEIGHTS_COLUMN] < 0].index.tolist()] + raise InvalidWeightsFile(weights_file, dedent(f"""\ + Found negative weights on the following lines: {negative_weight_lines} + {WEIGHTS_COLUMN!r} column must be non-negative.""")) + + return weights + + +def get_weighted_columns(weights_file): + with open(weights_file) as f: + has_rows = False + for row in f: + has_rows = True + if row.startswith('#'): + continue + columns = row.rstrip().split('\t') + break + if not has_rows: + raise InvalidWeightsFile(weights_file, "File is empty.") + columns.remove(WEIGHTS_COLUMN) + return columns + + +def get_default_weight(weights: pd.DataFrame, weighted_columns: List[str]): + default_weight_values = weights[(weights[weighted_columns] == COLUMN_VALUE_FOR_DEFAULT_WEIGHT).all(axis=1)][WEIGHTS_COLUMN].unique() + + if len(default_weight_values) > 1: + # TODO: raise InvalidWeightsFile, not AugurError. This function takes + # the weights DataFrame instead of the filepath, so it does not have the + # file parameter to InvalidWeightsFile. I didn't want to pass an extra + # filepath parameter to this function just to have it available for the + # custom exception class. I also didn't want to pass the filepath + # parameter. One idea would be to define a custom class to represent a + # weights file, however this seemed overkill in a quick prototype: + # + raise AugurError(f"Multiple default weights were specified: {', '.join(repr(weight) for weight in default_weight_values)}. Only one default weight entry can be accepted.") + if len(default_weight_values) == 1: + return default_weight_values[0] diff --git a/docs/api/developer/augur.filter.rst b/docs/api/developer/augur.filter.rst index 44942a2fd..c24ddbac3 100644 --- a/docs/api/developer/augur.filter.rst +++ b/docs/api/developer/augur.filter.rst @@ -17,3 +17,4 @@ Submodules augur.filter.io augur.filter.subsample augur.filter.validate_arguments + augur.filter.weights_file diff --git a/docs/api/developer/augur.filter.weights_file.rst b/docs/api/developer/augur.filter.weights_file.rst new file mode 100644 index 000000000..31be68a6b --- /dev/null +++ b/docs/api/developer/augur.filter.weights_file.rst @@ -0,0 +1,7 @@ +augur.filter.weights\_file module +================================= + +.. automodule:: augur.filter.weights_file + :members: + :undoc-members: + :show-inheritance: diff --git a/tests/filter/test_subsample.py b/tests/filter/test_subsample.py index d454d0e53..b8e427f7e 100644 --- a/tests/filter/test_subsample.py +++ b/tests/filter/test_subsample.py @@ -70,11 +70,11 @@ def test_filter_groupby_invalid_warn(self, valid_metadata: pd.DataFrame, capsys) strains = metadata.index.tolist() group_by_strain = augur.filter.subsample.get_groups_for_subsampling(strains, metadata, group_by=groups) assert group_by_strain == { - 'SEQ_1': ('A', 2020, (2020, 1), 'unknown'), - 'SEQ_2': ('A', 2020, (2020, 2), 'unknown'), - 'SEQ_3': ('B', 2020, (2020, 3), 'unknown'), - 'SEQ_4': ('B', 2020, (2020, 4), 'unknown'), - 'SEQ_5': ('B', 2020, (2020, 5), 'unknown') + 'SEQ_1': ('A', 2020, '2020-01', 'unknown'), + 'SEQ_2': ('A', 2020, '2020-02', 'unknown'), + 'SEQ_3': ('B', 2020, '2020-03', 'unknown'), + 'SEQ_4': ('B', 2020, '2020-04', 'unknown'), + 'SEQ_5': ('B', 2020, '2020-05', 'unknown') } captured = capsys.readouterr() assert captured.err == "WARNING: Some of the specified group-by categories couldn't be found: invalid\nFiltering by group may behave differently than expected!\n" @@ -150,9 +150,9 @@ def test_filter_groupby_only_year_month_provided(self, valid_metadata: pd.DataFr strains = metadata.index.tolist() group_by_strain = augur.filter.subsample.get_groups_for_subsampling(strains, metadata, group_by=groups) assert group_by_strain == { - 'SEQ_1': ('A', 2020, (2020, 1)), - 'SEQ_2': ('A', 2020, (2020, 1)), - 'SEQ_3': ('B', 2020, (2020, 1)), - 'SEQ_4': ('B', 2020, (2020, 1)), - 'SEQ_5': ('B', 2020, (2020, 1)) + 'SEQ_1': ('A', 2020, '2020-01'), + 'SEQ_2': ('A', 2020, '2020-01'), + 'SEQ_3': ('B', 2020, '2020-01'), + 'SEQ_4': ('B', 2020, '2020-01'), + 'SEQ_5': ('B', 2020, '2020-01') } diff --git a/tests/functional/filter/cram/subsample-output-group-by-sizes-error.t b/tests/functional/filter/cram/subsample-output-group-by-sizes-error.t new file mode 100644 index 000000000..2e17dc60f --- /dev/null +++ b/tests/functional/filter/cram/subsample-output-group-by-sizes-error.t @@ -0,0 +1,14 @@ +Setup + + $ source "$TESTDIR"/_setup.sh + +--output-group-by-sizes does not work without --group-by-weights. + + $ ${AUGUR} filter \ + > --metadata "$TESTDIR/../data/metadata.tsv" \ + > --group-by year month \ + > --subsample-max-sequences 100 \ + > --output-group-by-sizes target_group_sizes.tsv \ + > --output-strains strains.txt + ERROR: --output-group-by-sizes is only available for --group-by-weights. It may be added to other sampling methods in the future - see + [2] diff --git a/tests/functional/filter/cram/subsample-weighted-and-uniform-mix.t b/tests/functional/filter/cram/subsample-weighted-and-uniform-mix.t new file mode 100644 index 000000000..b11028c24 --- /dev/null +++ b/tests/functional/filter/cram/subsample-weighted-and-uniform-mix.t @@ -0,0 +1,122 @@ +Setup + + $ source "$TESTDIR"/_setup.sh + +Generate metadata file with 250 rows. + + $ echo "strain date location" > metadata.tsv + $ for i in $(seq 1 50); do + > echo "2000A_$i 2000 A" >> metadata.tsv + > echo "2000B_$i 2000 B" >> metadata.tsv + > echo "2001A_$i 2001 A" >> metadata.tsv + > echo "2001B_$i 2001 B" >> metadata.tsv + > echo "2002B_$i 2002 B" >> metadata.tsv + > done + +Weight locations A:B as 2:1. This is reflected in target_group_sizes.tsv below. + + $ cat >weights-A2B1.tsv <<~~ + > location weight + > A 2 + > B 1 + > ~~ + + $ ${AUGUR} filter \ + > --metadata metadata.tsv \ + > --group-by location \ + > --group-by-weights weights-A2B1.tsv \ + > --subsample-max-sequences 100 \ + > --subsample-seed 0 \ + > --output-group-by-sizes target_group_sizes.tsv \ + > --output-metadata filtered.tsv 2>/dev/null + + $ cat target_group_sizes.tsv | tsv-pretty + location weight _augur_filter_target_size _augur_filter_input_size _augur_filter_subsampling_output_size + A 2 67 100 67 + B 1 33 150 33 + +There are also enough rows per group that the output metadata directly reflects +the target group sizes. + + $ cat filtered.tsv | tail -n +2 | cut -f3 | sort | uniq -c + \s*67 A (re) + \s*33 B (re) + +Using 1:1 weights is similarly straightforward, with 50 sequences from each location. + + $ cat >weights-A1B1.tsv <<~~ + > location weight + > A 1 + > B 1 + > ~~ + + $ ${AUGUR} filter \ + > --metadata metadata.tsv \ + > --group-by location \ + > --group-by-weights weights-A1B1.tsv \ + > --subsample-max-sequences 100 \ + > --subsample-seed 0 \ + > --output-group-by-sizes target_group_sizes.tsv \ + > --output-strains strains.txt 2>/dev/null + + $ cat target_group_sizes.tsv | tsv-pretty + location weight _augur_filter_target_size _augur_filter_input_size _augur_filter_subsampling_output_size + A 1 50 100 50 + B 1 50 150 50 + +Keep the 1:1 location weighting, but add uniform sampling on year. +The uniform sampling happens "within" each weighted column value, so the 1:1 +location weighting is reflected even though there is an imbalance in years +available per location. + + $ ${AUGUR} filter \ + > --metadata metadata.tsv \ + > --group-by year location \ + > --group-by-weights weights-A1B1.tsv \ + > --subsample-max-sequences 100 \ + > --subsample-seed 0 \ + > --output-group-by-sizes target_group_sizes.tsv \ + > --output-strains strains.txt 2>/dev/null + + $ cat target_group_sizes.tsv | tsv-pretty + year location weight _augur_filter_target_size _augur_filter_input_size _augur_filter_subsampling_output_size + 2000 A 0.5 25 50 25 + 2000 B 0.3333333333333333 16 50 16 + 2001 A 0.5 25 50 25 + 2001 B 0.3333333333333333 16 50 16 + 2002 B 0.3333333333333333 17 50 17 + +If a single sequence is added for group (2002,A), the weighting now appears +"equal" among all years and locations. + +However, there is only 1 sequence available in (2002,A), much lower than the +requested 17, so the total number of sequences outputted is lower than requested. + + $ echo "2002A_1 2002 A" >> metadata.tsv + + $ ${AUGUR} filter \ + > --metadata metadata.tsv \ + > --group-by year location \ + > --group-by-weights weights-A1B1.tsv \ + > --subsample-max-sequences 100 \ + > --subsample-seed 0 \ + > --output-group-by-sizes target_group_sizes.tsv \ + > --output-strains strains.txt + Sampling with weights defined by weights-A1B1.tsv. + NOTE: Weights were not provided for the column 'year'. Using equal weights across values in that column. + WARNING: Targeted 17 sequences for group ['year=2002', "location='A'"] but only 1 is available. + 168 strains were dropped during filtering + 168 were dropped because of subsampling criteria + 83 strains passed all filters + + $ cat target_group_sizes.tsv | tsv-pretty + year location weight _augur_filter_target_size _augur_filter_input_size _augur_filter_subsampling_output_size + 2000 A 0.3333333333333333 17 50 17 + 2000 B 0.3333333333333333 16 50 16 + 2001 A 0.3333333333333333 16 50 16 + 2001 B 0.3333333333333333 16 50 16 + 2002 A 0.3333333333333333 17 1 1 + 2002 B 0.3333333333333333 17 50 17 + + $ wc -l strains.txt + \s*83 .* (re) diff --git a/tests/functional/filter/cram/subsample-weighted-comments.t b/tests/functional/filter/cram/subsample-weighted-comments.t new file mode 100644 index 000000000..1fba471dd --- /dev/null +++ b/tests/functional/filter/cram/subsample-weighted-comments.t @@ -0,0 +1,40 @@ +Setup + + $ source "$TESTDIR"/_setup.sh + +Set up files. + + $ cat >metadata.tsv <<~~ + > strain date location + > SEQ1 2000-01-01 A + > SEQ2 2000-01-02 A + > SEQ3 2000-01-03 B + > SEQ4 2000-01-04 B + > SEQ5 2000-02-01 A + > SEQ6 2000-02-02 A + > SEQ7 2000-03-01 B + > SEQ8 2000-03-02 B + > ~~ + +Comments in the weights file are valid. + + $ cat >weights.tsv <<~~ + > # This is a comment + > ## So is this + > location weight + > A 2 + > B 1 + > C 3 + > ~~ + + $ ${AUGUR} filter \ + > --metadata metadata.tsv \ + > --group-by location \ + > --group-by-weights weights.tsv \ + > --subsample-max-sequences 6 \ + > --output-strains strains.txt + Sampling with weights defined by weights.tsv. + NOTE: Skipping 1 group due to lack of entries in metadata. + 2 strains were dropped during filtering + 2 were dropped because of subsampling criteria + 6 strains passed all filters diff --git a/tests/functional/filter/cram/subsample-weighted-invalid-file.t b/tests/functional/filter/cram/subsample-weighted-invalid-file.t new file mode 100644 index 000000000..3159bbef1 --- /dev/null +++ b/tests/functional/filter/cram/subsample-weighted-invalid-file.t @@ -0,0 +1,77 @@ +Setup + + $ source "$TESTDIR"/_setup.sh + +Set up files. + + $ cat >metadata.tsv <<~~ + > strain date location + > SEQ1 2000-01-01 A + > SEQ2 2000-01-02 A + > SEQ3 2000-01-03 B + > SEQ4 2000-01-04 B + > SEQ5 2000-02-01 A + > SEQ6 2000-02-02 A + > SEQ7 2000-03-01 B + > SEQ8 2000-03-02 B + > ~~ + +Weights must be non-negative. + + $ cat >weights.tsv <<~~ + > location weight + > A 2 + > B 1 + > C -1 + > ~~ + + $ ${AUGUR} filter \ + > --metadata metadata.tsv \ + > --group-by location \ + > --group-by-weights weights.tsv \ + > --subsample-max-sequences 6 \ + > --subsample-seed 0 \ + > --output-strains strains.txt + Sampling with weights defined by weights.tsv. + ERROR: Bad weights file 'weights.tsv'. + Found negative weights on the following lines: [4] + 'weight' column must be non-negative. + [2] + +Weights must be numeric. + + $ cat >weights.tsv <<~~ + > location weight + > A yes + > B 1 + > C no + > ~~ + + $ ${AUGUR} filter \ + > --metadata metadata.tsv \ + > --group-by location \ + > --group-by-weights weights.tsv \ + > --subsample-max-sequences 6 \ + > --subsample-seed 0 \ + > --output-strains strains.txt + Sampling with weights defined by weights.tsv. + ERROR: Bad weights file 'weights.tsv'. + Found non-numeric weights on the following lines: [2, 4] + 'weight' column must be numeric. + [2] + +Weights file cannot be empty. + + $ cat >weights.tsv <<~~ + > ~~ + + $ ${AUGUR} filter \ + > --metadata metadata.tsv \ + > --group-by location \ + > --group-by-weights weights.tsv \ + > --subsample-max-sequences 6 \ + > --subsample-seed 0 \ + > --output-strains strains.txt + ERROR: Bad weights file 'weights.tsv'. + File is empty. + [2] diff --git a/tests/functional/filter/cram/subsample-weighted-validation-errors.t b/tests/functional/filter/cram/subsample-weighted-validation-errors.t new file mode 100644 index 000000000..c6400fe74 --- /dev/null +++ b/tests/functional/filter/cram/subsample-weighted-validation-errors.t @@ -0,0 +1,70 @@ +Setup + + $ source "$TESTDIR"/_setup.sh + +Set up files. + + $ cat >metadata.tsv <<~~ + > strain date location + > SEQ1 2000-01-01 A + > SEQ2 2000-01-02 A + > SEQ3 2000-01-03 B + > SEQ4 2000-01-04 B + > SEQ5 2000-02-01 A + > SEQ6 2000-02-02 A + > SEQ7 2000-03-01 B + > SEQ8 2000-03-02 B + > ~~ + + $ cat >weights.tsv <<~~ + > location weight + > A 2 + > B 1 + > ~~ + +When --group-by-weights is specified, all columns must be provided in +--group-by. + + $ ${AUGUR} filter \ + > --metadata metadata.tsv \ + > --group-by-weights weights.tsv \ + > --subsample-max-sequences 6 \ + > --subsample-seed 0 \ + > --output-strains strains.txt + ERROR: Columns in --group-by-weights must be a subset of columns provided in --group-by. + [2] + + $ ${AUGUR} filter \ + > --metadata metadata.tsv \ + > --group-by month \ + > --group-by-weights weights.tsv \ + > --subsample-max-sequences 6 \ + > --subsample-seed 0 \ + > --output-strains strains.txt + ERROR: Columns in --group-by-weights must be a subset of columns provided in --group-by. + [2] + +--output-group-by-sizes is only available for --group-by-weights. + + $ ${AUGUR} filter \ + > --metadata metadata.tsv \ + > --group-by location \ + > --subsample-max-sequences 6 \ + > --subsample-seed 0 \ + > --output-group-by-sizes sizes.tsv \ + > --output-strains strains.txt + ERROR: --output-group-by-sizes is only available for --group-by-weights. It may be added to other sampling methods in the future - see + [2] + +--group-by-weights cannot be used with --no-probabilistic-sampling. + + $ ${AUGUR} filter \ + > --metadata metadata.tsv \ + > --group-by location \ + > --group-by-weights weights.tsv \ + > --no-probabilistic-sampling \ + > --subsample-max-sequences 6 \ + > --subsample-seed 0 \ + > --output-strains strains.txt + ERROR: --group-by-weights cannot be used with --no-probabilistic-sampling. + [2] diff --git a/tests/functional/filter/cram/subsample-weighted.t b/tests/functional/filter/cram/subsample-weighted.t new file mode 100644 index 000000000..1737911e9 --- /dev/null +++ b/tests/functional/filter/cram/subsample-weighted.t @@ -0,0 +1,169 @@ +Setup + + $ source "$TESTDIR"/_setup.sh + +Set up files. + + $ cat >metadata.tsv <<~~ + > strain date location + > SEQ1 2000-01-01 A + > SEQ2 2000-01-02 A + > SEQ3 2000-01-03 B + > SEQ4 2000-01-04 B + > SEQ5 2000-02-01 A + > SEQ6 2000-02-02 A + > SEQ7 2000-03-01 B + > SEQ8 2000-03-02 B + > ~~ + +Sampling with location weights only. + + $ cat >weights.tsv <<~~ + > location weight + > A 2 + > B 1 + > C 3 + > ~~ + +This should take 4 from location A and 2 from location B. The weight for +location C is ignored because there are no corresponding rows in the metadata. + + $ ${AUGUR} filter \ + > --metadata metadata.tsv \ + > --group-by location \ + > --group-by-weights weights.tsv \ + > --subsample-max-sequences 6 \ + > --subsample-seed 0 \ + > --output-strains strains.txt + Sampling with weights defined by weights.tsv. + NOTE: Skipping 1 group due to lack of entries in metadata. + 2 strains were dropped during filtering + 2 were dropped because of subsampling criteria + 6 strains passed all filters + + $ cat strains.txt + SEQ1 + SEQ2 + SEQ5 + SEQ6 + SEQ7 + SEQ8 + +Sampling with weights on location and uniform sampling on date (--group-by +month) should work. + + $ cat >weights.tsv <<~~ + > location weight + > A 2 + > B 1 + > ~~ + + $ ${AUGUR} filter \ + > --metadata metadata.tsv \ + > --group-by location month \ + > --group-by-weights weights.tsv \ + > --subsample-max-sequences 6 \ + > --subsample-seed 0 \ + > --output-strains strains.txt + Sampling with weights defined by weights.tsv. + NOTE: Skipping 2 groups due to lack of entries in metadata. + NOTE: Weights were not provided for the column 'month'. Using equal weights across values in that column. + 2 strains were dropped during filtering + 2 were dropped because of subsampling criteria + 6 strains passed all filters + +Sampling with incomplete weights should show an error. + + $ cat >weights.tsv <<~~ + > location weight + > A 2 + > ~~ + + $ ${AUGUR} filter \ + > --metadata metadata.tsv \ + > --group-by location \ + > --group-by-weights weights.tsv \ + > --subsample-max-sequences 6 \ + > --subsample-seed 0 \ + > --output-strains strains.txt + Sampling with weights defined by weights.tsv. + ERROR: The input metadata contains these values under the following columns that are not covered by 'weights.tsv': + - 'location': ['B'] + To fix this, either: + (1) specify weights explicitly - add entries to 'weights.tsv' for the values above, or + (2) specify a default weight - add an entry to 'weights.tsv' with the value 'default' for all columns + [2] + +Re-running with a default weight shows a warning and continues. + + $ cat >weights.tsv <<~~ + > location weight + > A 2 + > default 1 + > ~~ + + $ ${AUGUR} filter \ + > --metadata metadata.tsv \ + > --group-by month location \ + > --group-by-weights weights.tsv \ + > --subsample-max-sequences 6 \ + > --subsample-seed 0 \ + > --output-strains strains.txt + Sampling with weights defined by weights.tsv. + WARNING: The input metadata contains these values under the following columns that are not directly covered by 'weights.tsv': + - 'location': ['B'] + The default weight of 1 will be used for all groups defined by those values. + NOTE: Skipping 4 groups due to lack of entries in metadata. + NOTE: Weights were not provided for the column 'month'. Using equal weights across values in that column. + 2 strains were dropped during filtering + 2 were dropped because of subsampling criteria + 6 strains passed all filters + +To specify a default weight, the value 'default' must be set for all weighted columns. + + $ cat >weights.tsv <<~~ + > location month weight + > A 2000-01 2 + > A 2000-02 2 + > default 1 + > ~~ + + $ ${AUGUR} filter \ + > --metadata metadata.tsv \ + > --group-by month location \ + > --group-by-weights weights.tsv \ + > --subsample-max-sequences 6 \ + > --subsample-seed 0 \ + > --output-strains strains.txt + Sampling with weights defined by weights.tsv. + ERROR: The input metadata contains these values under the following columns that are not covered by 'weights.tsv': + - 'location': ['B'] + - 'month': ['2000-01', '2000-03'] + To fix this, either: + (1) specify weights explicitly - add entries to 'weights.tsv' for the values above, or + (2) specify a default weight - add an entry to 'weights.tsv' with the value 'default' for all columns + [2] + + $ cat >weights.tsv <<~~ + > location month weight + > A 2000-01 2 + > A 2000-02 2 + > default default 1 + > ~~ + + $ ${AUGUR} filter \ + > --metadata metadata.tsv \ + > --group-by month location \ + > --group-by-weights weights.tsv \ + > --subsample-max-sequences 6 \ + > --subsample-seed 0 \ + > --output-strains strains.txt + Sampling with weights defined by weights.tsv. + WARNING: The input metadata contains these values under the following columns that are not directly covered by 'weights.tsv': + - 'location': ['B'] + - 'month': ['2000-01', '2000-03'] + The default weight of 1 will be used for all groups defined by those values. + NOTE: Skipping 1 group due to lack of entries in metadata. + 2 strains were dropped during filtering + 2 were dropped because of subsampling criteria + 6 strains passed all filters