diff --git a/.github/workflows/test-snake.yml b/.github/workflows/test-snake.yml index 8b090e6..f59c04c 100644 --- a/.github/workflows/test-snake.yml +++ b/.github/workflows/test-snake.yml @@ -49,9 +49,10 @@ jobs: - name: Install dependencies run: | pip install -e . + pip install pytest - - name: Snakemake Testing + - name: Snakemake Unit Testing run: | - snakemake --cores 1 --snakefile workflow/Snakefile --directory .test --verbose + pytest workflow/.tests - \ No newline at end of file + # TODO: add dry-run testing \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5e68225..b90c2b1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,11 +32,21 @@ repos: rev: v0.10.2 hooks: - id: snakefmt + - repo: local hooks: - - id: snakemake-dryrun - name: Snakemake Dry Run - entry: bash -c 'cd workflow && poetry run snakemake -n' + - id: snakemake-unit-testing + name: Snakemake Unit Testing + entry: bash -c 'poetry run pytest workflow/.tests' language: system - files: (Snakefile|\.smk$) - pass_filenames: false \ No newline at end of file + types: [python] + +# TODO enable dry-run testing + # - repo: local + # hooks: + # - id: snakemake-dryrun + # name: Snakemake Dry Run + # entry: bash -c 'cd workflow && poetry run snakemake -n' + # language: system + # files: (Snakefile|\.smk$) + # pass_filenames: false \ No newline at end of file diff --git a/config/amplicon_cov.smk b/config/amplicon_cov.smk index 4da3ef9..8d03af9 100644 --- a/config/amplicon_cov.smk +++ b/config/amplicon_cov.smk @@ -1,5 +1,5 @@ # Inputs -sample_list_dir : "/cluster/project/pangolin/work-amplicon-coverage/test_data/" -sample_dir : "/cluster/project/pangolin/work-amplicon-coverage/test_data/samples" +sample_list_dir: "/cluster/project/pangolin/work-amplicon-coverage/test_data/" +sample_dir: "/cluster/project/pangolin/work-amplicon-coverage/test_data/samples" # Outputs -output_dir : "/cluster/home/koehng/temp/amplicon_cov/" +output_dir: "/cluster/home/koehng/temp/amplicon_cov/" diff --git a/pyproject.toml b/pyproject.toml index fefee27..5722cb3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ seaborn = "^0.13.2" pandas-stubs = "^2.2.2.240807" click = "^8.1.7" snakemake = "^8.20.4" +interrogate = "^1.7.0" [tool.poetry.group.dev.dependencies] pytest = "^7.2.1" diff --git a/workflow/.tests/unit/common.py b/workflow/.tests/unit/common.py index 58d6954..c92d9cb 100644 --- a/workflow/.tests/unit/common.py +++ b/workflow/.tests/unit/common.py @@ -5,6 +5,7 @@ from pathlib import Path import subprocess as sp import os +import pandas as pd import sys @@ -74,3 +75,26 @@ def compare_files(self, generated_file, expected_file): Compare the generated file with the expected file. """ sp.check_output(["cmp", generated_file, expected_file]) + + +def compare_csv_files( + file1_path: str, file2_path: str, tolerance: float = 1e-4 +) -> bool: + """ + Compare two CSV files with a given tolerance. + """ + df1 = pd.read_csv(file1_path, skiprows=[1]) + df2 = pd.read_csv(file2_path, skiprows=[1]) + + if df1.shape != df2.shape: + raise ValueError("DataFrames have different shapes") + + # check that the data frames contrain the same data types + assert df1.dtypes.equals(df2.dtypes) + + # check that the data frames contain the same data + pd.testing.assert_frame_equal( + df1, df2, check_exact=False, rtol=tolerance, atol=tolerance + ) + + return True diff --git a/workflow/.tests/unit/test_make_price_data.py b/workflow/.tests/unit/test_make_price_data.py index d422005..ac2bc01 100644 --- a/workflow/.tests/unit/test_make_price_data.py +++ b/workflow/.tests/unit/test_make_price_data.py @@ -1,3 +1,7 @@ +""" +This script tests the make_price_data rule. +""" + import os import sys import subprocess as sp @@ -5,10 +9,15 @@ import shutil from pathlib import Path +from common import compare_csv_files + sys.path.insert(0, os.path.dirname(__file__)) def test_make_price_data(): + """ + Test the make_price_data rule. + """ with TemporaryDirectory() as tmpdir: workdir = Path(tmpdir) / "workdir" workdir.mkdir(exist_ok=True) @@ -55,17 +64,12 @@ def test_make_price_data(): assert (workdir / "results" / "statistics.csv").exists() # Compare output with expected result - result = sp.run( - [ - "diff", - str(workdir / "results" / "statistics.csv"), - str(expected_path / "statistics.csv"), - ], - capture_output=True, - text=True, + files_match = compare_csv_files( + str(workdir / "results" / "statistics.csv"), + str(expected_path / "statistics.csv"), ) - assert result.returncode == 0, f"Files are different:\n{result.stdout}" + assert files_match, "Files are different within the specified tolerance" ### Main diff --git a/workflow/rules/amplicon_cov.smk b/workflow/rules/amplicon_cov.smk index 1d753db..822582a 100644 --- a/workflow/rules/amplicon_cov.smk +++ b/workflow/rules/amplicon_cov.smk @@ -9,15 +9,15 @@ rule relative_amplicon_coverage_per_batch: Calculate the relative amplicon coverage for all samples in the batch specific samples{batch}.tsv file. """ input: - sample_list = config['sample_list_dir'] + "samples{batch}.tsv", - samples = config['sample_dir'] + sample_list=config["sample_list_dir"] + "samples{batch}.tsv", + samples=config["sample_dir"], output: - heatmap = config["output_dir"] + "{batch}/cov_heatmap.pdf", + heatmap=config["output_dir"] + "{batch}/cov_heatmap.pdf", params: - primers_fp ="../resources/amplicon_cov/articV3primers.bed", - output_dir = config["output_dir"] + "{batch}/" + primers_fp="../resources/amplicon_cov/articV3primers.bed", + output_dir=config["output_dir"] + "{batch}/", log: - config["output_dir"] + "relative_amplicon_coverage_per_batch/{batch}.log" + config["output_dir"] + "relative_amplicon_coverage_per_batch/{batch}.log", shell: """ mkdir -p {params.output_dir} @@ -33,19 +33,20 @@ rule relative_amplicon_coverage_per_batch: rule get_samples_per_batch: input: - samples_list = config['sample_list_dir'] + "samples.tsv" + samples_list=config["sample_list_dir"] + "samples.tsv", output: - samples_batch = config['sample_list_dir'] + "samples{batch}.tsv", + samples_batch=config["sample_list_dir"] + "samples{batch}.tsv", log: - config["output_dir"] + "get_samples_per_batch_{batch}.log" + config["output_dir"] + "get_samples_per_batch_{batch}.log", shell: """ grep {wildcards.batch} {input.samples_list} > {output.samples_batch} """ + rule get_coverage_for_batch: """ Calculate the relative amplicon coverage for all samples in the batch specific samples{batch}.tsv file. """ input: - samples = f"{config['output_dir']}20240705_AAFH52MM5/cov_heatmap.pdf", \ No newline at end of file + samples=f"{config['output_dir']}20240705_AAFH52MM5/cov_heatmap.pdf", diff --git a/workflow/rules/base_coverage.smk b/workflow/rules/base_coverage.smk index 5a94cc6..4a5b345 100644 --- a/workflow/rules/base_coverage.smk +++ b/workflow/rules/base_coverage.smk @@ -27,17 +27,18 @@ rule basecnt_coverage_depth: """Generate matrix of coverage depth per base position """ input: - mutations_of_interest = config["mutations_of_interest_dir"], - timeline = config["timeline_fp"] + mutations_of_interest=config["mutations_of_interest_dir"], + timeline=config["timeline_fp"], output: - output_file = config["outdir"] + "{location}/mut_base_coverage_{location}_{enddate}.csv" + output_file=config["outdir"] + + "{location}/mut_base_coverage_{location}_{enddate}.csv", params: - startdate = "2024-01-01", - enddate = "{enddate}", - location = "{location}", + startdate="2024-01-01", + enddate="{enddate}", + location="{location}", # TODO: add protocol and subset params, see extract_sample_ID log: - "logs/basecnt_coverage_depth/{location}_{enddate}.log" + "logs/basecnt_coverage_depth/{location}_{enddate}.log", run: logging.info("Running basecnt_coverage_depth") ug.analyze.run_basecnt_coverage( @@ -45,27 +46,28 @@ rule basecnt_coverage_depth: timeline_file_dir=input.timeline, mutations_of_interest_dir=input.mutations_of_interest, output_file=output.output_file, - startdate = params.startdate, - enddate = params.enddate, - location = params.location + startdate=params.startdate, + enddate=params.enddate, + location=params.location, ) rule total_coverage_depth: """ Calcultate the total coverage depth - """ + """ input: - mutations_of_interest = config["mutations_of_interest_dir"], - timeline = config["timeline_fp"] + mutations_of_interest=config["mutations_of_interest_dir"], + timeline=config["timeline_fp"], output: - output_file = config["outdir"] + "{location}/mut_total_coverage_{location}_{enddate}.csv" + output_file=config["outdir"] + + "{location}/mut_total_coverage_{location}_{enddate}.csv", params: - startdate = "2024-01-01", - enddate = "{enddate}", - location = "{location}", + startdate="2024-01-01", + enddate="{enddate}", + location="{location}", # TODO: add protocol and subset params, see extract_sample_ID log: - "logs/basecnt_coverage_depth/{location}_{enddate}.log" + "logs/basecnt_coverage_depth/{location}_{enddate}.log", run: logging.info("Running total_coverage_depth") ug.analyze.run_total_coverage_depth( @@ -73,28 +75,33 @@ rule total_coverage_depth: mutations_of_interest_fp=input.mutations_of_interest, timeline_file_dir=input.timeline, output_file=output.output_file, - startdate = params.startdate, - enddate = params.enddate, - location = params.location + startdate=params.startdate, + enddate=params.enddate, + location=params.location, ) + # snakemake lint=off rule mutation_statistics: """Compute mutation frequencies from the basecnt and general coverages and report the statistics """ input: - basecnt_coverage = config["outdir"] + "{location}/mut_base_coverage_{location}_{enddate}.csv", - total_coverage = config["outdir"] + "{location}/mut_total_coverage_{location}_{enddate}.csv" + basecnt_coverage=config["outdir"] + + "{location}/mut_base_coverage_{location}_{enddate}.csv", + total_coverage=config["outdir"] + + "{location}/mut_total_coverage_{location}_{enddate}.csv", params: - location = "{location}", - enddate = "{enddate}" + location="{location}", + enddate="{enddate}", log: - "logs/basecnt_coverage_depth/{location}_{enddate}.log" + "logs/basecnt_coverage_depth/{location}_{enddate}.log", output: - heatmap = config["outdir"] + "{location}/heatmap_{location}_{enddate}.pdf", - lineplot = config["outdir"] + "{location}/lineplot_{location}_{enddate}.pdf", - frequency_data_matrix = config["outdir"] + "{location}/frequency_data_matrix_{location}_{enddate}.csv", - mutations_statistics = config["outdir"] + "{location}/mutations_statistics__{location}_{enddate}.csv" + heatmap=config["outdir"] + "{location}/heatmap_{location}_{enddate}.pdf", + lineplot=config["outdir"] + "{location}/lineplot_{location}_{enddate}.pdf", + frequency_data_matrix=config["outdir"] + + "{location}/frequency_data_matrix_{location}_{enddate}.csv", + mutations_statistics=config["outdir"] + + "{location}/mutations_statistics__{location}_{enddate}.csv", run: logging.info("Running mutation_statistics") # Median frequency with IQR @@ -128,7 +135,9 @@ rule mutation_statistics: df = frequency_data_matrix.transpose() sns.set_style("white") - g = sns.heatmap(df, yticklabels=samples, cmap="Blues", linewidths=0, linecolor="none") + g = sns.heatmap( + df, yticklabels=samples, cmap="Blues", linewidths=0, linecolor="none" + ) fig = g.get_figure() plt.yticks(rotation=0, fontsize=8) @@ -152,27 +161,27 @@ rule mutation_statistics: # LINE PLOT explanatory_labels = { "C23039G": "C23039G (KP.3)", - "G22599C": "G22599C (KP.2)", - } + "G22599C": "G22599C (KP.2)", + } - # Plot line plot - sns.set(rc={"figure.figsize": (10, 5)}) - sns.set_style("white") + # Plot line plot + sns.set(rc={"figure.figsize": (10, 5)}) + sns.set_style("white") - # Transpose the DataFrame to have samples as columns - df = frequency_data_matrix.transpose() + # Transpose the DataFrame to have samples as columns + df = frequency_data_matrix.transpose() - # Create the line plot - fig, ax = plt.subplots() + # Create the line plot + fig, ax = plt.subplots() - # Plot each sample - for sample in df.columns: - explanatory_label = explanatory_labels.get(sample, sample) - ax.plot( + # Plot each sample + for sample in df.columns: + explanatory_label = explanatory_labels.get(sample, sample) + ax.plot( df.index, df[sample], label=explanatory_label, marker="o" ) # 'o' adds points to the line plot - # Customize the plot + # Customize the plot plt.xticks(rotation=45, fontsize=6, ha="right") plt.yticks(fontsize=8) plt.xlabel("Day", fontsize=10) @@ -208,7 +217,9 @@ rule mutation_statistics: iqr = q3 - q1 # Combine results into a DataFrame - results_2_weeks = pd.DataFrame({"Median": medians, "IQR": iqr, "Q1": q1, "Q3": q3}) + results_2_weeks = pd.DataFrame( + {"Median": medians, "IQR": iqr, "Q1": q1, "Q3": q3} + ) # select most recent 6 weeks six_weeks_ago = most_recent_date - timedelta(weeks=6) @@ -221,7 +232,9 @@ rule mutation_statistics: iqr = q3 - q1 # Combine results into a DataFrame - results_6_weeks = pd.DataFrame({"Median": medians, "IQR": iqr, "Q1": q1, "Q3": q3}) + results_6_weeks = pd.DataFrame( + {"Median": medians, "IQR": iqr, "Q1": q1, "Q3": q3} + ) # select most recent 12 weeks (3 months) three_months_ago = most_recent_date - timedelta(weeks=12) @@ -234,7 +247,9 @@ rule mutation_statistics: iqr = q3 - q1 # Combine results into a DataFrame - results_12_weeks = pd.DataFrame({"Median": medians, "IQR": iqr, "Q1": q1, "Q3": q3}) + results_12_weeks = pd.DataFrame( + {"Median": medians, "IQR": iqr, "Q1": q1, "Q3": q3} + ) # select most recent 24 weeks (6 months) six_months_ago = most_recent_date - timedelta(weeks=24) @@ -247,7 +262,9 @@ rule mutation_statistics: iqr = q3 - q1 # Combine results into a DataFrame - results_24_weeks = pd.DataFrame({"Median": medians, "IQR": iqr, "Q1": q1, "Q3": q3}) + results_24_weeks = pd.DataFrame( + {"Median": medians, "IQR": iqr, "Q1": q1, "Q3": q3} + ) # Combine information about each mutation into a single df dict_mut = {} @@ -273,14 +290,17 @@ rule mutation_statistics: ) logging.info("Saved mutation statistics") + # snakemake lint=on + rule mutation_statistics_Zürich_2024_07_03: """ Run mutation_statistics for Zürich on and enddate 2024-07-03 - """ + """ input: config["outdir"] + "Zürich (ZH)/lineplot_Zürich (ZH)_2024-07-03.pdf", config["outdir"] + "Zürich (ZH)/heatmap_Zürich (ZH)_2024-07-03.pdf", - config["outdir"] + "Zürich (ZH)/frequency_data_matrix_Zürich (ZH)_2024-07-03.csv", - config["outdir"] + "Zürich (ZH)/mutations_statistics__Zürich (ZH)_2024-07-03.csv" - + config["outdir"] + + "Zürich (ZH)/frequency_data_matrix_Zürich (ZH)_2024-07-03.csv", + config["outdir"] + + "Zürich (ZH)/mutations_statistics__Zürich (ZH)_2024-07-03.csv", diff --git a/workflow/rules/smk_testing.smk b/workflow/rules/smk_testing.smk index f2e4825..2ae713f 100644 --- a/workflow/rules/smk_testing.smk +++ b/workflow/rules/smk_testing.smk @@ -1,17 +1,19 @@ -import pandas as pd +import pandas as pd import logging from datetime import datetime, timedelta + # Use the specific config file for this test configfile: "../config/smk_testing_config.yaml" + rule make_price_data: input: - orderbook = config["orderbook"] + orderbook=config["orderbook"], output: - statistics = config["statistics"] + statistics=config["statistics"], params: - interval = config["interval"] + interval=config["interval"], run: # Read the data data = pd.read_csv(input.orderbook) @@ -23,7 +25,9 @@ rule make_price_data: # choose bounds for the intervals in seconds based on the config interval_seconds = params.interval * 60 # convert minutes to seconds bounds = range(int(start_time), int(end_time) + 1, interval_seconds) - statistics = data.groupby(pd.cut(data["Time"], bins=bounds)).agg(["mean", "std", "min", "max"]) + statistics = data.groupby(pd.cut(data["Time"], bins=bounds)).agg( + ["mean", "std", "min", "max"] + ) # save the statistics - statistics.to_csv(output.statistics, index=False) \ No newline at end of file + statistics.to_csv(output.statistics, index=False)