Skip to content

Commit

Permalink
Merge pull request galaxyproject#17601 from nsoranzo/type_annots
Browse files Browse the repository at this point in the history
Type annotation improvements
  • Loading branch information
mvdbeek authored Mar 6, 2024
2 parents 1d5fb96 + fac8eed commit bc0d075
Show file tree
Hide file tree
Showing 9 changed files with 55 additions and 43 deletions.
7 changes: 6 additions & 1 deletion lib/galaxy/job_metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
)
from typing import (
Any,
cast,
Dict,
List,
NamedTuple,
Optional,
TYPE_CHECKING,
)

from galaxy import util
Expand All @@ -34,6 +36,9 @@
Safety,
)

if TYPE_CHECKING:
from galaxy.job_metrics.instrumenters import InstrumentPlugin

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -72,7 +77,7 @@ class JobMetrics:

def __init__(self, conf_file=None, conf_dict=None, **kwargs):
"""Load :class:`JobInstrumenter` objects from specified configuration file."""
self.plugin_classes = self.__plugins_dict()
self.plugin_classes = cast(Dict[str, "InstrumentPlugin"], self.__plugins_dict())
if conf_file and os.path.exists(conf_file):
self.default_job_instrumenter = JobInstrumenter.from_file(self.plugin_classes, conf_file, **kwargs)
elif conf_dict or conf_dict is None:
Expand Down
4 changes: 2 additions & 2 deletions lib/galaxy/job_metrics/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ class FormattedMetric(NamedTuple):
class JobMetricFormatter:
"""Format job metric key-value pairs for human consumption in Web UI."""

def format(self, key: Any, value: Any) -> FormattedMetric:
return FormattedMetric(str(key), str(value))
def format(self, key: str, value: Any) -> FormattedMetric:
return FormattedMetric(key, str(value))


def seconds_to_str(value: int) -> str:
Expand Down
3 changes: 0 additions & 3 deletions lib/galaxy/model/dataset_collections/types/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@ class ListDatasetCollectionType(BaseDatasetCollectionType):

collection_type = "list"

def __init__(self):
pass

