diff --git a/.circleci/config.yml b/.circleci/config.yml
index 60e7b0264..a53874990 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -26,35 +26,7 @@ jobs:
pip install .[workflow-checks,graph-plotting,flask-plotting]
pytest fireworks
- pytest_pymongo4:
- working_directory: ~/fireworks
- docker:
- - image: continuumio/miniconda3:4.6.14
- - image: circleci/mongo:latest
- steps:
- - checkout
- - run:
- command: |
- export PATH=$HOME/miniconda3/bin:$PATH
- conda config --set always_yes yes --set changeps1 no
- conda update -q conda
- conda info -a
- conda create -q -n test-environment python=3.8
- source activate test-environment
- conda update --quiet --all
- pip install --quiet --ignore-installed -r requirements.txt -r requirements-ci.txt
- - run:
- name: Run fireworks tests
- command: |
- export PATH=$HOME/miniconda3/bin:$PATH
- source activate test-environment
- pip install --quiet -e .
- pip install --quiet --upgrade pymongo
- pytest fireworks
-
workflows:
version: 2
build_and_test:
- jobs:
- - pytest
- - pytest_pymongo4
+ jobs: [pytest]
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
new file mode 100644
index 000000000..1326ed539
--- /dev/null
+++ b/.github/workflows/test.yml
@@ -0,0 +1,36 @@
+name: Test
+
+on:
+ push:
+ branches: [main]
+ pull_request:
+ branches: [main]
+
+jobs:
+ pytest:
+ runs-on: ubuntu-latest
+
+ services:
+ mongodb:
+ image: mongo
+ ports:
+ - 27017:27017
+
+ steps:
+ - name: Checkout repo
+ uses: actions/checkout@v4
+
+ - name: Set up Python
+ uses: actions/setup-python@v5
+ with:
+ python-version: 3.8
+
+ - name: Install dependencies
+ run: |
+ pip install -r requirements.txt -r requirements-ci.txt
+
+ - name: Run fireworks tests
+ shell: bash -l {0}
+ run: |
+ pip install .[workflow-checks,graph-plotting,flask-plotting]
+ pytest fireworks
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 4c8742642..a389706dc 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -2,7 +2,7 @@ exclude: ^docs
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
- rev: v0.1.3
+ rev: v0.3.5
hooks:
- id: ruff
args: [--fix, --ignore, D]
diff --git a/README.md b/README.md
index a26744fd6..36f29e505 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,6 @@
-#
+
+
+
FireWorks stores, executes, and manages calculation workflows.
@@ -9,7 +11,8 @@ FireWorks stores, executes, and manages calculation workflows.
If you like FireWorks, you might also like [rocketsled](https://github.com/hackingmaterials/rocketsled).
If you find FireWorks useful, please consider citing the paper:
-```
+
+```txt
Jain, A., Ong, S. P., Chen, W., Medasani, B., Qu, X., Kocher, M., Brafman, M.,
Petretto, G., Rignanese, G.-M., Hautier, G., Gunter, D., and Persson, K. A.
(2015) FireWorks: a dynamic workflow system designed for high-throughput
diff --git a/docs_rst/_static/fireworks-logo.svg b/docs_rst/_static/fireworks-logo.svg
new file mode 100644
index 000000000..675c5b4d3
--- /dev/null
+++ b/docs_rst/_static/fireworks-logo.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/docs_rst/conf.py b/docs_rst/conf.py
index 619286ecb..75895e36f 100644
--- a/docs_rst/conf.py
+++ b/docs_rst/conf.py
@@ -183,11 +183,11 @@
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
- #'papersize': 'letterpaper',
+ # 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
- #'pointsize': '10pt',
+ # 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
- #'preamble': '',
+ # 'preamble': '',
}
# Grouping the document tree into LaTeX files. List of tuples
@@ -309,5 +309,5 @@ def skip(app, what, name, obj, skip, options):
# AJ: a hack found online to get __init__ to show up in docs
-def setup(app):
+def setup(app) -> None:
app.connect("autodoc-skip-member", skip)
diff --git a/docs_rst/config_tutorial.rst b/docs_rst/config_tutorial.rst
index 366a326fa..e8b9f18e7 100644
--- a/docs_rst/config_tutorial.rst
+++ b/docs_rst/config_tutorial.rst
@@ -77,6 +77,7 @@ A few basic parameters that can be tweaked are:
* ``WEBSERVER_PORT: 5000`` - the default port on which to run the web server
* ``QUEUE_JOBNAME_MAXLEN: 20`` - the max length of the job name to send to the queuing system (some queuing systems limit the size of job names)
* ``MONGOMOCK_SERVERSTORE_FILE`` - path to a non-empty JSON file, if set then mongomock will be used instead of MongoDB; this file should be initialized with '{}'
+* ``ROCKET_STREAM_LOGLEVEL: INFO`` - the streaming log level of the rocket launcher logger (valid values: DEBUG, INFO, WARNING, ERROR, CRITICAL)
Parameters that you probably shouldn't change
---------------------------------------------
diff --git a/docs_rst/firetask_tutorial.rst b/docs_rst/firetask_tutorial.rst
index 20061d752..cdcae15f1 100644
--- a/docs_rst/firetask_tutorial.rst
+++ b/docs_rst/firetask_tutorial.rst
@@ -150,7 +150,7 @@ Let's explore custom Firetasks with an example: a custom Python script for addin
input_array = fw_spec['input_array']
m_sum = sum(input_array)
- print(f"The sum of {input_array} is: {m_sum}"
+ print(f"The sum of {input_array} is: {m_sum}")
return FWAction(stored_data={'sum': m_sum}, mod_spec=[{'_push': {'input_array': m_sum}}])
diff --git a/fireworks/core/firework.py b/fireworks/core/firework.py
index f9a4b1ee3..e95015e82 100644
--- a/fireworks/core/firework.py
+++ b/fireworks/core/firework.py
@@ -7,13 +7,16 @@
- A FiretaskBase defines the contract for tasks that run within a Firework (Firetasks).
- A FWAction encapsulates the output of a Firetask and tells FireWorks what to do next after a job completes.
"""
+
+from __future__ import annotations
+
import abc
import os
import pprint
from collections import defaultdict
from copy import deepcopy
from datetime import datetime
-from typing import Any, Dict, Iterator, List, Optional, Sequence
+from typing import Any, Iterator, NoReturn, Sequence
from monty.io import reverse_readline, zopen
from monty.os.path import zpath
@@ -33,7 +36,7 @@
__date__ = "Feb 5, 2013"
-class FiretaskBase(defaultdict, FWSerializable, metaclass=abc.ABCMeta):
+class FiretaskBase(defaultdict, FWSerializable, abc.ABC):
"""
FiretaskBase is used like an abstract class that defines a computing task
(Firetask). All Firetasks should inherit from FiretaskBase.
@@ -46,7 +49,7 @@ class FiretaskBase(defaultdict, FWSerializable, metaclass=abc.ABCMeta):
# if set to a list of str, only required and optional kwargs are allowed; consistency checked upon init
optional_params = None
- def __init__(self, *args, **kwargs):
+ def __init__(self, *args, **kwargs) -> None:
dict.__init__(self, *args, **kwargs)
required_params = self.required_params or []
@@ -65,7 +68,7 @@ def __init__(self, *args, **kwargs):
)
@abc.abstractmethod
- def run_task(self, fw_spec):
+ def run_task(self, fw_spec) -> NoReturn:
"""
This method gets called when the Firetask is run. It can take in a
Firework spec, perform some task using that data, and then return an
@@ -98,7 +101,7 @@ def to_dict(self):
def from_dict(cls, m_dict):
return cls(m_dict)
- def __repr__(self):
+ def __repr__(self) -> str:
return f"<{self.fw_name}>:{dict(self)}"
# not strictly needed here for pickle/unpickle, but complements __setstate__
@@ -133,7 +136,7 @@ def __init__(
defuse_children=False,
defuse_workflow=False,
propagate=False,
- ):
+ ) -> None:
"""
Args:
stored_data (dict): data to store from the run. Does not affect the operation of FireWorks.
@@ -206,7 +209,7 @@ def skip_remaining_tasks(self):
"""
return self.exit or self.detours or self.additions or self.defuse_children or self.defuse_workflow
- def __str__(self):
+ def __str__(self) -> str:
return "FWAction\n" + pprint.pformat(self.to_dict())
@@ -238,7 +241,7 @@ def __init__(
fw_id=None,
parents=None,
updated_on=None,
- ):
+ ) -> None:
"""
Args:
tasks (Firetask or [Firetask]): a list of Firetasks to run in sequence.
@@ -260,7 +263,7 @@ def __init__(
if fw_id is not None:
self.fw_id = fw_id
else:
- global NEGATIVE_FWID_CTR
+ global NEGATIVE_FWID_CTR # noqa: PLW0603
NEGATIVE_FWID_CTR -= 1
self.fw_id = NEGATIVE_FWID_CTR
@@ -283,7 +286,7 @@ def state(self):
return self._state
@state.setter
- def state(self, state):
+ def state(self, state) -> None:
"""
Setter for the FW state, which triggers updated_on change.
@@ -315,7 +318,7 @@ def to_dict(self):
return m_dict
- def _rerun(self):
+ def _rerun(self) -> None:
"""
Moves all Launches to archived Launches and resets the state to 'WAITING'. The Firework
can thus be re-run even if it was Launched in the past. This method should be called by
@@ -364,7 +367,7 @@ def from_dict(cls, m_dict):
tasks, m_dict["spec"], name, launches, archived_launches, state, created_on, fw_id, updated_on=updated_on
)
- def __str__(self):
+ def __str__(self) -> str:
return f"Firework object: (id: {int(self.fw_id)} , name: {self.fw_name})"
def __iter__(self) -> Iterator[FiretaskBase]:
@@ -382,7 +385,7 @@ class Tracker(FWSerializable):
MAX_TRACKER_LINES = 1000
- def __init__(self, filename, nlines=TRACKER_LINES, content="", allow_zipped=False):
+ def __init__(self, filename, nlines=TRACKER_LINES, content="", allow_zipped=False) -> None:
"""
Args:
filename (str)
@@ -434,7 +437,7 @@ def from_dict(cls, m_dict):
m_dict["filename"], m_dict["nlines"], m_dict.get("content", ""), m_dict.get("allow_zipped", False)
)
- def __str__(self):
+ def __str__(self) -> str:
return f"### Filename: {self.filename}\n{self.content}"
@@ -453,7 +456,7 @@ def __init__(
state_history=None,
launch_id=None,
fw_id=None,
- ):
+ ) -> None:
"""
Args:
state (str): the state of the Launch (e.g. RUNNING, COMPLETED)
@@ -480,7 +483,7 @@ def __init__(
self.launch_id = launch_id
self.fw_id = fw_id
- def touch_history(self, update_time=None, checkpoint=None):
+ def touch_history(self, update_time=None, checkpoint=None) -> None:
"""
Updates the update_on field of the state history of a Launch. Used to ping that a Launch
is still alive.
@@ -493,7 +496,7 @@ def touch_history(self, update_time=None, checkpoint=None):
self.state_history[-1]["checkpoint"] = checkpoint
self.state_history[-1]["updated_on"] = update_time
- def set_reservation_id(self, reservation_id):
+ def set_reservation_id(self, reservation_id) -> None:
"""
Adds the job_id to the reservation.
@@ -514,7 +517,7 @@ def state(self):
return self._state
@state.setter
- def state(self, state):
+ def state(self, state) -> None:
"""
Setter for the Launch's state. Automatically triggers an update to state_history.
@@ -624,7 +627,7 @@ def from_dict(cls, m_dict):
m_dict["fw_id"],
)
- def _update_state_history(self, state):
+ def _update_state_history(self, state) -> None:
"""
Internal method to update the state history whenever the Launch state is modified.
@@ -672,7 +675,7 @@ class Workflow(FWSerializable):
class Links(dict, FWSerializable):
"""An inner class for storing the DAG links between FireWorks."""
- def __init__(self, *args, **kwargs):
+ def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
for k, v in list(self.items()):
@@ -762,12 +765,12 @@ def __reduce__(self):
def __init__(
self,
fireworks: Sequence[Firework],
- links_dict: Optional[Dict[int, List[int]]] = None,
- name: Optional[str] = None,
- metadata: Optional[Dict[str, Any]] = None,
- created_on: Optional[datetime] = None,
- updated_on: Optional[datetime] = None,
- fw_states: Optional[Dict[int, str]] = None,
+ links_dict: dict[int, list[int]] | None = None,
+ name: str | None = None,
+ metadata: dict[str, Any] | None = None,
+ created_on: datetime | None = None,
+ updated_on: datetime | None = None,
+ fw_states: dict[int, str] | None = None,
) -> None:
"""
Args:
@@ -784,7 +787,7 @@ def __init__(
links_dict = links_dict or {}
# main dict containing mapping of an id to a Firework object
- self.id_fw: Dict[int, Firework] = {}
+ self.id_fw: dict[int, Firework] = {}
for fw in fireworks:
if fw.fw_id in self.id_fw:
raise ValueError("FW ids must be unique!")
@@ -825,7 +828,7 @@ def __init__(
self.fw_states = fw_states or {key: val.state for key, val in self.id_fw.items()}
@property
- def fws(self) -> List[Firework]:
+ def fws(self) -> list[Firework]:
"""Return list of all fireworks."""
return list(self.id_fw.values())
@@ -882,7 +885,7 @@ def state(self) -> str:
m_state = "RESERVED"
return m_state
- def apply_action(self, action: FWAction, fw_id: int) -> List[int]:
+ def apply_action(self, action: FWAction, fw_id: int) -> list[int]:
"""
Apply a FWAction on a Firework in the Workflow.
@@ -903,7 +906,7 @@ def apply_action(self, action: FWAction, fw_id: int) -> List[int]:
# Traverse whole sub-workflow down to leaves.
visited_cfid = set() # avoid double-updating for diamond deps
- def recursive_update_spec(fw_id):
+ def recursive_update_spec(fw_id) -> None:
for cfid in self.links[fw_id]:
if cfid not in visited_cfid:
visited_cfid.add(cfid)
@@ -923,7 +926,7 @@ def recursive_update_spec(fw_id):
if action.mod_spec and action.propagate:
visited_cfid = set()
- def recursive_mod_spec(fw_id):
+ def recursive_mod_spec(fw_id) -> None:
for cfid in self.links[fw_id]:
if cfid not in visited_cfid:
visited_cfid.add(cfid)
@@ -1144,7 +1147,7 @@ def refresh(self, fw_id, updated_ids=None):
return updated_ids
@property
- def root_fw_ids(self) -> List[int]:
+ def root_fw_ids(self) -> list[int]:
"""
Gets root FireWorks of this workflow (those with no parents).
@@ -1157,7 +1160,7 @@ def root_fw_ids(self) -> List[int]:
return list(root_ids)
@property
- def leaf_fw_ids(self) -> List[int]:
+ def leaf_fw_ids(self) -> list[int]:
"""
Gets leaf FireWorks of this workflow (those with no children).
@@ -1170,7 +1173,7 @@ def leaf_fw_ids(self) -> List[int]:
leaf_ids.append(id)
return leaf_ids
- def _reassign_ids(self, old_new: Dict[int, int]) -> None:
+ def _reassign_ids(self, old_new: dict[int, int]) -> None:
"""
Internal method to reassign Firework ids, e.g. due to database insertion.
@@ -1195,7 +1198,7 @@ def _reassign_ids(self, old_new: Dict[int, int]) -> None:
new_fw_states[old_new.get(fwid, fwid)] = fw_state
self.fw_states = new_fw_states
- def to_dict(self) -> Dict[str, Any]:
+ def to_dict(self) -> dict[str, Any]:
return {
"fws": [f.to_dict() for f in self.id_fw.values()],
"links": self.links.to_dict(),
@@ -1205,7 +1208,7 @@ def to_dict(self) -> Dict[str, Any]:
"created_on": self.created_on,
}
- def to_db_dict(self) -> Dict[str, Any]:
+ def to_db_dict(self) -> dict[str, Any]:
m_dict = self.links.to_db_dict()
m_dict["metadata"] = self.metadata
m_dict["state"] = self.state
@@ -1270,7 +1273,7 @@ def _get_representative_launch(fw):
return m_launch
@classmethod
- def from_wflow(cls, wflow: "Workflow") -> "Workflow":
+ def from_wflow(cls, wflow: Workflow) -> Workflow:
"""
Create a fresh Workflow from an existing one.
@@ -1294,7 +1297,7 @@ def reset(self, reset_ids: bool = True) -> None:
if reset_ids:
old_new = {} # mapping between old and new Firework ids
for fw_id, fw in self.id_fw.items():
- global NEGATIVE_FWID_CTR
+ global NEGATIVE_FWID_CTR # noqa: PLW0603
NEGATIVE_FWID_CTR -= 1
new_id = NEGATIVE_FWID_CTR
old_new[fw_id] = new_id
@@ -1306,7 +1309,7 @@ def reset(self, reset_ids: bool = True) -> None:
self.fw_states = {key: self.id_fw[key].state for key in self.id_fw}
@classmethod
- def from_dict(cls, m_dict: Dict[str, Any]) -> "Workflow":
+ def from_dict(cls, m_dict: dict[str, Any]) -> Workflow:
"""
Return Workflow from its dict representation.
@@ -1331,7 +1334,7 @@ def from_dict(cls, m_dict: Dict[str, Any]) -> "Workflow":
return Workflow.from_Firework(Firework.from_dict(m_dict))
@classmethod
- def from_Firework(cls, fw: Firework, name: Optional[str] = None, metadata=None) -> "Workflow":
+ def from_Firework(cls, fw: Firework, name: str | None = None, metadata=None) -> Workflow:
"""
Return Workflow from the given Firework.
@@ -1346,10 +1349,10 @@ def from_Firework(cls, fw: Firework, name: Optional[str] = None, metadata=None)
name = name if name else fw.name
return Workflow([fw], None, name=name, metadata=metadata, created_on=fw.created_on, updated_on=fw.updated_on)
- def __str__(self):
+ def __str__(self) -> str:
return f"Workflow object: (fw_ids: {[*self.id_fw]} , name: {self.name})"
- def remove_fws(self, fw_ids):
+ def remove_fws(self, fw_ids) -> None:
"""
Remove the fireworks corresponding to the input firework ids and update the workflow i.e the
parents of the removed fireworks become the parents of the children fireworks (only if the
diff --git a/fireworks/core/fworker.py b/fireworks/core/fworker.py
index 559e140ed..08c0395ac 100644
--- a/fireworks/core/fworker.py
+++ b/fireworks/core/fworker.py
@@ -19,7 +19,7 @@
class FWorker(FWSerializable):
- def __init__(self, name="Automatically generated Worker", category="", query=None, env=None):
+ def __init__(self, name="Automatically generated Worker", category="", query=None, env=None) -> None:
"""
Args:
name (str): the name of the resource, should be unique
diff --git a/fireworks/core/launchpad.py b/fireworks/core/launchpad.py
index b08cc5f76..3665177bc 100644
--- a/fireworks/core/launchpad.py
+++ b/fireworks/core/launchpad.py
@@ -92,7 +92,7 @@ class WFLock:
Calling functions are responsible for handling the error in order to avoid database inconsistencies.
"""
- def __init__(self, lp, fw_id, expire_secs=WFLOCK_EXPIRATION_SECS, kill=WFLOCK_EXPIRATION_KILL):
+ def __init__(self, lp, fw_id, expire_secs=WFLOCK_EXPIRATION_SECS, kill=WFLOCK_EXPIRATION_KILL) -> None:
"""
Args:
lp (LaunchPad)
@@ -156,7 +156,7 @@ def __init__(
authsource=None,
uri_mode=False,
mongoclient_kwargs=None,
- ):
+ ) -> None:
"""
Args:
host (str): hostname. If uri_mode is True, a MongoDB connection string URI
@@ -244,7 +244,7 @@ def to_dict(self):
"mongoclient_kwargs": self.mongoclient_kwargs,
}
- def update_spec(self, fw_ids, spec_document, mongo=False):
+ def update_spec(self, fw_ids, spec_document, mongo=False) -> None:
"""
Update fireworks with a spec. Sometimes you need to modify a firework in progress.
@@ -301,7 +301,7 @@ def auto_load(cls):
return LaunchPad.from_file(LAUNCHPAD_LOC)
return LaunchPad()
- def reset(self, password, require_password=True, max_reset_wo_password=25):
+ def reset(self, password, require_password=True, max_reset_wo_password=25) -> None:
"""
Create a new FireWorks database. This will overwrite the existing FireWorks database! To
safeguard against accidentally erasing an existing database, a password must be entered.
@@ -337,7 +337,7 @@ def reset(self, password, require_password=True, max_reset_wo_password=25):
else:
raise ValueError(f"Invalid password! Password is today's date: {m_password}")
- def maintain(self, infinite=True, maintain_interval=None):
+ def maintain(self, infinite=True, maintain_interval=None) -> None:
"""
Perform launchpad maintenance: detect lost runs and unreserved RESERVE launches.
@@ -377,7 +377,9 @@ def add_wf(self, wf, reassign_all=True):
Add workflow(or firework) to the launchpad. The firework ids will be reassigned.
Args:
- wf (Workflow/Firework)
+ wf (Workflow/Firework): Workflow or Firework object
+ reassign_all (bool): If True, the firework ids will be assigned
+ starting from the next available id. Defaults to True.
Returns:
dict: mapping between old and new Firework ids
@@ -398,7 +400,7 @@ def add_wf(self, wf, reassign_all=True):
self.m_logger.info(f"Added a workflow. id_map: {old_new}")
return old_new
- def bulk_add_wfs(self, wfs):
+ def bulk_add_wfs(self, wfs) -> None:
"""
Adds a list of workflows to the fireworks database
using insert_many for both the fws and wfs, is
@@ -439,7 +441,7 @@ def bulk_add_wfs(self, wfs):
self.fireworks.insert_many(fw.to_db_dict() for fw in all_fws)
return
- def append_wf(self, new_wf, fw_ids, detour=False, pull_spec_mods=True):
+ def append_wf(self, new_wf, fw_ids, detour=False, pull_spec_mods=True) -> None:
"""
Append a new workflow on top of an existing workflow.
@@ -461,7 +463,7 @@ def get_launch_by_id(self, launch_id):
Given a Launch id, return details of the Launch.
Args:
- launch_id (int): launch id
+ launch_id (int): launch id.
Returns:
Launch object
@@ -477,7 +479,7 @@ def get_fw_dict_by_id(self, fw_id):
Given firework id, return firework dict.
Args:
- fw_id (int): firework id
+ fw_id (int): Firework id.
Returns:
dict
@@ -513,11 +515,10 @@ def get_fw_by_id(self, fw_id):
return Firework.from_dict(self.get_fw_dict_by_id(fw_id))
def get_wf_by_fw_id(self, fw_id):
- """
- Given a Firework id, give back the Workflow containing that Firework.
+ """Given a Firework id, give back the Workflow containing that Firework.
Args:
- fw_id (int)
+ fw_id (int): Firework id.
Returns:
A Workflow object
@@ -535,12 +536,11 @@ def get_wf_by_fw_id(self, fw_id):
links_dict["updated_on"],
)
- def get_wf_by_fw_id_lzyfw(self, fw_id):
- """
- Given a FireWork id, give back the Workflow containing that FireWork.
+ def get_wf_by_fw_id_lzyfw(self, fw_id: int) -> Workflow:
+ """Given a FireWork id, give back the Workflow containing that FireWork.
Args:
- fw_id (int)
+ fw_id (int): FireWork id.
Returns:
A Workflow object
@@ -549,9 +549,10 @@ def get_wf_by_fw_id_lzyfw(self, fw_id):
if not links_dict:
raise ValueError(f"Could not find a Workflow with fw_id: {fw_id}")
- fws = []
- for fw_id in links_dict["nodes"]:
- fws.append(LazyFirework(fw_id, self.fireworks, self.launches, self.gridfs_fallback))
+ fws = [
+ LazyFirework(fw_id, self.fireworks, self.launches, self.gridfs_fallback) for fw_id in links_dict["nodes"]
+ ]
+
# Check for fw_states in links_dict to conform with pre-optimized workflows
fw_states = {int(k): v for k, v in links_dict["fw_states"].items()} if "fw_states" in links_dict else None
@@ -565,13 +566,13 @@ def get_wf_by_fw_id_lzyfw(self, fw_id):
fw_states,
)
- def delete_fws(self, fw_ids, delete_launch_dirs=False):
+ def delete_fws(self, fw_ids, delete_launch_dirs=False) -> None:
"""Delete a set of fireworks identified by their fw_ids.
ATTENTION: This function serves maintenance purposes and will leave
workflows untouched. Its use will thus result in a corrupted database.
Use 'delete_wf' instead for consistently deleting workflows together
- with theit fireworks.
+ with their fireworks.
Args:
fw_ids ([int]): Firework ids
@@ -580,35 +581,40 @@ def delete_fws(self, fw_ids, delete_launch_dirs=False):
"""
potential_launch_ids = []
launch_ids = []
- for i in fw_ids:
- fw_dict = self.fireworks.find_one({"fw_id": i})
- potential_launch_ids += fw_dict["launches"] + fw_dict["archived_launches"]
-
- for i in potential_launch_ids: # only remove launches if no other fws refer to them
+ for fw_id in fw_ids:
+ fw_dict = self.fireworks.find_one({"fw_id": fw_id})
+ if fw_dict:
+ potential_launch_ids += fw_dict.get("launches", []) + fw_dict.get("archived_launches", [])
+
+ launch_ids = [
+ launch_id
+ for launch_id in potential_launch_ids
if not self.fireworks.find_one(
- {"$or": [{"launches": i}, {"archived_launches": i}], "fw_id": {"$nin": fw_ids}}, {"launch_id": 1}
- ):
- launch_ids.append(i)
+ {"$or": [{"launches": launch_id}, {"archived_launches": launch_id}], "fw_id": {"$nin": fw_ids}},
+ {"launch_id": 1},
+ )
+ ]
if delete_launch_dirs:
- launch_dirs = []
- for i in launch_ids:
- launch_dirs.append(self.launches.find_one({"launch_id": i}, {"launch_dir": 1})["launch_dir"])
+ launch_dirs = [
+ self.launches.find_one({"launch_id": launch_id}, {"launch_dir": 1})["launch_dir"]
+ for launch_id in launch_ids
+ ]
print(f"Remove folders {launch_dirs}")
- for d in launch_dirs:
- shutil.rmtree(d, ignore_errors=True)
+ for launch_dir in launch_dirs:
+ shutil.rmtree(launch_dir, ignore_errors=True)
print(f"Remove fws {fw_ids}")
if self.gridfs_fallback is not None:
- for lid in launch_ids:
- for f in self.gridfs_fallback.find({"metadata.launch_id": lid}):
- self.gridfs_fallback.delete(f._id)
+ for launch_id in launch_ids:
+ for file_id in self.gridfs_fallback.find({"metadata.launch_id": launch_id}):
+ self.gridfs_fallback.delete(file_id._id)
print(f"Remove launches {launch_ids}")
self.launches.delete_many({"launch_id": {"$in": launch_ids}})
self.offline_runs.delete_many({"launch_id": {"$in": launch_ids}})
self.fireworks.delete_many({"fw_id": {"$in": fw_ids}})
- def delete_wf(self, fw_id, delete_launch_dirs=False):
+ def delete_wf(self, fw_id, delete_launch_dirs=False) -> None:
"""
Delete the workflow containing firework with the given id.
@@ -643,8 +649,7 @@ def get_wf_summary_dict(self, fw_id, mode="more"):
if mode != "less":
wf_fields.append("updated_on")
fw_fields.extend(["name", "launches"])
- launch_fields.append("launch_id")
- launch_fields.append("launch_dir")
+ launch_fields.extend(("launch_id", "launch_dir"))
if mode == "reservations":
launch_fields.append("state_history.reservation_id")
@@ -883,7 +888,7 @@ def run_exists(self, fworker=None):
q = fworker.query if fworker else {}
return bool(self._get_a_fw_to_run(query=q, checkout=False))
- def future_run_exists(self, fworker=None):
+ def future_run_exists(self, fworker=None) -> bool:
"""Check if database has any current OR future Fireworks available.
Returns:
@@ -906,7 +911,7 @@ def future_run_exists(self, fworker=None):
# there is no future work to do
return False
- def tuneup(self, bkground=True):
+ def tuneup(self, bkground=True) -> None:
"""Database tuneup: build indexes."""
self.m_logger.info("Performing db tune-up")
@@ -1027,7 +1032,7 @@ def resume_fw(self, fw_id):
self._refresh_wf(fw_id)
return f
- def defuse_wf(self, fw_id, defuse_all_states=True):
+ def defuse_wf(self, fw_id, defuse_all_states=True) -> None:
"""
Defuse the workflow containing the given firework id.
@@ -1040,7 +1045,7 @@ def defuse_wf(self, fw_id, defuse_all_states=True):
if fw.state not in ["COMPLETED", "FIZZLED"] or defuse_all_states:
self.defuse_fw(fw.fw_id)
- def pause_wf(self, fw_id):
+ def pause_wf(self, fw_id) -> None:
"""
Pause the workflow containing the given firework id.
@@ -1052,7 +1057,7 @@ def pause_wf(self, fw_id):
if fw.state not in ["COMPLETED", "FIZZLED", "DEFUSED"]:
self.pause_fw(fw.fw_id)
- def reignite_wf(self, fw_id):
+ def reignite_wf(self, fw_id) -> None:
"""
Reignite the workflow containing the given firework id.
@@ -1063,7 +1068,7 @@ def reignite_wf(self, fw_id):
for fw in wf:
self.reignite_fw(fw.fw_id)
- def archive_wf(self, fw_id):
+ def archive_wf(self, fw_id) -> None:
"""
Archive the workflow containing the given firework id.
@@ -1085,7 +1090,7 @@ def archive_wf(self, fw_id):
)
self._refresh_wf(fw.fw_id)
- def _restart_ids(self, next_fw_id, next_launch_id):
+ def _restart_ids(self, next_fw_id, next_launch_id) -> None:
"""
internal method used to reset firework id counters.
@@ -1099,7 +1104,7 @@ def _restart_ids(self, next_fw_id, next_launch_id):
)
self.m_logger.debug(f"RESTARTED fw_id, launch_id to ({next_fw_id}, {next_launch_id})")
- def _check_fw_for_uniqueness(self, m_fw):
+ def _check_fw_for_uniqueness(self, m_fw) -> bool:
"""
Check if there are duplicates. If not unique, a new id is assigned and the workflow
refreshed.
@@ -1183,7 +1188,7 @@ def reserve_fw(self, fworker, launch_dir, host=None, ip=None, fw_id=None):
fw_id (int): fw_id to be reserved, if desired
Returns:
- (Firework, int): the checked out firework and the new launch id
+ (Firework, int): the checked out firework and the new launch id.
"""
return self.checkout_fw(fworker, launch_dir, host=host, ip=ip, fw_id=fw_id, state="RESERVED")
@@ -1197,13 +1202,10 @@ def get_fw_ids_from_reservation_id(self, reservation_id):
Returns:
[int]: list of firework ids.
"""
- fw_ids = []
l_id = self.launches.find_one({"state_history.reservation_id": reservation_id}, {"launch_id": 1})["launch_id"]
- for fw in self.fireworks.find({"launches": l_id}, {"fw_id": 1}):
- fw_ids.append(fw["fw_id"])
- return fw_ids
+ return [fw["fw_id"] for fw in self.fireworks.find({"launches": l_id}, {"fw_id": 1})]
- def cancel_reservation_by_reservation_id(self, reservation_id):
+ def cancel_reservation_by_reservation_id(self, reservation_id) -> None:
"""Given the reservation id, cancel the reservation and rerun the corresponding fireworks."""
l_id = self.launches.find_one(
{"state_history.reservation_id": reservation_id, "state": "RESERVED"}, {"launch_id": 1}
@@ -1224,7 +1226,7 @@ def get_reservation_id_from_fw_id(self, fw_id):
return None
return None
- def cancel_reservation(self, launch_id):
+ def cancel_reservation(self, launch_id) -> None:
"""Given the launch id, cancel the reservation and rerun the fireworks."""
m_launch = self.get_launch_by_id(launch_id)
m_launch.state = "READY"
@@ -1246,30 +1248,31 @@ def detect_unreserved(self, expiration_secs=RESERVATION_EXPIRATION_SECS, rerun=F
Returns:
[int]: list of expired launch ids
"""
- bad_launch_ids = []
now_time = datetime.datetime.utcnow()
- cutoff_timestr = (now_time - datetime.timedelta(seconds=expiration_secs)).isoformat()
+ cutoff_time_str = (now_time - datetime.timedelta(seconds=expiration_secs)).isoformat()
bad_launch_data = self.launches.find(
{
"state": "RESERVED",
- "state_history": {"$elemMatch": {"state": "RESERVED", "updated_on": {"$lte": cutoff_timestr}}},
+ "state_history": {"$elemMatch": {"state": "RESERVED", "updated_on": {"$lte": cutoff_time_str}}},
},
{"launch_id": 1, "fw_id": 1},
)
- for ld in bad_launch_data:
- if self.fireworks.find_one({"fw_id": ld["fw_id"], "state": "RESERVED"}, {"fw_id": 1}):
- bad_launch_ids.append(ld["launch_id"])
+ bad_launch_ids = [
+ ld["launch_id"]
+ for ld in bad_launch_data
+ if self.fireworks.find_one({"fw_id": ld["fw_id"], "state": "RESERVED"}, {"fw_id": 1})
+ ]
if rerun:
for lid in bad_launch_ids:
self.cancel_reservation(lid)
return bad_launch_ids
- def mark_fizzled(self, launch_id):
+ def mark_fizzled(self, launch_id) -> None:
"""
Mark the launch corresponding to the given id as FIZZLED.
Args:
- launch_id (int): launch id
+ launch_id (int): launch id.
Returns:
dict: updated launch
@@ -1337,10 +1340,10 @@ def detect_lostruns(
potential_lost_fw_ids.append(ld["fw_id"])
for fw_id in potential_lost_fw_ids: # tricky: figure out what's actually lost
- f = self.fireworks.find_one({"fw_id": fw_id}, {"launches": 1, "state": 1})
+ fw = self.fireworks.find_one({"fw_id": fw_id}, {"launches": 1, "state": 1}) or {}
# only RUNNING FireWorks can be "lost", i.e. not defused or archived
- if f["state"] == "RUNNING":
- l_ids = f["launches"]
+ if fw.get("state") == "RUNNING":
+ l_ids = fw["launches"]
not_lost = [x for x in l_ids if x not in lost_launch_ids]
if len(not_lost) == 0: # all launches are lost - we are lost!
lost_fw_ids.append(fw_id)
@@ -1381,7 +1384,7 @@ def detect_lostruns(
return lost_launch_ids, lost_fw_ids, inconsistent_fw_ids
- def set_reservation_id(self, launch_id, reservation_id):
+ def set_reservation_id(self, launch_id, reservation_id) -> None:
"""
Set reservation id to the launch corresponding to the given launch id.
@@ -1407,7 +1410,7 @@ def checkout_fw(self, fworker, launch_dir, fw_id=None, host=None, ip=None, state
state (str): RESERVED or RUNNING, the fetched firework's state will be set to this value.
Returns:
- (Firework, int): firework and the new launch id
+ (Firework, int): firework and the new launch id.
"""
m_fw = self._get_a_fw_to_run(fworker.query, fw_id=fw_id)
if not m_fw:
@@ -1471,7 +1474,7 @@ def checkout_fw(self, fworker, launch_dir, fw_id=None, host=None, ip=None, state
return m_fw, launch_id
- def change_launch_dir(self, launch_id, launch_dir):
+ def change_launch_dir(self, launch_id, launch_dir) -> None:
"""
Change the launch directory corresponding to the given launch id.
@@ -1483,7 +1486,7 @@ def change_launch_dir(self, launch_id, launch_dir):
m_launch.launch_dir = launch_dir
self.launches.find_one_and_replace({"launch_id": m_launch.launch_id}, m_launch.to_db_dict(), upsert=True)
- def restore_backup_data(self, launch_id, fw_id):
+ def restore_backup_data(self, launch_id, fw_id) -> None:
"""For the given launch id and firework id, restore the back up data."""
if launch_id in self.backup_launch_data:
self.launches.find_one_and_replace({"launch_id": launch_id}, self.backup_launch_data[launch_id])
@@ -1541,7 +1544,7 @@ def complete_launch(self, launch_id, action=None, state="COMPLETED"):
# change return type to dict to make return type serializable to support job packing
return m_launch.to_dict()
- def ping_launch(self, launch_id, ptime=None, checkpoint=None):
+ def ping_launch(self, launch_id, ptime=None, checkpoint=None) -> None:
"""
Ping that a Launch is still alive: updates the 'update_on 'field of the state history of a
Launch.
@@ -1655,13 +1658,15 @@ def rerun_fw(self, fw_id, rerun_duplicates=True, recover_launch=None, recover_mo
duplicates = []
reruns = []
if rerun_duplicates:
- f = self.fireworks.find_one({"fw_id": fw_id, "spec._dupefinder": {"$exists": True}}, {"launches": 1})
- if f:
- for d in self.fireworks.find(
- {"launches": {"$in": f["launches"]}, "fw_id": {"$ne": fw_id}}, {"fw_id": 1}
- ):
- duplicates.append(d["fw_id"])
- duplicates = list(set(duplicates))
+ fw = self.fireworks.find_one({"fw_id": fw_id, "spec._dupefinder": {"$exists": True}}, {"launches": 1})
+ if fw:
+ duplicates = [
+ fw_dct["fw_id"]
+ for fw_dct in self.fireworks.find(
+ {"launches": {"$in": fw["launches"]}, "fw_id": {"$ne": fw_id}}, {"fw_id": 1}
+ )
+ ]
+ duplicates = list(set(duplicates))
# Launch recovery
if recover_launch is not None:
@@ -1691,10 +1696,10 @@ def rerun_fw(self, fw_id, rerun_duplicates=True, recover_launch=None, recover_mo
reruns.append(fw_id)
# rerun duplicated FWs
- for f in duplicates:
- self.m_logger.info(f"Also rerunning duplicate fw_id: {f}")
+ for fw in duplicates:
+ self.m_logger.info(f"Also rerunning duplicate fw_id: {fw}")
# False for speed, True shouldn't be needed
- r = self.rerun_fw(f, rerun_duplicates=False, recover_launch=recover_launch, recover_mode=recover_mode)
+ r = self.rerun_fw(fw, rerun_duplicates=False, recover_launch=recover_launch, recover_mode=recover_mode)
reruns.extend(r)
return reruns
@@ -1713,7 +1718,7 @@ def get_recovery(self, fw_id, launch_id="last"):
recovery.update({"_prev_dir": launch.launch_dir, "_launch_id": launch.launch_id})
return recovery
- def _refresh_wf(self, fw_id):
+ def _refresh_wf(self, fw_id) -> None:
"""
Update the FW state of all jobs in workflow.
@@ -1741,7 +1746,7 @@ def _refresh_wf(self, fw_id):
err_message = f"Error refreshing workflow. The full stack trace is: {traceback.format_exc()}"
raise RuntimeError(err_message)
- def _update_wf(self, wf, updated_ids):
+ def _update_wf(self, wf, updated_ids) -> None:
"""
Update the workflow with the updated firework ids.
Note: must be called within an enclosing WFLock.
@@ -1819,7 +1824,7 @@ def _steal_launches(self, thief_fw):
self.m_logger.info(f"Duplicate found! fwids {thief_fw.fw_id} and {potential_match['fw_id']}")
return stolen
- def set_priority(self, fw_id, priority):
+ def set_priority(self, fw_id, priority) -> None:
"""
Set priority to the firework with the given id.
@@ -1837,12 +1842,12 @@ def get_logdir(self):
"""
return self.logdir
- def add_offline_run(self, launch_id, fw_id, name):
+ def add_offline_run(self, launch_id, fw_id, name) -> None:
"""
Add the launch and firework to the offline_run collection.
Args:
- launch_id (int): launch id
+ launch_id (int): launch id.
fw_id (id): firework id
name (str)
"""
@@ -1860,7 +1865,7 @@ def recover_offline(self, launch_id, ignore_errors=False, print_errors=False):
Update the launch state using the offline data in FW_offline.json file.
Args:
- launch_id (int): launch id
+ launch_id (int): launch id.
ignore_errors (bool)
print_errors (bool)
@@ -1886,7 +1891,7 @@ def recover_offline(self, launch_id, ignore_errors=False, print_errors=False):
if not already_running:
m_launch.state = "RUNNING" # this should also add a history item
- checkpoint = offline_data["checkpoint"] if "checkpoint" in offline_data else None
+ checkpoint = offline_data.get("checkpoint", None)
# look for ping file - update the Firework if this is the case
ping_loc = os.path.join(m_launch.launch_dir, "FW_ping.json")
@@ -1952,7 +1957,7 @@ def recover_offline(self, launch_id, ignore_errors=False, print_errors=False):
self.offline_runs.update_one({"launch_id": launch_id}, {"$set": {"completed": True}})
return m_launch.fw_id
- def forget_offline(self, launchid_or_fwid, launch_mode=True):
+ def forget_offline(self, launchid_or_fwid, launch_mode=True) -> None:
"""
Unmark the offline run for the given launch or firework id.
@@ -1988,7 +1993,7 @@ def get_launchdir(self, fw_id, launch_idx=-1):
fw = self.get_fw_by_id(fw_id)
return fw.launches[launch_idx].launch_dir if len(fw.launches) > 0 else None
- def log_message(self, level, message):
+ def log_message(self, level, message) -> None:
"""
Support for job packing.
@@ -2010,7 +2015,7 @@ class LazyFirework:
db_fields = ("name", "fw_id", "spec", "created_on", "state")
db_launch_fields = ("launches", "archived_launches")
- def __init__(self, fw_id, fw_coll, launch_coll, fallback_fs):
+ def __init__(self, fw_id, fw_coll, launch_coll, fallback_fs) -> None:
"""
Args:
fw_id (int): firework id
@@ -2020,7 +2025,7 @@ def __init__(self, fw_id, fw_coll, launch_coll, fallback_fs):
# This is the only attribute known w/o a DB query
self.fw_id = fw_id
self._fwc, self._lc, self._ffs = fw_coll, launch_coll, fallback_fs
- self._launches = {k: False for k in self.db_launch_fields}
+ self._launches = dict.fromkeys(self.db_launch_fields, False)
self._fw, self._lids, self._state = None, None, None
# FireWork methods
@@ -2036,20 +2041,20 @@ def state(self):
return self._state
@state.setter
- def state(self, state):
+ def state(self, state) -> None:
self.partial_fw._state = state
self.partial_fw.updated_on = datetime.datetime.utcnow()
def to_dict(self):
return self.full_fw.to_dict()
- def _rerun(self):
+ def _rerun(self) -> None:
self.full_fw._rerun()
def to_db_dict(self):
return self.full_fw.to_db_dict()
- def __str__(self):
+ def __str__(self) -> str:
return f"LazyFireWork object: (id: {self.fw_id})"
# Properties that shadow FireWork attributes
@@ -2059,7 +2064,7 @@ def tasks(self):
return self.partial_fw.tasks
@tasks.setter
- def tasks(self, value):
+ def tasks(self, value) -> None:
self.partial_fw.tasks = value
@property
@@ -2067,7 +2072,7 @@ def spec(self):
return self.partial_fw.spec
@spec.setter
- def spec(self, value):
+ def spec(self, value) -> None:
self.partial_fw.spec = value
@property
@@ -2075,7 +2080,7 @@ def name(self):
return self.partial_fw.name
@name.setter
- def name(self, value):
+ def name(self, value) -> None:
self.partial_fw.name = value
@property
@@ -2083,7 +2088,7 @@ def created_on(self):
return self.partial_fw.created_on
@created_on.setter
- def created_on(self, value):
+ def created_on(self, value) -> None:
self.partial_fw.created_on = value
@property
@@ -2091,7 +2096,7 @@ def updated_on(self):
return self.partial_fw.updated_on
@updated_on.setter
- def updated_on(self, value):
+ def updated_on(self, value) -> None:
self.partial_fw.updated_on = value
@property
@@ -2101,7 +2106,7 @@ def parents(self):
return []
@parents.setter
- def parents(self, value):
+ def parents(self, value) -> None:
self.partial_fw.parents = value
# Properties that shadow FireWork attributes, but which are
@@ -2112,7 +2117,7 @@ def launches(self):
return self._get_launch_data("launches")
@launches.setter
- def launches(self, value):
+ def launches(self, value) -> None:
self._launches["launches"] = True
self.partial_fw.launches = value
@@ -2121,7 +2126,7 @@ def archived_launches(self):
return self._get_launch_data("archived_launches")
@archived_launches.setter
- def archived_launches(self, value):
+ def archived_launches(self, value) -> None:
self._launches["archived_launches"] = True
self.partial_fw.archived_launches = value
diff --git a/fireworks/core/rocket.py b/fireworks/core/rocket.py
index 9d58962a0..a6edd4fcd 100644
--- a/fireworks/core/rocket.py
+++ b/fireworks/core/rocket.py
@@ -3,6 +3,8 @@
completes the Launch.
"""
+from __future__ import annotations
+
import distutils.dir_util
import errno
import glob
@@ -15,13 +17,12 @@
import traceback
from datetime import datetime
from threading import Event, Thread, current_thread
-from typing import Dict, Union
+from typing import TYPE_CHECKING
from monty.io import zopen
from monty.os.path import zpath
from fireworks.core.firework import Firework, FWAction
-from fireworks.core.fworker import FWorker
from fireworks.core.launchpad import LaunchPad, LockedWorkflowError
from fireworks.fw_config import (
PING_TIME_SECS,
@@ -35,6 +36,9 @@
from fireworks.utilities.dict_mods import apply_mod
from fireworks.utilities.fw_utilities import get_fw_logger
+if TYPE_CHECKING:
+ from fireworks.core.fworker import FWorker
+
__author__ = "Anubhav Jain"
__copyright__ = "Copyright 2013, The Materials Project"
__maintainer__ = "Anubhav Jain"
@@ -56,7 +60,7 @@ def ping_launch(launchpad: LaunchPad, launch_id: int, stop_event: Event, master_
stop_event.wait(PING_TIME_SECS)
-def start_ping_launch(launchpad: LaunchPad, launch_id: int) -> Union[Event, None]:
+def start_ping_launch(launchpad: LaunchPad, launch_id: int) -> Event | None:
fd = FWData()
if fd.MULTIPROCESSING:
if not launch_id:
@@ -69,7 +73,7 @@ def start_ping_launch(launchpad: LaunchPad, launch_id: int) -> Union[Event, None
return ping_stop
-def stop_backgrounds(ping_stop, btask_stops):
+def stop_backgrounds(ping_stop, btask_stops) -> None:
fd = FWData()
if fd.MULTIPROCESSING:
fd.Running_IDs[os.getpid()] = None
@@ -80,7 +84,7 @@ def stop_backgrounds(ping_stop, btask_stops):
b.set()
-def background_task(btask, spec, stop_event, master_thread):
+def background_task(btask, spec, stop_event, master_thread) -> None:
num_launched = 0
while not stop_event.is_set() and master_thread.is_alive():
for task in btask.tasks:
@@ -150,7 +154,8 @@ def run(self, pdb_on_exception: bool = False) -> bool:
launch_id = None # we don't need this in offline mode...
if not m_fw:
- print(f"No FireWorks are ready to run and match query! {self.fworker.query}")
+ msg = f"No FireWorks are ready to run and match query! {self.fworker.query}"
+ l_logger.log(logging.INFO, msg)
return False
final_state = None
@@ -228,7 +233,7 @@ def run(self, pdb_on_exception: bool = False) -> bool:
# start background tasks
if "_background_tasks" in my_spec:
for bt in my_spec["_background_tasks"]:
- btask_stops.append(start_background_task(bt, m_fw.spec))
+ btask_stops.append(start_background_task(bt, m_fw.spec)) # noqa: PERF401
# execute the Firetasks!
for t_counter, t in enumerate(m_fw.tasks[starting_task:], start=starting_task):
@@ -422,7 +427,7 @@ def run(self, pdb_on_exception: bool = False) -> bool:
return True
@staticmethod
- def update_checkpoint(launchpad: LaunchPad, launch_dir: str, launch_id: int, checkpoint: Dict[str, any]) -> None:
+ def update_checkpoint(launchpad: LaunchPad, launch_dir: str, launch_id: int, checkpoint: dict[str, any]) -> None:
"""
Helper function to update checkpoint.
@@ -443,7 +448,7 @@ def update_checkpoint(launchpad: LaunchPad, launch_dir: str, launch_id: int, che
f_out.write(json.dumps(d, ensure_ascii=False))
def decorate_fwaction(
- self, fwaction: FWAction, my_spec: Dict[str, any], m_fw: Firework, launch_dir: str
+ self, fwaction: FWAction, my_spec: dict[str, any], m_fw: Firework, launch_dir: str
) -> FWAction:
if my_spec.get("_pass_job_info"):
job_info = list(my_spec.get("_job_info", []))
diff --git a/fireworks/core/rocket_launcher.py b/fireworks/core/rocket_launcher.py
index e03d982b8..15ac7ebb1 100644
--- a/fireworks/core/rocket_launcher.py
+++ b/fireworks/core/rocket_launcher.py
@@ -1,16 +1,20 @@
"""This module contains methods for launching Rockets, both singly and in rapid-fire mode."""
+from __future__ import annotations
+
import os
import time
from datetime import datetime
-from typing import Optional
+from typing import TYPE_CHECKING
from fireworks.core.fworker import FWorker
-from fireworks.core.launchpad import LaunchPad
from fireworks.core.rocket import Rocket
from fireworks.fw_config import FWORKER_LOC, RAPIDFIRE_SLEEP_SECS
from fireworks.utilities.fw_utilities import create_datestamp_dir, get_fw_logger, log_multi, redirect_local
+if TYPE_CHECKING:
+ from fireworks.core.launchpad import LaunchPad
+
__author__ = "Anubhav Jain"
__copyright__ = "Copyright 2013, The Materials Project"
__maintainer__ = "Anubhav Jain"
@@ -56,15 +60,15 @@ def launch_rocket(launchpad, fworker=None, fw_id=None, strm_lvl="INFO", pdb_on_e
def rapidfire(
launchpad: LaunchPad,
fworker: FWorker = None,
- m_dir: Optional[str] = None,
+ m_dir: str | None = None,
nlaunches: int = 0,
max_loops: int = -1,
- sleep_time: Optional[int] = None,
+ sleep_time: int | None = None,
strm_lvl: str = "INFO",
- timeout: Optional[int] = None,
+ timeout: int | None = None,
local_redirect: bool = False,
pdb_on_exception: bool = False,
-):
+) -> None:
"""
Keeps running Rockets in m_dir until we reach an error. Automatically creates subdirectories
for each Rocket. Usually stops when we run out of FireWorks from the LaunchPad.
diff --git a/fireworks/core/tests/tasks.py b/fireworks/core/tests/tasks.py
index 83aa0bf9f..f13b5ddef 100644
--- a/fireworks/core/tests/tasks.py
+++ b/fireworks/core/tests/tasks.py
@@ -1,4 +1,5 @@
import time
+from typing import NoReturn
from unittest import SkipTest
from fireworks import FiretaskBase, Firework, FWAction
@@ -6,7 +7,7 @@
class SerializableException(Exception):
- def __init__(self, exc_details):
+ def __init__(self, exc_details) -> None:
self.exc_details = exc_details
def to_dict(self):
@@ -17,7 +18,7 @@ def to_dict(self):
class ExceptionTestTask(FiretaskBase):
exec_counter = 0
- def run_task(self, fw_spec):
+ def run_task(self, fw_spec) -> None:
ExceptionTestTask.exec_counter += 1
if not fw_spec.get("skip_exception", False):
raise SerializableException(self["exc_details"])
@@ -27,7 +28,7 @@ def run_task(self, fw_spec):
class ExecutionCounterTask(FiretaskBase):
exec_counter = 0
- def run_task(self, fw_spec):
+ def run_task(self, fw_spec) -> None:
ExecutionCounterTask.exec_counter += 1
@@ -39,7 +40,7 @@ def run_task(self, fw_spec):
@explicit_serialize
class TodictErrorTask(FiretaskBase):
- def to_dict(self):
+ def to_dict(self) -> NoReturn:
raise RuntimeError("to_dict error")
def run_task(self, fw_spec):
@@ -87,7 +88,7 @@ def run_task(self, fw_spec):
@explicit_serialize
class DoNothingTask(FiretaskBase):
- def run_task(self, fw_spec):
+ def run_task(self, fw_spec) -> None:
pass
diff --git a/fireworks/core/tests/test_firework.py b/fireworks/core/tests/test_firework.py
index e8e041391..7d77ef0b1 100644
--- a/fireworks/core/tests/test_firework.py
+++ b/fireworks/core/tests/test_firework.py
@@ -1,6 +1,5 @@
"""TODO: Modify unittest doc."""
-
__author__ = "Shyue Ping Ong"
__copyright__ = "Copyright 2012, The Materials Project"
__maintainer__ = "Shyue Ping Ong"
@@ -17,7 +16,7 @@
class FiretaskBaseTest(unittest.TestCase):
- def test_init(self):
+ def test_init(self) -> None:
class DummyTask(FiretaskBase):
required_params = ["hello"]
@@ -38,7 +37,7 @@ class DummyTask2(FiretaskBase):
with pytest.raises(NotImplementedError):
d.run_task({})
- def test_param_checks(self):
+ def test_param_checks(self) -> None:
class DummyTask(FiretaskBase):
_fw_name = "DummyTask"
required_params = ["param1"]
@@ -60,14 +59,14 @@ def run_task(self, fw_spec):
class FiretaskPickleTest(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
import pickle
self.task = PickleTask(test=0)
self.pkl_task = pickle.dumps(self.task)
self.upkl_task = pickle.loads(self.pkl_task)
- def test_init(self):
+ def test_init(self) -> None:
assert isinstance(self.upkl_task, PickleTask)
assert PickleTask.from_dict(self.task.to_dict()) == self.upkl_task
assert dir(self.task) == dir(self.upkl_task)
@@ -92,12 +91,12 @@ def run_task(self, fw_spec):
class WorkflowTest(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.fw1 = Firework(Task1())
self.fw2 = Firework([Task2(), Task2()], parents=self.fw1)
self.fw3 = Firework(Task1(), parents=self.fw1)
- def test_init(self):
+ def test_init(self) -> None:
fws = []
for i in range(5):
fw = Firework([PyTask(func="print", args=[i])], fw_id=i)
@@ -110,7 +109,7 @@ def test_init(self):
with pytest.raises(ValueError):
Workflow(fws, links_dict={0: [1, 2, 3], 1: [4], 2: [100]})
- def test_copy(self):
+ def test_copy(self) -> None:
"""Test that we can produce a copy of a Workflow but that the copy
has unique fw_ids.
"""
@@ -135,7 +134,7 @@ def test_copy(self):
for child_id, orig_child_id in zip(children, orig_children):
assert orig_child_id == wf_copy.id_fw[child_id].name
- def test_remove_leaf_fws(self):
+ def test_remove_leaf_fws(self) -> None:
fw4 = Firework(Task1(), parents=[self.fw2, self.fw3])
fws = [self.fw1, self.fw2, self.fw3, fw4]
wflow = Workflow(fws)
@@ -146,7 +145,7 @@ def test_remove_leaf_fws(self):
wflow.remove_fws(wflow.leaf_fw_ids)
assert wflow.leaf_fw_ids == parents
- def test_remove_root_fws(self):
+ def test_remove_root_fws(self) -> None:
fw4 = Firework(Task1(), parents=[self.fw2, self.fw3])
fws = [self.fw1, self.fw2, self.fw3, fw4]
wflow = Workflow(fws)
@@ -157,7 +156,7 @@ def test_remove_root_fws(self):
wflow.remove_fws(wflow.root_fw_ids)
assert sorted(wflow.root_fw_ids) == sorted(children)
- def test_iter_len_index(self):
+ def test_iter_len_index(self) -> None:
fws = [self.fw1, self.fw2, self.fw3]
wflow = Workflow(fws)
for idx, fw in enumerate(wflow):
diff --git a/fireworks/core/tests/test_launchpad.py b/fireworks/core/tests/test_launchpad.py
index 78161e2b9..2945e5853 100644
--- a/fireworks/core/tests/test_launchpad.py
+++ b/fireworks/core/tests/test_launchpad.py
@@ -40,27 +40,27 @@ class AuthenticationTest(unittest.TestCase):
"""Tests whether users are authenticating against the correct mongo dbs."""
@classmethod
- def setUpClass(cls):
+ def setUpClass(cls) -> None:
try:
client = fireworks.fw_config.MongoClient()
client.not_the_admin_db.command("createUser", "myuser", pwd="mypassword", roles=["dbOwner"])
except Exception:
raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.")
- def test_no_admin_privileges_for_plebs(self):
+ def test_no_admin_privileges_for_plebs(self) -> None:
"""Normal users can not authenticate against the admin db."""
with pytest.raises(OperationFailure):
lp = LaunchPad(name="admin", username="myuser", password="mypassword", authsource="admin")
lp.db.collection.count_documents({})
- def test_authenticating_to_users_db(self):
+ def test_authenticating_to_users_db(self) -> None:
"""A user should be able to authenticate against a database that they
are a user of.
"""
lp = LaunchPad(name="not_the_admin_db", username="myuser", password="mypassword", authsource="not_the_admin_db")
lp.db.collection.count_documents({})
- def test_authsource_infered_from_db_name(self):
+ def test_authsource_infered_from_db_name(self) -> None:
"""The default behavior is to authenticate against the db that the user
is trying to access.
"""
@@ -70,7 +70,7 @@ def test_authsource_infered_from_db_name(self):
class LaunchPadTest(unittest.TestCase):
@classmethod
- def setUpClass(cls):
+ def setUpClass(cls) -> None:
cls.lp = None
cls.fworker = FWorker()
try:
@@ -80,17 +80,17 @@ def setUpClass(cls):
raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.")
@classmethod
- def tearDownClass(cls):
+ def tearDownClass(cls) -> None:
if cls.lp:
cls.lp.connection.drop_database(TEST_DB_NAME)
cls.lp.connection
- def setUp(self):
+ def setUp(self) -> None:
self.old_wd = os.getcwd()
self.LP_LOC = os.path.join(MODULE_DIR, "launchpad.yaml")
self.lp.to_file(self.LP_LOC)
- def tearDown(self):
+ def tearDown(self) -> None:
self.lp.reset(password=None, require_password=False, max_reset_wo_password=1000)
# Delete launch locations
if os.path.exists(os.path.join("FW.json")):
@@ -102,13 +102,13 @@ def tearDown(self):
if os.path.exists(self.LP_LOC):
os.remove(self.LP_LOC)
- def test_dict_from_file(self):
+ def test_dict_from_file(self) -> None:
lp = LaunchPad.from_file(self.LP_LOC)
lp_dict = lp.to_dict()
new_lp = LaunchPad.from_dict(lp_dict)
assert isinstance(new_lp, LaunchPad)
- def test_reset(self):
+ def test_reset(self) -> None:
# Store some test fireworks
# Attempt couple of ways to reset the lp and check
fw = Firework(ScriptTask.from_str('echo "hello"'), name="hello")
@@ -129,14 +129,14 @@ def test_reset(self):
self.lp.reset("")
self.lp.reset("", False, 100) # reset back
- def test_pw_check(self):
+ def test_pw_check(self) -> None:
fw = Firework(ScriptTask.from_str('echo "hello"'), name="hello")
self.lp.add_wf(fw)
args = ("",)
with pytest.raises(ValueError):
self.lp.reset(*args)
- def test_add_wf(self):
+ def test_add_wf(self) -> None:
fw = Firework(ScriptTask.from_str('echo "hello"'), name="hello")
self.lp.add_wf(fw)
wf_id = self.lp.get_wf_ids()
@@ -153,7 +153,7 @@ def test_add_wf(self):
assert len(fw_ids) == 3
self.lp.reset("", require_password=False)
- def test_add_wfs(self):
+ def test_add_wfs(self) -> None:
ftask = ScriptTask.from_str('echo "lorem ipsum"')
wfs = []
for _ in range(50):
@@ -171,7 +171,7 @@ def test_add_wfs(self):
class LaunchPadDefuseReigniteRerunArchiveDeleteTest(unittest.TestCase):
@classmethod
- def setUpClass(cls):
+ def setUpClass(cls) -> None:
cls.lp = None
cls.fworker = FWorker()
try:
@@ -181,11 +181,11 @@ def setUpClass(cls):
raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.")
@classmethod
- def tearDownClass(cls):
+ def tearDownClass(cls) -> None:
if cls.lp:
cls.lp.connection.drop_database(TEST_DB_NAME)
- def setUp(self):
+ def setUp(self) -> None:
# define the individual FireWorks used in the Workflow
# Parent Firework
fw_p = Firework(
@@ -290,7 +290,7 @@ def setUp(self):
self.old_wd = os.getcwd()
- def tearDown(self):
+ def tearDown(self) -> None:
self.lp.reset(password=None, require_password=False)
# Delete launch locations
if os.path.exists(os.path.join("FW.json")):
@@ -300,12 +300,12 @@ def tearDown(self):
shutil.rmtree(ldir)
@staticmethod
- def _teardown(dests):
+ def _teardown(dests) -> None:
for f in dests:
if os.path.exists(f):
os.remove(f)
- def test_pause_fw(self):
+ def test_pause_fw(self) -> None:
self.lp.pause_fw(self.zeus_fw_id)
paused_ids = self.lp.get_fw_ids({"state": "PAUSED"})
@@ -338,7 +338,7 @@ def test_pause_fw(self):
except Exception:
raise
- def test_defuse_fw(self):
+ def test_defuse_fw(self) -> None:
# defuse Zeus
self.lp.defuse_fw(self.zeus_fw_id)
@@ -362,7 +362,7 @@ def test_defuse_fw(self):
except Exception:
raise
- def test_defuse_fw_after_completion(self):
+ def test_defuse_fw_after_completion(self) -> None:
# Launch rockets in rapidfire
rapidfire(self.lp, self.fworker, m_dir=MODULE_DIR)
# defuse Zeus
@@ -373,7 +373,7 @@ def test_defuse_fw_after_completion(self):
completed_ids = set(self.lp.get_fw_ids({"state": "COMPLETED"}))
assert not self.zeus_child_fw_ids.issubset(completed_ids)
- def test_reignite_fw(self):
+ def test_reignite_fw(self) -> None:
# Defuse Zeus
self.lp.defuse_fw(self.zeus_fw_id)
defused_ids = self.lp.get_fw_ids({"state": "DEFUSED"})
@@ -391,7 +391,7 @@ def test_reignite_fw(self):
assert self.zeus_fw_id in completed_ids
assert self.zeus_child_fw_ids.issubset(completed_ids)
- def test_pause_wf(self):
+ def test_pause_wf(self) -> None:
# pause Workflow containing Zeus
self.lp.pause_wf(self.zeus_fw_id)
paused_ids = self.lp.get_fw_ids({"state": "PAUSED"})
@@ -404,7 +404,7 @@ def test_pause_wf(self):
fws_no_run = set(self.lp.get_fw_ids({"state": {"$nin": ["COMPLETED"]}}))
assert fws_no_run == self.all_ids
- def test_defuse_wf(self):
+ def test_defuse_wf(self) -> None:
# defuse Workflow containing Zeus
self.lp.defuse_wf(self.zeus_fw_id)
defused_ids = self.lp.get_fw_ids({"state": "DEFUSED"})
@@ -417,7 +417,7 @@ def test_defuse_wf(self):
fws_no_run = set(self.lp.get_fw_ids({"state": {"$nin": ["COMPLETED"]}}))
assert fws_no_run == self.all_ids
- def test_defuse_wf_after_partial_run(self):
+ def test_defuse_wf_after_partial_run(self) -> None:
# Run a firework before defusing Zeus
launch_rocket(self.lp, self.fworker)
print("----------\nafter launch rocket\n--------")
@@ -439,7 +439,7 @@ def test_defuse_wf_after_partial_run(self):
fws_no_run = set(self.lp.get_fw_ids({"state": "COMPLETED"}))
assert len(fws_no_run) == 0
- def test_reignite_wf(self):
+ def test_reignite_wf(self) -> None:
# Defuse workflow containing Zeus
self.lp.defuse_wf(self.zeus_fw_id)
defused_ids = self.lp.get_fw_ids({"state": "DEFUSED"})
@@ -456,7 +456,7 @@ def test_reignite_wf(self):
fws_completed = set(self.lp.get_fw_ids({"state": "COMPLETED"}))
assert fws_completed == self.all_ids
- def test_archive_wf(self):
+ def test_archive_wf(self) -> None:
# Run a firework before archiving Zeus
launch_rocket(self.lp, self.fworker)
@@ -474,7 +474,7 @@ def test_archive_wf(self):
fw = self.lp.get_fw_by_id(self.zeus_fw_id)
assert fw.state == "ARCHIVED"
- def test_delete_wf(self):
+ def test_delete_wf(self) -> None:
# Run a firework before deleting Zeus
rapidfire(self.lp, self.fworker, nlaunches=1)
@@ -496,7 +496,7 @@ def test_delete_wf(self):
# Check that the launch dir has not been deleted
assert os.path.isdir(first_ldir)
- def test_delete_wf_and_files(self):
+ def test_delete_wf_and_files(self) -> None:
# Run a firework before deleting Zeus
rapidfire(self.lp, self.fworker, nlaunches=1)
@@ -518,7 +518,7 @@ def test_delete_wf_and_files(self):
# Check that the launch dir has not been deleted
assert not os.path.isdir(first_ldir)
- def test_rerun_fws2(self):
+ def test_rerun_fws2(self) -> None:
# Launch all fireworks
rapidfire(self.lp, self.fworker, m_dir=MODULE_DIR)
fw = self.lp.get_fw_by_id(self.zeus_fw_id)
@@ -560,7 +560,7 @@ def test_rerun_fws2(self):
@unittest.skipIf(PYMONGO_MAJOR_VERSION > 3, "detect lostruns test not supported for pymongo major version > 3")
class LaunchPadLostRunsDetectTest(unittest.TestCase):
@classmethod
- def setUpClass(cls):
+ def setUpClass(cls) -> None:
cls.lp = None
cls.fworker = FWorker()
try:
@@ -570,11 +570,11 @@ def setUpClass(cls):
raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.")
@classmethod
- def tearDownClass(cls):
+ def tearDownClass(cls) -> None:
if cls.lp:
cls.lp.connection.drop_database(TEST_DB_NAME)
- def setUp(self):
+ def setUp(self) -> None:
# Define a timed fireWork
fw_timer = Firework(PyTask(func="time.sleep", args=[5]), name="timer")
self.lp.add_wf(fw_timer)
@@ -584,7 +584,7 @@ def setUp(self):
self.old_wd = os.getcwd()
- def tearDown(self):
+ def tearDown(self) -> None:
self.lp.reset(password=None, require_password=False)
# Delete launch locations
if os.path.exists(os.path.join("FW.json")):
@@ -594,15 +594,15 @@ def tearDown(self):
shutil.rmtree(ldir)
# self.lp.connection.close()
- def test_detect_lostruns(self):
+ def test_detect_lostruns(self) -> None:
# Launch the timed firework in a separate process
class RocketProcess(Process):
- def __init__(self, lpad, fworker):
+ def __init__(self, lpad, fworker) -> None:
super(self.__class__, self).__init__()
self.lpad = lpad
self.fworker = fworker
- def run(self):
+ def run(self) -> None:
launch_rocket(self.lpad, self.fworker)
rp = RocketProcess(self.lp, self.fworker)
@@ -633,15 +633,15 @@ def run(self):
assert (lost_launch_ids, lost_fw_ids) == ([1], [1])
assert self.lp.get_fw_by_id(1).state == "READY"
- def test_detect_lostruns_defuse(self):
+ def test_detect_lostruns_defuse(self) -> None:
# Launch the timed firework in a separate process
class RocketProcess(Process):
- def __init__(self, lpad, fworker):
+ def __init__(self, lpad, fworker) -> None:
super(self.__class__, self).__init__()
self.lpad = lpad
self.fworker = fworker
- def run(self):
+ def run(self) -> None:
launch_rocket(self.lpad, self.fworker)
rp = RocketProcess(self.lp, self.fworker)
@@ -656,24 +656,24 @@ def run(self):
raise ValueError("FW never starts running")
rp.terminate() # Kill the rocket
- lost_launch_ids, lost_fw_ids, i = self.lp.detect_lostruns(0.01)
+ lost_launch_ids, lost_fw_ids, _i = self.lp.detect_lostruns(0.01)
assert (lost_launch_ids, lost_fw_ids) == ([1], [1])
self.lp.defuse_fw(1)
- lost_launch_ids, lost_fw_ids, i = self.lp.detect_lostruns(0.01, rerun=True)
+ lost_launch_ids, lost_fw_ids, _i = self.lp.detect_lostruns(0.01, rerun=True)
assert (lost_launch_ids, lost_fw_ids) == ([1], [])
assert self.lp.get_fw_by_id(1).state == "DEFUSED"
- def test_state_after_run_start(self):
+ def test_state_after_run_start(self) -> None:
# Launch the timed firework in a separate process
class RocketProcess(Process):
- def __init__(self, lpad, fworker):
+ def __init__(self, lpad, fworker) -> None:
super(self.__class__, self).__init__()
self.lpad = lpad
self.fworker = fworker
- def run(self):
+ def run(self) -> None:
launch_rocket(self.lpad, self.fworker)
rp = RocketProcess(self.lp, self.fworker)
@@ -702,7 +702,7 @@ class WorkflowFireworkStatesTest(unittest.TestCase):
"""
@classmethod
- def setUpClass(cls):
+ def setUpClass(cls) -> None:
cls.lp = None
cls.fworker = FWorker()
try:
@@ -712,11 +712,11 @@ def setUpClass(cls):
raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.")
@classmethod
- def tearDownClass(cls):
+ def tearDownClass(cls) -> None:
if cls.lp:
cls.lp.connection.drop_database(TEST_DB_NAME)
- def setUp(self):
+ def setUp(self) -> None:
# define the individual FireWorks used in the Workflow
# Parent Firework
fw_p = Firework(
@@ -820,7 +820,7 @@ def setUp(self):
self.old_wd = os.getcwd()
- def tearDown(self):
+ def tearDown(self) -> None:
self.lp.reset(password=None, require_password=False)
# Delete launch locations
if os.path.exists(os.path.join("FW.json")):
@@ -830,12 +830,12 @@ def tearDown(self):
shutil.rmtree(ldir)
@staticmethod
- def _teardown(dests):
+ def _teardown(dests) -> None:
for f in dests:
if os.path.exists(f):
os.remove(f)
- def test_defuse_fw(self):
+ def test_defuse_fw(self) -> None:
# defuse Zeus
self.lp.defuse_fw(self.zeus_fw_id)
# Ensure the states are sync after defusing fw
@@ -859,7 +859,7 @@ def test_defuse_fw(self):
except Exception:
raise
- def test_defuse_fw_after_completion(self):
+ def test_defuse_fw_after_completion(self) -> None:
# Launch rockets in rapidfire
rapidfire(self.lp, self.fworker, m_dir=MODULE_DIR)
# defuse Zeus
@@ -872,7 +872,7 @@ def test_defuse_fw_after_completion(self):
fw_cache_state = wf.fw_states[fw_id]
assert fw_state == fw_cache_state
- def test_reignite_fw(self):
+ def test_reignite_fw(self) -> None:
# Defuse Zeus and launch remaining fireworks
self.lp.defuse_fw(self.zeus_fw_id)
rapidfire(self.lp, self.fworker, m_dir=MODULE_DIR)
@@ -887,7 +887,7 @@ def test_reignite_fw(self):
fw_cache_state = wf.fw_states[fw_id]
assert fw_state == fw_cache_state
- def test_defuse_wf(self):
+ def test_defuse_wf(self) -> None:
# defuse Workflow containing Zeus
self.lp.defuse_wf(self.zeus_fw_id)
defused_ids = self.lp.get_fw_ids({"state": "DEFUSED"})
@@ -901,7 +901,7 @@ def test_defuse_wf(self):
fw_cache_state = wf.fw_states[fw_id]
assert fw_state == fw_cache_state
- def test_reignite_wf(self):
+ def test_reignite_wf(self) -> None:
# Defuse workflow containing Zeus
self.lp.defuse_wf(self.zeus_fw_id)
@@ -918,7 +918,7 @@ def test_reignite_wf(self):
fw_cache_state = wf.fw_states[fw_id]
assert fw_state == fw_cache_state
- def test_archive_wf(self):
+ def test_archive_wf(self) -> None:
# Run a firework before archiving Zeus
launch_rocket(self.lp, self.fworker)
# archive Workflow containing Zeus.
@@ -931,7 +931,7 @@ def test_archive_wf(self):
fw_cache_state = wf.fw_states[fw_id]
assert fw_state == fw_cache_state
- def test_rerun_fws(self):
+ def test_rerun_fws(self) -> None:
# Launch all fireworks
rapidfire(self.lp, self.fworker, m_dir=MODULE_DIR)
fw = self.lp.get_fw_by_id(self.zeus_fw_id)
@@ -948,15 +948,15 @@ def test_rerun_fws(self):
fw_cache_state = wf.fw_states[fw_id]
assert fw_state == fw_cache_state
- def test_rerun_timed_fws(self):
+ def test_rerun_timed_fws(self) -> None:
# Launch all fireworks in a separate process
class RapidfireProcess(Process):
- def __init__(self, lpad, fworker):
+ def __init__(self, lpad, fworker) -> None:
super(self.__class__, self).__init__()
self.lpad = lpad
self.fworker = fworker
- def run(self):
+ def run(self) -> None:
rapidfire(self.lpad, self.fworker)
rp = RapidfireProcess(self.lp, self.fworker)
@@ -972,7 +972,7 @@ def run(self):
assert fw_state == fw_cache_state
# Detect lost runs
- lost_lids, lost_fwids, inconsistent_fwids = self.lp.detect_lostruns(expiration_secs=0.5)
+ _lost_lids, lost_fwids, _inconsistent_fwids = self.lp.detect_lostruns(expiration_secs=0.5)
# Ensure the states are sync
wf = self.lp.get_wf_by_fw_id_lzyfw(self.zeus_fw_id)
fws = wf.id_fw
@@ -1012,7 +1012,7 @@ def run(self):
class LaunchPadRerunExceptionTest(unittest.TestCase):
@classmethod
- def setUpClass(cls):
+ def setUpClass(cls) -> None:
cls.lp = None
cls.fworker = FWorker()
try:
@@ -1022,11 +1022,11 @@ def setUpClass(cls):
raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.")
@classmethod
- def tearDownClass(cls):
+ def tearDownClass(cls) -> None:
if cls.lp:
cls.lp.connection.drop_database(TEST_DB_NAME)
- def setUp(self):
+ def setUp(self) -> None:
fireworks.core.firework.EXCEPT_DETAILS_ON_RERUN = True
self.error_test_dict = {"error": "description", "error_code": 1}
@@ -1043,7 +1043,7 @@ def setUp(self):
self.old_wd = os.getcwd()
- def tearDown(self):
+ def tearDown(self) -> None:
self.lp.reset(password=None, require_password=False)
# Delete launch locations
if os.path.exists(os.path.join("FW.json")):
@@ -1052,14 +1052,14 @@ def tearDown(self):
for ldir in glob.glob(os.path.join(MODULE_DIR, "launcher_*")):
shutil.rmtree(ldir)
- def test_except_details_on_rerun(self):
+ def test_except_details_on_rerun(self) -> None:
rapidfire(self.lp, self.fworker, m_dir=MODULE_DIR)
assert os.getcwd() == MODULE_DIR
self.lp.rerun_fw(1)
fw = self.lp.get_fw_by_id(1)
assert fw.spec["_exception_details"] == self.error_test_dict
- def test_task_level_rerun(self):
+ def test_task_level_rerun(self) -> None:
rapidfire(self.lp, self.fworker, m_dir=MODULE_DIR)
assert os.getcwd() == MODULE_DIR
self.lp.rerun_fw(1, recover_launch="last")
@@ -1076,7 +1076,7 @@ def test_task_level_rerun(self):
fw = self.lp.get_fw_by_id(1)
assert "_recovery" not in fw.spec
- def test_task_level_rerun_cp(self):
+ def test_task_level_rerun_cp(self) -> None:
rapidfire(self.lp, self.fworker, m_dir=MODULE_DIR)
assert os.getcwd() == MODULE_DIR
self.lp.rerun_fw(1, recover_launch="last", recover_mode="cp")
@@ -1089,7 +1089,7 @@ def test_task_level_rerun_cp(self):
assert ExceptionTestTask.exec_counter == 2
assert filecmp.cmp(os.path.join(dirs[0], "date_file"), os.path.join(dirs[1], "date_file"))
- def test_task_level_rerun_prev_dir(self):
+ def test_task_level_rerun_prev_dir(self) -> None:
rapidfire(self.lp, self.fworker, m_dir=MODULE_DIR)
assert os.getcwd() == MODULE_DIR
self.lp.rerun_fw(1, recover_launch="last", recover_mode="prev_dir")
@@ -1105,7 +1105,7 @@ def test_task_level_rerun_prev_dir(self):
class WFLockTest(unittest.TestCase):
@classmethod
- def setUpClass(cls):
+ def setUpClass(cls) -> None:
cls.lp = None
cls.fworker = FWorker()
try:
@@ -1115,11 +1115,11 @@ def setUpClass(cls):
raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.")
@classmethod
- def tearDownClass(cls):
+ def tearDownClass(cls) -> None:
if cls.lp:
cls.lp.connection.drop_database(TEST_DB_NAME)
- def setUp(self):
+ def setUp(self) -> None:
# set the defaults in the init of wflock to break the lock quickly
fireworks.core.launchpad.WFLock(3, False).__init__.__func__.__defaults__ = (3, False)
@@ -1132,7 +1132,7 @@ def setUp(self):
self.old_wd = os.getcwd()
- def tearDown(self):
+ def tearDown(self) -> None:
self.lp.reset(password=None, require_password=False)
# Delete launch locations
if os.path.exists(os.path.join("FW.json")):
@@ -1141,15 +1141,15 @@ def tearDown(self):
for ldir in glob.glob(os.path.join(MODULE_DIR, "launcher_*")):
shutil.rmtree(ldir)
- def test_fix_db_inconsistencies_completed(self):
+ def test_fix_db_inconsistencies_completed(self) -> None:
class RocketProcess(Process):
- def __init__(self, lpad, fworker, fw_id):
+ def __init__(self, lpad, fworker, fw_id) -> None:
super(self.__class__, self).__init__()
self.lpad = lpad
self.fworker = fworker
self.fw_id = fw_id
- def run(self):
+ def run(self) -> None:
launch_rocket(self.lpad, self.fworker, fw_id=self.fw_id)
# Launch the slow firework in a separate process
@@ -1188,15 +1188,15 @@ def run(self):
assert fast_fw.state == "COMPLETED"
- def test_fix_db_inconsistencies_fizzled(self):
+ def test_fix_db_inconsistencies_fizzled(self) -> None:
class RocketProcess(Process):
- def __init__(self, lpad, fworker, fw_id):
+ def __init__(self, lpad, fworker, fw_id) -> None:
super(self.__class__, self).__init__()
self.lpad = lpad
self.fworker = fworker
self.fw_id = fw_id
- def run(self):
+ def run(self) -> None:
launch_rocket(self.lpad, self.fworker, fw_id=self.fw_id)
self.lp.update_spec([2], {"fizzle": True})
@@ -1236,7 +1236,7 @@ def run(self):
class LaunchPadOfflineTest(unittest.TestCase):
@classmethod
- def setUpClass(cls):
+ def setUpClass(cls) -> None:
cls.lp = None
cls.fworker = FWorker()
try:
@@ -1246,11 +1246,11 @@ def setUpClass(cls):
raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.")
@classmethod
- def tearDownClass(cls):
+ def tearDownClass(cls) -> None:
if cls.lp:
cls.lp.connection.drop_database(TEST_DB_NAME)
- def setUp(self):
+ def setUp(self) -> None:
fireworks.core.firework.EXCEPT_DETAILS_ON_RERUN = True
self.error_test_dict = {"error": "description", "error_code": 1}
@@ -1262,7 +1262,7 @@ def setUp(self):
self.old_wd = os.getcwd()
- def tearDown(self):
+ def tearDown(self) -> None:
self.lp.reset(password=None, require_password=False)
# Delete launch locations
if os.path.exists(os.path.join("FW.json")):
@@ -1271,7 +1271,7 @@ def tearDown(self):
for ldir in glob.glob(os.path.join(MODULE_DIR, "launcher_*")):
shutil.rmtree(ldir, ignore_errors=True)
- def test__recover_completed(self):
+ def test__recover_completed(self) -> None:
fw, launch_id = self.lp.reserve_fw(self.fworker, self.launch_dir)
fw = self.lp.get_fw_by_id(1)
with cd(self.launch_dir):
@@ -1286,7 +1286,7 @@ def test__recover_completed(self):
assert fw.state == "COMPLETED"
- def test_recover_errors(self):
+ def test_recover_errors(self) -> None:
fw, launch_id = self.lp.reserve_fw(self.fworker, self.launch_dir)
fw = self.lp.get_fw_by_id(1)
with cd(self.launch_dir):
@@ -1317,7 +1317,7 @@ class GridfsStoredDataTest(unittest.TestCase):
"""
@classmethod
- def setUpClass(cls):
+ def setUpClass(cls) -> None:
cls.lp = None
cls.fworker = FWorker()
try:
@@ -1327,14 +1327,14 @@ def setUpClass(cls):
raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.")
@classmethod
- def tearDownClass(cls):
+ def tearDownClass(cls) -> None:
if cls.lp:
cls.lp.connection.drop_database(TEST_DB_NAME)
- def setUp(self):
+ def setUp(self) -> None:
self.old_wd = os.getcwd()
- def tearDown(self):
+ def tearDown(self) -> None:
self.lp.reset(password=None, require_password=False)
# Delete launch locations
if os.path.exists(os.path.join("FW.json")):
@@ -1343,7 +1343,7 @@ def tearDown(self):
for ldir in glob.glob(os.path.join(MODULE_DIR, "launcher_*")):
shutil.rmtree(ldir)
- def test_many_detours(self):
+ def test_many_detours(self) -> None:
task = DetoursTask(n_detours=2000, data_per_detour=["a" * 100] * 100)
fw = Firework([task])
self.lp.add_wf(fw)
@@ -1366,7 +1366,7 @@ def test_many_detours(self):
wf = self.lp.get_wf_by_fw_id_lzyfw(1)
assert len(wf.id_fw[1].launches[0].action.detours) == 2000
- def test_many_detours_offline(self):
+ def test_many_detours_offline(self) -> None:
task = DetoursTask(n_detours=2000, data_per_detour=["a" * 100] * 100)
fw = Firework([task])
self.lp.add_wf(fw)
diff --git a/fireworks/core/tests/test_rocket.py b/fireworks/core/tests/test_rocket.py
index 75dc02a89..5dac1fa09 100644
--- a/fireworks/core/tests/test_rocket.py
+++ b/fireworks/core/tests/test_rocket.py
@@ -11,7 +11,7 @@
class RocketTest(unittest.TestCase):
@classmethod
- def setUpClass(cls):
+ def setUpClass(cls) -> None:
cls.lp = None
cls.fworker = FWorker()
try:
@@ -21,20 +21,20 @@ def setUpClass(cls):
raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.")
@classmethod
- def tearDownClass(cls):
+ def tearDownClass(cls) -> None:
if cls.lp:
cls.lp.connection.drop_database(TESTDB_NAME)
- def setUp(self):
+ def setUp(self) -> None:
pass
- def tearDown(self):
+ def tearDown(self) -> None:
self.lp.reset(password=None, require_password=False)
# Delete launch locations
if os.path.exists(os.path.join("FW.json")):
os.remove("FW.json")
- def test_serializable_exception(self):
+ def test_serializable_exception(self) -> None:
error_test_dict = {"error": "description", "error_code": 1}
fw = Firework(ExceptionTestTask(exc_details=error_test_dict))
self.lp.add_wf(fw)
@@ -45,7 +45,7 @@ def test_serializable_exception(self):
launches = fw.launches
assert launches[0].action.stored_data["_exception"]["_details"] == error_test_dict
- def test_postproc_exception(self):
+ def test_postproc_exception(self) -> None:
fw = Firework(MalformedAdditionTask())
self.lp.add_wf(fw)
launch_rocket(self.lp, self.fworker)
diff --git a/fireworks/core/tests/test_tracker.py b/fireworks/core/tests/test_tracker.py
index 369ea2b56..3914310ac 100644
--- a/fireworks/core/tests/test_tracker.py
+++ b/fireworks/core/tests/test_tracker.py
@@ -1,6 +1,5 @@
"""Tracker unitest."""
-
__author__ = "Bharat medasani"
__copyright__ = "Copyright 2012, The Materials Project"
__maintainer__ = "Bharat medasani"
@@ -24,7 +23,7 @@
class TrackerTest(unittest.TestCase):
@classmethod
- def setUpClass(cls):
+ def setUpClass(cls) -> None:
cls.lp = None
cls.fworker = FWorker()
try:
@@ -34,11 +33,11 @@ def setUpClass(cls):
raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.")
@classmethod
- def tearDownClass(cls):
+ def tearDownClass(cls) -> None:
if cls.lp:
cls.lp.connection.drop_database(TESTDB_NAME)
- def setUp(self):
+ def setUp(self) -> None:
self.old_wd = os.getcwd()
self.dest1 = os.path.join(MODULE_DIR, "numbers1.txt")
self.dest2 = os.path.join(MODULE_DIR, "numbers2.txt")
@@ -46,7 +45,7 @@ def setUp(self):
self.tracker1 = Tracker(self.dest1, nlines=2)
self.tracker2 = Tracker(self.dest2, nlines=2)
- def tearDown(self):
+ def tearDown(self) -> None:
self.lp.reset(password=None, require_password=False)
if os.path.exists(os.path.join("FW.json")):
os.remove("FW.json")
@@ -55,12 +54,12 @@ def tearDown(self):
shutil.rmtree(i)
@staticmethod
- def _teardown(dests):
+ def _teardown(dests) -> None:
for f in dests:
if os.path.exists(f):
os.remove(f)
- def test_tracker(self):
+ def test_tracker(self) -> None:
"""Launch a workflow and track the files."""
self._teardown([self.dest1])
try:
@@ -79,7 +78,7 @@ def test_tracker(self):
finally:
self._teardown([self.dest1])
- def test_tracker_failed_fw(self):
+ def test_tracker_failed_fw(self) -> None:
"""Add a bad firetask to workflow and test the tracking."""
self._teardown([self.dest1])
try:
@@ -108,12 +107,12 @@ def test_tracker_failed_fw(self):
finally:
self._teardown([self.dest1])
- def test_tracker_mlaunch(self):
+ def test_tracker_mlaunch(self) -> None:
"""Test the tracker for mlaunch."""
self._teardown([self.dest1, self.dest2])
try:
- def add_wf(j, dest, tracker, name):
+ def add_wf(j, dest, tracker, name) -> None:
fts = []
for i in range(j, j + 25):
ft = ScriptTask.from_str('echo "' + str(i) + '" >> ' + dest, {"store_stdout": True})
diff --git a/fireworks/examples/custom_firetasks/hello_world/hello_world_task.py b/fireworks/examples/custom_firetasks/hello_world/hello_world_task.py
index 04c30462f..b6dc31592 100644
--- a/fireworks/examples/custom_firetasks/hello_world/hello_world_task.py
+++ b/fireworks/examples/custom_firetasks/hello_world/hello_world_task.py
@@ -3,5 +3,5 @@
@explicit_serialize
class HelloTask(FiretaskBase):
- def run_task(self, fw_spec):
+ def run_task(self, fw_spec) -> None:
print("Hello, world!")
diff --git a/fireworks/examples/custom_firetasks/merge_task/merge_task.py b/fireworks/examples/custom_firetasks/merge_task/merge_task.py
index 273e3f92d..c356e5d41 100644
--- a/fireworks/examples/custom_firetasks/merge_task/merge_task.py
+++ b/fireworks/examples/custom_firetasks/merge_task/merge_task.py
@@ -24,7 +24,7 @@ def run_task(self, fw_spec):
@explicit_serialize
class TaskC(FiretaskBase):
- def run_task(self, fw_spec):
+ def run_task(self, fw_spec) -> None:
print("This is task C.")
print(f"Task A gave me: {fw_spec['param_A']}")
print(f"Task B gave me: {fw_spec['param_B']}")
diff --git a/fireworks/features/background_task.py b/fireworks/features/background_task.py
index af5f9e6ad..88fc6c722 100644
--- a/fireworks/features/background_task.py
+++ b/fireworks/features/background_task.py
@@ -10,7 +10,7 @@
class BackgroundTask(FWSerializable):
_fw_name = "BackgroundTask"
- def __init__(self, tasks, num_launches=0, sleep_time=60, run_on_finish=False):
+ def __init__(self, tasks, num_launches=0, sleep_time=60, run_on_finish=False) -> None:
"""
Args:
tasks [Firetask]: a list of Firetasks to perform
diff --git a/fireworks/features/dupefinder.py b/fireworks/features/dupefinder.py
index 62027f657..a543df4a2 100644
--- a/fireworks/features/dupefinder.py
+++ b/fireworks/features/dupefinder.py
@@ -1,5 +1,7 @@
"""This module contains the base class for implementing Duplicate Finders."""
+from typing import NoReturn
+
from fireworks.utilities.fw_serializers import FWSerializable, serialize_fw
__author__ = "Anubhav Jain"
@@ -12,10 +14,10 @@
class DupeFinderBase(FWSerializable):
"""This serves an Abstract class for implementing Duplicate Finders."""
- def __init__(self):
+ def __init__(self) -> None:
pass
- def verify(self, spec1, spec2):
+ def verify(self, spec1, spec2) -> NoReturn:
"""
Method that checks whether two specs are identical enough to be
considered duplicates. Return true if duplicated. Note that
@@ -31,7 +33,7 @@ def verify(self, spec1, spec2):
"""
raise NotImplementedError
- def query(self, spec):
+ def query(self, spec) -> NoReturn:
"""
Given a spec, returns a database query that gives potential candidates for duplicated Fireworks.
diff --git a/fireworks/features/fw_report.py b/fireworks/features/fw_report.py
index 3cb8333ce..d40ea1dde 100644
--- a/fireworks/features/fw_report.py
+++ b/fireworks/features/fw_report.py
@@ -1,5 +1,6 @@
+from __future__ import annotations
+
from datetime import datetime
-from typing import List
from dateutil.relativedelta import relativedelta
@@ -24,7 +25,7 @@
class FWReport:
- def __init__(self, lpad):
+ def __init__(self, lpad) -> None:
"""
Args:
lpad (LaunchPad).
@@ -54,7 +55,7 @@ def get_stats(self, coll="fireworks", interval="days", num_intervals=5, addition
# initialize collection
if coll.lower() in ["fws", "fireworks"]:
coll = "fireworks"
- elif coll.lower() in ["launches"]:
+ elif coll.lower() == "launches":
coll = "launches"
elif coll.lower() in ["wflows", "workflows"]:
coll = "workflows"
@@ -75,29 +76,27 @@ def get_stats(self, coll="fireworks", interval="days", num_intervals=5, addition
date_q = {"$gte": start_time.isoformat()} if string_type_dates else {"$gte": start_time}
match_q.update({time_field: date_q})
- pipeline.append({"$match": match_q})
- pipeline.append(
- {"$project": {"state": 1, "_id": 0, "date_key": {"$substr": ["$" + time_field, 0, date_key_idx]}}}
- )
- pipeline.append(
- {
- "$group": {
- "_id": {"state:": "$state", "date_key": "$date_key"},
- "count": {"$sum": 1},
- "state": {"$first": "$state"},
- }
- }
- )
- pipeline.append(
- {
- "$group": {
- "_id": {"_id_date_key": "$_id.date_key"},
- "date_key": {"$first": "$_id.date_key"},
- "states": {"$push": {"count": "$count", "state": "$state"}},
- }
- }
+ pipeline.extend(
+ (
+ {"$match": match_q},
+ {"$project": {"state": 1, "_id": 0, "date_key": {"$substr": ["$" + time_field, 0, date_key_idx]}}},
+ {
+ "$group": {
+ "_id": {"state:": "$state", "date_key": "$date_key"},
+ "count": {"$sum": 1},
+ "state": {"$first": "$state"},
+ }
+ },
+ {
+ "$group": {
+ "_id": {"_id_date_key": "$_id.date_key"},
+ "date_key": {"$first": "$_id.date_key"},
+ "states": {"$push": {"count": "$count", "state": "$state"}},
+ }
+ },
+ {"$sort": {"date_key": -1}},
+ )
)
- pipeline.append({"$sort": {"date_key": -1}})
# add in missing states and more fields
decorated_list = []
@@ -183,7 +182,7 @@ def plot_stats(self, coll="fireworks", interval="days", num_intervals=5, states=
return fig
@staticmethod
- def get_stats_str(decorated_stat_list: List[dict]) -> str:
+ def get_stats_str(decorated_stat_list: list[dict]) -> str:
"""
Convert the list of stats from FWReport.get_stats() to a string representation for viewing.
diff --git a/fireworks/features/introspect.py b/fireworks/features/introspect.py
index 09f0de785..ecc82c2e8 100644
--- a/fireworks/features/introspect.py
+++ b/fireworks/features/introspect.py
@@ -80,7 +80,7 @@ def compare_stats(stats_dict1, n_samples1, stats_dict2, n_samples2, threshold=5)
class Introspector:
- def __init__(self, lpad):
+ def __init__(self, lpad) -> None:
"""
Args:
lpad (LaunchPad).
@@ -94,7 +94,7 @@ def introspect_fizzled(self, coll="fws", rsort=True, threshold=10, limit=100):
coll = "fireworks"
state_key = "spec"
- elif coll.lower() in ["tasks"]:
+ elif coll.lower() == "tasks":
coll = "fireworks"
state_key = "spec._tasks"
@@ -102,7 +102,7 @@ def introspect_fizzled(self, coll="fws", rsort=True, threshold=10, limit=100):
coll = "workflows"
state_key = "metadata"
- elif coll.lower() in ["launches"]:
+ elif coll.lower() == "launches":
coll = "launches"
state_key = "action.stored_data._exception._stacktrace"
@@ -171,14 +171,14 @@ def introspect_fizzled(self, coll="fws", rsort=True, threshold=10, limit=100):
return table
@staticmethod
- def print_report(table, coll):
+ def print_report(table, coll) -> None:
if coll.lower() in ["fws", "fireworks"]:
header_txt = "fireworks.spec"
- elif coll.lower() in ["tasks"]:
+ elif coll.lower() == "tasks":
header_txt = "fireworks.spec._tasks"
elif coll.lower() in ["wflows", "workflows"]:
header_txt = "workflows.metadata"
- elif coll.lower() in ["launches"]:
+ elif coll.lower() == "launches":
header_txt = "launches.actions.stored_data._exception._stacktrace"
header_txt = f"Introspection report for {header_txt}"
@@ -192,4 +192,4 @@ def print_report(table, coll):
for row in table:
print(f"----{row[3]} Failures have the following stack trace--------------")
print(row[1])
- print("")
+ print()
diff --git a/fireworks/features/multi_launcher.py b/fireworks/features/multi_launcher.py
index 03b93ca34..546b846a4 100644
--- a/fireworks/features/multi_launcher.py
+++ b/fireworks/features/multi_launcher.py
@@ -16,7 +16,7 @@
__date__ = "Aug 19, 2013"
-def ping_multilaunch(port, stop_event):
+def ping_multilaunch(port, stop_event) -> None:
"""
A single manager to ping all launches during multiprocess launches.
@@ -43,7 +43,7 @@ def ping_multilaunch(port, stop_event):
def rapidfire_process(
fworker, nlaunches, sleep, loglvl, port, node_list, sub_nproc, timeout, running_ids_dict, local_redirect
-):
+) -> None:
"""
Initializes shared data with multiprocessing parameters and starts a rapidfire.
@@ -205,7 +205,7 @@ def launch_multiprocess(
timeout=None,
exclude_current_node=False,
local_redirect=False,
-):
+) -> None:
"""
Launch the jobs in the job packing mode.
diff --git a/fireworks/features/stats.py b/fireworks/features/stats.py
index 39b395880..c381e64d6 100644
--- a/fireworks/features/stats.py
+++ b/fireworks/features/stats.py
@@ -22,7 +22,7 @@
class FWStats:
- def __init__(self, lpad):
+ def __init__(self, lpad) -> None:
"""
Object to get Fireworks running stats from a LaunchPad.
diff --git a/fireworks/features/tests/test_introspect.py b/fireworks/features/tests/test_introspect.py
index 2137ad829..74f78ded4 100644
--- a/fireworks/features/tests/test_introspect.py
+++ b/fireworks/features/tests/test_introspect.py
@@ -6,7 +6,7 @@
class IntrospectTest(unittest.TestCase):
- def test_flatten_dict(self):
+ def test_flatten_dict(self) -> None:
assert set(flatten_to_keys({"d": {"e": {"f": 4}, "f": 10}}, max_recurs=1)) == {
f"d{separator_str}"
}
diff --git a/fireworks/flask_site/gunicorn.py b/fireworks/flask_site/gunicorn.py
index a65c25a3a..7618f23ff 100755
--- a/fireworks/flask_site/gunicorn.py
+++ b/fireworks/flask_site/gunicorn.py
@@ -12,12 +12,12 @@ def number_of_workers():
class StandaloneApplication(gunicorn.app.base.BaseApplication):
- def __init__(self, app, options=None):
+ def __init__(self, app, options=None) -> None:
self.options = options or {}
self.application = app
super().__init__()
- def load_config(self):
+ def load_config(self) -> None:
config = {key: value for key, value in self.options.items() if key in self.cfg.settings and value is not None}
for key, value in config.items():
self.cfg.set(key.lower(), value)
diff --git a/fireworks/flask_site/helpers.py b/fireworks/flask_site/helpers.py
index 454acd00b..19196fd9d 100644
--- a/fireworks/flask_site/helpers.py
+++ b/fireworks/flask_site/helpers.py
@@ -16,8 +16,7 @@ def fw_filt_given_wf_filt(filt, lp):
def wf_filt_given_fw_filt(filt, lp):
wf_ids = set()
- for doc in lp.fireworks.find(filt, {"_id": 0, "fw_id": 1}):
- wf_ids.add(doc["fw_id"])
+ wf_ids.update(doc["fw_id"] for doc in lp.fireworks.find(filt, {"_id": 0, "fw_id": 1}))
return {"nodes": {"$in": list(wf_ids)}}
diff --git a/fireworks/fw_config.py b/fireworks/fw_config.py
index ebd77d59d..7b2b964b6 100644
--- a/fireworks/fw_config.py
+++ b/fireworks/fw_config.py
@@ -1,7 +1,9 @@
"""A set of global constants for FireWorks (Python code as a config file)."""
+from __future__ import annotations
+
import os
-from typing import Any, Dict, Optional
+from typing import Any
from monty.design_patterns import singleton
from monty.serialization import dumpfn, loadfn
@@ -175,7 +177,7 @@ def override_user_settings() -> None:
override_user_settings()
-def config_to_dict() -> Dict[str, Any]:
+def config_to_dict() -> dict[str, Any]:
d = {}
for k, v in globals().items():
if k.upper() == k and k != "NEGATIVE_FWID_CTR":
@@ -183,7 +185,7 @@ def config_to_dict() -> Dict[str, Any]:
return d
-def write_config(path: Optional[str] = None) -> None:
+def write_config(path: str | None = None) -> None:
if path is None:
path = os.path.join(os.path.expanduser("~"), ".fireworks", "FW_config.yaml")
dumpfn(config_to_dict(), path)
@@ -193,7 +195,7 @@ def write_config(path: Optional[str] = None) -> None:
class FWData:
"""This class stores data that a Firetask might want to access, e.g. to see the runtime params."""
- def __init__(self):
+ def __init__(self) -> None:
self.MULTIPROCESSING = None # default single process framework
self.NODE_LIST = None # the node list for sub jobs
self.SUB_NPROCS = None # the number of process of the sub job
diff --git a/fireworks/queue/queue_adapter.py b/fireworks/queue/queue_adapter.py
index 87296c129..37b806d18 100644
--- a/fireworks/queue/queue_adapter.py
+++ b/fireworks/queue/queue_adapter.py
@@ -34,7 +34,7 @@ class Command:
status = None
output, error = "", ""
- def __init__(self, command):
+ def __init__(self, command) -> None:
"""
initialize the object.
@@ -57,7 +57,7 @@ def run(self, timeout=None, **kwargs):
(status, output, error)
"""
- def target(**kwargs):
+ def target(**kwargs) -> None:
try:
self.process = subprocess.Popen(self.command, **kwargs)
self.output, self.error = self.process.communicate()
diff --git a/fireworks/queue/queue_launcher.py b/fireworks/queue/queue_launcher.py
index 4b2b5b497..b99827c93 100644
--- a/fireworks/queue/queue_launcher.py
+++ b/fireworks/queue/queue_launcher.py
@@ -179,7 +179,7 @@ def rapidfire(
strm_lvl="INFO",
timeout=None,
fill_mode=False,
-):
+) -> None:
"""
Submit many jobs to the queue.
@@ -330,7 +330,7 @@ def _get_number_of_jobs_in_queue(qadapter, njobs_queue, l_logger):
raise RuntimeError("Unable to determine number of jobs in queue, check queue adapter and queue server status!")
-def setup_offline_job(launchpad, fw, launch_id):
+def setup_offline_job(launchpad, fw, launch_id) -> None:
# separate this function out for reuse in unit testing
fw.to_file("FW.json")
with open("FW_offline.json", "w") as f:
diff --git a/fireworks/scripts/lpad_run.py b/fireworks/scripts/lpad_run.py
index b6748dda2..647d73dd3 100644
--- a/fireworks/scripts/lpad_run.py
+++ b/fireworks/scripts/lpad_run.py
@@ -1,18 +1,21 @@
"""A runnable script for managing a FireWorks database (a command-line interface to launchpad.py)."""
+from __future__ import annotations
+
import ast
import copy
import datetime
import json
import os
import re
+import sys
import time
from argparse import ArgumentParser, ArgumentTypeError, Namespace
from importlib import metadata
-from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Union
+from typing import Any, Sequence
-import ruamel.yaml as yaml
from pymongo import ASCENDING, DESCENDING
+from ruamel.yaml import YAML
from fireworks import FW_INSTALL_DIR
from fireworks.core.firework import Firework, Workflow
@@ -46,7 +49,7 @@
DEFAULT_LPAD_YAML = "my_launchpad.yaml"
-def pw_check(ids: List[int], args: Namespace, skip_pw: bool = False) -> List[int]:
+def pw_check(ids: list[int], args: Namespace, skip_pw: bool = False) -> list[int]:
if len(ids) > PW_CHECK_NUM and not skip_pw:
m_password = datetime.datetime.now().strftime("%Y-%m-%d")
if not args.password:
@@ -62,7 +65,7 @@ def pw_check(ids: List[int], args: Namespace, skip_pw: bool = False) -> List[int
return ids
-def parse_helper(lp: LaunchPad, args: Namespace, wf_mode: bool = False, skip_pw: bool = False) -> List[int]:
+def parse_helper(lp: LaunchPad, args: Namespace, wf_mode: bool = False, skip_pw: bool = False) -> list[int]:
"""
Helper method to parse args that can take either id, name, state or query.
@@ -160,7 +163,7 @@ def init_yaml(args: Namespace) -> None:
),
)
- doc: Dict[str, Union[str, int, bool, None]] = {}
+ doc: dict[str, str | int | bool | None] = {}
if args.uri_mode:
print(
"Note 1: You are in URI format mode. This means that all database parameters (username, password, host, "
@@ -260,11 +263,10 @@ def print_fws(ids, lp, args: Namespace) -> None:
fws.append(d)
if len(fws) == 1:
fws = fws[0]
-
- print(args.output(fws))
+ get_output(args, fws)
-def get_fw_ids_helper(lp: LaunchPad, args: Namespace, count_only: Union[bool, None] = None) -> Union[List[int], int]:
+def get_fw_ids_helper(lp: LaunchPad, args: Namespace, count_only: bool | None = None) -> list[int] | int:
"""Build fws query from command line options and submit.
Parameters:
@@ -316,8 +318,8 @@ def get_fw_ids_helper(lp: LaunchPad, args: Namespace, count_only: Union[bool, No
def get_fws_helper(
- lp: LaunchPad, ids: List[int], args: Namespace
-) -> Union[List[int], int, List[Dict[str, Union[str, int, bool]]], Union[str, int, bool]]:
+ lp: LaunchPad, ids: list[int], args: Namespace
+) -> list[int] | int | list[dict[str, str | int | bool]] | str | bool:
"""Get fws from ids in a representation according to args.display_format."""
fws = []
if args.display_format == "ids":
@@ -343,7 +345,7 @@ def get_fws(args: Namespace) -> None:
lp = get_lp(args)
ids = get_fw_ids_helper(lp, args)
fws = get_fws_helper(lp, ids, args)
- print(args.output(fws))
+ get_output(args, fws)
def get_fws_in_wfs(args: Namespace) -> None:
@@ -467,7 +469,7 @@ def get_wfs(args: Namespace) -> None:
else:
if len(wfs) == 1:
wfs = wfs[0]
- print(args.output(wfs))
+ get_output(args, wfs)
def delete_wfs(args: Namespace) -> None:
@@ -681,7 +683,7 @@ def set_priority(args: Namespace) -> None:
lp.m_logger.info(f"Finished setting priorities of {len(fw_ids)} FWs")
-def _open_webbrowser(url):
+def _open_webbrowser(url) -> None:
"""Open a web browser after a delay to give the web server more startup time."""
import webbrowser
@@ -793,10 +795,10 @@ def introspect(args: Namespace) -> None:
isp = Introspector(lp)
for coll in ["launches", "tasks", "fireworks", "workflows"]:
print(f"generating report for {coll}...please wait...")
- print("")
+ print()
table = isp.introspect_fizzled(coll=coll, threshold=args.threshold, limit=args.max)
isp.print_report(table, coll)
- print("")
+ print()
def get_launchdir(args: Namespace) -> None:
@@ -817,8 +819,7 @@ def track_fws(args: Namespace) -> None:
for d in data:
for t in d["trackers"]:
if (not include or t.filename in include) and (not exclude or t.filename not in exclude):
- output.append(f"## Launch id: {d['launch_id']}")
- output.append(str(t))
+ output.extend((f"## Launch id: {d['launch_id']}", str(t)))
if output:
name = lp.fireworks.find_one({"fw_id": f}, {"name": 1})["name"]
output.insert(0, f"# FW id: {f}, FW name: {name}")
@@ -852,13 +853,18 @@ def orphaned(args: Namespace) -> None:
lp.m_logger.info(f"Found {len(orphaned_fw_ids)} orphaned fw_ids: {orphaned_fw_ids}")
lp.delete_fws(orphaned_fw_ids, delete_launch_dirs=args.delete_launch_dirs)
else:
- print(args.output(fws))
+ get_output(args, fws)
-def get_output_func(format: Literal["json", "yaml"]) -> Callable[[str], Any]:
- if format == "json":
- return lambda x: json.dumps(x, default=DATETIME_HANDLER, indent=4)
- return lambda x: yaml.safe_dump(recursive_dict(x, preserve_unicode=False), default_flow_style=False)
+def get_output(args: Namespace, objs: list[Any]) -> None:
+ """Prints output on stdout"""
+ if args.output == "json":
+ json.dump(objs, sys.stdout, default=DATETIME_HANDLER, indent=4)
+ else:
+ yaml = YAML(typ="safe", pure=True)
+ yaml.default_flow_style = False
+ yaml.dump(recursive_dict(objs, preserve_unicode=False), sys.stdout)
+ print()
def arg_positive_int(value: str) -> int:
@@ -872,7 +878,7 @@ def arg_positive_int(value: str) -> int:
return ivalue
-def lpad(argv: Optional[Sequence[str]] = None) -> int:
+def lpad(argv: Sequence[str] | None = None) -> int:
m_description = (
"A command line interface to FireWorks. For more help on a specific command, type 'lpad -h'."
)
@@ -1539,8 +1545,6 @@ def lpad(argv: Optional[Sequence[str]] = None) -> int:
cfg_files_to_check.append(("fworker", "-w", False, FWORKER_LOC))
_validate_config_file_paths(args, cfg_files_to_check)
- args.output = get_output_func(args.output)
-
if args.command is None:
# if no command supplied, print help
parser.print_help()
diff --git a/fireworks/scripts/mlaunch_run.py b/fireworks/scripts/mlaunch_run.py
index 0292f6e4c..3e6bd42ce 100644
--- a/fireworks/scripts/mlaunch_run.py
+++ b/fireworks/scripts/mlaunch_run.py
@@ -1,9 +1,11 @@
"""A runnable script to launch Job Packing (Multiple) Rockets."""
+from __future__ import annotations
+
import os
from argparse import ArgumentParser
from importlib import metadata
-from typing import Optional, Sequence
+from typing import Sequence
from fireworks.core.fworker import FWorker
from fireworks.core.launchpad import LaunchPad
@@ -19,7 +21,7 @@
__date__ = "Aug 19, 2013"
-def mlaunch(argv: Optional[Sequence[str]] = None) -> int:
+def mlaunch(argv: Sequence[str] | None = None) -> int:
m_description = "This program launches multiple Rockets simultaneously"
parser = ArgumentParser("mlaunch", description=m_description)
@@ -92,7 +94,7 @@ def mlaunch(argv: Optional[Sequence[str]] = None) -> int:
if args.nodefile in os.environ:
args.nodefile = os.environ[args.nodefile]
with open(args.nodefile) as f:
- total_node_list = [line.strip() for line in f.readlines()]
+ total_node_list = [line.strip() for line in f]
launch_multiprocess(
launchpad,
diff --git a/fireworks/scripts/qlaunch_run.py b/fireworks/scripts/qlaunch_run.py
index a7b5b9851..1f3a8f9fe 100644
--- a/fireworks/scripts/qlaunch_run.py
+++ b/fireworks/scripts/qlaunch_run.py
@@ -1,8 +1,11 @@
"""A runnable script for launching rockets (a command-line interface to queue_launcher.py)."""
+
+from __future__ import annotations
+
import os
import time
from argparse import ArgumentParser
-from typing import Optional, Sequence
+from typing import Sequence
try:
import fabric
@@ -31,7 +34,7 @@
__date__ = "Jan 14, 2013"
-def do_launch(args):
+def do_launch(args) -> None:
cfg_files_to_check = [
("launchpad", "-l", False, LAUNCHPAD_LOC),
("fworker", "-w", False, FWORKER_LOC),
@@ -74,7 +77,7 @@ def do_launch(args):
)
-def qlaunch(argv: Optional[Sequence[str]] = None) -> int:
+def qlaunch(argv: Sequence[str] | None = None) -> int:
m_description = (
"This program is used to submit jobs to a queueing system. "
"Details of the job and queue interaction are handled by the "
diff --git a/fireworks/scripts/rlaunch_run.py b/fireworks/scripts/rlaunch_run.py
index 68e8d9301..ce3150c40 100644
--- a/fireworks/scripts/rlaunch_run.py
+++ b/fireworks/scripts/rlaunch_run.py
@@ -1,11 +1,13 @@
"""A runnable script to launch a single Rocket (a command-line interface to rocket_launcher.py)."""
+from __future__ import annotations
+
import os
import signal
import sys
from argparse import ArgumentParser
from importlib import metadata
-from typing import Optional, Sequence
+from typing import Sequence
from fireworks.core.fworker import FWorker
from fireworks.core.launchpad import LaunchPad
@@ -24,12 +26,12 @@
__date__ = "Feb 7, 2013"
-def handle_interrupt(signum, frame):
+def handle_interrupt(signum, frame) -> None:
sys.stderr.write(f"Interrupted by signal {signum:d}\n")
sys.exit(1)
-def rlaunch(argv: Optional[Sequence[str]] = None) -> int:
+def rlaunch(argv: Sequence[str] | None = None) -> int:
m_description = (
"This program launches one or more Rockets. A Rocket retrieves a job from the "
'central database and runs it. The "single-shot" option launches a single Rocket, '
@@ -165,7 +167,7 @@ def rlaunch(argv: Optional[Sequence[str]] = None) -> int:
if args.nodefile in os.environ:
args.nodefile = os.environ[args.nodefile]
with open(args.nodefile) as f:
- total_node_list = [line.strip() for line in f.readlines()]
+ total_node_list = [line.strip() for line in f]
launch_multiprocess(
launchpad,
fworker,
diff --git a/fireworks/scripts/tests/test_lpad_run.py b/fireworks/scripts/tests/test_lpad_run.py
index 4f2451d2f..d5244c676 100644
--- a/fireworks/scripts/tests/test_lpad_run.py
+++ b/fireworks/scripts/tests/test_lpad_run.py
@@ -20,7 +20,7 @@ def lp(capsys):
@pytest.mark.parametrize(("detail", "expected_1", "expected_2"), [("count", "0\n", "1\n"), ("ids", "[]\n", "1\n")])
-def test_lpad_get_fws(capsys, lp, detail, expected_1, expected_2):
+def test_lpad_get_fws(capsys, lp, detail, expected_1, expected_2) -> None:
"""Test lpad CLI get_fws command."""
ret_code = lpad(["get_fws", "-d", detail])
assert ret_code == 0
@@ -45,7 +45,7 @@ def test_lpad_get_fws(capsys, lp, detail, expected_1, expected_2):
@pytest.mark.parametrize("arg", ["-v", "--version"])
-def test_lpad_report_version(capsys, arg):
+def test_lpad_report_version(capsys, arg) -> None:
"""Test lpad CLI version flag."""
with pytest.raises(SystemExit, match="0"):
lpad([arg])
@@ -56,7 +56,7 @@ def test_lpad_report_version(capsys, arg):
assert stderr == ""
-def test_lpad_config_file_flags():
+def test_lpad_config_file_flags() -> None:
"""Test lpad CLI throws errors on missing config file flags."""
with pytest.raises(FileNotFoundError, match="launchpad_file '' does not exist!"):
lpad(["-l", "", "get_fws"])
diff --git a/fireworks/scripts/tests/test_mlaunch_run.py b/fireworks/scripts/tests/test_mlaunch_run.py
index 1c7827bd4..aa6f2f9a4 100644
--- a/fireworks/scripts/tests/test_mlaunch_run.py
+++ b/fireworks/scripts/tests/test_mlaunch_run.py
@@ -6,7 +6,7 @@
@pytest.mark.parametrize("arg", ["-v", "--version"])
-def test_mlaunch_report_version(capsys, arg):
+def test_mlaunch_report_version(capsys, arg) -> None:
"""Test mlaunch CLI version flag."""
with pytest.raises(SystemExit, match="0"):
mlaunch([arg])
@@ -17,7 +17,7 @@ def test_mlaunch_report_version(capsys, arg):
assert stderr == ""
-def test_mlaunch_config_file_flags():
+def test_mlaunch_config_file_flags() -> None:
"""Test mlaunch CLI throws errors on missing config file flags."""
num_jobs = "1"
diff --git a/fireworks/scripts/tests/test_qlaunch_run.py b/fireworks/scripts/tests/test_qlaunch_run.py
index 6c04a7318..3200f0980 100644
--- a/fireworks/scripts/tests/test_qlaunch_run.py
+++ b/fireworks/scripts/tests/test_qlaunch_run.py
@@ -8,7 +8,7 @@
@pytest.mark.parametrize("arg", ["-v", "--version"])
-def test_qlaunch_report_version(capsys, arg):
+def test_qlaunch_report_version(capsys, arg) -> None:
"""Test qlaunch CLI version flag."""
with pytest.raises(SystemExit):
qlaunch([arg])
@@ -19,7 +19,7 @@ def test_qlaunch_report_version(capsys, arg):
assert stderr == ""
-def test_qlaunch_config_file_flags():
+def test_qlaunch_config_file_flags() -> None:
"""Test qlaunch CLI throws errors on missing config file flags."""
# qadapter.yaml is mandatory, test for ValueError if missing
with pytest.raises(ValueError, match="No path specified for qadapter_file."):
diff --git a/fireworks/scripts/tests/test_rlaunch_run.py b/fireworks/scripts/tests/test_rlaunch_run.py
index 2895e4e1d..bc1f60327 100644
--- a/fireworks/scripts/tests/test_rlaunch_run.py
+++ b/fireworks/scripts/tests/test_rlaunch_run.py
@@ -6,7 +6,7 @@
@pytest.mark.parametrize("arg", ["-v", "--version"])
-def test_rlaunch_report_version(capsys, arg):
+def test_rlaunch_report_version(capsys, arg) -> None:
"""Test rlaunch CLI version flag."""
with pytest.raises(SystemExit, match="0"):
rlaunch([arg])
@@ -17,7 +17,7 @@ def test_rlaunch_report_version(capsys, arg):
assert stderr == ""
-def test_rlaunch_config_file_flags():
+def test_rlaunch_config_file_flags() -> None:
"""Test rlaunch CLI throws errors on missing config file flags."""
with pytest.raises(FileNotFoundError, match="launchpad_file '' does not exist!"):
rlaunch(["-l", ""])
diff --git a/fireworks/tests/master_tests.py b/fireworks/tests/master_tests.py
index 73e369b5d..e1dbbc3ba 100644
--- a/fireworks/tests/master_tests.py
+++ b/fireworks/tests/master_tests.py
@@ -23,7 +23,7 @@
class TestImports(unittest.TestCase):
"""Make sure that required external libraries can be imported."""
- def test_imports(self):
+ def test_imports(self) -> None:
pass
# test that MongoClient is available (newer pymongo)
@@ -31,7 +31,7 @@ def test_imports(self):
class BasicTests(unittest.TestCase):
"""Make sure that required external libraries can be imported."""
- def test_fwconnector(self):
+ def test_fwconnector(self) -> None:
fw1 = Firework(ScriptTask.from_str('echo "1"'))
fw2 = Firework(ScriptTask.from_str('echo "1"'))
@@ -44,7 +44,7 @@ def test_fwconnector(self):
wf3 = Workflow([fw1, fw2])
assert wf3.links == {fw1.fw_id: [], fw2.fw_id: []}
- def test_parentconnector(self):
+ def test_parentconnector(self) -> None:
fw1 = Firework(ScriptTask.from_str('echo "1"'))
fw2 = Firework(ScriptTask.from_str('echo "1"'), parents=fw1)
fw3 = Firework(ScriptTask.from_str('echo "1"'), parents=[fw1, fw2])
@@ -69,7 +69,7 @@ def get_data(obj_dict):
return cls_.from_dict(obj_dict)
return None
- def test_serialization_details(self):
+ def test_serialization_details(self) -> None:
# This detects a weird bug found in early version of serializers
pbs = CommonAdapter("PBS")
@@ -78,7 +78,7 @@ def test_serialization_details(self):
assert isinstance(load_object(pbs.to_dict()), CommonAdapter)
assert isinstance(self.get_data(pbs.to_dict()), CommonAdapter) # repeated test on purpose!
- def test_recursive_deserialize(self):
+ def test_recursive_deserialize(self) -> None:
my_dict = {
"update_spec": {},
"mod_spec": [],
diff --git a/fireworks/tests/mongo_tests.py b/fireworks/tests/mongo_tests.py
index bef1aa983..d3c0fd950 100644
--- a/fireworks/tests/mongo_tests.py
+++ b/fireworks/tests/mongo_tests.py
@@ -6,6 +6,7 @@
import time
import unittest
from multiprocessing import Pool
+from typing import NoReturn
import pytest
@@ -36,14 +37,14 @@
NCORES_PARALLEL_TEST = 4
-def random_launch(lp_creds):
+def random_launch(lp_creds) -> None:
lp = LaunchPad.from_dict(lp_creds)
while lp.run_exists(None):
launch_rocket(lp)
time.sleep(random.random() / 3 + 0.1)
-def throw_error(msg):
+def throw_error(msg) -> NoReturn:
raise ValueError(msg)
@@ -74,7 +75,7 @@ def run_task(self, fw_spec):
class MongoTests(unittest.TestCase):
@classmethod
- def setUpClass(cls):
+ def setUpClass(cls) -> None:
cls.lp = None
cls.fworker = FWorker()
try:
@@ -84,21 +85,21 @@ def setUpClass(cls):
raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.")
@classmethod
- def tearDownClass(cls):
+ def tearDownClass(cls) -> None:
if cls.lp:
cls.lp.connection.drop_database(TESTDB_NAME)
@staticmethod
- def _teardown(dests):
+ def _teardown(dests) -> None:
for f in dests:
if os.path.exists(f):
os.remove(f)
- def setUp(self):
+ def setUp(self) -> None:
self.lp.reset(password=None, require_password=False)
self.old_wd = os.getcwd()
- def tearDown(self):
+ def tearDown(self) -> None:
self.lp.reset(password=None, require_password=False)
if os.path.exists(os.path.join("FW.json")):
os.remove("FW.json")
@@ -110,14 +111,14 @@ def tearDown(self):
for i in glob.glob(os.path.join(MODULE_DIR, "launcher*")):
shutil.rmtree(i)
- def test_basic_fw(self):
+ def test_basic_fw(self) -> None:
test1 = ScriptTask.from_str("python -c 'print(\"test1\")'", {"store_stdout": True})
fw = Firework(test1)
self.lp.add_wf(fw)
launch_rocket(self.lp, self.fworker)
assert self.lp.get_launch_by_id(1).action.stored_data["stdout"] == "test1\n"
- def test_basic_fw_offline(self):
+ def test_basic_fw_offline(self) -> None:
test1 = ScriptTask.from_str("python -c 'print(\"test1\")'", {"store_stdout": True})
fw = Firework(test1)
self.lp.add_wf(fw)
@@ -152,7 +153,7 @@ def test_basic_fw_offline(self):
self.lp.recover_offline(launch["launch_id"])
assert self.lp.get_launch_by_id(1).action.stored_data["stdout"] == "test1\n"
- def test_offline_fw_passinfo(self):
+ def test_offline_fw_passinfo(self) -> None:
fw1 = Firework([AdditionTask()], {"input_array": [1, 1]}, name="1")
fw2 = Firework([AdditionTask()], {"input_array": [2, 2]}, name="2")
fw3 = Firework([AdditionTask()], {"input_array": [3]}, parents=[fw1, fw2], name="3")
@@ -198,7 +199,7 @@ def test_offline_fw_passinfo(self):
assert set(child_fw.spec["input_array"]) == {2, 3, 4}
assert child_fw.launches[0].action.stored_data["sum"] == 9
- def test_multi_fw(self):
+ def test_multi_fw(self) -> None:
test1 = ScriptTask.from_str("python -c 'print(\"test1\")'", {"store_stdout": True})
test2 = ScriptTask.from_str("python -c 'print(\"test2\")'", {"store_stdout": True})
fw = Firework([test1, test2])
@@ -206,7 +207,7 @@ def test_multi_fw(self):
launch_rocket(self.lp, self.fworker)
assert self.lp.get_launch_by_id(1).action.stored_data["stdout"] == "test2\n"
- def test_multi_fw_complex(self):
+ def test_multi_fw_complex(self) -> None:
dest1 = os.path.join(os.path.dirname(os.path.abspath(__file__)), "inputs.txt")
dest2 = os.path.join(os.path.dirname(os.path.abspath(__file__)), "temp_file.txt")
self._teardown([dest1, dest2])
@@ -234,7 +235,7 @@ def test_multi_fw_complex(self):
finally:
self._teardown([dest1, dest2])
- def test_backgroundtask(self):
+ def test_backgroundtask(self) -> None:
dest1 = os.path.join(os.path.dirname(os.path.abspath(__file__)), "hello.txt")
self._teardown([dest1])
@@ -256,13 +257,13 @@ def test_backgroundtask(self):
finally:
self._teardown([dest1])
- def test_add_fw(self):
+ def test_add_fw(self) -> None:
fw = Firework(AdditionTask(), {"input_array": [5, 7]})
self.lp.add_wf(fw)
rapidfire(self.lp, self.fworker, m_dir=MODULE_DIR)
assert self.lp.get_launch_by_id(1).action.stored_data["sum"] == 12
- def test_org_wf(self):
+ def test_org_wf(self) -> None:
test1 = ScriptTask.from_str("python -c 'print(\"test1\")'", {"store_stdout": True})
test2 = ScriptTask.from_str("python -c 'print(\"test2\")'", {"store_stdout": True})
fw1 = Firework(test1, fw_id=-1)
@@ -274,7 +275,7 @@ def test_org_wf(self):
launch_rocket(self.lp, self.fworker)
assert self.lp.get_launch_by_id(2).action.stored_data["stdout"] == "test2\n"
- def test_fibadder(self):
+ def test_fibadder(self) -> None:
fib = FibonacciAdderTask()
fw = Firework(fib, {"smaller": 0, "larger": 1, "stop_point": 3})
self.lp.add_wf(fw)
@@ -285,7 +286,7 @@ def test_fibadder(self):
assert self.lp.get_launch_by_id(3).action.stored_data == {}
assert not self.lp.run_exists()
- def test_parallel_fibadder(self):
+ def test_parallel_fibadder(self) -> None:
# this is really testing to see if a Workflow can handle multiple FWs updating it at once
parent = Firework(ScriptTask.from_str("python -c 'print(\"test1\")'", {"store_stdout": True}))
fib1 = Firework(FibonacciAdderTask(), {"smaller": 0, "larger": 1, "stop_point": 30}, parents=[parent])
@@ -300,7 +301,7 @@ def test_parallel_fibadder(self):
creds_array = [self.lp.to_dict()] * NCORES_PARALLEL_TEST
p.map(random_launch, creds_array)
- def test_multi_detour(self):
+ def test_multi_detour(self) -> None:
fw1 = Firework([MultipleDetourTask()], fw_id=1)
fw2 = Firework([ScriptTask.from_str('echo "DONE"')], parents=[fw1], fw_id=2)
self.lp.add_wf(Workflow([fw1, fw2]))
@@ -312,7 +313,7 @@ def test_multi_detour(self):
assert set(links[4]) == {2}
assert set(links[5]) == {2}
- def test_fw_env(self):
+ def test_fw_env(self) -> None:
t = DummyFWEnvTask()
fw = Firework(t)
self.lp.add_wf(fw)
@@ -322,7 +323,7 @@ def test_fw_env(self):
launch_rocket(self.lp, FWorker(env={"hello": "world"}))
assert self.lp.get_launch_by_id(2).action.stored_data["data"] == "world"
- def test_job_info(self):
+ def test_job_info(self) -> None:
fw1 = Firework([ScriptTask.from_str('echo "Testing job info"')], spec={"_pass_job_info": True}, fw_id=1)
fw2 = Firework([DummyJobPassTask()], parents=[fw1], spec={"_pass_job_info": True, "target": 1}, fw_id=2)
fw3 = Firework([DummyJobPassTask()], parents=[fw2], spec={"target": 2}, fw_id=3)
@@ -361,7 +362,7 @@ def test_job_info(self):
assert len(modified_spec["_job_info"]) == 2
- def test_files_in_out(self):
+ def test_files_in_out(self) -> None:
# create the Workflow that passes files_in and files_out
fw1 = Firework(
[ScriptTask.from_str('echo "This is the first FireWork" > test1')],
@@ -390,7 +391,7 @@ def test_files_in_out(self):
for f in ["test1", "hello.gz", "fwtest.2"]:
os.remove(f)
- def test_preserve_fworker(self):
+ def test_preserve_fworker(self) -> None:
fw1 = Firework(
[ScriptTask.from_str('echo "Testing preserve FWorker"')], spec={"_preserve_fworker": True}, fw_id=1
)
@@ -415,14 +416,14 @@ def test_preserve_fworker(self):
assert modified_spec["_fworker"] is not None
- def test_add_lp_and_fw_id(self):
+ def test_add_lp_and_fw_id(self) -> None:
fw1 = Firework([DummyLPTask()], spec={"_add_launchpad_and_fw_id": True})
self.lp.add_wf(fw1)
launch_rocket(self.lp, self.fworker)
assert self.lp.get_launch_by_id(1).action.stored_data["fw_id"] == 1
assert self.lp.get_launch_by_id(1).action.stored_data["host"] is not None
- def test_spec_copy(self):
+ def test_spec_copy(self) -> None:
task1 = ScriptTask.from_str('echo "Task 1"')
task2 = ScriptTask.from_str('echo "Task 2"')
@@ -436,7 +437,7 @@ def test_spec_copy(self):
assert self.lp.get_fw_by_id(1).tasks[0]["script"][0] == 'echo "Task 1"'
assert self.lp.get_fw_by_id(2).tasks[0]["script"][0] == 'echo "Task 2"'
- def test_category(self):
+ def test_category(self) -> None:
task1 = ScriptTask.from_str('echo "Task 1"')
task2 = ScriptTask.from_str('echo "Task 2"')
@@ -453,7 +454,7 @@ def test_category(self):
assert self.lp.run_exists(FWorker()) # can run any category
assert self.lp.run_exists(FWorker(category=["dummy_category", "other category"]))
- def test_category_pt2(self):
+ def test_category_pt2(self) -> None:
task1 = ScriptTask.from_str('echo "Task 1"')
task2 = ScriptTask.from_str('echo "Task 2"')
@@ -467,7 +468,7 @@ def test_category_pt2(self):
assert self.lp.run_exists(FWorker()) # can run any category
assert not self.lp.run_exists(FWorker(category=["dummy_category", "other category"]))
- def test_delete_fw(self):
+ def test_delete_fw(self) -> None:
test1 = ScriptTask.from_str("python -c 'print(\"test1\")'", {"store_stdout": True})
fw = Firework(test1)
self.lp.add_wf(fw)
@@ -479,7 +480,7 @@ def test_delete_fw(self):
with pytest.raises(ValueError):
self.lp.get_launch_by_id(1)
- def test_duplicate_delete_fw(self):
+ def test_duplicate_delete_fw(self) -> None:
test1 = ScriptTask.from_str("python -c 'print(\"test1\")'", {"store_stdout": True})
fw = Firework(test1, {"_dupefinder": DupeFinderExact()})
self.lp.add_wf(fw)
@@ -496,7 +497,7 @@ def test_duplicate_delete_fw(self):
self.lp.get_fw_by_id(del_id)
assert self.lp.get_launch_by_id(1).action.stored_data["stdout"] == "test1\n"
- def test_dupefinder(self):
+ def test_dupefinder(self) -> None:
test1 = ScriptTask.from_str("python -c 'print(\"test1\")'", {"store_stdout": True})
fw = Firework(test1, {"_dupefinder": DupeFinderExact()})
self.lp.add_wf(fw)
@@ -516,7 +517,7 @@ def test_dupefinder(self):
print("--------")
assert self.lp.launches.count_documents({}) == 1
- def test_append_wf(self):
+ def test_append_wf(self) -> None:
fw1 = Firework([UpdateSpecTask()])
fw2 = Firework([ModSpecTask()])
self.lp.add_wf(Workflow([fw1, fw2]))
@@ -548,7 +549,7 @@ def test_append_wf(self):
with pytest.raises(ValueError):
self.lp.append_wf(new_wf, [4], detour=True)
- def test_append_wf_detour(self):
+ def test_append_wf_detour(self) -> None:
fw1 = Firework([ModSpecTask()], fw_id=1)
fw2 = Firework([ModSpecTask()], fw_id=2, parents=[fw1])
self.lp.add_wf(Workflow([fw1, fw2]))
@@ -561,7 +562,7 @@ def test_append_wf_detour(self):
assert self.lp.get_fw_by_id(2).spec["dummy2"] == [True, True]
- def test_force_lock_removal(self):
+ def test_force_lock_removal(self) -> None:
test1 = ScriptTask.from_str("python -c 'print(\"test1\")'", {"store_stdout": True})
fw = Firework(test1, {"_dupefinder": DupeFinderExact()}, fw_id=1)
self.lp.add_wf(fw)
@@ -569,7 +570,7 @@ def test_force_lock_removal(self):
with WFLock(self.lp, 1), WFLock(self.lp, 1, kill=True, expire_secs=1):
assert True # dummy to make sure we got here
- def test_fizzle(self):
+ def test_fizzle(self) -> None:
p = PyTask(func="fireworks.tests.mongo_tests.throw_error", args=["Testing; this error is normal."])
fw = Firework(p)
self.lp.add_wf(fw)
@@ -577,21 +578,21 @@ def test_fizzle(self):
assert self.lp.get_fw_by_id(1).state == "FIZZLED"
assert not launch_rocket(self.lp, self.fworker)
- def test_defuse(self):
+ def test_defuse(self) -> None:
p = PyTask(func="fireworks.tests.mongo_tests.throw_error", args=["This should not happen"])
fw = Firework(p)
self.lp.add_wf(fw)
self.lp.defuse_fw(fw.fw_id)
assert not launch_rocket(self.lp, self.fworker)
- def test_archive(self):
+ def test_archive(self) -> None:
p = PyTask(func="fireworks.tests.mongo_tests.throw_error", args=["This should not happen"])
fw = Firework(p)
self.lp.add_wf(fw)
self.lp.archive_wf(fw.fw_id)
assert not launch_rocket(self.lp, self.fworker)
- def test_stats(self):
+ def test_stats(self) -> None:
test1 = ScriptTask.from_str("python -c 'print(\"test1\")'", {"store_stdout": True})
fw = Firework(test1)
self.lp.add_wf(fw)
diff --git a/fireworks/tests/multiprocessing_tests.py b/fireworks/tests/multiprocessing_tests.py
index 48ca49705..e6a0a1e8d 100644
--- a/fireworks/tests/multiprocessing_tests.py
+++ b/fireworks/tests/multiprocessing_tests.py
@@ -17,7 +17,7 @@
class TestLinks(TestCase):
- def test_pickle(self):
+ def test_pickle(self) -> None:
links1 = Workflow.Links({1: 2, 3: [5, 7, 8]})
s = pickle.dumps(links1)
links2 = pickle.loads(s)
@@ -28,7 +28,7 @@ class TestCheckoutFW(TestCase):
lp = None
@classmethod
- def setUpClass(cls):
+ def setUpClass(cls) -> None:
cls.fworker = FWorker()
try:
cls.lp = LaunchPad(name=TESTDB_NAME, strm_lvl="ERROR")
@@ -37,14 +37,14 @@ def setUpClass(cls):
raise unittest.SkipTest("MongoDB is not running in localhost: 27017! Skipping tests.")
@classmethod
- def tearDownClass(cls):
+ def tearDownClass(cls) -> None:
if cls.lp:
cls.lp.connection.drop_database(TESTDB_NAME)
- def setUp(self):
+ def setUp(self) -> None:
self.old_wd = os.getcwd()
- def tearDown(self):
+ def tearDown(self) -> None:
self.lp.reset(password=None, require_password=False)
os.chdir(self.old_wd)
if os.path.exists(os.path.join("FW.json")):
@@ -53,7 +53,7 @@ def tearDown(self):
for i in glob.glob(os.path.join(MODULE_DIR, "launcher*")):
shutil.rmtree(i)
- def test_checkout_fw(self):
+ def test_checkout_fw(self) -> None:
os.chdir(MODULE_DIR)
self.lp.add_wf(
Firework(ScriptTask.from_str(shell_cmd='echo "hello 1"', parameters={"stdout_file": "task.out"}), fw_id=1)
@@ -76,7 +76,7 @@ class TestEarlyExit(TestCase):
lp = None
@classmethod
- def setUpClass(cls):
+ def setUpClass(cls) -> None:
cls.fworker = FWorker()
try:
cls.lp = LaunchPad(name=TESTDB_NAME, strm_lvl="ERROR")
@@ -85,14 +85,14 @@ def setUpClass(cls):
raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.")
@classmethod
- def tearDownClass(cls):
+ def tearDownClass(cls) -> None:
if cls.lp:
cls.lp.connection.drop_database(TESTDB_NAME)
- def setUp(self):
+ def setUp(self) -> None:
self.old_wd = os.getcwd()
- def tearDown(self):
+ def tearDown(self) -> None:
self.lp.reset(password=None, require_password=False)
os.chdir(self.old_wd)
if os.path.exists(os.path.join("FW.json")):
@@ -101,7 +101,7 @@ def tearDown(self):
for i in glob.glob(os.path.join(MODULE_DIR, "launcher*")):
shutil.rmtree(i)
- def test_early_exit(self):
+ def test_early_exit(self) -> None:
os.chdir(MODULE_DIR)
script_text = "echo hello from process $PPID; sleep 2"
fw1 = Firework(ScriptTask.from_str(shell_cmd=script_text, parameters={"stdout_file": "task.out"}), fw_id=1)
diff --git a/fireworks/tests/tasks.py b/fireworks/tests/tasks.py
index 1c8a8c744..c7e3452ee 100644
--- a/fireworks/tests/tasks.py
+++ b/fireworks/tests/tasks.py
@@ -1,6 +1,5 @@
"""TODO: Modify module doc."""
-
__author__ = "Shyue Ping Ong"
__copyright__ = "Copyright 2012, The Materials Project"
__maintainer__ = "Shyue Ping Ong"
diff --git a/fireworks/tests/test_fw_config.py b/fireworks/tests/test_fw_config.py
index d08677565..ab8a0ef46 100644
--- a/fireworks/tests/test_fw_config.py
+++ b/fireworks/tests/test_fw_config.py
@@ -10,7 +10,7 @@
class ConfigTest(unittest.TestCase):
- def test_config(self):
+ def test_config(self) -> None:
d = config_to_dict()
assert "NEGATIVE_FWID_CTR" not in d
diff --git a/fireworks/tests/test_workflow.py b/fireworks/tests/test_workflow.py
index 655361be6..875bcd194 100644
--- a/fireworks/tests/test_workflow.py
+++ b/fireworks/tests/test_workflow.py
@@ -4,55 +4,55 @@
class TestWorkflowState(unittest.TestCase):
- def test_completed(self):
+ def test_completed(self) -> None:
# all leaves complete
one = fw.Firework([], state="COMPLETED", fw_id=1)
two = fw.Firework([], state="COMPLETED", fw_id=2)
assert fw.Workflow([one, two]).state == "COMPLETED"
- def test_archived(self):
+ def test_archived(self) -> None:
one = fw.Firework([], state="ARCHIVED", fw_id=1)
two = fw.Firework([], state="ARCHIVED", fw_id=2)
assert fw.Workflow([one, two]).state == "ARCHIVED"
- def test_defused(self):
+ def test_defused(self) -> None:
# any defused == defused
one = fw.Firework([], state="COMPLETED", fw_id=1)
two = fw.Firework([], state="DEFUSED", fw_id=2)
assert fw.Workflow([one, two]).state == "DEFUSED"
- def test_paused(self):
+ def test_paused(self) -> None:
# any paused == paused
one = fw.Firework([], state="COMPLETED", fw_id=1)
two = fw.Firework([], state="PAUSED", fw_id=2)
assert fw.Workflow([one, two]).state == "PAUSED"
- def test_fizzled_1(self):
+ def test_fizzled_1(self) -> None:
# WF(Fizzled -> Waiting(no fizz parents)) == FIZZLED
one = fw.Firework([], state="FIZZLED", fw_id=1)
two = fw.Firework([], state="WAITING", fw_id=2, parents=one)
assert fw.Workflow([one, two]).state == "FIZZLED"
- def test_fizzled_2(self):
+ def test_fizzled_2(self) -> None:
# WF(Fizzled -> Ready(allow fizz parents)) == RUNNING
one = fw.Firework([], state="FIZZLED", fw_id=1)
two = fw.Firework([], state="READY", fw_id=2, spec={"_allow_fizzled_parents": True}, parents=one)
assert fw.Workflow([one, two]).state == "RUNNING"
- def test_fizzled_3(self):
+ def test_fizzled_3(self) -> None:
# WF(Fizzled -> Completed(allow fizz parents)) == COMPLETED
one = fw.Firework([], state="FIZZLED", fw_id=1)
two = fw.Firework([], state="COMPLETED", fw_id=2, spec={"_allow_fizzled_parents": True}, parents=one)
assert fw.Workflow([one, two]).state == "COMPLETED"
- def test_fizzled_4(self):
+ def test_fizzled_4(self) -> None:
# one child doesn't allow fizzled parents
one = fw.Firework([], state="FIZZLED", fw_id=1)
two = fw.Firework([], state="READY", fw_id=2, spec={"_allow_fizzled_parents": True}, parents=one)
@@ -60,39 +60,39 @@ def test_fizzled_4(self):
assert fw.Workflow([one, two, three]).state == "FIZZLED"
- def test_fizzled_5(self):
+ def test_fizzled_5(self) -> None:
# leaf is fizzled, wf is fizzled
one = fw.Firework([], state="COMPLETED", fw_id=1)
two = fw.Firework([], state="FIZZLED", fw_id=2, parents=one)
assert fw.Workflow([one, two]).state == "FIZZLED"
- def test_fizzled_6(self):
+ def test_fizzled_6(self) -> None:
# deep fizzled fireworks, but still RUNNING
one = fw.Firework([], state="FIZZLED", fw_id=1)
two = fw.Firework([], state="FIZZLED", fw_id=2, spec={"_allow_fizzled_parents": True}, parents=one)
three = fw.Firework([], state="READY", fw_id=3, spec={"_allow_fizzled_parents": True}, parents=two)
assert fw.Workflow([one, two, three]).state == "RUNNING"
- def test_running_1(self):
+ def test_running_1(self) -> None:
one = fw.Firework([], state="COMPLETED", fw_id=1)
two = fw.Firework([], state="READY", fw_id=2, parents=one)
assert fw.Workflow([one, two]).state == "RUNNING"
- def test_running_2(self):
+ def test_running_2(self) -> None:
one = fw.Firework([], state="RUNNING", fw_id=1)
two = fw.Firework([], state="WAITING", fw_id=2, parents=one)
assert fw.Workflow([one, two]).state == "RUNNING"
- def test_reserved(self):
+ def test_reserved(self) -> None:
one = fw.Firework([], state="RESERVED", fw_id=1)
two = fw.Firework([], state="READY", fw_id=2, parents=one)
assert fw.Workflow([one, two]).state == "RESERVED"
- def test_ready(self):
+ def test_ready(self) -> None:
one = fw.Firework([], state="READY", fw_id=1)
two = fw.Firework([], state="READY", fw_id=2, parents=one)
diff --git a/fireworks/user_objects/firetasks/dataflow_tasks.py b/fireworks/user_objects/firetasks/dataflow_tasks.py
index e2a8d4992..8467207d5 100644
--- a/fireworks/user_objects/firetasks/dataflow_tasks.py
+++ b/fireworks/user_objects/firetasks/dataflow_tasks.py
@@ -6,6 +6,8 @@
import sys
+from ruamel.yaml import YAML
+
from fireworks import Firework
from fireworks.core.firework import FiretaskBase, FWAction
from fireworks.utilities.fw_serializers import load_object
@@ -381,8 +383,6 @@ def run_task(self, fw_spec):
import operator
from functools import reduce
- import ruamel.yaml as yaml
-
filename = self["filename"]
mapstring = self["mapstring"]
assert isinstance(filename, basestring)
@@ -392,7 +392,7 @@ def run_task(self, fw_spec):
fmt = filename.split(".")[-1]
assert fmt in ["json", "yaml"]
with open(filename) as inp:
- data = json.load(inp) if fmt == "json" else yaml.safe_load(inp)
+ data = json.load(inp) if fmt == "json" else YAML(typ="safe", pure=True).load(inp)
leaf = reduce(operator.getitem, maplist[:-1], fw_spec)
if isinstance(data, dict):
diff --git a/fireworks/user_objects/firetasks/fileio_tasks.py b/fireworks/user_objects/firetasks/fileio_tasks.py
index 29bcc8452..7ff5923e8 100755
--- a/fireworks/user_objects/firetasks/fileio_tasks.py
+++ b/fireworks/user_objects/firetasks/fileio_tasks.py
@@ -31,7 +31,7 @@ class FileWriteTask(FiretaskBase):
required_params = ["files_to_write"]
optional_params = ["dest"]
- def run_task(self, fw_spec):
+ def run_task(self, fw_spec) -> None:
pth = self.get("dest", os.getcwd())
for d in self["files_to_write"]:
with open(os.path.join(pth, d["filename"]), "w") as f:
@@ -54,7 +54,7 @@ class FileDeleteTask(FiretaskBase):
required_params = ["files_to_delete"]
optional_params = ["dest", "ignore_errors"]
- def run_task(self, fw_spec):
+ def run_task(self, fw_spec) -> None:
pth = self.get("dest", os.getcwd())
ignore_errors = self.get("ignore_errors", True)
for f in self["files_to_delete"]:
@@ -98,7 +98,7 @@ class FileTransferTask(FiretaskBase):
"copyfile": shutil.copyfile,
}
- def run_task(self, fw_spec):
+ def run_task(self, fw_spec) -> None:
shell_interpret = self.get("shell_interpret", True)
ignore_errors = self.get("ignore_errors", False)
max_retry = self.get("max_retry", 0)
@@ -162,7 +162,7 @@ def run_task(self, fw_spec):
ssh.close()
@staticmethod
- def _rexists(sftp, path):
+ def _rexists(sftp, path) -> bool:
"""os.path.exists for paramiko's SCP object."""
try:
sftp.stat(path)
@@ -187,7 +187,7 @@ class CompressDirTask(FiretaskBase):
_fw_name = "CompressDirTask"
optional_params = ["compression", "dest", "ignore_errors"]
- def run_task(self, fw_spec):
+ def run_task(self, fw_spec) -> None:
ignore_errors = self.get("ignore_errors", False)
dest = self.get("dest", os.getcwd())
compression = self.get("compression", "gz")
@@ -211,7 +211,7 @@ class DecompressDirTask(FiretaskBase):
_fw_name = "DecompressDirTask"
optional_params = ["dest", "ignore_errors"]
- def run_task(self, fw_spec):
+ def run_task(self, fw_spec) -> None:
ignore_errors = self.get("ignore_errors", False)
dest = self.get("dest", os.getcwd())
try:
@@ -235,5 +235,5 @@ class ArchiveDirTask(FiretaskBase):
required_params = ["base_name"]
optional_params = ["format"]
- def run_task(self, fw_spec):
+ def run_task(self, fw_spec) -> None:
shutil.make_archive(self["base_name"], format=self.get("format", "gztar"), root_dir=".")
diff --git a/fireworks/user_objects/firetasks/filepad_tasks.py b/fireworks/user_objects/firetasks/filepad_tasks.py
index 55ec172cf..0302f887a 100644
--- a/fireworks/user_objects/firetasks/filepad_tasks.py
+++ b/fireworks/user_objects/firetasks/filepad_tasks.py
@@ -31,7 +31,7 @@ class AddFilesTask(FiretaskBase):
required_params = ["paths"]
optional_params = ["identifiers", "directory", "filepad_file", "compress", "metadata"]
- def run_task(self, fw_spec):
+ def run_task(self, fw_spec) -> None:
directory = os.path.abspath(self.get("directory", "."))
@@ -71,7 +71,7 @@ class GetFilesTask(FiretaskBase):
required_params = ["identifiers"]
optional_params = ["filepad_file", "dest_dir", "new_file_names"]
- def run_task(self, fw_spec):
+ def run_task(self, fw_spec) -> None:
fpad = get_fpad(self.get("filepad_file", None))
dest_dir = self.get("dest_dir", os.path.abspath("."))
new_file_names = self.get("new_file_names", [])
@@ -145,7 +145,7 @@ class GetFilesByQueryTask(FiretaskBase):
"sort_key",
]
- def run_task(self, fw_spec):
+ def run_task(self, fw_spec) -> None:
fpad = get_fpad(self.get("filepad_file", None))
dest_dir = self.get("dest_dir", os.path.abspath("."))
new_file_names = self.get("new_file_names", [])
@@ -202,7 +202,7 @@ class DeleteFilesTask(FiretaskBase):
required_params = ["identifiers"]
optional_params = ["filepad_file"]
- def run_task(self, fw_spec):
+ def run_task(self, fw_spec) -> None:
fpad = get_fpad(self.get("filepad_file", None))
for file in self["identifiers"]:
fpad.delete_file(file)
diff --git a/fireworks/user_objects/firetasks/script_task.py b/fireworks/user_objects/firetasks/script_task.py
index 26f09a958..b86dc3294 100644
--- a/fireworks/user_objects/firetasks/script_task.py
+++ b/fireworks/user_objects/firetasks/script_task.py
@@ -1,10 +1,11 @@
"""This module includes tasks to integrate scripts and python functions."""
+from __future__ import annotations
+
import builtins
import shlex
import subprocess
import sys
-from typing import Dict, List, Optional, Union
from fireworks.core.firework import FiretaskBase, FWAction
@@ -91,7 +92,7 @@ def _run_task_internal(self, fw_spec, stdin):
return FWAction(stored_data=output)
- def _load_params(self, d):
+ def _load_params(self, d) -> None:
if d.get("stdin_file") and d.get("stdin_key"):
raise ValueError("ScriptTask cannot process both a key and file as the standard in!")
@@ -163,7 +164,7 @@ class PyTask(FiretaskBase):
# note that we are not using "optional_params" because we do not want to do
# strict parameter checking in FiretaskBase due to "auto_kwargs" option
- def run_task(self, fw_spec: Dict[str, Union[List[int], int]]) -> Optional[FWAction]:
+ def run_task(self, fw_spec: dict[str, list[int] | int]) -> FWAction | None:
toks = self["func"].rsplit(".", 1)
if len(toks) == 2:
mod_name, funcname = toks
@@ -177,8 +178,7 @@ def run_task(self, fw_spec: Dict[str, Union[List[int], int]]) -> Optional[FWActi
inputs = self.get("inputs", [])
assert isinstance(inputs, list)
- for item in inputs:
- args.append(fw_spec[item])
+ args += [fw_spec[item] for item in inputs]
if self.get("auto_kwargs"):
kwargs = {
diff --git a/fireworks/user_objects/firetasks/templatewriter_task.py b/fireworks/user_objects/firetasks/templatewriter_task.py
index 2fd9e0438..22dc082ef 100644
--- a/fireworks/user_objects/firetasks/templatewriter_task.py
+++ b/fireworks/user_objects/firetasks/templatewriter_task.py
@@ -31,7 +31,7 @@ class TemplateWriterTask(FiretaskBase):
_fw_name = "TemplateWriterTask"
- def run_task(self, fw_spec):
+ def run_task(self, fw_spec) -> None:
if self.get("use_global_spec"):
self._load_params(fw_spec)
else:
@@ -45,7 +45,7 @@ def run_task(self, fw_spec):
with open(self.output_file, write_mode) as of:
of.write(output)
- def _load_params(self, d):
+ def _load_params(self, d) -> None:
self.context = d["context"]
self.output_file = d["output_file"]
self.append_file = d.get("append") # append to output file?
diff --git a/fireworks/user_objects/firetasks/tests/test_dataflow_tasks.py b/fireworks/user_objects/firetasks/tests/test_dataflow_tasks.py
index 58932d5b8..098213fa1 100644
--- a/fireworks/user_objects/firetasks/tests/test_dataflow_tasks.py
+++ b/fireworks/user_objects/firetasks/tests/test_dataflow_tasks.py
@@ -5,6 +5,8 @@
import uuid
from unittest import SkipTest
+from ruamel.yaml import YAML
+
from fireworks.user_objects.firetasks.dataflow_tasks import (
CommandLineTask,
ForeachTask,
@@ -26,7 +28,7 @@ def afunc(array, power):
class CommandLineTaskTest(unittest.TestCase):
"""run tests for CommandLineTask."""
- def test_command_line_task_1(self):
+ def test_command_line_task_1(self) -> None:
"""Input from string to stdin, output from stdout to string."""
params = {
"command_spec": {
@@ -47,7 +49,7 @@ def test_command_line_task_1(self):
output_string = action.mod_spec[0]["_push"]["output string"]["value"]
assert output_string == "Hello world!"
- def test_command_line_task_2(self):
+ def test_command_line_task_2(self) -> None:
"""
input from string to data, output from stdout to file;
input from file to stdin, output from stdout to string and from file.
@@ -92,7 +94,7 @@ def test_command_line_task_2(self):
os.remove(filename)
os.remove(output_file)
- def test_command_line_task_3(self):
+ def test_command_line_task_3(self) -> None:
"""Input from string to data with command line options."""
import platform
@@ -153,7 +155,7 @@ def test_command_line_task_3(self):
assert time_stamp_1[11:19] == time_stamp_2[11:19]
os.remove(filename)
- def test_command_line_task_4(self):
+ def test_command_line_task_4(self) -> None:
"""Multiple string inputs, multiple file outputs."""
params = {
"command_spec": {
@@ -186,7 +188,7 @@ def test_command_line_task_4(self):
class ForeachTaskTest(unittest.TestCase):
"""run tests for ForeachTask."""
- def test_foreach_pytask(self):
+ def test_foreach_pytask(self) -> None:
"""Run PyTask for a list of numbers."""
numbers = [0, 1, 2, 3, 4]
power = 2
@@ -207,7 +209,7 @@ def test_foreach_pytask(self):
for number, result in zip(numbers, results):
assert result == pow(number, power)
- def test_foreach_commandlinetask(self):
+ def test_foreach_commandlinetask(self) -> None:
"""Run CommandLineTask for a list of input data."""
inputs = ["black", "white", 2.5, 17]
worklist = [{"source": {"type": "data", "value": s}} for s in inputs]
@@ -240,7 +242,7 @@ def test_foreach_commandlinetask(self):
class JoinDictTaskTest(unittest.TestCase):
"""run tests for JoinDictTask."""
- def test_join_dict_task(self):
+ def test_join_dict_task(self) -> None:
"""Joins dictionaries into a new or existing dict in spec."""
temperature = {"value": 273.15, "units": "Kelvin"}
pressure = {"value": 1.2, "units": "bar"}
@@ -266,7 +268,7 @@ def test_join_dict_task(self):
class JoinListTaskTest(unittest.TestCase):
"""run tests for JoinListTask."""
- def test_join_list_task(self):
+ def test_join_list_task(self) -> None:
"""Joins items into a new or existing list in spec."""
temperature = {"value": 273.15, "units": "Kelvin"}
pressure = {"value": 1.2, "units": "bar"}
@@ -289,15 +291,13 @@ def test_join_list_task(self):
class ImportDataTaskTest(unittest.TestCase):
"""run tests for ImportDataTask."""
- def test_import_data_task(self):
+ def test_import_data_task(self) -> None:
"""Loads data from a file into spec."""
import json
- import ruamel.yaml as yaml
-
temperature = {"value": 273.15, "units": "Kelvin"}
spec = {"state parameters": {}}
- formats = {"json": json, "yaml": yaml}
+ formats = {"json": json, "yaml": YAML(typ="safe", pure=True)}
params = {"mapstring": "state parameters/temperature"}
for fmt in formats:
filename = str(uuid.uuid4()) + "." + fmt
diff --git a/fireworks/user_objects/firetasks/tests/test_fileio_tasks.py b/fireworks/user_objects/firetasks/tests/test_fileio_tasks.py
index 067e9ec6f..81f34c0e2 100644
--- a/fireworks/user_objects/firetasks/tests/test_fileio_tasks.py
+++ b/fireworks/user_objects/firetasks/tests/test_fileio_tasks.py
@@ -21,13 +21,13 @@
class FileWriteDeleteTest(unittest.TestCase):
- def test_init(self):
+ def test_init(self) -> None:
FileWriteTask(files_to_write="hello")
FileWriteTask({"files_to_write": "hello"})
with pytest.raises(RuntimeError):
FileWriteTask()
- def test_run(self):
+ def test_run(self) -> None:
t = load_object_from_file(os.path.join(module_dir, "write.yaml"))
t.run_task({})
for i in range(2):
@@ -41,11 +41,11 @@ def test_run(self):
class CompressDecompressArchiveDirTest(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.cwd = os.getcwd()
os.chdir(module_dir)
- def test_compress_dir(self):
+ def test_compress_dir(self) -> None:
c = CompressDirTask(compression="gz")
c.run_task({})
assert os.path.exists("delete.yaml.gz")
@@ -55,13 +55,13 @@ def test_compress_dir(self):
assert not os.path.exists("delete.yaml.gz")
assert os.path.exists("delete.yaml")
- def test_archive_dir(self):
+ def test_archive_dir(self) -> None:
a = ArchiveDirTask(base_name="archive", format="gztar")
a.run_task({})
assert os.path.exists("archive.tar.gz")
os.remove("archive.tar.gz")
- def tearDown(self):
+ def tearDown(self) -> None:
os.chdir(self.cwd)
diff --git a/fireworks/user_objects/firetasks/tests/test_filepad_tasks.py b/fireworks/user_objects/firetasks/tests/test_filepad_tasks.py
index 8350d26a7..7b807a681 100644
--- a/fireworks/user_objects/firetasks/tests/test_filepad_tasks.py
+++ b/fireworks/user_objects/firetasks/tests/test_filepad_tasks.py
@@ -18,12 +18,12 @@
class FilePadTasksTest(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.paths = [os.path.join(module_dir, "write.yaml"), os.path.join(module_dir, "delete.yaml")]
self.identifiers = ["write", "delete"]
self.fp = FilePad.auto_load()
- def test_addfilestask_run(self):
+ def test_addfilestask_run(self) -> None:
t = AddFilesTask(paths=self.paths, identifiers=self.identifiers)
t.run_task({})
write_file_contents, _ = self.fp.get_file("write")
@@ -33,7 +33,7 @@ def test_addfilestask_run(self):
with open(self.paths[1]) as f:
assert del_file_contents == f.read().encode()
- def test_deletefilestask_run(self):
+ def test_deletefilestask_run(self) -> None:
t = DeleteFilesTask(identifiers=self.identifiers)
t.run_task({})
file_contents, doc = self.fp.get_file("write")
@@ -43,7 +43,7 @@ def test_deletefilestask_run(self):
assert file_contents is None
assert doc is None
- def test_getfilestask_run(self):
+ def test_getfilestask_run(self) -> None:
t = AddFilesTask(paths=self.paths, identifiers=self.identifiers)
t.run_task({})
dest_dir = os.path.abspath(".")
@@ -56,7 +56,7 @@ def test_getfilestask_run(self):
assert write_file_contents == f.read().encode()
os.remove(os.path.join(dest_dir, new_file_names[0]))
- def test_getfilesbyquerytask_run(self):
+ def test_getfilesbyquerytask_run(self) -> None:
"""Tests querying objects from FilePad by metadata."""
t = AddFilesTask(paths=self.paths, identifiers=self.identifiers, metadata={"key": "value"})
t.run_task({})
@@ -69,7 +69,7 @@ def test_getfilesbyquerytask_run(self):
assert test_file_contents == file.read().encode()
os.remove(os.path.join(dest_dir, new_file_names[0]))
- def test_getfilesbyquerytask_run(self):
+ def test_getfilesbyquerytask_run(self) -> None:
"""Tests querying objects from FilePad by metadata."""
with open("original_test_file.txt", "w") as f:
f.write("Some file with some content")
@@ -87,7 +87,7 @@ def test_getfilesbyquerytask_run(self):
assert test_file_contents == f.read().encode()
os.remove(os.path.join(dest_dir, "queried_test_file.txt"))
- def test_getfilesbyquerytask_metafile_run(self):
+ def test_getfilesbyquerytask_metafile_run(self) -> None:
"""Tests writing metadata to a yaml file."""
with open("original_test_file.txt", "w") as f:
f.write("Some file with some content")
@@ -113,7 +113,7 @@ def test_getfilesbyquerytask_metafile_run(self):
os.remove(os.path.join(dest_dir, "queried_test_file.txt"))
os.remove(os.path.join(dest_dir, "queried_test_file.txt.meta.yaml"))
- def test_getfilesbyquerytask_ignore_empty_result_run(self):
+ def test_getfilesbyquerytask_ignore_empty_result_run(self) -> None:
"""Tests on ignoring empty results from FilePad query."""
dest_dir = os.path.abspath(".")
t = GetFilesByQueryTask(
@@ -125,7 +125,7 @@ def test_getfilesbyquerytask_ignore_empty_result_run(self):
t.run_task({})
# test successful if no exception raised
- def test_getfilesbyquerytask_raise_empty_result_run(self):
+ def test_getfilesbyquerytask_raise_empty_result_run(self) -> None:
"""Tests on raising exception on empty results from FilePad query."""
dest_dir = os.path.abspath(".")
t = GetFilesByQueryTask(
@@ -138,7 +138,7 @@ def test_getfilesbyquerytask_raise_empty_result_run(self):
t.run_task({})
# test successful if exception raised
- def test_getfilesbyquerytask_ignore_degenerate_file_name(self):
+ def test_getfilesbyquerytask_ignore_degenerate_file_name(self) -> None:
"""Tests on ignoring degenerate file name in result from FilePad query."""
with open("degenerate_file.txt", "w") as f:
f.write("Some file with some content")
@@ -158,7 +158,7 @@ def test_getfilesbyquerytask_ignore_degenerate_file_name(self):
t.run_task({})
# test successful if no exception raised
- def test_getfilesbyquerytask_raise_degenerate_file_name(self):
+ def test_getfilesbyquerytask_raise_degenerate_file_name(self) -> None:
"""Tests on raising exception on degenerate file name from FilePad query."""
with open("degenerate_file.txt", "w") as f:
f.write("Some file with some content")
@@ -179,7 +179,7 @@ def test_getfilesbyquerytask_raise_degenerate_file_name(self):
t.run_task({})
# test successful if exception raised
- def test_getfilesbyquerytask_sort_ascending_name_run(self):
+ def test_getfilesbyquerytask_sort_ascending_name_run(self) -> None:
"""Tests on sorting queried files in ascending order."""
file_contents = ["Some file with some content", "Some other file with some other content"]
@@ -209,7 +209,7 @@ def test_getfilesbyquerytask_sort_ascending_name_run(self):
with open("degenerate_file.txt") as f:
assert file_contents[-1] == f.read()
- def test_getfilesbyquerytask_sort_descending_name_run(self):
+ def test_getfilesbyquerytask_sort_descending_name_run(self) -> None:
"""Tests on sorting queried files in descending order."""
file_contents = ["Some file with some content", "Some other file with some other content"]
@@ -244,17 +244,17 @@ def test_getfilesbyquerytask_sort_descending_name_run(self):
os.remove("degenerate_file.txt")
- def test_addfilesfrompatterntask_run(self):
+ def test_addfilesfrompatterntask_run(self) -> None:
t = AddFilesTask(paths="*.yaml", directory=module_dir)
t.run_task({})
write_file_contents, _ = self.fp.get_file(self.paths[0])
with open(self.paths[0]) as f:
assert write_file_contents == f.read().encode()
- del_file_contents, wdoc = self.fp.get_file(self.paths[1])
+ del_file_contents, _wdoc = self.fp.get_file(self.paths[1])
with open(self.paths[1]) as f:
assert del_file_contents == f.read().encode()
- def tearDown(self):
+ def tearDown(self) -> None:
self.fp.reset()
diff --git a/fireworks/user_objects/firetasks/tests/test_script_task.py b/fireworks/user_objects/firetasks/tests/test_script_task.py
index bc643df4d..f5d82c2b8 100644
--- a/fireworks/user_objects/firetasks/tests/test_script_task.py
+++ b/fireworks/user_objects/firetasks/tests/test_script_task.py
@@ -17,7 +17,7 @@ def afunc(y, z, a):
class ScriptTaskTest(unittest.TestCase):
- def test_scripttask(self):
+ def test_scripttask(self) -> None:
if os.path.exists("hello.txt"):
os.remove("hello.txt")
s = ScriptTask({"script": 'echo "hello world"', "stdout_file": "hello.txt"})
@@ -30,7 +30,7 @@ def test_scripttask(self):
class PyTaskTest(unittest.TestCase):
- def test_task(self):
+ def test_task(self) -> None:
p = PyTask(func="json.dumps", kwargs={"obj": {"hello": "world"}}, stored_data_varname="json")
a = p.run_task({})
assert a.stored_data["json"] == '{"hello": "world"}'
@@ -40,7 +40,7 @@ def test_task(self):
p = PyTask(func="print", args=[3])
p.run_task({})
- def test_task_auto_kwargs(self):
+ def test_task_auto_kwargs(self) -> None:
p = PyTask(func="json.dumps", obj={"hello": "world"}, stored_data_varname="json", auto_kwargs=True)
a = p.run_task({})
assert a.stored_data["json"] == '{"hello": "world"}'
@@ -50,7 +50,7 @@ def test_task_auto_kwargs(self):
p = PyTask(func="print", args=[3])
p.run_task({})
- def test_task_data_flow(self):
+ def test_task_data_flow(self) -> None:
"""Test dataflow parameters: inputs, outputs and chunk_number."""
params = {"func": "pow", "inputs": ["arg", "power", "modulo"], "stored_data_varname": "data"}
spec = {"arg": 2, "power": 3, "modulo": None}
diff --git a/fireworks/user_objects/firetasks/tests/test_templatewriter_task.py b/fireworks/user_objects/firetasks/tests/test_templatewriter_task.py
index c3206b1b2..6f2b40f20 100644
--- a/fireworks/user_objects/firetasks/tests/test_templatewriter_task.py
+++ b/fireworks/user_objects/firetasks/tests/test_templatewriter_task.py
@@ -1,6 +1,5 @@
"""TODO: Modify unittest doc."""
-
__author__ = "Bharat Medasani"
__copyright__ = "Copyright 2012, The Materials Project"
__maintainer__ = "Bharat Medasani"
@@ -14,7 +13,7 @@
class TemplateWriterTaskTest(unittest.TestCase):
- def test_task(self):
+ def test_task(self) -> None:
with open("test_template.txt", "w") as fp:
fp.write("option1 = {{opt1}}\noption2 = {{opt2}}")
t = TemplateWriterTask(
diff --git a/fireworks/user_objects/firetasks/unittest_tasks.py b/fireworks/user_objects/firetasks/unittest_tasks.py
index aee84717d..8c4b3b41f 100644
--- a/fireworks/user_objects/firetasks/unittest_tasks.py
+++ b/fireworks/user_objects/firetasks/unittest_tasks.py
@@ -12,7 +12,7 @@
class TestSerializer(FWSerializable):
_fw_name = "TestSerializer Name"
- def __init__(self, a, m_date):
+ def __init__(self, a, m_date) -> None:
if not isinstance(m_date, datetime.datetime):
raise ValueError("m_date must be a datetime instance!")
@@ -34,7 +34,7 @@ def from_dict(cls, m_dict):
class ExportTestSerializer(FWSerializable):
_fw_name = "TestSerializer Export Name"
- def __init__(self, a):
+ def __init__(self, a) -> None:
self.a = a
def __eq__(self, other):
diff --git a/fireworks/user_objects/queue_adapters/common_adapter.py b/fireworks/user_objects/queue_adapters/common_adapter.py
index 2d25e16d6..6e1380b25 100644
--- a/fireworks/user_objects/queue_adapters/common_adapter.py
+++ b/fireworks/user_objects/queue_adapters/common_adapter.py
@@ -39,7 +39,7 @@ class CommonAdapter(QueueAdapterBase):
"MOAB": {"submit_cmd": "msub", "status_cmd": "showq"},
}
- def __init__(self, q_type, q_name=None, template_file=None, timeout=None, **kwargs):
+ def __init__(self, q_type, q_name=None, template_file=None, timeout=None, **kwargs) -> None:
"""
Initializes a new QueueAdapter object.
diff --git a/fireworks/user_objects/queue_adapters/pbs_newt_adapter.py b/fireworks/user_objects/queue_adapters/pbs_newt_adapter.py
index 60f8cda51..d966af1c2 100644
--- a/fireworks/user_objects/queue_adapters/pbs_newt_adapter.py
+++ b/fireworks/user_objects/queue_adapters/pbs_newt_adapter.py
@@ -40,7 +40,7 @@ def get_njobs_in_queue(self, username=None):
return len(r.json())
@staticmethod
- def _init_auth_session(max_pw_requests=3):
+ def _init_auth_session(max_pw_requests=3) -> None:
"""
Initialize the _session class var with an authorized session. Asks for a /
password in new sessions, skips PW check for previously authenticated sessions.
diff --git a/fireworks/user_objects/queue_adapters/tests/test_common_adapter.py b/fireworks/user_objects/queue_adapters/tests/test_common_adapter.py
index c08d0289b..6f89c0cb5 100644
--- a/fireworks/user_objects/queue_adapters/tests/test_common_adapter.py
+++ b/fireworks/user_objects/queue_adapters/tests/test_common_adapter.py
@@ -1,20 +1,22 @@
"""TODO: Modify unittest doc."""
-
__author__ = "Shyue Ping Ong"
__copyright__ = "Copyright 2012, The Materials Project"
__maintainer__ = "Shyue Ping Ong"
__email__ = "shyuep@gmail.com"
__date__ = "12/31/13"
+import sys
import unittest
+from ruamel.yaml import YAML
+
from fireworks.user_objects.queue_adapters.common_adapter import CommonAdapter, os
from fireworks.utilities.fw_serializers import load_object, load_object_from_file
class CommonAdapterTest(unittest.TestCase):
- def test_serialization(self):
+ def test_serialization(self) -> None:
p = CommonAdapter(
q_type="PBS",
q_name="hello",
@@ -36,16 +38,17 @@ def test_serialization(self):
assert p.get_script_str("here").split("\n")[-1] != "# world"
assert "_fw_template_file" not in p.to_dict()
- def test_yaml_load(self):
+ def test_yaml_load(self) -> None:
# Test yaml loading.
p = load_object_from_file(os.path.join(os.path.dirname(__file__), "pbs.yaml"))
p = CommonAdapter(q_type="PBS", q_name="hello", ppnode="8:ib", nnodes=1, hello="world", queue="random")
print(p.get_script_str("."))
- import ruamel.yaml as yaml
-
- print(yaml.safe_dump(p.to_dict(), default_flow_style=False))
+ yaml = YAML(typ="safe", pure=True)
+ yaml.default_flow_style = False
+ yaml.dump(p.to_dict(), sys.stdout)
+ print()
- def test_parse_njobs(self):
+ def test_parse_njobs(self) -> None:
pbs = """
tscc-mgr.sdsc.edu:
Req'd Req'd Elap
@@ -86,7 +89,7 @@ def test_parse_njobs(self):
p = CommonAdapter(q_type="SGE", q_name="hello", queue="all.q", hello="world")
assert p._parse_njobs(sge, "ongsp") == 3
- def test_parse_jobid(self):
+ def test_parse_jobid(self) -> None:
p = CommonAdapter(q_type="SLURM", q_name="hello", queue="home-ong", hello="world")
sbatch_output = """
SOME PREAMBLE
@@ -104,14 +107,14 @@ def test_parse_jobid(self):
qsub_output = 'Your job 44275 ("jobname") has been submitted'
assert p._parse_jobid(qsub_output) == "44275"
- def test_status_cmd_pbs(self):
+ def test_status_cmd_pbs(self) -> None:
p = load_object_from_file(
os.path.join(os.path.dirname(__file__), "pbs_override.yaml") # intentional red herring to test deepcopy
)
p = CommonAdapter(q_type="PBS")
assert p._get_status_cmd("my_name") == ["qstat", "-u", "my_name"]
- def test_override(self):
+ def test_override(self) -> None:
p = load_object_from_file(os.path.join(os.path.dirname(__file__), "pbs_override.yaml"))
assert p._get_status_cmd("my_name") == ["my_qstatus", "-u", "my_name"]
diff --git a/fireworks/utilities/dagflow.py b/fireworks/utilities/dagflow.py
index a2bf2a882..82c7cc7bf 100644
--- a/fireworks/utilities/dagflow.py
+++ b/fireworks/utilities/dagflow.py
@@ -43,7 +43,7 @@ class DAGFlow(Graph):
visualization of workflows.
"""
- def __init__(self, steps, links=None, nlinks=None, name=None, **kwargs):
+ def __init__(self, steps, links=None, nlinks=None, name=None, **kwargs) -> None:
Graph.__init__(self, directed=True, graph_attrs={"name": name}, **kwargs)
for step in steps:
@@ -69,7 +69,7 @@ def from_fireworks(cls, fireworkflow):
step = {}
step["name"] = fwk["name"]
step["id"] = fwk["fw_id"]
- step["state"] = fwk["state"] if "state" in fwk else None
+ step["state"] = fwk.get("state", None)
steps.append(step)
links = []
@@ -108,7 +108,7 @@ def task_input(task, spec):
step_data = []
for task in step["_tasks"]:
- true_task = task["task"] if "task" in task else task
+ true_task = task.get("task", task)
step_data.extend(task_input(true_task, fwk["spec"]))
if "outputs" in true_task:
assert isinstance(true_task["outputs"], list), "outputs must be a list in fw_id " + str(step["id"])
@@ -136,14 +136,14 @@ def _get_ctrlflow_links(self):
links.append((source, target))
return links
- def _add_ctrlflow_links(self, links):
+ def _add_ctrlflow_links(self, links) -> None:
"""Adds graph edges corresponding to control flow links."""
for link in links:
source = self._get_index(link[0])
target = self._get_index(link[1])
self.add_edge(source, target, label=" ")
- def _add_dataflow_links(self, step_id=None, mode="both"):
+ def _add_dataflow_links(self, step_id=None, mode="both") -> None:
"""Adds graph edges corresponding to data flow links."""
if step_id:
vidx = self._get_index(step_id)
@@ -205,13 +205,13 @@ def _get_targets(self, step, entity):
return lst
@staticmethod
- def _set_io_fields(step):
+ def _set_io_fields(step) -> None:
"""Set io keys as step attributes."""
for item in ["inputs", "outputs", "output"]:
step[item] = []
for task in step["_tasks"]:
# test the case of meta-tasks
- true_task = task["task"] if "task" in task else task
+ true_task = task.get("task", task)
if item in true_task:
if isinstance(true_task[item], list):
step[item].extend(true_task[item])
@@ -266,22 +266,22 @@ def _get_leaves(self):
"""Returns all leaves (i.e. vertices without outgoing edges)."""
return [i for i, v in enumerate(self.degree(mode=igraph.OUT)) if v == 0]
- def delete_ctrlflow_links(self):
+ def delete_ctrlflow_links(self) -> None:
"""Deletes graph edges corresponding to control flow links."""
lst = [link.index for link in list(self.es) if link["label"] == " "]
self.delete_edges(lst)
- def delete_dataflow_links(self):
+ def delete_dataflow_links(self) -> None:
"""Deletes graph edges corresponding to data flow links."""
lst = [link.index for link in list(self.es) if link["label"] != " "]
self.delete_edges(lst)
- def add_step_labels(self):
+ def add_step_labels(self) -> None:
"""Labels the workflow steps (i.e. graph vertices)."""
for vertex in list(self.vs):
vertex["label"] = vertex["name"] + ", id: " + str(vertex["id"])
- def check(self):
+ def check(self) -> None:
"""Correctness check of the workflow."""
try:
assert self.is_dag(), "The workflow graph must be a DAG."
@@ -292,7 +292,7 @@ def check(self):
assert len(self.vs["id"]) == len(set(self.vs["id"])), "Workflow steps must have unique IDs."
self.check_dataflow()
- def check_dataflow(self):
+ def check_dataflow(self) -> None:
"""Checks whether all inputs and outputs match."""
# check for shared output data entities
for vertex in list(self.vs):
@@ -323,7 +323,7 @@ def to_dict(self):
dct["links"] = self._get_ctrlflow_links()
return dct
- def to_dot(self, filename="wf.dot", view="combined"):
+ def to_dot(self, filename="wf.dot", view="combined") -> None:
"""Writes the workflow into a file in DOT format."""
graph = DAGFlow(**self.to_dict())
if view == "controlflow":
diff --git a/fireworks/utilities/dict_mods.py b/fireworks/utilities/dict_mods.py
index 098a89bb1..afced3f50 100644
--- a/fireworks/utilities/dict_mods.py
+++ b/fireworks/utilities/dict_mods.py
@@ -70,26 +70,26 @@ class DictMods:
supported using a special "->" keyword, e.g. {"a->b": 1}
"""
- def __init__(self):
+ def __init__(self) -> None:
self.supported_actions = {}
for i in dir(self):
if (not re.match(r"__\w+__", i)) and callable(getattr(self, i)):
self.supported_actions["_" + i] = getattr(self, i)
@staticmethod
- def set(input_dict, settings):
+ def set(input_dict, settings) -> None:
for k, v in settings.items():
(d, key) = get_nested_dict(input_dict, k)
d[key] = v
@staticmethod
- def unset(input_dict, settings):
+ def unset(input_dict, settings) -> None:
for k in settings:
(d, key) = get_nested_dict(input_dict, k)
del d[key]
@staticmethod
- def push(input_dict, settings):
+ def push(input_dict, settings) -> None:
for k, v in settings.items():
(d, key) = get_nested_dict(input_dict, k)
if key in d:
@@ -98,7 +98,7 @@ def push(input_dict, settings):
d[key] = [v]
@staticmethod
- def push_all(input_dict, settings):
+ def push_all(input_dict, settings) -> None:
for k, v in settings.items():
(d, key) = get_nested_dict(input_dict, k)
if key in d:
@@ -107,7 +107,7 @@ def push_all(input_dict, settings):
d[key] = v
@staticmethod
- def inc(input_dict, settings):
+ def inc(input_dict, settings) -> None:
for k, v in settings.items():
(d, key) = get_nested_dict(input_dict, k)
if key in d:
@@ -116,14 +116,14 @@ def inc(input_dict, settings):
d[key] = v
@staticmethod
- def rename(input_dict, settings):
+ def rename(input_dict, settings) -> None:
for k, v in settings.items():
if k in input_dict:
input_dict[v] = input_dict[k]
del input_dict[k]
@staticmethod
- def add_to_set(input_dict, settings):
+ def add_to_set(input_dict, settings) -> None:
for k, v in settings.items():
(d, key) = get_nested_dict(input_dict, k)
if key in d and (not isinstance(d[key], (list, tuple))):
@@ -134,7 +134,7 @@ def add_to_set(input_dict, settings):
d[key] = v
@staticmethod
- def pull(input_dict, settings):
+ def pull(input_dict, settings) -> None:
for k, v in settings.items():
(d, key) = get_nested_dict(input_dict, k)
if key in d and (not isinstance(d[key], (list, tuple))):
@@ -143,7 +143,7 @@ def pull(input_dict, settings):
d[key] = [i for i in d[key] if i != v]
@staticmethod
- def pull_all(input_dict, settings):
+ def pull_all(input_dict, settings) -> None:
for k, v in settings.items():
if k in input_dict and (not isinstance(input_dict[k], (list, tuple))):
raise ValueError(f"Keyword {k} does not refer to an array.")
@@ -151,7 +151,7 @@ def pull_all(input_dict, settings):
DictMods.pull(input_dict, {k: i})
@staticmethod
- def pop(input_dict, settings):
+ def pop(input_dict, settings) -> None:
for k, v in settings.items():
(d, key) = get_nested_dict(input_dict, k)
if key in d and (not isinstance(d[key], (list, tuple))):
@@ -162,7 +162,7 @@ def pop(input_dict, settings):
d[key].pop(0)
-def apply_mod(modification, obj):
+def apply_mod(modification, obj) -> None:
"""
Note that modify makes actual in-place modifications. It does not
return a copy.
diff --git a/fireworks/utilities/filepad.py b/fireworks/utilities/filepad.py
index c60a539e3..47676eded 100644
--- a/fireworks/utilities/filepad.py
+++ b/fireworks/utilities/filepad.py
@@ -37,7 +37,7 @@ def __init__(
logdir=None,
strm_lvl=None,
text_mode=False,
- ):
+ ) -> None:
"""
Args:
host (str): hostname
@@ -106,7 +106,7 @@ def __init__(
# build indexes
self.build_indexes()
- def build_indexes(self, indexes=None, background=True):
+ def build_indexes(self, indexes=None, background=True) -> None:
"""
Build the indexes.
@@ -194,7 +194,7 @@ def get_file_by_query(self, query, sort_key=None, sort_direction=DESCENDING):
cursor = self.filepad.find(query).sort(sort_key, sort_direction)
return [self._get_file_contents(d) for d in cursor]
- def delete_file(self, identifier):
+ def delete_file(self, identifier) -> None:
"""
Delete the document with the matching identifier. The contents in the gridfs as well as the
associated document in the filepad are deleted.
@@ -224,7 +224,7 @@ def update_file(self, identifier, path, compress=True):
doc = self.filepad.find_one({"identifier": identifier})
return self._update_file_contents(doc, path, compress)
- def delete_file_by_id(self, gfs_id):
+ def delete_file_by_id(self, gfs_id) -> None:
"""
Args:
gfs_id (str): the file id.
@@ -232,7 +232,7 @@ def delete_file_by_id(self, gfs_id):
self.gridfs.delete(gfs_id)
self.filepad.delete_one({"gfs_id": gfs_id})
- def delete_file_by_query(self, query):
+ def delete_file_by_query(self, query) -> None:
"""
Args:
query (dict): pymongo query dict.
@@ -366,7 +366,7 @@ def auto_load(cls):
return FilePad.from_db_file(LAUNCHPAD_LOC)
return FilePad()
- def reset(self):
+ def reset(self) -> None:
"""Reset filepad and the gridfs collections."""
self.filepad.delete_many({})
self.db[self.gridfs_coll_name].files.delete_many({})
diff --git a/fireworks/utilities/fw_serializers.py b/fireworks/utilities/fw_serializers.py
index 36c879ba2..aec94e1f0 100644
--- a/fireworks/utilities/fw_serializers.py
+++ b/fireworks/utilities/fw_serializers.py
@@ -33,9 +33,11 @@
import json # note that ujson is faster, but at this time does not support "default" in dumps()
import pkgutil
import traceback
+from typing import NoReturn
-import ruamel.yaml as yaml
from monty.json import MontyDecoder, MSONable
+from ruamel.yaml import YAML
+from ruamel.yaml.compat import StringIO
from fireworks.fw_config import (
DECODE_MONTY,
@@ -179,7 +181,7 @@ def _decorator(self, *args, **kwargs):
return _decorator
-class FWSerializable(metaclass=abc.ABCMeta):
+class FWSerializable(abc.ABC):
"""
To create a serializable object within FireWorks, you should subclass this
class and implement the to_dict() and from_dict() methods.
@@ -205,7 +207,7 @@ def fw_name(self):
return get_default_serialization(self.__class__)
@abc.abstractmethod
- def to_dict(self):
+ def to_dict(self) -> NoReturn:
raise NotImplementedError("FWSerializable object did not implement to_dict()!")
def to_db_dict(self):
@@ -219,10 +221,10 @@ def as_dict(self):
@classmethod
@abc.abstractmethod
- def from_dict(cls, m_dict):
+ def from_dict(cls, m_dict) -> NoReturn:
raise NotImplementedError("FWSerializable object did not implement from_dict()!")
- def __repr__(self):
+ def __repr__(self) -> str:
return json.dumps(self.to_dict(), default=DATETIME_HANDLER)
def to_format(self, f_format="json", **kwargs):
@@ -235,8 +237,12 @@ def to_format(self, f_format="json", **kwargs):
if f_format == "json":
return json.dumps(self.to_dict(), default=DATETIME_HANDLER, **kwargs)
if f_format == "yaml":
- # start with the JSON format, and convert to YAML
- return yaml.safe_dump(self.to_dict(), default_flow_style=YAML_STYLE, allow_unicode=True)
+ yaml = YAML(typ="safe", pure=True)
+ yaml.default_flow_style = YAML_STYLE
+ yaml.allow_unicode = True
+ strm = StringIO()
+ yaml.dump(self.to_dict(), strm)
+ return strm.getvalue()
raise ValueError(f"Unsupported format {f_format}")
@classmethod
@@ -254,14 +260,14 @@ def from_format(cls, f_str, f_format="json"):
if f_format == "json":
dct = json.loads(f_str)
elif f_format == "yaml":
- dct = yaml.safe_load(f_str)
+ dct = YAML(typ="safe", pure=True).load(f_str)
else:
raise ValueError(f"Unsupported format {f_format}")
if JSON_SCHEMA_VALIDATE and cls.__name__ in JSON_SCHEMA_VALIDATE_LIST:
fireworks_schema.validate(dct, cls.__name__)
return cls.from_dict(reconstitute_dates(dct))
- def to_file(self, filename, f_format=None, **kwargs):
+ def to_file(self, filename, f_format=None, **kwargs) -> None:
"""
Write a serialization of this object to a file.
@@ -271,8 +277,16 @@ def to_file(self, filename, f_format=None, **kwargs):
"""
if f_format is None:
f_format = filename.split(".")[-1]
- with open(filename, "w", **ENCODING_PARAMS) as f:
- f.write(self.to_format(f_format=f_format, **kwargs))
+ with open(filename, "w", **ENCODING_PARAMS) as f_out:
+ if f_format == "json":
+ json.dump(self.to_dict(), f_out, default=DATETIME_HANDLER, **kwargs)
+ elif f_format == "yaml":
+ yaml = YAML(typ="safe", pure=True)
+ yaml.default_flow_style = YAML_STYLE
+ yaml.allow_unicode = True
+ yaml.dump(self.to_dict(), f_out)
+ else:
+ raise ValueError(f"Unsupported format {f_format}")
@classmethod
def from_file(cls, filename, f_format=None):
@@ -387,7 +401,7 @@ def load_object_from_file(filename, f_format=None):
if f_format == "json":
dct = json.loads(f.read())
elif f_format == "yaml":
- dct = yaml.safe_load(f)
+ dct = YAML(typ="safe", pure=True).load(f.read())
else:
raise ValueError(f"Unknown file format {f_format} cannot be loaded!")
diff --git a/fireworks/utilities/fw_utilities.py b/fireworks/utilities/fw_utilities.py
index cf01deb7a..3b63ce2ae 100644
--- a/fireworks/utilities/fw_utilities.py
+++ b/fireworks/utilities/fw_utilities.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import contextlib
import datetime
import errno
@@ -10,7 +12,6 @@
import traceback
from logging import Formatter, Logger
from multiprocessing.managers import BaseManager
-from typing import Tuple
from fireworks.fw_config import DS_PASSWORD, FW_BLOCK_FORMAT, FW_LOGGING_FORMAT, FWData
@@ -28,7 +29,7 @@
def get_fw_logger(
name: str,
l_dir: None = None,
- file_levels: Tuple[str, str] = ("DEBUG", "ERROR"),
+ file_levels: tuple[str, str] = ("DEBUG", "ERROR"),
stream_level: str = "DEBUG",
formatter: Formatter = DEFAULT_FORMATTER,
clear_logs: bool = False,
@@ -71,7 +72,7 @@ def get_fw_logger(
return logger
-def log_multi(m_logger, msg, log_lvl="info"):
+def log_multi(m_logger, msg, log_lvl="info") -> None:
"""
Args:
m_logger (logger): The logger object
@@ -85,7 +86,7 @@ def log_multi(m_logger, msg, log_lvl="info"):
_log_fnc(msg)
-def log_fancy(m_logger, msgs, log_lvl="info", add_traceback=False):
+def log_fancy(m_logger, msgs, log_lvl="info", add_traceback=False) -> None:
"""
A wrapper around the logger messages useful for multi-line logs.
Helps to group log messages by adding a fancy border around it,
@@ -167,7 +168,7 @@ def get_path():
def get_my_ip():
- global _g_ip
+ global _g_ip # noqa: PLW0603
if _g_ip is None:
try:
_g_ip = socket.gethostbyname(socket.gethostname())
@@ -177,7 +178,7 @@ def get_my_ip():
def get_my_host():
- global _g_host
+ global _g_host # noqa: PLW0603
if _g_host is None:
_g_host = socket.gethostname()
return _g_host
diff --git a/fireworks/utilities/tests/test_dagflow.py b/fireworks/utilities/tests/test_dagflow.py
index b4b8e63eb..a2b94d551 100644
--- a/fireworks/utilities/tests/test_dagflow.py
+++ b/fireworks/utilities/tests/test_dagflow.py
@@ -16,7 +16,7 @@
class DAGFlowTest(unittest.TestCase):
"""run tests for DAGFlow class."""
- def setUp(self):
+ def setUp(self) -> None:
try:
__import__("igraph", fromlist=["Graph"])
except (ImportError, ModuleNotFoundError):
@@ -34,7 +34,7 @@ def setUp(self):
)
self.fw3 = Firework(PyTask(func="print", inputs=["second power"]), name="the third one")
- def test_dagflow_ok(self):
+ def test_dagflow_ok(self) -> None:
"""Construct and replicate."""
from fireworks.utilities.dagflow import DAGFlow
@@ -42,7 +42,7 @@ def test_dagflow_ok(self):
dagf = DAGFlow.from_fireworks(wfl)
DAGFlow(**dagf.to_dict())
- def test_dagflow_loop(self):
+ def test_dagflow_loop(self) -> None:
"""Loop in graph."""
from fireworks.utilities.dagflow import DAGFlow
@@ -52,7 +52,7 @@ def test_dagflow_loop(self):
DAGFlow.from_fireworks(wfl).check()
assert msg in str(exc.value)
- def test_dagflow_cut(self):
+ def test_dagflow_cut(self) -> None:
"""Disconnected graph."""
from fireworks.utilities.dagflow import DAGFlow
@@ -62,7 +62,7 @@ def test_dagflow_cut(self):
DAGFlow.from_fireworks(wfl).check()
assert msg in str(exc.value)
- def test_dagflow_link(self):
+ def test_dagflow_link(self) -> None:
"""Wrong links."""
from fireworks.utilities.dagflow import DAGFlow
@@ -72,7 +72,7 @@ def test_dagflow_link(self):
DAGFlow.from_fireworks(wfl).check()
assert msg in str(exc.value)
- def test_dagflow_missing_input(self):
+ def test_dagflow_missing_input(self) -> None:
"""Missing input."""
from fireworks.utilities.dagflow import DAGFlow
@@ -89,7 +89,7 @@ def test_dagflow_missing_input(self):
DAGFlow.from_fireworks(wfl).check()
assert msg in str(exc.value)
- def test_dagflow_clashing_inputs(self):
+ def test_dagflow_clashing_inputs(self) -> None:
"""Parent firework output overwrites an input in spec."""
from fireworks.utilities.dagflow import DAGFlow
@@ -107,7 +107,7 @@ def test_dagflow_clashing_inputs(self):
DAGFlow.from_fireworks(wfl).check()
assert msg in str(exc.value)
- def test_dagflow_race_condition(self):
+ def test_dagflow_race_condition(self) -> None:
"""Two parent firework outputs overwrite each other."""
from fireworks.utilities.dagflow import DAGFlow
@@ -123,7 +123,7 @@ def test_dagflow_race_condition(self):
DAGFlow.from_fireworks(wfl).check()
assert msg in str(exc.value)
- def test_dagflow_clashing_outputs(self):
+ def test_dagflow_clashing_outputs(self) -> None:
"""Subsequent task overwrites output of a task."""
from fireworks.utilities.dagflow import DAGFlow
@@ -137,7 +137,7 @@ def test_dagflow_clashing_outputs(self):
DAGFlow.from_fireworks(Workflow([fwk], {})).check()
assert msg in str(exc.value)
- def test_dagflow_non_dataflow_tasks(self):
+ def test_dagflow_non_dataflow_tasks(self) -> None:
"""non-dataflow tasks using outputs and inputs keys do not fail."""
from fireworks.core.firework import FiretaskBase
from fireworks.utilities.dagflow import DAGFlow
@@ -148,7 +148,7 @@ class NonDataFlowTask(FiretaskBase):
_fw_name = "NonDataFlowTask"
required_params = ["inputs", "outputs"]
- def run_task(self, fw_spec):
+ def run_task(self, fw_spec) -> None:
pass
task = NonDataFlowTask(inputs=["first power", "exponent"], outputs=["second power"])
@@ -156,7 +156,7 @@ def run_task(self, fw_spec):
wfl = Workflow([self.fw1, fw2], {self.fw1: [fw2], fw2: []})
DAGFlow.from_fireworks(wfl).check()
- def test_dagflow_view(self):
+ def test_dagflow_view(self) -> None:
"""Visualize the workflow graph."""
from fireworks.utilities.dagflow import DAGFlow
diff --git a/fireworks/utilities/tests/test_filepad.py b/fireworks/utilities/tests/test_filepad.py
index f501d7124..f2d628f1e 100644
--- a/fireworks/utilities/tests/test_filepad.py
+++ b/fireworks/utilities/tests/test_filepad.py
@@ -7,22 +7,22 @@
class FilePadTest(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.chgcar_file = os.path.join(module_dir, "CHGCAR.Fe3O4")
self.fp = FilePad.auto_load()
self.identifier = "Fe3O4"
- def test_add_file(self):
+ def test_add_file(self) -> None:
gfs_id, file_identifier = self.fp.add_file(self.chgcar_file, identifier=self.identifier)
assert file_identifier == self.identifier
assert gfs_id is not None
- def test_add_file_with_no_identifier(self):
+ def test_add_file_with_no_identifier(self) -> None:
gfs_id, file_identifier = self.fp.add_file(self.chgcar_file)
assert gfs_id is not None
assert file_identifier == gfs_id
- def test_get_file(self):
+ def test_get_file(self) -> None:
_, file_identifier = self.fp.add_file(self.chgcar_file, identifier="xxx", metadata={"author": "Kiran Mathew"})
file_contents, doc = self.fp.get_file(file_identifier)
with open(self.chgcar_file) as file:
@@ -35,27 +35,27 @@ def test_get_file(self):
assert doc["original_file_path"] == abspath
assert doc["compressed"] is True
- def test_delete_file(self):
+ def test_delete_file(self) -> None:
_, file_identifier = self.fp.add_file(self.chgcar_file)
self.fp.delete_file(file_identifier)
contents, doc = self.fp.get_file(file_identifier)
assert contents is None
assert doc is None
- def test_update_file(self):
+ def test_update_file(self) -> None:
gfs_id, _ = self.fp.add_file(self.chgcar_file, identifier="test_update_file")
old_id, new_id = self.fp.update_file("test_update_file", self.chgcar_file)
assert old_id == gfs_id
assert new_id != gfs_id
assert not self.fp.gridfs.exists(old_id)
- def test_update_file_by_id(self):
+ def test_update_file_by_id(self) -> None:
gfs_id, _ = self.fp.add_file(self.chgcar_file, identifier="some identifier")
old, new = self.fp.update_file_by_id(gfs_id, self.chgcar_file)
assert old == gfs_id
assert new != gfs_id
- def tearDown(self):
+ def tearDown(self) -> None:
self.fp.reset()
diff --git a/fireworks/utilities/tests/test_fw_serializers.py b/fireworks/utilities/tests/test_fw_serializers.py
index 6b6ae81ff..9d447d14e 100644
--- a/fireworks/utilities/tests/test_fw_serializers.py
+++ b/fireworks/utilities/tests/test_fw_serializers.py
@@ -20,7 +20,7 @@
@explicit_serialize
class ExplicitTestSerializer(FWSerializable):
- def __init__(self, a):
+ def __init__(self, a) -> None:
self.a = a
def __eq__(self, other):
@@ -35,7 +35,7 @@ def from_dict(cls, m_dict):
class SerializationTest(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
test_date = datetime.datetime.utcnow()
# A basic datetime test serialized object
self.obj_1 = TestSerializer("prop1", test_date)
@@ -55,50 +55,50 @@ def setUp(self):
self.module_dir = os.path.dirname(os.path.abspath(__file__))
- def tearDown(self):
+ def tearDown(self) -> None:
os.remove("test.json")
os.remove("test.yaml")
- def test_sanity(self):
+ def test_sanity(self) -> None:
assert self.obj_1 == self.obj_1_copy, "The __eq__() method of the TestSerializer is not set up properly!"
assert self.obj_1 != self.obj_2, "The __ne__() method of the TestSerializer is not set up properly!"
assert self.obj_1 == self.obj_1.from_dict(
self.obj_1.to_dict()
), "The to/from_dict() methods of the TestSerializer are not set up properly!"
- def test_serialize_fw_decorator(self):
+ def test_serialize_fw_decorator(self) -> None:
m_dict = self.obj_1.to_dict()
assert m_dict["_fw_name"] == "TestSerializer Name"
- def test_json(self):
+ def test_json(self) -> None:
obj1_json_string = str(self.obj_1.to_format()) # default format is JSON, make sure this is true
assert self.obj_1.from_format(obj1_json_string) == self.obj_1, "JSON format export / import fails!"
- def test_yaml(self):
+ def test_yaml(self) -> None:
obj1_yaml_string = str(self.obj_1.to_format("yaml"))
assert self.obj_1.from_format(obj1_yaml_string, "yaml") == self.obj_1, "YAML format export / import fails!"
- def test_complex_json(self):
+ def test_complex_json(self) -> None:
obj2_json_string = str(self.obj_2.to_format()) # default format is JSON, make sure this is true
assert self.obj_2.from_format(obj2_json_string) == self.obj_2, "Complex JSON format export / import fails!"
- def test_complex_yaml(self):
+ def test_complex_yaml(self) -> None:
obj2_yaml_string = str(self.obj_2.to_format("yaml"))
assert (
self.obj_2.from_format(obj2_yaml_string, "yaml") == self.obj_2
), "Complex YAML format export / import fails!"
- def test_unicode_json(self):
+ def test_unicode_json(self) -> None:
obj3_json_string = str(self.obj_3.to_format()) # default format is JSON, make sure this is true
assert self.obj_3.from_format(obj3_json_string) == self.obj_3, "Unicode JSON format export / import fails!"
- def test_unicode_yaml(self):
+ def test_unicode_yaml(self) -> None:
obj3_yaml_string = str(self.obj_3.to_format("yaml"))
assert (
self.obj_3.from_format(obj3_yaml_string, "yaml") == self.obj_3
), "Unicode YAML format export / import fails!"
- def test_unicode_json_file(self):
+ def test_unicode_json_file(self) -> None:
with open(os.path.join(self.module_dir, "test_reference.json")) as f, open(
"test.json", **ENCODING_PARAMS
) as f2:
@@ -108,22 +108,22 @@ def test_unicode_json_file(self):
assert self.obj_3.from_file("test.json") == self.obj_3, "Unicode JSON file import fails!"
- def test_unicode_yaml_file(self):
+ def test_unicode_yaml_file(self) -> None:
ref_path = os.path.join(self.module_dir, "test_reference.yaml")
with open(ref_path, **ENCODING_PARAMS) as f, open("test.yaml", **ENCODING_PARAMS) as f2:
assert f.read() == f2.read(), "Unicode JSON file export fails"
assert self.obj_3.from_file("test.yaml") == self.obj_3, "Unicode YAML file import fails!"
- def test_implicit_serialization(self):
+ def test_implicit_serialization(self) -> None:
assert (
load_object({"a": {"p1": {"p2": 3}}, "_fw_name": "TestSerializer Export Name"}) == self.obj_4
), "Implicit import fails!"
- def test_as_dict(self):
+ def test_as_dict(self) -> None:
assert self.obj_1.as_dict() == self.obj_1.to_dict()
- def test_numpy_array(self):
+ def test_numpy_array(self) -> None:
try:
import numpy as np
except Exception:
@@ -135,11 +135,11 @@ def test_numpy_array(self):
class ExplicitSerializationTest(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.s_obj = ExplicitTestSerializer(1)
self.s_dict = self.s_obj.to_dict()
- def test_explicit_serialization(self):
+ def test_explicit_serialization(self) -> None:
assert load_object(self.s_dict) == self.s_obj
diff --git a/fireworks/utilities/tests/test_update_collection.py b/fireworks/utilities/tests/test_update_collection.py
index a7abb0873..620694bf1 100644
--- a/fireworks/utilities/tests/test_update_collection.py
+++ b/fireworks/utilities/tests/test_update_collection.py
@@ -15,7 +15,7 @@
class UpdateCollectionTests(unittest.TestCase):
@classmethod
- def setUpClass(cls):
+ def setUpClass(cls) -> None:
cls.lp = None
try:
cls.lp = LaunchPad(name=TESTDB_NAME, strm_lvl="ERROR")
@@ -24,25 +24,25 @@ def setUpClass(cls):
raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.")
@classmethod
- def tearDownClass(cls):
+ def tearDownClass(cls) -> None:
if cls.lp:
cls.lp.connection.drop_database(TESTDB_NAME)
- def test_update_path(cls):
- cls.lp.db.test_coll.insert_one({"foo": "bar", "foo_list": [{"foo1": "bar1"}, {"foo2": "foo/old/path/bar"}]})
+ def test_update_path(self) -> None:
+ self.lp.db.test_coll.insert_one({"foo": "bar", "foo_list": [{"foo1": "bar1"}, {"foo2": "foo/old/path/bar"}]})
update_path_in_collection(
- cls.lp.db,
+ self.lp.db,
collection_name="test_coll",
replacements={"old/path": "new/path"},
query=None,
dry_run=False,
force_clear=False,
)
- ndocs = cls.lp.db.test_coll.count_documents({})
- assert ndocs == 1
- test_doc = cls.lp.db.test_coll.find_one({"foo": "bar"})
+ n_docs = self.lp.db.test_coll.count_documents({})
+ assert n_docs == 1
+ test_doc = self.lp.db.test_coll.find_one({"foo": "bar"})
assert test_doc["foo_list"][1]["foo2"] == "foo/new/path/bar"
- test_doc_archived = cls.lp.db[f"test_coll_xiv_{datetime.date.today()}"].find_one()
+ test_doc_archived = self.lp.db[f"test_coll_xiv_{datetime.date.today()}"].find_one()
assert test_doc_archived["foo_list"][1]["foo2"] == "foo/old/path/bar"
diff --git a/fireworks/utilities/tests/test_visualize.py b/fireworks/utilities/tests/test_visualize.py
index 330d334dd..1d82a01ac 100644
--- a/fireworks/utilities/tests/test_visualize.py
+++ b/fireworks/utilities/tests/test_visualize.py
@@ -21,7 +21,7 @@ def power_wf():
return Workflow([fw1, fw2, fw3], {fw1: [fw2], fw2: [fw3], fw3: []})
-def test_wf_to_graph(power_wf):
+def test_wf_to_graph(power_wf) -> None:
dag = wf_to_graph(power_wf)
assert isinstance(dag, Digraph)
@@ -31,7 +31,7 @@ def test_wf_to_graph(power_wf):
assert isinstance(dag, Digraph)
-def test_plot_wf(power_wf):
+def test_plot_wf(power_wf) -> None:
plot_wf(power_wf)
plot_wf(power_wf, depth_factor=0.5, breadth_factor=1)
diff --git a/fireworks/utilities/update_collection.py b/fireworks/utilities/update_collection.py
index 9d63f28ac..25070376b 100644
--- a/fireworks/utilities/update_collection.py
+++ b/fireworks/utilities/update_collection.py
@@ -8,7 +8,7 @@
__date__ = "Dec 08, 2016"
-def update_launchpad_data(lp, replacements, **kwargs):
+def update_launchpad_data(lp, replacements, **kwargs) -> None:
"""
If you want to update a text string in your entire FireWorks database with a replacement, use this method.
For example, you might want to update a directory name preamble like "/scratch/user1" to "/project/user2".
@@ -26,7 +26,7 @@ def update_launchpad_data(lp, replacements, **kwargs):
print("Update launchpad data complete.")
-def update_path_in_collection(db, collection_name, replacements, query=None, dry_run=False, force_clear=False):
+def update_path_in_collection(db, collection_name, replacements, query=None, dry_run=False, force_clear=False) -> None:
"""
updates the text specified in replacements for the documents in a MongoDB collection.
This can be used to mass-update an outdated value (e.g., a directory path or tag) in that collection.
diff --git a/fireworks/utilities/visualize.py b/fireworks/utilities/visualize.py
index c28e2f346..0c079c112 100644
--- a/fireworks/utilities/visualize.py
+++ b/fireworks/utilities/visualize.py
@@ -1,4 +1,6 @@
-from typing import Any, Dict, Optional
+from __future__ import annotations
+
+from typing import Any
from monty.dev import requires
@@ -23,7 +25,7 @@ def plot_wf(
markersize=10,
markerfacecolor="blue",
fontsize=12,
-):
+) -> None:
"""
Generate a visual representation of the workflow. Useful for checking whether the firework
connections are in order before launching the workflow.
@@ -60,7 +62,7 @@ def plot_wf(
# the rest
for k in keys:
for i, j in enumerate(wf.links[k]):
- if not points_map.get(j, None):
+ if not points_map.get(j):
points_map[j] = ((i - len(wf.links[k]) / 2.0) * breadth_factor, k * depth_factor)
# connect the dots
@@ -98,7 +100,7 @@ def plot_wf(
"graphviz package required for wf_to_graph.\n"
"Follow the installation instructions here: https://github.com/xflr6/graphviz",
)
-def wf_to_graph(wf: Workflow, dag_kwargs: Optional[Dict[str, Any]] = None, wf_show_tasks: bool = True) -> Digraph:
+def wf_to_graph(wf: Workflow, dag_kwargs: dict[str, Any] | None = None, wf_show_tasks: bool = True) -> Digraph:
"""Renders a graph representation of a workflow or firework. Workflows are rendered as the
control flow of the firework, while Fireworks are rendered as a sequence of Firetasks.
@@ -150,7 +152,7 @@ def wf_to_graph(wf: Workflow, dag_kwargs: Optional[Dict[str, Any]] = None, wf_sh
if idx == 0:
subgraph.edge(str(fw.fw_id), node_id)
else:
- subgraph.edge(f"{fw.fw_id}-{idx-1}", node_id)
+ subgraph.edge(f"{fw.fw_id}-{idx - 1}", node_id)
dag.subgraph(subgraph)
diff --git a/fw_tutorials/dynamic_wf/printjob_task.py b/fw_tutorials/dynamic_wf/printjob_task.py
index 0427d1907..2081402f1 100644
--- a/fw_tutorials/dynamic_wf/printjob_task.py
+++ b/fw_tutorials/dynamic_wf/printjob_task.py
@@ -10,7 +10,7 @@
class PrintJobTask(FiretaskBase):
_fw_name = "Print Job Task"
- def run_task(self, fw_spec):
+ def run_task(self, fw_spec) -> None:
job_info_array = fw_spec["_job_info"]
prev_job_info = job_info_array[-1]
diff --git a/fw_tutorials/python/python_examples.py b/fw_tutorials/python/python_examples.py
index 9a11d2ba0..b59bc3134 100644
--- a/fw_tutorials/python/python_examples.py
+++ b/fw_tutorials/python/python_examples.py
@@ -19,7 +19,7 @@ def setup():
return launchpad
-def basic_fw_ex():
+def basic_fw_ex() -> None:
print("--- BASIC FIREWORK EXAMPLE ---")
# setup
@@ -34,7 +34,7 @@ def basic_fw_ex():
launch_rocket(launchpad, FWorker())
-def rapid_fire_ex():
+def rapid_fire_ex() -> None:
print("--- RAPIDFIRE EXAMPLE ---")
# setup
@@ -55,7 +55,7 @@ def rapid_fire_ex():
rapidfire(launchpad, FWorker())
-def multiple_tasks_ex():
+def multiple_tasks_ex() -> None:
print("--- MULTIPLE FIRETASKS EXAMPLE ---")
# setup
@@ -72,7 +72,7 @@ def multiple_tasks_ex():
rapidfire(launchpad, FWorker())
-def basic_wf_ex():
+def basic_wf_ex() -> None:
print("--- BASIC WORKFLOW EXAMPLE ---")
# setup
diff --git a/pyproject.toml b/pyproject.toml
index 98e80f9e7..28a4cdfb3 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,6 +4,8 @@ line-length = 120
[tool.ruff]
target-version = "py38"
line-length = 120
+
+[tool.ruff.lint]
select = [
"B", # flake8-bugbear
"C4", # flake8-comprehensions
@@ -46,6 +48,7 @@ ignore = [
"PD901", # pandas-df-variable-name
"PERF203", # try-except-in-loop
# "PERF401", # manual-list-comprehension (TODO fix these or wait for autofix)
+ "ISC001",
"PLR", # pylint refactor
"PLW2901", # Outer for loop variable overwritten by inner assignment target
"PT013", # pytest-incorrect-pytest-import
@@ -56,7 +59,7 @@ ignore = [
pydocstyle.convention = "google"
isort.split-on-trailing-comma = false
-[tool.ruff.per-file-ignores]
+[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F401"]
"tests/**" = ["D"]
"tasks.py" = ["D"]
diff --git a/requirements.txt b/requirements.txt
index 0aea8ed6d..66f62c17b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,10 +1,10 @@
ruamel.yaml==0.16.5
-pymongo==3.10.0
+pymongo==4.0.0
Jinja2
monty==3.0.2
python-dateutil==2.8.1
tabulate==0.8.6
-flask==1.1.1
+flask==2.2.5
flask-paginate==0.5.5
gunicorn==20.0.4
tqdm==4.41.0
diff --git a/setup.py b/setup.py
index 23b9acd86..83c18152e 100644
--- a/setup.py
+++ b/setup.py
@@ -12,7 +12,7 @@
name="FireWorks",
version="2.0.3",
description="FireWorks workflow software",
- long_description=open("README.md").read(), # noqa: SIM115
+ long_description=open("README.md", encoding="utf-8").read(), # noqa: SIM115
url="https://github.com/materialsproject/fireworks",
author="Anubhav Jain",
author_email="anubhavster@gmail.com",
@@ -28,7 +28,7 @@
python_requires=">=3.8",
install_requires=[
"ruamel.yaml>=0.15.35",
- "pymongo>=3.3.0",
+ "pymongo>=4.0.0",
"Jinja2>=2.8.0",
"monty>=1.0.1",
"python-dateutil>=2.5.3",
diff --git a/tasks.py b/tasks.py
index a36980eaa..bbdc164ed 100644
--- a/tasks.py
+++ b/tasks.py
@@ -21,7 +21,7 @@
@task
-def make_doc(ctx):
+def make_doc(ctx) -> None:
with cd("docs_rst"):
ctx.run("sphinx-apidoc -o . -f ../fireworks")
ctx.run("make html")
@@ -36,7 +36,7 @@ def make_doc(ctx):
@task
-def update_doc(ctx):
+def update_doc(ctx) -> None:
make_doc(ctx)
with cd("docs"):
ctx.run("git add .")
@@ -45,12 +45,12 @@ def update_doc(ctx):
@task
-def publish(ctx):
+def publish(ctx) -> None:
ctx.run("python setup.py release")
@task
-def release_github(ctx):
+def release_github(ctx) -> None:
payload = {
"tag_name": fw_version,
"target_commitish": "master",
@@ -71,13 +71,13 @@ def release_github(ctx):
@task
-def release(ctx):
+def release(ctx) -> None:
publish(ctx)
update_doc(ctx)
release_github(ctx)
@task
-def open_doc(ctx):
+def open_doc(ctx) -> None:
pth = os.path.abspath("docs/index.html")
webbrowser.open("file://" + pth)