Skip to content

Commit

Permalink
MRG: add label output & input options to compare and plot, for be…
Browse files Browse the repository at this point in the history
…tter customization (#2598)

Adds `sourmash compare --labels-to` and `sourmash plot --labels-from` to
support better label customization.

Fixes #2452
Fixes #2915

## `sourmash compare --labels-to`

This command will generate a 'labels-to' file. Running:
```
sourmash compare tests/test-data/demo/*.sig -o compare-demo \
    --labels-to compare-demo-labels.csv
```
will produce a file that looks like this:

file `compare-demo-labels.csv`:
```csv
order,md5,label,name,filename,signature_file
1,60f7e23c24a8d94791cc7a8680c493f9,SRR2060939_1.fastq.gz,,SRR2060939_1.fastq.gz,../tests/test-data/demo/SRR2060939_1.sig
2,4e94e60265e04f0763142e20b52c0da1,SRR2060939_2.fastq.gz,,SRR2060939_2.fastq.gz,../tests/test-data/demo/SRR2060939_2.sig
3,f71e78178af9e45e6f1d87a0c53c465c,SRR2241509_1.fastq.gz,,SRR2241509_1.fastq.gz,../tests/test-data/demo/SRR2241509_1.sig
4,6d6e87e1154e95b279e5e7db414bc37b,SRR2255622_1.fastq.gz,,SRR2255622_1.fastq.gz,../tests/test-data/demo/SRR2255622_1.sig
5,0107d767a345eff67ecdaed2ee5cd7ba,SRR453566_1.fastq.gz,,SRR453566_1.fastq.gz,../tests/test-data/demo/SRR453566_1.sig
6,f0c834bc306651d2b9321fb21d3e8d8f,SRR453569_1.fastq.gz,,SRR453569_1.fastq.gz,../tests/test-data/demo/SRR453569_1.sig
7,b59473c94ff2889eca5d7165936e64b3,SRR453570_1.fastq.gz,,SRR453570_1.fastq.gz,../tests/test-data/demo/SRR453570_1.sig
```

The `label` column in this file can be edited to suit the user's needs;
the index column is `order`, and all other columns can be ignored or
deleted or updated without consequence.

## `sourmash plot --labels-from`

This command will load labels from a file. Running:
```
sourmash plot --labels-from compare-demo-new-labels.csv compare-demo
```
uses the `label` column from the CSV as labels, in the order specified
by the `order` column (interpreted as integers and sorted from lowest to
highest). All other columns are ignored.

## Example in a Jupyter Notebook

Some example code for updating the labels is available here:


https://github.com/sourmash-bio/sourmash/blob/compare_labels/doc/plotting-compare.ipynb

## TODO

- [x] add test for `args.labeltext and args.labels_from` check
- [x] check the notebook update

## Future:

- [ ] Consider switching to `LinearIndex` in the signature loading code,
as that would let us maintain the location in the code without the
current machinations. Also worth thinking about enabling lazy loading,
which some future `Index`-code based modification might support.
- [ ] consider if and how to validate --labels-from CSV file...

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
ctb and pre-commit-ci[bot] authored Feb 6, 2024
1 parent a128ee3 commit c6831fd
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 20 deletions.
2 changes: 1 addition & 1 deletion doc/command-line.md
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ sourmash compare <sourmash signature file> [ <sourmash signature file> ... ]
Options:

* `--output <filename>` -- save the output matrix to this file, as a numpy binary matrix.
* `--csv <filename>` -- save the output matrix to this file in CSV format.
* `--distance-matrix` -- create and output a distance matrix, instead of a similarity matrix.
* `--ksize <k>` -- do the comparisons at this k-mer size.
* `--containment` -- calculate containment instead of similarity; `C(i, j) = size(i intersection j) / size(i)`
Expand All @@ -233,6 +232,7 @@ Options:
* `--ignore-abundance` -- ignore abundances in signatures.
* `--picklist <pickfile>:<colname>:<coltype>` -- select a subset of signatures with [a picklist](#using-picklists-to-subset-large-collections-of-signatures)
* `--csv <outfile.csv>` -- save the output matrix in CSV format.
* `--labels-to <labels.csv>` -- create a CSV file (spreadsheet) that can be passed in to `sourmash plot` with `--labels-from` in order to customize the labels.

**Note:** compare by default produces a symmetric similarity matrix
that can be used for clustering in downstream tasks. With `--containment`,
Expand Down
5 changes: 5 additions & 0 deletions src/sourmash/cli/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ def subparser(subparsers):
metavar="F",
help="write matrix to specified file in CSV format (with column " "headers)",
)
subparser.add_argument(
"--labels-to",
"--labels-save",
help="a CSV file containing label information",
)
subparser.add_argument(
"-p",
"--processes",
Expand Down
5 changes: 5 additions & 0 deletions src/sourmash/cli/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ def subparser(subparsers):
help="write clustered matrix and labels out in CSV format (with column"
" headers) to this file",
)
subparser.add_argument(
"--labels-from",
"--labels-load",
help="a CSV file containing label information to use on plot; implies --labels",
)


def main(args):
Expand Down
81 changes: 64 additions & 17 deletions src/sourmash/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,19 @@ def compare(args):
notify(
f"\nwarning: no signatures loaded at given ksize/molecule type/picklist from {filename}"
)
siglist.extend(loaded)

# track ksizes/moltypes
# add to siglist; track ksizes/moltypes
s = None
for s in loaded:
siglist.append((s, filename))
ksizes.add(s.minhash.ksize)
moltypes.add(sourmash_args.get_moltype(s))

if s is None:
notify(
f"\nwarning: no signatures loaded at given ksize/molecule type/picklist from {filename}"
)

# error out while loading if we have more than one ksize/moltype
if len(ksizes) > 1 or len(moltypes) > 1:
break
Expand Down Expand Up @@ -105,7 +111,7 @@ def compare(args):

# check to make sure they're potentially compatible - either using
# scaled, or not.
scaled_sigs = [s.minhash.scaled for s in siglist]
scaled_sigs = [s.minhash.scaled for (s, _) in siglist]
is_scaled = all(scaled_sigs)
is_scaled_2 = any(scaled_sigs)

Expand Down Expand Up @@ -145,16 +151,20 @@ def compare(args):

# notify about implicit --ignore-abundance:
if is_containment or return_ani:
track_abundances = any(s.minhash.track_abundance for s in siglist)
track_abundances = any(s.minhash.track_abundance for s, _ in siglist)
if track_abundances:
notify(
"NOTE: --containment, --max-containment, --avg-containment, and --estimate-ani ignore signature abundances."
)

# CTB: note, up to this point, we could do everything with manifests
# w/o actually loading any signatures. I'm not sure the manifest
# API allows it tho.

# if using scaled sketches or --scaled, downsample to common max scaled.
printed_scaled_msg = False
if is_scaled:
max_scaled = max(s.minhash.scaled for s in siglist)
max_scaled = max(s.minhash.scaled for s, _ in siglist)
if args.scaled:
args.scaled = int(args.scaled)

Expand All @@ -166,7 +176,7 @@ def compare(args):
notify(f"WARNING: continuing with scaled value of {max_scaled}.")

new_siglist = []
for s in siglist:
for s, filename in siglist:
if not size_may_be_inaccurate and not s.minhash.size_is_accurate():
size_may_be_inaccurate = True
if s.minhash.scaled != max_scaled:
Expand All @@ -177,9 +187,9 @@ def compare(args):
printed_scaled_msg = True
with s.update() as s:
s.minhash = s.minhash.downsample(scaled=max_scaled)
new_siglist.append(s)
new_siglist.append((s, filename))
else:
new_siglist.append(s)
new_siglist.append((s, filename))
siglist = new_siglist
elif args.scaled is not None:
error("ERROR: cannot specify --scaled with non-scaled signatures.")
Expand All @@ -196,16 +206,20 @@ def compare(args):

# do all-by-all calculation

labeltext = [str(item) for item in siglist]
labeltext = [str(ss) for ss, _ in siglist]
sigsonly = [ss for ss, _ in siglist]
if args.containment:
similarity = compare_serial_containment(siglist, return_ani=return_ani)
similarity = compare_serial_containment(sigsonly, return_ani=return_ani)
elif args.max_containment:
similarity = compare_serial_max_containment(siglist, return_ani=return_ani)
similarity = compare_serial_max_containment(sigsonly, return_ani=return_ani)
elif args.avg_containment:
similarity = compare_serial_avg_containment(siglist, return_ani=return_ani)
similarity = compare_serial_avg_containment(sigsonly, return_ani=return_ani)
else:
similarity = compare_all_pairs(
siglist, args.ignore_abundance, n_jobs=args.processes, return_ani=return_ani
sigsonly,
args.ignore_abundance,
n_jobs=args.processes,
return_ani=return_ani,
)

# if distance matrix desired, switch to 1-similarity
Expand All @@ -215,7 +229,7 @@ def compare(args):
matrix = similarity

if len(siglist) < 30:
for i, ss in enumerate(siglist):
for i, (ss, filename) in enumerate(siglist):
# for small matrices, pretty-print some output
name_num = f"{i}-{str(ss)}"
if len(name_num) > 20:
Expand Down Expand Up @@ -246,6 +260,25 @@ def compare(args):
with open(args.output, "wb") as fp:
numpy.save(fp, matrix)

# output labels information via --labels-to?
if args.labels_to:
labeloutname = args.labels_to
notify(f"saving labels to: {labeloutname}")
with sourmash_args.FileOutputCSV(labeloutname) as fp:
w = csv.writer(fp)
w.writerow(
["sort_order", "md5", "label", "name", "filename", "signature_file"]
)

for n, (ss, location) in enumerate(siglist):
md5 = ss.md5sum()
sigfile = location
label = str(ss)
name = ss.name
filename = ss.filename

w.writerow([str(n + 1), md5, label, name, filename, sigfile])

# output CSV?
if args.csv:
with FileOutputCSV(args.csv) as csv_fp:
Expand Down Expand Up @@ -289,7 +322,10 @@ def plot(args):
notify("...got {} x {} matrix.", *D.shape)

# see sourmash#2790 for details :)
if args.labeltext or args.labels:
if args.labeltext or args.labels or args.labels_from:
if args.labeltext and args.labels_from:
notify("ERROR: cannot supply both --labeltext and --labels-from")
sys.exit(-1)
display_labels = True
args.labels = True # override => labels always true
elif args.labels is None and not args.indices:
Expand All @@ -303,13 +339,24 @@ def plot(args):
else:
display_labels = False

if args.labels:
if args.labels_from:
labelfilename = args.labels_from
notify(f"loading labels from CSV file '{labelfilename}'")

labeltext = []
with sourmash_args.FileInputCSV(labelfilename) as r:
for row in r:
order, label = row["sort_order"], row["label"]
labeltext.append((int(order), label))
labeltext.sort()
labeltext = [t[1] for t in labeltext]
elif args.labels:
if args.labeltext:
labelfilename = args.labeltext
else:
labelfilename = D_filename + ".labels.txt"

notify(f"loading labels from {labelfilename}")
notify(f"loading labels from text file '{labelfilename}'")
with open(labelfilename) as f:
labeltext = [x.strip() for x in f]

Expand Down
5 changes: 5 additions & 0 deletions tests/test-data/compare/labels_from-test.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
sort_order,md5,label,name,filename,signature_file
4,8a619747693c045afde376263841806b,genome-s10+s11-CHANGED,genome-s10+s11,-,/Users/t/dev/sourmash/tests/test-data/genome-s10+s11.sig
3,ff511252a80bb9a7dbb0acf62626e123,genome-s12-CHANGED,genome-s12,genome-s12.fa.gz,/Users/t/dev/sourmash/tests/test-data/genome-s12.fa.gz.sig
2,1437d8eae64bad9bdc8d13e1daa0a43e,genome-s11-CHANGED,genome-s11,genome-s11.fa.gz,/Users/t/dev/sourmash/tests/test-data/genome-s11.fa.gz.sig
1,4cb3290263eba24548f5bef38bcaefc9,genome-s10-CHANGED,genome-s10,genome-s10.fa.gz,/Users/t/dev/sourmash/tests/test-data/genome-s10.fa.gz.sig
116 changes: 114 additions & 2 deletions tests/test_sourmash.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def test_compare_serial(runtmp):

testsigs = utils.get_test_data("genome-s1*.sig")
testsigs = glob.glob(testsigs)
assert len(testsigs) == 4

c.run_sourmash("compare", "-o", "cmp", "-k", "21", "--dna", *testsigs)

Expand Down Expand Up @@ -1252,7 +1253,7 @@ def test_plot_override_labeltext(runtmp):

print(runtmp.last_result.out)

assert "loading labels from new.labels.txt" in runtmp.last_result.err
assert "loading labels from text file 'new.labels.txt'" in runtmp.last_result.err

expected = """\
0\ta
Expand Down Expand Up @@ -1291,7 +1292,7 @@ def test_plot_override_labeltext_fail(runtmp):
print(runtmp.last_result.out)
print(runtmp.last_result.err)
assert runtmp.last_result.status != 0
assert "loading labels from new.labels.txt" in runtmp.last_result.err
assert "loading labels from text file 'new.labels.txt'" in runtmp.last_result.err
assert "3 labels != matrix size, exiting" in runtmp.last_result.err


Expand Down Expand Up @@ -1406,6 +1407,117 @@ def test_plot_subsample_2(runtmp):
assert expected in runtmp.last_result.out


def test_compare_and_plot_labels_from_to(runtmp):
# test doing compare --labels-to and plot --labels-from.
testdata1 = utils.get_test_data("genome-s10.fa.gz.sig")
testdata2 = utils.get_test_data("genome-s11.fa.gz.sig")
testdata3 = utils.get_test_data("genome-s12.fa.gz.sig")
testdata4 = utils.get_test_data("genome-s10+s11.sig")

labels_csv = runtmp.output("label.csv")

runtmp.run_sourmash(
"compare",
testdata1,
testdata2,
testdata3,
testdata4,
"-o",
"cmp",
"-k",
"21",
"--dna",
"--labels-to",
labels_csv,
)

runtmp.sourmash("plot", "cmp", "--labels-from", labels_csv)

print(runtmp.last_result.out)

assert "loading labels from CSV file" in runtmp.last_result.err

expected = """\
0\tgenome-s10
1\tgenome-s11
2\tgenome-s12
3\tgenome-s10+s11"""
assert expected in runtmp.last_result.out


def test_compare_and_plot_labels_from_changed(runtmp):
# test 'plot --labels-from' with changed labels
testdata1 = utils.get_test_data("genome-s10.fa.gz.sig")
testdata2 = utils.get_test_data("genome-s11.fa.gz.sig")
testdata3 = utils.get_test_data("genome-s12.fa.gz.sig")
testdata4 = utils.get_test_data("genome-s10+s11.sig")

labels_csv = utils.get_test_data("compare/labels_from-test.csv")

runtmp.run_sourmash(
"compare",
testdata1,
testdata2,
testdata3,
testdata4,
"-o",
"cmp",
"-k",
"21",
"--dna",
)

runtmp.sourmash("plot", "cmp", "--labels-from", labels_csv)

print(runtmp.last_result.out)

assert "loading labels from CSV file" in runtmp.last_result.err

expected = """\
0\tgenome-s10-CHANGED
1\tgenome-s11-CHANGED
2\tgenome-s12-CHANGED
3\tgenome-s10+s11-CHANGED"""
assert expected in runtmp.last_result.out


def test_compare_and_plot_labels_from_error(runtmp):
# 'plot --labels-from ... --labeltext ...' should fail
testdata1 = utils.get_test_data("genome-s10.fa.gz.sig")
testdata2 = utils.get_test_data("genome-s11.fa.gz.sig")
testdata3 = utils.get_test_data("genome-s12.fa.gz.sig")
testdata4 = utils.get_test_data("genome-s10+s11.sig")

labels_csv = utils.get_test_data("compare/labels_from-test.csv")

runtmp.run_sourmash(
"compare",
testdata1,
testdata2,
testdata3,
testdata4,
"-o",
"cmp",
"-k",
"21",
"--dna",
)

with pytest.raises(SourmashCommandFailed):
runtmp.sourmash(
"plot",
"cmp",
"--labels-from",
labels_csv,
"--labeltext",
labels_csv,
fail_ok=True,
)

err = runtmp.last_result.err
assert "ERROR: cannot supply both --labeltext and --labels-from" in err


@utils.in_tempdir
def test_search_query_sig_does_not_exist(c):
testdata1 = utils.get_test_data("short.fa")
Expand Down

0 comments on commit c6831fd

Please sign in to comment.