From 46fb01670ab50c39f1c8c70243534891271a77e3 Mon Sep 17 00:00:00 2001 From: Nicola Soranzo Date: Tue, 5 Mar 2024 13:47:02 +0000 Subject: [PATCH] Add/fix type annotations --- lib/galaxy/job_metrics/__init__.py | 7 +++++- lib/galaxy/job_metrics/formatting.py | 2 +- lib/galaxy/tool_util/cwl/util.py | 6 ++--- lib/galaxy/tools/wrappers.py | 4 ++-- lib/galaxy/workflow/modules.py | 33 ++++++++++++++++++++-------- lib/galaxy/workflow/run.py | 15 ++++++++----- 6 files changed, 46 insertions(+), 21 deletions(-) diff --git a/lib/galaxy/job_metrics/__init__.py b/lib/galaxy/job_metrics/__init__.py index 0a315ffdf31d..13a93e6a8ca0 100644 --- a/lib/galaxy/job_metrics/__init__.py +++ b/lib/galaxy/job_metrics/__init__.py @@ -20,10 +20,12 @@ ) from typing import ( Any, + cast, Dict, List, NamedTuple, Optional, + TYPE_CHECKING, ) from galaxy import util @@ -34,6 +36,9 @@ Safety, ) +if TYPE_CHECKING: + from galaxy.job_metrics.instrumenters import InstrumentPlugin + log = logging.getLogger(__name__) @@ -72,7 +77,7 @@ class JobMetrics: def __init__(self, conf_file=None, conf_dict=None, **kwargs): """Load :class:`JobInstrumenter` objects from specified configuration file.""" - self.plugin_classes = self.__plugins_dict() + self.plugin_classes = cast(Dict[str, "InstrumentPlugin"], self.__plugins_dict()) if conf_file and os.path.exists(conf_file): self.default_job_instrumenter = JobInstrumenter.from_file(self.plugin_classes, conf_file, **kwargs) elif conf_dict or conf_dict is None: diff --git a/lib/galaxy/job_metrics/formatting.py b/lib/galaxy/job_metrics/formatting.py index b37de650f96f..49515867883c 100644 --- a/lib/galaxy/job_metrics/formatting.py +++ b/lib/galaxy/job_metrics/formatting.py @@ -14,7 +14,7 @@ class FormattedMetric(NamedTuple): class JobMetricFormatter: """Format job metric key-value pairs for human consumption in Web UI.""" - def format(self, key: Any, value: Any) -> FormattedMetric: + def format(self, key: str, value: Any) -> FormattedMetric: return FormattedMetric(str(key), str(value)) diff --git a/lib/galaxy/tool_util/cwl/util.py b/lib/galaxy/tool_util/cwl/util.py index fd32c9079263..4d2dc2977d11 100644 --- a/lib/galaxy/tool_util/cwl/util.py +++ b/lib/galaxy/tool_util/cwl/util.py @@ -146,8 +146,8 @@ def galactic_job_json( for Galaxy. """ - datasets = [] - dataset_collections = [] + datasets: List[Dict[str, Any]] = [] + dataset_collections: List[Dict[str, Any]] = [] def response_to_hda(target: UploadTarget, upload_response: Dict[str, Any]) -> Dict[str, str]: assert isinstance(upload_response, dict), upload_response @@ -277,7 +277,7 @@ def replacement_file(value): return upload_file(file_path, secondary_files_tar_path, filetype=filetype, **kwd) - def replacement_directory(value): + def replacement_directory(value: Dict[str, Any]) -> Dict[str, Any]: file_path = value.get("location", None) or value.get("path", None) if file_path is None: return value diff --git a/lib/galaxy/tools/wrappers.py b/lib/galaxy/tools/wrappers.py index b8591ea749d5..eff1590cacc5 100644 --- a/lib/galaxy/tools/wrappers.py +++ b/lib/galaxy/tools/wrappers.py @@ -642,9 +642,9 @@ def __init__( self.collection = collection elements = collection.elements - element_instances = {} + element_instances: Dict[str, Union[DatasetCollectionWrapper, DatasetFilenameWrapper]] = {} - element_instance_list = [] + element_instance_list: List[Union[DatasetCollectionWrapper, DatasetFilenameWrapper]] = [] for dataset_collection_element in elements: element_object = dataset_collection_element.element_object element_identifier = dataset_collection_element.element_identifier diff --git a/lib/galaxy/workflow/modules.py b/lib/galaxy/workflow/modules.py index da53d9de6159..a6138d2588d2 100644 --- a/lib/galaxy/workflow/modules.py +++ b/lib/galaxy/workflow/modules.py @@ -463,7 +463,9 @@ def decode_runtime_state(self, step, runtime_state): state.decode(runtime_state, Bunch(inputs=self.get_runtime_inputs(step)), self.trans.app) return state - def execute(self, trans, progress, invocation_step, use_cached_job=False): + def execute( + self, trans, progress: "WorkflowProgress", invocation_step, use_cached_job: bool = False + ) -> Optional[bool]: """Execute the given workflow invocation step. Use the supplied workflow progress object to track outputs, find @@ -508,7 +510,7 @@ def get_informal_replacement_parameters(self, step) -> List[str]: return [] - def compute_collection_info(self, progress, step, all_inputs): + def compute_collection_info(self, progress: "WorkflowProgress", step, all_inputs): """ Use get_all_inputs (if implemented) to determine collection mapping for execution. """ @@ -526,7 +528,7 @@ def compute_collection_info(self, progress, step, all_inputs): collection_info.when_values = progress.when_values return collection_info or progress.subworkflow_collection_info - def _find_collections_to_match(self, progress, step, all_inputs): + def _find_collections_to_match(self, progress: "WorkflowProgress", step, all_inputs) -> matching.CollectionsToMatch: collections_to_match = matching.CollectionsToMatch() dataset_collection_type_descriptions = self.trans.app.dataset_collection_manager.collection_type_descriptions @@ -756,7 +758,9 @@ def get_post_job_actions(self, incoming): def get_content_id(self): return self.trans.security.encode_id(self.subworkflow.id) - def execute(self, trans, progress, invocation_step, use_cached_job=False): + def execute( + self, trans, progress: "WorkflowProgress", invocation_step, use_cached_job: bool = False + ) -> Optional[bool]: """Execute the given workflow step in the given workflow invocation. Use the supplied workflow progress object to track outputs, find inputs, etc... @@ -929,7 +933,9 @@ def get_runtime_state(self): def get_all_inputs(self, data_only=False, connectable_only=False): return [] - def execute(self, trans, progress, invocation_step, use_cached_job=False): + def execute( + self, trans, progress: "WorkflowProgress", invocation_step, use_cached_job: bool = False + ) -> Optional[bool]: invocation = invocation_step.workflow_invocation step = invocation_step.workflow_step input_value = step.state.inputs["input"] @@ -963,6 +969,7 @@ def execute(self, trans, progress, invocation_step, use_cached_job=False): if content: invocation.add_input(content, step.id) progress.set_outputs_for_input(invocation_step, step_outputs) + return None def recover_mapping(self, invocation_step, progress): progress.set_outputs_for_input(invocation_step, already_persisted=True) @@ -1522,7 +1529,9 @@ def get_all_outputs(self, data_only=False): ) ] - def execute(self, trans, progress, invocation_step, use_cached_job=False): + def execute( + self, trans, progress: "WorkflowProgress", invocation_step, use_cached_job: bool = False + ) -> Optional[bool]: step = invocation_step.workflow_step input_value = step.state.inputs["input"] if input_value is None: @@ -1535,6 +1544,7 @@ def execute(self, trans, progress, invocation_step, use_cached_job=False): input_value = default_value.get("value", NO_REPLACEMENT) step_outputs = dict(output=input_value) progress.set_outputs_for_input(invocation_step, step_outputs) + return None def step_state_to_tool_state(self, state): state = safe_loads(state) @@ -1666,9 +1676,12 @@ def get_runtime_state(self): state.inputs = dict() return state - def execute(self, trans, progress, invocation_step, use_cached_job=False): + def execute( + self, trans, progress: "WorkflowProgress", invocation_step, use_cached_job: bool = False + ) -> Optional[bool]: step = invocation_step.workflow_step progress.mark_step_outputs_delayed(step, why="executing pause step") + return None def recover_mapping(self, invocation_step, progress): if invocation_step: @@ -2131,7 +2144,9 @@ def decode_runtime_state(self, step, runtime_state): f"Tool {self.tool_id} missing. Cannot recover runtime state.", tool_id=self.tool_id ) - def execute(self, trans, progress, invocation_step, use_cached_job=False): + def execute( + self, trans, progress: "WorkflowProgress", invocation_step, use_cached_job: bool = False + ) -> Optional[bool]: invocation = invocation_step.workflow_invocation step = invocation_step.workflow_step tool = trans.app.toolbox.get_tool(step.tool_id, tool_version=step.tool_version, tool_uuid=step.tool_uuid) @@ -2171,7 +2186,7 @@ def execute(self, trans, progress, invocation_step, use_cached_job=False): found_replacement_keys = set() # Connect up - def callback(input, prefixed_name, **kwargs): + def callback(input, prefixed_name: str, **kwargs): input_dict = all_inputs_by_name[prefixed_name] replacement: Union[model.Dataset, NoReplacement] = NO_REPLACEMENT diff --git a/lib/galaxy/workflow/run.py b/lib/galaxy/workflow/run.py index c7eab00a7a1d..4d1c024fbdd7 100644 --- a/lib/galaxy/workflow/run.py +++ b/lib/galaxy/workflow/run.py @@ -144,6 +144,8 @@ def queue_invoke( class WorkflowInvoker: + progress: "WorkflowProgress" + def __init__( self, trans: "WorkRequestContext", @@ -348,7 +350,7 @@ class WorkflowProgress: def __init__( self, workflow_invocation: WorkflowInvocation, - inputs_by_step_id: Any, + inputs_by_step_id: Dict[int, Any], module_injector: ModuleInjector, param_map: Dict[int, Dict[str, Any]], jobs_per_scheduling_iteration: int = -1, @@ -415,7 +417,7 @@ def remaining_steps( remaining_steps.append((step, invocation_step)) return remaining_steps - def replacement_for_input(self, trans, step: "WorkflowStep", input_dict: Dict[str, Any]) -> Any: + def replacement_for_input(self, trans, step: "WorkflowStep", input_dict: Dict[str, Any]): replacement: Union[ modules.NoReplacement, model.DatasetCollectionInstance, @@ -447,7 +449,7 @@ def replacement_for_input(self, trans, step: "WorkflowStep", input_dict: Dict[st replacement = raw_to_galaxy(trans, step_input.default_value) return replacement - def replacement_for_connection(self, connection: "WorkflowStepConnection", is_data: bool = True) -> Any: + def replacement_for_connection(self, connection: "WorkflowStepConnection", is_data: bool = True): output_step_id = connection.output_step.id output_name = connection.output_name if output_step_id not in self.outputs: @@ -530,7 +532,7 @@ def replacement_for_connection(self, connection: "WorkflowStepConnection", is_da return replacement - def get_replacement_workflow_output(self, workflow_output: "WorkflowOutput") -> Any: + def get_replacement_workflow_output(self, workflow_output: "WorkflowOutput"): step = workflow_output.workflow_step output_name = workflow_output.output_name step_outputs = self.outputs[step.id] @@ -541,7 +543,10 @@ def get_replacement_workflow_output(self, workflow_output: "WorkflowOutput") -> return step_outputs[output_name] def set_outputs_for_input( - self, invocation_step: WorkflowInvocationStep, outputs: Any = None, already_persisted: bool = False + self, + invocation_step: WorkflowInvocationStep, + outputs: Optional[Dict[str, Any]] = None, + already_persisted: bool = False, ) -> None: step = invocation_step.workflow_step