def generate_elements(self, elements):
for identifier, element in elements.items():
association = DatasetCollectionElement(
Expand Down
3 changes: 0 additions & 3 deletions lib/galaxy/model/dataset_collections/types/paired.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ class PairedDatasetCollectionType(BaseDatasetCollectionType):

collection_type = "paired"

def __init__(self):
pass

def generate_elements(self, elements):
if forward_dataset := elements.get(FORWARD_IDENTIFIER):
left_association = DatasetCollectionElement(
Expand Down
6 changes: 3 additions & 3 deletions lib/galaxy/tool_util/cwl/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ def galactic_job_json(
for Galaxy.
"""

datasets = []
dataset_collections = []
datasets: List[Dict[str, Any]] = []
dataset_collections: List[Dict[str, Any]] = []

def response_to_hda(target: UploadTarget, upload_response: Dict[str, Any]) -> Dict[str, str]:
assert isinstance(upload_response, dict), upload_response
Expand Down Expand Up @@ -277,7 +277,7 @@ def replacement_file(value):

return upload_file(file_path, secondary_files_tar_path, filetype=filetype, **kwd)

def replacement_directory(value):
def replacement_directory(value: Dict[str, Any]) -> Dict[str, Any]:
file_path = value.get("location", None) or value.get("path", None)
if file_path is None:
return value
Expand Down
16 changes: 10 additions & 6 deletions lib/galaxy/tools/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)

from packaging.version import Version
from typing_extensions import TypeAlias

from galaxy.model import (
DatasetCollection,
Expand Down Expand Up @@ -605,6 +606,9 @@ def __bool__(self) -> bool:
__nonzero__ = __bool__


DatasetCollectionElementWrapper: TypeAlias = Union["DatasetCollectionWrapper", DatasetFilenameWrapper]


class DatasetCollectionWrapper(ToolParameterValueWrapper, HasDatasets):
name: Optional[str]
collection: DatasetCollection
Expand Down Expand Up @@ -642,15 +646,15 @@ def __init__(
self.collection = collection

elements = collection.elements
element_instances = {}
element_instances: Dict[str, DatasetCollectionElementWrapper] = {}

element_instance_list = []
element_instance_list: List[DatasetCollectionElementWrapper] = []
for dataset_collection_element in elements:
element_object = dataset_collection_element.element_object
element_identifier = dataset_collection_element.element_identifier

if dataset_collection_element.is_collection:
element_wrapper: Union[DatasetCollectionWrapper, DatasetFilenameWrapper] = DatasetCollectionWrapper(
element_wrapper: DatasetCollectionElementWrapper = DatasetCollectionWrapper(
job_working_directory, dataset_collection_element, **kwargs
)
else:
Expand Down Expand Up @@ -757,15 +761,15 @@ def serialize(
def is_input_supplied(self) -> bool:
return self.__input_supplied

def __getitem__(self, key: Union[str, int]) -> Union[None, "DatasetCollectionWrapper", DatasetFilenameWrapper]:
def __getitem__(self, key: Union[str, int]) -> Optional[DatasetCollectionElementWrapper]:
if not self.__input_supplied:
return None
if isinstance(key, int):
return self.__element_instance_list[key]
else:
return self.__element_instances[key]

def __getattr__(self, key: str) -> Union[None, "DatasetCollectionWrapper", DatasetFilenameWrapper]:
def __getattr__(self, key: str) -> Optional[DatasetCollectionElementWrapper]:
if not self.__input_supplied:
return None
try:
Expand All @@ -775,7 +779,7 @@ def __getattr__(self, key: str) -> Union[None, "DatasetCollectionWrapper", Datas

def __iter__(
self,
) -> Iterator[Union["DatasetCollectionWrapper", DatasetFilenameWrapper]]:
) -> Iterator[DatasetCollectionElementWrapper]:
if not self.__input_supplied:
return [].__iter__()
return self.__element_instance_list.__iter__()
Expand Down
33 changes: 24 additions & 9 deletions lib/galaxy/workflow/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,9 @@ def decode_runtime_state(self, step, runtime_state):
state.decode(runtime_state, Bunch(inputs=self.get_runtime_inputs(step)), self.trans.app)
return state

def execute(self, trans, progress, invocation_step, use_cached_job=False):
def execute(
self, trans, progress: "WorkflowProgress", invocation_step, use_cached_job: bool = False
) -> Optional[bool]:
"""Execute the given workflow invocation step.
Use the supplied workflow progress object to track outputs, find
Expand Down Expand Up @@ -508,7 +510,7 @@ def get_informal_replacement_parameters(self, step) -> List[str]:

return []

def compute_collection_info(self, progress, step, all_inputs):
def compute_collection_info(self, progress: "WorkflowProgress", step, all_inputs):
"""
Use get_all_inputs (if implemented) to determine collection mapping for execution.
"""
Expand All @@ -526,7 +528,7 @@ def compute_collection_info(self, progress, step, all_inputs):
collection_info.when_values = progress.when_values
return collection_info or progress.subworkflow_collection_info

def _find_collections_to_match(self, progress, step, all_inputs):
def _find_collections_to_match(self, progress: "WorkflowProgress", step, all_inputs) -> matching.CollectionsToMatch:
collections_to_match = matching.CollectionsToMatch()
dataset_collection_type_descriptions = self.trans.app.dataset_collection_manager.collection_type_descriptions

Expand Down Expand Up @@ -756,7 +758,9 @@ def get_post_job_actions(self, incoming):
def get_content_id(self):
return self.trans.security.encode_id(self.subworkflow.id)

def execute(self, trans, progress, invocation_step, use_cached_job=False):
def execute(
self, trans, progress: "WorkflowProgress", invocation_step, use_cached_job: bool = False
) -> Optional[bool]:
"""Execute the given workflow step in the given workflow invocation.
Use the supplied workflow progress object to track outputs, find
inputs, etc...
Expand Down Expand Up @@ -929,7 +933,9 @@ def get_runtime_state(self):
def get_all_inputs(self, data_only=False, connectable_only=False):
return []

def execute(self, trans, progress, invocation_step, use_cached_job=False):
def execute(
self, trans, progress: "WorkflowProgress", invocation_step, use_cached_job: bool = False
) -> Optional[bool]:
invocation = invocation_step.workflow_invocation
step = invocation_step.workflow_step
input_value = step.state.inputs["input"]
Expand Down Expand Up @@ -963,6 +969,7 @@ def execute(self, trans, progress, invocation_step, use_cached_job=False):
if content:
invocation.add_input(content, step.id)
progress.set_outputs_for_input(invocation_step, step_outputs)
return None

def recover_mapping(self, invocation_step, progress):
progress.set_outputs_for_input(invocation_step, already_persisted=True)
Expand Down Expand Up @@ -1522,7 +1529,9 @@ def get_all_outputs(self, data_only=False):
)
]

def execute(self, trans, progress, invocation_step, use_cached_job=False):
def execute(
self, trans, progress: "WorkflowProgress", invocation_step, use_cached_job: bool = False
) -> Optional[bool]:
step = invocation_step.workflow_step
input_value = step.state.inputs["input"]
if input_value is None:
Expand All @@ -1535,6 +1544,7 @@ def execute(self, trans, progress, invocation_step, use_cached_job=False):
input_value = default_value.get("value", NO_REPLACEMENT)
step_outputs = dict(output=input_value)
progress.set_outputs_for_input(invocation_step, step_outputs)
return None

def step_state_to_tool_state(self, state):
state = safe_loads(state)
Expand Down Expand Up @@ -1666,9 +1676,12 @@ def get_runtime_state(self):
state.inputs = dict()
return state

def execute(self, trans, progress, invocation_step, use_cached_job=False):
def execute(
self, trans, progress: "WorkflowProgress", invocation_step, use_cached_job: bool = False
) -> Optional[bool]:
step = invocation_step.workflow_step
progress.mark_step_outputs_delayed(step, why="executing pause step")
return None

def recover_mapping(self, invocation_step, progress):
if invocation_step:
Expand Down Expand Up @@ -2131,7 +2144,9 @@ def decode_runtime_state(self, step, runtime_state):
f"Tool {self.tool_id} missing. Cannot recover runtime state.", tool_id=self.tool_id
)

def execute(self, trans, progress, invocation_step, use_cached_job=False):
def execute(
self, trans, progress: "WorkflowProgress", invocation_step, use_cached_job: bool = False
) -> Optional[bool]:
invocation = invocation_step.workflow_invocation
step = invocation_step.workflow_step
tool = trans.app.toolbox.get_tool(step.tool_id, tool_version=step.tool_version, tool_uuid=step.tool_uuid)
Expand Down Expand Up @@ -2171,7 +2186,7 @@ def execute(self, trans, progress, invocation_step, use_cached_job=False):
found_replacement_keys = set()

# Connect up
def callback(input, prefixed_name, **kwargs):
def callback(input, prefixed_name: str, **kwargs):
input_dict = all_inputs_by_name[prefixed_name]

replacement: Union[model.Dataset, NoReplacement] = NO_REPLACEMENT
Expand Down
15 changes: 10 additions & 5 deletions lib/galaxy/workflow/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ def queue_invoke(


class WorkflowInvoker:
progress: "WorkflowProgress"

def __init__(
self,
trans: "WorkRequestContext",
Expand Down Expand Up @@ -348,7 +350,7 @@ class WorkflowProgress:
def __init__(
self,
workflow_invocation: WorkflowInvocation,
inputs_by_step_id: Any,
inputs_by_step_id: Dict[int, Any],
module_injector: ModuleInjector,
param_map: Dict[int, Dict[str, Any]],
jobs_per_scheduling_iteration: int = -1,
Expand Down Expand Up @@ -415,7 +417,7 @@ def remaining_steps(
remaining_steps.append((step, invocation_step))
return remaining_steps

def replacement_for_input(self, trans, step: "WorkflowStep", input_dict: Dict[str, Any]) -> Any:
def replacement_for_input(self, trans, step: "WorkflowStep", input_dict: Dict[str, Any]):
replacement: Union[
modules.NoReplacement,
model.DatasetCollectionInstance,
Expand Down Expand Up @@ -447,7 +449,7 @@ def replacement_for_input(self, trans, step: "WorkflowStep", input_dict: Dict[st
replacement = raw_to_galaxy(trans, step_input.default_value)
return replacement

def replacement_for_connection(self, connection: "WorkflowStepConnection", is_data: bool = True) -> Any:
def replacement_for_connection(self, connection: "WorkflowStepConnection", is_data: bool = True):
output_step_id = connection.output_step.id
output_name = connection.output_name
if output_step_id not in self.outputs:
Expand Down Expand Up @@ -530,7 +532,7 @@ def replacement_for_connection(self, connection: "WorkflowStepConnection", is_da

return replacement

def get_replacement_workflow_output(self, workflow_output: "WorkflowOutput") -> Any:
def get_replacement_workflow_output(self, workflow_output: "WorkflowOutput"):
step = workflow_output.workflow_step
output_name = workflow_output.output_name
step_outputs = self.outputs[step.id]
Expand All @@ -541,7 +543,10 @@ def get_replacement_workflow_output(self, workflow_output: "WorkflowOutput") ->
return step_outputs[output_name]

def set_outputs_for_input(
self, invocation_step: WorkflowInvocationStep, outputs: Any = None, already_persisted: bool = False
self,
invocation_step: WorkflowInvocationStep,
outputs: Optional[Dict[str, Any]] = None,
already_persisted: bool = False,
) -> None:
step = invocation_step.workflow_step

Expand Down
11 changes: 0 additions & 11 deletions test/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,6 @@ def celery_includes():
return ["galaxy.celery.tasks"]


def pytest_collection_finish(session):
try:
# This needs to be run after test collection
from .test_config_defaults import DRIVER

DRIVER.tear_down()
print("Galaxy test driver shutdown successful")
except Exception:
pass


@pytest.fixture
def temp_file():
with tempfile.NamedTemporaryFile(delete=True, mode="wb") as fh:
Expand Down

0 comments on commit bc0d075

Please sign in to comment.