diff --git a/lib/galaxy/job_metrics/__init__.py b/lib/galaxy/job_metrics/__init__.py index 0a315ffdf31d..13a93e6a8ca0 100644 --- a/lib/galaxy/job_metrics/__init__.py +++ b/lib/galaxy/job_metrics/__init__.py @@ -20,10 +20,12 @@ ) from typing import ( Any, + cast, Dict, List, NamedTuple, Optional, + TYPE_CHECKING, ) from galaxy import util @@ -34,6 +36,9 @@ Safety, ) +if TYPE_CHECKING: + from galaxy.job_metrics.instrumenters import InstrumentPlugin + log = logging.getLogger(__name__) @@ -72,7 +77,7 @@ class JobMetrics: def __init__(self, conf_file=None, conf_dict=None, **kwargs): """Load :class:`JobInstrumenter` objects from specified configuration file.""" - self.plugin_classes = self.__plugins_dict() + self.plugin_classes = cast(Dict[str, "InstrumentPlugin"], self.__plugins_dict()) if conf_file and os.path.exists(conf_file): self.default_job_instrumenter = JobInstrumenter.from_file(self.plugin_classes, conf_file, **kwargs) elif conf_dict or conf_dict is None: diff --git a/lib/galaxy/job_metrics/formatting.py b/lib/galaxy/job_metrics/formatting.py index b37de650f96f..cbb4bff4b84b 100644 --- a/lib/galaxy/job_metrics/formatting.py +++ b/lib/galaxy/job_metrics/formatting.py @@ -14,8 +14,8 @@ class FormattedMetric(NamedTuple): class JobMetricFormatter: """Format job metric key-value pairs for human consumption in Web UI.""" - def format(self, key: Any, value: Any) -> FormattedMetric: - return FormattedMetric(str(key), str(value)) + def format(self, key: str, value: Any) -> FormattedMetric: + return FormattedMetric(key, str(value)) def seconds_to_str(value: int) -> str: diff --git a/lib/galaxy/model/dataset_collections/types/list.py b/lib/galaxy/model/dataset_collections/types/list.py index 6fccedf1728e..18ce4db76537 100644 --- a/lib/galaxy/model/dataset_collections/types/list.py +++ b/lib/galaxy/model/dataset_collections/types/list.py @@ -7,9 +7,6 @@ class ListDatasetCollectionType(BaseDatasetCollectionType): collection_type = "list" - def __init__(self): - pass - def generate_elements(self, elements): for identifier, element in elements.items(): association = DatasetCollectionElement( diff --git a/lib/galaxy/model/dataset_collections/types/paired.py b/lib/galaxy/model/dataset_collections/types/paired.py index cfbd19c04343..4ae95a1442a2 100644 --- a/lib/galaxy/model/dataset_collections/types/paired.py +++ b/lib/galaxy/model/dataset_collections/types/paired.py @@ -15,9 +15,6 @@ class PairedDatasetCollectionType(BaseDatasetCollectionType): collection_type = "paired" - def __init__(self): - pass - def generate_elements(self, elements): if forward_dataset := elements.get(FORWARD_IDENTIFIER): left_association = DatasetCollectionElement( diff --git a/lib/galaxy/tool_util/cwl/util.py b/lib/galaxy/tool_util/cwl/util.py index fd32c9079263..4d2dc2977d11 100644 --- a/lib/galaxy/tool_util/cwl/util.py +++ b/lib/galaxy/tool_util/cwl/util.py @@ -146,8 +146,8 @@ def galactic_job_json( for Galaxy. """ - datasets = [] - dataset_collections = [] + datasets: List[Dict[str, Any]] = [] + dataset_collections: List[Dict[str, Any]] = [] def response_to_hda(target: UploadTarget, upload_response: Dict[str, Any]) -> Dict[str, str]: assert isinstance(upload_response, dict), upload_response @@ -277,7 +277,7 @@ def replacement_file(value): return upload_file(file_path, secondary_files_tar_path, filetype=filetype, **kwd) - def replacement_directory(value): + def replacement_directory(value: Dict[str, Any]) -> Dict[str, Any]: file_path = value.get("location", None) or value.get("path", None) if file_path is None: return value diff --git a/lib/galaxy/tools/wrappers.py b/lib/galaxy/tools/wrappers.py index b8591ea749d5..8f8e79be9aaa 100644 --- a/lib/galaxy/tools/wrappers.py +++ b/lib/galaxy/tools/wrappers.py @@ -20,6 +20,7 @@ ) from packaging.version import Version +from typing_extensions import TypeAlias from galaxy.model import ( DatasetCollection, @@ -605,6 +606,9 @@ def __bool__(self) -> bool: __nonzero__ = __bool__ +DatasetCollectionElementWrapper: TypeAlias = Union["DatasetCollectionWrapper", DatasetFilenameWrapper] + + class DatasetCollectionWrapper(ToolParameterValueWrapper, HasDatasets): name: Optional[str] collection: DatasetCollection @@ -642,15 +646,15 @@ def __init__( self.collection = collection elements = collection.elements - element_instances = {} + element_instances: Dict[str, DatasetCollectionElementWrapper] = {} - element_instance_list = [] + element_instance_list: List[DatasetCollectionElementWrapper] = [] for dataset_collection_element in elements: element_object = dataset_collection_element.element_object element_identifier = dataset_collection_element.element_identifier if dataset_collection_element.is_collection: - element_wrapper: Union[DatasetCollectionWrapper, DatasetFilenameWrapper] = DatasetCollectionWrapper( + element_wrapper: DatasetCollectionElementWrapper = DatasetCollectionWrapper( job_working_directory, dataset_collection_element, **kwargs ) else: @@ -757,7 +761,7 @@ def serialize( def is_input_supplied(self) -> bool: return self.__input_supplied - def __getitem__(self, key: Union[str, int]) -> Union[None, "DatasetCollectionWrapper", DatasetFilenameWrapper]: + def __getitem__(self, key: Union[str, int]) -> Optional[DatasetCollectionElementWrapper]: if not self.__input_supplied: return None if isinstance(key, int): @@ -765,7 +769,7 @@ def __getitem__(self, key: Union[str, int]) -> Union[None, "DatasetCollectionWra else: return self.__element_instances[key] - def __getattr__(self, key: str) -> Union[None, "DatasetCollectionWrapper", DatasetFilenameWrapper]: + def __getattr__(self, key: str) -> Optional[DatasetCollectionElementWrapper]: if not self.__input_supplied: return None try: @@ -775,7 +779,7 @@ def __getattr__(self, key: str) -> Union[None, "DatasetCollectionWrapper", Datas def __iter__( self, - ) -> Iterator[Union["DatasetCollectionWrapper", DatasetFilenameWrapper]]: + ) -> Iterator[DatasetCollectionElementWrapper]: if not self.__input_supplied: return [].__iter__() return self.__element_instance_list.__iter__() diff --git a/lib/galaxy/workflow/modules.py b/lib/galaxy/workflow/modules.py index da53d9de6159..a6138d2588d2 100644 --- a/lib/galaxy/workflow/modules.py +++ b/lib/galaxy/workflow/modules.py @@ -463,7 +463,9 @@ def decode_runtime_state(self, step, runtime_state): state.decode(runtime_state, Bunch(inputs=self.get_runtime_inputs(step)), self.trans.app) return state - def execute(self, trans, progress, invocation_step, use_cached_job=False): + def execute( + self, trans, progress: "WorkflowProgress", invocation_step, use_cached_job: bool = False + ) -> Optional[bool]: """Execute the given workflow invocation step. Use the supplied workflow progress object to track outputs, find @@ -508,7 +510,7 @@ def get_informal_replacement_parameters(self, step) -> List[str]: return [] - def compute_collection_info(self, progress, step, all_inputs): + def compute_collection_info(self, progress: "WorkflowProgress", step, all_inputs): """ Use get_all_inputs (if implemented) to determine collection mapping for execution. """ @@ -526,7 +528,7 @@ def compute_collection_info(self, progress, step, all_inputs): collection_info.when_values = progress.when_values return collection_info or progress.subworkflow_collection_info - def _find_collections_to_match(self, progress, step, all_inputs): + def _find_collections_to_match(self, progress: "WorkflowProgress", step, all_inputs) -> matching.CollectionsToMatch: collections_to_match = matching.CollectionsToMatch() dataset_collection_type_descriptions = self.trans.app.dataset_collection_manager.collection_type_descriptions @@ -756,7 +758,9 @@ def get_post_job_actions(self, incoming): def get_content_id(self): return self.trans.security.encode_id(self.subworkflow.id) - def execute(self, trans, progress, invocation_step, use_cached_job=False): + def execute( + self, trans, progress: "WorkflowProgress", invocation_step, use_cached_job: bool = False + ) -> Optional[bool]: """Execute the given workflow step in the given workflow invocation. Use the supplied workflow progress object to track outputs, find inputs, etc... @@ -929,7 +933,9 @@ def get_runtime_state(self): def get_all_inputs(self, data_only=False, connectable_only=False): return [] - def execute(self, trans, progress, invocation_step, use_cached_job=False): + def execute( + self, trans, progress: "WorkflowProgress", invocation_step, use_cached_job: bool = False + ) -> Optional[bool]: invocation = invocation_step.workflow_invocation step = invocation_step.workflow_step input_value = step.state.inputs["input"] @@ -963,6 +969,7 @@ def execute(self, trans, progress, invocation_step, use_cached_job=False): if content: invocation.add_input(content, step.id) progress.set_outputs_for_input(invocation_step, step_outputs) + return None def recover_mapping(self, invocation_step, progress): progress.set_outputs_for_input(invocation_step, already_persisted=True) @@ -1522,7 +1529,9 @@ def get_all_outputs(self, data_only=False): ) ] - def execute(self, trans, progress, invocation_step, use_cached_job=False): + def execute( + self, trans, progress: "WorkflowProgress", invocation_step, use_cached_job: bool = False + ) -> Optional[bool]: step = invocation_step.workflow_step input_value = step.state.inputs["input"] if input_value is None: @@ -1535,6 +1544,7 @@ def execute(self, trans, progress, invocation_step, use_cached_job=False): input_value = default_value.get("value", NO_REPLACEMENT) step_outputs = dict(output=input_value) progress.set_outputs_for_input(invocation_step, step_outputs) + return None def step_state_to_tool_state(self, state): state = safe_loads(state) @@ -1666,9 +1676,12 @@ def get_runtime_state(self): state.inputs = dict() return state - def execute(self, trans, progress, invocation_step, use_cached_job=False): + def execute( + self, trans, progress: "WorkflowProgress", invocation_step, use_cached_job: bool = False + ) -> Optional[bool]: step = invocation_step.workflow_step progress.mark_step_outputs_delayed(step, why="executing pause step") + return None def recover_mapping(self, invocation_step, progress): if invocation_step: @@ -2131,7 +2144,9 @@ def decode_runtime_state(self, step, runtime_state): f"Tool {self.tool_id} missing. Cannot recover runtime state.", tool_id=self.tool_id ) - def execute(self, trans, progress, invocation_step, use_cached_job=False): + def execute( + self, trans, progress: "WorkflowProgress", invocation_step, use_cached_job: bool = False + ) -> Optional[bool]: invocation = invocation_step.workflow_invocation step = invocation_step.workflow_step tool = trans.app.toolbox.get_tool(step.tool_id, tool_version=step.tool_version, tool_uuid=step.tool_uuid) @@ -2171,7 +2186,7 @@ def execute(self, trans, progress, invocation_step, use_cached_job=False): found_replacement_keys = set() # Connect up - def callback(input, prefixed_name, **kwargs): + def callback(input, prefixed_name: str, **kwargs): input_dict = all_inputs_by_name[prefixed_name] replacement: Union[model.Dataset, NoReplacement] = NO_REPLACEMENT diff --git a/lib/galaxy/workflow/run.py b/lib/galaxy/workflow/run.py index c7eab00a7a1d..4d1c024fbdd7 100644 --- a/lib/galaxy/workflow/run.py +++ b/lib/galaxy/workflow/run.py @@ -144,6 +144,8 @@ def queue_invoke( class WorkflowInvoker: + progress: "WorkflowProgress" + def __init__( self, trans: "WorkRequestContext", @@ -348,7 +350,7 @@ class WorkflowProgress: def __init__( self, workflow_invocation: WorkflowInvocation, - inputs_by_step_id: Any, + inputs_by_step_id: Dict[int, Any], module_injector: ModuleInjector, param_map: Dict[int, Dict[str, Any]], jobs_per_scheduling_iteration: int = -1, @@ -415,7 +417,7 @@ def remaining_steps( remaining_steps.append((step, invocation_step)) return remaining_steps - def replacement_for_input(self, trans, step: "WorkflowStep", input_dict: Dict[str, Any]) -> Any: + def replacement_for_input(self, trans, step: "WorkflowStep", input_dict: Dict[str, Any]): replacement: Union[ modules.NoReplacement, model.DatasetCollectionInstance, @@ -447,7 +449,7 @@ def replacement_for_input(self, trans, step: "WorkflowStep", input_dict: Dict[st replacement = raw_to_galaxy(trans, step_input.default_value) return replacement - def replacement_for_connection(self, connection: "WorkflowStepConnection", is_data: bool = True) -> Any: + def replacement_for_connection(self, connection: "WorkflowStepConnection", is_data: bool = True): output_step_id = connection.output_step.id output_name = connection.output_name if output_step_id not in self.outputs: @@ -530,7 +532,7 @@ def replacement_for_connection(self, connection: "WorkflowStepConnection", is_da return replacement - def get_replacement_workflow_output(self, workflow_output: "WorkflowOutput") -> Any: + def get_replacement_workflow_output(self, workflow_output: "WorkflowOutput"): step = workflow_output.workflow_step output_name = workflow_output.output_name step_outputs = self.outputs[step.id] @@ -541,7 +543,10 @@ def get_replacement_workflow_output(self, workflow_output: "WorkflowOutput") -> return step_outputs[output_name] def set_outputs_for_input( - self, invocation_step: WorkflowInvocationStep, outputs: Any = None, already_persisted: bool = False + self, + invocation_step: WorkflowInvocationStep, + outputs: Optional[Dict[str, Any]] = None, + already_persisted: bool = False, ) -> None: step = invocation_step.workflow_step diff --git a/test/integration/conftest.py b/test/integration/conftest.py index f450af85511f..3a0619d076cf 100644 --- a/test/integration/conftest.py +++ b/test/integration/conftest.py @@ -10,17 +10,6 @@ def celery_includes(): return ["galaxy.celery.tasks"] -def pytest_collection_finish(session): - try: - # This needs to be run after test collection - from .test_config_defaults import DRIVER - - DRIVER.tear_down() - print("Galaxy test driver shutdown successful") - except Exception: - pass - - @pytest.fixture def temp_file(): with tempfile.NamedTemporaryFile(delete=True, mode="wb") as fh: