Skip to content

Commit

Permalink
Type annotations and fixes
Browse files Browse the repository at this point in the history
- Use `galaxy.util.path.StrPath` everywhere
  • Loading branch information
nsoranzo committed Sep 30, 2024
1 parent 3f39b62 commit cd54557
Show file tree
Hide file tree
Showing 27 changed files with 195 additions and 181 deletions.
7 changes: 4 additions & 3 deletions lib/galaxy/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
GalaxyTaskBeforeStartUserRateLimitPostgres,
GalaxyTaskBeforeStartUserRateLimitStandard,
)
from galaxy.config import GalaxyAppConfiguration
from galaxy.config_watchers import ConfigWatchers
from galaxy.datatypes.registry import Registry
from galaxy.files import (
Expand Down Expand Up @@ -206,7 +207,7 @@ def shutdown(self):


class SentryClientMixin:
config: config.GalaxyAppConfiguration
config: GalaxyAppConfiguration
application_stack: ApplicationStack

def configure_sentry_client(self):
Expand Down Expand Up @@ -263,7 +264,7 @@ class MinimalGalaxyApplication(BasicSharedApp, HaltableContainer, SentryClientMi
"""Encapsulates the state of a minimal Galaxy application"""

model: GalaxyModelMapping
config: config.GalaxyAppConfiguration
config: GalaxyAppConfiguration
tool_cache: ToolCache
job_config: jobs.JobConfiguration
toolbox_search: ToolBoxSearch
Expand All @@ -287,7 +288,7 @@ def __init__(self, fsmon=False, **kwargs) -> None:
self.name = "galaxy"
self.is_webapp = False
# Read config file and check for errors
self.config = self._register_singleton(config.GalaxyAppConfiguration, config.GalaxyAppConfiguration(**kwargs))
self.config = self._register_singleton(GalaxyAppConfiguration, GalaxyAppConfiguration(**kwargs))
self.config.check()
config_file = kwargs.get("global_conf", {}).get("__file__", None)
if config_file:
Expand Down
7 changes: 4 additions & 3 deletions lib/galaxy/app_unittest_utils/tools_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from galaxy.tool_util.parser import get_tool_source
from galaxy.tools import create_tool_from_source
from galaxy.util.bunch import Bunch
from galaxy.util.path import StrPath

datatypes_registry = galaxy.datatypes.registry.Registry()
datatypes_registry.load_datatypes()
Expand Down Expand Up @@ -83,10 +84,10 @@ def _init_tool(
tool_id="test_tool",
extra_file_contents=None,
extra_file_path=None,
tool_path=None,
tool_path: Optional[StrPath] = None,
):
if tool_path is None:
self.tool_file = os.path.join(self.test_directory, filename)
self.tool_file: StrPath = os.path.join(self.test_directory, filename)
contents_template = string.Template(tool_contents)
tool_contents = contents_template.safe_substitute(dict(version=version, profile=profile, tool_id=tool_id))
self.__write_tool(tool_contents)
Expand All @@ -96,7 +97,7 @@ def _init_tool(
self.tool_file = tool_path
return self.__setup_tool()

def _init_tool_for_path(self, tool_file):
def _init_tool_for_path(self, tool_file: StrPath):
self.tool_file = tool_file
return self.__setup_tool()

Expand Down
6 changes: 2 additions & 4 deletions lib/galaxy/celery/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,8 @@ def setup_data_table_manager(app):


@lru_cache
def cached_create_tool_from_representation(app, raw_tool_source):
return create_tool_from_representation(
app=app, raw_tool_source=raw_tool_source, tool_dir="", tool_source_class="XmlToolSource"
)
def cached_create_tool_from_representation(app: MinimalManagerApp, raw_tool_source: str):
return create_tool_from_representation(app=app, raw_tool_source=raw_tool_source, tool_source_class="XmlToolSource")


@galaxy_task(action="recalculate a user's disk usage")
Expand Down
1 change: 1 addition & 0 deletions lib/galaxy/datatypes/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ class Data(metaclass=DataMeta):
edam_data = "data_0006"
edam_format = "format_1915"
file_ext = "data"
is_subclass = False
# Data is not chunkable by default.
CHUNKABLE = False

Expand Down
85 changes: 47 additions & 38 deletions lib/galaxy/datatypes/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@
import logging
import os
import pkgutil
from pathlib import Path
from string import Template
from typing import (
Any,
cast,
Dict,
Iterable,
List,
Optional,
Tuple,
Type,
TYPE_CHECKING,
Union,
)
Expand All @@ -24,8 +25,12 @@
import galaxy.util
from galaxy.datatypes.protocols import DatasetProtocol
from galaxy.tool_util.edam_util import load_edam_tree
from galaxy.util import RW_R__R__
from galaxy.util import (
Element,
RW_R__R__,
)
from galaxy.util.bunch import Bunch
from galaxy.util.path import StrPath
from . import (
binary,
coverage,
Expand Down Expand Up @@ -65,7 +70,7 @@ def __init__(self, config=None):
self.log.addHandler(logging.NullHandler())
self.config = config
self.edam = edam
self.datatypes_by_extension = {}
self.datatypes_by_extension: Dict[str, Data] = {}
self.datatypes_by_suffix_inferences = {}
self.mimetypes_by_extension = {}
self.datatype_converters = {}
Expand All @@ -75,7 +80,7 @@ def __init__(self, config=None):
self.converter_deps = {}
self.available_tracks = []
self.set_external_metadata_tool = None
self.sniff_order = []
self.sniff_order: List[Data] = []
self.upload_file_formats = []
# Datatype elements defined in local datatypes_conf.xml that contain display applications.
self.display_app_containers = []
Expand Down Expand Up @@ -105,7 +110,7 @@ def __init__(self, config=None):
def load_datatypes(
self,
root_dir=None,
config=None,
config: Optional[Union[Element, StrPath]] = None,
override=True,
use_converters=True,
use_display_applications=True,
Expand All @@ -127,8 +132,8 @@ def __import_module(full_path: str, datatype_module: str):
return module

if root_dir and config:
compressed_sniffers = {}
if isinstance(config, (str, Path)):
compressed_sniffers: Dict[Type[Data], List[Data]] = {}
if isinstance(config, (str, os.PathLike)):
# Parse datatypes_conf.xml
tree = galaxy.util.parse_xml(config)
root = tree.getroot()
Expand All @@ -137,6 +142,7 @@ def __import_module(full_path: str, datatype_module: str):
else:
root = config
registration = root.find("registration")
assert registration is not None
# Set default paths defined in local datatypes_conf.xml.
if use_converters:
if not self.converters_path:
Expand Down Expand Up @@ -167,7 +173,6 @@ def __import_module(full_path: str, datatype_module: str):

for elem in registration.findall("datatype"):
# Keep a status of the process steps to enable stopping the process of handling the datatype if necessary.
ok = True
extension = self.get_extension(elem)
dtype = elem.get("type", None)
type_extension = elem.get("type_extension", None)
Expand Down Expand Up @@ -199,7 +204,9 @@ def __import_module(full_path: str, datatype_module: str):
if override or extension not in self.datatypes_by_extension:
can_process_datatype = True
if can_process_datatype:
datatype_class: Optional[Type[Data]] = None
if dtype is not None:
ok = True
try:
fields = dtype.split(":")
datatype_module = fields[0]
Expand All @@ -208,21 +215,18 @@ def __import_module(full_path: str, datatype_module: str):
self.log.exception("Error parsing datatype definition for dtype %s", str(dtype))
ok = False
if ok:
datatype_class = None
if datatype_class is None:
try:
# The datatype class name must be contained in one of the datatype modules in the Galaxy distribution.
fields = datatype_module.split(".")[1:]
module = __import__(datatype_module)
for mod in fields:
module = getattr(module, mod)
datatype_class = getattr(module, datatype_class_name)
self.log.debug(
f"Retrieved datatype module {str(datatype_module)}:{datatype_class_name} from the datatype registry for extension {extension}."
)
except Exception:
self.log.exception("Error importing datatype module %s", str(datatype_module))
ok = False
try:
# The datatype class name must be contained in one of the datatype modules in the Galaxy distribution.
fields = datatype_module.split(".")[1:]
module = __import__(datatype_module)
for mod in fields:
module = getattr(module, mod)
datatype_class = getattr(module, datatype_class_name)
self.log.debug(
f"Retrieved datatype module {str(datatype_module)}:{datatype_class_name} from the datatype registry for extension {extension}."
)
except Exception:
self.log.exception("Error importing datatype module %s", str(datatype_module))
elif type_extension is not None:
try:
datatype_class = self.datatypes_by_extension[type_extension].__class__
Expand All @@ -233,8 +237,7 @@ def __import_module(full_path: str, datatype_module: str):
self.log.exception(
"Error determining datatype_class for type_extension %s", str(type_extension)
)
ok = False
if ok:
if datatype_class:
# A new tool shed repository that contains custom datatypes is being installed, and since installation is
# occurring after the datatypes registry has been initialized at server startup, its contents cannot be
# overridden by new introduced conflicting data types unless the value of override is True.
Expand Down Expand Up @@ -262,7 +265,7 @@ def __import_module(full_path: str, datatype_module: str):
for upload_warning_el in upload_warning_els:
if upload_warning_template is not None:
raise NotImplementedError("Multiple upload_warnings not implemented")
upload_warning_template = Template(upload_warning_el.text)
upload_warning_template = Template(upload_warning_el.text or "")
datatype_instance = datatype_class()
self.datatypes_by_extension[extension] = datatype_instance
if mimetype is None:
Expand All @@ -282,9 +285,9 @@ def __import_module(full_path: str, datatype_module: str):
# compressed files in the future (e.g. maybe some day faz will be a compressed fasta
# or something along those lines)
for infer_from in elem.findall("infer_from"):
suffix = infer_from.get("suffix", None)
suffix = infer_from.get("suffix")
if suffix is None:
raise Exception("Failed to parse infer_from datatype element")
raise ConfigurationError("Failed to parse infer_from datatype element")
infer_from_suffixes.append(suffix)
self.datatypes_by_suffix_inferences[suffix] = datatype_instance
for converter in elem.findall("converter"):
Expand All @@ -300,9 +303,11 @@ def __import_module(full_path: str, datatype_module: str):
self.converters.append((converter_config, extension, target_datatype))
# Add composite files.
for composite_file in elem.findall("composite_file"):
name = composite_file.get("name", None)
name = composite_file.get("name")
if name is None:
self.log.warning(f"You must provide a name for your composite_file ({composite_file}).")
raise ConfigurationError(
f"You must provide a name for your composite_file ({composite_file})."
)
optional = composite_file.get("optional", False)
mimetype = composite_file.get("mimetype", None)
self.datatypes_by_extension[extension].add_composite_file(
Expand All @@ -321,8 +326,8 @@ def __import_module(full_path: str, datatype_module: str):
composite_files = datatype_instance.get_composite_files()
if composite_files:
_composite_files = []
for name, composite_file in composite_files.items():
_composite_file = composite_file.dict()
for name, composite_file_bunch in composite_files.items():
_composite_file = composite_file_bunch.dict()
_composite_file["name"] = name
_composite_files.append(_composite_file)
datatype_info_dict["composite_files"] = _composite_files
Expand All @@ -332,16 +337,18 @@ def __import_module(full_path: str, datatype_module: str):
compressed_extension = f"{extension}.{auto_compressed_type}"
upper_compressed_type = auto_compressed_type[0].upper() + auto_compressed_type[1:]
auto_compressed_type_name = datatype_class_name + upper_compressed_type
attributes = {}
attributes: Dict[str, Any] = {}
if auto_compressed_type == "gz":
dynamic_parent = binary.GzDynamicCompressedArchive
dynamic_parent: Type[binary.DynamicCompressedArchive] = (
binary.GzDynamicCompressedArchive
)
elif auto_compressed_type == "bz2":
dynamic_parent = binary.Bz2DynamicCompressedArchive
else:
raise Exception(f"Unknown auto compression type [{auto_compressed_type}]")
raise ConfigurationError(f"Unknown auto compression type [{auto_compressed_type}]")
attributes["file_ext"] = compressed_extension
attributes["uncompressed_datatype_instance"] = datatype_instance
compressed_datatype_class = type(
compressed_datatype_class: Type[Data] = type(
auto_compressed_type_name,
(
datatype_class,
Expand Down Expand Up @@ -411,7 +418,7 @@ def __import_module(full_path: str, datatype_module: str):
self._load_build_sites(root)
self.set_default_values()

def append_to_sniff_order():
def append_to_sniff_order() -> None:
sniff_order_classes = {type(_) for _ in self.sniff_order}
for datatype in self.datatypes_by_extension.values():
# Add a datatype only if it is not already in sniff_order, it
Expand Down Expand Up @@ -482,7 +489,9 @@ def get_legacy_sites_by_build(self, site_type, build):
def get_display_sites(self, site_type):
return self.display_sites.get(site_type, [])

def load_datatype_sniffers(self, root, override=False, compressed_sniffers=None):
def load_datatype_sniffers(
self, root, override=False, compressed_sniffers: Optional[Dict[Type["Data"], List["Data"]]] = None
):
"""
Process the sniffers element from a parsed a datatypes XML file located at root_dir/config (if processing the Galaxy
distributed config) or contained within an installed Tool Shed repository.
Expand Down
13 changes: 9 additions & 4 deletions lib/galaxy/datatypes/sniff.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
Callable,
Dict,
IO,
Iterable,
NamedTuple,
Optional,
TYPE_CHECKING,
Union,
)

Expand All @@ -44,6 +46,8 @@
pass
import magic # isort:skip

if TYPE_CHECKING:
from .data import Data

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -689,7 +693,7 @@ def _get_file_prefix(filename_or_file_prefix: Union[str, FilePrefix], auto_decom
return filename_or_file_prefix


def run_sniffers_raw(file_prefix: FilePrefix, sniff_order):
def run_sniffers_raw(file_prefix: FilePrefix, sniff_order: Iterable["Data"]):
"""Run through sniffers specified by sniff_order, return None of None match."""
fname = file_prefix.filename
file_ext = None
Expand Down Expand Up @@ -718,15 +722,16 @@ def run_sniffers_raw(file_prefix: FilePrefix, sniff_order):
continue
try:
if hasattr(datatype, "sniff_prefix"):
if file_prefix.compressed_format and getattr(datatype, "compressed_format", None):
datatype_compressed_format = getattr(datatype, "compressed_format", None)
if file_prefix.compressed_format and datatype_compressed_format:
# Compare the compressed format detected
# to the expected.
if file_prefix.compressed_format != datatype.compressed_format:
if file_prefix.compressed_format != datatype_compressed_format:
continue
if datatype.sniff_prefix(file_prefix):
file_ext = datatype.file_ext
break
elif datatype.sniff(fname):
elif hasattr(datatype, "sniff") and datatype.sniff(fname):
file_ext = datatype.file_ext
break
except Exception:
Expand Down
7 changes: 2 additions & 5 deletions lib/galaxy/files/sources/util.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import time
from os import PathLike
from typing import (
List,
Optional,
Tuple,
Union,
)

from galaxy import exceptions
Expand All @@ -20,8 +18,7 @@
requests,
)
from galaxy.util.config_parsers import IpAllowedListEntryT

TargetPathT = Union[str, PathLike]
from galaxy.util.path import StrPath


def _not_implemented(drs_uri: str, desc: str) -> NotImplementedError:
Expand Down Expand Up @@ -79,7 +76,7 @@ def _get_access_info(obj_url: str, access_method: dict, headers=None) -> Tuple[s

def fetch_drs_to_file(
drs_uri: str,
target_path: TargetPathT,
target_path: StrPath,
user_context: Optional[FileSourcesUserContext],
force_http=False,
retry_options: Optional[RetryOptions] = None,
Expand Down
Loading

0 comments on commit cd54557

Please sign in to comment.