From 68072a6200a2711ed8269ac2234ac0c8b2ff0325 Mon Sep 17 00:00:00 2001 From: "@alex.papadopoulos" Date: Wed, 14 Feb 2024 17:14:17 +0000 Subject: [PATCH] - Allow passing a max-template-date as an argument on colabfold batch - Additional logging to allow better debuggability - Log calling parameters on colabfold's scripts internal functions - Create batches of msas before calling structure prediction on very high number sequence runs - Make changes to support one .m8 template file per sequence on the structure prediction step (colabfold batch) - General refactoring by separating functionality into functions thus increasing modularity and separation of concerns --- colabfold/batch.py | 359 ++++++++++++----- colabfold/mmseqs/search.py | 786 +++++++++++++++++++++++-------------- colabfold/utils.py | 16 + 3 files changed, 772 insertions(+), 389 deletions(-) mode change 100644 => 100755 colabfold/batch.py diff --git a/colabfold/batch.py b/colabfold/batch.py old mode 100644 new mode 100755 index bcf1000d..2f326549 --- a/colabfold/batch.py +++ b/colabfold/batch.py @@ -20,10 +20,11 @@ import pickle import gzip -from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter +from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter, ArgumentTypeError from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING from io import StringIO +from datetime import datetime import importlib_metadata import numpy as np @@ -65,6 +66,7 @@ safe_filename, setup_logging, CFMMCIFIO, + log_function_call, ) from colabfold.relax import relax_me @@ -81,6 +83,22 @@ logging.getLogger('jax._src.xla_bridge').addFilter(lambda _: False) # jax >=0.4.6 logging.getLogger('jax._src.lib.xla_bridge').addFilter(lambda _: False) # jax < 0.4.5 + +def split_into_batches(lst, batch_size=500): + # Split the list into chunks of size batch_size + for i in range(0, len(lst), batch_size): + yield lst[i:i + batch_size] + + +# Define a function to validate the date format +def valid_date(s): + try: + _ = datetime.strptime(s, "%Y-%m-%d") + return s # We still need it to be treated as a string + except ValueError: + raise ArgumentTypeError(f"Not a valid date: '{s}'. Required format: YYYY-MM-DD") + + def mk_mock_template( query_sequence: Union[List[str], str], num_temp: int = 1 ) -> Dict[str, Any]: @@ -118,11 +136,10 @@ def mk_mock_template( return template_features def mk_template( - a3m_lines: str, template_path: str, query_sequence: str -) -> Dict[str, Any]: + a3m_lines: str, template_path: str, query_sequence: str, max_template_date: str) -> Dict[str, Any]: template_featurizer = templates.HhsearchHitFeaturizer( mmcif_dir=template_path, - max_template_date="2100-01-01", + max_template_date=max_template_date, max_hits=20, kalign_binary_path="kalign", release_dates_path=None, @@ -620,6 +637,9 @@ def get_queries( for file in sorted(input_path.iterdir()): if not file.is_file(): continue + if file.suffix.lower() == ".m8": + # Skipping logging template hit file in directory + continue if file.suffix.lower() not in [".a3m", ".fasta", ".faa"]: logger.warning(f"non-fasta/a3m file in input directory: {file}") continue @@ -726,6 +746,7 @@ def get_msa_and_templates( pairing_strategy: str = "greedy", host_url: str = DEFAULT_API_SERVER, user_agent: str = "", + max_template_date: str = "" ) -> Tuple[ Optional[List[str]], Optional[List[str]], List[str], List[int], List[Dict[str, Any]] ]: @@ -786,6 +807,7 @@ def get_msa_and_templates( a3m_lines_mmseqs2[index], template_paths[index], query_seqs_unique[index], + max_template_date ) if len(template_feature["template_domain_names"]) == 0: template_feature = mk_mock_template(query_seqs_unique[index]) @@ -1223,6 +1245,7 @@ def run( result_dir: Union[str, Path], num_models: int, is_complex: bool, + input_path: Union[str, Path] = None, num_recycles: Optional[int] = None, recycle_early_stop_tolerance: Optional[float] = None, model_order: List[int] = [1,2,3,4,5], @@ -1255,10 +1278,12 @@ def run( save_recycles: bool = False, use_dropout: bool = False, use_gpu_relax: bool = False, + use_unpacked_pdbs: bool = False, stop_at_score: float = 100, dpi: int = 200, max_seq: Optional[int] = None, max_extra_seq: Optional[int] = None, + max_template_date: Optional[str] = None, pdb_hit_file: Optional[Path] = None, local_pdb_path: Optional[Path] = None, use_cluster_profile: bool = True, @@ -1286,7 +1311,6 @@ def run( # disable GPU on tensorflow tf.config.set_visible_devices([], 'GPU') - from alphafold.notebooks.notebook_utils import get_pae_json from colabfold.alphafold.models import load_models_and_params from colabfold.colabfold import plot_paes, plot_plddts from colabfold.plot import plot_msa_v2 @@ -1376,6 +1400,7 @@ def run( "rank_by": rank_by, "max_seq": max_seq, "max_extra_seq": max_extra_seq, + "max_template_date": max_template_date, "pair_mode": pair_mode, "pairing_strategy": pairing_strategy, "host_url": host_url, @@ -1389,6 +1414,7 @@ def run( "use_cluster_profile": use_cluster_profile, "use_fuse": use_fuse, "use_bfloat16": use_bfloat16, + "use_unpacked_pdbs": use_unpacked_pdbs, "version": importlib_metadata.version("colabfold"), } config_out_file = result_dir.joinpath("config.json") @@ -1401,12 +1427,26 @@ def run( model_type if num_models > 0 else "", use_msa, use_env, use_templates, use_amber, result_dir ) - if pdb_hit_file is not None: - if local_pdb_path is None: + if pdb_hit_file: + if not local_pdb_path: raise ValueError("local_pdb_path is not specified.") else: custom_template_path = result_dir / "templates" put_mmciffiles_into_resultdir(pdb_hit_file, local_pdb_path, custom_template_path) + elif use_unpacked_pdbs: + if not local_pdb_path: + raise ValueError("local_pdb_path is not specified.") + elif not input_path: + raise ValueError("input_path is not specified.") + else: + custom_template_path = result_dir / "templates" + for query in queries: + if len(query) > 0: + m8file = next(input_path.glob(f"{query[0]}_*.m8"), None) + put_mmciffiles_into_resultdir(Path(input_path/m8file), local_pdb_path, + custom_template_path) + else: + logger.warning(f"Skipping {query[0]} as it has no sequence") if custom_template_path is not None: mk_hhsearch_db(custom_template_path) @@ -1414,7 +1454,6 @@ def run( pad_len = 0 ranks, metrics = [],[] first_job = True - job_number = 0 for job_number, (raw_jobname, query_sequence, a3m_lines) in enumerate(queries): if jobname_prefix is not None: # pad job number based on number of queries @@ -1455,7 +1494,7 @@ def run( if a3m_lines is None: (unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality, template_features) \ = get_msa_and_templates(jobname, query_sequence, a3m_lines, result_dir, msa_mode, use_templates, - custom_template_path, pair_mode, pairing_strategy, host_url, user_agent) + custom_template_path, pair_mode, pairing_strategy, host_url, user_agent, max_template_date) elif a3m_lines is not None: (unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality, template_features) \ @@ -1463,7 +1502,7 @@ def run( if use_templates: (_, _, _, _, template_features) \ = get_msa_and_templates(jobname, query_seqs_unique, unpaired_msa, result_dir, 'single_sequence', use_templates, - custom_template_path, pair_mode, pairing_strategy, host_url, user_agent) + custom_template_path, pair_mode, pairing_strategy, host_url, user_agent, max_template_date) if num_models == 0: with open(pickled_msa_and_templates, 'wb') as f: @@ -1674,7 +1713,9 @@ def set_model_type(is_complex: bool, model_type: str) -> str: model_type = "alphafold2_ptm" return model_type -def main(): + +# This function sets up and returns the argument parser +def parse_arguments(): parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) parser.add_argument( "input", @@ -1698,7 +1739,7 @@ def main(): "single_sequence", ], help="Databases to use to create the MSA: UniRef30+Environmental (default), UniRef30 only or None. " - "Using an A3M file as input overwrites this option.", + "Using an A3M file as input overwrites this option.", ) msa_group.add_argument( "--pair-mode", @@ -1710,10 +1751,10 @@ def main(): msa_group.add_argument( "--pair-strategy", help="How sequences are paired during MSA pairing for complex prediction. " - "complete: MSA sequences should only be paired if the same species exists in all MSAs. " - "greedy: MSA sequences should only be paired if the same species exists in at least two MSAs. " - "Typically, greedy produces better predictions as it results in more paired sequences. " - "However, in some cases complete pairing might help, especially if MSAs are already large and can be well paired. ", + "complete: MSA sequences should only be paired if the same species exists in all MSAs. " + "greedy: MSA sequences should only be paired if the same species exists in at least two MSAs. " + "Typically, greedy produces better predictions as it results in more paired sequences. " + "However, in some cases complete pairing might help, especially if MSAs are already large and can be well paired. ", type=str, default="greedy", choices=["complete", "greedy"], @@ -1723,58 +1764,64 @@ def main(): default=False, action="store_true", help="Query PDB templates from the MSA server. " - 'If this parameter is not set, "--custom-template-path" and "--pdb-hit-file" will not be used. ' - "Warning: This can result in the MSA server being queried with A3M input. " + 'If this parameter is not set, "--custom-template-path" and "--pdb-hit-file" will not be used. ' + "Warning: This can result in the MSA server being queried with A3M input. " ) msa_group.add_argument( "--custom-template-path", type=str, default=None, help="Directory with PDB files to provide as custom templates to the predictor. " - "No templates will be queried from the MSA server. " - "'--templates' argument is also required to enable this.", + "No templates will be queried from the MSA server. " + "'--templates' argument is also required to enable this.", ) msa_group.add_argument( "--pdb-hit-file", default=None, help="Path to a BLAST-m8 formatted PDB hit file corresponding to the input A3M file (e.g. pdb70.m8). " - "Typically, this parameter should be used for a MSA generated by 'colabfold_search'. " - "'--templates' argument is also required to enable this.", + "Typically, this parameter should be used for a MSA generated by 'colabfold_search'. " + "'--templates' argument is also required to enable this.", ) msa_group.add_argument( "--local-pdb-path", default=None, help="Directory of a local mirror of the PDB mmCIF database (e.g. /path/to/pdb/divided). " - "If provided, PDB files from the directory are used for templates specified by '--pdb-hit-file'. ", + "If provided, PDB files from the directory are used for templates specified by '--pdb-hit-file'. ", + ) + msa_group.add_argument( + "--max-template-date", + default="2031-01-01", + type=valid_date, + help="Max template date to be used for AlphaFold structure prediction.", ) pred_group = parser.add_argument_group("Prediction arguments", "") pred_group.add_argument( "--num-recycle", help="Number of prediction recycles. " - "Increasing recycles can improve the prediction quality but slows down the prediction.", + "Increasing recycles can improve the prediction quality but slows down the prediction.", type=int, default=None, ) pred_group.add_argument( "--recycle-early-stop-tolerance", help="Specify convergence criteria. " - "Run recycles until the distance between recycles is within the given tolerance value.", + "Run recycles until the distance between recycles is within the given tolerance value.", type=float, default=None, ) pred_group.add_argument( "--num-ensemble", help="Number of ensembles. " - "The trunk of the network is run multiple times with different random choices for the MSA cluster centers. " - "This can result in a better prediction at the cost of longer runtime. ", + "The trunk of the network is run multiple times with different random choices for the MSA cluster centers. " + "This can result in a better prediction at the cost of longer runtime. ", type=int, default=1, ) pred_group.add_argument( "--num-seeds", help="Number of seeds to try. Will iterate from range(random_seed, random_seed+num_seeds). " - "This can result in a better/different prediction at the cost of longer runtime. ", + "This can result in a better/different prediction at the cost of longer runtime. ", type=int, default=1, ) @@ -1787,7 +1834,7 @@ def main(): pred_group.add_argument( "--num-models", help="Number of models to use for structure prediction. " - "Reducing the number of models speeds up the prediction but results in lower quality.", + "Reducing the number of models speeds up the prediction but results in lower quality.", type=int, default=5, choices=[1, 2, 3, 4, 5], @@ -1795,9 +1842,9 @@ def main(): pred_group.add_argument( "--model-type", help="Predict structure/complex using the given model. " - 'Auto will pick "alphafold2_ptm" for structure predictions and "alphafold2_multimer_v3" for complexes. ' - "Older versions of the AF2 models are generally worse, however they can sometimes result in better predictions. " - "If the model is not already downloaded, it will be automatically downloaded. ", + 'Auto will pick "alphafold2_ptm" for structure predictions and "alphafold2_multimer_v3" for complexes. ' + "Older versions of the AF2 models are generally worse, however they can sometimes result in better predictions. " + "If the model is not already downloaded, it will be automatically downloaded. ", type=str, default="auto", choices=[ @@ -1816,26 +1863,26 @@ def main(): default=False, action="store_true", help="Activate dropouts during inference to sample from uncertainty of the models. " - "This can result in different predictions and can be (carefully!) used for conformations sampling.", + "This can result in different predictions and can be (carefully!) used for conformations sampling.", ) pred_group.add_argument( "--max-seq", help="Number of sequence clusters to use. " - "This can result in different predictions and can be (carefully!) used for conformations sampling.", + "This can result in different predictions and can be (carefully!) used for conformations sampling.", type=int, default=None, ) pred_group.add_argument( "--max-extra-seq", help="Number of extra sequences to use. " - "This can result in different predictions and can be (carefully!) used for conformations sampling.", + "This can result in different predictions and can be (carefully!) used for conformations sampling.", type=int, default=None, ) pred_group.add_argument( "--max-msa", help="Defines: `max-seq:max-extra-seq` number of sequences to use in one go. " - '"--max-seq" and "--max-extra-seq" are ignored if this parameter is set.', + '"--max-seq" and "--max-extra-seq" are ignored if this parameter is set.', type=str, default=None, ) @@ -1846,6 +1893,7 @@ def main(): help="Experimental: For multimer models, disable cluster profiles.", ) pred_group.add_argument("--data", help="Path to AlphaFold2 weights directory.") + pred_group.add_argument("--custom-weights", help="Path to AlphaFold2 predownloaded weights.") relax_group = parser.add_argument_group("Relaxation arguments", "") relax_group.add_argument( @@ -1853,12 +1901,12 @@ def main(): default=False, action="store_true", help="Enable OpenMM/Amber for structure relaxation. " - "Can improve the quality of side-chains at a cost of longer runtime. " + "Can improve the quality of side-chains at a cost of longer runtime. " ) relax_group.add_argument( "--num-relax", help="Specify how many of the top ranked structures to relax using OpenMM/Amber. " - "Typically, relaxing the top-ranked prediction is enough and speeds up the runtime. ", + "Typically, relaxing the top-ranked prediction is enough and speeds up the runtime. ", type=int, default=0, ) @@ -1867,7 +1915,7 @@ def main(): type=int, default=2000, help="Maximum number of iterations for the relaxation process. " - "AlphaFold2 sets this to unlimited (0), however, we found that this can lead to very long relaxation times for some inputs.", + "AlphaFold2 sets this to unlimited (0), however, we found that this can lead to very long relaxation times for some inputs.", ) relax_group.add_argument( "--relax-tolerance", @@ -1892,8 +1940,8 @@ def main(): default=False, action="store_true", help="Run OpenMM/Amber on GPU instead of CPU. " - "This can significantly speed up the relaxation runtime, however, might lead to compatibility issues with CUDA. " - "Unsupported on AMD/ROCM and Apple Silicon.", + "This can significantly speed up the relaxation runtime, however, might lead to compatibility issues with CUDA. " + "Unsupported on AMD/ROCM and Apple Silicon.", ) output_group = parser.add_argument_group("Output arguments", "") @@ -1907,7 +1955,7 @@ def main(): output_group.add_argument( "--stop-at-score", help="Compute models until pLDDT (single chain) or pTM-score (multimer) > threshold is reached. " - "This speeds up prediction by running less models for easier queries.", + "This speeds up prediction by running less models for easier queries.", type=float, default=100, ) @@ -1922,7 +1970,7 @@ def main(): default=False, action="store_true", help="Save all raw outputs from model to a pickle file. " - "Useful for downstream use in other models." + "Useful for downstream use in other models." ) output_group.add_argument( "--save-recycles", @@ -1957,7 +2005,7 @@ def main(): output_group.add_argument( "--sort-queries-by", help="Sort input queries by: none, length, random. " - "Sorting by length speeds up prediction as models are recompiled less often.", + "Sorting by length speeds up prediction as models are recompiled less often.", type=str, default="length", choices=["none", "length", "random"], @@ -1982,22 +2030,73 @@ def main(): type=int, default=10, help="Whenever the input length changes, the model needs to be recompiled. " - "We pad sequences by the specified length, so we can e.g., compute sequences from length 100 to 110 without recompiling. " - "Individual predictions will become marginally slower due to longer input, " - "but overall performance increases due to not recompiling. " - "Set to 0 to disable.", + "We pad sequences by the specified length, so we can e.g., compute sequences from length 100 to 110 without recompiling. " + "Individual predictions will become marginally slower due to longer input, " + "but overall performance increases due to not recompiling. " + "Set to 0 to disable.", ) - args = parser.parse_args() - - if (args.custom_template_path is not None) and (args.pdb_hit_file is not None): + return parser.parse_args() + +@log_function_call +def compute_pdbs( + input_path, + results, + msa_only=False, + msa_mode="mmseqs2_uniref_env", + pair_mode="unpaired_paired", + pair_strategy="greedy", + templates=False, + custom_template_path=None, + pdb_hit_file=None, + local_pdb_path=None, + max_template_date="2031-01-01", + num_recycle=None, + recycle_early_stop_tolerance=None, + num_ensemble=1, + num_seeds=1, + random_seed=0, + num_models=5, + model_type="auto", + model_order="1,2,3,4,5", + use_dropout=False, + max_seq=None, + max_extra_seq=None, + max_msa=None, + disable_cluster_profile=False, + data=None, + custom_weights=None, + amber=False, + num_relax=0, + relax_max_iterations=2000, + relax_tolerance=2.39, + relax_stiffness=10.0, + relax_max_outer_iterations=3, + use_gpu_relax=False, + use_unpacked_pdbs=False, + rank="auto", + stop_at_score=100, + jobname_prefix=None, + save_all=False, + save_recycles=False, + save_single_representations=False, + save_pair_representations=False, + overwrite_existing_results=False, + zip_results=False, + sort_queries_by="length", + host_url=DEFAULT_API_SERVER, + disable_unified_memory=False, + recompile_padding=10, + log_file=None +): + if (custom_template_path is not None) and (pdb_hit_file is not None): raise RuntimeError("Arguments --pdb-hit-file and --custom-template-path cannot be used simultaneously.") - # disable unified memory - if args.disable_unified_memory: + # disable unified memory + if disable_unified_memory: for k in ENV.keys(): if k in os.environ: del os.environ[k] - setup_logging(Path(args.results).joinpath("log.txt")) + setup_logging(Path(log_file) if log_file else Path(results).joinpath("log.txt")) version = importlib_metadata.version("colabfold") commit = get_commit() @@ -2006,76 +2105,138 @@ def main(): logger.info(f"Running colabfold {version}") - data_dir = Path(args.data or default_data_dir) + data_dir = Path(data or default_data_dir) - queries, is_complex = get_queries(args.input, args.sort_queries_by) - model_type = set_model_type(is_complex, args.model_type) + queries, is_complex = get_queries(input_path, sort_queries_by) + model_type = set_model_type(is_complex, model_type) - if args.msa_only: - args.num_models = 0 + if msa_only: + num_models = 0 - if args.num_models > 0: - download_alphafold_params(model_type, data_dir) + if num_models > 0: + if not custom_weights: + download_alphafold_params(model_type, data_dir) + else: + data_dir = custom_weights - if args.msa_mode != "single_sequence" and not args.templates: + if msa_mode != "single_sequence" and not templates: uses_api = any((query[2] is None for query in queries)) - if uses_api and args.host_url == DEFAULT_API_SERVER: + if uses_api and host_url == DEFAULT_API_SERVER: print(ACCEPT_DEFAULT_TERMS, file=sys.stderr) - model_order = [int(i) for i in args.model_order.split(",")] + model_order = [int(i) for i in model_order.split(",")] - assert args.recompile_padding >= 0, "Can't apply negative padding" + assert recompile_padding >= 0, "Can't apply negative padding" # backward compatibility - if args.amber and args.num_relax == 0: - args.num_relax = args.num_models * args.num_seeds + if amber and num_relax == 0: + num_relax = num_models * num_seeds user_agent = f"colabfold/{version}" - run( - queries=queries, - result_dir=args.results, - use_templates=args.templates, - custom_template_path=args.custom_template_path, - num_relax=args.num_relax, - relax_max_iterations=args.relax_max_iterations, - relax_tolerance=args.relax_tolerance, - relax_stiffness=args.relax_stiffness, - relax_max_outer_iterations=args.relax_max_outer_iterations, + # Initialization time and memory requirements seem to be excessive for fasta files with more than 500 sequences + for batch_queries in split_into_batches(queries, batch_size=500): + run( + input_path=input_path, + queries=batch_queries, + result_dir=results, + use_templates=templates, + custom_template_path=custom_template_path, + num_relax=num_relax, + relax_max_iterations=relax_max_iterations, + relax_tolerance=relax_tolerance, + relax_stiffness=relax_stiffness, + relax_max_outer_iterations=relax_max_outer_iterations, + msa_mode=msa_mode, + model_type=model_type, + num_models=num_models, + num_recycles=num_recycle, + recycle_early_stop_tolerance=recycle_early_stop_tolerance, + num_ensemble=num_ensemble, + model_order=model_order, + is_complex=is_complex, + keep_existing_results=not overwrite_existing_results, + rank_by=rank, + pair_mode=pair_mode, + pairing_strategy=pair_strategy, + data_dir=data_dir, + host_url=host_url, + user_agent=user_agent, + random_seed=random_seed, + num_seeds=num_seeds, + stop_at_score=stop_at_score, + recompile_padding=recompile_padding, + zip_results=zip_results, + save_single_representations=save_single_representations, + save_pair_representations=save_pair_representations, + use_dropout=use_dropout, + max_seq=max_seq, + max_extra_seq=max_extra_seq, + max_msa=max_msa, + max_template_date=max_template_date, + pdb_hit_file=pdb_hit_file, + local_pdb_path=local_pdb_path, + use_cluster_profile=not disable_cluster_profile, + use_gpu_relax=use_gpu_relax, + use_unpacked_pdbs=use_unpacked_pdbs, + jobname_prefix=jobname_prefix, + save_all=save_all, + save_recycles=save_recycles, + ) + + +# This is the main function that now only handles argument parsing +def main(): + args = parse_arguments() + compute_pdbs( + input_path=args.input, + results=args.results, + msa_only=args.msa_only, msa_mode=args.msa_mode, - model_type=model_type, - num_models=args.num_models, - num_recycles=args.num_recycle, + pair_mode=args.pair_mode, + pair_strategy=args.pair_strategy, + templates=args.templates, + custom_template_path=args.custom_template_path, + pdb_hit_file=args.pdb_hit_file, + local_pdb_path=args.local_pdb_path, + max_template_date=args.max_template_date, + num_recycle=args.num_recycle, recycle_early_stop_tolerance=args.recycle_early_stop_tolerance, num_ensemble=args.num_ensemble, - model_order=model_order, - is_complex=is_complex, - keep_existing_results=not args.overwrite_existing_results, - rank_by=args.rank, - pair_mode=args.pair_mode, - pairing_strategy=args.pair_strategy, - data_dir=data_dir, - host_url=args.host_url, - user_agent=user_agent, - random_seed=args.random_seed, num_seeds=args.num_seeds, - stop_at_score=args.stop_at_score, - recompile_padding=args.recompile_padding, - zip_results=args.zip, - save_single_representations=args.save_single_representations, - save_pair_representations=args.save_pair_representations, + random_seed=args.random_seed, + num_models=args.num_models, + model_type=args.model_type, + model_order=args.model_order, use_dropout=args.use_dropout, max_seq=args.max_seq, max_extra_seq=args.max_extra_seq, max_msa=args.max_msa, - pdb_hit_file=args.pdb_hit_file, - local_pdb_path=args.local_pdb_path, - use_cluster_profile=not args.disable_cluster_profile, - use_gpu_relax = args.use_gpu_relax, + disable_cluster_profile=args.disable_cluster_profile, + data=args.data, + custom_weights=args.custom_weights, + amber=args.amber, + num_relax=args.num_relax, + relax_max_iterations=args.relax_max_iterations, + relax_tolerance=args.relax_tolerance, + relax_stiffness=args.relax_stiffness, + relax_max_outer_iterations=args.relax_max_outer_iterations, + use_gpu_relax=args.use_gpu_relax, + rank=args.rank, + stop_at_score=args.stop_at_score, jobname_prefix=args.jobname_prefix, save_all=args.save_all, save_recycles=args.save_recycles, + save_single_representations=args.save_single_representations, + save_pair_representations=args.save_pair_representations, + overwrite_existing_results=args.overwrite_existing_results, + zip_results=args.zip_results, + sort_queries_by=args.sort_queries_by, + host_url=args.host_url, + disable_unified_memory=args.disable_unified_memory, + recompile_padding=args.recompile_padding, ) + if __name__ == "__main__": main() diff --git a/colabfold/mmseqs/search.py b/colabfold/mmseqs/search.py index 36e0e59f..66e0b301 100644 --- a/colabfold/mmseqs/search.py +++ b/colabfold/mmseqs/search.py @@ -14,45 +14,94 @@ from typing import List, Union from colabfold.batch import get_queries, msa_to_str -from colabfold.utils import safe_filename +from colabfold.utils import safe_filename, log_function_call logger = logging.getLogger(__name__) def run_mmseqs(mmseqs: Path, params: List[Union[str, Path]]): + """ + Run mmseqs with the given parameters + :param mmseqs: Path to the mmseqs binary + :param params: List of parameters to pass to mmseqs + :return: + """ params_log = " ".join(str(i) for i in params) logger.info(f"Running {mmseqs} {params_log}") - # hide MMseqs2 verbose paramters list that clogs up the log + # hide MMseqs2 verbose parameters list that clogs up the log os.environ["MMSEQS_CALL_DEPTH"] = "1" - subprocess.check_call([mmseqs] + params) + # Open a subprocess and direct stdout and stderr to subprocess.PIPE + with subprocess.Popen([mmseqs] + params, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, + bufsize=1) as proc: + # Read the output line by line as it becomes available + for line in proc.stdout: + logger.info(line.strip()) # Log each line from the output + # Wait for the subprocess to finish and get the exit code + proc.wait() + if proc.returncode != 0: + raise subprocess.CalledProcessError(proc.returncode, [mmseqs] + params) + + +def safe_rename_files(base, old_suffix, new_suffix): + # Iterate through every item in base directory + for item in base.iterdir(): + # Check if the item is a file and has the correct extension + if item.is_file() and item.suffix == old_suffix: + # Construct new filename using safe_filename function + new_filename = safe_filename(item.stem) + new_suffix + # Rename the old file to the new one + item.rename(base.joinpath(new_filename)) + + +def rename_m8_files(queries_unique, base, template_db): + id = 0 + for raw_jobname, query_sequences, query_seqs_cardinality in queries_unique: + with base.joinpath(f"{safe_filename(raw_jobname)}_{template_db}.m8").open( + "w" + ) as f: + for _ in range(len(query_seqs_cardinality)): + with base.joinpath(f"{id}.m8").open("r") as g: + f.write(g.read()) + os.remove(base.joinpath(f"{id}.m8")) + id += 1 -def mmseqs_search_monomer( - dbbase: Path, - base: Path, - uniref_db: Path = Path("uniref30_2302_db"), - template_db: Path = Path(""), # Unused by default - metagenomic_db: Path = Path("colabfold_envdb_202108_db"), - mmseqs: Path = Path("mmseqs"), - use_env: bool = True, - use_templates: bool = False, - filter: bool = True, - expand_eval: float = math.inf, - align_eval: int = 10, - diff: int = 3000, - qsc: float = -20.0, - max_accept: int = 1000000, - prefilter_mode: int = 0, - s: float = 8, - db_load_mode: int = 2, - threads: int = 32, -): - """Run mmseqs with a local colabfold database set +def rename_a3m_files(queries_unique, base): + logging.info("Renaming a3m files with args %s", queries_unique) + for job_number, (raw_jobname, query_sequences, query_seqs_cardinality) in enumerate(queries_unique): + logging.info(f"Renaming {base.joinpath(f'{job_number}.a3m')} to " + f"{base.joinpath(f'{safe_filename(raw_jobname)}.a3m')}") + os.rename( + base.joinpath(f"{job_number}.a3m"), + base.joinpath(f"{safe_filename(raw_jobname)}.a3m"), + ) - db1: uniprot db (UniRef30) - db2: Template (unused by default) - db3: metagenomic db (colabfold_envdb_202108 or bfd_mgy_colabfold, the former is preferred) - """ + +def get_db_suffix(db, db_load_mode, dbtype="env"): + if not Path(f"{db}.dbtype").is_file(): + raise FileNotFoundError(f"Database {db} does not exist") + if ( + ( + not Path(f"{db}.idx").is_file() + and not Path(f"{db}.idx.index").is_file() + ) + or os.environ.get("MMSEQS_IGNORE_INDEX", False) + ): + logger.info("Search does not use index") + db_load_mode = 0 + dbSuffix1 = "_seq" if dbtype == "env" else "" + dbSuffix2 = "_aln" if dbtype == "env" else "" + else: + dbSuffix1 = ".idx" + dbSuffix2 = ".idx" + + return dbSuffix1, dbSuffix2, db_load_mode + + +@log_function_call +def search_uniref_db(base, uniref_db, filter: int = 1, mmseqs: Path = Path("mmseqs"),align_eval: int = 10, + qsc: float = -20.0, max_accept: int = 1000000, db_load_mode: int = 0, threads: int = 32, + s: float = 8, diff: int = 3000, expand_eval: float = math.inf, prefilter_mode: int = 0): if filter: # 0.1 was not used in benchmarks due to POSIX shell bug in line above # EXPAND_EVAL=0.1 @@ -60,163 +109,160 @@ def mmseqs_search_monomer( qsc = 0.8 max_accept = 100000 - used_dbs = [uniref_db] - if use_templates: - used_dbs.append(template_db) - if use_env: - used_dbs.append(metagenomic_db) - - for db in used_dbs: - if not dbbase.joinpath(f"{db}.dbtype").is_file(): - raise FileNotFoundError(f"Database {db} does not exist") - if ( - ( - not dbbase.joinpath(f"{db}.idx").is_file() - and not dbbase.joinpath(f"{db}.idx.index").is_file() - ) - or os.environ.get("MMSEQS_IGNORE_INDEX", False) - ): - logger.info("Search does not use index") - db_load_mode = 0 - dbSuffix1 = "_seq" - dbSuffix2 = "_aln" - dbSuffix3 = "" - else: - dbSuffix1 = ".idx" - dbSuffix2 = ".idx" - dbSuffix3 = ".idx" + dbSuffix1, dbSuffix2, db_load_mode = get_db_suffix(uniref_db, db_load_mode) + + search_param = ["--num-iterations", "3", "--db-load-mode", str(db_load_mode), "-a", "-e", "0.1", "--max-seqs", + "10000"] - # fmt: off - # @formatter:off - search_param = ["--num-iterations", "3", "--db-load-mode", str(db_load_mode), "-a", "-e", "0.1", "--max-seqs", "10000"] search_param += ["--prefilter-mode", str(prefilter_mode)] if s is not None: search_param += ["-s", "{:.1f}".format(s)] else: search_param += ["--k-score", "'seq:96,prof:80'"] + filter_param = ["--filter-msa", str(filter), "--filter-min-enable", "1000", "--diff", str(diff), "--qid", + "0.0,0.2,0.4,0.6,0.8,1.0", "--qsc", "0", "--max-seq-id", "0.95",] + expand_param = ["--expansion-mode", "0", "-e", str(expand_eval), "--expand-filter-clusters", str(filter), + "--max-seq-id", "0.95",] - filter_param = ["--filter-msa", str(filter), "--filter-min-enable", "1000", "--diff", str(diff), "--qid", "0.0,0.2,0.4,0.6,0.8,1.0", "--qsc", "0", "--max-seq-id", "0.95",] - expand_param = ["--expansion-mode", "0", "-e", str(expand_eval), "--expand-filter-clusters", str(filter), "--max-seq-id", "0.95",] - - run_mmseqs(mmseqs, ["search", base.joinpath("qdb"), dbbase.joinpath(uniref_db), base.joinpath("res"), base.joinpath("tmp"), "--threads", str(threads)] + search_param) + run_mmseqs(mmseqs, ["search", base.joinpath("qdb"), uniref_db, base.joinpath("res"), base.joinpath("tmp"), + "--threads", str(threads)] + search_param) run_mmseqs(mmseqs, ["mvdb", base.joinpath("tmp/latest/profile_1"), base.joinpath("prof_res")]) run_mmseqs(mmseqs, ["lndb", base.joinpath("qdb_h"), base.joinpath("prof_res_h")]) - run_mmseqs(mmseqs, ["expandaln", base.joinpath("qdb"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"), base.joinpath("res"), dbbase.joinpath(f"{uniref_db}{dbSuffix2}"), base.joinpath("res_exp"), "--db-load-mode", str(db_load_mode), "--threads", str(threads)] + expand_param) - run_mmseqs(mmseqs, ["align", base.joinpath("prof_res"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"), base.joinpath("res_exp"), base.joinpath("res_exp_realign"), "--db-load-mode", str(db_load_mode), "-e", str(align_eval), "--max-accept", str(max_accept), "--threads", str(threads), "--alt-ali", "10", "-a"]) - run_mmseqs(mmseqs, ["filterresult", base.joinpath("qdb"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"), + run_mmseqs(mmseqs, ["expandaln", base.joinpath("qdb"), f"{uniref_db}{dbSuffix1}", base.joinpath("res"), + f"{uniref_db}{dbSuffix2}", base.joinpath("res_exp"), "--db-load-mode", str(db_load_mode), + "--threads", str(threads)] + expand_param) + run_mmseqs(mmseqs, ["align", base.joinpath("prof_res"), f"{uniref_db}{dbSuffix1}", base.joinpath("res_exp"), + base.joinpath("res_exp_realign"), "--db-load-mode", str(db_load_mode), "-e", str(align_eval), + "--max-accept", str(max_accept), "--threads", str(threads), "--alt-ali", "10", "-a"]) + run_mmseqs(mmseqs, ["filterresult", base.joinpath("qdb"), f"{uniref_db}{dbSuffix1}", base.joinpath("res_exp_realign"), base.joinpath("res_exp_realign_filter"), "--db-load-mode", str(db_load_mode), "--qid", "0", "--qsc", str(qsc), "--diff", "0", "--threads", str(threads), "--max-seq-id", "1.0", "--filter-min-enable", "100"]) - run_mmseqs(mmseqs, ["result2msa", base.joinpath("qdb"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"), - base.joinpath("res_exp_realign_filter"), base.joinpath("uniref.a3m"), "--msa-format-mode", - "6", "--db-load-mode", str(db_load_mode), "--threads", str(threads)] + filter_param) + run_mmseqs(mmseqs, ["result2msa", base.joinpath("qdb"), f"{uniref_db}{dbSuffix1}", + base.joinpath("res_exp_realign_filter"), base.joinpath(f"{uniref_db.name}.a3m"), + "--msa-format-mode", "6", "--db-load-mode", str(db_load_mode), "--threads", + str(threads)] + filter_param) subprocess.run([mmseqs] + ["rmdb", base.joinpath("res_exp_realign")]) subprocess.run([mmseqs] + ["rmdb", base.joinpath("res_exp")]) subprocess.run([mmseqs] + ["rmdb", base.joinpath("res")]) subprocess.run([mmseqs] + ["rmdb", base.joinpath("res_exp_realign_filter")]) + shutil.rmtree(base.joinpath("tmp")) - if use_env: - run_mmseqs(mmseqs, ["search", base.joinpath("prof_res"), dbbase.joinpath(metagenomic_db), base.joinpath("res_env"), - base.joinpath("tmp3"), "--threads", str(threads)] + search_param) - run_mmseqs(mmseqs, ["expandaln", base.joinpath("prof_res"), dbbase.joinpath(f"{metagenomic_db}{dbSuffix1}"), base.joinpath("res_env"), - dbbase.joinpath(f"{metagenomic_db}{dbSuffix2}"), base.joinpath("res_env_exp"), "-e", str(expand_eval), - "--expansion-mode", "0", "--db-load-mode", str(db_load_mode), "--threads", str(threads)]) - run_mmseqs(mmseqs, ["align", base.joinpath("tmp3/latest/profile_1"), dbbase.joinpath(f"{metagenomic_db}{dbSuffix1}"), - base.joinpath("res_env_exp"), base.joinpath("res_env_exp_realign"), "--db-load-mode", - str(db_load_mode), "-e", str(align_eval), "--max-accept", str(max_accept), "--threads", - str(threads), "--alt-ali", "10", "-a"]) - run_mmseqs(mmseqs, ["filterresult", base.joinpath("qdb"), dbbase.joinpath(f"{metagenomic_db}{dbSuffix1}"), - base.joinpath("res_env_exp_realign"), base.joinpath("res_env_exp_realign_filter"), - "--db-load-mode", str(db_load_mode), "--qid", "0", "--qsc", str(qsc), "--diff", "0", - "--max-seq-id", "1.0", "--threads", str(threads), "--filter-min-enable", "100"]) - run_mmseqs(mmseqs, ["result2msa", base.joinpath("qdb"), dbbase.joinpath(f"{metagenomic_db}{dbSuffix1}"), - base.joinpath("res_env_exp_realign_filter"), - base.joinpath("bfd.mgnify30.metaeuk30.smag30.a3m"), "--msa-format-mode", "6", - "--db-load-mode", str(db_load_mode), "--threads", str(threads)] + filter_param) - - run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_env_exp_realign_filter")]) - run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_env_exp_realign")]) - run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_env_exp")]) - run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_env")]) - - run_mmseqs(mmseqs, ["mergedbs", base.joinpath("qdb"), base.joinpath("final.a3m"), base.joinpath("uniref.a3m"), base.joinpath("bfd.mgnify30.metaeuk30.smag30.a3m")]) - run_mmseqs(mmseqs, ["rmdb", base.joinpath("bfd.mgnify30.metaeuk30.smag30.a3m")]) - else: - run_mmseqs(mmseqs, ["mvdb", base.joinpath("uniref.a3m"), base.joinpath("final.a3m")]) - - if use_templates: - run_mmseqs(mmseqs, ["search", base.joinpath("prof_res"), dbbase.joinpath(template_db), base.joinpath("res_pdb"), - base.joinpath("tmp2"), "--db-load-mode", str(db_load_mode), "--threads", str(threads), "-s", "7.5", "-a", "-e", "0.1", "--prefilter-mode", str(prefilter_mode)]) - run_mmseqs(mmseqs, ["convertalis", base.joinpath("prof_res"), dbbase.joinpath(f"{template_db}{dbSuffix3}"), base.joinpath("res_pdb"), - base.joinpath(f"{template_db}"), "--format-output", - "query,target,fident,alnlen,mismatch,gapopen,qstart,qend,tstart,tend,evalue,bits,cigar", - "--db-output", "1", - "--db-load-mode", str(db_load_mode), "--threads", str(threads)]) - run_mmseqs(mmseqs, ["unpackdb", base.joinpath(f"{template_db}"), base.joinpath("."), "--unpack-name-mode", "0", "--unpack-suffix", ".m8"]) - run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_pdb")]) - run_mmseqs(mmseqs, ["rmdb", base.joinpath(f"{template_db}")]) - - run_mmseqs(mmseqs, ["unpackdb", base.joinpath("final.a3m"), base.joinpath("."), "--unpack-name-mode", "0", "--unpack-suffix", ".a3m"]) - run_mmseqs(mmseqs, ["rmdb", base.joinpath("final.a3m")]) - run_mmseqs(mmseqs, ["rmdb", base.joinpath("uniref.a3m")]) - run_mmseqs(mmseqs, ["rmdb", base.joinpath("res")]) - # @formatter:on - # fmt: on - for file in base.glob("prof_res*"): - file.unlink() - shutil.rmtree(base.joinpath("tmp")) - if use_templates: - shutil.rmtree(base.joinpath("tmp2")) - if use_env: - shutil.rmtree(base.joinpath("tmp3")) +@log_function_call +def search_env_db(base, metagenomic_db, mmseqs: Path = Path("mmseqs"), align_eval: int = 10, qsc: float = -20.0, + max_accept: int = 1000000, db_load_mode: int = 0, threads: int = 32, expand_eval: float = math.inf, + diff: int = 3000, filter: int = 1, s: float = 8, split_memory_limit: str = "100G"): + + if filter: + # 0.1 was not used in benchmarks due to POSIX shell bug in line above + # EXPAND_EVAL=0.1 + align_eval = 10 + qsc = 0.8 + max_accept = 100000 + dbSuffix1, dbSuffix2, db_load_mode = get_db_suffix(metagenomic_db, db_load_mode) + + filter_param = ["--filter-msa", str(filter), "--filter-min-enable", "1000", "--diff", str(diff), "--qid", + "0.0,0.2,0.4,0.6,0.8,1.0", "--qsc", "0", "--max-seq-id", "0.95", ] + search_param = ["--num-iterations", "3", "--db-load-mode", str(db_load_mode), "-a", "-e", "0.1", "--max-seqs", + "10000", "--split-memory-limit", split_memory_limit] + if s: + search_param += ["-s", "{:.1f}".format(s)] + else: + search_param += ["--k-score", "'seq:96,prof:80'"] -def mmseqs_search_pair( - dbbase: Path, + run_mmseqs(mmseqs, + ["search", base.joinpath("prof_res"), metagenomic_db, base.joinpath(f"res_env_{metagenomic_db.name}"), + base.joinpath(f"tmp_{metagenomic_db.name}"), "--threads", str(threads)] + search_param) + run_mmseqs(mmseqs, + ["expandaln", base.joinpath("prof_res"), f"{metagenomic_db}{dbSuffix1}", + base.joinpath(f"res_env_{metagenomic_db.name}"), f"{metagenomic_db}{dbSuffix2}", + base.joinpath(f"res_env_{metagenomic_db.name}_exp"), "-e", str(expand_eval), "--expansion-mode", "0", + "--db-load-mode", str(db_load_mode), "--threads", str(threads)]) + run_mmseqs(mmseqs, ["align", base.joinpath(f"tmp_{metagenomic_db.name}/latest/profile_1"), + f"{metagenomic_db}{dbSuffix1}", base.joinpath(f"res_env_{metagenomic_db.name}_exp"), + base.joinpath(f"res_env_{metagenomic_db.name}_exp_realign"), "--db-load-mode", + str(db_load_mode), "-e", str(align_eval), "--max-accept", str(max_accept), "--threads", + str(threads), "--alt-ali", "10", "-a"]) + run_mmseqs(mmseqs, ["filterresult", base.joinpath("qdb"), f"{metagenomic_db}{dbSuffix1}", + base.joinpath(f"res_env_{metagenomic_db.name}_exp_realign"), + base.joinpath(f"res_env_{metagenomic_db.name}_exp_realign_filter"), + "--db-load-mode", str(db_load_mode), "--qid", "0", "--qsc", str(qsc), "--diff", "0", + "--max-seq-id", "1.0", "--threads", str(threads), "--filter-min-enable", "100"]) + run_mmseqs(mmseqs, ["result2msa", base.joinpath("qdb"), f"{metagenomic_db}{dbSuffix1}", + base.joinpath(f"res_env_{metagenomic_db.name}_exp_realign_filter"), + base.joinpath(f"{metagenomic_db.name}.a3m"), "--msa-format-mode", "6", + "--db-load-mode", str(db_load_mode), "--threads", str(threads)] + filter_param) + + run_mmseqs(mmseqs, ["rmdb", base.joinpath(f"res_env_{metagenomic_db.name}_exp_realign_filter")]) + run_mmseqs(mmseqs, ["rmdb", base.joinpath(f"res_env_{metagenomic_db.name}_exp_realign")]) + run_mmseqs(mmseqs, ["rmdb", base.joinpath(f"res_env_{metagenomic_db.name}_exp")]) + run_mmseqs(mmseqs, ["rmdb", base.joinpath(f"res_env_{metagenomic_db.name}")]) + + shutil.rmtree(base.joinpath(f"tmp_{metagenomic_db.name}")) + + +@log_function_call +def search_template_db(base, template_db, prefilter_mode: int = 0, mmseqs: Path = Path("mmseqs"), db_load_mode: int = 0, + threads: int = 32): + + dbSuffix, _, db_load_mode = get_db_suffix(template_db, db_load_mode, "template") + + run_mmseqs(mmseqs, ["search", base.joinpath("prof_res"), template_db, base.joinpath("res_pdb"), + base.joinpath("tmp2"), "--db-load-mode", str(db_load_mode), "--threads", str(threads), "-s", + "7.5", "-a", "-e", "0.1", "--prefilter-mode", str(prefilter_mode)]) + run_mmseqs(mmseqs, ["convertalis", base.joinpath("prof_res"), f"{template_db}{dbSuffix}", + base.joinpath("res_pdb"), base.joinpath(f"{template_db.name}"), "--format-output", + "query,target,fident,alnlen,mismatch,gapopen,qstart,qend,tstart,tend,evalue,bits,cigar", + "--db-output", "1", + "--db-load-mode", str(db_load_mode), "--threads", str(threads)]) + + +@log_function_call +def search_pair( base: Path, uniref_db: Path = Path("uniref30_2302_db"), mmseqs: Path = Path("mmseqs"), - prefilter_mode: int = 0, s: float = 8, threads: int = 64, db_load_mode: int = 2, pairing_strategy: int = 0, ): - if not dbbase.joinpath(f"{uniref_db}.dbtype").is_file(): - raise FileNotFoundError(f"Database {uniref_db} does not exist") - if ( - ( - not dbbase.joinpath(f"{uniref_db}.idx").is_file() - and not dbbase.joinpath(f"{uniref_db}.idx.index").is_file() - ) - or os.environ.get("MMSEQS_IGNORE_INDEX", False) - ): - logger.info("Search does not use index") - db_load_mode = 0 - dbSuffix1 = "_seq" - dbSuffix2 = "_aln" - else: - dbSuffix1 = ".idx" - dbSuffix2 = ".idx" + dbSuffix1, dbSuffix2, db_load_mode = get_db_suffix(uniref_db, db_load_mode) # fmt: off # @formatter:off - search_param = ["--num-iterations", "3", "--db-load-mode", str(db_load_mode), "-a", "-e", "0.1", "--max-seqs", "10000",] - search_param += ["--prefilter-mode", str(prefilter_mode)] + search_param = ["--num-iterations", "3", "--db-load-mode", str(db_load_mode), "-a", "-e", "0.1", "--max-seqs", + "10000",] if s is not None: search_param += ["-s", "{:.1f}".format(s)] else: search_param += ["--k-score", "'seq:96,prof:80'"] expand_param = ["--expansion-mode", "0", "-e", "inf", "--expand-filter-clusters", "0", "--max-seq-id", "0.95",] - run_mmseqs(mmseqs, ["search", base.joinpath("qdb"), dbbase.joinpath(uniref_db), base.joinpath("res"), base.joinpath("tmp"), "--threads", str(threads),] + search_param,) - run_mmseqs(mmseqs, ["expandaln", base.joinpath("qdb"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"), base.joinpath("res"), dbbase.joinpath(f"{uniref_db}{dbSuffix2}"), base.joinpath("res_exp"), "--db-load-mode", str(db_load_mode), "--threads", str(threads),] + expand_param,) - run_mmseqs(mmseqs, ["align", base.joinpath("qdb"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"), base.joinpath("res_exp"), base.joinpath("res_exp_realign"), "--db-load-mode", str(db_load_mode), "-e", "0.001", "--max-accept", "1000000", "--threads", str(threads), "-c", "0.5", "--cov-mode", "1",],) - run_mmseqs(mmseqs, ["pairaln", base.joinpath("qdb"), dbbase.joinpath(f"{uniref_db}"), base.joinpath("res_exp_realign"), base.joinpath("res_exp_realign_pair"), "--db-load-mode", str(db_load_mode), "--pairing-mode", str(pairing_strategy), "--pairing-dummy-mode", "0", "--threads", str(threads), ],) - run_mmseqs(mmseqs, ["align", base.joinpath("qdb"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"), base.joinpath("res_exp_realign_pair"), base.joinpath("res_exp_realign_pair_bt"), "--db-load-mode", str(db_load_mode), "-e", "inf", "-a", "--threads", str(threads), ],) - run_mmseqs(mmseqs, ["pairaln", base.joinpath("qdb"), dbbase.joinpath(f"{uniref_db}"), base.joinpath("res_exp_realign_pair_bt"), base.joinpath("res_final"), "--db-load-mode", str(db_load_mode), "--pairing-mode", str(pairing_strategy), "--pairing-dummy-mode", "1", "--threads", str(threads),],) - run_mmseqs(mmseqs, ["result2msa", base.joinpath("qdb"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"), base.joinpath("res_final"), base.joinpath("pair.a3m"), "--db-load-mode", str(db_load_mode), "--msa-format-mode", "5", "--threads", str(threads),],) - run_mmseqs(mmseqs, ["unpackdb", base.joinpath("pair.a3m"), base.joinpath("."), "--unpack-name-mode", "0", "--unpack-suffix", ".paired.a3m",],) + run_mmseqs(mmseqs, ["search", base.joinpath("qdb"), uniref_db, base.joinpath("res"), base.joinpath("tmp"), + "--threads", str(threads),] + search_param,) + run_mmseqs(mmseqs, ["expandaln", base.joinpath("qdb"), f"{uniref_db}{dbSuffix1}", base.joinpath("res"), + f"{uniref_db}{dbSuffix2}", base.joinpath("res_exp"), "--db-load-mode", str(db_load_mode), + "--threads", str(threads),] + expand_param,) + run_mmseqs(mmseqs, ["align", base.joinpath("qdb"), f"{uniref_db}{dbSuffix1}", base.joinpath("res_exp"), + base.joinpath("res_exp_realign"), "--db-load-mode", str(db_load_mode), "-e", "0.001", + "--max-accept", "1000000", "--threads", str(threads), "-c", "0.5", "--cov-mode", "1",],) + run_mmseqs(mmseqs, ["pairaln", base.joinpath("qdb"), f"{uniref_db}", base.joinpath("res_exp_realign"), + base.joinpath("res_exp_realign_pair"), "--db-load-mode", str(db_load_mode), "--pairing-mode", + str(pairing_strategy), "--pairing-dummy-mode", "0", "--threads", str(threads), ],) + run_mmseqs(mmseqs, ["align", base.joinpath("qdb"), f"{uniref_db}{dbSuffix1}", + base.joinpath("res_exp_realign_pair"), base.joinpath("res_exp_realign_pair_bt"), + "--db-load-mode", str(db_load_mode), "-e", "inf", "-a", "--threads", str(threads), ],) + run_mmseqs(mmseqs, ["pairaln", base.joinpath("qdb"), f"{uniref_db}", + base.joinpath("res_exp_realign_pair_bt"), base.joinpath("res_final"), "--db-load-mode", + str(db_load_mode), "--pairing-mode", str(pairing_strategy), "--pairing-dummy-mode", "1", + "--threads", str(threads),],) + run_mmseqs(mmseqs, ["result2msa", base.joinpath("qdb"), f"{uniref_db}{dbSuffix1}", + base.joinpath("res_final"), base.joinpath("pair.a3m"), "--db-load-mode", str(db_load_mode), + "--msa-format-mode", "5", "--threads", str(threads),],) + run_mmseqs(mmseqs, ["unpackdb", base.joinpath("pair.a3m"), base.joinpath("."), "--unpack-name-mode", "0", + "--unpack-suffix", ".paired.a3m",],) run_mmseqs(mmseqs, ["rmdb", base.joinpath("qdb")]) run_mmseqs(mmseqs, ["rmdb", base.joinpath("qdb_h")]) run_mmseqs(mmseqs, ["rmdb", base.joinpath("res")]) @@ -231,7 +277,256 @@ def mmseqs_search_pair( # fmt: on -def main(): +@log_function_call +def merge_dbs(base, mmseqs: Path = Path("mmseqs")): + msa_files = list(base.glob("*.a3m")) + run_mmseqs(mmseqs, ["mergedbs", base.joinpath("qdb"), base.joinpath("final.a3m")] + msa_files) + for msa_file in msa_files: + run_mmseqs(mmseqs, ["rmdb", msa_file]) + + +@log_function_call +def unpack_msa_files(base, mmseqs: Path = Path("mmseqs"), template_db_name: str = None, use_lookup: bool = False): + if use_lookup: + shutil.copyfile(base.joinpath("qdb.lookup"), base.joinpath("final.a3m.lookup")) + run_mmseqs(mmseqs, ["unpackdb", base.joinpath("final.a3m"), base.joinpath("."), "--unpack-suffix", ".a3m"]) + if template_db_name: + shutil.copyfile(base.joinpath("qdb.lookup"), base.joinpath(f"{template_db_name}.lookup")) + run_mmseqs(mmseqs, ["unpackdb", base.joinpath(f"{template_db_name}"), base.joinpath("."), + "--unpack-suffix", ".m8"]) + else: + run_mmseqs(mmseqs, ["unpackdb", base.joinpath("final.a3m"), base.joinpath("."), "--unpack-name-mode", + "0", "--unpack-suffix", ".a3m"]) + if template_db_name: + run_mmseqs(mmseqs, ["unpackdb", base.joinpath(f"{template_db_name}"), base.joinpath("."), + "--unpack-name-mode", "0", "--unpack-suffix", ".m8"]) + + run_mmseqs(mmseqs, ["rmdb", base.joinpath("final.a3m")]) + run_mmseqs(mmseqs, ["rmdb", base.joinpath("res")]) + run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_pdb")]) + run_mmseqs(mmseqs, ["rmdb", base.joinpath(f"{template_db_name}")]) + shutil.rmtree(base.joinpath("tmp2")) + for file in base.glob("prof_res*"): + file.unlink() + + +@log_function_call +def rename_msa_files(base, queries_unique, template_db_name: str = None, use_lookup=False): + if use_lookup: + safe_rename_files(base, ".a3m", ".a3m") + if template_db_name: + safe_rename_files(base, ".m8", f"_{template_db_name}.m8") + else: + rename_a3m_files(queries_unique, base) + if template_db_name: + rename_m8_files(queries_unique, base, template_db_name) + + +def mmseqs_search_monomer( + base: Path, + uniref_db: Path, + template_db: Path = None, # Unused by default + env_dbs: [Path] = None, + mmseqs: Path = Path("mmseqs"), + filter: int = 1, + expand_eval: float = math.inf, + align_eval: int = 10, + diff: int = 3000, + qsc: float = -20.0, + max_accept: int = 1000000, + s: float = 8, + db_load_mode: int = 0, + threads: int = 32, + prefilter_mode: int = 0, +): + """Run mmseqs with a local colabfold database set + + db1: uniprot db (UniRef30) + db2: Template (unused by default) + db3: metagenomic db (colabfold_envdb_202108 or bfd_mgy_colabfold, the former is preferred) + """ + + search_uniref_db(base=base, uniref_db=uniref_db, filter=filter, mmseqs=mmseqs, align_eval=align_eval, qsc=qsc, + max_accept=max_accept, db_load_mode=db_load_mode, threads=threads, s=s, diff=diff, + prefilter_mode=prefilter_mode) + if env_dbs: + for metagenomic_db in env_dbs: + search_env_db(base=base, metagenomic_db=metagenomic_db, mmseqs=mmseqs, align_eval=align_eval, qsc=qsc, + max_accept=max_accept, db_load_mode=db_load_mode, threads=threads, expand_eval=expand_eval, + diff=diff, filter=filter) + merge_dbs(base, mmseqs) + else: + run_mmseqs(mmseqs, ["mvdb", base.joinpath(f"{uniref_db.name}.a3m"), base.joinpath("final.a3m")]) + + if template_db: + search_template_db(mmseqs=mmseqs, base=base, template_db=template_db, db_load_mode=db_load_mode, + threads=threads, prefilter_mode=prefilter_mode) + + +@log_function_call +def mmseqs_search_multimer( + base: Path, + uniref_db: Path, + mmseqs: Path = Path("mmseqs"), + s: float = 8, + db_load_mode: int = 0, + threads: int = 32, + pairing_strategy: int = 0, + queries_unique: list[list[Union[str, list[str], list[int]]]] = None, +): + search_pair( + mmseqs=mmseqs, + base=base, + uniref_db=uniref_db, + s=s, + db_load_mode=db_load_mode, + threads=threads, + pairing_strategy=pairing_strategy, + ) + + id = 0 + for job_number, ( + raw_jobname, + query_sequences, + query_seqs_cardinality, + ) in enumerate(queries_unique): + unpaired_msa = [] + paired_msa = None + if len(query_seqs_cardinality) > 1: + paired_msa = [] + for _ in query_sequences: + with base.joinpath(f"{id}.a3m").open("r") as f: + unpaired_msa.append(f.read()) + base.joinpath(f"{id}.a3m").unlink() + if len(query_seqs_cardinality) > 1: + with base.joinpath(f"{id}.paired.a3m").open("r") as f: + paired_msa.append(f.read()) + base.joinpath(f"{id}.paired.a3m").unlink() + id += 1 + msa = msa_to_str( + unpaired_msa, paired_msa, query_sequences, query_seqs_cardinality + ) + base.joinpath(f"{job_number}.a3m").write_text(msa) + + +def create_query_db(query, base, mmseqs: Path = Path("mmseqs")): + queries, is_complex = get_queries(query) + + queries_unique = [] + for job_number, (raw_jobname, query_sequences, a3m_lines) in enumerate(queries): + # remove duplicates before searching + query_sequences = ( + [query_sequences] if isinstance(query_sequences, str) else query_sequences + ) + query_seqs_unique = [] + for x in query_sequences: + if x not in query_seqs_unique: + query_seqs_unique.append(x) + query_seqs_cardinality = [0] * len(query_seqs_unique) + for seq in query_sequences: + seq_idx = query_seqs_unique.index(seq) + query_seqs_cardinality[seq_idx] += 1 + + queries_unique.append([raw_jobname, query_seqs_unique, query_seqs_cardinality]) + + base.mkdir(exist_ok=True, parents=True) + query_file = base.joinpath("query.fas") + with query_file.open("w") as f: + for job_number, ( + raw_jobname, + query_sequences, + query_seqs_cardinality, + ) in enumerate(queries_unique): + for j, seq in enumerate(query_sequences): + # The header of first sequence set as 101 + query_seq_headername = 101 + j + f.write(f">{query_seq_headername}\n{seq}\n") + + run_mmseqs( + mmseqs, ["createdb", query_file, base.joinpath("qdb"), "--shuffle", "0"], + ) + with base.joinpath("qdb.lookup").open("w") as f: + id = 0 + file_number = 0 + for job_number, ( + raw_jobname, + query_sequences, + query_seqs_cardinality, + ) in enumerate(queries_unique): + for _ in query_sequences: + raw_jobname_first = raw_jobname.split()[0] + f.write(f"{id}\t{raw_jobname_first}\t{file_number}\n") + id += 1 + file_number += 1 + + return is_complex, queries_unique, query_file + + +@log_function_call +def compute_msas( + query: Path, + base: Path, + dbbase: Path = None, + s: float = None, + env_dbs: [Path] = None, + uniref_db: Path = None, + template_db: Path = None, + filter: int = 1, + mmseqs: Path = Path("mmseqs"), + expand_eval: float = math.inf, + align_eval: int = 10, + diff: int = 3000, + qsc: float = -20.0, + max_accept: int = 1000000, + pairing_strategy: int = 0, + db_load_mode: int = 0, + threads: int = 64, + use_lookup: bool = False, +): + + is_complex, queries_unique, query_file = create_query_db(query, base, mmseqs) + mmseqs_search_monomer( + mmseqs=mmseqs, + base=base, + uniref_db=dbbase/uniref_db, + template_db=dbbase/template_db if template_db else None, + env_dbs=[dbbase/env_db for env_db in env_dbs] if env_dbs else None, + filter=filter, + expand_eval=expand_eval, + align_eval=align_eval, + diff=diff, + qsc=qsc, + max_accept=max_accept, + s=s, + db_load_mode=db_load_mode, + threads=threads, + ) + + unpack_msa_files(base=base, mmseqs=mmseqs, template_db_name=(dbbase / template_db).name if template_db else None, + use_lookup=use_lookup) + + if is_complex is True: + mmseqs_search_multimer( + base=base, + uniref_db=dbbase/uniref_db, + mmseqs=mmseqs, + s=s, + db_load_mode=db_load_mode, + threads=threads, + pairing_strategy=pairing_strategy, + queries_unique=queries_unique, + ) + + rename_msa_files(base=base, queries_unique=queries_unique, + template_db_name=(dbbase / template_db).name if template_db else None, + use_lookup=use_lookup) + + query_file.unlink() + run_mmseqs(mmseqs, ["rmdb", base.joinpath("qdb")]) + run_mmseqs(mmseqs, ["rmdb", base.joinpath("qdb_h")]) + + +def parse_arguments(): parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) parser.add_argument( "query", @@ -251,27 +546,31 @@ def main(): type=int, default=0, choices=[0, 1, 2], - help="Prefiltering algorithm to use: 0: k-mer (high-mem), 1: ungapped (high-cpu), 2: exhaustive (no prefilter, very slow). See wiki for more details: https://github.com/sokrypton/ColabFold/wiki#colabfold_search", + help="Prefiltering algorithm to use: 0: k-mer (high-mem), 1: ungapped (high-cpu), " + "2: exhaustive (no prefilter, very slow). See wiki for more details: " + "https://github.com/sokrypton/ColabFold/wiki#colabfold_search", ) parser.add_argument( "-s", type=float, default=None, - help="MMseqs2 sensitivity. Lowering this will result in a much faster search but possibly sparser MSAs. By default, the k-mer threshold is directly set to the same one of the server, which corresponds to a sensitivity of ~8.", + help="MMseqs2 sensitivity. Lowering this will result in a much faster search but possibly sparser MSAs. " + "By default, the k-mer threshold is directly set to the same one of the server, " + "which corresponds to a sensitivity of ~8.", ) + + parser.add_argument("--uniref-db", type=Path, default=Path("uniref30_2302_db"), + help="UniRef database") + parser.add_argument("--template-db", type=Path, default=Path(""), + help="Templates database") + parser.add_argument("--env-dbs", type=Path, nargs='+', default=[Path("colabfold_envdb_202108_db")], + help="Environmental databases") + # Backwards compatibility # dbs are uniref, templates and environmental # We normally don't use templates - parser.add_argument( - "--db1", type=Path, default=Path("uniref30_2302_db"), help="UniRef database" - ) - parser.add_argument("--db2", type=Path, default=Path(""), help="Templates database") - parser.add_argument( - "--db3", - type=Path, - default=Path("colabfold_envdb_202108_db"), - help="Environmental database", - ) - + parser.add_argument("--db1", type=Path, help="UniRef database") + parser.add_argument("--db2", type=Path, help="Templates database") + parser.add_argument("--db3", type=Path, help="Environmental database") # poor man's boolean arguments parser.add_argument( "--use-env", type=int, default=1, choices=[0, 1], help="Use --db3" @@ -335,138 +634,45 @@ def main(): ) args = parser.parse_args() - logging.basicConfig(level = logging.INFO) + # Backwards compatibility + if args.db1: + args.uniref_db = args.db1 - queries, is_complex = get_queries(args.query, None) - - queries_unique = [] - for job_number, (raw_jobname, query_sequences, a3m_lines) in enumerate(queries): - # remove duplicates before searching - query_sequences = ( - [query_sequences] if isinstance(query_sequences, str) else query_sequences - ) - query_seqs_unique = [] - for x in query_sequences: - if x not in query_seqs_unique: - query_seqs_unique.append(x) - query_seqs_cardinality = [0] * len(query_seqs_unique) - for seq in query_sequences: - seq_idx = query_seqs_unique.index(seq) - query_seqs_cardinality[seq_idx] += 1 - - queries_unique.append([raw_jobname, query_seqs_unique, query_seqs_cardinality]) - - args.base.mkdir(exist_ok=True, parents=True) - query_file = args.base.joinpath("query.fas") - with query_file.open("w") as f: - for job_number, ( - raw_jobname, - query_sequences, - query_seqs_cardinality, - ) in enumerate(queries_unique): - for j, seq in enumerate(query_sequences): - # The header of first sequence set as 101 - query_seq_headername = 101 + j - f.write(f">{query_seq_headername}\n{seq}\n") - - run_mmseqs( - args.mmseqs, - ["createdb", query_file, args.base.joinpath("qdb"), "--shuffle", "0"], - ) - with args.base.joinpath("qdb.lookup").open("w") as f: - id = 0 - file_number = 0 - for job_number, ( - raw_jobname, - query_sequences, - query_seqs_cardinality, - ) in enumerate(queries_unique): - for seq in query_sequences: - raw_jobname_first = raw_jobname.split()[0] - f.write(f"{id}\t{raw_jobname_first}\t{file_number}\n") - id += 1 - file_number += 1 - - mmseqs_search_monomer( - mmseqs=args.mmseqs, - dbbase=args.dbbase, - base=args.base, - uniref_db=args.db1, - template_db=args.db2, - metagenomic_db=args.db3, - use_env=args.use_env, - use_templates=args.use_templates, - filter=args.filter, - expand_eval=args.expand_eval, - align_eval=args.align_eval, - diff=args.diff, - qsc=args.qsc, - max_accept=args.max_accept, - prefilter_mode=args.prefilter_mode, - s=args.s, - db_load_mode=args.db_load_mode, - threads=args.threads, - ) - if is_complex is True: - mmseqs_search_pair( - mmseqs=args.mmseqs, - dbbase=args.dbbase, - base=args.base, - uniref_db=args.db1, - prefilter_mode=args.prefilter_mode, - s=args.s, - db_load_mode=args.db_load_mode, - threads=args.threads, - pairing_strategy=args.pairing_strategy, - ) + if args.use_templates: + if args.db2: + args.template_db = args.db2 + else: + args.template_db = None - id = 0 - for job_number, ( - raw_jobname, - query_sequences, - query_seqs_cardinality, - ) in enumerate(queries_unique): - unpaired_msa = [] - paired_msa = None - if len(query_seqs_cardinality) > 1: - paired_msa = [] - for seq in query_sequences: - with args.base.joinpath(f"{id}.a3m").open("r") as f: - unpaired_msa.append(f.read()) - args.base.joinpath(f"{id}.a3m").unlink() - if len(query_seqs_cardinality) > 1: - with args.base.joinpath(f"{id}.paired.a3m").open("r") as f: - paired_msa.append(f.read()) - args.base.joinpath(f"{id}.paired.a3m").unlink() - id += 1 - msa = msa_to_str( - unpaired_msa, paired_msa, query_sequences, query_seqs_cardinality - ) - args.base.joinpath(f"{job_number}.a3m").write_text(msa) + if args.use_env: + if args.db3: + args.env_dbs = [args.db3] + else: + args.env_dbs = None - # rename a3m files - for job_number, (raw_jobname, query_sequences, query_seqs_cardinality) in enumerate(queries_unique): - os.rename( - args.base.joinpath(f"{job_number}.a3m"), - args.base.joinpath(f"{safe_filename(raw_jobname)}.a3m"), - ) + return args - # rename m8 files - if args.use_templates: - id = 0 - for raw_jobname, query_sequences, query_seqs_cardinality in queries_unique: - with args.base.joinpath(f"{safe_filename(raw_jobname)}_{args.db2}.m8").open( - "w" - ) as f: - for _ in range(len(query_seqs_cardinality)): - with args.base.joinpath(f"{id}.m8").open("r") as g: - f.write(g.read()) - os.remove(args.base.joinpath(f"{id}.m8")) - id += 1 - query_file.unlink() - run_mmseqs(args.mmseqs, ["rmdb", args.base.joinpath("qdb")]) - run_mmseqs(args.mmseqs, ["rmdb", args.base.joinpath("qdb_h")]) +def main(): + args = parse_arguments() + compute_msas(query=args.query, + base=args.base, + dbbase=args.dbbase, + s=args.s, + uniref_db=args.uniref_db, + template_db=args.template_db, + env_dbs=args.env_dbs, + filter=args.filter, + mmseqs=args.mmseqs, + expand_eval=args.expand_eval, + align_eval=args.align_eval, + diff=args.diff, + qsc=args.qsc, + max_accept=args.max_accept, + pairing_strategy=args.pairing_strategy, + db_load_mode=args.db_load_mode, + threads=args.threads, + ) if __name__ == "__main__": diff --git a/colabfold/utils.py b/colabfold/utils.py index 20fe35fa..7f9c2e14 100644 --- a/colabfold/utils.py +++ b/colabfold/utils.py @@ -1,6 +1,7 @@ import json import logging import warnings +import functools from pathlib import Path from typing import Optional @@ -28,6 +29,21 @@ precompute all MSAs with `colabfold_search` or host your own API and pass it to `--host-url` """ + +def log_function_call(func): + """Decorator to log function calls.""" + @functools.wraps(func) + def wrapper(*args, **kwargs): + # Convert args and kwargs to strings to log them + args_str = ', '.join(repr(a) for a in args) + kwargs_str = ', '.join(f'{k}={v!r}' for k, v in kwargs.items()) + logging.info(f"Calling {func.__name__}({args_str}, {kwargs_str})") + result = func(*args, **kwargs) + logging.info(f"{func.__name__} returned {result!r}") + return result + return wrapper + + class TqdmHandler(logging.StreamHandler): """https://stackoverflow.com/a/38895482/3549270"""