diff --git a/lib/galaxy/tools/__init__.py b/lib/galaxy/tools/__init__.py index d463fa379e3b..679fa44fdedc 100644 --- a/lib/galaxy/tools/__init__.py +++ b/lib/galaxy/tools/__init__.py @@ -11,7 +11,19 @@ import threading from datetime import datetime from pathlib import Path -from typing import cast, Dict, List, NamedTuple, Optional, Tuple, Type, Union +from typing import ( + Any, + cast, + Dict, + List, + NamedTuple, + Optional, + Set, + Tuple, + Type, + TYPE_CHECKING, + Union, +) from urllib.parse import unquote_plus import packaging.version @@ -51,7 +63,7 @@ from galaxy.tool_util.toolbox import BaseGalaxyToolBox from galaxy.tool_util.toolbox.views.sources import StaticToolBoxViewSources from galaxy.tools import expressions -from galaxy.tools.actions import DefaultToolAction +from galaxy.tools.actions import DefaultToolAction, ToolAction from galaxy.tools.actions.data_manager import DataManagerToolAction from galaxy.tools.actions.data_source import DataSourceToolAction from galaxy.tools.actions.model_operations import ModelOperationToolAction @@ -115,6 +127,8 @@ MappingParameters, ) +if TYPE_CHECKING: + from galaxy.tools.actions.metadata import SetMetadataToolAction log = logging.getLogger(__name__) @@ -442,7 +456,7 @@ def _get_tool_shed_repository(self, tool_shed, name, owner, installed_changeset_ def __build_tool_version_select_field(self, tools, tool_id, set_selected): """Build a SelectField whose options are the ids for the received list of tools.""" - options = [] + options: List[Tuple[str, str]] = [] for tool in tools: options.insert(0, (tool.version, tool.id)) select_field = SelectField(name='tool_id') @@ -505,6 +519,11 @@ def copy(self): return new_state +class _Options(Bunch): + sanitize: str + refresh: str + + class Tool(Dictifiable): """ Represents a computational tool that can be executed through Galaxy. @@ -514,8 +533,11 @@ class Tool(Dictifiable): requires_setting_metadata = True produces_entry_points = False default_tool_action = DefaultToolAction + tool_action: ToolAction tool_type_local = False dict_collection_visible_keys = ['id', 'name', 'version', 'description', 'labels'] + __help: Optional[threading.Lock] + __help_by_page: Union[threading.Lock, List[str]] def __init__(self, config_file, tool_source, app, guid=None, repository_id=None, tool_shed_repository=None, allow_code_files=True, dynamic=False): """Load a tool from the config named by `config_file`""" @@ -535,9 +557,9 @@ def __init__(self, config_file, tool_source, app, guid=None, repository_id=None, self.stdio_regexes = list() self.inputs_by_page = list() self.display_by_page = list() - self.action = '/tool_runner/index' - self.target = 'galaxy_main' - self.method = 'post' + self.action: Union[str, Tuple[str, str]] = "/tool_runner/index" + self.target = "galaxy_main" + self.method = "post" self.labels = [] self.check_values = True self.nginx_upload = False @@ -728,8 +750,11 @@ def requires_galaxy_python_environment(self): else: unversioned_legacy_tool = self.old_id in GALAXY_LIB_TOOLS_UNVERSIONED versioned_legacy_tool = self.old_id in GALAXY_LIB_TOOLS_VERSIONED - legacy_tool = unversioned_legacy_tool or \ - (versioned_legacy_tool and self.version_object < GALAXY_LIB_TOOLS_VERSIONED[self.old_id]) + legacy_tool = unversioned_legacy_tool or ( + versioned_legacy_tool + and self.old_id + and self.version_object < GALAXY_LIB_TOOLS_VERSIONED[self.old_id] + ) return legacy_tool def __get_job_tool_configuration(self, job_params=None): @@ -941,11 +966,12 @@ def parse(self, tool_source, guid=None, dynamic=False): self.__parse_legacy_features(tool_source) # Load any tool specific options (optional) - self.options = dict( - sanitize=tool_source.parse_sanitize(), - refresh=tool_source.parse_refresh(), + self.options = _Options( + **dict( + sanitize=tool_source.parse_sanitize(), + refresh=tool_source.parse_refresh(), + ) ) - self.options = Bunch(** self.options) # Read in name of galaxy.json metadata file and how to parse it. self.provided_metadata_file = tool_source.parse_provided_metadata_file() @@ -1025,9 +1051,9 @@ def parse(self, tool_source, guid=None, dynamic=False): self.ports = tool_source.parse_interactivetool() def __parse_legacy_features(self, tool_source): - self.code_namespace = dict() - self.hook_map = {} - self.uihints = {} + self.code_namespace: Dict[str, str] = {} + self.hook_map: Dict[str, str] = {} + self.uihints: Dict[str, str] = {} if not hasattr(tool_source, 'root'): return @@ -1051,7 +1077,11 @@ def __parse_legacy_features(self, tool_source): compiled_code = compile(code_string, code_path, 'exec') exec(compiled_code, self.code_namespace) except Exception: - if refactoring_tool and self.python_template_version.release[0] < 3: + if ( + refactoring_tool + and self.python_template_version + and self.python_template_version.release[0] < 3 + ): # Could be a code file that uses python 2 syntax translated_code = str(refactoring_tool.refactor_string(code_string, name='auto_translated_code_file')) compiled_code = compile(translated_code, f"futurized_{code_path}", 'exec') @@ -1200,7 +1230,7 @@ def parse_inputs(self, tool_source): # Load parameters (optional) self.inputs = {} pages = tool_source.parse_input_pages() - enctypes = set() + enctypes: Set[str] = set() if pages.inputs_defined: if hasattr(pages, "input_elem"): input_elem = pages.input_elem @@ -1212,13 +1242,21 @@ def parse_inputs(self, tool_source): # a string. The actual action needs to get url_for run to add any # prefixes, and we want to avoid adding the prefix to the # nginx_upload_path. - if self.nginx_upload and self.app.config.nginx_upload_path: - if '?' in unquote_plus(self.action): - raise Exception('URL parameters in a non-default tool action can not be used ' - 'in conjunction with nginx upload. Please convert them to ' - 'hidden POST parameters') - self.action = (f"{self.app.config.nginx_upload_path}?nginx_redir=", - unquote_plus(self.action)) + if ( + self.nginx_upload + and self.app.config.nginx_upload_path + and not isinstance(self.action, tuple) + ): + if "?" in unquote_plus(self.action): + raise Exception( + "URL parameters in a non-default tool action can not be used " + "in conjunction with nginx upload. Please convert them to " + "hidden POST parameters" + ) + self.action = ( + f"{self.app.config.nginx_upload_path}?nginx_redir=", + unquote_plus(self.action), + ) self.target = input_elem.get("target", self.target) self.method = input_elem.get("method", self.method) # Parse the actual parameters @@ -1291,7 +1329,7 @@ def _parse_citations(self, tool_source): return [] root = tool_source.root - citations = [] + citations: List[str] = [] citations_elem = root.find("citations") if citations_elem is None: return citations @@ -1311,40 +1349,51 @@ def parse_input_elem(self, page_source, enctypes, context=None): groups (repeat, conditional) or param elements. Groups will be parsed recursively. """ - rval = {} + rval: Dict[str, Any] = {} context = ExpressionContext(rval, context) for input_source in page_source.parse_input_sources(): # Repeat group input_type = input_source.parse_input_type() if input_type == "repeat": - group = Repeat() - group.name = input_source.get("name") - group.title = input_source.get("title") - group.help = input_source.get("help", None) + group_r = Repeat() + group_r.name = input_source.get("name") + group_r.title = input_source.get("title") + group_r.help = input_source.get("help", None) page_source = input_source.parse_nested_inputs_source() - group.inputs = self.parse_input_elem(page_source, enctypes, context) - group.default = int(input_source.get("default", 0)) - group.min = int(input_source.get("min", 0)) + group_r.inputs = self.parse_input_elem(page_source, enctypes, context) + group_r.default = int(input_source.get("default", 0)) + group_r.min = int(input_source.get("min", 0)) # Use float instead of int so that 'inf' can be used for no max - group.max = float(input_source.get("max", "inf")) - assert group.min <= group.max, ValueError(f"Tool with id '{self.id}': min repeat count must be less-than-or-equal to the max.") + group_r.max = float(input_source.get("max", "inf")) + assert group_r.min <= group_r.max, ValueError( + f"Tool with id '{self.id}': min repeat count must be less-than-or-equal to the max." + ) # Force default to be within min-max range - group.default = min(max(group.default, group.min), group.max) - rval[group.name] = group + group_r.default = cast( + int, min(max(group_r.default, group_r.min), group_r.max) + ) + rval[group_r.name] = group_r elif input_type == "conditional": - group = Conditional() - group.name = input_source.get("name") - group.value_ref = input_source.get('value_ref', None) - group.value_ref_in_group = input_source.get_bool('value_ref_in_group', True) + group_c = Conditional() + group_c.name = input_source.get("name") + group_c.value_ref = input_source.get("value_ref", None) + group_c.value_ref_in_group = input_source.get_bool( + "value_ref_in_group", True + ) value_from = input_source.get("value_from", None) if value_from: value_from = value_from.split(':') - group.value_from = locals().get(value_from[0]) - group.test_param = rval[group.value_ref] - group.test_param.refresh_on_change = True + temp_value_from = locals().get(value_from[0]) + group_c.test_param = rval[group_c.value_ref] + group_c.test_param.refresh_on_change = True for attr in value_from[1].split('.'): - group.value_from = getattr(group.value_from, attr) - for case_value, case_inputs in group.value_from(context, group, self).items(): + temp_value_from = getattr(temp_value_from, attr) + group_c.value_from = temp_value_from # type: ignore[assignment] + # ^^ due to https://github.com/python/mypy/issues/2427 + assert group_c.value_from + for case_value, case_inputs in group_c.value_from( + context, group_c, self + ).items(): case = ConditionalWhen() case.value = case_value if case_inputs: @@ -1352,60 +1401,84 @@ def parse_input_elem(self, page_source, enctypes, context=None): case.inputs = self.parse_input_elem(page_source, enctypes, context) else: case.inputs = {} - group.cases.append(case) + group_c.cases.append(case) else: # Should have one child "input" which determines the case test_param_input_source = input_source.parse_test_input_source() - group.test_param = self.parse_param_elem(test_param_input_source, enctypes, context) - if group.test_param.optional: + group_c.test_param = self.parse_param_elem( + test_param_input_source, enctypes, context + ) + if group_c.test_param.optional: log.debug(f"Tool with id '{self.id}': declares a conditional test parameter as optional, this is invalid and will be ignored.") - group.test_param.optional = False - possible_cases = list(group.test_param.legal_values) # store possible cases, undefined whens will have no inputs + group_c.test_param.optional = False + possible_cases = list( + group_c.test_param.legal_values + ) # store possible cases, undefined whens will have no inputs # Must refresh when test_param changes - group.test_param.refresh_on_change = True + group_c.test_param.refresh_on_change = True # And a set of possible cases for (value, case_inputs_source) in input_source.parse_when_input_sources(): case = ConditionalWhen() case.value = value case.inputs = self.parse_input_elem(case_inputs_source, enctypes, context) - group.cases.append(case) + group_c.cases.append(case) try: possible_cases.remove(case.value) except Exception: - log.debug("Tool with id '%s': a when tag has been defined for '%s (%s) --> %s', but does not appear to be selectable." % - (self.id, group.name, group.test_param.name, case.value)) + log.debug( + "Tool with id '%s': a when tag has been defined for '%s (%s) --> %s', but does not appear to be selectable." + % ( + self.id, + group_c.name, + group_c.test_param.name, + case.value, + ) + ) for unspecified_case in possible_cases: - log.warning("Tool with id '%s': a when tag has not been defined for '%s (%s) --> %s', assuming empty inputs." % - (self.id, group.name, group.test_param.name, unspecified_case)) + log.warning( + "Tool with id '%s': a when tag has not been defined for '%s (%s) --> %s', assuming empty inputs." + % ( + self.id, + group_c.name, + group_c.test_param.name, + unspecified_case, + ) + ) case = ConditionalWhen() case.value = unspecified_case case.inputs = {} - group.cases.append(case) - rval[group.name] = group + group_c.cases.append(case) + rval[group_c.name] = group_c elif input_type == "section": - group = Section() - group.name = input_source.get("name") - group.title = input_source.get("title") - group.help = input_source.get("help", None) - group.expanded = input_source.get_bool("expanded", False) + group_s = Section() + group_s.name = input_source.get("name") + group_s.title = input_source.get("title") + group_s.help = input_source.get("help", None) + group_s.expanded = input_source.get_bool("expanded", False) page_source = input_source.parse_nested_inputs_source() - group.inputs = self.parse_input_elem(page_source, enctypes, context) - rval[group.name] = group + group_s.inputs = self.parse_input_elem(page_source, enctypes, context) + rval[group_s.name] = group_s elif input_type == "upload_dataset": elem = input_source.elem() - group = UploadDataset() - group.name = elem.get("name") - group.title = elem.get("title") - group.file_type_name = elem.get('file_type_name', group.file_type_name) - group.default_file_type = elem.get('default_file_type', group.default_file_type) - group.metadata_ref = elem.get('metadata_ref', group.metadata_ref) + group_u = UploadDataset() + group_u.name = elem.get("name") + group_u.title = elem.get("title") + group_u.file_type_name = elem.get( + "file_type_name", group_u.file_type_name + ) + group_u.default_file_type = elem.get( + "default_file_type", group_u.default_file_type + ) + group_u.metadata_ref = elem.get("metadata_ref", group_u.metadata_ref) try: - rval[group.file_type_name].refresh_on_change = True + rval[group_u.file_type_name].refresh_on_change = True except KeyError: pass group_page_source = XmlPageSource(elem) - group.inputs = self.parse_input_elem(group_page_source, enctypes, context) - rval[group.name] = group + group_u.inputs = self.parse_input_elem( + group_page_source, enctypes, context + ) + rval[group_u.name] = group_u elif input_type == "param": param = self.parse_param_elem(input_source, enctypes, context) rval[param.name] = param @@ -1501,7 +1574,7 @@ def __ensure_help(self): def __inititalize_help(self): tool_source = self.__help_source self.__help = None - self.__help_by_page = [] + __help_by_page = [] help_footer = "" help_text = tool_source.parse_help() if help_text is not None: @@ -1527,20 +1600,25 @@ def __inititalize_help(self): # Multiple help page case if help_pages: for help_page in help_pages: - self.__help_by_page.append(help_page.text) + __help_by_page.append(help_page.text) help_footer = help_footer + help_page.tail # Each page has to rendered all-together because of backreferences allowed by rst try: - self.__help_by_page = [Template(rst_to_html(help_header + x + help_footer), - input_encoding='utf-8', - default_filters=['decode.utf8'], - encoding_errors='replace') - for x in self.__help_by_page] + __help_by_page = [ + Template( + rst_to_html(help_header + x + help_footer), + input_encoding="utf-8", + default_filters=["decode.utf8"], + encoding_errors="replace", + ) + for x in __help_by_page + ] except Exception: log.exception("Exception while parsing multi-page help for tool with id '%s'", self.id) # Pad out help pages to match npages ... could this be done better? - while len(self.__help_by_page) < self.npages: - self.__help_by_page.append(self.__help) + while len(__help_by_page) < self.npages: + __help_by_page.append(self.__help) + self.__help_by_page = __help_by_page def find_output_def(self, name): # name is JobToOutputDatasetAssociation name. @@ -1658,7 +1736,7 @@ def expand_incoming(self, trans, incoming, request_context, input_format='legacy all_params = [] for expanded_incoming in expanded_incomings: params = {} - errors = {} + errors: Dict[str, str] = {} if self.input_translator: self.input_translator.translate(expanded_incoming) if not self.check_values: @@ -2058,8 +2136,8 @@ def to_archive(self): tool = self tarball_files = [] temp_files = [] - with open(os.path.abspath(tool.config_file)) as fh: - tool_xml = fh.read() + with open(os.path.abspath(tool.config_file)) as fh1: + tool_xml = fh1.read() # Retrieve tool help images and rewrite the tool's xml into a temporary file with the path # modified to be relative to the repository root. image_found = False @@ -2079,9 +2157,11 @@ def to_archive(self): tool_xml = tool_xml.replace('${static_path}/%s' % tarball_path, tarball_path) # If one or more tool help images were found, add the modified tool XML to the tarball instead of the original. if image_found: - with tempfile.NamedTemporaryFile(mode='w', suffix='.xml', delete=False) as fh: - new_tool_config = fh.name - fh.write(tool_xml) + with tempfile.NamedTemporaryFile( + mode="w", suffix=".xml", delete=False + ) as fh2: + new_tool_config = fh2.name + fh2.write(tool_xml) tool_tup = (new_tool_config, os.path.split(tool.config_file)[-1]) temp_files.append(new_tool_config) else: @@ -2139,15 +2219,19 @@ def to_archive(self): if len(data_table_definitions) > 0: # Put the data table definition XML in a temporary file. table_definition = '\n\n %s' - table_definition = table_definition % '\n'.join(data_table_definitions) - with tempfile.NamedTemporaryFile(mode='w', delete=False) as fh: - table_conf = fh.name - fh.write(table_definition) + table_definition = table_definition % "\n".join( + data_table_definitions + ) + with tempfile.NamedTemporaryFile( + mode="w", delete=False + ) as fh3: + table_conf = fh3.name + fh3.write(table_definition) tarball_files.append((table_conf, os.path.join('tool-data', 'tool_data_table_conf.xml.sample'))) temp_files.append(table_conf) # Create the tarball. - with tempfile.NamedTemporaryFile(suffix='.tgz', delete=False) as fh: - tarball_archive = fh.name + with tempfile.NamedTemporaryFile(suffix=".tgz", delete=False) as fh4: + tarball_archive = fh4.name tarball = tarfile.open(name=tarball_archive, mode='w:gz') # Add the files from the previously generated list. for fspath, tarpath in tarball_files: @@ -2263,8 +2347,8 @@ def to_json(self, trans, kwd=None, job=None, workflow_building_mode=False, histo set_dataset_matcher_factory(request_context, self) # create tool state - state_inputs = {} - state_errors = {} + state_inputs: Dict[str, str] = {} + state_errors: Dict[str, str] = {} populate_state(request_context, self.inputs, params.__dict__, state_inputs, state_errors) # create tool model @@ -2320,7 +2404,8 @@ def populate_model(self, request_context, inputs, state_inputs, group_inputs, ot if input.type == 'repeat': tool_dict = input.to_dict(request_context) group_size = len(group_state) - group_cache = tool_dict['cache'] = [None] * group_size + tool_dict["cache"] = [None] * group_size + group_cache: List[List[str]] = tool_dict["cache"] for i in range(group_size): group_cache[i] = [] self.populate_model(request_context, input.inputs, group_state[i], group_cache[i], other_values) @@ -2519,7 +2604,11 @@ def exec_before_job(self, app, inp_data, out_data, param_dict=None): json_params['output_data'].append(data_dict) if json_filename is None: json_filename = file_name - with open(json_filename, 'w') as out: + if not json_filename: + raise Exception( + "Must call 'exec_before_job' with 'out_data' containing at least one entry." + ) + with open(json_filename, "w") as out: out.write(json.dumps(json_params)) @@ -2567,7 +2656,7 @@ def exec_before_job(self, app, inp_data, out_data, param_dict=None): if param_dict is None: raise Exception("Internal error - param_dict is empty.") - job = {} + job: Dict[str, str] = {} json_wrap(self.inputs, param_dict, self.profile, job, handle_files='OBJECT') expression_inputs = { 'job': job, @@ -2670,7 +2759,11 @@ def exec_before_job(self, app, inp_data, out_data, param_dict=None): json_params['output_data'].append(data_dict) if json_filename is None: json_filename = file_name - with open(json_filename, 'w') as out: + if not json_filename: + raise Exception( + "Must call 'exec_before_job' with 'out_data' containing at least one entry." + ) + with open(json_filename, "w") as out: out.write(json.dumps(json_params)) @@ -2692,6 +2785,7 @@ class SetMetadataTool(Tool): """ tool_type = 'set_metadata' requires_setting_metadata = False + tool_action: "SetMetadataToolAction" def regenerate_imported_metadata_if_needed(self, hda, history, job): if len(hda.metadata_file_types) > 0: @@ -3032,7 +3126,7 @@ def produce_outputs(self, trans, out_data, output_collections, incoming, history new_element_structure = {} # Which inputs does the identifier appear in. - identifiers_map = {} + identifiers_map: Dict[str, List[int]] = {} for input_num, input_list in enumerate(input_lists): for dce in input_list.collection.elements: element_identifier = dce.element_identifier @@ -3072,7 +3166,7 @@ def produce_outputs(self, trans, out_data, output_collections, incoming, history if dupl_actions == "keep_first" and identifier_seen: continue - if add_suffix: + if add_suffix and suffix_pattern: suffix = suffix_pattern.replace("#", str(copy + 1)) effective_identifer = f"{element_identifier}{suffix}" else: diff --git a/lib/galaxy/tools/actions/__init__.py b/lib/galaxy/tools/actions/__init__.py index e3fd7c3741f6..5a25e7495efc 100644 --- a/lib/galaxy/tools/actions/__init__.py +++ b/lib/galaxy/tools/actions/__init__.py @@ -51,12 +51,12 @@ class ToolAction: been converted and validated). """ - def execute(self, tool, trans, incoming=None, set_output_hid=True): + def execute(self, tool, trans, incoming=None, set_output_hid=True, **kwargs): incoming = incoming or {} raise TypeError("Abstract method") -class DefaultToolAction: +class DefaultToolAction(ToolAction): """Default tool action is to run an external command""" produces_real_jobs = True diff --git a/lib/galaxy/tools/parameters/grouping.py b/lib/galaxy/tools/parameters/grouping.py index 0774cd4e14bb..b0fb702c06be 100644 --- a/lib/galaxy/tools/parameters/grouping.py +++ b/lib/galaxy/tools/parameters/grouping.py @@ -5,7 +5,7 @@ import logging import os import unicodedata -from typing import Any, Dict, List, Optional, TYPE_CHECKING +from typing import Any, Callable, Dict, List, Mapping, Optional, TYPE_CHECKING from galaxy.datatypes import data, sniff from galaxy.exceptions import ( @@ -24,6 +24,7 @@ if TYPE_CHECKING: from galaxy.tools.parameter.basic import ToolParameter + from galaxy.tools import Tool log = logging.getLogger(__name__) URI_PREFIXES = [f"{x}://" for x in ["http", "https", "ftp", "file", "gxfiles", "gximport", "gxuserimport", "gxftp"]] @@ -687,6 +688,9 @@ def get_filenames(context): class Conditional(Group): type = "conditional" + value_from: Callable[ + ["Conditional", ExpressionContext, "Conditional", "Tool"], Mapping[str, str] + ] def __init__(self): Group.__init__(self) diff --git a/setup.cfg b/setup.cfg index ef0494971047..bd0ffaaae7f8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -583,8 +583,6 @@ check_untyped_defs = False check_untyped_defs = False [mypy-galaxy.queue_worker] check_untyped_defs = False -[mypy-galaxy.tools] -check_untyped_defs = False [mypy-galaxy.jobs.mapper] check_untyped_defs = False [mypy-galaxy.jobs.runners]