diff --git a/README.md b/README.md index c36bbfaa..690fc630 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,16 @@ # nuad +nuad is a Python library that enables one to specify constraints on a DNA (or RNA) nanostructure made from synthetic DNA/RNA and then attempts to find concrete DNA sequences that satisfy the constraints. + +Note: If you are reading this on the PyPI website, many links below won't work. They are relative links intended to be read on the [GitHub README page](https://github.com/UC-Davis-molecular-computing/nuad/tree/main#readme). + ## Table of contents * [Overview](#overview) * [API documentation](#api-documentation) * [Installation](#installation) + * [Installing nuad](#installing-nuad) + * [Installing NUPACK and ViennaRNA](#installing-nupack-and-viennarna) * [Data model](#data-model) * [Constraint evaluations must be pure functions of their inputs](#constraint-evaluations-must-be-pure-functions-of-their-inputs) * [Examples](#examples) @@ -17,8 +23,6 @@ nuad stands for "NUcleic Acid Designer".† It is a Python library that enables one to specify constraints on a DNA (or RNA) nanostructure made from synthetic DNA/RNA (for example, "*all strands should have complex free energy at least -2.0 kcal/mol according to [NUPACK](http://www.nupack.org/)*", or "*every binding domain should have binding energy with its perfect complement between -8.0 kcal/mol and -9.0 kcal/mol in the [nearest-neighbor energy model](https://en.wikipedia.org/wiki/Nucleic_acid_thermodynamics#Nearest-neighbor_method)*"), and then attempts to find concrete DNA sequences that satisfy the constraints. It is not a standalone program, unlike other DNA sequence designers such as [NUPACK](http://www.nupack.org/design/new). Instead, it attempts to be more expressive than existing DNA sequence designers, at the cost of being less simple to use. The nuad library helps you to write your own DNA sequence designer, in case existing designers cannot capture the particular constraints of your project. -Note: If you are reading this on the PyPI website, many links below won't work. They are relative links intended to be read on the [GitHub README page](https://github.com/UC-Davis-molecular-computing/nuad/tree/main#readme). - Note: The nuad package was originally called dsd (DNA sequence designer), so you may see some old references to this name for the package. †A secondary reason for the name of the package is that some work was done when the primary author was on sabbatical in Maynooth, Ireland, whose original Irish name is [*Maigh Nuad*](https://en.wikipedia.org/wiki/Maynooth#Etymology). @@ -29,47 +33,66 @@ The API documentation is on readthedocs: https://nuad.readthedocs.io/ ## Installation -nuad requires Python version 3.7 or higher. Currently, it cannot be installed using pip (see [issue #12](https://github.com/UC-Davis-molecular-computing/nuad/issues/12)). +nuad requires Python version 3.7 or higher. Currently, although it can be installed using pip by typing `pip install nuad`, it depends on two pieces of software that are not installed automatically by pip (see [issue #12](https://github.com/UC-Davis-molecular-computing/nuad/issues/12)). nuad uses [NUPACK](http://www.nupack.org/downloads) and [ViennaRNA](https://www.tbi.univie.ac.at/RNA/#download), which must be installed separately (see below for link to installation instructions). While it is technically possible to use nuad without them, most of the pre-packaged constraints require them. -To use NUPACK on Windows, you should use [Windows Subsystem for Linux (WSL)](https://docs.microsoft.com/en-us/windows/wsl/install-win10), which essentially installs a command-line-only Linux inside of your Windows system, which has access to your Windows file system. If you are using Windows, you can then run python code calling the nuad library from WSL (which will appear to the Python virtual machine as though it is running on Linux). WSL is necessary to use any of the constraints that use NUPACK 4. +To use NUPACK on Windows, you must use [Windows Subsystem for Linux (WSL)](https://docs.microsoft.com/en-us/windows/wsl/install), which essentially installs a command-line-only Linux inside of your Windows system, which has access to your Windows file system. If you are using Windows, you can then run python code calling the nuad library from WSL (which will appear to the Python virtual machine as though it is running on Linux). WSL is necessary to use any of the constraints that use NUPACK 4. + +### Installing nuad + +To install nuad, you can either install it using pip (the slightly simpler option) or git. No matter which method you choose, you must also install NUPACK and ViennaRNA separately (see [instructions below](#installing-nupack-and-viennarna)). + +- pip + + At the command line (WSL for Windows, not the Powershell prompt), type -To install nuad: + ``` + pip install nuad + ``` -1. Download the git repo, by one of two methods: - - Install [git](https://git-scm.com/downloads) if necessary, then type +- git + + 1. Download the git repo, by one of two methods: + - Install [git](https://git-scm.com/downloads) if necessary, then type - ```git clone https://github.com/UC-Davis-molecular-computing/nuad.git``` + ```git clone https://github.com/UC-Davis-molecular-computing/nuad.git``` - at the command line, or - - on the page `https://github.com/UC-Davis-molecular-computing/nuad`, click on Code → Download Zip: + at the command line, or + - on the page `https://github.com/UC-Davis-molecular-computing/nuad`, click on Code → Download Zip: - ![](images/screenshot-download-zip.png) + ![](images/screenshot-download-zip.png) - and then unzip somewhere on your file system. + and then unzip somewhere on your file system. -2. Add the directory `nuad` that you just created to your `PYTHONPATH` environment variable. In Linux, Mac, or [Windows Subsystem for Linux (WSL)](https://docs.microsoft.com/en-us/windows/wsl/install-win10), this is done by adding this line to your startup script (e.g., `~/.bashrc`, or `~/.bash_profile` for Mac OS), where `/path/to/nuad` represents the path to the `nuad` directory: + 2. Add the directory `nuad` that you just created to your `PYTHONPATH` environment variable. In Linux, Mac, or [Windows Subsystem for Linux (WSL)](https://docs.microsoft.com/en-us/windows/wsl/install-win10), this is done by adding this line to your startup script (e.g., `~/.bashrc`, or `~/.bash_profile` for Mac OS), where `/path/to/nuad` represents the path to the `nuad` directory: - ``` - export PYTHONPATH="${PYTHONPATH}:/path/to/nuad" - ``` + ``` + export PYTHONPATH="${PYTHONPATH}:/path/to/nuad" + ``` -3. Install the Python packages dependencies listed in the file [requirements.txt](https://github.com/UC-Davis-molecular-computing/nuad/blob/main/requirements.txt) by typing + 3. Install the Python packages dependencies listed in the file [requirements.txt](https://github.com/UC-Davis-molecular-computing/nuad/blob/main/requirements.txt) by typing - ``` - pip install numpy ordered_set psutil pathos scadnano xlwt xlrd - ``` + ``` + pip install numpy ordered_set psutil pathos xlwt xlrd tabulate scadnano + ``` - at the command line. + at the command line. If you have Python 3.7 then you will also have to install the `typing_extensions` package: `pip install typing_extensions` + +### Installing NUPACK and ViennaRNA + +Recall that if you are using Windows, you must do all installation through [WSL](https://docs.microsoft.com/en-us/windows/wsl/install) (Windows subsystem for Linux). + +Install NUPACK (version 4) and ViennaRNA following their installation instructions ([NUPACK installation](https://docs.nupack.org/start/#maclinux-installation), [ViennaRNA installation](https://www.tbi.univie.ac.at/RNA/ViennaRNA/doc/html/install.html), and [ViennaRNA downloads](https://www.tbi.univie.ac.at/RNA/#download)). If you do not install one of them, you can still install nuad, but most of the useful functions specifying pre-packaged constraints will be unavailable to call. + +After installing ViennaRNA, it may be necessary to add its executables directory (the directory containing executable programs such as RNAduplex) to your `PATH` environment variable. (Similarly to how the `PYTHONPATH` variable is adjusted above.) NUPACK 4 does not come with an executable, so this step is unnecessary; it is called directly from within Python. -4. Install NUPACK (version 4) and ViennaRNA following their installation instructions ([NUPACK installation](https://docs.nupack.org/start/#maclinux-installation), [ViennaRNA installation](https://www.tbi.univie.ac.at/RNA/ViennaRNA/doc/html/install.html), and [ViennaRNA downloads](https://www.tbi.univie.ac.at/RNA/#download)). (If you do not install one of them, you can still install nuad, but most of the useful functions specifying pre-packaged constraints will be unavailable to call.) If installing on Windows, you must first install [Windows Subsystem for Linux (WSL)](https://docs.microsoft.com/en-us/windows/wsl/install-win10), and then install NUPACK and ViennaRNA from within WSL. After installing ViennaRNA, it may be necessary to add its executables directory (the directory containing executable programs such as RNAduplex) to your `PATH` environment variable. (Similarly to how the `PYTHONPATH` variable is adjusted above.) NUPACK 4 does not come with an executable, so this step is unnecessary; it is called directly from within Python. +To test that NUPACK 4 is installed correctly, run `python3 -m pip show nupack`. - To test that NUPACK 4 is installed correctly, run `python3 -m pip show nupack`. - To test that ViennaRNA is installed correctly, type `RNAduplex` at the command line. +To test that ViennaRNA is installed correctly, type `RNAduplex` at the command line. -5. Test NUPACK and ViennaRNA are available from within nuad by typing `python` at the command line, then typing `import nuad`. It should import without errors: +Test NUPACK and ViennaRNA are available from within nuad by typing `python` at the command line, then typing `import nuad`. It should import without errors: ```python $ python @@ -80,7 +103,7 @@ To install nuad: >>> ``` - To test that NUPACK and ViennaRNA can each be called from within the Python library (note that if you do not install NUPACK and/or ViennaRNA, then only a subset of the following will succeed): +To test that NUPACK and ViennaRNA can each be called from within the Python library (note that if you do not install NUPACK and/or ViennaRNA, then only a subset of the following will succeed): ```python >>> import nuad.vienna_nupack as nv diff --git a/examples/many_strands_no_common_domains.py b/examples/many_strands_no_common_domains.py index 49484090..15af1f39 100644 --- a/examples/many_strands_no_common_domains.py +++ b/examples/many_strands_no_common_domains.py @@ -48,30 +48,30 @@ def main() -> None: # many 4-domain strands with no common domains, 4 domains each, every domain length = 10 # just for testing parallel processing - # num_strands = 2 + # num_strands = 3 # num_strands = 5 # num_strands = 10 - # num_strands = 50 - num_strands = 100 + num_strands = 50 + # num_strands = 100 # num_strands = 355 + design = nc.Design() # si wi ni ei # strand i is [----------|----------|----------|----------> - strands = [nc.Strand([f's{i}', f'w{i}', f'n{i}', f'e{i}']) for i in range(num_strands)] + for i in range(num_strands): + design.add_strand([f's{i}', f'w{i}', f'n{i}', f'e{i}']) - # some_fixed = False - some_fixed = True + some_fixed = False + # some_fixed = True if some_fixed: # fix all domains of strand 0 and one domain of strand 1 - for domain in strands[0].domains: + for domain in design.strands[0].domains: domain.set_fixed_sequence('ACGTACGTAC') - strands[1].domains[0].set_fixed_sequence('ACGTACGTAC') + design.strands[1].domains[0].set_fixed_sequence('ACGTACGTAC') parallel = False # parallel = True - design = nc.Design(strands) - numpy_constraints: List[NumpyConstraint] = [ nc.NearestNeighborEnergyConstraint(-9.3, -9.0, 52.0), # nc.BaseCountConstraint(base='G', high_count=1), @@ -107,14 +107,14 @@ def main() -> None: ) if some_fixed: - for strand in strands[1:]: # skip all domains on strand 0 since all its domains are fixed + for strand in design.strands[1:]: # skip all domains on strand 0 since all its domains are fixed for domain in strand.domains[:2]: if domain.name != 's1': # skip for s1 since that domain is fixed domain.pool = domain_pool_10 for domain in strand.domains[2:]: domain.pool = domain_pool_11 else: - for strand in strands: + for strand in design.strands: for domain in strand.domains[:2]: domain.pool = domain_pool_10 for domain in strand.domains[2:]: @@ -122,7 +122,7 @@ def main() -> None: # have to set nupack_complex_secondary_structure_constraint after DomainPools are set, # so that we know the domain lengths - strand_complexes = [nc.Complex((strand,)) for i, strand in enumerate(strands[2:])] + strand_complexes = [nc.Complex(strand) for i, strand in enumerate(design.strands[2:])] strand_base_pair_prob_constraint = nc.nupack_complex_base_pair_probability_constraint( strand_complexes=strand_complexes) @@ -142,16 +142,20 @@ def main() -> None: strand_individual_ss_constraint = nc.nupack_strand_complex_free_energy_constraint( threshold=-1.0, temperature=52, short_description='StrandSS', parallel=parallel) + strand_individual_ss_constraint2 = nc.nupack_strand_complex_free_energy_constraint( + threshold=-1.0, temperature=52, short_description='StrandSS2', parallel=parallel) + strand_pair_nupack_constraint = nc.nupack_strand_pairs_constraint( threshold=3.0, temperature=52, short_description='StrandPairNUPACK', parallel=parallel, weight=0.1) params = ns.SearchParameters(constraints=[ # domain_nupack_ss_constraint, strand_individual_ss_constraint, + # strand_individual_ss_constraint2, + strand_pairs_rna_duplex_constraint, # strand_pair_nupack_constraint, # domain_pair_nupack_constraint, # domain_pairs_rna_duplex_constraint, - strand_pairs_rna_duplex_constraint, # strand_base_pair_prob_constraint, # nc.domains_not_substrings_of_each_other_constraint(), ], @@ -163,6 +167,8 @@ def main() -> None: save_report_for_all_updates=True, save_design_for_all_updates=True, force_overwrite=True, + scrolling_output=False, + # report_only_violations=False, ) ns.search_for_dna_sequences(design, params) diff --git a/examples/sample_designer.py b/examples/sample_designer.py index f11e2131..f1d432ae 100644 --- a/examples/sample_designer.py +++ b/examples/sample_designer.py @@ -74,13 +74,12 @@ def main() -> None: # | w4* s4* # \===========--==========] - strand0: nc.Strand[str] = nc.Strand(['s1', 'w1', 'n1', 'e1'], name='strand 0') - strand1: nc.Strand[str] = nc.Strand(['s2', 'w2', 'n2', 'e2'], name='strand 1') - strand2: nc.Strand[None] = nc.Strand(['n2*', 'e1*', 'n3*', 'e3*'], name='strand 2') - strand3: nc.Strand[str] = nc.Strand(['s4*', 'w4*', 's1*', 'w2*'], name='strand 3') - strands = [strand0, strand1, strand2, strand3] + initial_design = nc.Design() - initial_design = nc.Design(strands) + strand0: nc.Strand[str] = initial_design.add_strand(['s1', 'w1', 'n1', 'e1'], name='strand 0') + strand1: nc.Strand[str] = initial_design.add_strand(['s2', 'w2', 'n2', 'e2'], name='strand 1') + strand2: nc.Strand[None] = initial_design.add_strand(['n2*', 'e1*', 'n3*', 'e3*'], name='strand 2') + strand3: nc.Strand[str] = initial_design.add_strand(['s4*', 'w4*', 's1*', 'w2*'], name='strand 3') if args.initial_design_filename is not None: with open(args.initial_design_filename, 'r') as file: diff --git a/examples/seesaw_gate.py b/examples/seesaw_gate.py index 012ca61d..80d5c188 100644 --- a/examples/seesaw_gate.py +++ b/examples/seesaw_gate.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List, Optional, Tuple # Test ComplexConstraint evaluate import nuad.constraints as nc @@ -44,6 +44,8 @@ ] TOEHOLD_DOMAIN_POOL: nc.DomainPool = nc.DomainPool('toehold_domain_pool', 5) +design = nc.Design() + # s2 S2 T s1 S1 # [==--=============--=====--==-=============> @@ -52,7 +54,7 @@ def seesaw_signal_strand(gate1: int, gate2: int) -> nc.Strand: d1_sub = f'{SIGNAL_DOMAIN_SUB_PREFIX}{gate1}' d2 = f'{SIGNAL_DOMAIN_PREFIX}{gate2}' d2_sub = f'{SIGNAL_DOMAIN_SUB_PREFIX}{gate2}' - s: nc.Strand = nc.Strand([d2_sub, d2, TOEHOLD_DOMAIN, d1_sub, d1], name=f'signal {gate1} {gate2}') + s: nc.Strand = design.add_strand([d2_sub, d2, TOEHOLD_DOMAIN, d1_sub, d1], name=f'signal {gate1} {gate2}') s.domains[0].pool = SUB_LONG_DOMAIN_POOL s.domains[1].pool = NON_SUB_LONG_DOMAIN_POOL s.domains[2].pool = TOEHOLD_DOMAIN_POOL @@ -67,7 +69,7 @@ def seesaw_signal_strand(gate1: int, gate2: int) -> nc.Strand: def gate_base_strand(gate: int) -> nc.Strand: d = f'{SIGNAL_DOMAIN_PREFIX}{gate}{COMPLEMENT_SUFFIX}' d_sub = f'{SIGNAL_DOMAIN_SUB_PREFIX}{gate}{COMPLEMENT_SUFFIX}' - s: nc.Strand = nc.Strand( + s: nc.Strand = design.add_strand( [TOEHOLD_COMPLEMENT, d, d_sub, TOEHOLD_COMPLEMENT], name=f'gate {gate}') s.domains[0].pool = TOEHOLD_DOMAIN_POOL s.domains[1].pool = NON_SUB_LONG_DOMAIN_POOL @@ -81,7 +83,7 @@ def gate_base_strand(gate: int) -> nc.Strand: def waste_strand(gate: int) -> nc.Strand: d = f'{SIGNAL_DOMAIN_PREFIX}{gate}' d_sub = f'{SIGNAL_DOMAIN_SUB_PREFIX}{gate}' - s: nc.Strand = nc.Strand([d_sub, d], name=f'waste {gate}') + s: nc.Strand = design.add_strand([d_sub, d], name=f'waste {gate}') s.domains[0].pool = SUB_LONG_DOMAIN_POOL s.domains[1].pool = NON_SUB_LONG_DOMAIN_POOL return s @@ -95,7 +97,7 @@ def threshold_base_strand(gate1: int, gate2: int) -> nc.Strand: d2 = f'{SIGNAL_DOMAIN_PREFIX}{gate2}{COMPLEMENT_SUFFIX}' d2_sub = f'{SIGNAL_DOMAIN_SUB_PREFIX}{gate2}{COMPLEMENT_SUFFIX}' - s: nc.Strand = nc.Strand( + s: nc.Strand = design.add_strand( [d1_sub, TOEHOLD_COMPLEMENT, d2, d2_sub], name=f'threshold {gate1} {gate2}') s.domains[0].pool = SUB_LONG_DOMAIN_POOL @@ -113,7 +115,7 @@ def reporter_base_strand(gate) -> nc.Strand: d = f'{SIGNAL_DOMAIN_PREFIX}{gate}{COMPLEMENT_SUFFIX}' d_sub = f'{SIGNAL_DOMAIN_SUB_PREFIX}{gate}{COMPLEMENT_SUFFIX}' - s: nc.Strand = nc.Strand( + s: nc.Strand = design.add_strand( [TOEHOLD_COMPLEMENT, d, d_sub], name=f'reporter {gate}') s.domains[0].pool = TOEHOLD_DOMAIN_POOL s.domains[1].pool = NON_SUB_LONG_DOMAIN_POOL @@ -219,8 +221,8 @@ def reporter_base_strand(gate) -> nc.Strand: # | | | # INTERIOR_TO_STRAND DANGLE_5P # T* S5* s5* T* -g_5_s_5_6_complex = (signal_5_6_strand, gate_5_base_strand) -g_5_s_5_7_complex = (signal_5_7_strand, gate_5_base_strand) +g_5_s_5_6_complex = nc.Complex(signal_5_6_strand, gate_5_base_strand) +g_5_s_5_7_complex = nc.Complex(signal_5_7_strand, gate_5_base_strand) g_5_s_5_6_nonimplicit_base_pairs = [(signal_5_6_toehold_addr, gate_5_bound_toehold_3p_addr)] g_5_s_5_6_complex_constraint = nc.nupack_complex_base_pair_probability_constraint( @@ -255,7 +257,7 @@ def reporter_base_strand(gate) -> nc.Strand: # | | | # INTERIOR_TO_STRAND DANGLE_3P # T* S5* s5* T* -g_5_s_2_5_complex = (signal_2_5_strand, gate_5_base_strand) +g_5_s_2_5_complex = nc.Complex(signal_2_5_strand, gate_5_base_strand) g_5_s_2_5_nonimplicit_base_pairs = [(signal_2_5_toehold_addr, gate_5_bound_toehold_5p_addr)] g_5_s_2_5_complex_constraint = nc.nupack_complex_base_pair_probability_constraint( strand_complexes=[g_5_s_2_5_complex], @@ -292,7 +294,7 @@ def reporter_base_strand(gate) -> nc.Strand: # | | # INTERIOR_TO_STRAND BLUNT_END # s2* T* S5* s5* -t_2_5_w_5_complex = (waste_5_strand, threshold_2_5_base_strand) +t_2_5_w_5_complex = nc.Complex(waste_5_strand, threshold_2_5_base_strand) t_2_5_w_5_complex_constraint = nc.nupack_complex_base_pair_probability_constraint( strand_complexes=[t_2_5_w_5_complex]) @@ -323,7 +325,7 @@ def reporter_base_strand(gate) -> nc.Strand: # INTERIOR_TO_STRAND INTERIOR_TO_STRAND BLUNT_END # s2* T* S5* s5* -waste_2_5_complex = (signal_2_5_strand, threshold_2_5_base_strand) +waste_2_5_complex = nc.Complex(signal_2_5_strand, threshold_2_5_base_strand) waste_2_5_complex_constraint = nc.nupack_complex_base_pair_probability_constraint( strand_complexes=[waste_2_5_complex]) @@ -349,7 +351,7 @@ def reporter_base_strand(gate) -> nc.Strand: # | | # INTERIOR_TO_STRAND BLUNT_END # T* S6* s6* -reporter_6_complex = (waste_6_strand, reporter_6_base_strand) +reporter_6_complex = nc.Complex(waste_6_strand, reporter_6_base_strand) reporter_6_complex_constraint = nc.nupack_complex_base_pair_probability_constraint( strand_complexes=[reporter_6_complex]) @@ -376,16 +378,16 @@ def reporter_base_strand(gate) -> nc.Strand: # | | | # INTERIOR_TO_STRAND BLUNT_END # T* S6* s6* -f_waste_6_complex = (signal_5_6_strand, reporter_6_base_strand) +f_waste_6_complex = nc.Complex(signal_5_6_strand, reporter_6_base_strand) f_waste_6_complex_constraint = nc.nupack_complex_base_pair_probability_constraint( strand_complexes=[f_waste_6_complex]) -def four_g_constraint_evaluate(seq: str, strand: Optional[nc.Strand]): - if 'GGGG' in seq: - return 1000 - else: - return 0 +def four_g_constraint_evaluate(seqs: Tuple[str, ...], strand: Optional[nc.Strand]) -> Tuple[float, str]: + seq = seqs[0] + score = 1000 if 'GGGG' in seq else 0 + violation_str = "" if 'GGGG' not in strand.sequence() else "** violation**" + return score, f"{strand.name}: {strand.sequence()}{violation_str}" def four_g_constraint_summary(strand: nc.Strand): @@ -396,11 +398,10 @@ def four_g_constraint_summary(strand: nc.Strand): four_g_constraint = nc.StrandConstraint(description="4GConstraint", short_description="4GConstraint", evaluate=four_g_constraint_evaluate, - strands=tuple(strands), - summary=four_g_constraint_summary) + strands=tuple(strands), ) # Constraints -complex_constraints = [ +constraints = [ g_5_s_5_6_complex_constraint, g_5_s_2_5_complex_constraint, t_2_5_w_5_complex_constraint, @@ -408,11 +409,13 @@ def four_g_constraint_summary(strand: nc.Strand): reporter_6_complex_constraint, f_waste_6_complex_constraint, ] +constraints.append(four_g_constraint) -seesaw_design = nc.Design(strands=strands, constraints=complex_constraints + [four_g_constraint]) +seesaw_design = nc.Design(strands=strands) params = ns.SearchParameters( # weigh_violations_equally=True, - report_delay=0.0, + constraints=constraints, + # report_delay=0.0, out_directory='output/seesaw_gate', report_only_violations=False, ) diff --git a/examples/square_root_circuit.py b/examples/square_root_circuit.py index 01c12cfa..7e340b96 100644 --- a/examples/square_root_circuit.py +++ b/examples/square_root_circuit.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field from math import ceil, floor -from typing import Dict, Iterable, List, Optional, Set, Tuple, Union, cast +from typing import Dict, Iterable, List, Optional, Set, Tuple, Union import itertools import nuad.search as ns # type: ignore @@ -36,11 +36,12 @@ SIGNAL_DOMAIN_LENGTH - EXTENDED_TOEHOLD_LENGTH) SUBDOMAIN_S_POOL: nc.DomainPool = nc.DomainPool(f'SUBDOMAIN_S_POOL', EXTENDED_TOEHOLD_LENGTH) TOEHOLD_DOMAIN_POOL: nc.DomainPool = nc.DomainPool( - 'TOEHOLD_DOMAIN_POOL', TOEHOLD_LENGTH, [three_letter_code_constraint]) + name='TOEHOLD_DOMAIN_POOL', length=TOEHOLD_LENGTH, numpy_constraints=[three_letter_code_constraint]) SIGNAL_DOMAIN_POOL: nc.DomainPool = nc.DomainPool( - 'SIGNAL_DOMAIN_POOL', SIGNAL_DOMAIN_LENGTH, - [three_letter_code_constraint, c_content_constraint, no_aaaaa_constraint, no_gggg_constraint]) + name='SIGNAL_DOMAIN_POOL', length=SIGNAL_DOMAIN_LENGTH, + numpy_constraints=[three_letter_code_constraint, c_content_constraint, no_aaaaa_constraint, + no_gggg_constraint]) # Alias dc_complex_constraint = nc.nupack_complex_base_pair_probability_constraint @@ -257,7 +258,7 @@ def reporter_bottom_strand(gate) -> nc.Strand: def input_gate_complex_constraint( - input_gate_complexes: List[Tuple[nc.Strand, nc.Strand]]) -> nc.ComplexConstraint: + input_gate_complexes: List[nc.Complex]) -> nc.ComplexConstraint: """Returns a input:gate complex constraint .. code-block:: none @@ -295,16 +296,14 @@ def input_gate_complex_constraint( addr_t = template_top_strand.address_of_first_domain_occurence('T') addr_t_star = template_bot_strand.address_of_first_domain_occurence('T*') return dc_complex_constraint( - strand_complexes=cast( - List[Tuple[nc.Strand, ...]], - input_gate_complexes), + strand_complexes=input_gate_complexes, nonimplicit_base_pairs=[(addr_t, addr_t_star)], description="input:gate Complex", short_description="input:gate") def gate_output_complex_constraint( - gate_output_complexes: List[Tuple[nc.Strand, ...]], + gate_output_complexes: List[nc.Complex], base_pair_prob_by_type: Optional[Dict[nc.BasePairType, float]] = None, description: str = 'gate:output') -> nc.ComplexConstraint: """Returns a gate:output complex constraint @@ -431,25 +430,20 @@ def violated(seq: str): return True return False - def evaluate(seq: str, strand: Optional[nc.Strand]): + def evaluate(seqs: Tuple[str, ...], strand: Optional[nc.Strand]) -> Tuple[float, str]: + seq = seqs[0] if violated(seq): - return 100 + violation_str = '** violation**' + score = 100 else: - return 0 - - def summary(strand: nc.Strand): - violation_str: str - if violated(strand.sequence()): violation_str = '' - else: - violation_str = "** violation**" - return f"{strand.name}: {strand.sequence()}{violation_str}" + score = 0 + return score, f"{strand.name}: {strand.sequence()}{violation_str}" return nc.StrandConstraint(description="Strand Substring Constraint", short_description="Strand Substring Constraint", evaluate=evaluate, - strands=tuple(strands), - summary=summary) + strands=tuple(strands)) @dataclass @@ -635,25 +629,27 @@ def _set_strands(self) -> None: def _add_input_gate_complex_constraint(self) -> None: """Adds input:gate complexes to self.constraint """ - input_gate_complexes = [] + input_gate_strands = [] for (input_, gate), s in self.signal_strands.items(): if gate in self.gate_base_strands: g = self.gate_base_strands[gate] - input_gate_complexes.append((s, g)) + input_gate_strands.append((s, g)) + input_gate_complexes = [nc.Complex(*strands) for strands in input_gate_strands] self.constraints.append( - input_gate_complex_constraint( - input_gate_complexes)) + input_gate_complex_constraint(input_gate_complexes)) def _add_gate_output_complex_constriant(self) -> None: """Adds gate:output complexes to self.constraint """ - gate_output_complexes: List[Tuple[nc.Strand, ...]] = [] + gate_output_strands: List[Tuple[nc.Strand, ...]] = [] for (gate, _), s in self.signal_strands.items(): if gate in self.gate_base_strands: g = self.gate_base_strands[gate] - gate_output_complexes.append((s, g)) + gate_output_strands.append((s, g)) + + gate_output_complexes = [nc.Complex(*strands) for strands in gate_output_strands] self.constraints.append( gate_output_complex_constraint( @@ -664,13 +660,15 @@ def _add_gate_output_complex_constriant(self) -> None: def _add_gate_fuel_complex_constriant(self) -> None: """Adds gate:fuel complexes to self.constraint """ - gate_output_complexes: List[Tuple[nc.Strand, ...]] = [] + gate_output_strands: List[Tuple[nc.Strand, ...]] = [] for gate in self.fuel_strands: if gate in self.fuel_strands: f = self.fuel_strands[gate] g = self.gate_base_strands[gate] - gate_output_complexes.append((f, g)) + gate_output_strands.append((f, g)) + + gate_output_complexes = [nc.Complex(*strands) for strands in gate_output_strands] # TODO: Make it so that only specific base pairs have lower threshold (such as base index 1) # which is an A that can bind to any T but it doesn't matter which. @@ -698,10 +696,12 @@ def _add_threshold_complex_constraint(self) -> None: 16 35 s2* T* S5* s5* """ - threshold_complexes: List[Tuple[nc.Strand, ...]] = [] + threshold_strands: List[Tuple[nc.Strand, ...]] = [] for (_, gate), thres_bottom_strand in self.threshold_bottom_strands.items(): waste_strand = self.threshold_top_strands[gate] - threshold_complexes.append((waste_strand, thres_bottom_strand)) + threshold_strands.append((waste_strand, thres_bottom_strand)) + + threshold_complexes = [nc.Complex(*strands) for strands in threshold_strands] self.constraints.append( dc_complex_constraint( @@ -727,12 +727,14 @@ def _add_threshold_waste_complex_constraint(self) -> None: 36 55 s2* T* S5* s5* """ - threshold_waste_complexes: List[Tuple[nc.Strand, ...]] = [] + threshold_waste_strands: List[Tuple[nc.Strand, ...]] = [] for (input_, gate), thres_bottom_strand in self.threshold_bottom_strands.items(): sig_strand = self.signal_strands[(input_, gate)] - threshold_waste_complexes.append( + threshold_waste_strands.append( (sig_strand, thres_bottom_strand)) + threshold_waste_complexes = [nc.Complex(*strands) for strands in threshold_waste_strands] + self.constraints.append( dc_complex_constraint( threshold_waste_complexes, @@ -755,10 +757,12 @@ def _add_reporter_complex_constraint(self) -> None: 33 T* S6* s6* """ - reporter_complexes: List[Tuple[nc.Strand, ...]] = [] + reporter_strands: List[Tuple[nc.Strand, ...]] = [] for (_, gate), reporter_bottom_strand_ in self.reporter_bottom_strands.items(): waste_strand = self.reporter_top_strands[gate] - reporter_complexes.append((waste_strand, reporter_bottom_strand_)) + reporter_strands.append((waste_strand, reporter_bottom_strand_)) + + reporter_complexes = [nc.Complex(*strands) for strands in reporter_strands] self.constraints.append( dc_complex_constraint( @@ -783,12 +787,14 @@ def _add_reporter_waste_complex_constraint(self) -> None: 53 T* S6* s6* """ - reporter_waste_complexes: List[Tuple[nc.Strand, ...]] = [] + reporter_waste_strands: List[Tuple[nc.Strand, ...]] = [] for (input_, gate), reporter_bottom_strand_ in self.reporter_bottom_strands.items(): signal_strand_ = self.signal_strands[(input_, gate)] - reporter_waste_complexes.append( + reporter_waste_strands.append( (signal_strand_, reporter_bottom_strand_)) + reporter_waste_complexes = [nc.Complex(*strands) for strands in reporter_waste_strands] + self.constraints.append( dc_complex_constraint( reporter_waste_complexes, @@ -928,11 +934,12 @@ def main() -> None: constraints: List[nc.Constraint] = [base_difference_constraint(recognition_domains), strand_substring_constraint(non_fuel_strands, ILLEGAL_SUBSTRINGS)] constraints.extend(seesaw_circuit.constraints) # make mypy happy about the generics with List - design = nc.Design(strands=strands, constraints=constraints) - params = ns.SearchParameters(out_directory='output/square_root_circuit', + design = nc.Design(strands=strands) + params = ns.SearchParameters(constraints=constraints, + out_directory='output/square_root_circuit', # weigh_violations_equally=True, - # restart=True, - report_delay=0.0) + # restart=True + ) ns.search_for_dna_sequences(design, params) diff --git a/examples/sst_canvas.py b/examples/sst_canvas.py index 23ccf357..5a579c84 100644 --- a/examples/sst_canvas.py +++ b/examples/sst_canvas.py @@ -132,7 +132,8 @@ def create_design(width: int, height: int) -> nc.Design: domain_pool_10 = nc.DomainPool(f'length-10_domains', 10, numpy_constraints=numpy_constraints) domain_pool_11 = nc.DomainPool(f'length-11_domains', 11, numpy_constraints=numpy_constraints) - tiles = [] + design = nc.Design() + for x in range(width): for y in range(height): # domains are named after the strand for which they are on the bottom, @@ -166,9 +167,8 @@ def create_design(width: int, height: int) -> nc.Design: w_domain_name = f'we_{x}_{y}' n_domain_name = f'ns_{x - 1}_{y}*' e_domain_name = f'we_{x}_{y + 1}*' - tile = nc.Strand(domain_names=[s_domain_name, w_domain_name, n_domain_name, e_domain_name], - name=f't_{x}_{y}') - tiles.append(tile) + tile = design.add_strand( + domain_names=[s_domain_name, w_domain_name, n_domain_name, e_domain_name], name=f't_{x}_{y}') if (x + y) % 2 == 0: outer_pool = domain_pool_11 @@ -187,7 +187,6 @@ def create_design(width: int, height: int) -> nc.Design: if not w_domain.has_pool(): w_domain.pool = inner_pool - design = nc.Design(strands=tiles) return design diff --git a/nuad/__version__.py b/nuad/__version__.py index 3ba17ec7..1304c322 100644 --- a/nuad/__version__.py +++ b/nuad/__version__.py @@ -1 +1 @@ -version = '0.1.8' # version line; WARNING: do not remove or change this line or comment +version = '0.2.0' # version line; WARNING: do not remove or change this line or comment diff --git a/nuad/constraints.py b/nuad/constraints.py index 71c6be31..bf8b1564 100644 --- a/nuad/constraints.py +++ b/nuad/constraints.py @@ -1347,6 +1347,24 @@ def mandatory_field(ret_type: Type, json_map: Dict, main_key: str, *legacy_keys: class Part(ABC): + def __eq__(self, other: Part) -> bool: + return type(self) == type(other) and self.name == other.name + + # Remember to set subclass __hash__ equal to this implementation; see here: + # https://docs.python.org/3/reference/datamodel.html#object.__hash__ + def __hash__(self) -> int: + return hash(self.key()) + + @property + @abstractmethod + def name(self) -> str: + pass + + @abstractmethod + def key(self) -> str: + # used as key in dictionary + pass + @staticmethod @abstractmethod def name_of_part_type(self) -> str: @@ -1372,10 +1390,22 @@ class DomainPair(Part, Generic[DomainLabel]): domain1: Domain domain2: Domain + def __post_init__(self) -> None: + # make this symmetric so make dict lookups work + if self.domain1.name > self.domain2.name: + self.domain1, self.domain2 = self.domain2, self.domain1 + + # needed to avoid unhashable type error; see + # https://docs.python.org/3/reference/datamodel.html#object.__hash__ + __hash__ = Part.__hash__ + @property def name(self) -> str: return f'{self.domain1.name}, {self.domain2.name}' + def key(self) -> str: + return f'DomainPair[{self.domain1.name}, {self.domain2.name}]' + @staticmethod def name_of_part_type(self) -> str: return 'domain pair' @@ -1389,7 +1419,7 @@ def fixed(self) -> bool: @dataclass -class Domain(JSONSerializable, Part, Generic[DomainLabel]): +class Domain(Part, JSONSerializable, Generic[DomainLabel]): """ Represents a contiguous substring of the DNA sequence of a :any:`Strand`, which is intended to be either single-stranded, or to bind fully to the Watson-Crick complement of the :any:`Domain`. @@ -1542,6 +1572,16 @@ def __init__(self, name: str, pool: Optional[DomainPool] = None, sequence: Optio def name_of_part_type(self) -> str: return 'domain' + def key(self) -> str: + return f'Domain({self.name})' + + # needed to avoid unhashable type error; see + # https://docs.python.org/3/reference/datamodel.html#object.__hash__ + __hash__ = Part.__hash__ + + def __repr__(self) -> str: + return self._name + def individual_parts(self) -> Tuple[Domain, ...]: return self, @@ -1603,17 +1643,6 @@ def from_json_serializable(json_map: Dict[str, Any], name=name, sequence=sequence, fixed=fixed, pool=pool, label=label) return domain - def __hash__(self) -> int: - return hash(self._name) - - def __eq__(self, other: Any) -> bool: - if not isinstance(other, Domain): - return False - return self._name == other._name - - def __repr__(self) -> str: - return self._name - @property def name(self) -> str: """ @@ -2089,14 +2118,10 @@ def _independent_descendent(self) -> Optional[Domain]: return None -_domains_interned: Dict[str, Domain] = {} - - def domains_not_substrings_of_each_other_constraint( check_complements: bool = True, short_description: str = 'dom neq', weight: float = 1.0, min_length: int = 0, - pairs: Optional[Iterable[Tuple[Domain, Domain]]] = None) \ - -> DomainPairConstraint: + pairs: Optional[Iterable[Tuple[Domain, Domain]]] = None) -> DomainPairConstraint: """ Returns constraint ensuring no two domains are substrings of each other. Note that this ensures that no two :any:`Domain`'s are equal if they are the same length. @@ -2112,7 +2137,8 @@ def domains_not_substrings_of_each_other_constraint( For instance if `min_length` is 4, then having two domains with sequences AAAA and CAAAAC would violate this constraint, but domains with sequences AAA and CAAAC would not. :param pairs: - pairs of domains to check (by default all pairs of unequal domains are compared) + pairs of domains to check. + By default all pairs of unequal domains are compared unless both are fixed. :return: a :any:`DomainPairConstraint` ensuring no two domain sequences contain each other as a substring (in particular, if they are equal length, then they are not the same domain) @@ -2242,10 +2268,22 @@ class StrandPair(Part, Generic[StrandLabel, DomainLabel]): strand1: Strand strand2: Strand + def __post_init__(self) -> None: + # make this symmetric so make dict lookups work + if self.strand1.name > self.strand2.name: + self.strand1, self.strand2 = self.strand2, self.strand1 + + # needed to avoid unhashable type error; see + # https://docs.python.org/3/reference/datamodel.html#object.__hash__ + __hash__ = Part.__hash__ + @property def name(self) -> str: return f'{self.strand1.name}, {self.strand2.name}' + def key(self) -> str: + return f'StrandPair[{self.strand1.name}, {self.strand2.name}]' + @staticmethod def name_of_part_type(self) -> str: return 'strand pair' @@ -2261,10 +2299,29 @@ def fixed(self) -> bool: @dataclass class Complex(Part, Generic[StrandLabel, DomainLabel]): strands: Tuple[Strand, ...] + """The strands in this complex.""" + + def __init__(self, *args: Strand) -> None: + """ + Creates a complex of strands given as arguments, e.g., ``Complex(strand1, strand2)`` creates + a 2-strand complex. + """ + for strand in args: + if not isinstance(strand, Strand): + raise TypeError(f'must pass Strands to constructor for complex, not {strand}') + self.strands = tuple(args) + + # needed to avoid unhashable type error; see + # https://docs.python.org/3/reference/datamodel.html#object.__hash__ + __hash__ = Part.__hash__ @property def name(self) -> str: - return ', '.join(strand.name for strand in self.strands) + strand_names = ', '.join(strand.name for strand in self.strands) + return f'Complex[{strand_names}]' + + def key(self) -> str: + return f'Complex[{self.name}]' @staticmethod def name_of_part_type(self) -> str: @@ -2288,7 +2345,7 @@ def fixed(self) -> bool: @dataclass -class Strand(JSONSerializable, Generic[StrandLabel, DomainLabel], Part): +class Strand(Part, JSONSerializable, Generic[StrandLabel, DomainLabel]): """Represents a DNA strand, made of several :any:`Domain`'s. """ domains: List[Domain[DomainLabel]] @@ -2358,7 +2415,6 @@ class Strand(JSONSerializable, Generic[StrandLabel, DomainLabel], Part): """ def __init__(self, - domain_names: Optional[List[str]] = None, domains: Optional[List[Domain[DomainLabel]]] = None, starred_domain_indices: Optional[Iterable[int]] = None, group: str = default_strand_group, @@ -2367,25 +2423,14 @@ def __init__(self, idt: Optional[IDTFields] = None, ) -> None: """ - A :any:`Strand` can be created either by listing explicit :any:`Domain` objects - via parameter `domains`, or by giving names via parameter `domain_names`. - If `domain_names` is specified, then by convention those that end with a ``*`` are - assumed to be starred. Also, :any:`Domain`'s created in this way are "interned" as global variables; - no two :any:`Domain`'s with the same name will be created, and subsequent uses of the same - name will refer to the same :any:`Domain` object. + A :any:`Strand` can be created only by listing explicit :any:`Domain` objects + via parameter `domains`. To specify a :any:`Strand` by giving domain *names*, see the method + :meth:`Design.add_strand`. - :param domain_names: - Names of the :any:`Domain`'s on this :any:`Strand`. - Mutually exclusive with :py:data:`Strand.domains` and :py:data:`Strand.starred_domain_indices`. :param domains: - Dictionary mapping each :any:`Domain` on this :any:`Strand` to the Boolean value indicating - whether it is a starred :any:`Domain`. - Mutually exclusive with :py:data:`Strand.domain_names`, and must be specified jointly with - :py:data:`Strand.starred_domain_indices`. + list of :any:`Domain`'s on this :any:`Strand` :param starred_domain_indices: Indices of :any:`Domain`'s in `domains` that are starred. - Mutually exclusive with :py:data:`Strand.domain_names`, and must be specified jointly with - :py:data:`Strand.domains`. :param group: name of group of this :any:`Strand`. :param name: @@ -2399,33 +2444,6 @@ def __init__(self, self._all_intersecting_domains = None self.group = group self._name = name - if (domain_names is not None and not (domains is None and starred_domain_indices is None)) or \ - (domain_names is None and not (domains is not None and starred_domain_indices is not None)): - raise ValueError('exactly one of domain_names or ' - 'domains and starred_domain_indices must be non-None\n' - f'domain_names: {domain_names}\n' - f'domains: {domains}\n' - f'starred_domain_indices: {starred_domain_indices}') - - elif domain_names is not None: - domains = [] - starred_domain_indices = OrderedSet() - for idx, domain_name in enumerate(domain_names): - is_starred = domain_name.endswith('*') - if is_starred: - domain_name = domain_name[:-1] - - # domain = Domain(name) if name not in _domains_interned else _domains_interned[name] - domain: Domain - if domain_name not in _domains_interned: - domain = Domain(name=domain_name) - _domains_interned[domain_name] = domain - else: - domain = _domains_interned[domain_name] - - domains.append(domain) - if is_starred: - starred_domain_indices.add(idx) # XXX: moved this check to Design constructor to allow subdomain graphs to be # constructed gradually while building up the design @@ -2450,6 +2468,13 @@ def __init__(self, def name_of_part_type(self) -> str: return 'strand' + def key(self) -> str: + return f'Strand({self._hash_domain_names_concatenated})' + + # needed to avoid unhashable type error; see + # https://docs.python.org/3/reference/datamodel.html#object.__hash__ + __hash__ = Part.__hash__ + def individual_parts(self) -> Tuple[Strand, ...]: return self, @@ -2516,14 +2541,6 @@ def intersects_domain(self, domain: Domain) -> bool: """ return domain in self.all_intersecting_domains() - def __hash__(self) -> int: - return self._hash_domain_names_concatenated - - def __eq__(self, other: Any) -> bool: - if not isinstance(other, Strand): - return False - return self._domain_names_concatenated == other._domain_names_concatenated - def length(self) -> int: """ :return: @@ -2985,6 +3002,8 @@ class Design(Generic[StrandLabel, DomainLabel], JSONSerializable): strands: List[Strand[StrandLabel, DomainLabel]] """List of all :any:`Strand`'s in this :any:`Design`.""" + _domains_interned: Dict[str, Domain] + ################################################# # derived fields, so not specified in constructor @@ -3016,7 +3035,7 @@ class Design(Generic[StrandLabel, DomainLabel], JSONSerializable): Computed from :py:data:`Design.strands`, so not specified in constructor. """ - def __init__(self, strands: Iterable[Strand]) -> None: + def __init__(self, strands: Iterable[Strand] = ()) -> None: """ :param strands: the :any:`Strand`'s in this :any:`Design` @@ -3025,6 +3044,7 @@ def __init__(self, strands: Iterable[Strand]) -> None: self.check_all_subdomain_graphs_acyclic() self.check_all_subdomain_graphs_uniquely_assignable() self.compute_derived_fields() + self._domains_interned = {} def compute_derived_fields(self) -> None: """ @@ -3183,6 +3203,108 @@ def from_json_serializable(json_map: Dict[str, Any], return Design(strands=strands) + def add_strand(self, + domain_names: Optional[List[str]] = None, + domains: Optional[List[Domain[DomainLabel]]] = None, + starred_domain_indices: Optional[Iterable[int]] = None, + group: str = default_strand_group, + name: Optional[str] = None, + label: Optional[StrandLabel] = None, + idt: Optional[IDTFields] = None, + ) -> Strand: + """ + This is an alternative way to create strands instead of calling the :any:`Strand` constructor + explicitly. It behaves similarly to the :any:`Strand` constructor, but it has an option + to specify :any:`Domain`'s simply by giving a name. + + A :any:`Strand` can be created either by listing explicit :any:`Domain` objects via parameter + `domains` (as in the :any:`Strand` constructor), or by giving names via parameter `domain_names`. + If `domain_names` is specified, then by convention those that end with a ``*`` are + assumed to be starred. Also, :any:`Domain`'s created in this way are "interned" as variables + in a cache stored in the :any:`Design` object; + no two :any:`Domain`'s with the same name in this design will be created, + and subsequent uses of the same name will refer to the same :any:`Domain` object. + + :param domain_names: + Names of the :any:`Domain`'s on this :any:`Strand`. + Mutually exclusive with :py:data:`Strand.domains` and :py:data:`Strand.starred_domain_indices`. + :param domains: + list of :any:`Domain`'s on this :any:`Strand`. + Mutually exclusive with :py:data:`Strand.domain_names`, and must be specified jointly with + :py:data:`Strand.starred_domain_indices`. + :param starred_domain_indices: + Indices of :any:`Domain`'s in `domains` that are starred. + Mutually exclusive with :py:data:`Strand.domain_names`, and must be specified jointly with + :py:data:`Strand.domains`. + :param group: + name of group of this :any:`Strand`. + :param name: + Name of this :any:`Strand`. + :param label: + Label to associate with this :any:`Strand`. + :param idt: + :any:`IDTFields` object to associate with this :any:`Strand`; needed to call + methods for exporting to IDT formats (e.g., :meth:`Strand.write_idt_bulk_input_file`) + :return: + the :any:`Strand` that is created + """ + if (domain_names is not None and not (domains is None and starred_domain_indices is None)) or \ + (domain_names is None and not (domains is not None and starred_domain_indices is not None)): + raise ValueError('exactly one of domain_names or ' + 'domains and starred_domain_indices must be non-None\n' + f'domain_names: {domain_names}\n' + f'domains: {domains}\n' + f'starred_domain_indices: {starred_domain_indices}') + + elif domain_names is not None: + domains = [] + starred_domain_indices = OrderedSet() + for idx, domain_name in enumerate(domain_names): + is_starred = domain_name.endswith('*') + if is_starred: + domain_name = domain_name[:-1] + + # domain = Domain(name) if name not in _domains_interned else _domains_interned[name] + domain: Domain + if domain_name not in self._domains_interned: + domain = Domain(name=domain_name) + self._domains_interned[domain_name] = domain + else: + domain = self._domains_interned[domain_name] + + domains.append(domain) + if is_starred: + starred_domain_indices.add(idx) + + domains_of_strand = list(domains) # type: ignore + strand = Strand(domains=domains_of_strand, + starred_domain_indices=starred_domain_indices, + group=group, + name=name, + label=label, + idt=idt) + + for existing_strand in self.strands: + if strand.name == existing_strand.name: + raise ValueError(f'strand name {strand.name} already exists for this strand:\n' + f' {existing_strand}\n' + f'so it cannot be used for the new strand\n' + f' {strand}') + self.strands.append(strand) + + for domain_in_strand in strand.domains: + domains_in_tree = domain_in_strand.all_domains_in_tree() + for domain in domains_in_tree: + if domain not in self.domains: + self.domains.append(domain) + name = domain.name + if name in self.domains_by_name and domain is not self.domains_by_name[name]: + raise ValueError(f'domain names must be unique, ' + f'but I found two different domains with name {domain.name}') + self.domains_by_name[domain.name] = domain + + return strand + @staticmethod def assign_modifications_to_strands(strands: List[Strand], strand_jsons: List[dict], all_mods: Dict[str, dm.Modification]) -> None: @@ -3530,8 +3652,8 @@ def from_scadnano_design(sc_design: sc.Design[StrandLabel, DomainLabel], # make dsd StrandGroups, taking names from Strands and Domains, # and assign (and maybe fix) DNA sequences - dsd_strands: List[Strand] = [] strand_names: Set[str] = set() + design: Design[StrandLabel, DomainLabel] = Design() for group, sc_strands in sc_strand_groups.items(): for sc_strand in sc_strands: # do not include strands with the same name more than once @@ -3543,13 +3665,13 @@ def from_scadnano_design(sc_design: sc.Design[StrandLabel, DomainLabel], domain_names: List[str] = [domain.name for domain in sc_strand.domains] sequence = sc_strand.dna_sequence - dsd_strand: Strand[StrandLabel, DomainLabel] = Strand(domain_names=domain_names, - group=group, - name=sc_strand.name, - label=sc_strand.label) + nuad_strand: Strand[StrandLabel, DomainLabel] = design.add_strand(domain_names=domain_names, + group=group, + name=sc_strand.name, + label=sc_strand.label) # assign sequence if sequence is not None: - for dsd_domain, sc_domain in zip(dsd_strand.domains, sc_strand.domains): + for dsd_domain, sc_domain in zip(nuad_strand.domains, sc_strand.domains): domain_sequence = sc_domain.dna_sequence # if this is a starred domain, # take the WC complement first so the dsd Domain stores the "canonical" sequence @@ -3562,17 +3684,17 @@ def from_scadnano_design(sc_design: sc.Design[StrandLabel, DomainLabel], dsd_domain.set_sequence(domain_sequence) # set domain labels - for dsd_domain, sc_domain in zip(dsd_strand.domains, sc_strand.domains): + for dsd_domain, sc_domain in zip(nuad_strand.domains, sc_strand.domains): if dsd_domain.label is None: dsd_domain.label = sc_domain.label elif sc_domain.label is not None and warn_existing_domain_labels: logger.warning(f'warning; dsd domain already has label {dsd_domain.label}; ' f'skipping assignment of scadnano label {sc_domain.label}') - dsd_strands.append(dsd_strand) - strand_names.add(dsd_strand.name) + strand_names.add(nuad_strand.name) + + design.compute_derived_fields() - design: Design[StrandLabel, DomainLabel] = Design(strands=dsd_strands) return design @staticmethod @@ -3838,256 +3960,50 @@ def check_all_subdomain_graphs_uniquely_assignable(self) -> None: d._check_acyclic_subdomain_graph() # noqa d._check_subdomain_graph_is_uniquely_assignable() # noqa + def check_names_unique(self) -> None: + # domain names already checked in compute_derived_fields() + self.check_strand_names_unique() + self.check_domain_pool_names_unique() + + def check_strand_names_unique(self) -> None: + strands_by_name = {} + for strand in self.strands: + name = strand.name + if name in strands_by_name: + raise ValueError(f'found two strands with name {name}:\n' + f' {strand}\n' + f'and\n' + f' {strands_by_name[name]}') + + def check_domain_pool_names_unique(self) -> None: + # self.domain_pools() already computed by compute_derived_fields() + domain_pools_by_name = {} + for pool in self.domain_pools(): + name = pool.name + if name in domain_pools_by_name: + raise ValueError(f'found two DomainPools with name {name}:\n' + f' {pool}\n' + f'and\n' + f' {domain_pools_by_name[name]}') + else: + domain_pools_by_name[pool.name] = pool + # represents a "Design Part", e.g., Strand, Tuple[Domain, Domain], etc... whatever portion of the Design # is checked by the constraint +# NOTE: this is needed in addition to the abstract base class Part, because it allows mypy type checking +# of the various different types of evaluate and evaluate_bulk functions. Otherwise they have more +# abstract type signatures, and we can't write something like evaluate(strand: Strand) +# Maybe if we eventually get rid of the parts and only pass in the sequences, this will not be needed. DesignPart = TypeVar('DesignPart', Domain, Strand, DomainPair, StrandPair, Complex, - # Iterable[Domain], - # Iterable[Strand], - # Iterable[Tuple[Domain, Domain]], - # Iterable[Tuple[Strand, Strand]], - # Iterable[Complex], Design) -@dataclass -class Violation(Generic[DesignPart]): - # Represents a violation of a single :any:`Constraint` in a :any:`Design`. The "part" of the :any:`Design` - # that violated the constraint is generic type `DesignPart` (e.g., for :any:`StrandPairConstraint`, - # DesignPart = :any:`Pair` [:any:`Strand`]). - - constraint: Constraint - # :any:`Constraint` that was violated to result in this :any:`Violation`. - - part: DesignPart - # DesignPart that caused this violation - - domains: FrozenSet[Domain] # = field(init=False, hash=False, compare=False, default=None) - # :any:`Domain`'s that were involved in violating :py:data:`Violation.constraint` - - summary: str - - score: float - - def __init__(self, constraint: Constraint, part: DesignPart, domains: Iterable[Domain], - score: float, summary: str) -> None: - # :param constraint: - # :any:`Constraint` that was violated to result in this - # :param domains: - # :any:`Domain`'s that were involved in violating :py:data:`Violation.constraint` - # :param score: - # total "score" of this violation, typically something like an excess energy over a - # threshold, squared, multiplied by the :data:`Constraint.weight` - object.__setattr__(self, 'constraint', constraint) - object.__setattr__(self, 'part', part) - domains_frozen = frozenset(domains) - object.__setattr__(self, 'domains', domains_frozen) - object.__setattr__(self, 'score', score) - object.__setattr__(self, 'summary', summary) - - def __repr__(self) -> str: - return f'Violation({self.constraint.short_description}, score={self.score:.2f}, ' \ - f'summary={self.summary})' - - def __str__(self) -> str: - return repr(self) - - # _Violation equality based on identity; different Violations in memory are considered different, - # even if all data between them matches. Don't create the same Violation twice! - def __hash__(self): - return super().__hash__() - - def __eq__(self, other): - return self is other - - -@dataclass -class ViolationSet: - # Represents violations of :any:`Constraint`'s in a :any:`Design`. - # - # It is designed to be efficiently updateable when a single :any:`Domain` changes, to efficiently update - # only those violations of :any:`Constraint`'s that could have been affected by the changed :any:`Domain`. - - violations_all: Dict[Constraint, OrderedSet[Violation]] - # Dict mapping each :any:`Constraint` to the set of all :any:`Violation`'s of it. - - domain_to_violations: Dict[Domain, OrderedSet[Violation]] - # Dict mapping each :any:`constraint.Domain` to the set of all :any:`Violation`'s for which it is blamed - - violations_nonfixed: Dict[Constraint, OrderedSet[Violation]] - # Dict mapping each :any:`Constraint` to the set of all :any:`Violations` - # that are associated to non-fixed :any:`constraint.Domain`'s. - - violations_fixed: Dict[Constraint, OrderedSet[Violation]] - # Dict mapping each :any:`Constraint` to the set of all :any:`Violations` - # that are associated to fixed :any:`constraint.Domain`'s. - - num_checked: Dict[Constraint, int] - - # number of instances of each :any:`Constraint` that were checked - - def __init__(self) -> None: - self.violations_all = defaultdict(OrderedSet) - self.domain_to_violations = defaultdict(OrderedSet) - self.violations_nonfixed = defaultdict(OrderedSet) - self.violations_fixed = defaultdict(OrderedSet) - self.num_checked = defaultdict(int) - - def __repr__(self): - lines = "\n ".join(map(str, self.violations_all.values())) - return f'ViolationSet(\n {lines})' - - def __str__(self): - return repr(self) - - def update(self, new_violations: Dict[Domain, OrderedSet[Violation]]) -> None: - # Update this :any:`ViolationSet` by merging in new violations from `new_violations`. - # - # :param new_violations: dict mapping each :any:`Domain` to the set of :any:`Violation`'s - # for which it is blamed - for domain, domain_violations in new_violations.items(): - self.domain_to_violations[domain].update(domain_violations) - for violation in domain_violations: - self.violations_all[violation.constraint].add(violation) - if not violation.part.fixed: - self.violations_nonfixed[violation.constraint].add(violation) - else: - self.violations_fixed[violation.constraint].add(violation) - - def clone(self) -> ViolationSet: - # Returns a deep-ish copy of this :any:`ViolationSet`. - # :py:data:`ViolationSet.all_violations` is a new list, - # but containing the same :any:`Violation`'s. - # :py:data:`ViolationSet.domain_to_violations` is a new dict, - # and each of its values is a new set, but each of the :any:`Domain`'s and :any:`Violation`'s - # is the same object as in the original :any:`ViolationSet`. - # - # This is required for efficiently processing :any:`Violation`'s from one search iteration to the - # next. - # - # :return: A deep-ish copy of this :any:`ViolationSet`. - domain_to_violations_deep_copy = defaultdict(OrderedSet, self.domain_to_violations) - for domain, violations in domain_to_violations_deep_copy.items(): - domain_to_violations_deep_copy[domain] = OrderedSet(violations) - - violations_all_deep_copy = defaultdict(OrderedSet, self.violations_all) - for constraint, violations in violations_all_deep_copy.items(): - violations_all_deep_copy[constraint] = OrderedSet(violations) - - violations_nonfixed_deep_copy = defaultdict(OrderedSet, self.violations_nonfixed) - for constraint, violations in violations_nonfixed_deep_copy.items(): - violations_nonfixed_deep_copy[constraint] = OrderedSet(violations) - - violations_fixed_deep_copy = defaultdict(OrderedSet, self.violations_fixed) - for constraint, violations in violations_fixed_deep_copy.items(): - violations_fixed_deep_copy[constraint] = OrderedSet(violations) - - result = ViolationSet() - result.violations_all = violations_all_deep_copy - result.domain_to_violations = domain_to_violations_deep_copy - result.violations_nonfixed = violations_nonfixed_deep_copy - result.violations_fixed = violations_fixed_deep_copy - - return result - - def remove_violations_of_domain(self, domain: Domain) -> None: - # Removes any :any:`Violation`'s blamed on `domain`. - # :param domain: the :any:`Domain` whose :any:`Violation`'s should be removed - - # XXX: need to make a copy of this set, since we are modifying the sets in place - # (values in self.domain_to_violations) - violations_of_domain = set(self.domain_to_violations[domain]) - - for violations in self.violations_all.values(): - violations -= violations_of_domain - for violations in self.violations_nonfixed.values(): - violations -= violations_of_domain - for violations in self.violations_fixed.values(): - violations -= violations_of_domain - - for violations_of_other_domain in self.domain_to_violations.values(): - violations_of_other_domain -= violations_of_domain - - assert len(self.domain_to_violations[domain]) == 0 - - def total_score(self) -> float: - """ - :return: Total score of all violations. - """ - return sum(violation.score - for violations in self.violations_all.values() - for violation in violations) - - def total_score_nonfixed(self) -> float: - # :return: - # Total score of all violations attributed to :any:`constraint.Domain`'s with - # :any:`constraint.Domain.fixed` = False. - return sum(violation.score - for violations in self.violations_nonfixed.values() - for violation in violations) - - def total_score_fixed(self) -> float: - # :return: - # Total score of all violations attributed to :any:`constraint.Domain`'s with - # :any:`constraint.Domain.fixed` = False. - return sum(violation.score - for violations in self.violations_fixed.values() - for violation in violations) - - def score_of_constraint(self, constraint: Constraint) -> float: - """ - :param constraint: - constraint to filter scores on - :return: - Total score of all violations due to `constraint`. - """ - return sum(violation.score - for violations in self.violations_all.values() - for violation in violations - if violation.constraint == constraint) - - def score_of_constraint_nonfixed(self, constraint: Constraint) -> float: - # :param constraint: - # constraint to filter scores on - # :return: - # Total score of all nonfixed violations due to `constraint`. - return sum(violation.score - for violations in self.violations_nonfixed.values() - for violation in violations - if violation.constraint == constraint) - - def score_of_constraint_fixed(self, constraint: Constraint) -> float: - # :param constraint: - # constraint to filter scores on - # :return: - # Total score of all fixed violations due to `constraint`. - return sum(violation.score - for violations in self.violations_fixed.values() - for violation in violations - if violation.constraint == constraint) - - def num_violations(self) -> float: - # :return: Total number of violations. - return sum(len(violations) for violations in self.violations_all.values()) - - def num_violations_nonfixed(self) -> float: - # :return: Total number of nonfixed violations. - return sum(len(violations) for violations in self.violations_nonfixed.values()) - - def num_violations_fixed(self) -> float: - # :return: Total number of fixed violations. - return sum(len(violations) for violations in self.violations_fixed.values()) - - def has_nonfixed_violations(self) -> bool: - # :return: whether there are any nonfixed Violations in this ViolationSet - return self.num_violations_nonfixed() > 0 - - @dataclass(frozen=True, eq=False) class Constraint(Generic[DesignPart], ABC): """ @@ -4323,7 +4239,10 @@ class ConstraintWithStrandPairs(Constraint[DesignPart], Generic[DesignPart]): # @dataclass(frozen=True, eq=False) # type: ignore class DomainPairConstraint(ConstraintWithDomainPairs[DomainPair], SingularConstraint[DomainPair]): - """Constraint that applies to a pair of :any:`Domain`'s.""" + """Constraint that applies to a pair of :any:`Domain`'s. + + These should be symmetric, meaning that the constraint will give the same evaluation whether its + evaluate method is given the pair (domain1, domain2), or the pair (domain2, domain1).""" @staticmethod def part_name() -> str: @@ -4333,7 +4252,10 @@ def part_name() -> str: @dataclass(frozen=True, eq=False) # type: ignore class StrandPairConstraint(ConstraintWithStrandPairs[StrandPair], SingularConstraint[StrandPair]): - """Constraint that applies to a pair of :any:`Strand`'s.""" + """Constraint that applies to a pair of :any:`Strand`'s. + + These should be symmetric, meaning that the constraint will give the same evaluation whether its + evaluate method is given the pair (strand1, strand2), or the pair (strand2, strand1).""" @staticmethod def part_name() -> str: @@ -4976,10 +4898,9 @@ def evaluate_bulk(domain_pairs: Iterable[DomainPair]) -> List[Tuple[DomainPair, for pair, energy in zip(domain_pairs, energies): excess = threshold - energy - if excess > 0.0: - summary = f'{energy:6.2f} kcal/mol' - pair_score_summary = (pair, excess, summary) - pairs_scores_summaries.append(pair_score_summary) + summary = f'{energy:6.2f} kcal/mol' + pair_score_summary = (pair, excess, summary) + pairs_scores_summaries.append(pair_score_summary) return pairs_scores_summaries pairs_tuple = None @@ -5354,10 +5275,9 @@ def evaluate(strand_pairs: Iterable[StrandPair]) -> List[Tuple[StrandPair, float for pair, energy in zip(strand_pairs, energies): excess = threshold - energy - if excess > 0.0: - summary = f'{energy:6.2f} kcal/mol' - pair_score_summary = (pair, excess, summary) - pairs_scores_summaries.append(pair_score_summary) + summary = f'{energy:6.2f} kcal/mol' + pair_score_summary = (pair, excess, summary) + pairs_scores_summaries.append(pair_score_summary) return pairs_scores_summaries pairs_tuple = None @@ -5456,10 +5376,9 @@ def evaluate(strand_pairs: Iterable[StrandPair]) -> List[Tuple[StrandPair, float for pair, energy in zip(strand_pairs, energies): excess = threshold - energy - if excess > 0.0: - summary = f'{energy:6.2f} kcal/mol' - pair_score_summary = (pair, excess, summary) - pairs_scores_summaries.append(pair_score_summary) + summary = f'{energy:6.2f} kcal/mol' + pair_score_summary = (pair, excess, summary) + pairs_scores_summaries.append(pair_score_summary) return pairs_scores_summaries pairs_tuple = None @@ -6855,7 +6774,7 @@ class _BasePair: """ -def _get_implicitly_bound_domain_addresses(strand_complex: Complex, +def _get_implicitly_bound_domain_addresses(strand_complex: Iterable[Strand], nonimplicit_base_pairs_domain_names: Optional[Set[str]] = None) \ -> Dict[StrandDomainAddress, StrandDomainAddress]: """Returns a map of all the implicitly bound domain addresses @@ -6983,7 +6902,7 @@ def _leafify_strand( def _get_base_pair_domain_endpoints_to_check( - strand_complex: Complex, + strand_complex: Iterable[Strand], nonimplicit_base_pairs: Iterable[BoundDomains] = None) -> Set[_BasePairDomainEndpoint]: """Returns the set of all the _BasePairDomainEndpoint to check @@ -7001,8 +6920,8 @@ def _get_base_pair_domain_endpoints_to_check( addr_translation_table: Dict[StrandDomainAddress, List[StrandDomainAddress]] = {} # Need to convert strands into strands lowest level subdomains - leafify_strand_complex = Complex(tuple( - [_leafify_strand(strand, addr_translation_table) for strand in strand_complex])) + leafify_strand_complex = Complex( + *[_leafify_strand(strand, addr_translation_table) for strand in strand_complex]) new_nonimplicit_base_pairs = [] if nonimplicit_base_pairs: diff --git a/nuad/search.py b/nuad/search.py index 446b8141..5a6b07ab 100644 --- a/nuad/search.py +++ b/nuad/search.py @@ -20,8 +20,8 @@ from collections import defaultdict, deque import collections.abc as abc from dataclasses import dataclass, field -from typing import List, Tuple, Sequence, FrozenSet, Optional, Dict, Callable, Iterable, Any, \ - Deque, TypeVar, Union +from typing import List, Tuple, Sequence, FrozenSet, Optional, Dict, Callable, Iterable, \ + Deque, TypeVar, Union, Generic, Iterator, Any import statistics import textwrap import re @@ -74,81 +74,6 @@ def default_output_directory() -> str: return os.path.join('output', f'{script_name_no_ext()}--{timestamp()}') -def _violations_of_constraints(design: Design, - constraints: [nc.Constraint], - never_increase_score: bool, - domains_changed: Optional[Iterable[Domain]], - violation_set_old: Optional[nc.ViolationSet], - iteration: int, - ) -> nc.ViolationSet: - """ - :param design: - The :any:`Design` for which to find DNA sequences. - :param constraints: - List of :any:`constraints.Constraint`'s to apply - :param domains_changed: - The :any:`Domain`'s that just changed; if None, then recalculate all constraints, otherwise assume no - constraints changed that do not involve a :any:`Domain` in `domains_changed`. - :param violation_set_old: - :any:`ViolationSet` to update, assuming `domain_changed` is the only :any:`Domain` that changed. - :param never_increase_score: - Indicates whether the search algorithm is using an update rule that never increases the total score - of violations (i.e., it only goes downhill). If so we can optimize and stop this function early as - soon as we find that the violations discovered so far exceed the total score of the current optimal - solution. In later stages of the search, when the optimal solution so far has very few violated - constraints, this vastly speeds up the search by allowing most of the constraint checking to be - skipping for most choices of DNA sequences to `domain_changed`. - :param iteration: - Current iteration number; useful for debugging (e.g., conditional breakpoints). - :return: - dict mapping each :any:`Domain` to the list of constraints it violated - """ - - if iteration > 0: - pass # to quiet PEP warnings - - if not ((domains_changed is None and violation_set_old is None) or ( - domains_changed is not None and violation_set_old is not None)): - raise ValueError('domains_changed and violation_set_old should both be None or both be not None; ' - f'domains_changed = {domains_changed}' - f'violation_set_old = {violation_set_old}') - - # remove violations involving domains_changed, since they might evaluate differently now - violation_set: nc.ViolationSet - if domains_changed is None: - violation_set = nc.ViolationSet() - else: - assert violation_set_old is not None - violation_set = violation_set_old.clone() # Keep old in case no improvement - for domain_changed in domains_changed: - assert not domain_changed.fixed - violation_set.remove_violations_of_domain(domain_changed) - - # find new violations of parts involving domains in domains_changed, and add them to violation_set - for constraint in constraints: - parts_to_check = find_parts_to_check(constraint, design, domains_changed) - - current_score_gap = violation_set_old.total_score() - violation_set.total_score() \ - if never_increase_score and violation_set_old is not None else None - - violations, quit_early_in_func = _violations_of_constraint( - parts=parts_to_check, constraint=constraint, current_score_gap=current_score_gap, - domains_changed=domains_changed, design=design) - violation_set.update(violations) - - parts_to_check_total = find_parts_to_check(constraint, design, None) - violation_set.num_checked[constraint] = len(parts_to_check_total) - - quit_early = _quit_early(never_increase_score, violation_set, violation_set_old) - assert quit_early == quit_early_in_func - if quit_early: - return violation_set - - return violation_set - - -# optimization so we don't keep recomputing parts to check for each constraint, -# only used when domains_changed is None, otherwise the parts to check depends on the domains that changed _parts_to_check_cache = {} @@ -188,21 +113,23 @@ def find_parts_to_check(constraint: nc.Constraint, design: nc.Design, return parts_to_check +# XXX: important that this is absolute constant. Sometimes this is called for the total weight of all +# violations, and sometimes just for the difference between old and new (the latter are smaller). +# If using relative epsilon, then those can disagree and trigger the assert statement that +# checks that _violations_of_constraints quit_early agrees with the subroutines it calls. +_epsilon = 0.000001 + + def _is_significantly_greater(x: float, y: float) -> bool: - # epsilon = min(abs(x), abs(y)) * 0.001 - # XXX: important that this is absolute constant. Sometimes this is called for the total weight of all - # violations, and sometimes just for the difference between old and new (the latter are smaller). - # If using relative epsilon, then those can disagree and trigger the assert statement that - # checks that _violations_of_constraints quit_early agrees with the subroutines it calls. - epsilon = 0.001 - return x > y + epsilon + return x > y + _epsilon + + +def _is_significantly_less(x: float, y: float) -> bool: + return x < y - _epsilon -def _quit_early(never_increase_score: bool, - violation_set: nc.ViolationSet, - violation_set_old: Optional[nc.ViolationSet]) -> bool: - return (never_increase_score and violation_set_old is not None - and _is_significantly_greater(violation_set.total_score(), violation_set_old.total_score())) +def _is_significantly_different(x: float, y: float) -> bool: + return abs(x - y) > _epsilon def _at_least_one_domain_unfixed(pair: Tuple[Domain, Domain]) -> bool: @@ -214,10 +141,10 @@ def _determine_domains_to_check(all_domains: Iterable[Domain], constraint: ConstraintWithDomains) -> Sequence[Domain]: """ Determines domains to check in `all_domains`. - If `domains_changed` is None, then this is all that are not fixed if constraint.domains + If `domains_new` is None, then this is all that are not fixed if constraint.domains is None, otherwise it is constraint.domains. - If `domains_changed` is not None, then among those domains specified above, - it is just those in `domains_changed` that appear in `all_domains`. + If `domains_new` is not None, then among those domains specified above, + it is just those in `domains_new` that appear in `all_domains`. """ # either all pairs, or just constraint.pairs if specified domains_to_check_if_domain_changed_none = all_domains \ @@ -266,7 +193,7 @@ def _determine_domain_pairs_to_check(all_domains: Iterable[Domain], it is all pairs where one of the two is `domain_changed`. """ # some code is repeated here, but otherwise it's way too slow on a large design to iterate over - # all pairs of domains only to filter out most of them that don't intersect domains_changed + # all pairs of domains only to filter out most of them that don't intersect domains_new if domains_changed is None: # either all pairs, or just constraint.pairs if specified if constraint.pairs is not None: @@ -279,7 +206,8 @@ def not_subdomain(dom1: Domain, dom2: Domain) -> bool: pairs = all_pairs(all_domains, with_replacement=constraint.check_domain_against_itself, where=not_subdomain) - domain_pairs_to_check = [DomainPair(domain1, domain2) for domain1, domain2 in pairs] + domain_pairs_to_check = [DomainPair(domain1, domain2) for domain1, domain2 in pairs + if not (domain1.fixed and domain2.fixed)] else: # either all pairs, or just constraint.pairs if specified @@ -313,7 +241,7 @@ def _determine_strand_pairs_to_check(all_strands: Iterable[Strand], Similar to _determine_domain_pairs_to_check but for strands. """ # some code is repeated here, but otherwise it's way too slow on a large design to iterate over - # all pairs of strands only to filter out most of them that don't intersect domains_changed + # all pairs of strands only to filter out most of them that don't intersect domains_new if domains_changed is None: # either all pairs, or just constraint.pairs if specified if constraint.pairs is not None: @@ -336,8 +264,8 @@ def _determine_strand_pairs_to_check(all_strands: Iterable[Strand], if domain_changed in strand.domains] for strand_with_domain_changed in strands_with_domain_changed: for other_strand in all_strands: - if strand_with_domain_changed is not other_strand or \ - constraint.check_strand_against_itself: + if (strand_with_domain_changed is not other_strand or + constraint.check_strand_against_itself): strand_pairs_to_check.append(StrandPair(strand_with_domain_changed, other_strand)) return strand_pairs_to_check @@ -391,61 +319,6 @@ def _strands_containing_domains(domains: Optional[Iterable[Domain]], strands: Li _empty_frozen_set: FrozenSet = frozenset() -def _violations_of_constraint(parts: Sequence[DesignPart], - constraint: Constraint[DesignPart], - current_score_gap: Optional[float], - domains_changed: Optional[Iterable[Domain]] = None, - design: Optional[Design] = None, # only used with DesignConstraint - ) -> Tuple[Dict[Domain, OrderedSet[nc.Violation]], bool]: - score_discovered_here: float = 0.0 - quit_early = False - - # measure violations of constraints and collect in list of triples (part, score, summary) - violating_parts_scores_summaries: List[Tuple[DesignPart, float, str]] = [] - if isinstance(constraint, SingularConstraint): - if not constraint.parallel or len(parts) == 1 or nc.cpu_count() == 1: - for part in parts: - seqs = tuple(indv_part.sequence() for indv_part in part.individual_parts()) - score, summary = constraint.call_evaluate(seqs, part) - if score > 0.0: - violating_parts_scores_summaries.append((part, score, summary)) - if current_score_gap is not None: - score_discovered_here += score - if _is_significantly_greater(score_discovered_here, current_score_gap): - quit_early = True - break - else: - raise NotImplementedError('TODO: implement parallelization') - - elif isinstance(constraint, (BulkConstraint, DesignConstraint)): - if isinstance(constraint, DesignConstraint): - violating_parts_scores_summaries = constraint.call_evaluate_design(design, domains_changed) - else: - # XXX: I don't understand the mypy error on the next line - violating_parts_scores_summaries = constraint.call_evaluate_bulk(parts) # type: ignore - - # we can't quit this function early, - # but we can let the caller know to stop evaluating constraints - total_score = sum(score for _, score, _ in violating_parts_scores_summaries) - if current_score_gap is not None: - score_discovered_here += total_score - if _is_significantly_greater(score_discovered_here, current_score_gap): - quit_early = True - else: - raise AssertionError(f'constraint {constraint} of unrecognized type {constraint.__class__.__name__}') - - # assign blame for violations to domains by looking up associated domains in each part - violations: Dict[Domain, OrderedSet[nc.Violation]] = defaultdict(OrderedSet) - for part, score, summary in violating_parts_scores_summaries: - domains = _independent_domains_in_part(part, exclude_fixed=False) - violation = nc.Violation(constraint=constraint, part=part, domains=domains, - score=score, summary=summary) - for domain in domains: - violations[domain].add(violation) - - return violations, quit_early - - def _independent_domains_in_part(part: nc.DesignPart, exclude_fixed: bool) -> List[Domain]: """ :param part: @@ -502,7 +375,7 @@ def _sequences_fragile_format_output_to_file(design: Design, def _write_intermediate_files(*, design: nc.Design, params: SearchParameters, rng: numpy.random.Generator, num_new_optimal: int, directories: _Directories, - violation_set: nc.ViolationSet) -> None: + eval_set: EvaluationSet) -> None: num_new_optimal_padded = f'{num_new_optimal}' if params.num_digits_update is None \ else f'{num_new_optimal:0{params.num_digits_update}d}' @@ -516,7 +389,7 @@ def _write_intermediate_files(*, design: nc.Design, params: SearchParameters, rn num_new_optimal_padded=num_new_optimal_padded) _write_report(params=params, directories=directories, - num_new_optimal_padded=num_new_optimal_padded, violation_set=violation_set) + num_new_optimal_padded=num_new_optimal_padded, eval_set=eval_set) def _write_design(design: Design, params: SearchParameters, directories: _Directories, @@ -551,12 +424,12 @@ def _write_sequences(design: Design, params: SearchParameters, directories: _Dir def _write_report(params: SearchParameters, directories: _Directories, - num_new_optimal_padded: str, violation_set: nc.ViolationSet) -> None: + num_new_optimal_padded: str, eval_set: EvaluationSet) -> None: content = f'''\ Report on constraints ===================== ''' + summary_of_constraints(params.constraints, params.report_only_violations, - violation_set=violation_set) + eval_set=eval_set) best_filename = directories.best_report_full_filename_noext() idx_filename = directories.indexed_report_full_filename_noext(num_new_optimal_padded) \ @@ -704,7 +577,7 @@ def best_report_full_filename_noext(self) -> str: def _check_design(design: nc.Design) -> None: - # verify design is legal + # verify design is legal in senses not already checked for strand in design.strands: for domain in strand.domains: @@ -801,7 +674,7 @@ class SearchParameters: info_log_file: bool = False """ - By default, the text written to the screen through logger.info (on the logger instance used in + If True, the text written to the screen through logger.info (on the logger instance used in dsd.constraints) is written to the file log_info.log in the directory `out_directory`. """ @@ -882,6 +755,51 @@ class SearchParameters: Whether to log the time taken per iteration to the screen. """ + scrolling_output: bool = True + r""" + If True, then screen output "scrolls" on the screen, i.e., a newline is printed after each iteration, + e.g., + + .. code-block:: console + + $ python sst_canvas.py + using random seed of 1; use this same seed to reproduce this run + number of processes in system: 4 + |-----------|--------|-----------|-----------|----------|---------------| + | iteration | update | opt score | new score | StrandSS | StrandPairRNA | + | 0 | 0 | 2555.9 | 2545.9 | 118.2 | 2437.8 | + |-----------|--------|-----------|-----------|----------|---------------| + | iteration | update | opt score | new score | StrandSS | StrandPairRNA | + | 1 | 1 | 2545.9 | 2593.0 | 120.2 | 2425.6 | + |-----------|--------|-----------|-----------|----------|---------------| + | iteration | update | opt score | new score | StrandSS | StrandPairRNA | + | 2 | 1 | 2545.9 | 2563.1 | 120.2 | 2425.6 | + |-----------|--------|-----------|-----------|----------|---------------| + | iteration | update | opt score | new score | StrandSS | StrandPairRNA | + | 3 | 1 | 2545.9 | 2545.0 | 120.2 | 2425.6 | + |-----------|--------|-----------|-----------|----------|---------------| + | iteration | update | opt score | new score | StrandSS | StrandPairRNA | + | 4 | 2 | 2545.0 | 2510.1 | 121.0 | 2423.9 | + + If False, then the screen output is updated in place: + + .. code-block:: console + + $ python sst_canvas.py + using random seed of 1; use this same seed to reproduce this run + number of processes in system: 4 + |-----------|--------|-----------|-----------|----------|---------------| + | iteration | update | opt score | new score | StrandSS | StrandPairRNA | + | 27 | 14 | 2340.5 | 2320.5 | 109.6 | 2230.9 | + + + This is done by printing the symbol '\r' (carriage return), which sets the print position + back to the start of the line. The terminal screen must be wide enough to handle the output or + this won't work. + + The search also occassionally logs other things to the screen that may disrupt this a bit. + """ + def __post_init__(self): self._check_constraint_types() @@ -965,8 +883,6 @@ def search_for_dna_sequences(design: nc.Design, params: SearchParameters) -> Non for flexibility. """ - design.check_all_subdomain_graphs_acyclic() - design.check_all_subdomain_graphs_uniquely_assignable() if params.random_seed is not None: logger.info(f'using random seed of {params.random_seed}; ' @@ -976,6 +892,10 @@ def search_for_dna_sequences(design: nc.Design, params: SearchParameters) -> Non # StrandPool that contains them. # domain_to_strand: Dict[dc.Domain, dc.Strand] = _check_design(design) design.compute_derived_fields() + + design.check_all_subdomain_graphs_acyclic() + design.check_all_subdomain_graphs_uniquely_assignable() + design.check_names_unique() _check_design(design) directories = _setup_directories(params) @@ -1012,71 +932,59 @@ def search_for_dna_sequences(design: nc.Design, params: SearchParameters) -> Non if rng_restart is not None: rng = rng_restart - violation_set_opt, domains_opt, scores_opt = _find_violations_and_score( - design=design, params=params, iteration=-1) + eval_set = EvaluationSet(params.constraints, params.never_increase_score) + eval_set.evaluate_all(design) if not params.restart: # write initial sequences and report _write_intermediate_files(design=design, params=params, rng=rng, num_new_optimal=num_new_optimal, - directories=directories, violation_set=violation_set_opt) - - # this helps with logging if we execute no iterations - violation_set_new = violation_set_opt + directories=directories, eval_set=eval_set) iteration = 0 - stopwatch = Stopwatch() - while not _done(iteration, params, violation_set_opt): + while not _done(iteration, params, eval_set): if params.log_time: stopwatch.restart() _check_cpu_count(cpu_count) - domains_changed, original_sequences = _reassign_domains(domains_opt, scores_opt, - params.max_domains_to_change, rng) + domains_new, original_sequences = _reassign_domains(eval_set, params.max_domains_to_change, rng) # evaluate constraints on new Design with domain_to_change's new sequence - violation_set_new, domains_new, scores_new = _find_violations_and_score( - design=design, params=params, domains_changed=domains_changed, - violation_set_old=violation_set_opt, iteration=iteration) + eval_set.evaluate_new(design, domains_new=domains_new) # _double_check_violations_from_scratch(design=design, params=params, iteration=iteration, - # violation_set_new=violation_set_new, - # violation_set_opt=violation_set_opt) + # eval_set=eval_set) - _log_constraint_summary(params=params, - violation_set_opt=violation_set_opt, violation_set_new=violation_set_new, + _log_constraint_summary(params=params, eval_set=eval_set, iteration=iteration, num_new_optimal=num_new_optimal) # based on total score of new constraint violations compared to optimal assignment so far, # decide whether to keep the change - # score_delta = violation_set_new.total_score() - violation_set_opt.total_score() - score_delta = violation_set_new.total_score_nonfixed() - violation_set_opt.total_score_nonfixed() + score_delta = -eval_set.calculate_score_gap() prob_keep_change = params.probability_of_keeping_change(score_delta) keep_change = rng.random() < prob_keep_change if prob_keep_change < 1 else True if not keep_change: - _unassign_domains(domains_changed, original_sequences) + _unassign_domains(domains_new, original_sequences) + eval_set.reset_new() else: # keep new sequence and update information about optimal solution so far - domains_opt = domains_new - scores_opt = scores_new - violation_set_opt = violation_set_new + eval_set.replace_with_new() if score_delta < 0: # increment whenever we actually improve the design num_new_optimal += 1 on_improved_design(num_new_optimal) # type: ignore _write_intermediate_files(design=design, params=params, rng=rng, num_new_optimal=num_new_optimal, directories=directories, - violation_set=violation_set_opt) + eval_set=eval_set) iteration += 1 if params.log_time: stopwatch.stop() _log_time(stopwatch) - _log_constraint_summary(params=params, - violation_set_opt=violation_set_opt, violation_set_new=violation_set_new, + _log_constraint_summary(params=params, eval_set=eval_set, iteration=iteration, num_new_optimal=num_new_optimal) finally: @@ -1092,12 +1000,22 @@ def search_for_dna_sequences(design: nc.Design, params: SearchParameters) -> Non nc.logger.removeHandler(directories.info_file_handler) # noqa -def _done(iteration: int, params: SearchParameters, violation_set_opt: nc.ViolationSet) -> bool: - keep_going = ( - (violation_set_opt.total_score() > params.target_score if params.target_score is not None else - violation_set_opt.has_nonfixed_violations()) - and (params.max_iterations is None or iteration < params.max_iterations)) - return not keep_going +def _done(iteration: int, params: SearchParameters, eval_set: EvaluationSet) -> bool: + # unconditinoally stop when max_iterations is reached, if specified + if params.max_iterations is not None and iteration >= params.max_iterations: + return True + + # otherwise if target_score is specified, check that current score is close to it + if params.target_score is not None: + if _is_significantly_greater(eval_set.total_score, params.target_score): + return False + else: + # otherwise just see if any violations remain that are not fixed + # (i.e., that might be correctable by changing domains; fixed violations are un-solvable) + if eval_set.has_nonfixed_violations(): + return False + + return True def create_report(design: nc.Design, constraints: Iterable[Constraint]) -> str: @@ -1120,14 +1038,13 @@ def create_report(design: nc.Design, constraints: Iterable[Constraint]) -> str: :return: string describing a report of how well `design` does according to `constraints` """ - violation_set: nc.ViolationSet = _violations_of_constraints( - design=design, constraints=constraints, never_increase_score=False, - domains_changed=None, violation_set_old=None, iteration=0) + evaluation_set = EvaluationSet(constraints, False) + evaluation_set.evaluate_all(design) content = f'''\ Report on constraints ===================== -''' + summary_of_constraints(constraints, True, violation_set=violation_set) +''' + summary_of_constraints(constraints, True, eval_set=evaluation_set) return content @@ -1167,31 +1084,27 @@ def _setup_directories(params: SearchParameters) -> _Directories: return directories -def _reassign_domains(domains_opt: List[Domain], scores_opt: List[float], max_domains_to_change: int, +def _reassign_domains(eval_set: EvaluationSet, max_domains_to_change: int, rng: np.random.Generator) -> Tuple[List[Domain], Dict[Domain, str]]: # pick domain to change, with probability proportional to total score of constraints it violates # first weight scores by domain's weight - scores_weighted = [score * domain.weight for domain, score in zip(domains_opt, scores_opt)] + domains = list(eval_set.domain_to_score.keys()) + scores_weighted = [score * domain.weight for domain, score in eval_set.domain_to_score.items()] probs_opt = np.asarray(scores_weighted) probs_opt /= probs_opt.sum() num_domains_to_change = 1 if max_domains_to_change == 1 \ else rng.choice(a=range(1, max_domains_to_change + 1)) - domains_changed: List[Domain] = list(rng.choice(a=domains_opt, p=probs_opt, replace=False, + domains_changed: List[Domain] = list(rng.choice(a=domains, p=probs_opt, replace=False, size=num_domains_to_change)) - # print(f'domains_changed: {domains_changed}') - # fixed Domains should never be blamed for constraint violation assert all(not domain_changed.fixed for domain_changed in domains_changed) - # dependent domains also cannot be blamed, since their independent source should be blamed + # dependent domains also cannot be blamed, since their independent source should have been blamed assert all(not domain_changed.dependent for domain_changed in domains_changed) original_sequences: Dict[Domain, str] = {} - # # first re-assign independent domains - # independent_domains = [domain for domain in domains_changed if not domain.dependent] - for domain in domains_changed: # set sequence of domain_changed to random new sequence from its DomainPool assert domain not in original_sequences @@ -1200,19 +1113,6 @@ def _reassign_domains(domains_opt: List[Domain], scores_opt: List[float], max_do new_sequence = domain.pool.generate_sequence(rng, previous_sequence) domain.set_sequence(new_sequence) - # # then for each dependent domain, find the independent domain in its tree that can change it, - # # and re-assign that domain - # dependent_domains = [domain for domain in domains_changed if domain.dependent] - # for dependent_domain in dependent_domains: - # independent_domain = dependent_domain.independent_ancestor_or_descendent() - # assert independent_domain not in original_sequences - # previous_sequence = independent_domain.sequence() - # original_sequences[independent_domain] = previous_sequence - # new_sequence = independent_domain.pool.generate_sequence(rng, previous_sequence) - # independent_domain.set_sequence(new_sequence) - # domains_changed.remove(dependent_domain) - # domains_changed.append(independent_domain) - return domains_changed, original_sequences @@ -1225,30 +1125,44 @@ def _unassign_domains(domains_changed: Iterable[Domain], original_sequences: Dic # to think a new assignment was better than the optimal so far, but a mistake in score accounting # from quitting early meant we had simply stopped looking for violations too soon. def _double_check_violations_from_scratch(design: nc.Design, params: SearchParameters, iteration: int, - violation_set_new: nc.ViolationSet, - violation_set_opt: nc.ViolationSet): - violation_set_new_fs, domains_new_fs, scores_new_fs = _find_violations_and_score( - design=design, params=params, iteration=iteration) - # XXX: we shouldn't check that the actual scores are close if quit_early is enabled, because then - # the total score found on quitting early will be less than the total score if not. - # But uncomment this, while disabling quitting early, to test more precisely for "wrong total score". - # import math - # if not math.isclose(violation_set_new.total_score(), violation_set_new_fs.total_score()): - # Instead, we check whether the total score lie on different sides of the opt total score, i.e., - # they make different decisions about whether to change to the new assignment - if (violation_set_new_fs.total_score() - > violation_set_opt.total_score() - >= violation_set_new.total_score()) or \ - (violation_set_new_fs.total_score() - <= violation_set_opt.total_score() - < violation_set_new.total_score()): - logger.warning(f'WARNING! There is a bug in nuad.') - logger.warning(f'total score opt = {violation_set_opt.total_score()}') - logger.warning(f'From scratch, we calculated score {violation_set_new_fs.total_score()}.') - logger.warning(f'Iteratively, we calculated score {violation_set_new.total_score()}.') - logger.warning(f'This means the iterative search is saying something different about ' - f'quitting early than the full search. ') - logger.warning(f'This happened on iteration {iteration}.') + eval_set: EvaluationSet): + eval_set_from_scratch = EvaluationSet(params.constraints, params.never_increase_score) + eval_set_from_scratch.evaluate_all(design) + score_new = eval_set.total_score_new() + score_opt = eval_set.total_score + score_fs = eval_set_from_scratch.total_score + + problem = False + if not params.never_increase_score: + if _is_significantly_different(score_fs, score_new): + problem = True + else: + # XXX: we shouldn't check that the actual scores are close if quit_early is enabled, because then + # the total score found on quitting early will be less than the total score if not. + # But uncomment this, while disabling quitting early, to test more precisely for "wrong total score". + # import math + # if not math.isclose(violation_set_new.total_score(), violation_set_new_fs.total_score()): + # Instead, we check whether the total score lie on different sides of the opt total score, i.e., + # they make different decisions about whether to change to the new assignment + # if ((score_new + # <= score_opt + # < score_fs) or + # (score_fs + # <= score_opt + # < score_new)): + if ((_is_significantly_less(score_new, score_opt) + and _is_significantly_less(score_opt, score_fs)) or + ((_is_significantly_less(score_fs, score_opt) + and _is_significantly_less(score_opt, score_new)))): + problem = True + if problem: + logger.warning(f'''\ +WARNING! There is a bug in nuad. +From scratch, we calculated score {score_fs}. +The optimal score so far is {score_opt}. +Iteratively, we calculated score {score_new}. +This means the iterative search is saying something different about quitting early than the full search. ' +This happened on iteration {iteration}.''') sys.exit(-1) @@ -1416,95 +1330,69 @@ def _log_time(stopwatch: Stopwatch, include_median: bool = False) -> None: time_last_n_calls_available = True -def _find_violations_and_score(design: Design, - params: SearchParameters, - domains_changed: Optional[Iterable[Domain]] = None, - violation_set_old: Optional[nc.ViolationSet] = None, - iteration: int = -1) \ - -> Tuple[nc.ViolationSet, List[Domain], List[float]]: - """ - :param design: - :any:`Design` to evaluate - :param domains_changed: - The :any:`Domain` that just changed; - if None, then recalculate all constraints, - otherwise assume no constraints changed that do not involve `domain` - :param violation_set_old: - :any:`ViolationSet` to update, assuming `domain_changed` is the only :any:`Domain` that changed - :param iteration: - Current iteration number; useful for debugging (e.g., conditional breakpoints). - :return: - Tuple (violations, domains, scores) - `violations`: dict mapping each domain to list of constraints that they violated - `domains`: list of :any:`Domain`'s that caused violations - `scores`: list of scores for each :any:`Domain`, in same order the domains appear, giving - the total score of :any:`Constraint`'s violated by the corresponding :any:`Domain` - """ - - violation_set: nc.ViolationSet = _violations_of_constraints( - design, params.constraints, params.never_increase_score, - domains_changed, violation_set_old, iteration) - - # NOTE: this filters out the fixed domains, - # but we keep them in violation_set for the sake of reports - domain_to_score: Dict[Domain, float] = { - domain: sum(violation.score for violation in domain_violations) - for domain, domain_violations in violation_set.domain_to_violations.items() - if not domain.fixed - } - domains = list(domain_to_score.keys()) - scores = list(domain_to_score.values()) - - return violation_set, domains, scores - - -def _flatten(list_of_lists: Iterable[Iterable[Any]]) -> Iterable[Any]: +def _flatten(list_of_lists: Iterable[Iterable[T]]) -> Iterable[T]: # Flatten one level of nesting return itertools.chain.from_iterable(list_of_lists) +def _remove_first_lines_from_string(s: str, num_lines: int) -> str: + return '\n'.join(s.split('\n')[num_lines:]) + + def _log_constraint_summary(*, params: SearchParameters, - violation_set_opt: nc.ViolationSet, - violation_set_new: nc.ViolationSet, + eval_set: EvaluationSet, iteration: int, num_new_optimal: int) -> None: - score_header = '\niteration|updates|opt score||new score|' - all_constraints_header = '|'.join( - f'{constraint.short_description}' for constraint in params.constraints) - header = score_header + all_constraints_header - - score_opt = violation_set_opt.total_score() - score_new = violation_set_new.total_score() - dec_opt = max(1, math.ceil(math.log(1 / score_opt, 10)) + 2) if score_opt > 0 else 1 - dec_new = max(1, math.ceil(math.log(1 / score_new, 10)) + 2) if score_new > 0 else 1 - score_str = f'{iteration:9}|{num_new_optimal:7}|' \ - f'{score_opt :9.{dec_opt}f}||' \ - f'{score_new :9.{dec_new}f}|' # \ + # If output is not scrolling, only print this once on first iteration. + if params.scrolling_output or iteration == 0: + row1 = ['iteration', 'update', 'opt score', 'new score'] + [f'{constraint.short_description}' + for constraint in params.constraints] + header = tabulate([row1], tablefmt='github') + print(header) + + def _dec(score: float) -> int: + # how many decimals after decimal point to use given the score + dec_opt = max(1, math.ceil(math.log(1 / score, 10)) + 2) if score > 0 else 1 + return dec_opt + + score_opt = eval_set.total_score + score_new = eval_set.total_score_new() + + dec_opt = _dec(score_opt) + dec_new = _dec(score_new) all_constraints_strs = [] for constraint in params.constraints: - score = violation_set_new.score_of_constraint(constraint) + score = eval_set.score_of_constraint(constraint, True) length = len(constraint.short_description) - num_decimals = max(1, math.ceil(math.log(1 / score, 10)) + 2) if score > 0 else 1 + num_decimals = _dec(score) constraint_str = f'{score:{length}.{num_decimals}f}' + # round further if this would exceed length + if len(constraint_str) > length: + excess = len(constraint_str) > length + num_decimals -= excess + if num_decimals < 0: + num_decimals = 0 + constraint_str = f'{score:{length}.{num_decimals}f}' all_constraints_strs.append(constraint_str) - all_constraints_str = '|'.join(all_constraints_strs) + # all_constraints_str = '|'.join(all_constraints_strs) # logger.info(header + '\n' + score_str + all_constraints_str) - #TODO: use floatfmt per column to adjust decimal places - import tabulate as tb - tb.PRESERVE_WHITESPACE = True + # TODO: use floatfmt per column to adjust decimal places row1 = ['iteration', 'update', 'opt score', 'new score'] + [f'{constraint.short_description}' - for constraint in params.constraints] - iteration_str = f'{iteration:9}' - num_new_optimal_str = f'{num_new_optimal:6}' + for constraint in params.constraints] + # iteration_str = f'{iteration:9}' + # num_new_optimal_str = f'{num_new_optimal:6}' score_opt_str = f'{score_opt :9.{dec_opt}f}' score_new_str = f'{score_new :9.{dec_new}f}' row2 = [iteration, num_new_optimal, score_opt_str, score_new_str] + all_constraints_strs # type:ignore table = [row1, row2] table_str = tabulate(table, tablefmt='github', numalign='right', stralign='right') - logger.info(table_str) + table_str = _remove_first_lines_from_string(table_str, 2) + # logger.info(table_str) + line_end = '\n' if params.scrolling_output else '\r' + print(table_str, end=line_end) def assign_sequences_to_domains_randomly_from_pools(design: Design, @@ -1536,11 +1424,12 @@ def assign_sequences_to_domains_randomly_from_pools(design: Design, for domain in independent_domains: skip_nonfixed_msg = skip_fixed_msg = None if warn_fixed_sequences and domain.has_sequence(): - skip_nonfixed_msg = f'Skipping assignment of DNA sequence to domain {domain.name}. ' \ - f'That domain has a NON-FIXED sequence {domain.sequence()}, ' \ + skip_nonfixed_msg = f'Skipping initial assignment of DNA sequence to domain {domain.name}. ' \ + f'That domain currently has a non-fixed sequence {domain.sequence()}, ' \ f'which the search will attempt to replace.' - skip_fixed_msg = f'Skipping assignment of DNA sequence to domain {domain.name}. ' \ - f'That domain has a FIXED sequence {domain.sequence()}.' + skip_fixed_msg = f'Skipping initial assignment of DNA sequence to domain {domain.name}. ' \ + f'That domain has a fixed sequence {domain.sequence()}, ' \ + f'and the search will not replace it.' if overwrite_existing_sequences: if not domain.fixed: at_least_one_domain_unfixed = True @@ -1616,20 +1505,487 @@ def keep_change_only_if_no_worse(score_delta: float) -> float: # return keep_change_only_if_better +K1 = TypeVar('K1') +K2 = TypeVar('K2') +V = TypeVar('V') + + +# convenience methods for iterating over 2D dicts + +def keys_2d(dct: Dict[K1, Dict[K2, V]]) -> Iterator[Tuple[K1, K2]]: + for first_key in dct.keys(): + for second_key in dct[first_key].keys(): + yield first_key, second_key + + +def values_2d(dct: Dict[K1, Dict[K2, V]]) -> Iterator[V]: + for first_key in dct.keys(): + for value in dct[first_key].values(): + yield value + + +def items_2d(dct: Dict[K1, Dict[K2, V]]) -> Iterator[Tuple[Tuple[K1, K2], V]]: + for first_key in dct.keys(): + for second_key, value in dct[first_key].items(): + yield (first_key, second_key), value + + +def key_exists_in_2d_dict(dct: Dict[K1, Dict[K2, Any]], key1: K1, key2: K2) -> bool: + # avoid using the brakcet [] operator (to avoid triggering the default dict behavior + if key1 not in dct.keys(): + return False + return key2 in dct[key1][key2] + + +@dataclass +class EvaluationSet: + # Represents evaluations of :any:`Constraint`'s in a :any:`Design`. + # + # It is designed to be efficiently updateable when a single :any:`Domain` changes, + # to efficiently update only those evaluations of :any:`Constraint`'s that could have been + # affected by the changed :any:`Domain`. + + constraints: List[Constraint] + # list of all constraints + + evaluations: Dict[Constraint, Dict[nc.Part, Evaluation]] + # "2D dict" mapping each (Constraint, Part) to the list of all Evaluations of it. + # has keys for every (Constraint, Part) instance. + + evaluations_new: Dict[Constraint, Dict[nc.Part, Evaluation]] + # "2D dict" mapping some Constraint, Part to the list of all new Evaluation's of it + # (after changing domain(s)). + # Unlike evaluations, only has keys for parts affected by the most recent domain changes. + + domain_to_evaluations: Dict[Domain, List[Evaluation]] + # Dict mapping each Domain to the set of all Evaluations for which it is blamed + + violations: Dict[Constraint, Dict[nc.Part, Evaluation]] + # "2D dict" mapping each (Constraint, Part) to the list of all violations of it. + # (evaluation that had a positive score) + # has keys for every Constraint + # only has keys for Parts that are violations + + violations_new: Dict[Constraint, Dict[nc.Part, Evaluation]] + # "2D dict" mapping some Constraint, Part to the list of all new Evaluations of it + # (after changing domain(s)). + # Unlike violations, only has keys for parts affected by the most recent domain changes. + + domain_to_evaluations: Dict[Domain, List[Evaluation]] + # Dict mapping each :any:`constraint.Domain` to the set of all :any:`Evaluation`'s for which it is blamed + + domain_to_evaluations_new: Dict[Domain, List[Evaluation]] + + domain_to_violations: Dict[Domain, List[Evaluation]] + # Dict mapping each :any:`constraint.Domain` to the set of all :any:`Evaluation`'s for which it is blamed + + domain_to_violations_new: Dict[Domain, List[Evaluation]] + + domain_to_score: Dict[Domain, float] + + domain_to_score_new: Dict[Domain, float] + + total_score: float + # sum of scores of all evalutions + + total_score_nonfixed: float + # for evaluations blaming Domains with Domain.fixed = False + + total_score_fixed: float + # for evaluations blaming Domains with Domain.fixed = True + + num_evaluations: int + num_evaluations_fixed: int + num_evaluations_nonfixed: int + + num_violations: int + num_violations_fixed: int + num_violations_nonfixed: int + + def __init__(self, constraints: Iterable[Constraint], never_increase_score: bool) -> None: + self.constraints = list(constraints) + self.never_increase_score = never_increase_score + self.reset_all() + + def __repr__(self): + all_evals: List[Evaluation] = [evaluation + for part_to_eval in self.evaluations.values() + for evaluation in part_to_eval.values()] + lines = "\n ".join(map(str, all_evals)) + return f'EvaluationSet(\n {lines})' + + def __str__(self): + return repr(self) + + def reset_all(self) -> None: + self.evaluations = {constraint: {} for constraint in self.constraints} + self.violations = {constraint: {} for constraint in self.constraints} + self.domain_to_evaluations = defaultdict(list) + self.domain_to_violations = defaultdict(list) + self.domain_to_score = defaultdict(float) + self.reset_new() + + def reset_new(self) -> None: + self.evaluations_new = {constraint: {} for constraint in self.constraints} + self.violations_new = {constraint: {} for constraint in self.constraints} + self.domain_to_evaluations_new = defaultdict(list) + self.domain_to_violations_new = defaultdict(list) + self.domain_to_score_new = defaultdict(float) + + def evaluate_all(self, design: Design) -> None: + # called on all parts of the design and sets self.evaluations + self.reset_all() + for constraint in self.constraints: + self.evaluate_constraint(constraint, design, None, None) + self.domain_to_score = EvaluationSet.sum_domain_scores(self.domain_to_violations) + self.update_scores_and_counts() + # _assert_violations_are_accurate(self.evaluations, self.violations) + + def evaluate_new(self, design: Design, domains_new: List[Domain]) -> None: + # called only on changed parts of the design and sets self.evaluations_new + # does quit early optimization since this is only called when comparing to an existing set of evals + self.reset_new() + score_gap = None + if self.never_increase_score: + score_gap = self.calculate_initial_score_gap(design, domains_new) + for constraint in self.constraints: + score_gap = self.evaluate_constraint(constraint, design, score_gap, domains_new) + if score_gap is not None and _is_significantly_greater(0.0, score_gap): + break + self.domain_to_score_new = EvaluationSet.sum_domain_scores(self.domain_to_violations_new) + + @staticmethod + def sum_domain_scores(domain_to_violations: Dict[Domain, List[Evaluation]]) -> Dict[Domain, float]: + # NOTE: this filters out the fixed domains, + # but we keep them in eval_set for the sake of reports + domain_to_score = { + domain: sum(violation.score for violation in domain_violations) + for domain, domain_violations in domain_to_violations.items() + if not domain.fixed + } + return domain_to_score + + def calculate_score_gap(self, fixed: Optional[bool] = None) -> Optional[float]: + # fixed is None (all violations), True (fixed violations), or False (nonfixed violations) + # total score of evaluations - total score of new evaluations + assert len(self.evaluations) > 0 + total_gap = 0.0 + for ((constraint, part), eval_new) in items_2d(self.evaluations_new): + eval_old = self.evaluations[constraint][part] + assert eval_old.part.fixed == eval_new.part.fixed + assert eval_old.violated == (eval_old.score > 0) + assert eval_new.violated == (eval_new.score > 0) + if fixed is None or eval_old.part.fixed == fixed: + total_gap += eval_old.score - eval_new.score + return total_gap + + def calculate_initial_score_gap(self, design: Design, domains_new: List[Domain]) -> float: + # before evaluations_new is populated, we need to calculate the total score of evaluations + # on parts affected by domains_new, which is the score gap assuming all new evaluations come up 0 + score_gap = 0.0 + for constraint in self.constraints: + parts = find_parts_to_check(constraint, design, domains_new) + for part in parts: + ev = self.evaluations[constraint][part] + score_gap += ev.score + return score_gap + + def evaluate_constraint(self, + constraint: Constraint[DesignPart], + design: Design, # only used with DesignConstraint + score_gap: Optional[float], + domains_new: Optional[Iterable[Domain]], + ) -> float: + # returns score gap = score(old evals) - score(new evals); + # if gap > 0, then new evals haven't added up to + assert ((score_gap is None and domains_new is None) or + (score_gap is not None and domains_new is not None)) + + parts = find_parts_to_check(constraint, design, domains_new) + + # measure violations of constraints and collect in list of triples (part, score, summary) + violating_parts_scores_summaries: List[Tuple[DesignPart, float, str]] = [] + if isinstance(constraint, SingularConstraint): + if not constraint.parallel or len(parts) == 1 or nc.cpu_count() == 1: + for part in parts: + seqs = tuple(indv_part.sequence() for indv_part in part.individual_parts()) + score, summary = constraint.call_evaluate(seqs, part) + violating_parts_scores_summaries.append((part, score, summary)) + if score > 0.0: + if score_gap is not None: + score_gap -= score + if _is_significantly_greater(0.0, score_gap): + break + else: + raise NotImplementedError('TODO: implement parallelization') + + elif isinstance(constraint, (BulkConstraint, DesignConstraint)): + if isinstance(constraint, DesignConstraint): + violating_parts_scores_summaries = constraint.call_evaluate_design(design, domains_new) + else: + # XXX: I don't understand the mypy error on the next line + violating_parts_scores_summaries = constraint.call_evaluate_bulk(parts) # type: ignore + + # we can't quit this function early, + # but we can let the caller know to stop evaluating constraints + total_score = sum(score for _, score, _ in violating_parts_scores_summaries) + if score_gap is not None: + score_gap -= total_score + else: + raise AssertionError( + f'constraint {constraint} of unrecognized type {constraint.__class__.__name__}') + + # assign blame for violations to domains by looking up associated domains in each part + evals_of_constraint = self.evaluations[constraint] + viols_of_constraint = self.violations[constraint] + domain_to_evals = self.domain_to_evaluations + domain_to_viols = self.domain_to_violations + if domains_new is not None: + evals_of_constraint = self.evaluations_new[constraint] + viols_of_constraint = self.violations_new[constraint] + domain_to_evals = self.domain_to_evaluations_new + domain_to_viols = self.domain_to_violations_new + + for part, score, summary in violating_parts_scores_summaries: + domains = _independent_domains_in_part(part, exclude_fixed=False) + evaluation = Evaluation(constraint=constraint, part=part, domains=domains, + score=score, summary=summary, violated=score > 0) + + evals_of_constraint[part] = evaluation + for domain in domains: + domain_to_evals[domain].append(evaluation) + + if evaluation.violated: + viols_of_constraint[part] = evaluation + for domain in domains: + domain_to_viols[domain].append(evaluation) + + return score_gap + + def replace_with_new(self) -> None: + # uses Evaluations in self.evaluations_new to replace equivalent ones in self.evalautions + # same for violations + + # IMPORTANT: update scores before we start to modify the EvaluationSet + self.total_score = self.total_score_new() + self.total_score_fixed = self.total_score_new(True) + self.total_score_nonfixed = self.total_score_new(False) + + # update all evaluations + for (constraint, part), evaluation in items_2d(self.evaluations_new): + # CONSIDER updating everything in this loop by looking up eval.violated + self.evaluations[constraint][part] = evaluation + + # update dict mapping domain to list of evals/violations for which it is blamed + for domain in evaluation.domains: + self.domain_to_evaluations[domain] = self.domain_to_evaluations_new[domain] + self.domain_to_violations[domain] = self.domain_to_violations_new[domain] + + viols_by_part = self.violations[constraint] + if evaluation.violated: + # if was not a violation before, increment total violations + if part not in viols_by_part: + self.num_violations += 1 + if evaluation.part.fixed: + self.num_violations_fixed += 1 + else: + self.num_violations_nonfixed += 1 + # add it to violations, or replace existing violation + viols_by_part[part] = evaluation + elif not evaluation.violated and part in viols_by_part.keys(): + # otherwise remove violation if one was there from old EvaluationSet, + # and decrement total violations + del viols_by_part[part] + self.num_violations -= 1 + if evaluation.part.fixed: + self.num_violations_fixed -= 1 + else: + self.num_violations_nonfixed -= 1 + + self.reset_new() + + _assert_violations_are_accurate(self.evaluations, self.violations) + + def update_scores_and_counts(self) -> None: + """ + :return: Total score of all evaluations. + """ + self.total_score = self.total_score_fixed = self.total_score_nonfixed = 0.0 + self.num_evaluations = self.num_evaluations_nonfixed = self.num_evaluations_fixed = 0 + self.num_violations = self.num_violations_nonfixed = self.num_violations_fixed = 0 + + # count evaluations + for evaluation in values_2d(self.evaluations): + self.num_evaluations += 1 + if evaluation.part.fixed: + self.num_evaluations_fixed += 1 + else: + self.num_evaluations_nonfixed += 1 + + # count violations and score + for violation in values_2d(self.violations): + self.total_score += violation.score + self.num_violations += 1 + if violation.part.fixed: + self.total_score_fixed += violation.score + self.num_violations_fixed += 1 + else: + self.total_score_nonfixed += violation.score + self.num_violations_nonfixed += 1 + + def total_score_new(self, fixed: Optional[bool] = None) -> float: + # return total score of all evaluations, or only fixed, or only nonfixed + if fixed is None: + total_score_old = self.total_score + elif fixed is True: + total_score_old = self.total_score_fixed + elif fixed is False: + total_score_old = self.total_score_nonfixed + else: + raise AssertionError(f'fixed should be None, True, or False, but is {fixed}') + + total_score_new = total_score_old - self.calculate_score_gap(fixed) + return total_score_new + + def score_of_constraint(self, constraint: Constraint, violations: bool) -> float: + # :param constraint: + # constraint to filter scores on + # :return: + # Total score of all evaluations due to `constraint`. + return sum(evaluation.score for evaluation in self.evaluations_of_constraint(constraint, violations)) + + def score_of_constraint_nonfixed(self, constraint: Constraint, violations: bool) -> float: + # :param constraint: + # constraint to filter scores on + # :return: + # Total score of all nonfixed evaluations due to `constraint`. + return sum(evaluation.score for evaluation in self.evaluations_of_constraint(constraint, violations) + if not evaluation.part.fixed) + + def score_of_constraint_fixed(self, constraint: Constraint, violations: bool) -> float: + # :param constraint: + # constraint to filter scores on + # :return: + # Total score of all fixed violations due to `constraint`. + return sum(evaluation.score for evaluation in self.evaluations_of_constraint(constraint, violations) + if evaluation.part.fixed) + + def has_nonfixed_evaluations(self) -> bool: + # :return: whether there are any nonfixed Evaluations in this EvaluationSet + return self.num_evaluations_nonfixed > 0 + + def has_nonfixed_violations(self) -> bool: + # :return: whether there are any nonfixed Evaluations in this EvaluationSet + return self.num_violations_nonfixed > 0 + + def evaluations_of_constraint(self, constraint: Constraint, violations: bool) -> List[Evaluation]: + dct = self.violations[constraint] if violations else self.evaluations[constraint] + return list(dct.values()) + + def evaluations_nonfixed_of_constraint(self, constraint: Constraint, + violations: bool) -> List[Evaluation]: + return [ev for ev in self.evaluations_of_constraint(constraint, violations) if not ev.part.fixed] + + def evaluations_fixed_of_constraint(self, constraint: Constraint, violations: bool) -> List[Evaluation]: + return [ev for ev in self.evaluations_of_constraint(constraint, violations) if ev.part.fixed] + + def num_evaluations_of(self, constraint: Constraint) -> int: + return len(self.evaluations[constraint]) + + def num_violations_of(self, constraint: Constraint) -> int: + return len(self.violations[constraint]) + + +def _assert_violations_are_accurate(evaluations: Dict[Constraint, Dict[nc.Part, Evaluation]], + violations: Dict[Constraint, Dict[nc.Part, Evaluation]]) -> None: + # go through all violations and ensure the violations in it are all in evaluations + for (constraint, part), viol in items_2d(violations): + assert constraint in evaluations.keys() + evals_by_part = evaluations[constraint] + assert part in evals_by_part.keys() + ev = evals_by_part[part] + assert viol == ev + assert viol.violated # also assert that they really are violated + + # now go in reverse and ensure the keys in evaluations not in violations are all not violated + for (constraint, part), ev in items_2d(evaluations): + if part in violations[constraint].keys(): + viol = violations[constraint][part] + assert viol == ev + assert ev.violated + else: + assert not ev.violated + + +@dataclass(frozen=True) +class Evaluation(Generic[DesignPart]): + # Represents a violation of a single :any:`Constraint` in a :any:`Design`. + # The "part" of the :any:`Design` that was evaluated for the constraint is generic type `DesignPart` + # (e.g., for :any:`StrandPairConstraint`, DesignPart = :any:`Pair` [:any:`Strand`]). + + constraint: Constraint + # :any:`Constraint` that was evaluated to result in this :any:`Evaluation`. + + violated: bool + # whether the :any:`Constraint` was violated last time it was evaluated + + part: DesignPart + # DesignPart that caused this violation + + domains: FrozenSet[Domain] # = field(init=False, hash=False, compare=False, default=None) + # :any:`Domain`'s that were involved in violating :py:data:`Evaluation.constraint` + + summary: str + + score: float + + def __init__(self, constraint: Constraint, violated: bool, part: DesignPart, domains: Iterable[Domain], + score: float, summary: str) -> None: + # :param constraint: + # :any:`Constraint` that was violated to result in this + # :param domains: + # :any:`Domain`'s that were involved in violating :py:data:`Evaluation.constraint` + # :param score: + # total "score" of this violation, typically something like an excess energy over a + # threshold, squared, multiplied by the :data:`Constraint.weight` + object.__setattr__(self, 'constraint', constraint) + object.__setattr__(self, 'violated', violated) + object.__setattr__(self, 'part', part) + domains_frozen = frozenset(domains) + object.__setattr__(self, 'domains', domains_frozen) + object.__setattr__(self, 'score', score) + object.__setattr__(self, 'summary', summary) + + def __repr__(self) -> str: + return f'Evaluation({self.constraint.short_description}, score={self.score:.2f}, ' \ + f'summary={self.summary}, violated={self.violated})' + + def __str__(self) -> str: + return repr(self) + + # Evaluation equality based on identity; different Evaluations in memory are considered different, + # even if all data between them matches. Don't create the same Evaluation twice! + def __hash__(self): + return super().__hash__() + + def __eq__(self, other): + return self is other + + #################################################################################### # report generating functions def summary_of_constraints(constraints: Iterable[Constraint], report_only_violations: bool, - violation_set: nc.ViolationSet) -> str: + eval_set: EvaluationSet) -> str: summaries: List[str] = [] # other constraints for constraint in constraints: - summary = summary_of_constraint(constraint, report_only_violations, violation_set) + summary = summary_of_constraint(constraint, report_only_violations, eval_set) summaries.append(summary) - score = violation_set.total_score() - score_unfixed = violation_set.total_score_nonfixed() + score = eval_set.total_score + score_unfixed = eval_set.total_score_nonfixed score_total_summary = f'total score of constraint violations: {score:.2f}' score_unfixed_summary = f'total score of unfixed constraint violations: {score_unfixed:.2f}' @@ -1641,34 +1997,35 @@ def summary_of_constraints(constraints: Iterable[Constraint], report_only_violat def summary_of_constraint(constraint: Constraint, report_only_violations: bool, - violation_set: nc.ViolationSet) -> str: + evaluation_set: EvaluationSet) -> str: if isinstance(constraint, (DomainConstraint, StrandConstraint, DomainPairConstraint, StrandPairConstraint, ComplexConstraint, DomainsConstraint, StrandsConstraint, DomainPairsConstraint, StrandPairsConstraint, ComplexesConstraint)): summaries = [] + num_evals = evaluation_set.num_evaluations_of(constraint) + num_viols = evaluation_set.num_violations_of(constraint) num_violations = 0 - num_checks = violation_set.num_checked[constraint] part_type_name = constraint.part_name() - violations_nonfixed = violation_set.violations_nonfixed[constraint] - violations_fixed = violation_set.violations_fixed[constraint] + evals_nonfixed = evaluation_set.evaluations_nonfixed_of_constraint(constraint, report_only_violations) + evals_fixed = evaluation_set.evaluations_fixed_of_constraint(constraint, report_only_violations) - some_fixed_violations = len(violations_fixed) > 0 + some_fixed_evals = len(evals_fixed) > 0 - for violations, header_name in [(violations_nonfixed, f"unfixed {part_type_name}s"), - (violations_fixed, f"fixed {part_type_name}s")]: - if len(violations) == 0: + for evals, header_name in [(evals_nonfixed, f"unfixed {part_type_name}s"), + (evals_fixed, f"fixed {part_type_name}s")]: + if len(evals) == 0: continue - max_part_name_length = max(len(violation.part.name) for violation in violations) - num_violations += len(violations) + max_part_name_length = max(len(violation.part.name) for violation in evals) + num_violations += len(evals) lines_and_scores: List[Tuple[str, float]] = [] - for violation in violations: - line = f'{part_type_name} {violation.part.name:{max_part_name_length}}: ' \ - f'{violation.summary}; score: {violation.score:.2f}' - lines_and_scores.append((line, violation.score)) + for ev in evals: + line = f'{part_type_name} {ev.part.name:{max_part_name_length}}: ' \ + f'{ev.summary}; score: {ev.score:.2f}' + lines_and_scores.append((line, ev.score)) lines_and_scores.sort(key=lambda line_and_score: line_and_score[1], reverse=True) @@ -1676,29 +2033,31 @@ def summary_of_constraint(constraint: Constraint, report_only_violations: bool, content = '\n'.join(lines) # only put header to distinguish fixed from unfixed violations if there are some fixed - full_header = _small_header(header_name, "=") if some_fixed_violations else '' + full_header = _small_header(header_name, "=") if some_fixed_evals else '' summary = full_header + f'\n{content}\n' summaries.append(summary) + if report_only_violations and num_violations != num_viols: + assert num_violations == num_viols + content = ''.join(summaries) report = ConstraintReport(constraint=constraint, content=content, - num_violations=num_violations, num_checks=num_checks) + num_violations=num_viols, num_evaluations=num_evals) elif isinstance(constraint, DesignConstraint): raise NotImplementedError() else: content = f'skipping summary of constraint {constraint.description}; ' \ f'unrecognized type {type(constraint)}' - report = ConstraintReport(constraint=constraint, content=content, num_violations=0, num_checks=0) + report = ConstraintReport(constraint=constraint, content=content, num_violations=0, num_evaluations=0) - summary = add_header_to_content_of_summary(report, violation_set, report_only_violations) + summary = add_header_to_content_of_summary(report, evaluation_set) return summary -def add_header_to_content_of_summary(report: ConstraintReport, violation_set: nc.ViolationSet, - report_only_violations: bool) -> str: - score = violation_set.score_of_constraint(report.constraint) - score_unfixed = violation_set.score_of_constraint_nonfixed(report.constraint) +def add_header_to_content_of_summary(report: ConstraintReport, eval_set: EvaluationSet) -> str: + score = eval_set.score_of_constraint(report.constraint, True) + score_unfixed = eval_set.score_of_constraint_nonfixed(report.constraint, True) if score != score_unfixed: summary_score_unfixed = f'\n* unfixed score of violations: {score_unfixed:.2f}' @@ -1706,16 +2065,15 @@ def add_header_to_content_of_summary(report: ConstraintReport, violation_set: nc summary_score_unfixed = None indented_content = textwrap.indent(report.content, ' ') + summary = f''' **{"*" * len(report.constraint.description)} * {report.constraint.description} -* checks: {report.num_checks} -* violations: {report.num_violations} +# evaluations: {report.num_evaluations} +* violations: {report.num_violations} * score of violations: {score:.2f}{"" if summary_score_unfixed is None else summary_score_unfixed} -{indented_content}''' + ('\nThe option "report_only_violations" is currently being ignored ' - 'when set to False\n' - 'see https://github.com/UC-Davis-molecular-computing/dsd/issues/134\n' - if not report_only_violations else '') +{indented_content}''' + return summary @@ -1733,7 +2091,7 @@ class ConstraintReport: constraint: Optional['Constraint'] """ The :any:`Constraint` to report on. This can be None if the :any:`Constraint` object is not available - at the time the :py:meth:`Constraint.generate_summary` function is defined. If so it will be + at the time the :meth:`Constraint.generate_summary` function is defined. If so it will be automatically inserted by the report generating code.""" content: str @@ -1747,7 +2105,7 @@ class ConstraintReport: violated the constraint. """ - num_checks: int + num_evaluations: int """ Total number of "parts" of the :any:`Design` (e.g., :any:`Strand`'s, pairs of :any:`Domain`'s) that were checked against the constraint. diff --git a/requirements.txt b/requirements.txt index 9e6af3b0..df8127d3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,8 +2,9 @@ numpy psutil ordered_set pathos -scadnano xlwt xlrd nupack -typing_extensions; python_version < "3.8" \ No newline at end of file +tabulate +scadnano +typing_extensions; python_version < "3.8" diff --git a/tests/test.py b/tests/test.py index c4048a49..2697b882 100644 --- a/tests/test.py +++ b/tests/test.py @@ -16,12 +16,6 @@ _domain_pools: Dict[int, DomainPool] = {} -def clear_domains_interned() -> None: - """Clear interned domains. - """ - constraints._domains_interned.clear() - - def assign_domain_pool_of_length(length: int) -> DomainPool: """Returns a DomainPool of given size @@ -38,7 +32,7 @@ def assign_domain_pool_of_length(length: int) -> DomainPool: return new_domain_pool -def construct_strand(domain_names: List[str], domain_lengths: List[int]) -> Strand: +def construct_strand(design: Design, domain_names: List[str], domain_lengths: List[int]) -> Strand: """Constructs a strand with given domain names and domain lengths. :param domain_names: Names of the domain on the strand @@ -55,7 +49,7 @@ def construct_strand(domain_names: List[str], domain_lengths: List[int]) -> Stra f'domain_names contained {len(domain_names)} names ' f'but domain_lengths contained {len(domain_lengths)} ' f'lengths') - s: Strand = Strand(domain_names) + s: Strand = design.add_strand(domain_names=domain_names) for (i, length) in enumerate(domain_lengths): s.domains[i].pool = assign_domain_pool_of_length(length) s.compute_derived_fields() @@ -93,19 +87,19 @@ def test_strand_intersecting_domains(self) -> None: self.assertEqual(1, len(s1.domains)) self.assertEqual(a, s1.domains[0]) - s2 = Strand(domains=[b,C], starred_domain_indices=[]) + s2 = Strand(domains=[b, C], starred_domain_indices=[]) self.assertEqual(2, len(s2.domains)) self.assertEqual(b, s2.domains[0]) self.assertEqual(C, s2.domains[1]) - s3 = Strand(domains=[E,F,g,h], starred_domain_indices=[]) + s3 = Strand(domains=[E, F, g, h], starred_domain_indices=[]) self.assertEqual(4, len(s3.domains)) self.assertEqual(E, s3.domains[0]) self.assertEqual(F, s3.domains[1]) self.assertEqual(g, s3.domains[2]) self.assertEqual(h, s3.domains[3]) - s4 = Strand(domains=[E,F,C], starred_domain_indices=[]) + s4 = Strand(domains=[E, F, C], starred_domain_indices=[]) self.assertEqual(3, len(s4.domains)) self.assertEqual(E, s4.domains[0]) self.assertEqual(F, s4.domains[1]) @@ -116,7 +110,7 @@ def test_strand_intersecting_domains(self) -> None: self.assertTrue(s.intersects_domain(domain)) # these strands do not hit every domain - s5 = Strand(domains=[b,g], starred_domain_indices=[]) + s5 = Strand(domains=[b, g], starred_domain_indices=[]) self.assertEqual(2, len(s5.domains)) self.assertEqual(b, s5.domains[0]) self.assertEqual(g, s5.domains[1]) @@ -211,8 +205,8 @@ def test_substrings_circular_except_overlapping_indices(self) -> None: class TestModifyDesignAfterCreated(unittest.TestCase): def setUp(self) -> None: - strand = nc.Strand(domain_names=['x', 'y']) - self.design = nc.Design(strands=[strand]) + self.design = nc.Design() + self.design.add_strand(domain_names=['x', 'y']) def add_domain(self): strand = self.design.strands[0] @@ -286,12 +280,11 @@ class TestExportDNASequences(unittest.TestCase): def test_idt_bulk_export(self) -> None: custom_idt = nc.IDTFields(scale='100nm', purification='PAGE') - strands = [ - nc.Strand(domain_names=['a', 'b*', 'c', 'd*'], name='s0', idt=custom_idt), - nc.Strand(domain_names=['d', 'c*', 'e', 'f'], name='s1'), - ] - design = nc.Design(strands) - # a b c d e f + design = nc.Design() + design.add_strand(domain_names=['a', 'b*', 'c', 'd*'], name='s0', idt=custom_idt) + design.add_strand(domain_names=['d', 'c*', 'e', 'f'], name='s1') + + # a b c d e f seqs = ['AACG', 'CCGT', 'GGTA', 'TTAC', 'AAAACCCC', 'AAAAGGGG'] # s0: AACG-ACGG-GGTA-GTAA # s1: TTAC-TACC-AAAACCCC-AAAAGGGG @@ -319,13 +312,11 @@ def test_write_idt_plate_excel_file(self) -> None: for plate_type in [sc.PlateType.wells96, sc.PlateType.wells384]: filename = f'test_excel_export_{plate_type.num_wells_per_plate()}.xls' - strands = [] + design = nc.Design() for strand_idx in range(3 * plate_type.num_wells_per_plate() + 10): idt = nc.IDTFields() - strand = nc.Strand(name=f's{strand_idx}', domain_names=[f'd{strand_idx}'], idt=idt) + strand = design.add_strand(name=f's{strand_idx}', domain_names=[f'd{strand_idx}'], idt=idt) strand.domains[0].set_fixed_sequence('T' * strand_len) - strands.append(strand) - design = nc.Design(strands=strands) design.write_idt_plate_excel_file(filename=filename, plate_type=plate_type) @@ -355,8 +346,8 @@ def test_NearestNeighborEnergyConstraint_raises_exception_if_energies_in_wrong_o class TestInsertDomains(unittest.TestCase): def setUp(self) -> None: - strands = [nc.Strand(domain_names=['a', 'b*', 'c', 'd*'])] - self.design = nc.Design(strands) + self.design = Design() + self.design.add_strand(domain_names=['a', 'b*', 'c', 'd*']) self.strand = self.design.strands[0] def test_no_insertion(self) -> None: @@ -402,8 +393,6 @@ def test_insert_idx_2_domain_starred(self) -> None: class TestExteriorBaseTypeOfDomain3PEnd(unittest.TestCase): - def setUp(self) -> None: - clear_domains_interned() def test_adjacent_to_exterior_base_pair_on_length_2_domain(self) -> None: """Test that base pair on domain of length two is properly classified as @@ -420,8 +409,9 @@ def test_adjacent_to_exterior_base_pair_on_length_2_domain(self) -> None: [=============--==> b* a* """ - top_strand = construct_strand(['a', 'b'], [2, 13]) - bot_strand = construct_strand(['b*', 'a*'], [13, 2]) + design = Design() + top_strand = construct_strand(design, ['a', 'b'], [2, 13]) + bot_strand = construct_strand(design, ['b*', 'a*'], [13, 2]) top_a = top_strand.address_of_domain(0) @@ -602,8 +592,9 @@ def test_seesaw_gate_output_complex(self): | | | INTERIOR_TO_STRAND DANGLE_5P """ - output_strand = construct_strand(['so', 'So', 'T', 'sg', 'Sg'], [2, 13, 5, 2, 13]) - gate_base_strand = construct_strand(['T*', 'Sg*', 'sg*', 'T*'], [5, 13, 2, 5]) + design = Design() + output_strand = construct_strand(design, ['so', 'So', 'T', 'sg', 'Sg'], [2, 13, 5, 2, 13]) + gate_base_strand = construct_strand(design, ['T*', 'Sg*', 'sg*', 'T*'], [5, 13, 2, 5]) gate_output_complex = [output_strand, gate_base_strand] output_t = output_strand.address_of_domain(2) @@ -653,8 +644,9 @@ def test_seesaw_threshold_complex(self): | | INTERIOR_TO_STRAND BLUNT_END """ - waste_strand = construct_strand(['sg', 'Sg'], [2, 13]) - threshold_base_strand = construct_strand(['si*', 'T*', 'Sg*', 'sg*'], [2, 5, 13, 2]) + design = Design() + waste_strand = construct_strand(design, ['sg', 'Sg'], [2, 13]) + threshold_base_strand = construct_strand(design, ['si*', 'T*', 'Sg*', 'sg*'], [2, 5, 13, 2]) threshold_complex = [waste_strand, threshold_base_strand] expected = set([ @@ -753,8 +745,9 @@ def test_seesaw_reporter_complex(self): | | INTERIOR_TO_STRAND BLUNT_END """ - waste_strand = construct_strand(['so', 'So'], [2, 13]) - reporter_base_strand = construct_strand(['T*', 'So*', 'so*'], [5, 13, 2]) + design = Design() + waste_strand = construct_strand(design, ['so', 'So'], [2, 13]) + reporter_base_strand = construct_strand(design, ['T*', 'So*', 'so*'], [5, 13, 2]) reporter_complex = [waste_strand, reporter_base_strand] expected = set([ @@ -795,8 +788,9 @@ def test_seesaw_reporter_waste_complex(self): | | | INTERIOR_TO_STRAND BLUNT_END """ - output_strand = construct_strand(['so', 'So', 'T', 'sg', 'Sg'], [2, 13, 5, 2, 13]) - reporter_base_strand = construct_strand(['T*', 'So*', 'so*'], [5, 13, 2]) + design = Design() + output_strand = construct_strand(design, ['so', 'So', 'T', 'sg', 'Sg'], [2, 13, 5, 2, 13]) + reporter_base_strand = construct_strand(design, ['T*', 'So*', 'so*'], [5, 13, 2]) reporter_waste_complex = [output_strand, reporter_base_strand] expected = set([ @@ -819,8 +813,8 @@ def test_seesaw_reporter_waste_complex(self): class TestStrandDomainAddress(unittest.TestCase): def setUp(self): - clear_domains_interned() - self.strand = construct_strand(['a', 'b', 'c'], [10, 20, 30]) + design = Design() + self.strand = construct_strand(design, ['a', 'b', 'c'], [10, 20, 30]) self.addr = StrandDomainAddress(self.strand, 1) def test_init(self):