diff --git a/.gitignore b/.gitignore index 379bd2972..81f54f990 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ *.py[cod] +.hypothesis .mypy_cache # doc builds diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c4227fd18..254fc9855 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,16 +14,37 @@ repos: args: [--in-place, --remove-all-unused-imports, --remove-unused-variable, --ignore-init-module-imports] - repo: https://github.com/psf/black - rev: 22.1.0 + rev: 22.3.0 hooks: - id: black args: [--line-length, '120'] - repo: https://github.com/PyCQA/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.1.1 + hooks: + - id: mypy + additional_dependencies: + - jinja2 + - numpy + - pymongo-stubs + - ruamel.yaml + - types-flask + - types-paramiko + - types-prettytable + - types-python-dateutil + - types-pyyaml + - types-requests + - types-setuptools + - types-six + - types-tabulate + - types-tqdm + args: [] + - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.1.0 hooks: diff --git a/docs_rst/conf.py b/docs_rst/conf.py index 7aebfda79..74265cb4d 100644 --- a/docs_rst/conf.py +++ b/docs_rst/conf.py @@ -12,6 +12,7 @@ import os import sys +from typing import Mapping if sys.version_info < (3, 8): import importlib_metadata as metadata @@ -185,7 +186,7 @@ # -- Options for LaTeX output -------------------------------------------------- -latex_elements = { +latex_elements: Mapping[str, str] = { # The paper size ('letterpaper' or 'a4paper'). #'papersize': 'letterpaper', # The font size ('10pt', '11pt' or '12pt'). diff --git a/fireworks/core/firework.py b/fireworks/core/firework.py index 616f4d463..b8e7e0383 100644 --- a/fireworks/core/firework.py +++ b/fireworks/core/firework.py @@ -14,12 +14,13 @@ from collections import defaultdict from copy import deepcopy from datetime import datetime -from typing import Any, Dict, Iterable, List, Sequence +from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union from monty.io import reverse_readline, zopen from monty.os.path import zpath from fireworks.core.fworker import FWorker +from fireworks.core.types import Checkpoint, FromDict, Spec, ToDict from fireworks.fw_config import ( EXCEPT_DETAILS_ON_RERUN, NEGATIVE_FWID_CTR, @@ -53,12 +54,12 @@ class FiretaskBase(defaultdict, FWSerializable, metaclass=abc.ABCMeta): You can set parameters of a Firetask like you'd use a dict. """ - required_params = None # list of str of required parameters to check for consistency upon init + required_params: Optional[List[str]] = None # list of str of required parameters to check for consistency upon init # if set to a list of str, only required and optional kwargs are allowed; consistency checked upon init - optional_params = None + optional_params: Optional[List[str]] = None - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: dict.__init__(self, *args, **kwargs) required_params = self.required_params or [] @@ -77,7 +78,7 @@ def __init__(self, *args, **kwargs): ) @abc.abstractmethod - def run_task(self, fw_spec): + def run_task(self, fw_spec: Spec) -> Optional["FWAction"]: """ This method gets called when the Firetask is run. It can take in a Firework spec, perform some task using that data, and then return an @@ -102,27 +103,27 @@ def run_task(self, fw_spec): @serialize_fw @recursive_serialize - def to_dict(self): + def to_dict(self) -> ToDict: return dict(self) @classmethod @recursive_deserialize - def from_dict(cls, m_dict): + def from_dict(cls, m_dict: FromDict) -> "FiretaskBase": return cls(m_dict) - def __repr__(self): + def __repr__(self) -> str: return f"<{self.fw_name}>:{dict(self)}" # not strictly needed here for pickle/unpickle, but complements __setstate__ - def __getstate__(self): - return self.to_dict() + def __getstate__(self) -> ToDict: + return self.to_dict() # type: ignore # added to support pickle/unpickle - def __setstate__(self, state): - self.__init__(state) + def __setstate__(self, state) -> None: + self.__init__(state) # type: ignore # added to support pickle/unpickle - def __reduce__(self): + def __reduce__(self) -> Tuple: t = defaultdict.__reduce__(self) return (t[0], (self.to_dict(),), t[2], t[3], t[4]) @@ -136,15 +137,15 @@ class FWAction(FWSerializable): def __init__( self, - stored_data=None, - exit=False, - update_spec=None, - mod_spec=None, - additions=None, - detours=None, - defuse_children=False, - defuse_workflow=False, - propagate=False, + stored_data: Optional[Dict[Any, Any]] = None, + exit: bool = False, + update_spec: Optional[Mapping[str, Any]] = None, + mod_spec: Optional[Sequence[Mapping[str, Any]]] = None, + additions: Optional[Union[Sequence["Firework"], Sequence["Workflow"]]] = None, + detours: Optional[Union[Sequence["Firework"], Sequence["Workflow"]]] = None, + defuse_children: bool = False, + defuse_workflow: bool = False, + propagate: bool = False, ): """ Args: @@ -162,7 +163,7 @@ def __init__( not only to direct children, but to all dependent FireWorks down to the Workflow's leaves. """ - mod_spec = mod_spec if mod_spec is not None else [] + mod_spec = list(mod_spec) if mod_spec is not None else [] additions = additions if additions is not None else [] detours = detours if detours is not None else [] @@ -177,7 +178,7 @@ def __init__( self.propagate = propagate @recursive_serialize - def to_dict(self): + def to_dict(self) -> ToDict: return { "stored_data": self.stored_data, "exit": self.exit, @@ -192,7 +193,7 @@ def to_dict(self): @classmethod @recursive_deserialize - def from_dict(cls, m_dict): + def from_dict(cls, m_dict: FromDict) -> "FWAction": d = m_dict additions = [Workflow.from_dict(f) for f in d["additions"]] detours = [Workflow.from_dict(f) for f in d["detours"]] @@ -209,16 +210,22 @@ def from_dict(cls, m_dict): ) @property - def skip_remaining_tasks(self): + def skip_remaining_tasks(self) -> bool: """ If the FWAction gives any dynamic action, we skip the subsequent Firetasks Returns: bool """ - return self.exit or self.detours or self.additions or self.defuse_children or self.defuse_workflow + return ( + self.exit + or bool(self.detours) + or bool(self.additions) + or bool(self.defuse_children) + or bool(self.defuse_workflow) + ) - def __str__(self): + def __str__(self) -> str: return "FWAction\n" + pprint.pformat(self.to_dict()) @@ -242,16 +249,16 @@ class Firework(FWSerializable): # note: if you modify this signature, you must also modify LazyFirework def __init__( self, - tasks, - spec=None, - name=None, - launches=None, - archived_launches=None, - state="WAITING", - created_on=None, - fw_id=None, - parents=None, - updated_on=None, + tasks: Union["FiretaskBase", Sequence["FiretaskBase"]], + spec: Optional[Spec] = None, + name: Optional[str] = None, + launches: Optional[Sequence["Launch"]] = None, + archived_launches: Optional[Sequence["Launch"]] = None, + state: str = "WAITING", + created_on: Optional[datetime] = None, + fw_id: Optional[int] = None, + parents: Optional[Union["Firework", Sequence["Firework"]]] = None, + updated_on: Optional[datetime] = None, ): """ Args: @@ -266,9 +273,9 @@ def __init__( updated_on (datetime): last time the STATE was updated. """ - self.tasks = tasks if isinstance(tasks, (list, tuple)) else [tasks] + self.tasks = list(tasks) if isinstance(tasks, Sequence) else [tasks] - self.spec = spec.copy() if spec else {} + self.spec = dict(spec) if spec else {} self.name = name or "Unnamed FW" # do it this way to prevent None # names @@ -279,18 +286,19 @@ def __init__( NEGATIVE_FWID_CTR -= 1 self.fw_id = NEGATIVE_FWID_CTR - self.launches = launches if launches else [] - self.archived_launches = archived_launches if archived_launches else [] + self.launches = list(launches) if launches else [] + self.archived_launches = list(archived_launches) if archived_launches else [] self.created_on = created_on or datetime.utcnow() self.updated_on = updated_on or datetime.utcnow() parents = [parents] if isinstance(parents, Firework) else parents - self.parents = parents if parents else [] + self.parents = list(parents) if parents else [] + assert self.parents is not None self._state = state @property - def state(self): + def state(self) -> str: """ Returns: str: The current state of the Firework @@ -298,7 +306,7 @@ def state(self): return self._state @state.setter - def state(self, state): + def state(self, state: str) -> None: """ Setter for the the FW state, which triggers updated_on change @@ -309,7 +317,7 @@ def state(self, state): self.updated_on = datetime.utcnow() @recursive_serialize - def to_dict(self): + def to_dict(self) -> ToDict: # put tasks in a special location of the spec spec = self.spec spec["_tasks"] = [t.to_dict() for t in self.tasks] @@ -330,7 +338,7 @@ def to_dict(self): return m_dict - def _rerun(self): + def _rerun(self) -> None: """ Moves all Launches to archived Launches and resets the state to 'WAITING'. The Firework can thus be re-run even if it was Launched in the past. This method should be called by @@ -351,10 +359,10 @@ def _rerun(self): self.archived_launches.extend(self.launches) self.archived_launches = list(set(self.archived_launches)) # filter duplicates - self.launches = [] + self.launches: List[Launch] = [] # type: ignore self.state = "WAITING" - def to_db_dict(self): + def to_db_dict(self) -> ToDict: """ Return firework dict with updated launches and state. """ @@ -368,7 +376,7 @@ def to_db_dict(self): @classmethod @recursive_deserialize - def from_dict(cls, m_dict): + def from_dict(cls, m_dict: FromDict) -> "Firework": tasks = m_dict["spec"]["_tasks"] launches = [Launch.from_dict(tmp) for tmp in m_dict.get("launches", [])] archived_launches = [Launch.from_dict(tmp) for tmp in m_dict.get("archived_launches", [])] @@ -381,8 +389,8 @@ def from_dict(cls, m_dict): tasks, m_dict["spec"], name, launches, archived_launches, state, created_on, fw_id, updated_on=updated_on ) - def __str__(self): - return f"Firework object: (id: {int(self.fw_id)} , name: {self.fw_name})" + def __str__(self) -> str: + return "Firework object: (id: %i , name: %s)" % (self.fw_id, self.fw_name) def __iter__(self) -> Iterable[FiretaskBase]: return self.tasks.__iter__() @@ -401,7 +409,9 @@ class Tracker(FWSerializable): MAX_TRACKER_LINES = 1000 - def __init__(self, filename, nlines=TRACKER_LINES, content="", allow_zipped=False): + def __init__( + self, filename: str, nlines: int = TRACKER_LINES, content: str = "", allow_zipped: bool = False + ) -> None: """ Args: filename (str) @@ -416,7 +426,7 @@ def __init__(self, filename, nlines=TRACKER_LINES, content="", allow_zipped=Fals self.content = content self.allow_zipped = allow_zipped - def track_file(self, launch_dir=None): + def track_file(self, launch_dir: Optional[str] = None) -> str: """ Reads the monitored file and returns back the last N lines @@ -441,19 +451,19 @@ def track_file(self, launch_dir=None): self.content = "\n".join(reversed(lines)) return self.content - def to_dict(self): + def to_dict(self) -> ToDict: m_dict = {"filename": self.filename, "nlines": self.nlines, "allow_zipped": self.allow_zipped} if self.content: m_dict["content"] = self.content return m_dict @classmethod - def from_dict(cls, m_dict): + def from_dict(cls, m_dict: FromDict) -> "Tracker": return Tracker( m_dict["filename"], m_dict["nlines"], m_dict.get("content", ""), m_dict.get("allow_zipped", False) ) - def __str__(self): + def __str__(self) -> str: return f"### Filename: {self.filename}\n{self.content}" @@ -464,17 +474,17 @@ class Launch(FWSerializable): def __init__( self, - state, - launch_dir, - fworker=None, - host=None, - ip=None, - trackers=None, - action=None, - state_history=None, - launch_id=None, - fw_id=None, - ): + state: str, + launch_dir: str, + fworker: Optional["FWorker"] = None, + host: Optional[str] = None, + ip: Optional[str] = None, + trackers: Optional[Sequence["Tracker"]] = None, + action: Optional["FWAction"] = None, + state_history: Optional[Dict[Any, Any]] = None, + launch_id: Optional[int] = None, + fw_id: Optional[int] = None, + ) -> None: """ Args: state (str): the state of the Launch (e.g. RUNNING, COMPLETED) @@ -501,7 +511,7 @@ def __init__( self.launch_id = launch_id self.fw_id = fw_id - def touch_history(self, update_time=None, checkpoint=None): + def touch_history(self, update_time: Optional[datetime] = None, checkpoint: Optional[Checkpoint] = None) -> None: """ Updates the update_on field of the state history of a Launch. Used to ping that a Launch is still alive. @@ -514,7 +524,7 @@ def touch_history(self, update_time=None, checkpoint=None): self.state_history[-1]["checkpoint"] = checkpoint self.state_history[-1]["updated_on"] = update_time - def set_reservation_id(self, reservation_id): + def set_reservation_id(self, reservation_id: Union[int, str]) -> None: """ Adds the job_id to the reservation. @@ -527,7 +537,7 @@ def set_reservation_id(self, reservation_id): break @property - def state(self): + def state(self) -> str: """ Returns: str: The current state of the Launch. @@ -535,7 +545,7 @@ def state(self): return self._state @state.setter - def state(self, state): + def state(self, state: str) -> None: """ Setter for the the Launch's state. Automatically triggers an update to state_history. @@ -546,7 +556,7 @@ def state(self, state): self._update_state_history(state) @property - def time_start(self): + def time_start(self) -> datetime: """ Returns: datetime: the time the Launch started RUNNING @@ -554,7 +564,7 @@ def time_start(self): return self._get_time("RUNNING") @property - def time_end(self): + def time_end(self) -> datetime: """ Returns: datetime: the time the Launch was COMPLETED or FIZZLED @@ -562,7 +572,7 @@ def time_end(self): return self._get_time(["COMPLETED", "FIZZLED"]) @property - def time_reserved(self): + def time_reserved(self) -> datetime: """ Returns: datetime: the time the Launch was RESERVED in the queue @@ -570,7 +580,7 @@ def time_reserved(self): return self._get_time("RESERVED") @property - def last_pinged(self): + def last_pinged(self) -> datetime: """ Returns: datetime: the time the Launch last pinged a heartbeat that it was still running @@ -578,7 +588,7 @@ def last_pinged(self): return self._get_time("RUNNING", True) @property - def runtime_secs(self): + def runtime_secs(self) -> int: # type: ignore """ Returns: int: the number of seconds that the Launch ran for. @@ -586,10 +596,10 @@ def runtime_secs(self): start = self.time_start end = self.time_end if start and end: - return (end - start).total_seconds() + return int((end - start).total_seconds()) @property - def reservedtime_secs(self): + def reservedtime_secs(self) -> int: # type: ignore """ Returns: int: number of seconds the Launch was stuck as RESERVED in a queue. @@ -597,10 +607,10 @@ def reservedtime_secs(self): start = self.time_reserved if start: end = self.time_start if self.time_start else datetime.utcnow() - return (end - start).total_seconds() + return int((end - start).total_seconds()) @recursive_serialize - def to_dict(self): + def to_dict(self) -> ToDict: return { "fworker": self.fworker, "fw_id": self.fw_id, @@ -615,7 +625,7 @@ def to_dict(self): } @recursive_serialize - def to_db_dict(self): + def to_db_dict(self) -> ToDict: m_d = self.to_dict() m_d["time_start"] = self.time_start m_d["time_end"] = self.time_end @@ -626,7 +636,7 @@ def to_db_dict(self): @classmethod @recursive_deserialize - def from_dict(cls, m_dict): + def from_dict(cls, m_dict: FromDict) -> "Launch": fworker = FWorker.from_dict(m_dict["fworker"]) if m_dict["fworker"] else None action = FWAction.from_dict(m_dict["action"]) if m_dict.get("action") else None trackers = [Tracker.from_dict(f) for f in m_dict["trackers"]] if m_dict.get("trackers") else None @@ -664,7 +674,7 @@ def _update_state_history(self, state): if state in ["RUNNING", "RESERVED"]: self.touch_history() # add updated_on key - def _get_time(self, states, use_update_time=False): + def _get_time(self, states, use_update_time: bool = False) -> datetime: """ Internal method to help get the time of various events in the Launch (e.g. RUNNING) from the state history. @@ -680,8 +690,8 @@ def _get_time(self, states, use_update_time=False): for data in self.state_history: if data["state"] in states: if use_update_time: - return data["updated_on"] - return data["created_on"] + return data["updated_on"] # type: ignore + return data["created_on"] # type: ignore class Workflow(FWSerializable): @@ -785,12 +795,12 @@ def __reduce__(self): def __init__( self, fireworks: Sequence[Firework], - links_dict: Dict[int, List[int]] = None, - name: str = None, + links_dict: Optional[Dict[int, List[int]]] = None, + name: Optional[str] = None, metadata: Dict[str, Any] = None, - created_on: datetime = None, - updated_on: datetime = None, - fw_states: Dict[int, str] = None, + created_on: Optional[datetime] = None, + updated_on: Optional[datetime] = None, + fw_states: Optional[Dict[int, str]] = None, ) -> None: """ Args: @@ -1198,7 +1208,7 @@ def leaf_fw_ids(self) -> List[int]: leaf_ids.append(id) return leaf_ids - def _reassign_ids(self, old_new: Dict[int, int]) -> None: + def _reassign_ids(self, old_new: Mapping[int, int]) -> None: """ Internal method to reassign Firework ids, e.g. due to database insertion. @@ -1223,7 +1233,7 @@ def _reassign_ids(self, old_new: Dict[int, int]) -> None: new_fw_states[old_new.get(fwid, fwid)] = fw_state self.fw_states = new_fw_states - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> ToDict: return { "fws": [f.to_dict() for f in self.id_fw.values()], "links": self.links.to_dict(), @@ -1233,7 +1243,7 @@ def to_dict(self) -> Dict[str, Any]: "created_on": self.created_on, } - def to_db_dict(self) -> Dict[str, Any]: + def to_db_dict(self) -> ToDict: m_dict = self.links.to_db_dict() m_dict["metadata"] = self.metadata m_dict["state"] = self.state @@ -1243,7 +1253,7 @@ def to_db_dict(self) -> Dict[str, Any]: m_dict["fw_states"] = {str(k): v for (k, v) in self.fw_states.items()} return m_dict - def to_display_dict(self): + def to_display_dict(self) -> ToDict: m_dict = self.to_db_dict() nodes = sorted(m_dict["nodes"]) m_dict["name--id"] = self.name + "--" + str(nodes[0]) @@ -1332,7 +1342,7 @@ def reset(self, reset_ids: bool = True) -> None: self.fw_states = {key: self.id_fw[key].state for key in self.id_fw} @classmethod - def from_dict(cls, m_dict: Dict[str, Any]) -> "Workflow": + def from_dict(cls, m_dict: FromDict) -> "Workflow": """ Return Workflow from its dict representation. @@ -1358,7 +1368,7 @@ def from_dict(cls, m_dict: Dict[str, Any]) -> "Workflow": return Workflow.from_Firework(Firework.from_dict(m_dict)) @classmethod - def from_Firework(cls, fw: Firework, name: str = None, metadata=None) -> "Workflow": + def from_Firework(cls, fw: "Firework", name: Optional[str] = None, metadata=None) -> "Workflow": """ Return Workflow from the given Firework. @@ -1373,10 +1383,10 @@ def from_Firework(cls, fw: Firework, name: str = None, metadata=None) -> "Workfl name = name if name else fw.name return Workflow([fw], None, name=name, metadata=metadata, created_on=fw.created_on, updated_on=fw.updated_on) - def __str__(self): + def __str__(self) -> str: return f"Workflow object: (fw_ids: {self.id_fw.keys()} , name: {self.name})" - def remove_fws(self, fw_ids): + def remove_fws(self, fw_ids: Sequence[int]) -> None: """ Remove the fireworks corresponding to the input firework ids and update the workflow i.e the parents of the removed fireworks become the parents of the children fireworks (only if the @@ -1387,7 +1397,7 @@ def remove_fws(self, fw_ids): """ # not working with the copies, causes spurious behavior - wf_dict = deepcopy(self.as_dict()) + wf_dict = deepcopy(self.as_dict()) # type: ignore orig_parent_links = deepcopy(self.links.parent_links) fws = wf_dict["fws"] diff --git a/fireworks/core/fworker.py b/fireworks/core/fworker.py index caa4a8070..b0155688a 100644 --- a/fireworks/core/fworker.py +++ b/fireworks/core/fworker.py @@ -3,7 +3,9 @@ """ import json +from typing import Any, Dict, Optional, Sequence, Union +from fireworks.core.types import FromDict, ToDict from fireworks.fw_config import FWORKER_LOC from fireworks.utilities.fw_serializers import ( DATETIME_HANDLER, @@ -21,7 +23,13 @@ class FWorker(FWSerializable): - def __init__(self, name="Automatically generated Worker", category="", query=None, env=None): + def __init__( + self, + name: str = "Automatically generated Worker", + category: Union[str, Sequence[str]] = "", + query: Optional[Dict[str, Any]] = None, + env: Optional[Dict[str, str]] = None, + ) -> None: """ Args: name (str): the name of the resource, should be unique @@ -41,7 +49,7 @@ def __init__(self, name="Automatically generated Worker", category="", query=Non self.env = env if env else {} @recursive_serialize - def to_dict(self): + def to_dict(self) -> ToDict: return { "name": self.name, "category": self.category, @@ -51,11 +59,11 @@ def to_dict(self): @classmethod @recursive_deserialize - def from_dict(cls, m_dict): + def from_dict(cls, m_dict: FromDict) -> "FWorker": return FWorker(m_dict["name"], m_dict["category"], json.loads(m_dict["query"]), m_dict.get("env")) @property - def query(self): + def query(self) -> ToDict: """ Returns updated query dict. """ @@ -77,10 +85,10 @@ def query(self): return q @classmethod - def auto_load(cls): + def auto_load(cls) -> "FWorker": """ Returns FWorker object from settings file(my_fworker.yaml). """ if FWORKER_LOC: - return FWorker.from_file(FWORKER_LOC) + return FWorker.from_file(FWORKER_LOC) # type: ignore return FWorker() diff --git a/fireworks/core/launchpad.py b/fireworks/core/launchpad.py index 9e6ffbd16..d0f1ffd4a 100644 --- a/fireworks/core/launchpad.py +++ b/fireworks/core/launchpad.py @@ -12,6 +12,7 @@ import warnings from collections import defaultdict from itertools import chain +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union import gridfs from bson import ObjectId @@ -22,6 +23,8 @@ from tqdm import tqdm from fireworks.core.firework import Firework, FWAction, Launch, Tracker, Workflow +from fireworks.core.fworker import FWorker +from fireworks.core.types import FromDict, Spec, ToDict from fireworks.fw_config import ( GRIDFS_FALLBACK_COLLECTION, LAUNCHPAD_LOC, @@ -50,7 +53,7 @@ # TODO: lots of duplication reduction and cleanup possible -def sort_aggregation(sort): +def sort_aggregation(sort: Sequence[Tuple[str, int]]) -> List[Dict[str, Any]]: """Build sorting aggregation pipeline. Args: @@ -97,7 +100,13 @@ class WFLock: Calling functions are responsible for handling the error in order to avoid database inconsistencies. """ - def __init__(self, lp, fw_id, expire_secs=WFLOCK_EXPIRATION_SECS, kill=WFLOCK_EXPIRATION_KILL): + def __init__( + self, + lp: "LaunchPad", + fw_id: int, + expire_secs: int = WFLOCK_EXPIRATION_SECS, + kill: bool = WFLOCK_EXPIRATION_KILL, + ) -> None: """ Args: lp (LaunchPad) @@ -110,7 +119,7 @@ def __init__(self, lp, fw_id, expire_secs=WFLOCK_EXPIRATION_SECS, kill=WFLOCK_EX self.expire_secs = expire_secs self.kill = kill - def __enter__(self): + def __enter__(self) -> None: ctr = 0 waiting_time = 0 # acquire lock @@ -140,7 +149,7 @@ def __enter__(self): {"nodes": self.fw_id, "locked": {"$exists": False}}, {"$set": {"locked": True}} ) - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type, exc_val, exc_tb) -> None: self.lp.workflows.find_one_and_update({"nodes": self.fw_id}, {"$unset": {"locked": True}}) @@ -151,19 +160,19 @@ class LaunchPad(FWSerializable): def __init__( self, - host=None, - port=None, - name=None, - username=None, - password=None, - logdir=None, - strm_lvl=None, - user_indices=None, - wf_user_indices=None, - authsource=None, - uri_mode=False, + host: Optional[str] = None, + port: Optional[int] = None, + name: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + logdir: Optional[str] = None, + strm_lvl: Optional[str] = None, + user_indices: Optional[Sequence[int]] = None, + wf_user_indices: Optional[Sequence[int]] = None, + authsource: str = None, + uri_mode: bool = False, mongoclient_kwargs=None, - ): + ) -> None: """ Args: host (str): hostname. If uri_mode is True, a MongoDB connection string URI @@ -234,7 +243,7 @@ def __init__( self.backup_launch_data = {} self.backup_fw_data = {} - def to_dict(self): + def to_dict(self) -> ToDict: """ Note: usernames/passwords are exported as unencrypted Strings! """ @@ -253,7 +262,7 @@ def to_dict(self): "mongoclient_kwargs": self.mongoclient_kwargs, } - def update_spec(self, fw_ids, spec_document, mongo=False): + def update_spec(self, fw_ids: Sequence[int], spec_document: Spec, mongo: bool = False) -> None: """ Update fireworks with a spec. Sometimes you need to modify a firework in progress. @@ -280,7 +289,7 @@ def update_spec(self, fw_ids, spec_document, mongo=False): ) @classmethod - def from_dict(cls, d): + def from_dict(cls, d: FromDict) -> "LaunchPad": port = d.get("port", None) name = d.get("name", None) username = d.get("username", None) @@ -308,7 +317,7 @@ def from_dict(cls, d): ) @classmethod - def auto_load(cls): + def auto_load(cls) -> "LaunchPad": if LAUNCHPAD_LOC: return LaunchPad.from_file(LAUNCHPAD_LOC) return LaunchPad() @@ -384,7 +393,7 @@ def maintain(self, infinite=True, maintain_interval=None): self.m_logger.debug(f"Sleeping for {maintain_interval} secs...") time.sleep(maintain_interval) - def add_wf(self, wf, reassign_all=True): + def add_wf(self, wf: Union[Workflow, Firework], reassign_all: bool = True) -> Dict[int, int]: """ Add workflow(or firework) to the launchpad. The firework ids will be reassigned. @@ -468,7 +477,7 @@ def append_wf(self, new_wf, fw_ids, detour=False, pull_spec_mods=True): with WFLock(self, fw_ids[0]): self._update_wf(wf, updated_ids) - def get_launch_by_id(self, launch_id): + def get_launch_by_id(self, launch_id: int) -> Launch: """ Given a Launch id, return details of the Launch. @@ -484,7 +493,7 @@ def get_launch_by_id(self, launch_id): return Launch.from_dict(m_launch) raise ValueError(f"No Launch exists with launch_id: {launch_id}") - def get_fw_dict_by_id(self, fw_id): + def get_fw_dict_by_id(self, fw_id: int): """ Given firework id, return firework dict. @@ -512,7 +521,7 @@ def get_fw_dict_by_id(self, fw_id): fw_dict["archived_launches"] = launches return fw_dict - def get_fw_by_id(self, fw_id): + def get_fw_by_id(self, fw_id: int) -> Firework: """ Given a Firework id, give back a Firework object. @@ -524,7 +533,7 @@ def get_fw_by_id(self, fw_id): """ return Firework.from_dict(self.get_fw_dict_by_id(fw_id)) - def get_wf_by_fw_id(self, fw_id): + def get_wf_by_fw_id(self, fw_id: int) -> Workflow: """ Given a Firework id, give back the Workflow containing that Firework. @@ -580,7 +589,7 @@ def get_wf_by_fw_id_lzyfw(self, fw_id): fw_states, ) - def delete_fws(self, fw_ids, delete_launch_dirs=False): + def delete_fws(self, fw_ids: Sequence[int], delete_launch_dirs: bool = False) -> None: """Delete a set of fireworks identified by their fw_ids. ATTENTION: This function serves maintenance purposes and will leave @@ -623,7 +632,7 @@ def delete_fws(self, fw_ids, delete_launch_dirs=False): self.offline_runs.delete_many({"launch_id": {"$in": launch_ids}}) self.fireworks.delete_many({"fw_id": {"$in": fw_ids}}) - def delete_wf(self, fw_id, delete_launch_dirs=False): + def delete_wf(self, fw_id: int, delete_launch_dirs: bool = False) -> None: """ Delete the workflow containing firework with the given id. @@ -638,7 +647,7 @@ def delete_wf(self, fw_id, delete_launch_dirs=False): print("Removing workflow.") self.workflows.delete_one({"nodes": fw_id}) - def get_wf_summary_dict(self, fw_id, mode="more"): + def get_wf_summary_dict(self, fw_id: int, mode: str = "more") -> Dict[str, Any]: """ A much faster way to get summary information about a Workflow by querying only for needed information. @@ -724,7 +733,14 @@ def get_wf_summary_dict(self, fw_id, mode="more"): return wf - def get_fw_ids(self, query=None, sort=None, limit=0, count_only=False, launches_mode=False): + def get_fw_ids( + self, + query: Optional[Mapping[str, Any]] = None, + sort=None, + limit: int = 0, + count_only: bool = False, + launches_mode: bool = False, + ) -> List[int]: """ Return all the fw ids that match a query. @@ -814,8 +830,14 @@ def get_wf_ids(self, query=None, sort=None, limit=0, count_only=False): return [fw["nodes"][0] for fw in cursor] def get_fw_ids_in_wfs( - self, wf_query=None, fw_query=None, sort=None, limit=0, count_only=False, launches_mode=False - ): + self, + wf_query: Optional[Dict[str, Any]] = None, + fw_query: Optional[Dict[str, Any]] = None, + sort: Optional[List[Tuple[str, str]]] = None, + limit: int = 0, + count_only: bool = False, + launches_mode: bool = False, + ) -> Union[List[int], int]: """ Return all fw ids that match fw_query within workflows that match wf_query. @@ -889,7 +911,7 @@ def get_fw_ids_in_wfs( cursor = self.workflows.aggregate(aggregation) return [fw["fw_id"] for fw in cursor] - def run_exists(self, fworker=None): + def run_exists(self, fworker: Optional[FWorker] = None) -> bool: """ Checks to see if the database contains any FireWorks that are ready to run. @@ -899,7 +921,7 @@ def run_exists(self, fworker=None): q = fworker.query if fworker else {} return bool(self._get_a_fw_to_run(query=q, checkout=False)) - def future_run_exists(self, fworker=None): + def future_run_exists(self, fworker: Optional[FWorker] = None) -> bool: """Check if database has any current OR future Fireworks available Returns: @@ -1178,7 +1200,7 @@ def _get_a_fw_to_run(self, query=None, fw_id=None, checkout=True): if self._check_fw_for_uniqueness(m_fw): return m_fw - def _get_active_launch_ids(self): + def _get_active_launch_ids(self) -> List[int]: """ Get all the launch ids. @@ -1190,7 +1212,14 @@ def _get_active_launch_ids(self): all_launch_ids.extend(l["launches"]) return all_launch_ids - def reserve_fw(self, fworker, launch_dir, host=None, ip=None, fw_id=None): + def reserve_fw( + self, + fworker: FWorker, + launch_dir: str, + host: Optional[str] = None, + ip: Optional[str] = None, + fw_id: Optional[int] = None, + ): """ Checkout the next ready firework and mark the launch reserved. @@ -1589,7 +1618,7 @@ def ping_launch(self, launch_id, ptime=None, checkpoint=None): }, ) - def get_new_fw_id(self, quantity=1): + def get_new_fw_id(self, quantity: int = 1) -> int: """ Checkout the next Firework id @@ -1617,7 +1646,7 @@ def get_new_launch_id(self): "database, please do so by performing a database reset (e.g., lpad reset)" ) - def _upsert_fws(self, fws, reassign_all=False): + def _upsert_fws(self, fws: List[Firework], reassign_all: bool = False) -> Dict[int, int]: """ Insert the fireworks to the 'fireworks' collection. @@ -2038,7 +2067,7 @@ class LazyFirework: db_fields = ("name", "fw_id", "spec", "created_on", "state") db_launch_fields = ("launches", "archived_launches") - def __init__(self, fw_id, fw_coll, launch_coll, fallback_fs): + def __init__(self, fw_id: int, fw_coll, launch_coll, fallback_fs): """ Args: fw_id (int): firework id @@ -2123,21 +2152,21 @@ def updated_on(self, value): self.partial_fw.updated_on = value @property - def parents(self): + def parents(self) -> List[Firework]: if self._fw is not None: return self.partial_fw.parents else: return [] @parents.setter - def parents(self, value): + def parents(self, value) -> None: self.partial_fw.parents = value # Properties that shadow FireWork attributes, but which are # fetched individually from the DB (i.e. launch objects) @property - def launches(self): + def launches(self) -> Launch: return self._get_launch_data("launches") @launches.setter @@ -2150,7 +2179,7 @@ def archived_launches(self): return self._get_launch_data("archived_launches") @archived_launches.setter - def archived_launches(self, value): + def archived_launches(self, value) -> None: self._launches["archived_launches"] = True self.partial_fw.archived_launches = value @@ -2177,7 +2206,7 @@ def full_fw(self): # Get a type of Launch object - def _get_launch_data(self, name): + def _get_launch_data(self, name: str) -> Launch: """ Pull launch data individually for each field. @@ -2202,7 +2231,7 @@ def _get_launch_data(self, name): return getattr(fw, name) -def get_action_from_gridfs(action_dict, fallback_fs): +def get_action_from_gridfs(action_dict: Mapping[str, Any], fallback_fs: gridfs.GridFS) -> Dict[str, Any]: """ Helper function to obtain the correct dictionary of the FWAction associated with a launch. If necessary retrieves the information from gridfs based diff --git a/fireworks/core/rocket.py b/fireworks/core/rocket.py index 3c3d7da36..02375ce42 100644 --- a/fireworks/core/rocket.py +++ b/fireworks/core/rocket.py @@ -15,7 +15,7 @@ import traceback from datetime import datetime from threading import Event, Thread, current_thread -from typing import Dict, Union +from typing import Optional from monty.io import zopen from monty.os.path import zpath @@ -23,6 +23,7 @@ from fireworks.core.firework import Firework, FWAction from fireworks.core.fworker import FWorker from fireworks.core.launchpad import LaunchPad, LockedWorkflowError +from fireworks.core.types import Checkpoint, Spec from fireworks.fw_config import ( PING_TIME_SECS, PRINT_FW_JSON, @@ -56,7 +57,7 @@ def ping_launch(launchpad: LaunchPad, launch_id: int, stop_event: Event, master_ stop_event.wait(PING_TIME_SECS) -def start_ping_launch(launchpad: LaunchPad, launch_id: int) -> Union[Event, None]: +def start_ping_launch(launchpad: LaunchPad, launch_id: int) -> Optional[Event]: fd = FWData() if fd.MULTIPROCESSING: if not launch_id: @@ -70,7 +71,7 @@ def start_ping_launch(launchpad: LaunchPad, launch_id: int) -> Union[Event, None return ping_stop -def stop_backgrounds(ping_stop, btask_stops): +def stop_backgrounds(ping_stop: Event, btask_stops) -> None: fd = FWData() if fd.MULTIPROCESSING: fd.Running_IDs[os.getpid()] = None @@ -81,7 +82,7 @@ def stop_backgrounds(ping_stop, btask_stops): b.set() -def background_task(btask, spec, stop_event, master_thread): +def background_task(btask, spec: Spec, stop_event: Event, master_thread: Thread) -> None: num_launched = 0 while not stop_event.is_set() and master_thread.is_alive(): for task in btask.tasks: @@ -93,7 +94,7 @@ def background_task(btask, spec, stop_event, master_thread): break -def start_background_task(btask, spec): +def start_background_task(btask, spec: Spec) -> Event: ping_stop = Event() ping_thread = Thread(target=background_task, args=(btask, spec, ping_stop, current_thread())) ping_thread.start() @@ -409,8 +410,9 @@ def run(self, pdb_on_exception: bool = False) -> bool: l_logger.log(logging.DEBUG, traceback.format_exc()) l_logger.log( logging.WARNING, - f"Firework {self.fw_id} fizzled but couldn't complete the update of the database. " - f"Reason: {e}\nRefresh the WF to recover the result (lpad admin refresh -i {self.fw_id}).", + "Firework {} fizzled but couldn't complete the update of the database." + " Reason: {}\nRefresh the WF to recover the result " + "(lpad admin refresh -i {}).".format(self.fw_id, e, self.fw_id), ) return True else: @@ -426,7 +428,7 @@ def run(self, pdb_on_exception: bool = False) -> bool: return True @staticmethod - def update_checkpoint(launchpad: LaunchPad, launch_dir: str, launch_id: int, checkpoint: Dict[str, any]) -> None: + def update_checkpoint(launchpad: LaunchPad, launch_dir: str, launch_id: int, checkpoint: Checkpoint) -> None: """ Helper function to update checkpoint @@ -446,9 +448,7 @@ def update_checkpoint(launchpad: LaunchPad, launch_dir: str, launch_id: int, che with zopen(fpath, "wt") as f_out: f_out.write(json.dumps(d, ensure_ascii=False)) - def decorate_fwaction( - self, fwaction: FWAction, my_spec: Dict[str, any], m_fw: Firework, launch_dir: str - ) -> FWAction: + def decorate_fwaction(self, fwaction: FWAction, my_spec: Spec, m_fw: Firework, launch_dir: str) -> FWAction: if my_spec.get("_pass_job_info"): job_info = list(my_spec.get("_job_info", [])) diff --git a/fireworks/core/rocket_launcher.py b/fireworks/core/rocket_launcher.py index b9aa5b311..e80f95e27 100644 --- a/fireworks/core/rocket_launcher.py +++ b/fireworks/core/rocket_launcher.py @@ -5,8 +5,10 @@ import os import time from datetime import datetime +from typing import Optional from fireworks.core.fworker import FWorker +from fireworks.core.launchpad import LaunchPad from fireworks.core.rocket import Rocket from fireworks.fw_config import FWORKER_LOC, RAPIDFIRE_SLEEP_SECS from fireworks.utilities.fw_utilities import ( @@ -23,17 +25,23 @@ __date__ = "Feb 22, 2013" -def get_fworker(fworker): - if fworker: +def get_fworker(fworker: Optional[FWorker]) -> FWorker: + if fworker is not None: my_fwkr = fworker elif FWORKER_LOC: - my_fwkr = FWorker.from_file(FWORKER_LOC) + my_fwkr = FWorker.from_file(FWORKER_LOC) # type: ignore else: my_fwkr = FWorker() return my_fwkr -def launch_rocket(launchpad, fworker=None, fw_id=None, strm_lvl="INFO", pdb_on_exception=False): +def launch_rocket( + launchpad: LaunchPad, + fworker: Optional[FWorker] = None, + fw_id: Optional[int] = None, + strm_lvl: str = "INFO", + pdb_on_exception: bool = False, +) -> bool: """ Run a single rocket in the current directory. @@ -53,6 +61,7 @@ def launch_rocket(launchpad, fworker=None, fw_id=None, strm_lvl="INFO", pdb_on_e l_logger = get_fw_logger("rocket.launcher", l_dir=l_dir, stream_level=strm_lvl) log_multi(l_logger, "Launching Rocket") + assert fw_id is not None rocket = Rocket(launchpad, fworker, fw_id) rocket_ran = rocket.run(pdb_on_exception=pdb_on_exception) log_multi(l_logger, "Rocket finished") @@ -60,16 +69,16 @@ def launch_rocket(launchpad, fworker=None, fw_id=None, strm_lvl="INFO", pdb_on_e def rapidfire( - launchpad, - fworker=None, - m_dir=None, - nlaunches=0, - max_loops=-1, - sleep_time=None, - strm_lvl="INFO", - timeout=None, - local_redirect=False, - pdb_on_exception=False, + launchpad: LaunchPad, + fworker: Optional[FWorker] = None, + m_dir: Optional[str] = None, + nlaunches: int = 0, + max_loops: int = -1, + sleep_time: Optional[int] = None, + strm_lvl: str = "INFO", + timeout: Optional[int] = None, + local_redirect: bool = False, + pdb_on_exception: bool = False, ): """ Keeps running Rockets in m_dir until we reach an error. Automatically creates subdirectories diff --git a/fireworks/core/tests/tasks.py b/fireworks/core/tests/tasks.py index 83aa0bf9f..23fbbc14a 100644 --- a/fireworks/core/tests/tasks.py +++ b/fireworks/core/tests/tasks.py @@ -2,6 +2,7 @@ from unittest import SkipTest from fireworks import FiretaskBase, Firework, FWAction +from fireworks.core.types import Spec from fireworks.utilities.fw_utilities import explicit_serialize @@ -95,7 +96,7 @@ def run_task(self, fw_spec): class DetoursTask(FiretaskBase): optional_params = ["n_detours", "data_per_detour"] - def run_task(self, fw_spec): + def run_task(self, fw_spec: Spec) -> FWAction: data_per_detour = self.get("data_per_detour", None) n_detours = self.get("n_detours", 10) fws = [] diff --git a/fireworks/core/tests/test_launchpad.py b/fireworks/core/tests/test_launchpad.py index 6cb340bb2..647707623 100644 --- a/fireworks/core/tests/test_launchpad.py +++ b/fireworks/core/tests/test_launchpad.py @@ -15,12 +15,11 @@ from monty.os import cd from pymongo import MongoClient -from pymongo.errors import OperationFailure from pymongo import __version__ as PYMONGO_VERSION +from pymongo.errors import OperationFailure import fireworks.fw_config -from fireworks import Firework, FWorker, LaunchPad, Workflow, FiretaskBase, \ - explicit_serialize +from fireworks import Firework, FWorker, LaunchPad, Workflow from fireworks.core.rocket_launcher import launch_rocket, rapidfire from fireworks.core.tests.tasks import ( DetoursTask, @@ -44,33 +43,28 @@ class AuthenticationTest(unittest.TestCase): def setUpClass(cls): try: client = MongoClient() - client.not_the_admin_db.command("createUser", "myuser", - pwd="mypassword", roles=["dbOwner"]) + client.not_the_admin_db.command("createUser", "myuser", pwd="mypassword", roles=["dbOwner"]) except Exception: - raise unittest.SkipTest( - "MongoDB is not running in localhost:27017! Skipping tests.") + raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.") def test_no_admin_privileges_for_plebs(self): """Normal users can not authenticate against the admin db.""" with self.assertRaises(OperationFailure): - lp = LaunchPad(name="admin", username="myuser", - password="mypassword", authsource="admin") + lp = LaunchPad(name="admin", username="myuser", password="mypassword", authsource="admin") lp.db.collection.count_documents({}) def test_authenticating_to_users_db(self): """A user should be able to authenticate against a database that they are a user of. """ - lp = LaunchPad(name="not_the_admin_db", username="myuser", - password="mypassword", authsource="not_the_admin_db") + lp = LaunchPad(name="not_the_admin_db", username="myuser", password="mypassword", authsource="not_the_admin_db") lp.db.collection.count_documents({}) def test_authsource_infered_from_db_name(self): """The default behavior is to authenticate against the db that the user is trying to access. """ - lp = LaunchPad(name="not_the_admin_db", username="myuser", - password="mypassword") + lp = LaunchPad(name="not_the_admin_db", username="myuser", password="mypassword") lp.db.collection.count_documents({}) @@ -83,8 +77,7 @@ def setUpClass(cls): cls.lp = LaunchPad(name=TESTDB_NAME, strm_lvl="ERROR") cls.lp.reset(password=None, require_password=False) except Exception: - raise unittest.SkipTest( - "MongoDB is not running in localhost:27017! Skipping tests.") + raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.") @classmethod def tearDownClass(cls): @@ -98,8 +91,7 @@ def setUp(self): self.lp.to_file(self.LP_LOC) def tearDown(self): - self.lp.reset(password=None, require_password=False, - max_reset_wo_password=1000) + self.lp.reset(password=None, require_password=False, max_reset_wo_password=1000) # Delete launch locations if os.path.exists(os.path.join("FW.json")): os.remove("FW.json") @@ -130,8 +122,7 @@ def test_reset(self): # test failsafe in a strict way for _ in range(30): - self.lp.add_wf( - Workflow([Firework(ScriptTask.from_str('echo "hello"'))])) + self.lp.add_wf(Workflow([Firework(ScriptTask.from_str('echo "hello"'))])) self.assertRaises(ValueError, self.lp.reset, "") self.lp.reset("", False, 100) # reset back @@ -164,10 +155,8 @@ def test_add_wfs(self): wfs = [] for _ in range(50): # Add two workflows with 3 and 5 simple fireworks - wf3 = Workflow([Firework(ftask, name="lorem") for _ in range(3)], - name="lorem wf") - wf5 = Workflow([Firework(ftask, name="lorem") for _ in range(5)], - name="lorem wf") + wf3 = Workflow([Firework(ftask, name="lorem") for _ in range(3)], name="lorem wf") + wf5 = Workflow([Firework(ftask, name="lorem") for _ in range(5)], name="lorem wf") wfs.extend([wf3, wf5]) self.lp.bulk_add_wfs(wfs) num_fws_total = sum(len(wf) for wf in wfs) @@ -186,8 +175,7 @@ def setUpClass(cls): cls.lp = LaunchPad(name=TESTDB_NAME, strm_lvl="ERROR") cls.lp.reset(password=None, require_password=False) except Exception: - raise unittest.SkipTest( - "MongoDB is not running in localhost:27017! Skipping tests.") + raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.") @classmethod def tearDownClass(cls): @@ -198,96 +186,80 @@ def setUp(self): # define the individual FireWorks used in the Workflow # Parent Firework fw_p = Firework( - ScriptTask.from_str('echo "Cronus is the ruler of titans"', - {"store_stdout": True}), name="parent", fw_id=1 + ScriptTask.from_str('echo "Cronus is the ruler of titans"', {"store_stdout": True}), name="parent", fw_id=1 ) # Sibling fireworks fw_s1 = Firework( - ScriptTask.from_str('echo "Zeus is son of Cronus"', - {"store_stdout": True}), + ScriptTask.from_str('echo "Zeus is son of Cronus"', {"store_stdout": True}), name="sib1", fw_id=2, parents=fw_p, ) fw_s2 = Firework( - ScriptTask.from_str('echo "Poisedon is brother of Zeus"', - {"store_stdout": True}), + ScriptTask.from_str('echo "Poisedon is brother of Zeus"', {"store_stdout": True}), name="sib2", fw_id=3, parents=fw_p, ) fw_s3 = Firework( - ScriptTask.from_str('echo "Hades is brother of Zeus"', - {"store_stdout": True}), + ScriptTask.from_str('echo "Hades is brother of Zeus"', {"store_stdout": True}), name="sib3", fw_id=4, parents=fw_p, ) fw_s4 = Firework( - ScriptTask.from_str('echo "Demeter is sister & wife of Zeus"', - {"store_stdout": True}), + ScriptTask.from_str('echo "Demeter is sister & wife of Zeus"', {"store_stdout": True}), name="sib4", fw_id=5, parents=fw_p, ) fw_s5 = Firework( - ScriptTask.from_str('echo "Lapetus is son of Oceanus"', - {"store_stdout": True}), name="cousin1", fw_id=6 + ScriptTask.from_str('echo "Lapetus is son of Oceanus"', {"store_stdout": True}), name="cousin1", fw_id=6 ) # Children fireworks fw_c1 = Firework( - ScriptTask.from_str('echo "Ares is son of Zeus"', - {"store_stdout": True}), name="c1", fw_id=7, - parents=fw_s1 + ScriptTask.from_str('echo "Ares is son of Zeus"', {"store_stdout": True}), name="c1", fw_id=7, parents=fw_s1 ) fw_c2 = Firework( ScriptTask.from_str( - 'echo "Persephone is daughter of Zeus & Demeter and wife of Hades"', - {"store_stdout": True} + 'echo "Persephone is daughter of Zeus & Demeter and wife of Hades"', {"store_stdout": True} ), name="c2", fw_id=8, parents=[fw_s1, fw_s4], ) fw_c3 = Firework( - ScriptTask.from_str( - 'echo "Makaria is daughter of Hades & Persephone"', - {"store_stdout": True}), + ScriptTask.from_str('echo "Makaria is daughter of Hades & Persephone"', {"store_stdout": True}), name="c3", fw_id=9, parents=[fw_s3, fw_c2], ) fw_c4 = Firework( - ScriptTask.from_str('echo "Dione is descendant of Lapetus"', - {"store_stdout": True}), + ScriptTask.from_str('echo "Dione is descendant of Lapetus"', {"store_stdout": True}), name="c4", fw_id=10, parents=fw_s5, ) fw_c5 = Firework( - ScriptTask.from_str('echo "Aphrodite is son of of Zeus and Dione"', - {"store_stdout": True}), + ScriptTask.from_str('echo "Aphrodite is son of of Zeus and Dione"', {"store_stdout": True}), name="c5", fw_id=11, parents=[fw_s1, fw_c4], ) fw_c6 = Firework( - ScriptTask.from_str('echo "Atlas is son of of Lapetus"', - {"store_stdout": True}), + ScriptTask.from_str('echo "Atlas is son of of Lapetus"', {"store_stdout": True}), name="c6", fw_id=12, parents=fw_s5, ) fw_c7 = Firework( - ScriptTask.from_str('echo "Maia is daughter of Atlas"', - {"store_stdout": True}), + ScriptTask.from_str('echo "Maia is daughter of Atlas"', {"store_stdout": True}), name="c7", fw_id=13, parents=fw_c6, ) fw_c8 = Firework( - ScriptTask.from_str('echo "Hermes is daughter of Maia and Zeus"', - {"store_stdout": True}), + ScriptTask.from_str('echo "Hermes is daughter of Maia and Zeus"', {"store_stdout": True}), name="c8", fw_id=14, parents=[fw_s1, fw_c7], @@ -295,8 +267,7 @@ def setUp(self): # assemble Workflow from FireWorks and their connections by id workflow = Workflow( - [fw_p, fw_s1, fw_s2, fw_s3, fw_s4, fw_s5, fw_c1, fw_c2, fw_c3, - fw_c4, fw_c5, fw_c6, fw_c7, fw_c8] + [fw_p, fw_s1, fw_s2, fw_s3, fw_s4, fw_s5, fw_c1, fw_c2, fw_c3, fw_c4, fw_c5, fw_c6, fw_c7, fw_c8] ) self.lp.add_wf(workflow) @@ -307,11 +278,11 @@ def setUp(self): self.zeus_sib_fw_ids = {3, 4, 5} self.par_fw_id = 1 self.all_ids = ( - self.zeus_child_fw_ids - | self.lapetus_desc_fw_ids - | self.zeus_sib_fw_ids - | {self.zeus_fw_id} - | {self.par_fw_id} + self.zeus_child_fw_ids + | self.lapetus_desc_fw_ids + | self.zeus_sib_fw_ids + | {self.zeus_fw_id} + | {self.par_fw_id} ) self.old_wd = os.getcwd() @@ -348,8 +319,7 @@ def test_pause_fw(self): self.assertTrue(self.zeus_sib_fw_ids.issubset(completed_ids)) # Check that Zeus and children are subset of incompleted fwids - fws_no_run = set( - self.lp.get_fw_ids({"state": {"$nin": ["COMPLETED"]}})) + fws_no_run = set(self.lp.get_fw_ids({"state": {"$nin": ["COMPLETED"]}})) self.assertIn(self.zeus_fw_id, fws_no_run) self.assertTrue(self.zeus_child_fw_ids.issubset(fws_no_run)) @@ -383,8 +353,7 @@ def test_defuse_fw(self): self.assertTrue(self.zeus_sib_fw_ids.issubset(completed_ids)) # Check that Zeus and children are subset of incompleted fwids - fws_no_run = set( - self.lp.get_fw_ids({"state": {"$nin": ["COMPLETED"]}})) + fws_no_run = set(self.lp.get_fw_ids({"state": {"$nin": ["COMPLETED"]}})) self.assertIn(self.zeus_fw_id, fws_no_run) self.assertTrue(self.zeus_child_fw_ids.issubset(fws_no_run)) except Exception: @@ -585,10 +554,7 @@ def test_rerun_fws2(self): self.assertFalse(fw_start_t > ts) -@unittest.skipIf( - PYMONGO_MAJOR_VERSION > 3, - "detect lostruns test not supported for pymongo major version > 3" -) +@unittest.skipIf(PYMONGO_MAJOR_VERSION > 3, "detect lostruns test not supported for pymongo major version > 3") class LaunchPadLostRunsDetectTest(unittest.TestCase): @classmethod def setUpClass(cls): @@ -598,8 +564,7 @@ def setUpClass(cls): cls.lp = LaunchPad(name=TESTDB_NAME, strm_lvl="ERROR") cls.lp.reset(password=None, require_password=False) except Exception: - raise unittest.SkipTest( - "MongoDB is not running in localhost:27017! Skipping tests.") + raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.") @classmethod def tearDownClass(cls): @@ -681,8 +646,7 @@ def run(self): # Wait for fw to start it = 0 - while not any([f.state == "RUNNING" for f in - self.lp.get_wf_by_fw_id_lzyfw(self.fw_id).fws]): + while not any([f.state == "RUNNING" for f in self.lp.get_wf_by_fw_id_lzyfw(self.fw_id).fws]): time.sleep(1) # Wait 1 sec it += 1 if it == 10: @@ -714,8 +678,7 @@ def run(self): # Wait for running it = 0 - while not any([f.state == "RUNNING" for f in - self.lp.get_wf_by_fw_id_lzyfw(self.fw_id).fws]): + while not any([f.state == "RUNNING" for f in self.lp.get_wf_by_fw_id_lzyfw(self.fw_id).fws]): time.sleep(1) # Wait 1 sec it += 1 if it == 10: @@ -728,6 +691,7 @@ def run(self): self.assertEqual(wf.fw_states[fw_id], "RUNNING") rp.terminate() + class WorkflowFireworkStatesTest(unittest.TestCase): """ Class to test the firework states locally cached in workflow. @@ -742,8 +706,7 @@ def setUpClass(cls): cls.lp = LaunchPad(name=TESTDB_NAME, strm_lvl="ERROR") cls.lp.reset(password=None, require_password=False) except Exception: - raise unittest.SkipTest( - "MongoDB is not running in localhost:27017! Skipping tests.") + raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.") @classmethod def tearDownClass(cls): @@ -754,95 +717,79 @@ def setUp(self): # define the individual FireWorks used in the Workflow # Parent Firework fw_p = Firework( - ScriptTask.from_str('echo "Cronus is the ruler of titans"', - {"store_stdout": True}), name="parent", fw_id=1 + ScriptTask.from_str('echo "Cronus is the ruler of titans"', {"store_stdout": True}), name="parent", fw_id=1 ) # Sibling fireworks # fw_s1 = Firework(ScriptTask.from_str( # 'echo "Zeus is son of Cronus"', # {'store_stdout':True}), name="sib1", fw_id=2, parents=fw_p) # Timed firework - fw_s1 = Firework(PyTask(func="time.sleep", args=[5]), name="sib1", - fw_id=2, parents=fw_p) + fw_s1 = Firework(PyTask(func="time.sleep", args=[5]), name="sib1", fw_id=2, parents=fw_p) fw_s2 = Firework( - ScriptTask.from_str('echo "Poisedon is brother of Zeus"', - {"store_stdout": True}), + ScriptTask.from_str('echo "Poisedon is brother of Zeus"', {"store_stdout": True}), name="sib2", fw_id=3, parents=fw_p, ) fw_s3 = Firework( - ScriptTask.from_str('echo "Hades is brother of Zeus"', - {"store_stdout": True}), + ScriptTask.from_str('echo "Hades is brother of Zeus"', {"store_stdout": True}), name="sib3", fw_id=4, parents=fw_p, ) fw_s4 = Firework( - ScriptTask.from_str('echo "Demeter is sister & wife of Zeus"', - {"store_stdout": True}), + ScriptTask.from_str('echo "Demeter is sister & wife of Zeus"', {"store_stdout": True}), name="sib4", fw_id=5, parents=fw_p, ) fw_s5 = Firework( - ScriptTask.from_str('echo "Lapetus is son of Oceanus"', - {"store_stdout": True}), name="cousin1", fw_id=6 + ScriptTask.from_str('echo "Lapetus is son of Oceanus"', {"store_stdout": True}), name="cousin1", fw_id=6 ) # Children fireworks fw_c1 = Firework( - ScriptTask.from_str('echo "Ares is son of Zeus"', - {"store_stdout": True}), name="c1", fw_id=7, - parents=fw_s1 + ScriptTask.from_str('echo "Ares is son of Zeus"', {"store_stdout": True}), name="c1", fw_id=7, parents=fw_s1 ) fw_c2 = Firework( ScriptTask.from_str( - 'echo "Persephone is daughter of Zeus & Demeter and wife of Hades"', - {"store_stdout": True} + 'echo "Persephone is daughter of Zeus & Demeter and wife of Hades"', {"store_stdout": True} ), name="c2", fw_id=8, parents=[fw_s1, fw_s4], ) fw_c3 = Firework( - ScriptTask.from_str( - 'echo "Makaria is daughter of Hades & Persephone"', - {"store_stdout": True}), + ScriptTask.from_str('echo "Makaria is daughter of Hades & Persephone"', {"store_stdout": True}), name="c3", fw_id=9, parents=[fw_s3, fw_c2], ) fw_c4 = Firework( - ScriptTask.from_str('echo "Dione is descendant of Lapetus"', - {"store_stdout": True}), + ScriptTask.from_str('echo "Dione is descendant of Lapetus"', {"store_stdout": True}), name="c4", fw_id=10, parents=fw_s5, ) fw_c5 = Firework( - ScriptTask.from_str('echo "Aphrodite is son of of Zeus and Dione"', - {"store_stdout": True}), + ScriptTask.from_str('echo "Aphrodite is son of of Zeus and Dione"', {"store_stdout": True}), name="c5", fw_id=11, parents=[fw_s1, fw_c4], ) fw_c6 = Firework( - ScriptTask.from_str('echo "Atlas is son of of Lapetus"', - {"store_stdout": True}), + ScriptTask.from_str('echo "Atlas is son of of Lapetus"', {"store_stdout": True}), name="c6", fw_id=12, parents=fw_s5, ) fw_c7 = Firework( - ScriptTask.from_str('echo "Maia is daughter of Atlas"', - {"store_stdout": True}), + ScriptTask.from_str('echo "Maia is daughter of Atlas"', {"store_stdout": True}), name="c7", fw_id=13, parents=fw_c6, ) fw_c8 = Firework( - ScriptTask.from_str('echo "Hermes is daughter of Maia and Zeus"', - {"store_stdout": True}), + ScriptTask.from_str('echo "Hermes is daughter of Maia and Zeus"', {"store_stdout": True}), name="c8", fw_id=14, parents=[fw_s1, fw_c7], @@ -850,8 +797,7 @@ def setUp(self): # assemble Workflow from FireWorks and their connections by id workflow = Workflow( - [fw_p, fw_s1, fw_s2, fw_s3, fw_s4, fw_s5, fw_c1, fw_c2, fw_c3, - fw_c4, fw_c5, fw_c6, fw_c7, fw_c8] + [fw_p, fw_s1, fw_s2, fw_s3, fw_s4, fw_s5, fw_c1, fw_c2, fw_c3, fw_c4, fw_c5, fw_c6, fw_c7, fw_c8] ) self.lp.add_wf(workflow) @@ -862,11 +808,11 @@ def setUp(self): self.zeus_sib_fw_ids = {3, 4, 5} self.par_fw_id = 1 self.all_ids = ( - self.zeus_child_fw_ids - | self.lapetus_desc_fw_ids - | self.zeus_sib_fw_ids - | {self.zeus_fw_id} - | {self.par_fw_id} + self.zeus_child_fw_ids + | self.lapetus_desc_fw_ids + | self.zeus_sib_fw_ids + | {self.zeus_fw_id} + | {self.par_fw_id} ) self.old_wd = os.getcwd() @@ -1023,8 +969,7 @@ def run(self): self.assertEqual(fw_state, fw_cache_state) # Detect lost runs - lost_lids, lost_fwids, inconsistent_fwids = self.lp.detect_lostruns( - expiration_secs=0.5) + lost_lids, lost_fwids, inconsistent_fwids = self.lp.detect_lostruns(expiration_secs=0.5) # Ensure the states are sync wf = self.lp.get_wf_by_fw_id_lzyfw(self.zeus_fw_id) fws = wf.id_fw @@ -1071,8 +1016,7 @@ def setUpClass(cls): cls.lp = LaunchPad(name=TESTDB_NAME, strm_lvl="ERROR") cls.lp.reset(password=None, require_password=False) except Exception: - raise unittest.SkipTest( - "MongoDB is not running in localhost:27017! Skipping tests.") + raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.") @classmethod def tearDownClass(cls): @@ -1086,8 +1030,7 @@ def setUp(self): fw = Firework( [ ExecutionCounterTask(), - ScriptTask.from_str('date +"%s %N"', - parameters={"stdout_file": "date_file"}), + ScriptTask.from_str('date +"%s %N"', parameters={"stdout_file": "date_file"}), ExceptionTestTask(exc_details=self.error_test_dict), ] ) @@ -1141,8 +1084,7 @@ def test_task_level_rerun_cp(self): self.assertEqual(self.lp.get_fw_by_id(1).state, "COMPLETED") self.assertEqual(ExecutionCounterTask.exec_counter, 1) self.assertEqual(ExceptionTestTask.exec_counter, 2) - self.assertTrue(filecmp.cmp(os.path.join(dirs[0], "date_file"), - os.path.join(dirs[1], "date_file"))) + self.assertTrue(filecmp.cmp(os.path.join(dirs[0], "date_file"), os.path.join(dirs[1], "date_file"))) def test_task_level_rerun_prev_dir(self): rapidfire(self.lp, self.fworker, m_dir=MODULE_DIR) @@ -1153,8 +1095,7 @@ def test_task_level_rerun_prev_dir(self): fw = self.lp.get_fw_by_id(1) self.assertEqual(os.getcwd(), MODULE_DIR) self.assertEqual(fw.state, "COMPLETED") - self.assertEqual(fw.launches[0].launch_dir, - fw.archived_launches[0].launch_dir) + self.assertEqual(fw.launches[0].launch_dir, fw.archived_launches[0].launch_dir) self.assertEqual(ExecutionCounterTask.exec_counter, 1) self.assertEqual(ExceptionTestTask.exec_counter, 2) @@ -1168,8 +1109,7 @@ def setUpClass(cls): cls.lp = LaunchPad(name=TESTDB_NAME, strm_lvl="ERROR") cls.lp.reset(password=None, require_password=False) except Exception: - raise unittest.SkipTest( - "MongoDB is not running in localhost:27017! Skipping tests.") + raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.") @classmethod def tearDownClass(cls): @@ -1178,17 +1118,13 @@ def tearDownClass(cls): def setUp(self): # set the defaults in the init of wflock to break the lock quickly - fireworks.core.launchpad.WFLock(3, - False).__init__.__func__.__defaults__ = ( - 3, False) + fireworks.core.launchpad.WFLock(3, False).__init__.__func__.__defaults__ = (3, False) self.error_test_dict = {"error": "description", "error_code": 1} fw_slow = Firework(SlowAdditionTask(), spec={"seconds": 10}, fw_id=1) - fw_fast = Firework(WaitWFLockTask(), fw_id=2, - spec={"_add_launchpad_and_fw_id": True}) + fw_fast = Firework(WaitWFLockTask(), fw_id=2, spec={"_add_launchpad_and_fw_id": True}) fw_child = Firework(ScriptTask.from_str('echo "child"'), fw_id=3) - wf = Workflow([fw_slow, fw_fast, fw_child], - {fw_slow: fw_child, fw_fast: fw_child}) + wf = Workflow([fw_slow, fw_fast, fw_child], {fw_slow: fw_child, fw_fast: fw_child}) self.lp.add_wf(wf) self.old_wd = os.getcwd() @@ -1226,8 +1162,7 @@ def run(self): fast_fw = self.lp.get_fw_by_id(2) if fast_fw.state == "FIZZLED": - stacktrace = self.lp.launches.find_one({"fw_id": 2}, { - "action.stored_data._exception._stacktrace": 1})[ + stacktrace = self.lp.launches.find_one({"fw_id": 2}, {"action.stored_data._exception._stacktrace": 1})[ "action" ]["stored_data"]["_exception"]["_stacktrace"] if "SkipTest" in stacktrace: @@ -1276,8 +1211,7 @@ def run(self): fast_fw = self.lp.get_fw_by_id(2) if fast_fw.state == "FIZZLED": - stacktrace = self.lp.launches.find_one({"fw_id": 2}, { - "action.stored_data._exception._stacktrace": 1})[ + stacktrace = self.lp.launches.find_one({"fw_id": 2}, {"action.stored_data._exception._stacktrace": 1})[ "action" ]["stored_data"]["_exception"]["_stacktrace"] if "SkipTest" in stacktrace: @@ -1306,8 +1240,7 @@ def setUpClass(cls): cls.lp = LaunchPad(name=TESTDB_NAME, strm_lvl="ERROR") cls.lp.reset(password=None, require_password=False) except Exception: - raise unittest.SkipTest( - "MongoDB is not running in localhost:27017! Skipping tests.") + raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.") @classmethod def tearDownClass(cls): @@ -1318,9 +1251,7 @@ def setUp(self): fireworks.core.firework.EXCEPT_DETAILS_ON_RERUN = True self.error_test_dict = {"error": "description", "error_code": 1} - fw = Firework( - ScriptTask.from_str('echo "test offline"', {"store_stdout": True}), - name="offline_fw", fw_id=1) + fw = Firework(ScriptTask.from_str('echo "test offline"', {"store_stdout": True}), name="offline_fw", fw_id=1) self.lp.add_wf(fw) self.launch_dir = os.path.join(MODULE_DIR, "launcher_offline") @@ -1362,17 +1293,14 @@ def test_recover_errors(self): shutil.rmtree(self.launch_dir) # recover ignoring errors - self.assertIsNotNone( - self.lp.recover_offline(launch_id, ignore_errors=True, - print_errors=True)) + self.assertIsNotNone(self.lp.recover_offline(launch_id, ignore_errors=True, print_errors=True)) fw = self.lp.get_fw_by_id(launch_id) self.assertEqual(fw.state, "RESERVED") # fizzle - self.assertIsNotNone( - self.lp.recover_offline(launch_id, ignore_errors=False)) + self.assertIsNotNone(self.lp.recover_offline(launch_id, ignore_errors=False)) fw = self.lp.get_fw_by_id(launch_id) @@ -1393,8 +1321,7 @@ def setUpClass(cls): cls.lp = LaunchPad(name=TESTDB_NAME, strm_lvl="ERROR") cls.lp.reset(password=None, require_password=False) except Exception: - raise unittest.SkipTest( - "MongoDB is not running in localhost:27017! Skipping tests.") + raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.") @classmethod def tearDownClass(cls): diff --git a/fireworks/core/types.py b/fireworks/core/types.py new file mode 100644 index 000000000..51e95e95f --- /dev/null +++ b/fireworks/core/types.py @@ -0,0 +1,6 @@ +from typing import Any, Dict, Mapping, MutableMapping + +Checkpoint = MutableMapping[Any, Any] +FromDict = Mapping[str, Any] +Spec = MutableMapping[str, Any] +ToDict = Dict[str, Any] diff --git a/fireworks/flask_site/gunicorn.py b/fireworks/flask_site/gunicorn.py index a65c25a3a..d4071ab2c 100755 --- a/fireworks/flask_site/gunicorn.py +++ b/fireworks/flask_site/gunicorn.py @@ -1,23 +1,24 @@ # Based on http://docs.gunicorn.org/en/19.6.0/custom.html import multiprocessing +from typing import Any, Mapping, Optional import gunicorn.app.base from fireworks.flask_site.app import app as handler_app -def number_of_workers(): +def number_of_workers() -> int: return (multiprocessing.cpu_count() * 2) + 1 class StandaloneApplication(gunicorn.app.base.BaseApplication): - def __init__(self, app, options=None): + def __init__(self, app, options: Optional[Mapping[str, Any]] = None) -> None: self.options = options or {} self.application = app super().__init__() - def load_config(self): + def load_config(self) -> None: config = {key: value for key, value in self.options.items() if key in self.cfg.settings and value is not None} for key, value in config.items(): self.cfg.set(key.lower(), value) diff --git a/fireworks/flask_site/helpers.py b/fireworks/flask_site/helpers.py index 63e91a022..3a97d6ccc 100644 --- a/fireworks/flask_site/helpers.py +++ b/fireworks/flask_site/helpers.py @@ -1,4 +1,9 @@ -def get_totals(states, lp): +from typing import Mapping, Sequence + +from fireworks import LaunchPad + + +def get_totals(states: Sequence[str], lp: LaunchPad) -> Mapping[str, int]: fw_stats = {} wf_stats = {} for state in states: @@ -7,21 +12,21 @@ def get_totals(states, lp): return {"fw_stats": fw_stats, "wf_stats": wf_stats} -def fw_filt_given_wf_filt(filt, lp): +def fw_filt_given_wf_filt(filt, lp: LaunchPad) -> Mapping[str, Mapping[str, Sequence[int]]]: fw_ids = set() for doc in lp.workflows.find(filt, {"_id": 0, "nodes": 1}): fw_ids |= set(doc["nodes"]) return {"fw_id": {"$in": list(fw_ids)}} -def wf_filt_given_fw_filt(filt, lp): +def wf_filt_given_fw_filt(filt, lp: LaunchPad) -> Mapping[str, Mapping[str, Sequence[int]]]: wf_ids = set() for doc in lp.fireworks.find(filt, {"_id": 0, "fw_id": 1}): wf_ids.add(doc["fw_id"]) return {"nodes": {"$in": list(wf_ids)}} -def uses_index(filt, coll): +def uses_index(filt, coll) -> bool: ii = coll.index_information() fields_filtered = set(filt.keys()) fields_indexed = {v["key"][0][0] for v in ii.values()} diff --git a/fireworks/fw_config.py b/fireworks/fw_config.py index bfdf62cb3..bc8691267 100644 --- a/fireworks/fw_config.py +++ b/fireworks/fw_config.py @@ -3,7 +3,7 @@ """ import os -from typing import Any, Dict +from typing import Any, Dict, List, Optional from monty.design_patterns import singleton from monty.serialization import dumpfn, loadfn @@ -43,7 +43,7 @@ PRINT_FW_YAML = False JSON_SCHEMA_VALIDATE = False -JSON_SCHEMA_VALIDATE_LIST = None +JSON_SCHEMA_VALIDATE_LIST: Optional[List[str]] = None PING_TIME_SECS = 3600 # while Running a job, how often to ping back the server that we're still alive RUN_EXPIRATION_SECS = PING_TIME_SECS * 4 # mark job as FIZZLED if not pinged in this time @@ -59,7 +59,7 @@ RAPIDFIRE_SLEEP_SECS = 60 # seconds to sleep between rapidfire loops LAUNCHPAD_LOC = None # where to find the my_launchpad.yaml file -FWORKER_LOC = None # where to find the my_fworker.yaml file +FWORKER_LOC: Optional[str] = None # where to find the my_fworker.yaml file QUEUEADAPTER_LOC = None # where to find the my_qadapter.yaml file CONFIG_FILE_DIR = "." # directory containing config files (if not individually set) @@ -167,7 +167,7 @@ def config_to_dict() -> Dict[str, Any]: return d -def write_config(path: str = None) -> None: +def write_config(path: Optional[str] = None) -> None: if path is None: path = os.path.join(os.path.expanduser("~"), ".fireworks", "FW_config.yaml") dumpfn(config_to_dict(), path) diff --git a/fireworks/queue/queue_adapter.py b/fireworks/queue/queue_adapter.py index b28f508ea..760506b09 100644 --- a/fireworks/queue/queue_adapter.py +++ b/fireworks/queue/queue_adapter.py @@ -11,6 +11,7 @@ import threading import traceback import warnings +from typing import Any, Dict from fireworks.utilities.fw_serializers import FWSerializable, serialize_fw from fireworks.utilities.fw_utilities import get_fw_logger @@ -106,7 +107,7 @@ class QueueAdapterBase(collections.defaultdict, FWSerializable): template_file = "OVERRIDE_ME" # path to template file for a queue script submit_cmd = "OVERRIDE_ME" # command to submit jobs, e.g. "qsub" or "squeue" q_name = "OVERRIDE_ME" # (arbitrary) name, e.g. "pbs" or "slurm" - defaults = {} # default parameter values for template + defaults: Dict[str, Any] = {} # default parameter values for template def get_script_str(self, launch_dir): """ diff --git a/fireworks/scripts/fwtool b/fireworks/scripts/fwtool index e71d5b5c1..7be00c13a 100755 --- a/fireworks/scripts/fwtool +++ b/fireworks/scripts/fwtool @@ -3,11 +3,12 @@ import os import shutil import sys -from argparse import ArgumentParser +from argparse import ArgumentParser, Namespace +from typing import List import yaml -from fireworks.core.firework import Firework +from fireworks.core.firework import FiretaskBase, Firework from fireworks.core.launchpad import LaunchPad from fireworks.user_objects.firetasks.fileio_tasks import FileWriteTask from fireworks.user_objects.firetasks.script_task import ScriptTask @@ -18,9 +19,11 @@ __maintainer__ = "Shyue Ping Ong" __email__ = "ongsp@ucsd.edu" __date__ = "1/6/14" +from typing import Sequence -def create_fw_single(args, fnames, yaml_fname): - tasks = [] + +def create_fw_single(args: Namespace, fnames: Sequence[str], yaml_fname: str) -> None: + tasks: List[FiretaskBase] = [] if fnames: files = [] for fname in fnames: @@ -34,7 +37,7 @@ def create_fw_single(args, fnames, yaml_fname): yaml.dump(fw.to_dict(), f, default_flow_style=False) -def create_fw(args): +def create_fw(args: Namespace) -> None: if not args.directory_mode: create_fw_single(args, args.files_or_dirs, args.output) else: @@ -43,7 +46,7 @@ def create_fw(args): create_fw_single(args, [os.path.join(d, f) for f in os.listdir(d)], output_fname.format(i)) -def do_cleanup(args): +def do_cleanup(args: Namespace) -> None: lp = LaunchPad.from_file(args.launchpad_file) if args.launchpad_file else LaunchPad(strm_lvl=args.loglvl) to_delete = [] for i in lp.get_fw_ids({}): diff --git a/fireworks/scripts/lpad_run.py b/fireworks/scripts/lpad_run.py index 0545189c4..9a6340370 100644 --- a/fireworks/scripts/lpad_run.py +++ b/fireworks/scripts/lpad_run.py @@ -1,7 +1,7 @@ """ A runnable script for managing a FireWorks database (a command-line interface to launchpad.py) """ - +# mypy: ignore-errors import ast import copy import datetime @@ -55,7 +55,7 @@ DEFAULT_LPAD_YAML = "my_launchpad.yaml" -def pw_check(ids: List[int], args: Namespace, skip_pw: bool = False) -> List[int]: +def pw_check(ids: Sequence[int], args: Namespace, skip_pw: bool = False) -> List[int]: if len(ids) > PW_CHECK_NUM and not skip_pw: m_password = datetime.datetime.now().strftime("%Y-%m-%d") if not args.password: @@ -249,7 +249,7 @@ def add_wf_dir(args: Namespace) -> None: lp.add_wf(fwf) -def print_fws(ids, lp, args: Namespace) -> None: +def print_fws(ids: Sequence[int], lp: LaunchPad, args: Namespace) -> None: """Prints results of some FireWorks query to stdout.""" fws = [] if args.display_format == "ids": @@ -275,7 +275,7 @@ def print_fws(ids, lp, args: Namespace) -> None: print(args.output(fws)) -def get_fw_ids_helper(lp: LaunchPad, args: Namespace, count_only: Union[bool, None] = None) -> Union[List[int], int]: +def get_fw_ids_helper(lp, args, count_only: Optional[bool] = None) -> Union[List[int], int]: """Build fws query from command line options and submit. Parameters: @@ -327,9 +327,7 @@ def get_fw_ids_helper(lp: LaunchPad, args: Namespace, count_only: Union[bool, No return ids -def get_fws_helper( - lp: LaunchPad, ids: List[int], args: Namespace -) -> Union[List[int], int, List[Dict[str, Union[str, int, bool]]], Union[str, int, bool]]: +def get_fws_helper(lp: LaunchPad, ids: Sequence[int], args: Namespace) -> Union[int, List[int], List[List[int]]]: """Get fws from ids in a representation according to args.display_format.""" fws = [] if args.display_format == "ids": @@ -493,15 +491,15 @@ def delete_wfs(args: Namespace) -> None: lp.m_logger.info(f"Finished deleting {len(fw_ids)} WFs") -def get_children(links, start, max_depth): - data = {} - for l, c in links.items(): - if l == start: - if len(c) > 0: - data[l] = [get_children(links, i, max_depth) for i in c] - else: - data[l] = c - return data +# def get_children(links, start): +# data = {} +# for l, c in links.items(): +# if l == start: +# if len(c) > 0: +# data[l] = [get_children(links, i) for i in c] +# else: +# data[l] = c +# return data def detect_lostruns(args: Namespace) -> None: @@ -695,7 +693,7 @@ def set_priority(args: Namespace) -> None: lp.m_logger.info(f"Finished setting priorities of {len(fw_ids)} FWs") -def _open_webbrowser(url): +def _open_webbrowser(url: str) -> None: """Open a web browser after a delay to give the web server more startup time.""" import webbrowser @@ -869,14 +867,14 @@ def orphaned(args: Namespace) -> None: print(args.output(fws)) -def get_output_func(format: Literal["json", "yaml"]) -> Callable[[str], Any]: +def get_output_func(format: Literal["json", "yaml"]) -> Callable[[Any], str]: if format == "json": - return lambda x: json.dumps(x, default=DATETIME_HANDLER, indent=4) + return lambda x: json.dumps(x, default=DATETIME_HANDLER, indent=4) # type: ignore else: - return lambda x: yaml.safe_dump(recursive_dict(x, preserve_unicode=False), default_flow_style=False) + return lambda x: yaml.safe_dump(recursive_dict(x, preserve_unicode=False), default_flow_style=False) # type: ignore -def arg_positive_int(value: str) -> int: +def arg_positive_int(value: Any) -> int: try: ivalue = int(value) except ValueError: @@ -934,7 +932,7 @@ def lpad(argv: Optional[Sequence[str]] = None) -> int: # enhanced display options allow for value 'none' or None (default) for no output enh_disp_args = copy.deepcopy(disp_args) enh_disp_kwargs = copy.deepcopy(disp_kwargs) - enh_disp_kwargs["choices"].append("none") + enh_disp_kwargs["choices"].append("none") # type: ignore enh_disp_kwargs["default"] = None query_args = ["-q", "--query"] diff --git a/fireworks/user_objects/firetasks/dataflow_tasks.py b/fireworks/user_objects/firetasks/dataflow_tasks.py index c59f7d32a..e149ece8d 100644 --- a/fireworks/user_objects/firetasks/dataflow_tasks.py +++ b/fireworks/user_objects/firetasks/dataflow_tasks.py @@ -5,9 +5,11 @@ __copyright__ = "Copyright 2016, Karlsruhe Institute of Technology" import sys +from typing import Any, List, Mapping, Optional from fireworks import Firework from fireworks.core.firework import FiretaskBase, FWAction +from fireworks.core.types import Spec from fireworks.utilities.fw_serializers import load_object if sys.version_info[0] > 2: @@ -74,7 +76,7 @@ class CommandLineTask(FiretaskBase): required_params = ["command_spec"] optional_params = ["inputs", "outputs", "chunk_number"] - def run_task(self, fw_spec): + def run_task(self, fw_spec: Spec) -> FWAction: cmd_spec = self["command_spec"] ilabels = self.get("inputs") olabels = self.get("outputs") @@ -136,7 +138,7 @@ def run_task(self, fw_spec): return FWAction() @staticmethod - def command_line_tool(command, inputs=None, outputs=None): + def command_line_tool(command, inputs: Optional[List[Mapping[Any, Any]]] = None, outputs: Optional[Spec] = None): """ This function composes and executes a command from provided specifications. @@ -163,7 +165,7 @@ def command_line_tool(command, inputs=None, outputs=None): from shutil import copyfile from subprocess import PIPE, Popen - def set_binding(arg): + def set_binding(arg: Mapping[str, Any]) -> str: argstr = "" if "binding" in arg: if "prefix" in arg["binding"]: @@ -283,7 +285,7 @@ class ForeachTask(FiretaskBase): required_params = ["task", "split"] optional_params = ["number of chunks"] - def run_task(self, fw_spec): + def run_task(self, fw_spec: Spec) -> FWAction: assert isinstance(self["split"], basestring), self["split"] assert isinstance(fw_spec[self["split"]], list) if isinstance(self["task"]["inputs"], list): @@ -321,7 +323,7 @@ class JoinDictTask(FiretaskBase): required_params = ["inputs", "output"] optional_params = ["rename"] - def run_task(self, fw_spec): + def run_task(self, fw_spec: Spec) -> FWAction: assert isinstance(self["output"], basestring) assert isinstance(self["inputs"], list) @@ -351,7 +353,7 @@ class JoinListTask(FiretaskBase): _fw_name = "JoinListTask" required_params = ["inputs", "output"] - def run_task(self, fw_spec): + def run_task(self, fw_spec: Spec) -> FWAction: assert isinstance(self["output"], basestring) assert isinstance(self["inputs"], list) if self["output"] not in fw_spec: @@ -377,7 +379,7 @@ class ImportDataTask(FiretaskBase): required_params = ["filename", "mapstring"] optional_params = [] - def run_task(self, fw_spec): + def run_task(self, fw_spec: Spec) -> FWAction: import json import operator from functools import reduce diff --git a/fireworks/user_objects/firetasks/script_task.py b/fireworks/user_objects/firetasks/script_task.py index 5630dfb78..ca4a2b396 100644 --- a/fireworks/user_objects/firetasks/script_task.py +++ b/fireworks/user_objects/firetasks/script_task.py @@ -4,9 +4,10 @@ import shlex import subprocess import sys -from typing import Dict, List, Optional, Union +from typing import Any, MutableMapping, Optional from fireworks.core.firework import FiretaskBase, FWAction +from fireworks.core.types import Spec if sys.version_info[0] > 2: basestring = str @@ -24,7 +25,7 @@ class ScriptTask(FiretaskBase): required_params = ["script"] _fw_name = "ScriptTask" - def run_task(self, fw_spec): + def run_task(self, fw_spec: Spec) -> FWAction: if self.get("use_global_spec"): self._load_params(fw_spec) else: @@ -37,7 +38,7 @@ def run_task(self, fw_spec): stdin = subprocess.PIPE if self.stdin_key else None return self._run_task_internal(fw_spec, stdin) - def _run_task_internal(self, fw_spec, stdin): + def _run_task_internal(self, fw_spec: Spec, stdin) -> FWAction: # run the program stdout = subprocess.PIPE if self.store_stdout or self.stdout_file else None stderr = subprocess.PIPE if self.store_stderr or self.stderr_file else None @@ -120,7 +121,7 @@ def _load_params(self, d): raise ValueError("ScriptTask cannot both FIZZLE and DEFUSE a bad returncode!") @classmethod - def from_str(cls, shell_cmd, parameters=None): + def from_str(cls, shell_cmd: str, parameters: Optional[MutableMapping[str, Any]] = None) -> "ScriptTask": parameters = parameters if parameters else {} parameters["script"] = [shell_cmd] parameters["use_shell"] = True @@ -163,7 +164,7 @@ class PyTask(FiretaskBase): # note that we are not using "optional_params" because we do not want to do # strict parameter checking in FiretaskBase due to "auto_kwargs" option - def run_task(self, fw_spec: Dict[str, Union[List[int], int]]) -> Optional[FWAction]: + def run_task(self, fw_spec: Spec) -> Optional[FWAction]: toks = self["func"].rsplit(".", 1) if len(toks) == 2: modname, funcname = toks diff --git a/fireworks/utilities/dagflow.py b/fireworks/utilities/dagflow.py index c66da1685..dcfb2ceb4 100644 --- a/fireworks/utilities/dagflow.py +++ b/fireworks/utilities/dagflow.py @@ -5,10 +5,13 @@ __copyright__ = "Copyright 2017, Karlsruhe Institute of Technology" from itertools import combinations +from typing import List, Optional, Tuple import igraph from igraph import Graph +from fireworks import Workflow + DF_TASKS = ["PyTask", "CommandLineTask", "ForeachTask", "JoinDictTask", "JoinListTask"] DEFAULT_IGRAPH_VISUAL_STYLE = { @@ -42,7 +45,7 @@ class DAGFlow(Graph): """The purpose of this class is to help construction, validation and visualization of workflows.""" - def __init__(self, steps, links=None, nlinks=None, name=None, **kwargs): + def __init__(self, steps, links=None, nlinks=None, name: Optional[str] = None, **kwargs) -> None: Graph.__init__(self, directed=True, graph_attrs={"name": name}, **kwargs) for step in steps: @@ -56,7 +59,7 @@ def __init__(self, steps, links=None, nlinks=None, name=None, **kwargs): self._add_dataflow_links() @classmethod - def from_fireworks(cls, fireworkflow): + def from_fireworks(cls, fireworkflow: Workflow) -> "DAGFlow": """Converts a fireworks workflow object into a new DAGFlow object""" wfd = fireworkflow.to_dict() if "name" in wfd: @@ -117,7 +120,7 @@ def task_input(task, spec): return cls(steps=steps, links=links, name=name) - def _get_links(self, nlinks): + def _get_links(self, nlinks) -> List[Tuple[List[str], List[str]]]: """Translates named links into links between step ids""" links = [] for link in nlinks: @@ -126,7 +129,7 @@ def _get_links(self, nlinks): links.append((source, target)) return links - def _get_ctrlflow_links(self): + def _get_ctrlflow_links(self) -> List[Tuple[str, str]]: """Returns a list of unique tuples of link ids""" links = [] for ilink in {link.tuple for link in list(self.es)}: @@ -135,7 +138,7 @@ def _get_ctrlflow_links(self): links.append((source, target)) return links - def _add_ctrlflow_links(self, links): + def _add_ctrlflow_links(self, links) -> None: """Adds graph edges corresponding to control flow links""" for link in links: source = self._get_index(link[0]) @@ -262,22 +265,22 @@ def _get_leaves(self): """Returns all leaves (i.e. vertices without outgoing edges)""" return [i for i, v in enumerate(self.degree(mode=igraph.OUT)) if v == 0] - def delete_ctrlflow_links(self): + def delete_ctrlflow_links(self) -> None: """Deletes graph edges corresponding to control flow links""" lst = [link.index for link in list(self.es) if link["label"] == " "] self.delete_edges(lst) - def delete_dataflow_links(self): + def delete_dataflow_links(self) -> None: """Deletes graph edges corresponding to data flow links""" lst = [link.index for link in list(self.es) if link["label"] != " "] self.delete_edges(lst) - def add_step_labels(self): + def add_step_labels(self) -> None: """Labels the workflow steps (i.e. graph vertices)""" for vertex in list(self.vs): vertex["label"] = vertex["name"] + ", id: " + str(vertex["id"]) - def check(self): + def check(self) -> None: """Correctness check of the workflow""" try: assert self.is_dag(), "The workflow graph must be a DAG." @@ -288,7 +291,7 @@ def check(self): assert len(self.vs["id"]) == len(set(self.vs["id"])), "Workflow steps must have unique IDs." self.check_dataflow() - def check_dataflow(self): + def check_dataflow(self) -> None: """Checks whether all inputs and outputs match""" # check for shared output data entities diff --git a/fireworks/utilities/fw_serializers.py b/fireworks/utilities/fw_serializers.py index 6423325d9..ca4a6c53e 100644 --- a/fireworks/utilities/fw_serializers.py +++ b/fireworks/utilities/fw_serializers.py @@ -33,10 +33,12 @@ import json # note that ujson is faster, but at this time does not support "default" in dumps() import pkgutil import traceback +from typing import Any, Dict, MutableMapping, Optional, Type, TypeVar import ruamel.yaml as yaml from monty.json import MontyDecoder, MSONable +from fireworks.core.types import FromDict from fireworks.fw_config import ( DECODE_MONTY, ENCODE_MONTY, @@ -55,7 +57,7 @@ # TODO: consider *somehow* switching FireWorks to monty serialization. e.g., numpy serialization is better handled. -SAVED_FW_MODULES = {} +SAVED_FW_MODULES: MutableMapping[str, str] = {} DATETIME_HANDLER = lambda obj: obj.isoformat() if isinstance(obj, datetime.datetime) else None ENCODING_PARAMS = {"encoding": "utf-8"} @@ -71,7 +73,7 @@ import fireworks_schema -def recursive_dict(obj, preserve_unicode=True): +def recursive_dict(obj: Any, preserve_unicode: bool = True) -> Any: if obj is None: return None @@ -93,7 +95,7 @@ def recursive_dict(obj, preserve_unicode=True): if isinstance(obj, datetime.datetime): return obj.isoformat() - if preserve_unicode and isinstance(obj, str) and obj != obj.encode("ascii", "ignore"): + if preserve_unicode and isinstance(obj, str) and obj != obj.encode("ascii", "ignore"): # type: ignore return obj if NUMPY_INSTALLED and isinstance(obj, np.ndarray): @@ -103,7 +105,7 @@ def recursive_dict(obj, preserve_unicode=True): # TODO: is reconstitute_dates really needed? Can this method just do everything? -def _recursive_load(obj): +def _recursive_load(obj: Any) -> Any: if obj is None: return None @@ -128,7 +130,7 @@ def _recursive_load(obj): return reconstitute_dates(obj) except Exception: # convert unicode to ASCII if not really unicode - if obj == obj.encode("ascii", "ignore"): + if obj == obj.encode("ascii", "ignore"): # type: ignore return str(obj) return obj @@ -177,6 +179,9 @@ def _decorator(self, *args, **kwargs): return _decorator +T = TypeVar("T", bound="FWSerializable") + + class FWSerializable(metaclass=abc.ABCMeta): """ To create a serializable object within FireWorks, you should subclass this @@ -196,20 +201,20 @@ class and implement the to_dict() and from_dict() methods. """ @property - def fw_name(self): + def fw_name(self) -> str: try: - return self._fw_name + return self._fw_name # type: ignore except AttributeError: return get_default_serialization(self.__class__) @abc.abstractmethod - def to_dict(self): + def to_dict(self) -> Dict[Any, Any]: raise NotImplementedError("FWSerializable object did not implement to_dict()!") - def to_db_dict(self): + def to_db_dict(self) -> Dict[Any, Any]: return self.to_dict() - def as_dict(self): + def as_dict(self) -> Dict[Any, Any]: # strictly for pseudo-compatibility with MSONable # Note that FWSerializable is not MSONable, it uses _fw_name instead of __class__ and # __module__ @@ -217,13 +222,13 @@ def as_dict(self): @classmethod @abc.abstractmethod - def from_dict(cls, m_dict): + def from_dict(cls: Type[T], m_dict: FromDict) -> T: raise NotImplementedError("FWSerializable object did not implement from_dict()!") - def __repr__(self): + def __repr__(self) -> str: return json.dumps(self.to_dict(), default=DATETIME_HANDLER) - def to_format(self, f_format="json", **kwargs): + def to_format(self, f_format: str = "json", **kwargs) -> str: """ returns a String representation in the given format @@ -234,12 +239,12 @@ def to_format(self, f_format="json", **kwargs): return json.dumps(self.to_dict(), default=DATETIME_HANDLER, **kwargs) elif f_format == "yaml": # start with the JSON format, and convert to YAML - return yaml.safe_dump(self.to_dict(), default_flow_style=YAML_STYLE, allow_unicode=True) + return yaml.safe_dump(self.to_dict(), default_flow_style=YAML_STYLE, allow_unicode=True) # type: ignore else: raise ValueError(f"Unsupported format {f_format}") @classmethod - def from_format(cls, f_str, f_format="json"): + def from_format(cls: Type[T], f_str: str, f_format: str = "json") -> T: """ convert from a String representation to its Object. @@ -256,11 +261,11 @@ def from_format(cls, f_str, f_format="json"): dct = yaml.safe_load(f_str) else: raise ValueError(f"Unsupported format {f_format}") - if JSON_SCHEMA_VALIDATE and cls.__name__ in JSON_SCHEMA_VALIDATE_LIST: + if JSON_SCHEMA_VALIDATE and cls.__name__ in JSON_SCHEMA_VALIDATE_LIST: # type: ignore fireworks_schema.validate(dct, cls.__name__) return cls.from_dict(reconstitute_dates(dct)) - def to_file(self, filename, f_format=None, **kwargs): + def to_file(self, filename: str, f_format: Optional[str] = None, **kwargs) -> None: """ Write a serialization of this object to a file. @@ -270,11 +275,11 @@ def to_file(self, filename, f_format=None, **kwargs): """ if f_format is None: f_format = filename.split(".")[-1] - with open(filename, "w", **ENCODING_PARAMS) as f: + with open(filename, "w", **ENCODING_PARAMS) as f: # type: ignore f.write(self.to_format(f_format=f_format, **kwargs)) @classmethod - def from_file(cls, filename, f_format=None): + def from_file(cls: Type[T], filename: str, f_format: Optional[str] = None) -> T: """ Load a serialization of this object from a file. @@ -287,20 +292,20 @@ def from_file(cls, filename, f_format=None): """ if f_format is None: f_format = filename.split(".")[-1] - with open(filename, "r", **ENCODING_PARAMS) as f: + with open(filename, "r", **ENCODING_PARAMS) as f: # type: ignore return cls.from_format(f.read(), f_format=f_format) def __getstate__(self): return self.to_dict() - def __setstate__(self, state): + def __setstate__(self, state) -> None: fw_obj = self.from_dict(state) for k, v in fw_obj.__dict__.items(): self.__dict__[k] = v # TODO: make this quicker the first time around -def load_object(obj_dict): +def load_object(obj_dict: MutableMapping[str, Any]) -> Any: """ Creates an instantiation of a class based on a dictionary representation. We implicitly determine the Class through introspection along with information in the dictionary. @@ -370,7 +375,7 @@ def load_object(obj_dict): raise ValueError(f"load_object() could not find a class with cls._fw_name {fw_name}") -def load_object_from_file(filename, f_format=None): +def load_object_from_file(filename: str, f_format: Optional[str] = None) -> Any: """ Implicitly load an object from a file. just a friendly wrapper to load_object() @@ -383,7 +388,7 @@ def load_object_from_file(filename, f_format=None): if f_format is None: f_format = filename.split(".")[-1] - with open(filename, "r", **ENCODING_PARAMS) as f: + with open(filename, "r", **ENCODING_PARAMS) as f: # type: ignore if f_format == "json": dct = json.loads(f.read()) elif f_format == "yaml": @@ -392,8 +397,9 @@ def load_object_from_file(filename, f_format=None): raise ValueError(f"Unknown file format {f_format} cannot be loaded!") classname = FW_NAME_UPDATES.get(dct["_fw_name"], dct["_fw_name"]) - if JSON_SCHEMA_VALIDATE and classname in JSON_SCHEMA_VALIDATE_LIST: - fireworks_schema.validate(dct, classname) + if JSON_SCHEMA_VALIDATE: + if JSON_SCHEMA_VALIDATE_LIST is not None and classname in JSON_SCHEMA_VALIDATE_LIST: + fireworks_schema.validate(dct, classname) return load_object(reconstitute_dates(dct)) @@ -413,7 +419,7 @@ def _search_module_for_obj(m_module, obj_dict): return obj.from_dict(obj_dict) -def reconstitute_dates(obj_dict): +def reconstitute_dates(obj_dict: Any) -> Any: if obj_dict is None: return None @@ -434,7 +440,7 @@ def reconstitute_dates(obj_dict): return obj_dict -def get_default_serialization(cls): +def get_default_serialization(cls: Type[Any]) -> str: root_mod = cls.__module__.split(".")[0] if root_mod == "__main__": raise ValueError( diff --git a/fireworks/utilities/fw_utilities.py b/fireworks/utilities/fw_utilities.py index 261902f51..77713b0a1 100644 --- a/fireworks/utilities/fw_utilities.py +++ b/fireworks/utilities/fw_utilities.py @@ -10,7 +10,7 @@ import traceback from logging import Formatter, Logger from multiprocessing.managers import BaseManager -from typing import Tuple +from typing import List, Optional, Tuple from fireworks.fw_config import DS_PASSWORD, FW_BLOCK_FORMAT, FW_LOGGING_FORMAT, FWData @@ -20,14 +20,14 @@ __email__ = "ajain@lbl.gov" __date__ = "Dec 12, 2012" -PREVIOUS_STREAM_LOGGERS = [] # contains the name of loggers that have already been initialized +PREVIOUS_STREAM_LOGGERS: List[Tuple[str, str]] = [] # contains the name of loggers that have already been initialized PREVIOUS_FILE_LOGGERS = [] # contains the name of file loggers that have already been initialized DEFAULT_FORMATTER = Formatter(FW_LOGGING_FORMAT) def get_fw_logger( name: str, - l_dir: None = None, + l_dir: Optional[str] = None, file_levels: Tuple[str, str] = ("DEBUG", "ERROR"), stream_level: str = "DEBUG", formatter: Formatter = DEFAULT_FORMATTER, @@ -49,7 +49,7 @@ def get_fw_logger( stream_level = stream_level if stream_level else "CRITICAL" # add handlers for the file_levels - if l_dir: + if l_dir is not None: for lvl in file_levels: f_name = os.path.join(l_dir, name.replace(".", "_") + "-" + lvl.lower() + ".log") mode = "w" if clear_logs else "a" diff --git a/fireworks/utilities/tests/test_dagflow.py b/fireworks/utilities/tests/test_dagflow.py index 3e570d534..c928ee644 100644 --- a/fireworks/utilities/tests/test_dagflow.py +++ b/fireworks/utilities/tests/test_dagflow.py @@ -14,7 +14,7 @@ class DAGFlowTest(unittest.TestCase): """run tests for DAGFlow class""" - def setUp(self): + def setUp(self) -> None: try: __import__("igraph", fromlist=["Graph"]) except (ImportError, ModuleNotFoundError): @@ -32,7 +32,7 @@ def setUp(self): ) self.fw3 = Firework(PyTask(func="print", inputs=["second power"]), name="the third one") - def test_dagflow_ok(self): + def test_dagflow_ok(self) -> None: """construct and replicate""" from fireworks.utilities.dagflow import DAGFlow @@ -40,7 +40,7 @@ def test_dagflow_ok(self): dagf = DAGFlow.from_fireworks(wfl) DAGFlow(**dagf.to_dict()) - def test_dagflow_loop(self): + def test_dagflow_loop(self) -> None: """loop in graph""" from fireworks.utilities.dagflow import DAGFlow @@ -50,7 +50,7 @@ def test_dagflow_loop(self): DAGFlow.from_fireworks(wfl).check() self.assertTrue(msg in str(context.exception)) - def test_dagflow_cut(self): + def test_dagflow_cut(self) -> None: """disconnected graph""" from fireworks.utilities.dagflow import DAGFlow @@ -60,7 +60,7 @@ def test_dagflow_cut(self): DAGFlow.from_fireworks(wfl).check() self.assertTrue(msg in str(context.exception)) - def test_dagflow_link(self): + def test_dagflow_link(self) -> None: """wrong links""" from fireworks.utilities.dagflow import DAGFlow @@ -70,7 +70,7 @@ def test_dagflow_link(self): DAGFlow.from_fireworks(wfl).check() self.assertTrue(msg in str(context.exception)) - def test_dagflow_missing_input(self): + def test_dagflow_missing_input(self) -> None: """missing input""" from fireworks.utilities.dagflow import DAGFlow @@ -87,7 +87,7 @@ def test_dagflow_missing_input(self): DAGFlow.from_fireworks(wfl).check() self.assertTrue(msg in str(context.exception)) - def test_dagflow_clashing_inputs(self): + def test_dagflow_clashing_inputs(self) -> None: """parent firework output overwrites an input in spec""" from fireworks.utilities.dagflow import DAGFlow @@ -105,7 +105,7 @@ def test_dagflow_clashing_inputs(self): DAGFlow.from_fireworks(wfl).check() self.assertTrue(msg in str(context.exception)) - def test_dagflow_race_condition(self): + def test_dagflow_race_condition(self) -> None: """two parent firework outputs overwrite each other""" from fireworks.utilities.dagflow import DAGFlow @@ -121,7 +121,7 @@ def test_dagflow_race_condition(self): DAGFlow.from_fireworks(wfl).check() self.assertTrue(msg in str(context.exception)) - def test_dagflow_clashing_outputs(self): + def test_dagflow_clashing_outputs(self) -> None: """subsequent task overwrites output of a task""" from fireworks.utilities.dagflow import DAGFlow @@ -135,7 +135,7 @@ def test_dagflow_clashing_outputs(self): DAGFlow.from_fireworks(Workflow([fwk], {})).check() self.assertTrue(msg in str(context.exception)) - def test_dagflow_non_dataflow_tasks(self): + def test_dagflow_non_dataflow_tasks(self) -> None: """non-dataflow tasks using outputs and inputs keys do not fail""" from fireworks.core.firework import FiretaskBase from fireworks.utilities.dagflow import DAGFlow @@ -154,7 +154,7 @@ def run_task(self, fw_spec): wfl = Workflow([self.fw1, fw2], {self.fw1: [fw2], fw2: []}) DAGFlow.from_fireworks(wfl).check() - def test_dagflow_view(self): + def test_dagflow_view(self) -> None: """visualize the workflow graph""" from fireworks.utilities.dagflow import DAGFlow diff --git a/pyproject.toml b/pyproject.toml index 55ec8d784..220f5b595 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,2 +1,63 @@ [tool.black] line-length = 120 + +[tool.mypy] +warn_redundant_casts = true +warn_return_any = true +warn_unreachable = true + +scripts_are_modules = true +warn_unused_configs = true + +# disallow_any_unimported = true +# disallow_any_decorated = true +# disallow_any_explicit = true +# disallow_any_expr = true +# disallow_any_generics = true +# disallow_subclassing_any = true + +# disallow_untyped_calls = true +# disallow_untyped_defs = true + +# check_untyped_defs = true + +# disallow_untyped_decorators = true + +no_implicit_optional = true + +strict_equality = true + +[[tool.mypy.overrides]] +module = [ + "argcomplete", + "atomate", + "atomate.vasp", + "atomate.vasp.workflows", + "fabric", + "fireworks_schema", + "flask_paginate", + "graphviz", + # https://github.com/benoitc/gunicorn/pull/2377 + "gunicorn", + "gunicorn.app", + "gunicorn.app.base", + "igraph", + "invoke", + "matplotlib", + "matplotlib.backends.backend_agg", + "matplotlib.figure", + "matplotlib.pyplot", + "matplotlib.ticker", + "monty", + "monty.design_patterns", + "monty.dev", + "monty.io", + "monty.json", + "monty.os", + "monty.os.path", + "monty.serialization", + "monty.shutil", + "pymatgen.core", + "pytest", +] +ignore_missing_imports = true