From dfdcde92c88f10ff560099e7264064cfce6df7c0 Mon Sep 17 00:00:00 2001 From: Eric Berquist Date: Fri, 15 Oct 2021 22:05:48 -0400 Subject: [PATCH 01/18] add initial mypy config --- .pre-commit-config.yaml | 20 ++++++++++++++++++++ pyproject.toml | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c4227fd18..ee0dcff91 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,6 +24,26 @@ repos: hooks: - id: isort + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.910-1 + hooks: + - id: mypy + additional_dependencies: + - jinja2 + - numpy + - pymongo-stubs + - ruamel.yaml + - types-flask + - types-paramiko + - types-prettytable + - types-python-dateutil + - types-pyyaml + - types-requests + - types-setuptools + - types-six + - types-tabulate + args: [] + - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.1.0 hooks: diff --git a/pyproject.toml b/pyproject.toml index 55ec8d784..af9a637bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,2 +1,40 @@ [tool.black] line-length = 120 + +[tool.mypy] +warn_redundant_casts = true +warn_return_any = true +warn_unreachable = true + +scripts_are_modules = true +warn_unused_configs = true + +[[tool.mypy.overrides]] +module = [ + "argcomplete", + "fabric", + "flask_paginate", + "graphviz", + # https://github.com/benoitc/gunicorn/pull/2377 + "gunicorn", + "gunicorn.app", + "gunicorn.app.base", + "invoke", + "matplotlib", + "matplotlib.backends.backend_agg", + "matplotlib.figure", + "matplotlib.pyplot", + "matplotlib.ticker", + "monty", + "monty.design_patterns", + "monty.dev", + "monty.io", + "monty.json", + "monty.os", + "monty.os.path", + "monty.serialization", + "monty.shutil", + # https://github.com/tqdm/tqdm/issues/260 + "tqdm", +] +ignore_missing_imports = true From e9a9efe1efe62f7bd91e3e24043614b3e46636e2 Mon Sep 17 00:00:00 2001 From: Eric Berquist Date: Fri, 15 Oct 2021 22:09:38 -0400 Subject: [PATCH 02/18] mypy: forgot igraph --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index af9a637bd..edbe5187e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ module = [ "gunicorn", "gunicorn.app", "gunicorn.app.base", + "igraph", "invoke", "matplotlib", "matplotlib.backends.backend_agg", From 4f064d1fcf5d102349f6bdf778374a3ec577938c Mon Sep 17 00:00:00 2001 From: Eric Berquist Date: Sat, 16 Oct 2021 12:46:05 -0400 Subject: [PATCH 03/18] add initial type aliases, fix import and format call in logging --- fireworks/core/rocket.py | 5 +++-- fireworks/core/types.py | 5 +++++ pyproject.toml | 1 + 3 files changed, 9 insertions(+), 2 deletions(-) create mode 100644 fireworks/core/types.py diff --git a/fireworks/core/rocket.py b/fireworks/core/rocket.py index 3c3d7da36..8352678a5 100644 --- a/fireworks/core/rocket.py +++ b/fireworks/core/rocket.py @@ -409,8 +409,9 @@ def run(self, pdb_on_exception: bool = False) -> bool: l_logger.log(logging.DEBUG, traceback.format_exc()) l_logger.log( logging.WARNING, - f"Firework {self.fw_id} fizzled but couldn't complete the update of the database. " - f"Reason: {e}\nRefresh the WF to recover the result (lpad admin refresh -i {self.fw_id}).", + "Firework {} fizzled but couldn't complete the update of the database." + " Reason: {}\nRefresh the WF to recover the result " + "(lpad admin refresh -i {}).".format(self.fw_id, e, self.fw_id), ) return True else: diff --git a/fireworks/core/types.py b/fireworks/core/types.py new file mode 100644 index 000000000..8f0bdf1e5 --- /dev/null +++ b/fireworks/core/types.py @@ -0,0 +1,5 @@ +from typing import Any, MutableMapping + + +Checkpoint = MutableMapping[Any, Any] +Spec = MutableMapping[Any, Any] diff --git a/pyproject.toml b/pyproject.toml index edbe5187e..b6e9b461d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ warn_unused_configs = true module = [ "argcomplete", "fabric", + "fireworks_schema", "flask_paginate", "graphviz", # https://github.com/benoitc/gunicorn/pull/2377 From 7b7919b39a2efcc558dece79376225faa68513ce Mon Sep 17 00:00:00 2001 From: Eric Berquist Date: Wed, 27 Oct 2021 09:46:06 -0400 Subject: [PATCH 04/18] If you have Hypothesis installed as a pytest plugin --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 379bd2972..81f54f990 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ *.py[cod] +.hypothesis .mypy_cache # doc builds From 984358373a67b5160e17b852884bea5a40ddb969 Mon Sep 17 00:00:00 2001 From: Eric Berquist Date: Wed, 27 Oct 2021 09:51:41 -0400 Subject: [PATCH 05/18] Start adding types --- docs_rst/conf.py | 3 +- fireworks/core/firework.py | 149 +++++++++--------- fireworks/core/fworker.py | 17 +- fireworks/core/launchpad.py | 54 ++++--- fireworks/core/rocket.py | 17 +- fireworks/core/rocket_launcher.py | 37 +++-- fireworks/core/tests/tasks.py | 3 +- fireworks/core/types.py | 4 +- fireworks/flask_site/gunicorn.py | 7 +- fireworks/flask_site/helpers.py | 13 +- fireworks/scripts/fwtool | 10 +- fireworks/scripts/lpad_run.py | 38 +++-- .../user_objects/firetasks/dataflow_tasks.py | 16 +- fireworks/utilities/fw_serializers.py | 43 ++--- 14 files changed, 227 insertions(+), 184 deletions(-) diff --git a/docs_rst/conf.py b/docs_rst/conf.py index 7aebfda79..74265cb4d 100644 --- a/docs_rst/conf.py +++ b/docs_rst/conf.py @@ -12,6 +12,7 @@ import os import sys +from typing import Mapping if sys.version_info < (3, 8): import importlib_metadata as metadata @@ -185,7 +186,7 @@ # -- Options for LaTeX output -------------------------------------------------- -latex_elements = { +latex_elements: Mapping[str, str] = { # The paper size ('letterpaper' or 'a4paper'). #'papersize': 'letterpaper', # The font size ('10pt', '11pt' or '12pt'). diff --git a/fireworks/core/firework.py b/fireworks/core/firework.py index 616f4d463..1f362bb70 100644 --- a/fireworks/core/firework.py +++ b/fireworks/core/firework.py @@ -14,12 +14,13 @@ from collections import defaultdict from copy import deepcopy from datetime import datetime -from typing import Any, Dict, Iterable, List, Sequence +from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union from monty.io import reverse_readline, zopen from monty.os.path import zpath from fireworks.core.fworker import FWorker +from fireworks.core.types import Checkpoint, FromDict, Spec from fireworks.fw_config import ( EXCEPT_DETAILS_ON_RERUN, NEGATIVE_FWID_CTR, @@ -58,7 +59,7 @@ class FiretaskBase(defaultdict, FWSerializable, metaclass=abc.ABCMeta): # if set to a list of str, only required and optional kwargs are allowed; consistency checked upon init optional_params = None - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: dict.__init__(self, *args, **kwargs) required_params = self.required_params or [] @@ -77,7 +78,7 @@ def __init__(self, *args, **kwargs): ) @abc.abstractmethod - def run_task(self, fw_spec): + def run_task(self, fw_spec: Spec) -> Optional["FWAction"]: """ This method gets called when the Firetask is run. It can take in a Firework spec, perform some task using that data, and then return an @@ -102,27 +103,27 @@ def run_task(self, fw_spec): @serialize_fw @recursive_serialize - def to_dict(self): + def to_dict(self) -> Dict[Any, Any]: return dict(self) @classmethod @recursive_deserialize - def from_dict(cls, m_dict): + def from_dict(cls, m_dict: Dict[Any, Any]) -> "FiretaskBase": return cls(m_dict) - def __repr__(self): + def __repr__(self) -> str: return f"<{self.fw_name}>:{dict(self)}" # not strictly needed here for pickle/unpickle, but complements __setstate__ - def __getstate__(self): + def __getstate__(self) -> Dict[Any, Any]: return self.to_dict() # added to support pickle/unpickle - def __setstate__(self, state): + def __setstate__(self, state) -> None: self.__init__(state) # added to support pickle/unpickle - def __reduce__(self): + def __reduce__(self) -> Tuple: t = defaultdict.__reduce__(self) return (t[0], (self.to_dict(),), t[2], t[3], t[4]) @@ -136,15 +137,15 @@ class FWAction(FWSerializable): def __init__( self, - stored_data=None, - exit=False, - update_spec=None, - mod_spec=None, - additions=None, - detours=None, - defuse_children=False, - defuse_workflow=False, - propagate=False, + stored_data: Optional[Dict[Any, Any]] = None, + exit: bool = False, + update_spec: Optional[Mapping[str, Any]] = None, + mod_spec: Optional[Mapping[str, Any]] = None, + additions: Optional[Union[Sequence["Firework"], Sequence["Workflow"]]] = None, + detours: Optional[Union[Sequence["Firework"], Sequence["Workflow"]]] = None, + defuse_children: bool = False, + defuse_workflow: bool = False, + propagate: bool = False, ): """ Args: @@ -177,7 +178,7 @@ def __init__( self.propagate = propagate @recursive_serialize - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: return { "stored_data": self.stored_data, "exit": self.exit, @@ -192,7 +193,7 @@ def to_dict(self): @classmethod @recursive_deserialize - def from_dict(cls, m_dict): + def from_dict(cls, m_dict: Dict[str, Any]) -> "FWAction": d = m_dict additions = [Workflow.from_dict(f) for f in d["additions"]] detours = [Workflow.from_dict(f) for f in d["detours"]] @@ -209,7 +210,7 @@ def from_dict(cls, m_dict): ) @property - def skip_remaining_tasks(self): + def skip_remaining_tasks(self) -> bool: """ If the FWAction gives any dynamic action, we skip the subsequent Firetasks @@ -218,7 +219,7 @@ def skip_remaining_tasks(self): """ return self.exit or self.detours or self.additions or self.defuse_children or self.defuse_workflow - def __str__(self): + def __str__(self) -> str: return "FWAction\n" + pprint.pformat(self.to_dict()) @@ -242,16 +243,16 @@ class Firework(FWSerializable): # note: if you modify this signature, you must also modify LazyFirework def __init__( self, - tasks, - spec=None, - name=None, - launches=None, - archived_launches=None, - state="WAITING", - created_on=None, - fw_id=None, - parents=None, - updated_on=None, + tasks: Union["FiretaskBase", Sequence["FiretaskBase"]], + spec: Optional[Dict[Any, Any]] = None, + name: Optional[str] = None, + launches: Optional[Sequence["Launch"]] = None, + archived_launches: Optional[Sequence["Launch"]] = None, + state: str = "WAITING", + created_on: Optional[datetime] = None, + fw_id: Optional[int] = None, + parents: Optional[Union["Firework", Sequence["Firework"]]] = None, + updated_on: Optional[datetime] = None, ): """ Args: @@ -290,7 +291,7 @@ def __init__( self._state = state @property - def state(self): + def state(self) -> str: """ Returns: str: The current state of the Firework @@ -298,7 +299,7 @@ def state(self): return self._state @state.setter - def state(self, state): + def state(self, state: str) -> None: """ Setter for the the FW state, which triggers updated_on change @@ -309,7 +310,7 @@ def state(self, state): self.updated_on = datetime.utcnow() @recursive_serialize - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: # put tasks in a special location of the spec spec = self.spec spec["_tasks"] = [t.to_dict() for t in self.tasks] @@ -330,7 +331,7 @@ def to_dict(self): return m_dict - def _rerun(self): + def _rerun(self) -> None: """ Moves all Launches to archived Launches and resets the state to 'WAITING'. The Firework can thus be re-run even if it was Launched in the past. This method should be called by @@ -354,7 +355,7 @@ def _rerun(self): self.launches = [] self.state = "WAITING" - def to_db_dict(self): + def to_db_dict(self) -> Dict[str, Any]: """ Return firework dict with updated launches and state. """ @@ -368,7 +369,7 @@ def to_db_dict(self): @classmethod @recursive_deserialize - def from_dict(cls, m_dict): + def from_dict(cls, m_dict: Dict[str, Any]) -> "Firework": tasks = m_dict["spec"]["_tasks"] launches = [Launch.from_dict(tmp) for tmp in m_dict.get("launches", [])] archived_launches = [Launch.from_dict(tmp) for tmp in m_dict.get("archived_launches", [])] @@ -381,8 +382,8 @@ def from_dict(cls, m_dict): tasks, m_dict["spec"], name, launches, archived_launches, state, created_on, fw_id, updated_on=updated_on ) - def __str__(self): - return f"Firework object: (id: {int(self.fw_id)} , name: {self.fw_name})" + def __str__(self) -> str: + return "Firework object: (id: %i , name: %s)" % (self.fw_id, self.fw_name) def __iter__(self) -> Iterable[FiretaskBase]: return self.tasks.__iter__() @@ -401,7 +402,9 @@ class Tracker(FWSerializable): MAX_TRACKER_LINES = 1000 - def __init__(self, filename, nlines=TRACKER_LINES, content="", allow_zipped=False): + def __init__( + self, filename: str, nlines: int = TRACKER_LINES, content: str = "", allow_zipped: bool = False + ) -> None: """ Args: filename (str) @@ -416,7 +419,7 @@ def __init__(self, filename, nlines=TRACKER_LINES, content="", allow_zipped=Fals self.content = content self.allow_zipped = allow_zipped - def track_file(self, launch_dir=None): + def track_file(self, launch_dir: Optional[str] = None) -> str: """ Reads the monitored file and returns back the last N lines @@ -441,19 +444,19 @@ def track_file(self, launch_dir=None): self.content = "\n".join(reversed(lines)) return self.content - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: m_dict = {"filename": self.filename, "nlines": self.nlines, "allow_zipped": self.allow_zipped} if self.content: m_dict["content"] = self.content return m_dict @classmethod - def from_dict(cls, m_dict): + def from_dict(cls, m_dict: Dict[str, Any]) -> "Tracker": return Tracker( m_dict["filename"], m_dict["nlines"], m_dict.get("content", ""), m_dict.get("allow_zipped", False) ) - def __str__(self): + def __str__(self) -> str: return f"### Filename: {self.filename}\n{self.content}" @@ -464,16 +467,16 @@ class Launch(FWSerializable): def __init__( self, - state, - launch_dir, - fworker=None, - host=None, - ip=None, - trackers=None, - action=None, - state_history=None, - launch_id=None, - fw_id=None, + state: str, + launch_dir: str, + fworker: Optional["FWorker"] = None, + host: Optional[str] = None, + ip: Optional[str] = None, + trackers: Optional[Sequence["Tracker"]] = None, + action: Optional["FWAction"] = None, + state_history: Optional[Dict[Any, Any]] = None, + launch_id: Optional[int] = None, + fw_id: Optional[int] = None, ): """ Args: @@ -501,7 +504,7 @@ def __init__( self.launch_id = launch_id self.fw_id = fw_id - def touch_history(self, update_time=None, checkpoint=None): + def touch_history(self, update_time: Optional[datetime] = None, checkpoint: Optional[Checkpoint] = None) -> None: """ Updates the update_on field of the state history of a Launch. Used to ping that a Launch is still alive. @@ -514,7 +517,7 @@ def touch_history(self, update_time=None, checkpoint=None): self.state_history[-1]["checkpoint"] = checkpoint self.state_history[-1]["updated_on"] = update_time - def set_reservation_id(self, reservation_id): + def set_reservation_id(self, reservation_id: Union[int, str]) -> None: """ Adds the job_id to the reservation. @@ -527,7 +530,7 @@ def set_reservation_id(self, reservation_id): break @property - def state(self): + def state(self) -> str: """ Returns: str: The current state of the Launch. @@ -535,7 +538,7 @@ def state(self): return self._state @state.setter - def state(self, state): + def state(self, state: str) -> None: """ Setter for the the Launch's state. Automatically triggers an update to state_history. @@ -546,7 +549,7 @@ def state(self, state): self._update_state_history(state) @property - def time_start(self): + def time_start(self) -> datetime: """ Returns: datetime: the time the Launch started RUNNING @@ -554,7 +557,7 @@ def time_start(self): return self._get_time("RUNNING") @property - def time_end(self): + def time_end(self) -> datetime: """ Returns: datetime: the time the Launch was COMPLETED or FIZZLED @@ -562,7 +565,7 @@ def time_end(self): return self._get_time(["COMPLETED", "FIZZLED"]) @property - def time_reserved(self): + def time_reserved(self) -> datetime: """ Returns: datetime: the time the Launch was RESERVED in the queue @@ -570,7 +573,7 @@ def time_reserved(self): return self._get_time("RESERVED") @property - def last_pinged(self): + def last_pinged(self) -> datetime: """ Returns: datetime: the time the Launch last pinged a heartbeat that it was still running @@ -578,7 +581,7 @@ def last_pinged(self): return self._get_time("RUNNING", True) @property - def runtime_secs(self): + def runtime_secs(self) -> int: # type: ignore """ Returns: int: the number of seconds that the Launch ran for. @@ -586,10 +589,10 @@ def runtime_secs(self): start = self.time_start end = self.time_end if start and end: - return (end - start).total_seconds() + return int((end - start).total_seconds()) @property - def reservedtime_secs(self): + def reservedtime_secs(self) -> int: # type: ignore """ Returns: int: number of seconds the Launch was stuck as RESERVED in a queue. @@ -597,10 +600,10 @@ def reservedtime_secs(self): start = self.time_reserved if start: end = self.time_start if self.time_start else datetime.utcnow() - return (end - start).total_seconds() + return int((end - start).total_seconds()) @recursive_serialize - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: return { "fworker": self.fworker, "fw_id": self.fw_id, @@ -664,7 +667,7 @@ def _update_state_history(self, state): if state in ["RUNNING", "RESERVED"]: self.touch_history() # add updated_on key - def _get_time(self, states, use_update_time=False): + def _get_time(self, states, use_update_time: bool = False) -> datetime: # type: ignore """ Internal method to help get the time of various events in the Launch (e.g. RUNNING) from the state history. @@ -1332,7 +1335,7 @@ def reset(self, reset_ids: bool = True) -> None: self.fw_states = {key: self.id_fw[key].state for key in self.id_fw} @classmethod - def from_dict(cls, m_dict: Dict[str, Any]) -> "Workflow": + def from_dict(cls, m_dict: FromDict) -> "Workflow": """ Return Workflow from its dict representation. @@ -1358,7 +1361,7 @@ def from_dict(cls, m_dict: Dict[str, Any]) -> "Workflow": return Workflow.from_Firework(Firework.from_dict(m_dict)) @classmethod - def from_Firework(cls, fw: Firework, name: str = None, metadata=None) -> "Workflow": + def from_Firework(cls, fw: "Firework", name: Optional[str] = None, metadata=None) -> "Workflow": """ Return Workflow from the given Firework. @@ -1373,10 +1376,10 @@ def from_Firework(cls, fw: Firework, name: str = None, metadata=None) -> "Workfl name = name if name else fw.name return Workflow([fw], None, name=name, metadata=metadata, created_on=fw.created_on, updated_on=fw.updated_on) - def __str__(self): + def __str__(self) -> str: return f"Workflow object: (fw_ids: {self.id_fw.keys()} , name: {self.name})" - def remove_fws(self, fw_ids): + def remove_fws(self, fw_ids: Sequence[int]) -> None: """ Remove the fireworks corresponding to the input firework ids and update the workflow i.e the parents of the removed fireworks become the parents of the children fireworks (only if the diff --git a/fireworks/core/fworker.py b/fireworks/core/fworker.py index caa4a8070..6ae28e025 100644 --- a/fireworks/core/fworker.py +++ b/fireworks/core/fworker.py @@ -3,6 +3,7 @@ """ import json +from typing import Any, Dict, Optional, Sequence, Union from fireworks.fw_config import FWORKER_LOC from fireworks.utilities.fw_serializers import ( @@ -21,7 +22,13 @@ class FWorker(FWSerializable): - def __init__(self, name="Automatically generated Worker", category="", query=None, env=None): + def __init__( + self, + name: str = "Automatically generated Worker", + category: Union[str, Sequence[str]] = "", + query: Optional[Dict[str, Any]] = None, + env: Optional[Dict[str, Any]] = None, + ) -> None: """ Args: name (str): the name of the resource, should be unique @@ -41,7 +48,7 @@ def __init__(self, name="Automatically generated Worker", category="", query=Non self.env = env if env else {} @recursive_serialize - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: return { "name": self.name, "category": self.category, @@ -51,11 +58,11 @@ def to_dict(self): @classmethod @recursive_deserialize - def from_dict(cls, m_dict): + def from_dict(cls, m_dict: Dict[str, Any]) -> "FWorker": return FWorker(m_dict["name"], m_dict["category"], json.loads(m_dict["query"]), m_dict.get("env")) @property - def query(self): + def query(self) -> Dict[str, Any]: """ Returns updated query dict. """ @@ -77,7 +84,7 @@ def query(self): return q @classmethod - def auto_load(cls): + def auto_load(cls) -> "FWorker": """ Returns FWorker object from settings file(my_fworker.yaml). """ diff --git a/fireworks/core/launchpad.py b/fireworks/core/launchpad.py index 9e6ffbd16..d59917680 100644 --- a/fireworks/core/launchpad.py +++ b/fireworks/core/launchpad.py @@ -12,6 +12,7 @@ import warnings from collections import defaultdict from itertools import chain +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union import gridfs from bson import ObjectId @@ -50,7 +51,7 @@ # TODO: lots of duplication reduction and cleanup possible -def sort_aggregation(sort): +def sort_aggregation(sort: Sequence[Tuple[str, int]]) -> List[Mapping[str, Any]]: """Build sorting aggregation pipeline. Args: @@ -97,7 +98,13 @@ class WFLock: Calling functions are responsible for handling the error in order to avoid database inconsistencies. """ - def __init__(self, lp, fw_id, expire_secs=WFLOCK_EXPIRATION_SECS, kill=WFLOCK_EXPIRATION_KILL): + def __init__( + self, + lp: "LaunchPad", + fw_id: int, + expire_secs: int = WFLOCK_EXPIRATION_SECS, + kill: bool = WFLOCK_EXPIRATION_KILL, + ): """ Args: lp (LaunchPad) @@ -110,7 +117,7 @@ def __init__(self, lp, fw_id, expire_secs=WFLOCK_EXPIRATION_SECS, kill=WFLOCK_EX self.expire_secs = expire_secs self.kill = kill - def __enter__(self): + def __enter__(self) -> None: ctr = 0 waiting_time = 0 # acquire lock @@ -140,7 +147,7 @@ def __enter__(self): {"nodes": self.fw_id, "locked": {"$exists": False}}, {"$set": {"locked": True}} ) - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type, exc_val, exc_tb) -> None: self.lp.workflows.find_one_and_update({"nodes": self.fw_id}, {"$unset": {"locked": True}}) @@ -151,17 +158,17 @@ class LaunchPad(FWSerializable): def __init__( self, - host=None, - port=None, - name=None, - username=None, - password=None, - logdir=None, - strm_lvl=None, - user_indices=None, - wf_user_indices=None, - authsource=None, - uri_mode=False, + host: Optional[str] = None, + port: Optional[int] = None, + name: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + logdir: Optional[str] = None, + strm_lvl: Optional[str] = None, + user_indices: Optional[Sequence[int]] = None, + wf_user_indices: Optional[Sequence[int]] = None, + authsource: str = None, + uri_mode: bool = False, mongoclient_kwargs=None, ): """ @@ -580,7 +587,7 @@ def get_wf_by_fw_id_lzyfw(self, fw_id): fw_states, ) - def delete_fws(self, fw_ids, delete_launch_dirs=False): + def delete_fws(self, fw_ids: Sequence[int], delete_launch_dirs: bool = False) -> None: """Delete a set of fireworks identified by their fw_ids. ATTENTION: This function serves maintenance purposes and will leave @@ -724,7 +731,14 @@ def get_wf_summary_dict(self, fw_id, mode="more"): return wf - def get_fw_ids(self, query=None, sort=None, limit=0, count_only=False, launches_mode=False): + def get_fw_ids( + self, + query: Optional[Mapping[str, Any]] = None, + sort=None, + limit: int = 0, + count_only: bool = False, + launches_mode: bool = False, + ) -> Union[int, Sequence[int]]: """ Return all the fw ids that match a query. @@ -2038,7 +2052,7 @@ class LazyFirework: db_fields = ("name", "fw_id", "spec", "created_on", "state") db_launch_fields = ("launches", "archived_launches") - def __init__(self, fw_id, fw_coll, launch_coll, fallback_fs): + def __init__(self, fw_id: int, fw_coll, launch_coll, fallback_fs): """ Args: fw_id (int): firework id @@ -2177,7 +2191,7 @@ def full_fw(self): # Get a type of Launch object - def _get_launch_data(self, name): + def _get_launch_data(self, name: str): """ Pull launch data individually for each field. @@ -2202,7 +2216,7 @@ def _get_launch_data(self, name): return getattr(fw, name) -def get_action_from_gridfs(action_dict, fallback_fs): +def get_action_from_gridfs(action_dict: Mapping[str, Any], fallback_fs: gridfs.GridFS) -> Dict[str, Any]: """ Helper function to obtain the correct dictionary of the FWAction associated with a launch. If necessary retrieves the information from gridfs based diff --git a/fireworks/core/rocket.py b/fireworks/core/rocket.py index 8352678a5..02375ce42 100644 --- a/fireworks/core/rocket.py +++ b/fireworks/core/rocket.py @@ -15,7 +15,7 @@ import traceback from datetime import datetime from threading import Event, Thread, current_thread -from typing import Dict, Union +from typing import Optional from monty.io import zopen from monty.os.path import zpath @@ -23,6 +23,7 @@ from fireworks.core.firework import Firework, FWAction from fireworks.core.fworker import FWorker from fireworks.core.launchpad import LaunchPad, LockedWorkflowError +from fireworks.core.types import Checkpoint, Spec from fireworks.fw_config import ( PING_TIME_SECS, PRINT_FW_JSON, @@ -56,7 +57,7 @@ def ping_launch(launchpad: LaunchPad, launch_id: int, stop_event: Event, master_ stop_event.wait(PING_TIME_SECS) -def start_ping_launch(launchpad: LaunchPad, launch_id: int) -> Union[Event, None]: +def start_ping_launch(launchpad: LaunchPad, launch_id: int) -> Optional[Event]: fd = FWData() if fd.MULTIPROCESSING: if not launch_id: @@ -70,7 +71,7 @@ def start_ping_launch(launchpad: LaunchPad, launch_id: int) -> Union[Event, None return ping_stop -def stop_backgrounds(ping_stop, btask_stops): +def stop_backgrounds(ping_stop: Event, btask_stops) -> None: fd = FWData() if fd.MULTIPROCESSING: fd.Running_IDs[os.getpid()] = None @@ -81,7 +82,7 @@ def stop_backgrounds(ping_stop, btask_stops): b.set() -def background_task(btask, spec, stop_event, master_thread): +def background_task(btask, spec: Spec, stop_event: Event, master_thread: Thread) -> None: num_launched = 0 while not stop_event.is_set() and master_thread.is_alive(): for task in btask.tasks: @@ -93,7 +94,7 @@ def background_task(btask, spec, stop_event, master_thread): break -def start_background_task(btask, spec): +def start_background_task(btask, spec: Spec) -> Event: ping_stop = Event() ping_thread = Thread(target=background_task, args=(btask, spec, ping_stop, current_thread())) ping_thread.start() @@ -427,7 +428,7 @@ def run(self, pdb_on_exception: bool = False) -> bool: return True @staticmethod - def update_checkpoint(launchpad: LaunchPad, launch_dir: str, launch_id: int, checkpoint: Dict[str, any]) -> None: + def update_checkpoint(launchpad: LaunchPad, launch_dir: str, launch_id: int, checkpoint: Checkpoint) -> None: """ Helper function to update checkpoint @@ -447,9 +448,7 @@ def update_checkpoint(launchpad: LaunchPad, launch_dir: str, launch_id: int, che with zopen(fpath, "wt") as f_out: f_out.write(json.dumps(d, ensure_ascii=False)) - def decorate_fwaction( - self, fwaction: FWAction, my_spec: Dict[str, any], m_fw: Firework, launch_dir: str - ) -> FWAction: + def decorate_fwaction(self, fwaction: FWAction, my_spec: Spec, m_fw: Firework, launch_dir: str) -> FWAction: if my_spec.get("_pass_job_info"): job_info = list(my_spec.get("_job_info", [])) diff --git a/fireworks/core/rocket_launcher.py b/fireworks/core/rocket_launcher.py index b9aa5b311..e80f95e27 100644 --- a/fireworks/core/rocket_launcher.py +++ b/fireworks/core/rocket_launcher.py @@ -5,8 +5,10 @@ import os import time from datetime import datetime +from typing import Optional from fireworks.core.fworker import FWorker +from fireworks.core.launchpad import LaunchPad from fireworks.core.rocket import Rocket from fireworks.fw_config import FWORKER_LOC, RAPIDFIRE_SLEEP_SECS from fireworks.utilities.fw_utilities import ( @@ -23,17 +25,23 @@ __date__ = "Feb 22, 2013" -def get_fworker(fworker): - if fworker: +def get_fworker(fworker: Optional[FWorker]) -> FWorker: + if fworker is not None: my_fwkr = fworker elif FWORKER_LOC: - my_fwkr = FWorker.from_file(FWORKER_LOC) + my_fwkr = FWorker.from_file(FWORKER_LOC) # type: ignore else: my_fwkr = FWorker() return my_fwkr -def launch_rocket(launchpad, fworker=None, fw_id=None, strm_lvl="INFO", pdb_on_exception=False): +def launch_rocket( + launchpad: LaunchPad, + fworker: Optional[FWorker] = None, + fw_id: Optional[int] = None, + strm_lvl: str = "INFO", + pdb_on_exception: bool = False, +) -> bool: """ Run a single rocket in the current directory. @@ -53,6 +61,7 @@ def launch_rocket(launchpad, fworker=None, fw_id=None, strm_lvl="INFO", pdb_on_e l_logger = get_fw_logger("rocket.launcher", l_dir=l_dir, stream_level=strm_lvl) log_multi(l_logger, "Launching Rocket") + assert fw_id is not None rocket = Rocket(launchpad, fworker, fw_id) rocket_ran = rocket.run(pdb_on_exception=pdb_on_exception) log_multi(l_logger, "Rocket finished") @@ -60,16 +69,16 @@ def launch_rocket(launchpad, fworker=None, fw_id=None, strm_lvl="INFO", pdb_on_e def rapidfire( - launchpad, - fworker=None, - m_dir=None, - nlaunches=0, - max_loops=-1, - sleep_time=None, - strm_lvl="INFO", - timeout=None, - local_redirect=False, - pdb_on_exception=False, + launchpad: LaunchPad, + fworker: Optional[FWorker] = None, + m_dir: Optional[str] = None, + nlaunches: int = 0, + max_loops: int = -1, + sleep_time: Optional[int] = None, + strm_lvl: str = "INFO", + timeout: Optional[int] = None, + local_redirect: bool = False, + pdb_on_exception: bool = False, ): """ Keeps running Rockets in m_dir until we reach an error. Automatically creates subdirectories diff --git a/fireworks/core/tests/tasks.py b/fireworks/core/tests/tasks.py index 83aa0bf9f..23fbbc14a 100644 --- a/fireworks/core/tests/tasks.py +++ b/fireworks/core/tests/tasks.py @@ -2,6 +2,7 @@ from unittest import SkipTest from fireworks import FiretaskBase, Firework, FWAction +from fireworks.core.types import Spec from fireworks.utilities.fw_utilities import explicit_serialize @@ -95,7 +96,7 @@ def run_task(self, fw_spec): class DetoursTask(FiretaskBase): optional_params = ["n_detours", "data_per_detour"] - def run_task(self, fw_spec): + def run_task(self, fw_spec: Spec) -> FWAction: data_per_detour = self.get("data_per_detour", None) n_detours = self.get("n_detours", 10) fws = [] diff --git a/fireworks/core/types.py b/fireworks/core/types.py index 8f0bdf1e5..da47f644e 100644 --- a/fireworks/core/types.py +++ b/fireworks/core/types.py @@ -1,5 +1,5 @@ -from typing import Any, MutableMapping - +from typing import Any, Mapping, MutableMapping Checkpoint = MutableMapping[Any, Any] +FromDict = Mapping[str, Any] Spec = MutableMapping[Any, Any] diff --git a/fireworks/flask_site/gunicorn.py b/fireworks/flask_site/gunicorn.py index a65c25a3a..d4071ab2c 100755 --- a/fireworks/flask_site/gunicorn.py +++ b/fireworks/flask_site/gunicorn.py @@ -1,23 +1,24 @@ # Based on http://docs.gunicorn.org/en/19.6.0/custom.html import multiprocessing +from typing import Any, Mapping, Optional import gunicorn.app.base from fireworks.flask_site.app import app as handler_app -def number_of_workers(): +def number_of_workers() -> int: return (multiprocessing.cpu_count() * 2) + 1 class StandaloneApplication(gunicorn.app.base.BaseApplication): - def __init__(self, app, options=None): + def __init__(self, app, options: Optional[Mapping[str, Any]] = None) -> None: self.options = options or {} self.application = app super().__init__() - def load_config(self): + def load_config(self) -> None: config = {key: value for key, value in self.options.items() if key in self.cfg.settings and value is not None} for key, value in config.items(): self.cfg.set(key.lower(), value) diff --git a/fireworks/flask_site/helpers.py b/fireworks/flask_site/helpers.py index 63e91a022..3a97d6ccc 100644 --- a/fireworks/flask_site/helpers.py +++ b/fireworks/flask_site/helpers.py @@ -1,4 +1,9 @@ -def get_totals(states, lp): +from typing import Mapping, Sequence + +from fireworks import LaunchPad + + +def get_totals(states: Sequence[str], lp: LaunchPad) -> Mapping[str, int]: fw_stats = {} wf_stats = {} for state in states: @@ -7,21 +12,21 @@ def get_totals(states, lp): return {"fw_stats": fw_stats, "wf_stats": wf_stats} -def fw_filt_given_wf_filt(filt, lp): +def fw_filt_given_wf_filt(filt, lp: LaunchPad) -> Mapping[str, Mapping[str, Sequence[int]]]: fw_ids = set() for doc in lp.workflows.find(filt, {"_id": 0, "nodes": 1}): fw_ids |= set(doc["nodes"]) return {"fw_id": {"$in": list(fw_ids)}} -def wf_filt_given_fw_filt(filt, lp): +def wf_filt_given_fw_filt(filt, lp: LaunchPad) -> Mapping[str, Mapping[str, Sequence[int]]]: wf_ids = set() for doc in lp.fireworks.find(filt, {"_id": 0, "fw_id": 1}): wf_ids.add(doc["fw_id"]) return {"nodes": {"$in": list(wf_ids)}} -def uses_index(filt, coll): +def uses_index(filt, coll) -> bool: ii = coll.index_information() fields_filtered = set(filt.keys()) fields_indexed = {v["key"][0][0] for v in ii.values()} diff --git a/fireworks/scripts/fwtool b/fireworks/scripts/fwtool index e71d5b5c1..94e5d2176 100755 --- a/fireworks/scripts/fwtool +++ b/fireworks/scripts/fwtool @@ -3,7 +3,7 @@ import os import shutil import sys -from argparse import ArgumentParser +from argparse import ArgumentParser, Namespace import yaml @@ -18,8 +18,10 @@ __maintainer__ = "Shyue Ping Ong" __email__ = "ongsp@ucsd.edu" __date__ = "1/6/14" +from typing import Sequence -def create_fw_single(args, fnames, yaml_fname): + +def create_fw_single(args: Namespace, fnames: Sequence[str], yaml_fname: str) -> None: tasks = [] if fnames: files = [] @@ -34,7 +36,7 @@ def create_fw_single(args, fnames, yaml_fname): yaml.dump(fw.to_dict(), f, default_flow_style=False) -def create_fw(args): +def create_fw(args: Namespace) -> None: if not args.directory_mode: create_fw_single(args, args.files_or_dirs, args.output) else: @@ -43,7 +45,7 @@ def create_fw(args): create_fw_single(args, [os.path.join(d, f) for f in os.listdir(d)], output_fname.format(i)) -def do_cleanup(args): +def do_cleanup(args: Namespace) -> None: lp = LaunchPad.from_file(args.launchpad_file) if args.launchpad_file else LaunchPad(strm_lvl=args.loglvl) to_delete = [] for i in lp.get_fw_ids({}): diff --git a/fireworks/scripts/lpad_run.py b/fireworks/scripts/lpad_run.py index 0545189c4..d6b364476 100644 --- a/fireworks/scripts/lpad_run.py +++ b/fireworks/scripts/lpad_run.py @@ -55,7 +55,7 @@ DEFAULT_LPAD_YAML = "my_launchpad.yaml" -def pw_check(ids: List[int], args: Namespace, skip_pw: bool = False) -> List[int]: +def pw_check(ids: Sequence[int], args: Namespace, skip_pw: bool = False) -> List[int]: if len(ids) > PW_CHECK_NUM and not skip_pw: m_password = datetime.datetime.now().strftime("%Y-%m-%d") if not args.password: @@ -249,7 +249,7 @@ def add_wf_dir(args: Namespace) -> None: lp.add_wf(fwf) -def print_fws(ids, lp, args: Namespace) -> None: +def print_fws(ids: Sequence[int], lp: LaunchPad, args: Namespace) -> None: """Prints results of some FireWorks query to stdout.""" fws = [] if args.display_format == "ids": @@ -327,9 +327,7 @@ def get_fw_ids_helper(lp: LaunchPad, args: Namespace, count_only: Union[bool, No return ids -def get_fws_helper( - lp: LaunchPad, ids: List[int], args: Namespace -) -> Union[List[int], int, List[Dict[str, Union[str, int, bool]]], Union[str, int, bool]]: +def get_fws_helper(lp: LaunchPad, ids: Sequence[int], args: Namespace) -> Union[List[int], int, List[Dict[str, Union[str, int, bool]]], Union[str, int, bool]]: """Get fws from ids in a representation according to args.display_format.""" fws = [] if args.display_format == "ids": @@ -493,15 +491,15 @@ def delete_wfs(args: Namespace) -> None: lp.m_logger.info(f"Finished deleting {len(fw_ids)} WFs") -def get_children(links, start, max_depth): - data = {} - for l, c in links.items(): - if l == start: - if len(c) > 0: - data[l] = [get_children(links, i, max_depth) for i in c] - else: - data[l] = c - return data +# def get_children(links, start): +# data = {} +# for l, c in links.items(): +# if l == start: +# if len(c) > 0: +# data[l] = [get_children(links, i) for i in c] +# else: +# data[l] = c +# return data def detect_lostruns(args: Namespace) -> None: @@ -695,7 +693,7 @@ def set_priority(args: Namespace) -> None: lp.m_logger.info(f"Finished setting priorities of {len(fw_ids)} FWs") -def _open_webbrowser(url): +def _open_webbrowser(url: str) -> None: """Open a web browser after a delay to give the web server more startup time.""" import webbrowser @@ -869,14 +867,14 @@ def orphaned(args: Namespace) -> None: print(args.output(fws)) -def get_output_func(format: Literal["json", "yaml"]) -> Callable[[str], Any]: +def get_output_func(format: Literal["json", "yaml"]) -> Callable[[Any], str]: if format == "json": - return lambda x: json.dumps(x, default=DATETIME_HANDLER, indent=4) + return lambda x: json.dumps(x, default=DATETIME_HANDLER, indent=4) # type: ignore else: - return lambda x: yaml.safe_dump(recursive_dict(x, preserve_unicode=False), default_flow_style=False) + return lambda x: yaml.safe_dump(recursive_dict(x, preserve_unicode=False), default_flow_style=False) # type: ignore -def arg_positive_int(value: str) -> int: +def arg_positive_int(value: Any) -> int: try: ivalue = int(value) except ValueError: @@ -934,7 +932,7 @@ def lpad(argv: Optional[Sequence[str]] = None) -> int: # enhanced display options allow for value 'none' or None (default) for no output enh_disp_args = copy.deepcopy(disp_args) enh_disp_kwargs = copy.deepcopy(disp_kwargs) - enh_disp_kwargs["choices"].append("none") + enh_disp_kwargs["choices"].append("none") # type: ignore enh_disp_kwargs["default"] = None query_args = ["-q", "--query"] diff --git a/fireworks/user_objects/firetasks/dataflow_tasks.py b/fireworks/user_objects/firetasks/dataflow_tasks.py index c59f7d32a..e149ece8d 100644 --- a/fireworks/user_objects/firetasks/dataflow_tasks.py +++ b/fireworks/user_objects/firetasks/dataflow_tasks.py @@ -5,9 +5,11 @@ __copyright__ = "Copyright 2016, Karlsruhe Institute of Technology" import sys +from typing import Any, List, Mapping, Optional from fireworks import Firework from fireworks.core.firework import FiretaskBase, FWAction +from fireworks.core.types import Spec from fireworks.utilities.fw_serializers import load_object if sys.version_info[0] > 2: @@ -74,7 +76,7 @@ class CommandLineTask(FiretaskBase): required_params = ["command_spec"] optional_params = ["inputs", "outputs", "chunk_number"] - def run_task(self, fw_spec): + def run_task(self, fw_spec: Spec) -> FWAction: cmd_spec = self["command_spec"] ilabels = self.get("inputs") olabels = self.get("outputs") @@ -136,7 +138,7 @@ def run_task(self, fw_spec): return FWAction() @staticmethod - def command_line_tool(command, inputs=None, outputs=None): + def command_line_tool(command, inputs: Optional[List[Mapping[Any, Any]]] = None, outputs: Optional[Spec] = None): """ This function composes and executes a command from provided specifications. @@ -163,7 +165,7 @@ def command_line_tool(command, inputs=None, outputs=None): from shutil import copyfile from subprocess import PIPE, Popen - def set_binding(arg): + def set_binding(arg: Mapping[str, Any]) -> str: argstr = "" if "binding" in arg: if "prefix" in arg["binding"]: @@ -283,7 +285,7 @@ class ForeachTask(FiretaskBase): required_params = ["task", "split"] optional_params = ["number of chunks"] - def run_task(self, fw_spec): + def run_task(self, fw_spec: Spec) -> FWAction: assert isinstance(self["split"], basestring), self["split"] assert isinstance(fw_spec[self["split"]], list) if isinstance(self["task"]["inputs"], list): @@ -321,7 +323,7 @@ class JoinDictTask(FiretaskBase): required_params = ["inputs", "output"] optional_params = ["rename"] - def run_task(self, fw_spec): + def run_task(self, fw_spec: Spec) -> FWAction: assert isinstance(self["output"], basestring) assert isinstance(self["inputs"], list) @@ -351,7 +353,7 @@ class JoinListTask(FiretaskBase): _fw_name = "JoinListTask" required_params = ["inputs", "output"] - def run_task(self, fw_spec): + def run_task(self, fw_spec: Spec) -> FWAction: assert isinstance(self["output"], basestring) assert isinstance(self["inputs"], list) if self["output"] not in fw_spec: @@ -377,7 +379,7 @@ class ImportDataTask(FiretaskBase): required_params = ["filename", "mapstring"] optional_params = [] - def run_task(self, fw_spec): + def run_task(self, fw_spec: Spec) -> FWAction: import json import operator from functools import reduce diff --git a/fireworks/utilities/fw_serializers.py b/fireworks/utilities/fw_serializers.py index 6423325d9..4a31fd8b2 100644 --- a/fireworks/utilities/fw_serializers.py +++ b/fireworks/utilities/fw_serializers.py @@ -33,6 +33,7 @@ import json # note that ujson is faster, but at this time does not support "default" in dumps() import pkgutil import traceback +from typing import Any, Dict, Mapping, MutableMapping, Optional, Type import ruamel.yaml as yaml from monty.json import MontyDecoder, MSONable @@ -55,7 +56,7 @@ # TODO: consider *somehow* switching FireWorks to monty serialization. e.g., numpy serialization is better handled. -SAVED_FW_MODULES = {} +SAVED_FW_MODULES: MutableMapping[str, str] = {} DATETIME_HANDLER = lambda obj: obj.isoformat() if isinstance(obj, datetime.datetime) else None ENCODING_PARAMS = {"encoding": "utf-8"} @@ -71,7 +72,7 @@ import fireworks_schema -def recursive_dict(obj, preserve_unicode=True): +def recursive_dict(obj: Any, preserve_unicode: bool = True) -> Any: if obj is None: return None @@ -103,7 +104,7 @@ def recursive_dict(obj, preserve_unicode=True): # TODO: is reconstitute_dates really needed? Can this method just do everything? -def _recursive_load(obj): +def _recursive_load(obj: Any) -> Any: if obj is None: return None @@ -196,20 +197,20 @@ class and implement the to_dict() and from_dict() methods. """ @property - def fw_name(self): + def fw_name(self) -> str: try: - return self._fw_name + return self._fw_name # type: ignore except AttributeError: return get_default_serialization(self.__class__) @abc.abstractmethod - def to_dict(self): + def to_dict(self) -> Mapping[Any, Any]: raise NotImplementedError("FWSerializable object did not implement to_dict()!") - def to_db_dict(self): + def to_db_dict(self) -> Mapping[Any, Any]: return self.to_dict() - def as_dict(self): + def as_dict(self) -> Mapping[Any, Any]: # strictly for pseudo-compatibility with MSONable # Note that FWSerializable is not MSONable, it uses _fw_name instead of __class__ and # __module__ @@ -217,13 +218,13 @@ def as_dict(self): @classmethod @abc.abstractmethod - def from_dict(cls, m_dict): + def from_dict(cls, m_dict) -> "FWSerializable": raise NotImplementedError("FWSerializable object did not implement from_dict()!") - def __repr__(self): + def __repr__(self) -> str: return json.dumps(self.to_dict(), default=DATETIME_HANDLER) - def to_format(self, f_format="json", **kwargs): + def to_format(self, f_format: str = "json", **kwargs) -> str: """ returns a String representation in the given format @@ -234,12 +235,12 @@ def to_format(self, f_format="json", **kwargs): return json.dumps(self.to_dict(), default=DATETIME_HANDLER, **kwargs) elif f_format == "yaml": # start with the JSON format, and convert to YAML - return yaml.safe_dump(self.to_dict(), default_flow_style=YAML_STYLE, allow_unicode=True) + return yaml.safe_dump(self.to_dict(), default_flow_style=YAML_STYLE, allow_unicode=True) # type: ignore else: raise ValueError(f"Unsupported format {f_format}") @classmethod - def from_format(cls, f_str, f_format="json"): + def from_format(cls, f_str: str, f_format: str = "json") -> "FWSerializable": """ convert from a String representation to its Object. @@ -256,11 +257,11 @@ def from_format(cls, f_str, f_format="json"): dct = yaml.safe_load(f_str) else: raise ValueError(f"Unsupported format {f_format}") - if JSON_SCHEMA_VALIDATE and cls.__name__ in JSON_SCHEMA_VALIDATE_LIST: + if JSON_SCHEMA_VALIDATE and cls.__name__ in JSON_SCHEMA_VALIDATE_LIST: # type: ignore fireworks_schema.validate(dct, cls.__name__) return cls.from_dict(reconstitute_dates(dct)) - def to_file(self, filename, f_format=None, **kwargs): + def to_file(self, filename: str, f_format: Optional[str] = None, **kwargs) -> None: """ Write a serialization of this object to a file. @@ -274,7 +275,7 @@ def to_file(self, filename, f_format=None, **kwargs): f.write(self.to_format(f_format=f_format, **kwargs)) @classmethod - def from_file(cls, filename, f_format=None): + def from_file(cls, filename: str, f_format: Optional[str] = None) -> "FWSerializable": """ Load a serialization of this object from a file. @@ -293,14 +294,14 @@ def from_file(cls, filename, f_format=None): def __getstate__(self): return self.to_dict() - def __setstate__(self, state): + def __setstate__(self, state) -> None: fw_obj = self.from_dict(state) for k, v in fw_obj.__dict__.items(): self.__dict__[k] = v # TODO: make this quicker the first time around -def load_object(obj_dict): +def load_object(obj_dict: Dict[str, Any]) -> Any: """ Creates an instantiation of a class based on a dictionary representation. We implicitly determine the Class through introspection along with information in the dictionary. @@ -370,7 +371,7 @@ def load_object(obj_dict): raise ValueError(f"load_object() could not find a class with cls._fw_name {fw_name}") -def load_object_from_file(filename, f_format=None): +def load_object_from_file(filename: str, f_format: Optional[str] = None) -> Any: """ Implicitly load an object from a file. just a friendly wrapper to load_object() @@ -413,7 +414,7 @@ def _search_module_for_obj(m_module, obj_dict): return obj.from_dict(obj_dict) -def reconstitute_dates(obj_dict): +def reconstitute_dates(obj_dict: Any) -> Any: if obj_dict is None: return None @@ -434,7 +435,7 @@ def reconstitute_dates(obj_dict): return obj_dict -def get_default_serialization(cls): +def get_default_serialization(cls: Type[Any]) -> str: root_mod = cls.__module__.split(".")[0] if root_mod == "__main__": raise ValueError( From 06eed06efe18a7b6c3b0d9157cd3535546c29527 Mon Sep 17 00:00:00 2001 From: Eric Berquist Date: Thu, 27 Jan 2022 22:27:16 -0500 Subject: [PATCH 06/18] Update mypy in pre-commit config --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ee0dcff91..16f0df906 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,7 +25,7 @@ repos: - id: isort - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.910-1 + rev: v0.931 hooks: - id: mypy additional_dependencies: From 3f3ae5f2f9ae29b86b32cc08375ad0cede3cdace Mon Sep 17 00:00:00 2001 From: Eric Berquist Date: Thu, 27 Jan 2022 22:37:29 -0500 Subject: [PATCH 07/18] Fix some imports --- fireworks/scripts/lpad_run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fireworks/scripts/lpad_run.py b/fireworks/scripts/lpad_run.py index d6b364476..c6ec6eede 100644 --- a/fireworks/scripts/lpad_run.py +++ b/fireworks/scripts/lpad_run.py @@ -327,7 +327,7 @@ def get_fw_ids_helper(lp: LaunchPad, args: Namespace, count_only: Union[bool, No return ids -def get_fws_helper(lp: LaunchPad, ids: Sequence[int], args: Namespace) -> Union[List[int], int, List[Dict[str, Union[str, int, bool]]], Union[str, int, bool]]: +def get_fws_helper(lp: LaunchPad, ids: Sequence[int], args: Namespace) -> Union[int, List[int], List[List[int]]]: """Get fws from ids in a representation according to args.display_format.""" fws = [] if args.display_format == "ids": From c9e5120991fd11fb53ff9913219fea4fc482c4cf Mon Sep 17 00:00:00 2001 From: Eric Berquist Date: Thu, 3 Feb 2022 22:41:31 -0500 Subject: [PATCH 08/18] continue typing --- fireworks/core/firework.py | 5 +- fireworks/core/launchpad.py | 45 ++-- fireworks/fw_config.py | 4 +- fireworks/scripts/lpad_run.py | 6 + fireworks/utilities/dagflow.py | 255 +++++++++++++++++++++- fireworks/utilities/fw_serializers.py | 2 +- fireworks/utilities/fw_utilities.py | 8 +- fireworks/utilities/tests/test_dagflow.py | 22 +- pyproject.toml | 18 ++ 9 files changed, 316 insertions(+), 49 deletions(-) diff --git a/fireworks/core/firework.py b/fireworks/core/firework.py index 1f362bb70..350399396 100644 --- a/fireworks/core/firework.py +++ b/fireworks/core/firework.py @@ -286,7 +286,8 @@ def __init__( self.updated_on = updated_on or datetime.utcnow() parents = [parents] if isinstance(parents, Firework) else parents - self.parents = parents if parents else [] + self.parents = list(parents) if parents else [] + assert self.parents is not None self._state = state @@ -477,7 +478,7 @@ def __init__( state_history: Optional[Dict[Any, Any]] = None, launch_id: Optional[int] = None, fw_id: Optional[int] = None, - ): + ) -> None: """ Args: state (str): the state of the Launch (e.g. RUNNING, COMPLETED) diff --git a/fireworks/core/launchpad.py b/fireworks/core/launchpad.py index d59917680..53e308828 100644 --- a/fireworks/core/launchpad.py +++ b/fireworks/core/launchpad.py @@ -23,6 +23,7 @@ from tqdm import tqdm from fireworks.core.firework import Firework, FWAction, Launch, Tracker, Workflow +from fireworks.core.types import Spec from fireworks.fw_config import ( GRIDFS_FALLBACK_COLLECTION, LAUNCHPAD_LOC, @@ -104,7 +105,7 @@ def __init__( fw_id: int, expire_secs: int = WFLOCK_EXPIRATION_SECS, kill: bool = WFLOCK_EXPIRATION_KILL, - ): + ) -> None: """ Args: lp (LaunchPad) @@ -170,7 +171,7 @@ def __init__( authsource: str = None, uri_mode: bool = False, mongoclient_kwargs=None, - ): + ) -> None: """ Args: host (str): hostname. If uri_mode is True, a MongoDB connection string URI @@ -241,7 +242,7 @@ def __init__( self.backup_launch_data = {} self.backup_fw_data = {} - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: """ Note: usernames/passwords are exported as unencrypted Strings! """ @@ -260,7 +261,7 @@ def to_dict(self): "mongoclient_kwargs": self.mongoclient_kwargs, } - def update_spec(self, fw_ids, spec_document, mongo=False): + def update_spec(self, fw_ids: Sequence[int], spec_document: Spec, mongo: bool = False): """ Update fireworks with a spec. Sometimes you need to modify a firework in progress. @@ -391,7 +392,7 @@ def maintain(self, infinite=True, maintain_interval=None): self.m_logger.debug(f"Sleeping for {maintain_interval} secs...") time.sleep(maintain_interval) - def add_wf(self, wf, reassign_all=True): + def add_wf(self, wf: Union[Workflow, Firework], reassign_all: bool = True) -> Dict[int, int]: """ Add workflow(or firework) to the launchpad. The firework ids will be reassigned. @@ -519,7 +520,7 @@ def get_fw_dict_by_id(self, fw_id): fw_dict["archived_launches"] = launches return fw_dict - def get_fw_by_id(self, fw_id): + def get_fw_by_id(self, fw_id: int) -> Firework: """ Given a Firework id, give back a Firework object. @@ -531,7 +532,7 @@ def get_fw_by_id(self, fw_id): """ return Firework.from_dict(self.get_fw_dict_by_id(fw_id)) - def get_wf_by_fw_id(self, fw_id): + def get_wf_by_fw_id(self, fw_id: int) -> Workflow: """ Given a Firework id, give back the Workflow containing that Firework. @@ -630,7 +631,7 @@ def delete_fws(self, fw_ids: Sequence[int], delete_launch_dirs: bool = False) -> self.offline_runs.delete_many({"launch_id": {"$in": launch_ids}}) self.fireworks.delete_many({"fw_id": {"$in": fw_ids}}) - def delete_wf(self, fw_id, delete_launch_dirs=False): + def delete_wf(self, fw_id: int, delete_launch_dirs: bool = False) -> None: """ Delete the workflow containing firework with the given id. @@ -645,7 +646,7 @@ def delete_wf(self, fw_id, delete_launch_dirs=False): print("Removing workflow.") self.workflows.delete_one({"nodes": fw_id}) - def get_wf_summary_dict(self, fw_id, mode="more"): + def get_wf_summary_dict(self, fw_id: int, mode: str = "more") -> Dict[str, Any]: """ A much faster way to get summary information about a Workflow by querying only for needed information. @@ -738,7 +739,7 @@ def get_fw_ids( limit: int = 0, count_only: bool = False, launches_mode: bool = False, - ) -> Union[int, Sequence[int]]: + ) -> List[int]: """ Return all the fw ids that match a query. @@ -828,8 +829,14 @@ def get_wf_ids(self, query=None, sort=None, limit=0, count_only=False): return [fw["nodes"][0] for fw in cursor] def get_fw_ids_in_wfs( - self, wf_query=None, fw_query=None, sort=None, limit=0, count_only=False, launches_mode=False - ): + self, + wf_query: Optional[Dict[str, Any]] = None, + fw_query: Optional[Dict[str, Any]] = None, + sort: Optional[List[Tuple[str, str]]] = None, + limit: int = 0, + count_only: bool = False, + launches_mode: bool = False, + ) -> List[int]: """ Return all fw ids that match fw_query within workflows that match wf_query. @@ -1603,7 +1610,7 @@ def ping_launch(self, launch_id, ptime=None, checkpoint=None): }, ) - def get_new_fw_id(self, quantity=1): + def get_new_fw_id(self, quantity: int = 1) -> int: """ Checkout the next Firework id @@ -1631,7 +1638,7 @@ def get_new_launch_id(self): "database, please do so by performing a database reset (e.g., lpad reset)" ) - def _upsert_fws(self, fws, reassign_all=False): + def _upsert_fws(self, fws: List[Firework], reassign_all: bool = False) -> Dict[int, int]: """ Insert the fireworks to the 'fireworks' collection. @@ -2137,21 +2144,21 @@ def updated_on(self, value): self.partial_fw.updated_on = value @property - def parents(self): + def parents(self) -> List[Firework]: if self._fw is not None: return self.partial_fw.parents else: return [] @parents.setter - def parents(self, value): + def parents(self, value) -> None: self.partial_fw.parents = value # Properties that shadow FireWork attributes, but which are # fetched individually from the DB (i.e. launch objects) @property - def launches(self): + def launches(self) -> Launch: return self._get_launch_data("launches") @launches.setter @@ -2164,7 +2171,7 @@ def archived_launches(self): return self._get_launch_data("archived_launches") @archived_launches.setter - def archived_launches(self, value): + def archived_launches(self, value) -> None: self._launches["archived_launches"] = True self.partial_fw.archived_launches = value @@ -2191,7 +2198,7 @@ def full_fw(self): # Get a type of Launch object - def _get_launch_data(self, name: str): + def _get_launch_data(self, name: str) -> Launch: """ Pull launch data individually for each field. diff --git a/fireworks/fw_config.py b/fireworks/fw_config.py index bfdf62cb3..a0cec9775 100644 --- a/fireworks/fw_config.py +++ b/fireworks/fw_config.py @@ -3,7 +3,7 @@ """ import os -from typing import Any, Dict +from typing import Any, Dict, List from monty.design_patterns import singleton from monty.serialization import dumpfn, loadfn @@ -43,7 +43,7 @@ PRINT_FW_YAML = False JSON_SCHEMA_VALIDATE = False -JSON_SCHEMA_VALIDATE_LIST = None +JSON_SCHEMA_VALIDATE_LIST: List[str] = None PING_TIME_SECS = 3600 # while Running a job, how often to ping back the server that we're still alive RUN_EXPIRATION_SECS = PING_TIME_SECS * 4 # mark job as FIZZLED if not pinged in this time diff --git a/fireworks/scripts/lpad_run.py b/fireworks/scripts/lpad_run.py index c6ec6eede..9626db62f 100644 --- a/fireworks/scripts/lpad_run.py +++ b/fireworks/scripts/lpad_run.py @@ -275,7 +275,13 @@ def print_fws(ids: Sequence[int], lp: LaunchPad, args: Namespace) -> None: print(args.output(fws)) +<<<<<<< HEAD def get_fw_ids_helper(lp: LaunchPad, args: Namespace, count_only: Union[bool, None] = None) -> Union[List[int], int]: +||||||| parent of a224c6c5 (continue typing) +def get_fw_ids_helper(lp, args, count_only=None): +======= +def get_fw_ids_helper(lp, args, count_only: Optional[bool] = None) -> List[int]: +>>>>>>> a224c6c5 (continue typing) """Build fws query from command line options and submit. Parameters: diff --git a/fireworks/utilities/dagflow.py b/fireworks/utilities/dagflow.py index c66da1685..fcf92af4c 100644 --- a/fireworks/utilities/dagflow.py +++ b/fireworks/utilities/dagflow.py @@ -5,6 +5,7 @@ __copyright__ = "Copyright 2017, Karlsruhe Institute of Technology" from itertools import combinations +from typing import List, Optional, Tuple import igraph from igraph import Graph @@ -42,7 +43,7 @@ class DAGFlow(Graph): """The purpose of this class is to help construction, validation and visualization of workflows.""" - def __init__(self, steps, links=None, nlinks=None, name=None, **kwargs): + def __init__(self, steps, links=None, nlinks=None, name: Optional[str] = None, **kwargs) -> None: Graph.__init__(self, directed=True, graph_attrs={"name": name}, **kwargs) for step in steps: @@ -56,7 +57,7 @@ def __init__(self, steps, links=None, nlinks=None, name=None, **kwargs): self._add_dataflow_links() @classmethod - def from_fireworks(cls, fireworkflow): + def from_fireworks(cls, fireworkflow) -> "DAGFlow": """Converts a fireworks workflow object into a new DAGFlow object""" wfd = fireworkflow.to_dict() if "name" in wfd: @@ -117,7 +118,7 @@ def task_input(task, spec): return cls(steps=steps, links=links, name=name) - def _get_links(self, nlinks): + def _get_links(self, nlinks) -> List[Tuple[x, x]]: """Translates named links into links between step ids""" links = [] for link in nlinks: @@ -126,7 +127,7 @@ def _get_links(self, nlinks): links.append((source, target)) return links - def _get_ctrlflow_links(self): + def _get_ctrlflow_links(self) -> List[Tuple[x, x]]: """Returns a list of unique tuples of link ids""" links = [] for ilink in {link.tuple for link in list(self.es)}: @@ -135,7 +136,7 @@ def _get_ctrlflow_links(self): links.append((source, target)) return links - def _add_ctrlflow_links(self, links): + def _add_ctrlflow_links(self, links) -> None: """Adds graph edges corresponding to control flow links""" for link in links: source = self._get_index(link[0]) @@ -262,22 +263,22 @@ def _get_leaves(self): """Returns all leaves (i.e. vertices without outgoing edges)""" return [i for i, v in enumerate(self.degree(mode=igraph.OUT)) if v == 0] - def delete_ctrlflow_links(self): + def delete_ctrlflow_links(self) -> None: """Deletes graph edges corresponding to control flow links""" lst = [link.index for link in list(self.es) if link["label"] == " "] self.delete_edges(lst) - def delete_dataflow_links(self): + def delete_dataflow_links(self) -> None: """Deletes graph edges corresponding to data flow links""" lst = [link.index for link in list(self.es) if link["label"] != " "] self.delete_edges(lst) - def add_step_labels(self): + def add_step_labels(self) -> None: """Labels the workflow steps (i.e. graph vertices)""" for vertex in list(self.vs): vertex["label"] = vertex["name"] + ", id: " + str(vertex["id"]) - def check(self): + def check(self) -> None: """Correctness check of the workflow""" try: assert self.is_dag(), "The workflow graph must be a DAG." @@ -288,7 +289,7 @@ def check(self): assert len(self.vs["id"]) == len(set(self.vs["id"])), "Workflow steps must have unique IDs." self.check_dataflow() - def check_dataflow(self): + def check_dataflow(self) -> None: """Checks whether all inputs and outputs match""" # check for shared output data entities @@ -344,3 +345,237 @@ def to_dot(self, filename="wf.dot", view="combined"): if isinstance(val, bool): del vertex[key] graph.write_dot(filename) +<<<<<<< HEAD +||||||| parent of d44d6080 (continue typing) + + +def plot_wf(wf, view="combined", labels=False, **kwargs): + """Plot workflow DAG via igraph.plot. + + Args: + wf (Workflow) + view (str): same as in 'to_dot'. Default: 'combined' + labels (bool): show a FW's name and id as labels in graph + + Other **kwargs can be any igraph plotting style keyword, overrides default. + See https://igraph.org/python/doc/tutorial/tutorial.html for possible + keywords. See `plot_wf` code for defaults. + + Returns: + igraph.drawing.Plot + """ + + dagf = DAGFlow.from_fireworks(wf) + if labels: + dagf.add_step_labels() + + # copied from to_dot + if view == "controlflow": + dagf.delete_dataflow_links() + elif view == "dataflow": + dagf.delete_ctrlflow_links() + elif view == "combined": + dlinks = [] + for vertex1, vertex2 in combinations(dagf.vs.indices, 2): + clinks = list(set(dagf.incident(vertex1, mode="ALL")) & set(dagf.incident(vertex2, mode="ALL"))) + if len(clinks) > 1: + for link in clinks: + if dagf.es[link]["label"] == " ": + dlinks.append(link) + dagf.delete_edges(dlinks) + + # remove non-string, non-numeric attributes because write_dot() warns + for vertex in dagf.vs: + for key, val in vertex.attributes().items(): + if not isinstance(val, (str, int, float, complex)): + del vertex[key] + if isinstance(val, bool): + del vertex[key] + + # plotting defaults + visual_style = DEFAULT_IGRAPH_VISUAL_STYLE.copy() + + # generic plotting defaults + visual_style["layout"] = dagf.layout_kamada_kawai() + + # vertex defaults + dagf_roots = dagf._get_roots() + dagf_leaves = dagf._get_leaves() + + def color_coding(v): + if v in dagf_roots: + return DEFAULT_IGRAPH_VERTEX_COLOR_CODING["root"] + elif v in dagf_leaves: + return DEFAULT_IGRAPH_VERTEX_COLOR_CODING["leaf"] + else: + return DEFAULT_IGRAPH_VERTEX_COLOR_CODING["other"] + + visual_style["vertex_color"] = [color_coding(v) for v in range(dagf.vcount())] + + visual_style.update(kwargs) + + # special treatment + if "layout" in kwargs and isinstance(kwargs["layout"], str): + visual_style["layout"] = dagf.layout(kwargs["layout"]) + + return igraph.plot(dagf, **visual_style) + + +@requires( + graphviz, + "graphviz package required for wf_to_graph.\n" + "Follow the installation instructions here: https://github.com/xflr6/graphviz", +) +def wf_to_graph(wf: Workflow) -> "Digraph": + """ + Renders a graph representation of a workflow or firework. Workflows are + rendered as the control flow of the firework, while Fireworks are + rendered as a sequence of Firetasks. + + Copied from https://git.io/JO6L8. + + Args: + workflow (Workflow|Firework): workflow or Firework + to be rendered. + + Returns: + Digraph: the rendered workflow or firework graph + """ + # Directed Acyclic Graph + dag = Digraph(comment=wf.name, graph_attr={"rankdir": "LR"}) + dag.node_attr["shape"] = "box" + if isinstance(wf, Workflow): + for fw in wf: + dag.node(str(fw.fw_id), label=fw.name, color=state_to_color[fw.state]) + + for start, targets in wf.links.items(): + for target in targets: + dag.edge(str(start), str(target)) + elif isinstance(wf, Firework): + for n, ft in enumerate(wf.tasks): + # Clean up names + name = ft.fw_name.replace("{", "").replace("}", "") + name = name.split(".")[-1] + dag.node(str(n), label=name) + if n >= 1: + dag.edge(str(n - 1), str(n)) + else: + raise ValueError("expected instance of Workflow or Firework") + return dag +======= + + +def plot_wf(wf: Workflow, view: str = "combined", labels: bool = False, **kwargs) -> Plot: + """Plot workflow DAG via igraph.plot. + + Args: + wf (Workflow) + view (str): same as in 'to_dot'. Default: 'combined' + labels (bool): show a FW's name and id as labels in graph + + Other **kwargs can be any igraph plotting style keyword, overrides default. + See https://igraph.org/python/doc/tutorial/tutorial.html for possible + keywords. See `plot_wf` code for defaults. + + Returns: + igraph.drawing.Plot + """ + + dagf = DAGFlow.from_fireworks(wf) + if labels: + dagf.add_step_labels() + + # copied from to_dot + if view == "controlflow": + dagf.delete_dataflow_links() + elif view == "dataflow": + dagf.delete_ctrlflow_links() + elif view == "combined": + dlinks = [] + for vertex1, vertex2 in combinations(dagf.vs.indices, 2): + clinks = list(set(dagf.incident(vertex1, mode="ALL")) & set(dagf.incident(vertex2, mode="ALL"))) + if len(clinks) > 1: + for link in clinks: + if dagf.es[link]["label"] == " ": + dlinks.append(link) + dagf.delete_edges(dlinks) + + # remove non-string, non-numeric attributes because write_dot() warns + for vertex in dagf.vs: + for key, val in vertex.attributes().items(): + if not isinstance(val, (str, int, float, complex)): + del vertex[key] + if isinstance(val, bool): + del vertex[key] + + # plotting defaults + visual_style = DEFAULT_IGRAPH_VISUAL_STYLE.copy() + + # generic plotting defaults + visual_style["layout"] = dagf.layout_kamada_kawai() + + # vertex defaults + dagf_roots = dagf._get_roots() + dagf_leaves = dagf._get_leaves() + + def color_coding(v): + if v in dagf_roots: + return DEFAULT_IGRAPH_VERTEX_COLOR_CODING["root"] + elif v in dagf_leaves: + return DEFAULT_IGRAPH_VERTEX_COLOR_CODING["leaf"] + else: + return DEFAULT_IGRAPH_VERTEX_COLOR_CODING["other"] + + visual_style["vertex_color"] = [color_coding(v) for v in range(dagf.vcount())] + + visual_style.update(kwargs) + + # special treatment + if "layout" in kwargs and isinstance(kwargs["layout"], str): + visual_style["layout"] = dagf.layout(kwargs["layout"]) + + return igraph.plot(dagf, **visual_style) + + +@requires( + graphviz, + "graphviz package required for wf_to_graph.\n" + "Follow the installation instructions here: https://github.com/xflr6/graphviz", +) +def wf_to_graph(wf: Workflow) -> "Digraph": + """ + Renders a graph representation of a workflow or firework. Workflows are + rendered as the control flow of the firework, while Fireworks are + rendered as a sequence of Firetasks. + + Copied from https://git.io/JO6L8. + + Args: + workflow (Workflow|Firework): workflow or Firework + to be rendered. + + Returns: + Digraph: the rendered workflow or firework graph + """ + # Directed Acyclic Graph + dag = Digraph(comment=wf.name, graph_attr={"rankdir": "LR"}) + dag.node_attr["shape"] = "box" + if isinstance(wf, Workflow): + for fw in wf: + dag.node(str(fw.fw_id), label=fw.name, color=state_to_color[fw.state]) + + for start, targets in wf.links.items(): + for target in targets: + dag.edge(str(start), str(target)) + elif isinstance(wf, Firework): + for n, ft in enumerate(wf.tasks): + # Clean up names + name = ft.fw_name.replace("{", "").replace("}", "") + name = name.split(".")[-1] + dag.node(str(n), label=name) + if n >= 1: + dag.edge(str(n - 1), str(n)) + else: + raise ValueError("expected instance of Workflow or Firework") + return dag +>>>>>>> d44d6080 (continue typing) diff --git a/fireworks/utilities/fw_serializers.py b/fireworks/utilities/fw_serializers.py index 4a31fd8b2..993b572cd 100644 --- a/fireworks/utilities/fw_serializers.py +++ b/fireworks/utilities/fw_serializers.py @@ -301,7 +301,7 @@ def __setstate__(self, state) -> None: # TODO: make this quicker the first time around -def load_object(obj_dict: Dict[str, Any]) -> Any: +def load_object(obj_dict: MutableMapping[str, Any]) -> Any: """ Creates an instantiation of a class based on a dictionary representation. We implicitly determine the Class through introspection along with information in the dictionary. diff --git a/fireworks/utilities/fw_utilities.py b/fireworks/utilities/fw_utilities.py index 261902f51..77713b0a1 100644 --- a/fireworks/utilities/fw_utilities.py +++ b/fireworks/utilities/fw_utilities.py @@ -10,7 +10,7 @@ import traceback from logging import Formatter, Logger from multiprocessing.managers import BaseManager -from typing import Tuple +from typing import List, Optional, Tuple from fireworks.fw_config import DS_PASSWORD, FW_BLOCK_FORMAT, FW_LOGGING_FORMAT, FWData @@ -20,14 +20,14 @@ __email__ = "ajain@lbl.gov" __date__ = "Dec 12, 2012" -PREVIOUS_STREAM_LOGGERS = [] # contains the name of loggers that have already been initialized +PREVIOUS_STREAM_LOGGERS: List[Tuple[str, str]] = [] # contains the name of loggers that have already been initialized PREVIOUS_FILE_LOGGERS = [] # contains the name of file loggers that have already been initialized DEFAULT_FORMATTER = Formatter(FW_LOGGING_FORMAT) def get_fw_logger( name: str, - l_dir: None = None, + l_dir: Optional[str] = None, file_levels: Tuple[str, str] = ("DEBUG", "ERROR"), stream_level: str = "DEBUG", formatter: Formatter = DEFAULT_FORMATTER, @@ -49,7 +49,7 @@ def get_fw_logger( stream_level = stream_level if stream_level else "CRITICAL" # add handlers for the file_levels - if l_dir: + if l_dir is not None: for lvl in file_levels: f_name = os.path.join(l_dir, name.replace(".", "_") + "-" + lvl.lower() + ".log") mode = "w" if clear_logs else "a" diff --git a/fireworks/utilities/tests/test_dagflow.py b/fireworks/utilities/tests/test_dagflow.py index 3e570d534..c928ee644 100644 --- a/fireworks/utilities/tests/test_dagflow.py +++ b/fireworks/utilities/tests/test_dagflow.py @@ -14,7 +14,7 @@ class DAGFlowTest(unittest.TestCase): """run tests for DAGFlow class""" - def setUp(self): + def setUp(self) -> None: try: __import__("igraph", fromlist=["Graph"]) except (ImportError, ModuleNotFoundError): @@ -32,7 +32,7 @@ def setUp(self): ) self.fw3 = Firework(PyTask(func="print", inputs=["second power"]), name="the third one") - def test_dagflow_ok(self): + def test_dagflow_ok(self) -> None: """construct and replicate""" from fireworks.utilities.dagflow import DAGFlow @@ -40,7 +40,7 @@ def test_dagflow_ok(self): dagf = DAGFlow.from_fireworks(wfl) DAGFlow(**dagf.to_dict()) - def test_dagflow_loop(self): + def test_dagflow_loop(self) -> None: """loop in graph""" from fireworks.utilities.dagflow import DAGFlow @@ -50,7 +50,7 @@ def test_dagflow_loop(self): DAGFlow.from_fireworks(wfl).check() self.assertTrue(msg in str(context.exception)) - def test_dagflow_cut(self): + def test_dagflow_cut(self) -> None: """disconnected graph""" from fireworks.utilities.dagflow import DAGFlow @@ -60,7 +60,7 @@ def test_dagflow_cut(self): DAGFlow.from_fireworks(wfl).check() self.assertTrue(msg in str(context.exception)) - def test_dagflow_link(self): + def test_dagflow_link(self) -> None: """wrong links""" from fireworks.utilities.dagflow import DAGFlow @@ -70,7 +70,7 @@ def test_dagflow_link(self): DAGFlow.from_fireworks(wfl).check() self.assertTrue(msg in str(context.exception)) - def test_dagflow_missing_input(self): + def test_dagflow_missing_input(self) -> None: """missing input""" from fireworks.utilities.dagflow import DAGFlow @@ -87,7 +87,7 @@ def test_dagflow_missing_input(self): DAGFlow.from_fireworks(wfl).check() self.assertTrue(msg in str(context.exception)) - def test_dagflow_clashing_inputs(self): + def test_dagflow_clashing_inputs(self) -> None: """parent firework output overwrites an input in spec""" from fireworks.utilities.dagflow import DAGFlow @@ -105,7 +105,7 @@ def test_dagflow_clashing_inputs(self): DAGFlow.from_fireworks(wfl).check() self.assertTrue(msg in str(context.exception)) - def test_dagflow_race_condition(self): + def test_dagflow_race_condition(self) -> None: """two parent firework outputs overwrite each other""" from fireworks.utilities.dagflow import DAGFlow @@ -121,7 +121,7 @@ def test_dagflow_race_condition(self): DAGFlow.from_fireworks(wfl).check() self.assertTrue(msg in str(context.exception)) - def test_dagflow_clashing_outputs(self): + def test_dagflow_clashing_outputs(self) -> None: """subsequent task overwrites output of a task""" from fireworks.utilities.dagflow import DAGFlow @@ -135,7 +135,7 @@ def test_dagflow_clashing_outputs(self): DAGFlow.from_fireworks(Workflow([fwk], {})).check() self.assertTrue(msg in str(context.exception)) - def test_dagflow_non_dataflow_tasks(self): + def test_dagflow_non_dataflow_tasks(self) -> None: """non-dataflow tasks using outputs and inputs keys do not fail""" from fireworks.core.firework import FiretaskBase from fireworks.utilities.dagflow import DAGFlow @@ -154,7 +154,7 @@ def run_task(self, fw_spec): wfl = Workflow([self.fw1, fw2], {self.fw1: [fw2], fw2: []}) DAGFlow.from_fireworks(wfl).check() - def test_dagflow_view(self): + def test_dagflow_view(self) -> None: """visualize the workflow graph""" from fireworks.utilities.dagflow import DAGFlow diff --git a/pyproject.toml b/pyproject.toml index b6e9b461d..88235b5b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,24 @@ warn_unreachable = true scripts_are_modules = true warn_unused_configs = true +# disallow_any_unimported = true +# disallow_any_decorated = true +# disallow_any_explicit = true +# disallow_any_expr = true +# disallow_any_generics = true +# disallow_subclassing_any = true + +# disallow_untyped_calls = true +# disallow_untyped_defs = true + +# check_untyped_defs = true + +# disallow_untyped_decorators = true + +no_implicit_optional = true + +strict_equality = true + [[tool.mypy.overrides]] module = [ "argcomplete", From e2c89ad0b573fc879e8018ddfae67adbbeaf4fea Mon Sep 17 00:00:00 2001 From: Eric Berquist Date: Fri, 4 Feb 2022 08:13:41 -0500 Subject: [PATCH 09/18] hacking --- fireworks/core/firework.py | 48 +++--- fireworks/core/types.py | 3 +- fireworks/fw_config.py | 6 +- fireworks/utilities/dagflow.py | 234 -------------------------- fireworks/utilities/fw_serializers.py | 2 +- pyproject.toml | 4 + 6 files changed, 34 insertions(+), 263 deletions(-) diff --git a/fireworks/core/firework.py b/fireworks/core/firework.py index 350399396..bb32e88ac 100644 --- a/fireworks/core/firework.py +++ b/fireworks/core/firework.py @@ -20,7 +20,7 @@ from monty.os.path import zpath from fireworks.core.fworker import FWorker -from fireworks.core.types import Checkpoint, FromDict, Spec +from fireworks.core.types import Checkpoint, FromDict, Spec, ToDict from fireworks.fw_config import ( EXCEPT_DETAILS_ON_RERUN, NEGATIVE_FWID_CTR, @@ -103,19 +103,19 @@ def run_task(self, fw_spec: Spec) -> Optional["FWAction"]: @serialize_fw @recursive_serialize - def to_dict(self) -> Dict[Any, Any]: + def to_dict(self) -> ToDict: return dict(self) @classmethod @recursive_deserialize - def from_dict(cls, m_dict: Dict[Any, Any]) -> "FiretaskBase": + def from_dict(cls, m_dict: FromDict) -> "FiretaskBase": return cls(m_dict) def __repr__(self) -> str: return f"<{self.fw_name}>:{dict(self)}" # not strictly needed here for pickle/unpickle, but complements __setstate__ - def __getstate__(self) -> Dict[Any, Any]: + def __getstate__(self) -> ToDict: return self.to_dict() # added to support pickle/unpickle @@ -178,7 +178,7 @@ def __init__( self.propagate = propagate @recursive_serialize - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> ToDict: return { "stored_data": self.stored_data, "exit": self.exit, @@ -193,7 +193,7 @@ def to_dict(self) -> Dict[str, Any]: @classmethod @recursive_deserialize - def from_dict(cls, m_dict: Dict[str, Any]) -> "FWAction": + def from_dict(cls, m_dict: FromDict) -> "FWAction": d = m_dict additions = [Workflow.from_dict(f) for f in d["additions"]] detours = [Workflow.from_dict(f) for f in d["detours"]] @@ -244,7 +244,7 @@ class Firework(FWSerializable): def __init__( self, tasks: Union["FiretaskBase", Sequence["FiretaskBase"]], - spec: Optional[Dict[Any, Any]] = None, + spec: Optional[Spec] = None, name: Optional[str] = None, launches: Optional[Sequence["Launch"]] = None, archived_launches: Optional[Sequence["Launch"]] = None, @@ -311,7 +311,7 @@ def state(self, state: str) -> None: self.updated_on = datetime.utcnow() @recursive_serialize - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> ToDict: # put tasks in a special location of the spec spec = self.spec spec["_tasks"] = [t.to_dict() for t in self.tasks] @@ -356,7 +356,7 @@ def _rerun(self) -> None: self.launches = [] self.state = "WAITING" - def to_db_dict(self) -> Dict[str, Any]: + def to_db_dict(self) -> ToDict: """ Return firework dict with updated launches and state. """ @@ -370,7 +370,7 @@ def to_db_dict(self) -> Dict[str, Any]: @classmethod @recursive_deserialize - def from_dict(cls, m_dict: Dict[str, Any]) -> "Firework": + def from_dict(cls, m_dict: FromDict) -> "Firework": tasks = m_dict["spec"]["_tasks"] launches = [Launch.from_dict(tmp) for tmp in m_dict.get("launches", [])] archived_launches = [Launch.from_dict(tmp) for tmp in m_dict.get("archived_launches", [])] @@ -445,14 +445,14 @@ def track_file(self, launch_dir: Optional[str] = None) -> str: self.content = "\n".join(reversed(lines)) return self.content - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> ToDict: m_dict = {"filename": self.filename, "nlines": self.nlines, "allow_zipped": self.allow_zipped} if self.content: m_dict["content"] = self.content return m_dict @classmethod - def from_dict(cls, m_dict: Dict[str, Any]) -> "Tracker": + def from_dict(cls, m_dict: FromDict) -> "Tracker": return Tracker( m_dict["filename"], m_dict["nlines"], m_dict.get("content", ""), m_dict.get("allow_zipped", False) ) @@ -604,7 +604,7 @@ def reservedtime_secs(self) -> int: # type: ignore return int((end - start).total_seconds()) @recursive_serialize - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> ToDict: return { "fworker": self.fworker, "fw_id": self.fw_id, @@ -619,7 +619,7 @@ def to_dict(self) -> Dict[str, Any]: } @recursive_serialize - def to_db_dict(self): + def to_db_dict(self) -> ToDict: m_d = self.to_dict() m_d["time_start"] = self.time_start m_d["time_end"] = self.time_end @@ -630,7 +630,7 @@ def to_db_dict(self): @classmethod @recursive_deserialize - def from_dict(cls, m_dict): + def from_dict(cls, m_dict: FromDict) -> "Launch": fworker = FWorker.from_dict(m_dict["fworker"]) if m_dict["fworker"] else None action = FWAction.from_dict(m_dict["action"]) if m_dict.get("action") else None trackers = [Tracker.from_dict(f) for f in m_dict["trackers"]] if m_dict.get("trackers") else None @@ -789,12 +789,12 @@ def __reduce__(self): def __init__( self, fireworks: Sequence[Firework], - links_dict: Dict[int, List[int]] = None, - name: str = None, + links_dict: Optional[Dict[int, List[int]]] = None, + name: Optional[str] = None, metadata: Dict[str, Any] = None, - created_on: datetime = None, - updated_on: datetime = None, - fw_states: Dict[int, str] = None, + created_on: Optional[datetime] = None, + updated_on: Optional[datetime] = None, + fw_states: Optional[Dict[int, str]] = None, ) -> None: """ Args: @@ -1202,7 +1202,7 @@ def leaf_fw_ids(self) -> List[int]: leaf_ids.append(id) return leaf_ids - def _reassign_ids(self, old_new: Dict[int, int]) -> None: + def _reassign_ids(self, old_new: Mapping[int, int]) -> None: """ Internal method to reassign Firework ids, e.g. due to database insertion. @@ -1227,7 +1227,7 @@ def _reassign_ids(self, old_new: Dict[int, int]) -> None: new_fw_states[old_new.get(fwid, fwid)] = fw_state self.fw_states = new_fw_states - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> ToDict: return { "fws": [f.to_dict() for f in self.id_fw.values()], "links": self.links.to_dict(), @@ -1237,7 +1237,7 @@ def to_dict(self) -> Dict[str, Any]: "created_on": self.created_on, } - def to_db_dict(self) -> Dict[str, Any]: + def to_db_dict(self) -> ToDict: m_dict = self.links.to_db_dict() m_dict["metadata"] = self.metadata m_dict["state"] = self.state @@ -1247,7 +1247,7 @@ def to_db_dict(self) -> Dict[str, Any]: m_dict["fw_states"] = {str(k): v for (k, v) in self.fw_states.items()} return m_dict - def to_display_dict(self): + def to_display_dict(self) -> ToDict: m_dict = self.to_db_dict() nodes = sorted(m_dict["nodes"]) m_dict["name--id"] = self.name + "--" + str(nodes[0]) diff --git a/fireworks/core/types.py b/fireworks/core/types.py index da47f644e..789333cfb 100644 --- a/fireworks/core/types.py +++ b/fireworks/core/types.py @@ -1,5 +1,6 @@ -from typing import Any, Mapping, MutableMapping +from typing import Any, Dict, Mapping, MutableMapping Checkpoint = MutableMapping[Any, Any] FromDict = Mapping[str, Any] Spec = MutableMapping[Any, Any] +ToDict = Dict[str, Any] diff --git a/fireworks/fw_config.py b/fireworks/fw_config.py index a0cec9775..69235271a 100644 --- a/fireworks/fw_config.py +++ b/fireworks/fw_config.py @@ -3,7 +3,7 @@ """ import os -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from monty.design_patterns import singleton from monty.serialization import dumpfn, loadfn @@ -43,7 +43,7 @@ PRINT_FW_YAML = False JSON_SCHEMA_VALIDATE = False -JSON_SCHEMA_VALIDATE_LIST: List[str] = None +JSON_SCHEMA_VALIDATE_LIST: Optional[List[str]] = None PING_TIME_SECS = 3600 # while Running a job, how often to ping back the server that we're still alive RUN_EXPIRATION_SECS = PING_TIME_SECS * 4 # mark job as FIZZLED if not pinged in this time @@ -167,7 +167,7 @@ def config_to_dict() -> Dict[str, Any]: return d -def write_config(path: str = None) -> None: +def write_config(path: Optional[str] = None) -> None: if path is None: path = os.path.join(os.path.expanduser("~"), ".fireworks", "FW_config.yaml") dumpfn(config_to_dict(), path) diff --git a/fireworks/utilities/dagflow.py b/fireworks/utilities/dagflow.py index fcf92af4c..8c98a1c95 100644 --- a/fireworks/utilities/dagflow.py +++ b/fireworks/utilities/dagflow.py @@ -345,237 +345,3 @@ def to_dot(self, filename="wf.dot", view="combined"): if isinstance(val, bool): del vertex[key] graph.write_dot(filename) -<<<<<<< HEAD -||||||| parent of d44d6080 (continue typing) - - -def plot_wf(wf, view="combined", labels=False, **kwargs): - """Plot workflow DAG via igraph.plot. - - Args: - wf (Workflow) - view (str): same as in 'to_dot'. Default: 'combined' - labels (bool): show a FW's name and id as labels in graph - - Other **kwargs can be any igraph plotting style keyword, overrides default. - See https://igraph.org/python/doc/tutorial/tutorial.html for possible - keywords. See `plot_wf` code for defaults. - - Returns: - igraph.drawing.Plot - """ - - dagf = DAGFlow.from_fireworks(wf) - if labels: - dagf.add_step_labels() - - # copied from to_dot - if view == "controlflow": - dagf.delete_dataflow_links() - elif view == "dataflow": - dagf.delete_ctrlflow_links() - elif view == "combined": - dlinks = [] - for vertex1, vertex2 in combinations(dagf.vs.indices, 2): - clinks = list(set(dagf.incident(vertex1, mode="ALL")) & set(dagf.incident(vertex2, mode="ALL"))) - if len(clinks) > 1: - for link in clinks: - if dagf.es[link]["label"] == " ": - dlinks.append(link) - dagf.delete_edges(dlinks) - - # remove non-string, non-numeric attributes because write_dot() warns - for vertex in dagf.vs: - for key, val in vertex.attributes().items(): - if not isinstance(val, (str, int, float, complex)): - del vertex[key] - if isinstance(val, bool): - del vertex[key] - - # plotting defaults - visual_style = DEFAULT_IGRAPH_VISUAL_STYLE.copy() - - # generic plotting defaults - visual_style["layout"] = dagf.layout_kamada_kawai() - - # vertex defaults - dagf_roots = dagf._get_roots() - dagf_leaves = dagf._get_leaves() - - def color_coding(v): - if v in dagf_roots: - return DEFAULT_IGRAPH_VERTEX_COLOR_CODING["root"] - elif v in dagf_leaves: - return DEFAULT_IGRAPH_VERTEX_COLOR_CODING["leaf"] - else: - return DEFAULT_IGRAPH_VERTEX_COLOR_CODING["other"] - - visual_style["vertex_color"] = [color_coding(v) for v in range(dagf.vcount())] - - visual_style.update(kwargs) - - # special treatment - if "layout" in kwargs and isinstance(kwargs["layout"], str): - visual_style["layout"] = dagf.layout(kwargs["layout"]) - - return igraph.plot(dagf, **visual_style) - - -@requires( - graphviz, - "graphviz package required for wf_to_graph.\n" - "Follow the installation instructions here: https://github.com/xflr6/graphviz", -) -def wf_to_graph(wf: Workflow) -> "Digraph": - """ - Renders a graph representation of a workflow or firework. Workflows are - rendered as the control flow of the firework, while Fireworks are - rendered as a sequence of Firetasks. - - Copied from https://git.io/JO6L8. - - Args: - workflow (Workflow|Firework): workflow or Firework - to be rendered. - - Returns: - Digraph: the rendered workflow or firework graph - """ - # Directed Acyclic Graph - dag = Digraph(comment=wf.name, graph_attr={"rankdir": "LR"}) - dag.node_attr["shape"] = "box" - if isinstance(wf, Workflow): - for fw in wf: - dag.node(str(fw.fw_id), label=fw.name, color=state_to_color[fw.state]) - - for start, targets in wf.links.items(): - for target in targets: - dag.edge(str(start), str(target)) - elif isinstance(wf, Firework): - for n, ft in enumerate(wf.tasks): - # Clean up names - name = ft.fw_name.replace("{", "").replace("}", "") - name = name.split(".")[-1] - dag.node(str(n), label=name) - if n >= 1: - dag.edge(str(n - 1), str(n)) - else: - raise ValueError("expected instance of Workflow or Firework") - return dag -======= - - -def plot_wf(wf: Workflow, view: str = "combined", labels: bool = False, **kwargs) -> Plot: - """Plot workflow DAG via igraph.plot. - - Args: - wf (Workflow) - view (str): same as in 'to_dot'. Default: 'combined' - labels (bool): show a FW's name and id as labels in graph - - Other **kwargs can be any igraph plotting style keyword, overrides default. - See https://igraph.org/python/doc/tutorial/tutorial.html for possible - keywords. See `plot_wf` code for defaults. - - Returns: - igraph.drawing.Plot - """ - - dagf = DAGFlow.from_fireworks(wf) - if labels: - dagf.add_step_labels() - - # copied from to_dot - if view == "controlflow": - dagf.delete_dataflow_links() - elif view == "dataflow": - dagf.delete_ctrlflow_links() - elif view == "combined": - dlinks = [] - for vertex1, vertex2 in combinations(dagf.vs.indices, 2): - clinks = list(set(dagf.incident(vertex1, mode="ALL")) & set(dagf.incident(vertex2, mode="ALL"))) - if len(clinks) > 1: - for link in clinks: - if dagf.es[link]["label"] == " ": - dlinks.append(link) - dagf.delete_edges(dlinks) - - # remove non-string, non-numeric attributes because write_dot() warns - for vertex in dagf.vs: - for key, val in vertex.attributes().items(): - if not isinstance(val, (str, int, float, complex)): - del vertex[key] - if isinstance(val, bool): - del vertex[key] - - # plotting defaults - visual_style = DEFAULT_IGRAPH_VISUAL_STYLE.copy() - - # generic plotting defaults - visual_style["layout"] = dagf.layout_kamada_kawai() - - # vertex defaults - dagf_roots = dagf._get_roots() - dagf_leaves = dagf._get_leaves() - - def color_coding(v): - if v in dagf_roots: - return DEFAULT_IGRAPH_VERTEX_COLOR_CODING["root"] - elif v in dagf_leaves: - return DEFAULT_IGRAPH_VERTEX_COLOR_CODING["leaf"] - else: - return DEFAULT_IGRAPH_VERTEX_COLOR_CODING["other"] - - visual_style["vertex_color"] = [color_coding(v) for v in range(dagf.vcount())] - - visual_style.update(kwargs) - - # special treatment - if "layout" in kwargs and isinstance(kwargs["layout"], str): - visual_style["layout"] = dagf.layout(kwargs["layout"]) - - return igraph.plot(dagf, **visual_style) - - -@requires( - graphviz, - "graphviz package required for wf_to_graph.\n" - "Follow the installation instructions here: https://github.com/xflr6/graphviz", -) -def wf_to_graph(wf: Workflow) -> "Digraph": - """ - Renders a graph representation of a workflow or firework. Workflows are - rendered as the control flow of the firework, while Fireworks are - rendered as a sequence of Firetasks. - - Copied from https://git.io/JO6L8. - - Args: - workflow (Workflow|Firework): workflow or Firework - to be rendered. - - Returns: - Digraph: the rendered workflow or firework graph - """ - # Directed Acyclic Graph - dag = Digraph(comment=wf.name, graph_attr={"rankdir": "LR"}) - dag.node_attr["shape"] = "box" - if isinstance(wf, Workflow): - for fw in wf: - dag.node(str(fw.fw_id), label=fw.name, color=state_to_color[fw.state]) - - for start, targets in wf.links.items(): - for target in targets: - dag.edge(str(start), str(target)) - elif isinstance(wf, Firework): - for n, ft in enumerate(wf.tasks): - # Clean up names - name = ft.fw_name.replace("{", "").replace("}", "") - name = name.split(".")[-1] - dag.node(str(n), label=name) - if n >= 1: - dag.edge(str(n - 1), str(n)) - else: - raise ValueError("expected instance of Workflow or Firework") - return dag ->>>>>>> d44d6080 (continue typing) diff --git a/fireworks/utilities/fw_serializers.py b/fireworks/utilities/fw_serializers.py index 993b572cd..6c5a766f4 100644 --- a/fireworks/utilities/fw_serializers.py +++ b/fireworks/utilities/fw_serializers.py @@ -33,7 +33,7 @@ import json # note that ujson is faster, but at this time does not support "default" in dumps() import pkgutil import traceback -from typing import Any, Dict, Mapping, MutableMapping, Optional, Type +from typing import Any, Mapping, MutableMapping, Optional, Type import ruamel.yaml as yaml from monty.json import MontyDecoder, MSONable diff --git a/pyproject.toml b/pyproject.toml index 88235b5b0..001427573 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,8 @@ strict_equality = true [[tool.mypy.overrides]] module = [ "argcomplete", + "atomate", + "atomate.vasp", "fabric", "fireworks_schema", "flask_paginate", @@ -54,6 +56,8 @@ module = [ "monty.os.path", "monty.serialization", "monty.shutil", + "pymatgen.core", + "pytest", # https://github.com/tqdm/tqdm/issues/260 "tqdm", ] From caba8e604c6e736e84e2f4ac7577983be733f8ad Mon Sep 17 00:00:00 2001 From: Eric Berquist Date: Fri, 4 Feb 2022 12:48:52 -0500 Subject: [PATCH 10/18] hacking --- fireworks/core/firework.py | 4 +-- fireworks/core/fworker.py | 11 ++++--- fireworks/core/launchpad.py | 32 ++++++++++++------- fireworks/core/types.py | 2 +- fireworks/fw_config.py | 2 +- fireworks/queue/queue_adapter.py | 3 +- fireworks/scripts/fwtool | 5 +-- fireworks/scripts/lpad_run.py | 2 +- .../user_objects/firetasks/script_task.py | 11 ++++--- fireworks/utilities/dagflow.py | 4 ++- fireworks/utilities/fw_serializers.py | 15 +++++---- 11 files changed, 53 insertions(+), 38 deletions(-) diff --git a/fireworks/core/firework.py b/fireworks/core/firework.py index bb32e88ac..77334f035 100644 --- a/fireworks/core/firework.py +++ b/fireworks/core/firework.py @@ -54,10 +54,10 @@ class FiretaskBase(defaultdict, FWSerializable, metaclass=abc.ABCMeta): You can set parameters of a Firetask like you'd use a dict. """ - required_params = None # list of str of required parameters to check for consistency upon init + required_params: Optional[List[str]] = None # list of str of required parameters to check for consistency upon init # if set to a list of str, only required and optional kwargs are allowed; consistency checked upon init - optional_params = None + optional_params: Optional[List[str]] = None def __init__(self, *args, **kwargs) -> None: dict.__init__(self, *args, **kwargs) diff --git a/fireworks/core/fworker.py b/fireworks/core/fworker.py index 6ae28e025..b0155688a 100644 --- a/fireworks/core/fworker.py +++ b/fireworks/core/fworker.py @@ -5,6 +5,7 @@ import json from typing import Any, Dict, Optional, Sequence, Union +from fireworks.core.types import FromDict, ToDict from fireworks.fw_config import FWORKER_LOC from fireworks.utilities.fw_serializers import ( DATETIME_HANDLER, @@ -27,7 +28,7 @@ def __init__( name: str = "Automatically generated Worker", category: Union[str, Sequence[str]] = "", query: Optional[Dict[str, Any]] = None, - env: Optional[Dict[str, Any]] = None, + env: Optional[Dict[str, str]] = None, ) -> None: """ Args: @@ -48,7 +49,7 @@ def __init__( self.env = env if env else {} @recursive_serialize - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> ToDict: return { "name": self.name, "category": self.category, @@ -58,11 +59,11 @@ def to_dict(self) -> Dict[str, Any]: @classmethod @recursive_deserialize - def from_dict(cls, m_dict: Dict[str, Any]) -> "FWorker": + def from_dict(cls, m_dict: FromDict) -> "FWorker": return FWorker(m_dict["name"], m_dict["category"], json.loads(m_dict["query"]), m_dict.get("env")) @property - def query(self) -> Dict[str, Any]: + def query(self) -> ToDict: """ Returns updated query dict. """ @@ -89,5 +90,5 @@ def auto_load(cls) -> "FWorker": Returns FWorker object from settings file(my_fworker.yaml). """ if FWORKER_LOC: - return FWorker.from_file(FWORKER_LOC) + return FWorker.from_file(FWORKER_LOC) # type: ignore return FWorker() diff --git a/fireworks/core/launchpad.py b/fireworks/core/launchpad.py index 53e308828..75ce03a73 100644 --- a/fireworks/core/launchpad.py +++ b/fireworks/core/launchpad.py @@ -23,7 +23,8 @@ from tqdm import tqdm from fireworks.core.firework import Firework, FWAction, Launch, Tracker, Workflow -from fireworks.core.types import Spec +from fireworks.core.fworker import FWorker +from fireworks.core.types import FromDict, Spec, ToDict from fireworks.fw_config import ( GRIDFS_FALLBACK_COLLECTION, LAUNCHPAD_LOC, @@ -242,7 +243,7 @@ def __init__( self.backup_launch_data = {} self.backup_fw_data = {} - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> ToDict: """ Note: usernames/passwords are exported as unencrypted Strings! """ @@ -261,7 +262,7 @@ def to_dict(self) -> Dict[str, Any]: "mongoclient_kwargs": self.mongoclient_kwargs, } - def update_spec(self, fw_ids: Sequence[int], spec_document: Spec, mongo: bool = False): + def update_spec(self, fw_ids: Sequence[int], spec_document: Spec, mongo: bool = False) -> None: """ Update fireworks with a spec. Sometimes you need to modify a firework in progress. @@ -288,7 +289,7 @@ def update_spec(self, fw_ids: Sequence[int], spec_document: Spec, mongo: bool = ) @classmethod - def from_dict(cls, d): + def from_dict(cls, d: FromDict) -> "LaunchPad": port = d.get("port", None) name = d.get("name", None) username = d.get("username", None) @@ -316,7 +317,7 @@ def from_dict(cls, d): ) @classmethod - def auto_load(cls): + def auto_load(cls) -> "LaunchPad": if LAUNCHPAD_LOC: return LaunchPad.from_file(LAUNCHPAD_LOC) return LaunchPad() @@ -476,7 +477,7 @@ def append_wf(self, new_wf, fw_ids, detour=False, pull_spec_mods=True): with WFLock(self, fw_ids[0]): self._update_wf(wf, updated_ids) - def get_launch_by_id(self, launch_id): + def get_launch_by_id(self, launch_id: int) -> Launch: """ Given a Launch id, return details of the Launch. @@ -492,7 +493,7 @@ def get_launch_by_id(self, launch_id): return Launch.from_dict(m_launch) raise ValueError(f"No Launch exists with launch_id: {launch_id}") - def get_fw_dict_by_id(self, fw_id): + def get_fw_dict_by_id(self, fw_id: int): """ Given firework id, return firework dict. @@ -836,7 +837,7 @@ def get_fw_ids_in_wfs( limit: int = 0, count_only: bool = False, launches_mode: bool = False, - ) -> List[int]: + ) -> Union[List[int], int]: """ Return all fw ids that match fw_query within workflows that match wf_query. @@ -910,7 +911,7 @@ def get_fw_ids_in_wfs( cursor = self.workflows.aggregate(aggregation) return [fw["fw_id"] for fw in cursor] - def run_exists(self, fworker=None): + def run_exists(self, fworker: Optional[FWorker] = None) -> bool: """ Checks to see if the database contains any FireWorks that are ready to run. @@ -920,7 +921,7 @@ def run_exists(self, fworker=None): q = fworker.query if fworker else {} return bool(self._get_a_fw_to_run(query=q, checkout=False)) - def future_run_exists(self, fworker=None): + def future_run_exists(self, fworker: Optional[FWorker] = None) -> bool: """Check if database has any current OR future Fireworks available Returns: @@ -1199,7 +1200,7 @@ def _get_a_fw_to_run(self, query=None, fw_id=None, checkout=True): if self._check_fw_for_uniqueness(m_fw): return m_fw - def _get_active_launch_ids(self): + def _get_active_launch_ids(self) -> List[int]: """ Get all the launch ids. @@ -1211,7 +1212,14 @@ def _get_active_launch_ids(self): all_launch_ids.extend(l["launches"]) return all_launch_ids - def reserve_fw(self, fworker, launch_dir, host=None, ip=None, fw_id=None): + def reserve_fw( + self, + fworker: FWorker, + launch_dir: str, + host: Optional[str] = None, + ip: Optional[str] = None, + fw_id: Optional[int] = None, + ): """ Checkout the next ready firework and mark the launch reserved. diff --git a/fireworks/core/types.py b/fireworks/core/types.py index 789333cfb..51e95e95f 100644 --- a/fireworks/core/types.py +++ b/fireworks/core/types.py @@ -2,5 +2,5 @@ Checkpoint = MutableMapping[Any, Any] FromDict = Mapping[str, Any] -Spec = MutableMapping[Any, Any] +Spec = MutableMapping[str, Any] ToDict = Dict[str, Any] diff --git a/fireworks/fw_config.py b/fireworks/fw_config.py index 69235271a..bc8691267 100644 --- a/fireworks/fw_config.py +++ b/fireworks/fw_config.py @@ -59,7 +59,7 @@ RAPIDFIRE_SLEEP_SECS = 60 # seconds to sleep between rapidfire loops LAUNCHPAD_LOC = None # where to find the my_launchpad.yaml file -FWORKER_LOC = None # where to find the my_fworker.yaml file +FWORKER_LOC: Optional[str] = None # where to find the my_fworker.yaml file QUEUEADAPTER_LOC = None # where to find the my_qadapter.yaml file CONFIG_FILE_DIR = "." # directory containing config files (if not individually set) diff --git a/fireworks/queue/queue_adapter.py b/fireworks/queue/queue_adapter.py index b28f508ea..760506b09 100644 --- a/fireworks/queue/queue_adapter.py +++ b/fireworks/queue/queue_adapter.py @@ -11,6 +11,7 @@ import threading import traceback import warnings +from typing import Any, Dict from fireworks.utilities.fw_serializers import FWSerializable, serialize_fw from fireworks.utilities.fw_utilities import get_fw_logger @@ -106,7 +107,7 @@ class QueueAdapterBase(collections.defaultdict, FWSerializable): template_file = "OVERRIDE_ME" # path to template file for a queue script submit_cmd = "OVERRIDE_ME" # command to submit jobs, e.g. "qsub" or "squeue" q_name = "OVERRIDE_ME" # (arbitrary) name, e.g. "pbs" or "slurm" - defaults = {} # default parameter values for template + defaults: Dict[str, Any] = {} # default parameter values for template def get_script_str(self, launch_dir): """ diff --git a/fireworks/scripts/fwtool b/fireworks/scripts/fwtool index 94e5d2176..7be00c13a 100755 --- a/fireworks/scripts/fwtool +++ b/fireworks/scripts/fwtool @@ -4,10 +4,11 @@ import os import shutil import sys from argparse import ArgumentParser, Namespace +from typing import List import yaml -from fireworks.core.firework import Firework +from fireworks.core.firework import FiretaskBase, Firework from fireworks.core.launchpad import LaunchPad from fireworks.user_objects.firetasks.fileio_tasks import FileWriteTask from fireworks.user_objects.firetasks.script_task import ScriptTask @@ -22,7 +23,7 @@ from typing import Sequence def create_fw_single(args: Namespace, fnames: Sequence[str], yaml_fname: str) -> None: - tasks = [] + tasks: List[FiretaskBase] = [] if fnames: files = [] for fname in fnames: diff --git a/fireworks/scripts/lpad_run.py b/fireworks/scripts/lpad_run.py index 9626db62f..3928d8ba8 100644 --- a/fireworks/scripts/lpad_run.py +++ b/fireworks/scripts/lpad_run.py @@ -1,7 +1,7 @@ """ A runnable script for managing a FireWorks database (a command-line interface to launchpad.py) """ - +# mypy: ignore-errors import ast import copy import datetime diff --git a/fireworks/user_objects/firetasks/script_task.py b/fireworks/user_objects/firetasks/script_task.py index 5630dfb78..ca4a2b396 100644 --- a/fireworks/user_objects/firetasks/script_task.py +++ b/fireworks/user_objects/firetasks/script_task.py @@ -4,9 +4,10 @@ import shlex import subprocess import sys -from typing import Dict, List, Optional, Union +from typing import Any, MutableMapping, Optional from fireworks.core.firework import FiretaskBase, FWAction +from fireworks.core.types import Spec if sys.version_info[0] > 2: basestring = str @@ -24,7 +25,7 @@ class ScriptTask(FiretaskBase): required_params = ["script"] _fw_name = "ScriptTask" - def run_task(self, fw_spec): + def run_task(self, fw_spec: Spec) -> FWAction: if self.get("use_global_spec"): self._load_params(fw_spec) else: @@ -37,7 +38,7 @@ def run_task(self, fw_spec): stdin = subprocess.PIPE if self.stdin_key else None return self._run_task_internal(fw_spec, stdin) - def _run_task_internal(self, fw_spec, stdin): + def _run_task_internal(self, fw_spec: Spec, stdin) -> FWAction: # run the program stdout = subprocess.PIPE if self.store_stdout or self.stdout_file else None stderr = subprocess.PIPE if self.store_stderr or self.stderr_file else None @@ -120,7 +121,7 @@ def _load_params(self, d): raise ValueError("ScriptTask cannot both FIZZLE and DEFUSE a bad returncode!") @classmethod - def from_str(cls, shell_cmd, parameters=None): + def from_str(cls, shell_cmd: str, parameters: Optional[MutableMapping[str, Any]] = None) -> "ScriptTask": parameters = parameters if parameters else {} parameters["script"] = [shell_cmd] parameters["use_shell"] = True @@ -163,7 +164,7 @@ class PyTask(FiretaskBase): # note that we are not using "optional_params" because we do not want to do # strict parameter checking in FiretaskBase due to "auto_kwargs" option - def run_task(self, fw_spec: Dict[str, Union[List[int], int]]) -> Optional[FWAction]: + def run_task(self, fw_spec: Spec) -> Optional[FWAction]: toks = self["func"].rsplit(".", 1) if len(toks) == 2: modname, funcname = toks diff --git a/fireworks/utilities/dagflow.py b/fireworks/utilities/dagflow.py index 8c98a1c95..1ccc6e132 100644 --- a/fireworks/utilities/dagflow.py +++ b/fireworks/utilities/dagflow.py @@ -10,6 +10,8 @@ import igraph from igraph import Graph +from fireworks import Workflow + DF_TASKS = ["PyTask", "CommandLineTask", "ForeachTask", "JoinDictTask", "JoinListTask"] DEFAULT_IGRAPH_VISUAL_STYLE = { @@ -57,7 +59,7 @@ def __init__(self, steps, links=None, nlinks=None, name: Optional[str] = None, * self._add_dataflow_links() @classmethod - def from_fireworks(cls, fireworkflow) -> "DAGFlow": + def from_fireworks(cls, fireworkflow: Workflow) -> "DAGFlow": """Converts a fireworks workflow object into a new DAGFlow object""" wfd = fireworkflow.to_dict() if "name" in wfd: diff --git a/fireworks/utilities/fw_serializers.py b/fireworks/utilities/fw_serializers.py index 6c5a766f4..1948d20e7 100644 --- a/fireworks/utilities/fw_serializers.py +++ b/fireworks/utilities/fw_serializers.py @@ -94,7 +94,7 @@ def recursive_dict(obj: Any, preserve_unicode: bool = True) -> Any: if isinstance(obj, datetime.datetime): return obj.isoformat() - if preserve_unicode and isinstance(obj, str) and obj != obj.encode("ascii", "ignore"): + if preserve_unicode and isinstance(obj, str) and obj != obj.encode("ascii", "ignore"): # type: ignore return obj if NUMPY_INSTALLED and isinstance(obj, np.ndarray): @@ -129,7 +129,7 @@ def _recursive_load(obj: Any) -> Any: return reconstitute_dates(obj) except Exception: # convert unicode to ASCII if not really unicode - if obj == obj.encode("ascii", "ignore"): + if obj == obj.encode("ascii", "ignore"): # type: ignore return str(obj) return obj @@ -271,7 +271,7 @@ def to_file(self, filename: str, f_format: Optional[str] = None, **kwargs) -> No """ if f_format is None: f_format = filename.split(".")[-1] - with open(filename, "w", **ENCODING_PARAMS) as f: + with open(filename, "w", **ENCODING_PARAMS) as f: # type: ignore f.write(self.to_format(f_format=f_format, **kwargs)) @classmethod @@ -288,7 +288,7 @@ def from_file(cls, filename: str, f_format: Optional[str] = None) -> "FWSerializ """ if f_format is None: f_format = filename.split(".")[-1] - with open(filename, "r", **ENCODING_PARAMS) as f: + with open(filename, "r", **ENCODING_PARAMS) as f: # type: ignore return cls.from_format(f.read(), f_format=f_format) def __getstate__(self): @@ -384,7 +384,7 @@ def load_object_from_file(filename: str, f_format: Optional[str] = None) -> Any: if f_format is None: f_format = filename.split(".")[-1] - with open(filename, "r", **ENCODING_PARAMS) as f: + with open(filename, "r", **ENCODING_PARAMS) as f: # type: ignore if f_format == "json": dct = json.loads(f.read()) elif f_format == "yaml": @@ -393,8 +393,9 @@ def load_object_from_file(filename: str, f_format: Optional[str] = None) -> Any: raise ValueError(f"Unknown file format {f_format} cannot be loaded!") classname = FW_NAME_UPDATES.get(dct["_fw_name"], dct["_fw_name"]) - if JSON_SCHEMA_VALIDATE and classname in JSON_SCHEMA_VALIDATE_LIST: - fireworks_schema.validate(dct, classname) + if JSON_SCHEMA_VALIDATE: + if JSON_SCHEMA_VALIDATE_LIST is not None and classname in JSON_SCHEMA_VALIDATE_LIST: + fireworks_schema.validate(dct, classname) return load_object(reconstitute_dates(dct)) From dc251db57beaa4f0321b68fb0c9f8cabac3f8e65 Mon Sep 17 00:00:00 2001 From: Eric Berquist Date: Mon, 7 Feb 2022 08:42:46 -0500 Subject: [PATCH 11/18] hacking --- fireworks/core/firework.py | 34 ++++++++++++++++----------- fireworks/utilities/dagflow.py | 4 ++-- fireworks/utilities/fw_serializers.py | 12 ++++++---- 3 files changed, 30 insertions(+), 20 deletions(-) diff --git a/fireworks/core/firework.py b/fireworks/core/firework.py index 77334f035..b8e7e0383 100644 --- a/fireworks/core/firework.py +++ b/fireworks/core/firework.py @@ -116,11 +116,11 @@ def __repr__(self) -> str: # not strictly needed here for pickle/unpickle, but complements __setstate__ def __getstate__(self) -> ToDict: - return self.to_dict() + return self.to_dict() # type: ignore # added to support pickle/unpickle def __setstate__(self, state) -> None: - self.__init__(state) + self.__init__(state) # type: ignore # added to support pickle/unpickle def __reduce__(self) -> Tuple: @@ -140,7 +140,7 @@ def __init__( stored_data: Optional[Dict[Any, Any]] = None, exit: bool = False, update_spec: Optional[Mapping[str, Any]] = None, - mod_spec: Optional[Mapping[str, Any]] = None, + mod_spec: Optional[Sequence[Mapping[str, Any]]] = None, additions: Optional[Union[Sequence["Firework"], Sequence["Workflow"]]] = None, detours: Optional[Union[Sequence["Firework"], Sequence["Workflow"]]] = None, defuse_children: bool = False, @@ -163,7 +163,7 @@ def __init__( not only to direct children, but to all dependent FireWorks down to the Workflow's leaves. """ - mod_spec = mod_spec if mod_spec is not None else [] + mod_spec = list(mod_spec) if mod_spec is not None else [] additions = additions if additions is not None else [] detours = detours if detours is not None else [] @@ -217,7 +217,13 @@ def skip_remaining_tasks(self) -> bool: Returns: bool """ - return self.exit or self.detours or self.additions or self.defuse_children or self.defuse_workflow + return ( + self.exit + or bool(self.detours) + or bool(self.additions) + or bool(self.defuse_children) + or bool(self.defuse_workflow) + ) def __str__(self) -> str: return "FWAction\n" + pprint.pformat(self.to_dict()) @@ -267,9 +273,9 @@ def __init__( updated_on (datetime): last time the STATE was updated. """ - self.tasks = tasks if isinstance(tasks, (list, tuple)) else [tasks] + self.tasks = list(tasks) if isinstance(tasks, Sequence) else [tasks] - self.spec = spec.copy() if spec else {} + self.spec = dict(spec) if spec else {} self.name = name or "Unnamed FW" # do it this way to prevent None # names @@ -280,8 +286,8 @@ def __init__( NEGATIVE_FWID_CTR -= 1 self.fw_id = NEGATIVE_FWID_CTR - self.launches = launches if launches else [] - self.archived_launches = archived_launches if archived_launches else [] + self.launches = list(launches) if launches else [] + self.archived_launches = list(archived_launches) if archived_launches else [] self.created_on = created_on or datetime.utcnow() self.updated_on = updated_on or datetime.utcnow() @@ -353,7 +359,7 @@ def _rerun(self) -> None: self.archived_launches.extend(self.launches) self.archived_launches = list(set(self.archived_launches)) # filter duplicates - self.launches = [] + self.launches: List[Launch] = [] # type: ignore self.state = "WAITING" def to_db_dict(self) -> ToDict: @@ -668,7 +674,7 @@ def _update_state_history(self, state): if state in ["RUNNING", "RESERVED"]: self.touch_history() # add updated_on key - def _get_time(self, states, use_update_time: bool = False) -> datetime: # type: ignore + def _get_time(self, states, use_update_time: bool = False) -> datetime: """ Internal method to help get the time of various events in the Launch (e.g. RUNNING) from the state history. @@ -684,8 +690,8 @@ def _get_time(self, states, use_update_time: bool = False) -> datetime: # type: for data in self.state_history: if data["state"] in states: if use_update_time: - return data["updated_on"] - return data["created_on"] + return data["updated_on"] # type: ignore + return data["created_on"] # type: ignore class Workflow(FWSerializable): @@ -1391,7 +1397,7 @@ def remove_fws(self, fw_ids: Sequence[int]) -> None: """ # not working with the copies, causes spurious behavior - wf_dict = deepcopy(self.as_dict()) + wf_dict = deepcopy(self.as_dict()) # type: ignore orig_parent_links = deepcopy(self.links.parent_links) fws = wf_dict["fws"] diff --git a/fireworks/utilities/dagflow.py b/fireworks/utilities/dagflow.py index 1ccc6e132..dcfb2ceb4 100644 --- a/fireworks/utilities/dagflow.py +++ b/fireworks/utilities/dagflow.py @@ -120,7 +120,7 @@ def task_input(task, spec): return cls(steps=steps, links=links, name=name) - def _get_links(self, nlinks) -> List[Tuple[x, x]]: + def _get_links(self, nlinks) -> List[Tuple[List[str], List[str]]]: """Translates named links into links between step ids""" links = [] for link in nlinks: @@ -129,7 +129,7 @@ def _get_links(self, nlinks) -> List[Tuple[x, x]]: links.append((source, target)) return links - def _get_ctrlflow_links(self) -> List[Tuple[x, x]]: + def _get_ctrlflow_links(self) -> List[Tuple[str, str]]: """Returns a list of unique tuples of link ids""" links = [] for ilink in {link.tuple for link in list(self.es)}: diff --git a/fireworks/utilities/fw_serializers.py b/fireworks/utilities/fw_serializers.py index 1948d20e7..1935f588f 100644 --- a/fireworks/utilities/fw_serializers.py +++ b/fireworks/utilities/fw_serializers.py @@ -33,11 +33,12 @@ import json # note that ujson is faster, but at this time does not support "default" in dumps() import pkgutil import traceback -from typing import Any, Mapping, MutableMapping, Optional, Type +from typing import Any, Mapping, MutableMapping, Optional, Type, TypeVar import ruamel.yaml as yaml from monty.json import MontyDecoder, MSONable +from fireworks.core.types import FromDict from fireworks.fw_config import ( DECODE_MONTY, ENCODE_MONTY, @@ -178,6 +179,9 @@ def _decorator(self, *args, **kwargs): return _decorator +T = TypeVar("T", bound="FWSerializable") + + class FWSerializable(metaclass=abc.ABCMeta): """ To create a serializable object within FireWorks, you should subclass this @@ -218,7 +222,7 @@ def as_dict(self) -> Mapping[Any, Any]: @classmethod @abc.abstractmethod - def from_dict(cls, m_dict) -> "FWSerializable": + def from_dict(cls: Type[T], m_dict: FromDict) -> T: raise NotImplementedError("FWSerializable object did not implement from_dict()!") def __repr__(self) -> str: @@ -240,7 +244,7 @@ def to_format(self, f_format: str = "json", **kwargs) -> str: raise ValueError(f"Unsupported format {f_format}") @classmethod - def from_format(cls, f_str: str, f_format: str = "json") -> "FWSerializable": + def from_format(cls: Type[T], f_str: str, f_format: str = "json") -> T: """ convert from a String representation to its Object. @@ -275,7 +279,7 @@ def to_file(self, filename: str, f_format: Optional[str] = None, **kwargs) -> No f.write(self.to_format(f_format=f_format, **kwargs)) @classmethod - def from_file(cls, filename: str, f_format: Optional[str] = None) -> "FWSerializable": + def from_file(cls: Type[T], filename: str, f_format: Optional[str] = None) -> T: """ Load a serialization of this object from a file. From 3edd478d4898a87cc8f937ccdac58fadecd3981d Mon Sep 17 00:00:00 2001 From: Eric Berquist Date: Mon, 7 Feb 2022 08:43:01 -0500 Subject: [PATCH 12/18] hacking --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 001427573..0e11e2f5b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ module = [ "argcomplete", "atomate", "atomate.vasp", + "atomate.vasp.workflows", "fabric", "fireworks_schema", "flask_paginate", From 2dc55a020fe541ec800dfc8b8d2b8b628b10affd Mon Sep 17 00:00:00 2001 From: Eric Berquist Date: Fri, 31 Mar 2023 13:43:48 -0400 Subject: [PATCH 13/18] fix bad rebase --- fireworks/scripts/lpad_run.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/fireworks/scripts/lpad_run.py b/fireworks/scripts/lpad_run.py index 3928d8ba8..9a6340370 100644 --- a/fireworks/scripts/lpad_run.py +++ b/fireworks/scripts/lpad_run.py @@ -275,13 +275,7 @@ def print_fws(ids: Sequence[int], lp: LaunchPad, args: Namespace) -> None: print(args.output(fws)) -<<<<<<< HEAD -def get_fw_ids_helper(lp: LaunchPad, args: Namespace, count_only: Union[bool, None] = None) -> Union[List[int], int]: -||||||| parent of a224c6c5 (continue typing) -def get_fw_ids_helper(lp, args, count_only=None): -======= -def get_fw_ids_helper(lp, args, count_only: Optional[bool] = None) -> List[int]: ->>>>>>> a224c6c5 (continue typing) +def get_fw_ids_helper(lp, args, count_only: Optional[bool] = None) -> Union[List[int], int]: """Build fws query from command line options and submit. Parameters: From 38f6908b0f0ca6cd64f1d2cc7d837431f0f75542 Mon Sep 17 00:00:00 2001 From: Eric Berquist Date: Fri, 31 Mar 2023 13:44:11 -0400 Subject: [PATCH 14/18] fix black and isort version issues --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 16f0df906..25140b9c5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,13 +14,13 @@ repos: args: [--in-place, --remove-all-unused-imports, --remove-unused-variable, --ignore-init-module-imports] - repo: https://github.com/psf/black - rev: 22.1.0 + rev: 22.3.0 hooks: - id: black args: [--line-length, '120'] - repo: https://github.com/PyCQA/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort From cc55fe8f8cf01e1a0aef6275aa24a0058f010f7c Mon Sep 17 00:00:00 2001 From: Eric Berquist Date: Fri, 31 Mar 2023 13:45:29 -0400 Subject: [PATCH 15/18] blacken --- fireworks/core/tests/test_launchpad.py | 227 +++++++++---------------- 1 file changed, 77 insertions(+), 150 deletions(-) diff --git a/fireworks/core/tests/test_launchpad.py b/fireworks/core/tests/test_launchpad.py index 6cb340bb2..647707623 100644 --- a/fireworks/core/tests/test_launchpad.py +++ b/fireworks/core/tests/test_launchpad.py @@ -15,12 +15,11 @@ from monty.os import cd from pymongo import MongoClient -from pymongo.errors import OperationFailure from pymongo import __version__ as PYMONGO_VERSION +from pymongo.errors import OperationFailure import fireworks.fw_config -from fireworks import Firework, FWorker, LaunchPad, Workflow, FiretaskBase, \ - explicit_serialize +from fireworks import Firework, FWorker, LaunchPad, Workflow from fireworks.core.rocket_launcher import launch_rocket, rapidfire from fireworks.core.tests.tasks import ( DetoursTask, @@ -44,33 +43,28 @@ class AuthenticationTest(unittest.TestCase): def setUpClass(cls): try: client = MongoClient() - client.not_the_admin_db.command("createUser", "myuser", - pwd="mypassword", roles=["dbOwner"]) + client.not_the_admin_db.command("createUser", "myuser", pwd="mypassword", roles=["dbOwner"]) except Exception: - raise unittest.SkipTest( - "MongoDB is not running in localhost:27017! Skipping tests.") + raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.") def test_no_admin_privileges_for_plebs(self): """Normal users can not authenticate against the admin db.""" with self.assertRaises(OperationFailure): - lp = LaunchPad(name="admin", username="myuser", - password="mypassword", authsource="admin") + lp = LaunchPad(name="admin", username="myuser", password="mypassword", authsource="admin") lp.db.collection.count_documents({}) def test_authenticating_to_users_db(self): """A user should be able to authenticate against a database that they are a user of. """ - lp = LaunchPad(name="not_the_admin_db", username="myuser", - password="mypassword", authsource="not_the_admin_db") + lp = LaunchPad(name="not_the_admin_db", username="myuser", password="mypassword", authsource="not_the_admin_db") lp.db.collection.count_documents({}) def test_authsource_infered_from_db_name(self): """The default behavior is to authenticate against the db that the user is trying to access. """ - lp = LaunchPad(name="not_the_admin_db", username="myuser", - password="mypassword") + lp = LaunchPad(name="not_the_admin_db", username="myuser", password="mypassword") lp.db.collection.count_documents({}) @@ -83,8 +77,7 @@ def setUpClass(cls): cls.lp = LaunchPad(name=TESTDB_NAME, strm_lvl="ERROR") cls.lp.reset(password=None, require_password=False) except Exception: - raise unittest.SkipTest( - "MongoDB is not running in localhost:27017! Skipping tests.") + raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.") @classmethod def tearDownClass(cls): @@ -98,8 +91,7 @@ def setUp(self): self.lp.to_file(self.LP_LOC) def tearDown(self): - self.lp.reset(password=None, require_password=False, - max_reset_wo_password=1000) + self.lp.reset(password=None, require_password=False, max_reset_wo_password=1000) # Delete launch locations if os.path.exists(os.path.join("FW.json")): os.remove("FW.json") @@ -130,8 +122,7 @@ def test_reset(self): # test failsafe in a strict way for _ in range(30): - self.lp.add_wf( - Workflow([Firework(ScriptTask.from_str('echo "hello"'))])) + self.lp.add_wf(Workflow([Firework(ScriptTask.from_str('echo "hello"'))])) self.assertRaises(ValueError, self.lp.reset, "") self.lp.reset("", False, 100) # reset back @@ -164,10 +155,8 @@ def test_add_wfs(self): wfs = [] for _ in range(50): # Add two workflows with 3 and 5 simple fireworks - wf3 = Workflow([Firework(ftask, name="lorem") for _ in range(3)], - name="lorem wf") - wf5 = Workflow([Firework(ftask, name="lorem") for _ in range(5)], - name="lorem wf") + wf3 = Workflow([Firework(ftask, name="lorem") for _ in range(3)], name="lorem wf") + wf5 = Workflow([Firework(ftask, name="lorem") for _ in range(5)], name="lorem wf") wfs.extend([wf3, wf5]) self.lp.bulk_add_wfs(wfs) num_fws_total = sum(len(wf) for wf in wfs) @@ -186,8 +175,7 @@ def setUpClass(cls): cls.lp = LaunchPad(name=TESTDB_NAME, strm_lvl="ERROR") cls.lp.reset(password=None, require_password=False) except Exception: - raise unittest.SkipTest( - "MongoDB is not running in localhost:27017! Skipping tests.") + raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.") @classmethod def tearDownClass(cls): @@ -198,96 +186,80 @@ def setUp(self): # define the individual FireWorks used in the Workflow # Parent Firework fw_p = Firework( - ScriptTask.from_str('echo "Cronus is the ruler of titans"', - {"store_stdout": True}), name="parent", fw_id=1 + ScriptTask.from_str('echo "Cronus is the ruler of titans"', {"store_stdout": True}), name="parent", fw_id=1 ) # Sibling fireworks fw_s1 = Firework( - ScriptTask.from_str('echo "Zeus is son of Cronus"', - {"store_stdout": True}), + ScriptTask.from_str('echo "Zeus is son of Cronus"', {"store_stdout": True}), name="sib1", fw_id=2, parents=fw_p, ) fw_s2 = Firework( - ScriptTask.from_str('echo "Poisedon is brother of Zeus"', - {"store_stdout": True}), + ScriptTask.from_str('echo "Poisedon is brother of Zeus"', {"store_stdout": True}), name="sib2", fw_id=3, parents=fw_p, ) fw_s3 = Firework( - ScriptTask.from_str('echo "Hades is brother of Zeus"', - {"store_stdout": True}), + ScriptTask.from_str('echo "Hades is brother of Zeus"', {"store_stdout": True}), name="sib3", fw_id=4, parents=fw_p, ) fw_s4 = Firework( - ScriptTask.from_str('echo "Demeter is sister & wife of Zeus"', - {"store_stdout": True}), + ScriptTask.from_str('echo "Demeter is sister & wife of Zeus"', {"store_stdout": True}), name="sib4", fw_id=5, parents=fw_p, ) fw_s5 = Firework( - ScriptTask.from_str('echo "Lapetus is son of Oceanus"', - {"store_stdout": True}), name="cousin1", fw_id=6 + ScriptTask.from_str('echo "Lapetus is son of Oceanus"', {"store_stdout": True}), name="cousin1", fw_id=6 ) # Children fireworks fw_c1 = Firework( - ScriptTask.from_str('echo "Ares is son of Zeus"', - {"store_stdout": True}), name="c1", fw_id=7, - parents=fw_s1 + ScriptTask.from_str('echo "Ares is son of Zeus"', {"store_stdout": True}), name="c1", fw_id=7, parents=fw_s1 ) fw_c2 = Firework( ScriptTask.from_str( - 'echo "Persephone is daughter of Zeus & Demeter and wife of Hades"', - {"store_stdout": True} + 'echo "Persephone is daughter of Zeus & Demeter and wife of Hades"', {"store_stdout": True} ), name="c2", fw_id=8, parents=[fw_s1, fw_s4], ) fw_c3 = Firework( - ScriptTask.from_str( - 'echo "Makaria is daughter of Hades & Persephone"', - {"store_stdout": True}), + ScriptTask.from_str('echo "Makaria is daughter of Hades & Persephone"', {"store_stdout": True}), name="c3", fw_id=9, parents=[fw_s3, fw_c2], ) fw_c4 = Firework( - ScriptTask.from_str('echo "Dione is descendant of Lapetus"', - {"store_stdout": True}), + ScriptTask.from_str('echo "Dione is descendant of Lapetus"', {"store_stdout": True}), name="c4", fw_id=10, parents=fw_s5, ) fw_c5 = Firework( - ScriptTask.from_str('echo "Aphrodite is son of of Zeus and Dione"', - {"store_stdout": True}), + ScriptTask.from_str('echo "Aphrodite is son of of Zeus and Dione"', {"store_stdout": True}), name="c5", fw_id=11, parents=[fw_s1, fw_c4], ) fw_c6 = Firework( - ScriptTask.from_str('echo "Atlas is son of of Lapetus"', - {"store_stdout": True}), + ScriptTask.from_str('echo "Atlas is son of of Lapetus"', {"store_stdout": True}), name="c6", fw_id=12, parents=fw_s5, ) fw_c7 = Firework( - ScriptTask.from_str('echo "Maia is daughter of Atlas"', - {"store_stdout": True}), + ScriptTask.from_str('echo "Maia is daughter of Atlas"', {"store_stdout": True}), name="c7", fw_id=13, parents=fw_c6, ) fw_c8 = Firework( - ScriptTask.from_str('echo "Hermes is daughter of Maia and Zeus"', - {"store_stdout": True}), + ScriptTask.from_str('echo "Hermes is daughter of Maia and Zeus"', {"store_stdout": True}), name="c8", fw_id=14, parents=[fw_s1, fw_c7], @@ -295,8 +267,7 @@ def setUp(self): # assemble Workflow from FireWorks and their connections by id workflow = Workflow( - [fw_p, fw_s1, fw_s2, fw_s3, fw_s4, fw_s5, fw_c1, fw_c2, fw_c3, - fw_c4, fw_c5, fw_c6, fw_c7, fw_c8] + [fw_p, fw_s1, fw_s2, fw_s3, fw_s4, fw_s5, fw_c1, fw_c2, fw_c3, fw_c4, fw_c5, fw_c6, fw_c7, fw_c8] ) self.lp.add_wf(workflow) @@ -307,11 +278,11 @@ def setUp(self): self.zeus_sib_fw_ids = {3, 4, 5} self.par_fw_id = 1 self.all_ids = ( - self.zeus_child_fw_ids - | self.lapetus_desc_fw_ids - | self.zeus_sib_fw_ids - | {self.zeus_fw_id} - | {self.par_fw_id} + self.zeus_child_fw_ids + | self.lapetus_desc_fw_ids + | self.zeus_sib_fw_ids + | {self.zeus_fw_id} + | {self.par_fw_id} ) self.old_wd = os.getcwd() @@ -348,8 +319,7 @@ def test_pause_fw(self): self.assertTrue(self.zeus_sib_fw_ids.issubset(completed_ids)) # Check that Zeus and children are subset of incompleted fwids - fws_no_run = set( - self.lp.get_fw_ids({"state": {"$nin": ["COMPLETED"]}})) + fws_no_run = set(self.lp.get_fw_ids({"state": {"$nin": ["COMPLETED"]}})) self.assertIn(self.zeus_fw_id, fws_no_run) self.assertTrue(self.zeus_child_fw_ids.issubset(fws_no_run)) @@ -383,8 +353,7 @@ def test_defuse_fw(self): self.assertTrue(self.zeus_sib_fw_ids.issubset(completed_ids)) # Check that Zeus and children are subset of incompleted fwids - fws_no_run = set( - self.lp.get_fw_ids({"state": {"$nin": ["COMPLETED"]}})) + fws_no_run = set(self.lp.get_fw_ids({"state": {"$nin": ["COMPLETED"]}})) self.assertIn(self.zeus_fw_id, fws_no_run) self.assertTrue(self.zeus_child_fw_ids.issubset(fws_no_run)) except Exception: @@ -585,10 +554,7 @@ def test_rerun_fws2(self): self.assertFalse(fw_start_t > ts) -@unittest.skipIf( - PYMONGO_MAJOR_VERSION > 3, - "detect lostruns test not supported for pymongo major version > 3" -) +@unittest.skipIf(PYMONGO_MAJOR_VERSION > 3, "detect lostruns test not supported for pymongo major version > 3") class LaunchPadLostRunsDetectTest(unittest.TestCase): @classmethod def setUpClass(cls): @@ -598,8 +564,7 @@ def setUpClass(cls): cls.lp = LaunchPad(name=TESTDB_NAME, strm_lvl="ERROR") cls.lp.reset(password=None, require_password=False) except Exception: - raise unittest.SkipTest( - "MongoDB is not running in localhost:27017! Skipping tests.") + raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.") @classmethod def tearDownClass(cls): @@ -681,8 +646,7 @@ def run(self): # Wait for fw to start it = 0 - while not any([f.state == "RUNNING" for f in - self.lp.get_wf_by_fw_id_lzyfw(self.fw_id).fws]): + while not any([f.state == "RUNNING" for f in self.lp.get_wf_by_fw_id_lzyfw(self.fw_id).fws]): time.sleep(1) # Wait 1 sec it += 1 if it == 10: @@ -714,8 +678,7 @@ def run(self): # Wait for running it = 0 - while not any([f.state == "RUNNING" for f in - self.lp.get_wf_by_fw_id_lzyfw(self.fw_id).fws]): + while not any([f.state == "RUNNING" for f in self.lp.get_wf_by_fw_id_lzyfw(self.fw_id).fws]): time.sleep(1) # Wait 1 sec it += 1 if it == 10: @@ -728,6 +691,7 @@ def run(self): self.assertEqual(wf.fw_states[fw_id], "RUNNING") rp.terminate() + class WorkflowFireworkStatesTest(unittest.TestCase): """ Class to test the firework states locally cached in workflow. @@ -742,8 +706,7 @@ def setUpClass(cls): cls.lp = LaunchPad(name=TESTDB_NAME, strm_lvl="ERROR") cls.lp.reset(password=None, require_password=False) except Exception: - raise unittest.SkipTest( - "MongoDB is not running in localhost:27017! Skipping tests.") + raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.") @classmethod def tearDownClass(cls): @@ -754,95 +717,79 @@ def setUp(self): # define the individual FireWorks used in the Workflow # Parent Firework fw_p = Firework( - ScriptTask.from_str('echo "Cronus is the ruler of titans"', - {"store_stdout": True}), name="parent", fw_id=1 + ScriptTask.from_str('echo "Cronus is the ruler of titans"', {"store_stdout": True}), name="parent", fw_id=1 ) # Sibling fireworks # fw_s1 = Firework(ScriptTask.from_str( # 'echo "Zeus is son of Cronus"', # {'store_stdout':True}), name="sib1", fw_id=2, parents=fw_p) # Timed firework - fw_s1 = Firework(PyTask(func="time.sleep", args=[5]), name="sib1", - fw_id=2, parents=fw_p) + fw_s1 = Firework(PyTask(func="time.sleep", args=[5]), name="sib1", fw_id=2, parents=fw_p) fw_s2 = Firework( - ScriptTask.from_str('echo "Poisedon is brother of Zeus"', - {"store_stdout": True}), + ScriptTask.from_str('echo "Poisedon is brother of Zeus"', {"store_stdout": True}), name="sib2", fw_id=3, parents=fw_p, ) fw_s3 = Firework( - ScriptTask.from_str('echo "Hades is brother of Zeus"', - {"store_stdout": True}), + ScriptTask.from_str('echo "Hades is brother of Zeus"', {"store_stdout": True}), name="sib3", fw_id=4, parents=fw_p, ) fw_s4 = Firework( - ScriptTask.from_str('echo "Demeter is sister & wife of Zeus"', - {"store_stdout": True}), + ScriptTask.from_str('echo "Demeter is sister & wife of Zeus"', {"store_stdout": True}), name="sib4", fw_id=5, parents=fw_p, ) fw_s5 = Firework( - ScriptTask.from_str('echo "Lapetus is son of Oceanus"', - {"store_stdout": True}), name="cousin1", fw_id=6 + ScriptTask.from_str('echo "Lapetus is son of Oceanus"', {"store_stdout": True}), name="cousin1", fw_id=6 ) # Children fireworks fw_c1 = Firework( - ScriptTask.from_str('echo "Ares is son of Zeus"', - {"store_stdout": True}), name="c1", fw_id=7, - parents=fw_s1 + ScriptTask.from_str('echo "Ares is son of Zeus"', {"store_stdout": True}), name="c1", fw_id=7, parents=fw_s1 ) fw_c2 = Firework( ScriptTask.from_str( - 'echo "Persephone is daughter of Zeus & Demeter and wife of Hades"', - {"store_stdout": True} + 'echo "Persephone is daughter of Zeus & Demeter and wife of Hades"', {"store_stdout": True} ), name="c2", fw_id=8, parents=[fw_s1, fw_s4], ) fw_c3 = Firework( - ScriptTask.from_str( - 'echo "Makaria is daughter of Hades & Persephone"', - {"store_stdout": True}), + ScriptTask.from_str('echo "Makaria is daughter of Hades & Persephone"', {"store_stdout": True}), name="c3", fw_id=9, parents=[fw_s3, fw_c2], ) fw_c4 = Firework( - ScriptTask.from_str('echo "Dione is descendant of Lapetus"', - {"store_stdout": True}), + ScriptTask.from_str('echo "Dione is descendant of Lapetus"', {"store_stdout": True}), name="c4", fw_id=10, parents=fw_s5, ) fw_c5 = Firework( - ScriptTask.from_str('echo "Aphrodite is son of of Zeus and Dione"', - {"store_stdout": True}), + ScriptTask.from_str('echo "Aphrodite is son of of Zeus and Dione"', {"store_stdout": True}), name="c5", fw_id=11, parents=[fw_s1, fw_c4], ) fw_c6 = Firework( - ScriptTask.from_str('echo "Atlas is son of of Lapetus"', - {"store_stdout": True}), + ScriptTask.from_str('echo "Atlas is son of of Lapetus"', {"store_stdout": True}), name="c6", fw_id=12, parents=fw_s5, ) fw_c7 = Firework( - ScriptTask.from_str('echo "Maia is daughter of Atlas"', - {"store_stdout": True}), + ScriptTask.from_str('echo "Maia is daughter of Atlas"', {"store_stdout": True}), name="c7", fw_id=13, parents=fw_c6, ) fw_c8 = Firework( - ScriptTask.from_str('echo "Hermes is daughter of Maia and Zeus"', - {"store_stdout": True}), + ScriptTask.from_str('echo "Hermes is daughter of Maia and Zeus"', {"store_stdout": True}), name="c8", fw_id=14, parents=[fw_s1, fw_c7], @@ -850,8 +797,7 @@ def setUp(self): # assemble Workflow from FireWorks and their connections by id workflow = Workflow( - [fw_p, fw_s1, fw_s2, fw_s3, fw_s4, fw_s5, fw_c1, fw_c2, fw_c3, - fw_c4, fw_c5, fw_c6, fw_c7, fw_c8] + [fw_p, fw_s1, fw_s2, fw_s3, fw_s4, fw_s5, fw_c1, fw_c2, fw_c3, fw_c4, fw_c5, fw_c6, fw_c7, fw_c8] ) self.lp.add_wf(workflow) @@ -862,11 +808,11 @@ def setUp(self): self.zeus_sib_fw_ids = {3, 4, 5} self.par_fw_id = 1 self.all_ids = ( - self.zeus_child_fw_ids - | self.lapetus_desc_fw_ids - | self.zeus_sib_fw_ids - | {self.zeus_fw_id} - | {self.par_fw_id} + self.zeus_child_fw_ids + | self.lapetus_desc_fw_ids + | self.zeus_sib_fw_ids + | {self.zeus_fw_id} + | {self.par_fw_id} ) self.old_wd = os.getcwd() @@ -1023,8 +969,7 @@ def run(self): self.assertEqual(fw_state, fw_cache_state) # Detect lost runs - lost_lids, lost_fwids, inconsistent_fwids = self.lp.detect_lostruns( - expiration_secs=0.5) + lost_lids, lost_fwids, inconsistent_fwids = self.lp.detect_lostruns(expiration_secs=0.5) # Ensure the states are sync wf = self.lp.get_wf_by_fw_id_lzyfw(self.zeus_fw_id) fws = wf.id_fw @@ -1071,8 +1016,7 @@ def setUpClass(cls): cls.lp = LaunchPad(name=TESTDB_NAME, strm_lvl="ERROR") cls.lp.reset(password=None, require_password=False) except Exception: - raise unittest.SkipTest( - "MongoDB is not running in localhost:27017! Skipping tests.") + raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.") @classmethod def tearDownClass(cls): @@ -1086,8 +1030,7 @@ def setUp(self): fw = Firework( [ ExecutionCounterTask(), - ScriptTask.from_str('date +"%s %N"', - parameters={"stdout_file": "date_file"}), + ScriptTask.from_str('date +"%s %N"', parameters={"stdout_file": "date_file"}), ExceptionTestTask(exc_details=self.error_test_dict), ] ) @@ -1141,8 +1084,7 @@ def test_task_level_rerun_cp(self): self.assertEqual(self.lp.get_fw_by_id(1).state, "COMPLETED") self.assertEqual(ExecutionCounterTask.exec_counter, 1) self.assertEqual(ExceptionTestTask.exec_counter, 2) - self.assertTrue(filecmp.cmp(os.path.join(dirs[0], "date_file"), - os.path.join(dirs[1], "date_file"))) + self.assertTrue(filecmp.cmp(os.path.join(dirs[0], "date_file"), os.path.join(dirs[1], "date_file"))) def test_task_level_rerun_prev_dir(self): rapidfire(self.lp, self.fworker, m_dir=MODULE_DIR) @@ -1153,8 +1095,7 @@ def test_task_level_rerun_prev_dir(self): fw = self.lp.get_fw_by_id(1) self.assertEqual(os.getcwd(), MODULE_DIR) self.assertEqual(fw.state, "COMPLETED") - self.assertEqual(fw.launches[0].launch_dir, - fw.archived_launches[0].launch_dir) + self.assertEqual(fw.launches[0].launch_dir, fw.archived_launches[0].launch_dir) self.assertEqual(ExecutionCounterTask.exec_counter, 1) self.assertEqual(ExceptionTestTask.exec_counter, 2) @@ -1168,8 +1109,7 @@ def setUpClass(cls): cls.lp = LaunchPad(name=TESTDB_NAME, strm_lvl="ERROR") cls.lp.reset(password=None, require_password=False) except Exception: - raise unittest.SkipTest( - "MongoDB is not running in localhost:27017! Skipping tests.") + raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.") @classmethod def tearDownClass(cls): @@ -1178,17 +1118,13 @@ def tearDownClass(cls): def setUp(self): # set the defaults in the init of wflock to break the lock quickly - fireworks.core.launchpad.WFLock(3, - False).__init__.__func__.__defaults__ = ( - 3, False) + fireworks.core.launchpad.WFLock(3, False).__init__.__func__.__defaults__ = (3, False) self.error_test_dict = {"error": "description", "error_code": 1} fw_slow = Firework(SlowAdditionTask(), spec={"seconds": 10}, fw_id=1) - fw_fast = Firework(WaitWFLockTask(), fw_id=2, - spec={"_add_launchpad_and_fw_id": True}) + fw_fast = Firework(WaitWFLockTask(), fw_id=2, spec={"_add_launchpad_and_fw_id": True}) fw_child = Firework(ScriptTask.from_str('echo "child"'), fw_id=3) - wf = Workflow([fw_slow, fw_fast, fw_child], - {fw_slow: fw_child, fw_fast: fw_child}) + wf = Workflow([fw_slow, fw_fast, fw_child], {fw_slow: fw_child, fw_fast: fw_child}) self.lp.add_wf(wf) self.old_wd = os.getcwd() @@ -1226,8 +1162,7 @@ def run(self): fast_fw = self.lp.get_fw_by_id(2) if fast_fw.state == "FIZZLED": - stacktrace = self.lp.launches.find_one({"fw_id": 2}, { - "action.stored_data._exception._stacktrace": 1})[ + stacktrace = self.lp.launches.find_one({"fw_id": 2}, {"action.stored_data._exception._stacktrace": 1})[ "action" ]["stored_data"]["_exception"]["_stacktrace"] if "SkipTest" in stacktrace: @@ -1276,8 +1211,7 @@ def run(self): fast_fw = self.lp.get_fw_by_id(2) if fast_fw.state == "FIZZLED": - stacktrace = self.lp.launches.find_one({"fw_id": 2}, { - "action.stored_data._exception._stacktrace": 1})[ + stacktrace = self.lp.launches.find_one({"fw_id": 2}, {"action.stored_data._exception._stacktrace": 1})[ "action" ]["stored_data"]["_exception"]["_stacktrace"] if "SkipTest" in stacktrace: @@ -1306,8 +1240,7 @@ def setUpClass(cls): cls.lp = LaunchPad(name=TESTDB_NAME, strm_lvl="ERROR") cls.lp.reset(password=None, require_password=False) except Exception: - raise unittest.SkipTest( - "MongoDB is not running in localhost:27017! Skipping tests.") + raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.") @classmethod def tearDownClass(cls): @@ -1318,9 +1251,7 @@ def setUp(self): fireworks.core.firework.EXCEPT_DETAILS_ON_RERUN = True self.error_test_dict = {"error": "description", "error_code": 1} - fw = Firework( - ScriptTask.from_str('echo "test offline"', {"store_stdout": True}), - name="offline_fw", fw_id=1) + fw = Firework(ScriptTask.from_str('echo "test offline"', {"store_stdout": True}), name="offline_fw", fw_id=1) self.lp.add_wf(fw) self.launch_dir = os.path.join(MODULE_DIR, "launcher_offline") @@ -1362,17 +1293,14 @@ def test_recover_errors(self): shutil.rmtree(self.launch_dir) # recover ignoring errors - self.assertIsNotNone( - self.lp.recover_offline(launch_id, ignore_errors=True, - print_errors=True)) + self.assertIsNotNone(self.lp.recover_offline(launch_id, ignore_errors=True, print_errors=True)) fw = self.lp.get_fw_by_id(launch_id) self.assertEqual(fw.state, "RESERVED") # fizzle - self.assertIsNotNone( - self.lp.recover_offline(launch_id, ignore_errors=False)) + self.assertIsNotNone(self.lp.recover_offline(launch_id, ignore_errors=False)) fw = self.lp.get_fw_by_id(launch_id) @@ -1393,8 +1321,7 @@ def setUpClass(cls): cls.lp = LaunchPad(name=TESTDB_NAME, strm_lvl="ERROR") cls.lp.reset(password=None, require_password=False) except Exception: - raise unittest.SkipTest( - "MongoDB is not running in localhost:27017! Skipping tests.") + raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.") @classmethod def tearDownClass(cls): From d417462c293da17cf84570f3cd2e3c58023178ef Mon Sep 17 00:00:00 2001 From: Eric Berquist Date: Fri, 31 Mar 2023 13:52:07 -0400 Subject: [PATCH 16/18] update mypy version --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 25140b9c5..85dfd3a91 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,7 +25,7 @@ repos: - id: isort - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.931 + rev: v1.1.1 hooks: - id: mypy additional_dependencies: From c8a9d282694762ef5db012b131fae74f32323bd5 Mon Sep 17 00:00:00 2001 From: Eric Berquist Date: Fri, 31 Mar 2023 13:52:29 -0400 Subject: [PATCH 17/18] abstract collection types should not be returned --- fireworks/core/launchpad.py | 2 +- fireworks/utilities/fw_serializers.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/fireworks/core/launchpad.py b/fireworks/core/launchpad.py index 75ce03a73..d0f1ffd4a 100644 --- a/fireworks/core/launchpad.py +++ b/fireworks/core/launchpad.py @@ -53,7 +53,7 @@ # TODO: lots of duplication reduction and cleanup possible -def sort_aggregation(sort: Sequence[Tuple[str, int]]) -> List[Mapping[str, Any]]: +def sort_aggregation(sort: Sequence[Tuple[str, int]]) -> List[Dict[str, Any]]: """Build sorting aggregation pipeline. Args: diff --git a/fireworks/utilities/fw_serializers.py b/fireworks/utilities/fw_serializers.py index 1935f588f..ca4a6c53e 100644 --- a/fireworks/utilities/fw_serializers.py +++ b/fireworks/utilities/fw_serializers.py @@ -33,7 +33,7 @@ import json # note that ujson is faster, but at this time does not support "default" in dumps() import pkgutil import traceback -from typing import Any, Mapping, MutableMapping, Optional, Type, TypeVar +from typing import Any, Dict, MutableMapping, Optional, Type, TypeVar import ruamel.yaml as yaml from monty.json import MontyDecoder, MSONable @@ -208,13 +208,13 @@ def fw_name(self) -> str: return get_default_serialization(self.__class__) @abc.abstractmethod - def to_dict(self) -> Mapping[Any, Any]: + def to_dict(self) -> Dict[Any, Any]: raise NotImplementedError("FWSerializable object did not implement to_dict()!") - def to_db_dict(self) -> Mapping[Any, Any]: + def to_db_dict(self) -> Dict[Any, Any]: return self.to_dict() - def as_dict(self) -> Mapping[Any, Any]: + def as_dict(self) -> Dict[Any, Any]: # strictly for pseudo-compatibility with MSONable # Note that FWSerializable is not MSONable, it uses _fw_name instead of __class__ and # __module__ From f35ae5baeda03544a149795d5a59e3dac55237d9 Mon Sep 17 00:00:00 2001 From: Eric Berquist Date: Fri, 31 Mar 2023 13:55:30 -0400 Subject: [PATCH 18/18] tqdm now has type stubs --- .pre-commit-config.yaml | 1 + pyproject.toml | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 85dfd3a91..254fc9855 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -42,6 +42,7 @@ repos: - types-setuptools - types-six - types-tabulate + - types-tqdm args: [] - repo: https://github.com/pre-commit/pre-commit-hooks diff --git a/pyproject.toml b/pyproject.toml index 0e11e2f5b..220f5b595 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,5 @@ module = [ "monty.shutil", "pymatgen.core", "pytest", - # https://github.com/tqdm/tqdm/issues/260 - "tqdm", ] ignore_missing_imports = true