diff --git a/docs_rst/conf.py b/docs_rst/conf.py index dfcee48e7..e0cc73049 100644 --- a/docs_rst/conf.py +++ b/docs_rst/conf.py @@ -12,6 +12,7 @@ import os import sys +from typing import Mapping from fireworks import __version__ @@ -181,7 +182,7 @@ # -- Options for LaTeX output -------------------------------------------------- -latex_elements = { +latex_elements: Mapping[str, str] = { # The paper size ('letterpaper' or 'a4paper'). #'papersize': 'letterpaper', # The font size ('10pt', '11pt' or '12pt'). diff --git a/fireworks/core/firework.py b/fireworks/core/firework.py index 833ebfb8a..eedc49d62 100644 --- a/fireworks/core/firework.py +++ b/fireworks/core/firework.py @@ -16,11 +16,13 @@ import pprint from collections import OrderedDict, defaultdict from datetime import datetime +from typing import Any, Dict, Mapping, Optional, Sequence, Tuple, Union from monty.io import reverse_readline, zopen from monty.os.path import zpath from fireworks.core.fworker import FWorker +from fireworks.core.types import Checkpoint, Spec from fireworks.fw_config import ( EXCEPT_DETAILS_ON_RERUN, NEGATIVE_FWID_CTR, @@ -64,7 +66,7 @@ class FiretaskBase(defaultdict, FWSerializable, metaclass=abc.ABCMeta): # if set to a list of str, only required and optional kwargs are allowed; consistency checked upon init optional_params = None - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: dict.__init__(self, *args, **kwargs) required_params = self.required_params or [] @@ -84,7 +86,7 @@ def __init__(self, *args, **kwargs): ) @abc.abstractmethod - def run_task(self, fw_spec): + def run_task(self, fw_spec: Spec) -> Optional["FWAction"]: """ This method gets called when the Firetask is run. It can take in a Firework spec, perform some task using that data, and then return an @@ -109,27 +111,27 @@ def run_task(self, fw_spec): @serialize_fw @recursive_serialize - def to_dict(self): + def to_dict(self) -> Dict[Any, Any]: return dict(self) @classmethod @recursive_deserialize - def from_dict(cls, m_dict): + def from_dict(cls, m_dict: Dict[Any, Any]) -> "FiretaskBase": return cls(m_dict) - def __repr__(self): + def __repr__(self) -> str: return f"<{self.fw_name}>:{dict(self)}" # not strictly needed here for pickle/unpickle, but complements __setstate__ - def __getstate__(self): + def __getstate__(self) -> Dict[Any, Any]: return self.to_dict() # added to support pickle/unpickle - def __setstate__(self, state): + def __setstate__(self, state) -> None: self.__init__(state) # added to support pickle/unpickle - def __reduce__(self): + def __reduce__(self) -> Tuple: t = defaultdict.__reduce__(self) return (t[0], (self.to_dict(),), t[2], t[3], t[4]) @@ -143,15 +145,15 @@ class FWAction(FWSerializable): def __init__( self, - stored_data=None, - exit=False, - update_spec=None, - mod_spec=None, - additions=None, - detours=None, - defuse_children=False, - defuse_workflow=False, - propagate=False, + stored_data: Optional[Dict[Any, Any]] = None, + exit: bool = False, + update_spec: Optional[Mapping[str, Any]] = None, + mod_spec: Optional[Mapping[str, Any]] = None, + additions: Optional[Union[Sequence["Firework"], Sequence["Workflow"]]] = None, + detours: Optional[Union[Sequence["Firework"], Sequence["Workflow"]]] = None, + defuse_children: bool = False, + defuse_workflow: bool = False, + propagate: bool = False, ): """ Args: @@ -184,7 +186,7 @@ def __init__( self.propagate = propagate @recursive_serialize - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: return { "stored_data": self.stored_data, "exit": self.exit, @@ -199,7 +201,7 @@ def to_dict(self): @classmethod @recursive_deserialize - def from_dict(cls, m_dict): + def from_dict(cls, m_dict: Dict[str, Any]) -> "FWAction": d = m_dict additions = [Workflow.from_dict(f) for f in d["additions"]] detours = [Workflow.from_dict(f) for f in d["detours"]] @@ -216,7 +218,7 @@ def from_dict(cls, m_dict): ) @property - def skip_remaining_tasks(self): + def skip_remaining_tasks(self) -> bool: """ If the FWAction gives any dynamic action, we skip the subsequent Firetasks @@ -225,7 +227,7 @@ def skip_remaining_tasks(self): """ return self.exit or self.detours or self.additions or self.defuse_children or self.defuse_workflow - def __str__(self): + def __str__(self) -> str: return "FWAction\n" + pprint.pformat(self.to_dict()) @@ -249,16 +251,16 @@ class Firework(FWSerializable): # note: if you modify this signature, you must also modify LazyFirework def __init__( self, - tasks, - spec=None, - name=None, - launches=None, - archived_launches=None, - state="WAITING", - created_on=None, - fw_id=None, - parents=None, - updated_on=None, + tasks: Union["FiretaskBase", Sequence["FiretaskBase"]], + spec: Optional[Dict[Any, Any]] = None, + name: Optional[str] = None, + launches: Optional[Sequence["Launch"]] = None, + archived_launches: Optional[Sequence["Launch"]] = None, + state: str = "WAITING", + created_on: Optional[datetime] = None, + fw_id: Optional[int] = None, + parents: Optional[Union["Firework", Sequence["Firework"]]] = None, + updated_on: Optional[datetime] = None, ): """ Args: @@ -298,7 +300,7 @@ def __init__( self._state = state @property - def state(self): + def state(self) -> str: """ Returns: str: The current state of the Firework @@ -306,7 +308,7 @@ def state(self): return self._state @state.setter - def state(self, state): + def state(self, state: str) -> None: """ Setter for the the FW state, which triggers updated_on change @@ -317,7 +319,7 @@ def state(self, state): self.updated_on = datetime.utcnow() @recursive_serialize - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: # put tasks in a special location of the spec spec = self.spec spec["_tasks"] = [t.to_dict() for t in self.tasks] @@ -338,7 +340,7 @@ def to_dict(self): return m_dict - def _rerun(self): + def _rerun(self) -> None: """ Moves all Launches to archived Launches and resets the state to 'WAITING'. The Firework can thus be re-run even if it was Launched in the past. This method should be called by @@ -362,7 +364,7 @@ def _rerun(self): self.launches = [] self.state = "WAITING" - def to_db_dict(self): + def to_db_dict(self) -> Dict[str, Any]: """ Return firework dict with updated launches and state. """ @@ -376,7 +378,7 @@ def to_db_dict(self): @classmethod @recursive_deserialize - def from_dict(cls, m_dict): + def from_dict(cls, m_dict: Dict[str, Any]) -> "Firework": tasks = m_dict["spec"]["_tasks"] launches = [Launch.from_dict(tmp) for tmp in m_dict.get("launches", [])] archived_launches = [Launch.from_dict(tmp) for tmp in m_dict.get("archived_launches", [])] @@ -389,7 +391,7 @@ def from_dict(cls, m_dict): tasks, m_dict["spec"], name, launches, archived_launches, state, created_on, fw_id, updated_on=updated_on ) - def __str__(self): + def __str__(self) -> str: return "Firework object: (id: %i , name: %s)" % (self.fw_id, self.fw_name) @@ -400,7 +402,9 @@ class Tracker(FWSerializable): MAX_TRACKER_LINES = 1000 - def __init__(self, filename, nlines=TRACKER_LINES, content="", allow_zipped=False): + def __init__( + self, filename: str, nlines: int = TRACKER_LINES, content: str = "", allow_zipped: bool = False + ) -> None: """ Args: filename (str) @@ -415,7 +419,7 @@ def __init__(self, filename, nlines=TRACKER_LINES, content="", allow_zipped=Fals self.content = content self.allow_zipped = allow_zipped - def track_file(self, launch_dir=None): + def track_file(self, launch_dir: Optional[str] = None) -> str: """ Reads the monitored file and returns back the last N lines @@ -440,19 +444,19 @@ def track_file(self, launch_dir=None): self.content = "\n".join(reversed(lines)) return self.content - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: m_dict = {"filename": self.filename, "nlines": self.nlines, "allow_zipped": self.allow_zipped} if self.content: m_dict["content"] = self.content return m_dict @classmethod - def from_dict(cls, m_dict): + def from_dict(cls, m_dict: Dict[str, Any]) -> "Tracker": return Tracker( m_dict["filename"], m_dict["nlines"], m_dict.get("content", ""), m_dict.get("allow_zipped", False) ) - def __str__(self): + def __str__(self) -> str: return f"### Filename: {self.filename}\n{self.content}" @@ -463,16 +467,16 @@ class Launch(FWSerializable): def __init__( self, - state, - launch_dir, - fworker=None, - host=None, - ip=None, - trackers=None, - action=None, - state_history=None, - launch_id=None, - fw_id=None, + state: str, + launch_dir: str, + fworker: Optional["FWorker"] = None, + host: Optional[str] = None, + ip: Optional[str] = None, + trackers: Optional[Sequence["Tracker"]] = None, + action: Optional["FWAction"] = None, + state_history: Optional[Dict[Any, Any]] = None, + launch_id: Optional[int] = None, + fw_id: Optional[int] = None, ): """ Args: @@ -500,7 +504,7 @@ def __init__( self.launch_id = launch_id self.fw_id = fw_id - def touch_history(self, update_time=None, checkpoint=None): + def touch_history(self, update_time: Optional[datetime] = None, checkpoint: Optional[Checkpoint] = None) -> None: """ Updates the update_on field of the state history of a Launch. Used to ping that a Launch is still alive. @@ -513,7 +517,7 @@ def touch_history(self, update_time=None, checkpoint=None): self.state_history[-1]["checkpoint"] = checkpoint self.state_history[-1]["updated_on"] = update_time - def set_reservation_id(self, reservation_id): + def set_reservation_id(self, reservation_id: Union[int, str]) -> None: """ Adds the job_id to the reservation. @@ -526,7 +530,7 @@ def set_reservation_id(self, reservation_id): break @property - def state(self): + def state(self) -> str: """ Returns: str: The current state of the Launch. @@ -534,7 +538,7 @@ def state(self): return self._state @state.setter - def state(self, state): + def state(self, state: str) -> None: """ Setter for the the Launch's state. Automatically triggers an update to state_history. @@ -545,7 +549,7 @@ def state(self, state): self._update_state_history(state) @property - def time_start(self): + def time_start(self) -> datetime: """ Returns: datetime: the time the Launch started RUNNING @@ -553,7 +557,7 @@ def time_start(self): return self._get_time("RUNNING") @property - def time_end(self): + def time_end(self) -> datetime: """ Returns: datetime: the time the Launch was COMPLETED or FIZZLED @@ -561,7 +565,7 @@ def time_end(self): return self._get_time(["COMPLETED", "FIZZLED"]) @property - def time_reserved(self): + def time_reserved(self) -> datetime: """ Returns: datetime: the time the Launch was RESERVED in the queue @@ -569,7 +573,7 @@ def time_reserved(self): return self._get_time("RESERVED") @property - def last_pinged(self): + def last_pinged(self) -> datetime: """ Returns: datetime: the time the Launch last pinged a heartbeat that it was still running @@ -577,7 +581,7 @@ def last_pinged(self): return self._get_time("RUNNING", True) @property - def runtime_secs(self): + def runtime_secs(self) -> int: # type: ignore """ Returns: int: the number of seconds that the Launch ran for. @@ -585,10 +589,10 @@ def runtime_secs(self): start = self.time_start end = self.time_end if start and end: - return (end - start).total_seconds() + return int((end - start).total_seconds()) @property - def reservedtime_secs(self): + def reservedtime_secs(self) -> int: # type: ignore """ Returns: int: number of seconds the Launch was stuck as RESERVED in a queue. @@ -596,10 +600,10 @@ def reservedtime_secs(self): start = self.time_reserved if start: end = self.time_start if self.time_start else datetime.utcnow() - return (end - start).total_seconds() + return int((end - start).total_seconds()) @recursive_serialize - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: return { "fworker": self.fworker, "fw_id": self.fw_id, @@ -663,7 +667,7 @@ def _update_state_history(self, state): if state in ["RUNNING", "RESERVED"]: self.touch_history() # add updated_on key - def _get_time(self, states, use_update_time=False): + def _get_time(self, states, use_update_time: bool = False) -> datetime: # type: ignore """ Internal method to help get the time of various events in the Launch (e.g. RUNNING) from the state history. @@ -1284,7 +1288,7 @@ def _get_representative_launch(fw): return m_launch @classmethod - def from_wflow(cls, wflow): + def from_wflow(cls, wflow: "Workflow") -> "Workflow": """ Create a fresh Workflow from an existing one. @@ -1298,7 +1302,7 @@ def from_wflow(cls, wflow): new_wf.reset(reset_ids=True) return new_wf - def reset(self, reset_ids=True): + def reset(self, reset_ids: bool = True) -> None: """ Reset the states of all Fireworks in this workflow to 'WAITING'. @@ -1320,7 +1324,7 @@ def reset(self, reset_ids=True): self.fw_states = {key: self.id_fw[key].state for key in self.id_fw} @classmethod - def from_dict(cls, m_dict): + def from_dict(cls, m_dict) -> "Workflow": """ Return Workflow from its dict representation. @@ -1346,7 +1350,7 @@ def from_dict(cls, m_dict): return Workflow.from_Firework(Firework.from_dict(m_dict)) @classmethod - def from_Firework(cls, fw, name=None, metadata=None): + def from_Firework(cls, fw: "Firework", name: Optional[str] = None, metadata=None) -> "Workflow": """ Return Workflow from the given Firework. @@ -1361,10 +1365,10 @@ def from_Firework(cls, fw, name=None, metadata=None): name = name if name else fw.name return Workflow([fw], None, name=name, metadata=metadata, created_on=fw.created_on, updated_on=fw.updated_on) - def __str__(self): + def __str__(self) -> str: return f"Workflow object: (fw_ids: {self.id_fw.keys()} , name: {self.name})" - def remove_fws(self, fw_ids): + def remove_fws(self, fw_ids: Sequence[int]) -> None: """ Remove the fireworks corresponding to the input firework ids and update the workflow i.e the parents of the removed fireworks become the parents of the children fireworks (only if the diff --git a/fireworks/core/fworker.py b/fireworks/core/fworker.py index 610d6c132..c444e24b6 100644 --- a/fireworks/core/fworker.py +++ b/fireworks/core/fworker.py @@ -3,6 +3,7 @@ """ import json +from typing import Any, Dict, Optional, Sequence, Union from fireworks.fw_config import FWORKER_LOC from fireworks.utilities.fw_serializers import ( @@ -22,7 +23,13 @@ class FWorker(FWSerializable): - def __init__(self, name="Automatically generated Worker", category="", query=None, env=None): + def __init__( + self, + name: str = "Automatically generated Worker", + category: Union[str, Sequence[str]] = "", + query: Optional[Dict[str, Any]] = None, + env: Optional[Dict[str, Any]] = None, + ) -> None: """ Args: name (str): the name of the resource, should be unique @@ -42,7 +49,7 @@ def __init__(self, name="Automatically generated Worker", category="", query=Non self.env = env if env else {} @recursive_serialize - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: return { "name": self.name, "category": self.category, @@ -52,11 +59,11 @@ def to_dict(self): @classmethod @recursive_deserialize - def from_dict(cls, m_dict): + def from_dict(cls, m_dict: Dict[str, Any]) -> "FWorker": return FWorker(m_dict["name"], m_dict["category"], json.loads(m_dict["query"]), m_dict.get("env")) @property - def query(self): + def query(self) -> Dict[str, Any]: """ Returns updated query dict. """ @@ -78,7 +85,7 @@ def query(self): return q @classmethod - def auto_load(cls): + def auto_load(cls) -> "FWorker": """ Returns FWorker object from settings file(my_fworker.yaml). """ diff --git a/fireworks/core/launchpad.py b/fireworks/core/launchpad.py index 43ba7308f..b6d3ae083 100644 --- a/fireworks/core/launchpad.py +++ b/fireworks/core/launchpad.py @@ -15,6 +15,7 @@ import traceback from collections import OrderedDict, defaultdict from itertools import chain +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union import gridfs from bson import ObjectId @@ -59,7 +60,7 @@ # TODO: lots of duplication reduction and cleanup possible -def sort_aggregation(sort): +def sort_aggregation(sort: Sequence[Tuple[str, int]]) -> List[Mapping[str, Any]]: """Build sorting aggregation pipeline. Args: @@ -106,7 +107,13 @@ class WFLock: Calling functions are responsible for handling the error in order to avoid database inconsistencies. """ - def __init__(self, lp, fw_id, expire_secs=WFLOCK_EXPIRATION_SECS, kill=WFLOCK_EXPIRATION_KILL): + def __init__( + self, + lp: "LaunchPad", + fw_id: int, + expire_secs: int = WFLOCK_EXPIRATION_SECS, + kill: bool = WFLOCK_EXPIRATION_KILL, + ): """ Args: lp (LaunchPad) @@ -119,7 +126,7 @@ def __init__(self, lp, fw_id, expire_secs=WFLOCK_EXPIRATION_SECS, kill=WFLOCK_EX self.expire_secs = expire_secs self.kill = kill - def __enter__(self): + def __enter__(self) -> None: ctr = 0 waiting_time = 0 # acquire lock @@ -149,7 +156,7 @@ def __enter__(self): {"nodes": self.fw_id, "locked": {"$exists": False}}, {"$set": {"locked": True}} ) - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type, exc_val, exc_tb) -> None: self.lp.workflows.find_one_and_update({"nodes": self.fw_id}, {"$unset": {"locked": True}}) @@ -160,22 +167,22 @@ class LaunchPad(FWSerializable): def __init__( self, - host=None, - port=None, - name=None, - username=None, - password=None, - logdir=None, - strm_lvl=None, - user_indices=None, - wf_user_indices=None, - ssl=False, - ssl_ca_certs=None, - ssl_certfile=None, - ssl_keyfile=None, - ssl_pem_passphrase=None, - authsource=None, - uri_mode=False, + host: Optional[str] = None, + port: Optional[int] = None, + name: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + logdir: Optional[str] = None, + strm_lvl: Optional[str] = None, + user_indices: Optional[Sequence[int]] = None, + wf_user_indices: Optional[Sequence[int]] = None, + ssl: bool = False, + ssl_ca_certs: Optional[str] = None, + ssl_certfile: Optional[str] = None, + ssl_keyfile: Optional[str] = None, + ssl_pem_passphrase: Optional[str] = None, + authsource: str = None, + uri_mode: bool = False, mongoclient_kwargs=None, ): """ @@ -627,7 +634,7 @@ def get_wf_by_fw_id_lzyfw(self, fw_id): fw_states, ) - def delete_fws(self, fw_ids, delete_launch_dirs=False): + def delete_fws(self, fw_ids: Sequence[int], delete_launch_dirs: bool = False) -> None: """Delete a set of fireworks identified by their fw_ids. ATTENTION: This function serves maintenance purposes and will leave @@ -771,7 +778,14 @@ def get_wf_summary_dict(self, fw_id, mode="more"): return wf - def get_fw_ids(self, query=None, sort=None, limit=0, count_only=False, launches_mode=False): + def get_fw_ids( + self, + query: Optional[Mapping[str, Any]] = None, + sort=None, + limit: int = 0, + count_only: bool = False, + launches_mode: bool = False, + ) -> Union[int, Sequence[int]]: """ Return all the fw ids that match a query. @@ -2089,7 +2103,7 @@ class LazyFirework: db_fields = ("name", "fw_id", "spec", "created_on", "state") db_launch_fields = ("launches", "archived_launches") - def __init__(self, fw_id, fw_coll, launch_coll, fallback_fs): + def __init__(self, fw_id: int, fw_coll, launch_coll, fallback_fs): """ Args: fw_id (int): firework id @@ -2228,7 +2242,7 @@ def full_fw(self): # Get a type of Launch object - def _get_launch_data(self, name): + def _get_launch_data(self, name: str): """ Pull launch data individually for each field. @@ -2253,7 +2267,7 @@ def _get_launch_data(self, name): return getattr(fw, name) -def get_action_from_gridfs(action_dict, fallback_fs): +def get_action_from_gridfs(action_dict: Mapping[str, Any], fallback_fs: gridfs.GridFS) -> Dict[str, Any]: """ Helper function to obtain the correct dictionary of the FWAction associated with a launch. If necessary retrieves the information from gridfs based diff --git a/fireworks/core/rocket.py b/fireworks/core/rocket.py index 7f1361b67..6118bc483 100644 --- a/fireworks/core/rocket.py +++ b/fireworks/core/rocket.py @@ -1,5 +1,7 @@ from monty.os.path import zpath +from fireworks.core.fworker import FWorker + """ A Rocket fetches a Firework from the database, runs the sequence of Firetasks inside, and then completes the Launch @@ -17,11 +19,13 @@ import threading import traceback from datetime import datetime +from typing import Optional from monty.io import zopen from fireworks.core.firework import Firework, FWAction from fireworks.core.launchpad import LaunchPad, LockedWorkflowError +from fireworks.core.types import Checkpoint, Spec from fireworks.fw_config import ( PING_TIME_SECS, PRINT_FW_JSON, @@ -42,7 +46,7 @@ __date__ = "Feb 7, 2013" -def do_ping(launchpad, launch_id): +def do_ping(launchpad: LaunchPad, launch_id: int) -> None: if launchpad: launchpad.ping_launch(launch_id) else: @@ -50,13 +54,13 @@ def do_ping(launchpad, launch_id): f.write('{"ping_time": "%s"}' % datetime.utcnow().isoformat()) -def ping_launch(launchpad, launch_id, stop_event, master_thread): +def ping_launch(launchpad: LaunchPad, launch_id: int, stop_event, master_thread) -> None: while not stop_event.is_set() and master_thread.is_alive(): do_ping(launchpad, launch_id) stop_event.wait(PING_TIME_SECS) -def start_ping_launch(launchpad, launch_id): +def start_ping_launch(launchpad: LaunchPad, launch_id: int) -> Optional[threading.Event]: fd = FWData() if fd.MULTIPROCESSING: if not launch_id: @@ -72,7 +76,7 @@ def start_ping_launch(launchpad, launch_id): return ping_stop -def stop_backgrounds(ping_stop, btask_stops): +def stop_backgrounds(ping_stop: threading.Event, btask_stops) -> None: fd = FWData() if fd.MULTIPROCESSING: fd.Running_IDs[os.getpid()] = None @@ -83,7 +87,7 @@ def stop_backgrounds(ping_stop, btask_stops): b.set() -def background_task(btask, spec, stop_event, master_thread): +def background_task(btask, spec: Spec, stop_event: threading.Event, master_thread) -> None: num_launched = 0 while not stop_event.is_set() and master_thread.is_alive(): for task in btask.tasks: @@ -95,7 +99,7 @@ def background_task(btask, spec, stop_event, master_thread): break -def start_background_task(btask, spec): +def start_background_task(btask, spec: Spec) -> threading.Event: ping_stop = threading.Event() ping_thread = threading.Thread(target=background_task, args=(btask, spec, ping_stop, threading.current_thread())) ping_thread.start() @@ -107,7 +111,7 @@ class Rocket: The Rocket fetches a workflow step from the FireWorks database and executes it. """ - def __init__(self, launchpad, fworker, fw_id): + def __init__(self, launchpad: LaunchPad, fworker: FWorker, fw_id: int) -> None: """ Args: launchpad (LaunchPad): A LaunchPad object for interacting with the FW database. @@ -119,7 +123,7 @@ def __init__(self, launchpad, fworker, fw_id): self.fworker = fworker self.fw_id = fw_id - def run(self, pdb_on_exception=False): + def run(self, pdb_on_exception: bool = False) -> bool: """ Run the rocket (check out a job from the database and execute it) @@ -428,7 +432,7 @@ def run(self, pdb_on_exception=False): return True @staticmethod - def update_checkpoint(launchpad, launch_dir, launch_id, checkpoint): + def update_checkpoint(launchpad: LaunchPad, launch_dir: str, launch_id: int, checkpoint: Checkpoint) -> None: """ Helper function to update checkpoint @@ -448,7 +452,7 @@ def update_checkpoint(launchpad, launch_dir, launch_id, checkpoint): with zopen(fpath, "wt") as f_out: f_out.write(json.dumps(d, ensure_ascii=False)) - def decorate_fwaction(self, fwaction, my_spec, m_fw, launch_dir): + def decorate_fwaction(self, fwaction: FWAction, my_spec: Spec, m_fw: Firework, launch_dir: str) -> FWAction: if my_spec.get("_pass_job_info"): job_info = list(my_spec.get("_job_info", [])) diff --git a/fireworks/core/rocket_launcher.py b/fireworks/core/rocket_launcher.py index 7295b9f09..0e7641f98 100644 --- a/fireworks/core/rocket_launcher.py +++ b/fireworks/core/rocket_launcher.py @@ -5,8 +5,10 @@ import os import time from datetime import datetime +from typing import Optional from fireworks.core.fworker import FWorker +from fireworks.core.launchpad import LaunchPad from fireworks.core.rocket import Rocket from fireworks.fw_config import FWORKER_LOC, RAPIDFIRE_SLEEP_SECS from fireworks.utilities.fw_utilities import ( @@ -24,17 +26,23 @@ __date__ = "Feb 22, 2013" -def get_fworker(fworker): - if fworker: +def get_fworker(fworker: Optional[FWorker]) -> FWorker: + if fworker is not None: my_fwkr = fworker elif FWORKER_LOC: - my_fwkr = FWorker.from_file(FWORKER_LOC) + my_fwkr = FWorker.from_file(FWORKER_LOC) # type: ignore else: my_fwkr = FWorker() return my_fwkr -def launch_rocket(launchpad, fworker=None, fw_id=None, strm_lvl="INFO", pdb_on_exception=False): +def launch_rocket( + launchpad: LaunchPad, + fworker: Optional[FWorker] = None, + fw_id: Optional[int] = None, + strm_lvl: str = "INFO", + pdb_on_exception: bool = False, +) -> bool: """ Run a single rocket in the current directory. @@ -54,6 +62,7 @@ def launch_rocket(launchpad, fworker=None, fw_id=None, strm_lvl="INFO", pdb_on_e l_logger = get_fw_logger("rocket.launcher", l_dir=l_dir, stream_level=strm_lvl) log_multi(l_logger, "Launching Rocket") + assert fw_id is not None rocket = Rocket(launchpad, fworker, fw_id) rocket_ran = rocket.run(pdb_on_exception=pdb_on_exception) log_multi(l_logger, "Rocket finished") @@ -61,16 +70,16 @@ def launch_rocket(launchpad, fworker=None, fw_id=None, strm_lvl="INFO", pdb_on_e def rapidfire( - launchpad, - fworker=None, - m_dir=None, - nlaunches=0, - max_loops=-1, - sleep_time=None, - strm_lvl="INFO", - timeout=None, - local_redirect=False, - pdb_on_exception=False, + launchpad: LaunchPad, + fworker: Optional[FWorker] = None, + m_dir: Optional[str] = None, + nlaunches: int = 0, + max_loops: int = -1, + sleep_time: Optional[int] = None, + strm_lvl: str = "INFO", + timeout: Optional[int] = None, + local_redirect: bool = False, + pdb_on_exception: bool = False, ): """ Keeps running Rockets in m_dir until we reach an error. Automatically creates subdirectories diff --git a/fireworks/core/tests/tasks.py b/fireworks/core/tests/tasks.py index 83aa0bf9f..23fbbc14a 100644 --- a/fireworks/core/tests/tasks.py +++ b/fireworks/core/tests/tasks.py @@ -2,6 +2,7 @@ from unittest import SkipTest from fireworks import FiretaskBase, Firework, FWAction +from fireworks.core.types import Spec from fireworks.utilities.fw_utilities import explicit_serialize @@ -95,7 +96,7 @@ def run_task(self, fw_spec): class DetoursTask(FiretaskBase): optional_params = ["n_detours", "data_per_detour"] - def run_task(self, fw_spec): + def run_task(self, fw_spec: Spec) -> FWAction: data_per_detour = self.get("data_per_detour", None) n_detours = self.get("n_detours", 10) fws = [] diff --git a/fireworks/core/types.py b/fireworks/core/types.py index 8f0bdf1e5..8ee2a35b5 100644 --- a/fireworks/core/types.py +++ b/fireworks/core/types.py @@ -1,5 +1,4 @@ from typing import Any, MutableMapping - Checkpoint = MutableMapping[Any, Any] Spec = MutableMapping[Any, Any] diff --git a/fireworks/flask_site/gunicorn.py b/fireworks/flask_site/gunicorn.py index abb9de292..c562f55ab 100755 --- a/fireworks/flask_site/gunicorn.py +++ b/fireworks/flask_site/gunicorn.py @@ -1,23 +1,24 @@ # Based on http://docs.gunicorn.org/en/19.6.0/custom.html import multiprocessing +from typing import Any, Mapping, Optional import gunicorn.app.base from fireworks.flask_site.app import app as handler_app -def number_of_workers(): +def number_of_workers() -> int: return (multiprocessing.cpu_count() * 2) + 1 class StandaloneApplication(gunicorn.app.base.BaseApplication): - def __init__(self, app, options=None): + def __init__(self, app, options: Optional[Mapping[str, Any]] = None) -> None: self.options = options or {} self.application = app super().__init__() - def load_config(self): + def load_config(self) -> None: config = {key: value for key, value in self.options.items() if key in self.cfg.settings and value is not None} for key, value in config.items(): self.cfg.set(key.lower(), value) diff --git a/fireworks/flask_site/helpers.py b/fireworks/flask_site/helpers.py index 63e91a022..3a97d6ccc 100644 --- a/fireworks/flask_site/helpers.py +++ b/fireworks/flask_site/helpers.py @@ -1,4 +1,9 @@ -def get_totals(states, lp): +from typing import Mapping, Sequence + +from fireworks import LaunchPad + + +def get_totals(states: Sequence[str], lp: LaunchPad) -> Mapping[str, int]: fw_stats = {} wf_stats = {} for state in states: @@ -7,21 +12,21 @@ def get_totals(states, lp): return {"fw_stats": fw_stats, "wf_stats": wf_stats} -def fw_filt_given_wf_filt(filt, lp): +def fw_filt_given_wf_filt(filt, lp: LaunchPad) -> Mapping[str, Mapping[str, Sequence[int]]]: fw_ids = set() for doc in lp.workflows.find(filt, {"_id": 0, "nodes": 1}): fw_ids |= set(doc["nodes"]) return {"fw_id": {"$in": list(fw_ids)}} -def wf_filt_given_fw_filt(filt, lp): +def wf_filt_given_fw_filt(filt, lp: LaunchPad) -> Mapping[str, Mapping[str, Sequence[int]]]: wf_ids = set() for doc in lp.fireworks.find(filt, {"_id": 0, "fw_id": 1}): wf_ids.add(doc["fw_id"]) return {"nodes": {"$in": list(wf_ids)}} -def uses_index(filt, coll): +def uses_index(filt, coll) -> bool: ii = coll.index_information() fields_filtered = set(filt.keys()) fields_indexed = {v["key"][0][0] for v in ii.values()} diff --git a/fireworks/scripts/fwtool b/fireworks/scripts/fwtool index df1f60c11..314c2aee4 100755 --- a/fireworks/scripts/fwtool +++ b/fireworks/scripts/fwtool @@ -3,7 +3,7 @@ import os import shutil import sys -from argparse import ArgumentParser +from argparse import ArgumentParser, Namespace import yaml @@ -19,8 +19,10 @@ __maintainer__ = "Shyue Ping Ong" __email__ = "ongsp@ucsd.edu" __date__ = "1/6/14" +from typing import Sequence -def create_fw_single(args, fnames, yaml_fname): + +def create_fw_single(args: Namespace, fnames: Sequence[str], yaml_fname: str) -> None: tasks = [] if fnames: files = [] @@ -35,7 +37,7 @@ def create_fw_single(args, fnames, yaml_fname): yaml.dump(fw.to_dict(), f, default_flow_style=False) -def create_fw(args): +def create_fw(args: Namespace) -> None: if not args.directory_mode: create_fw_single(args, args.files_or_dirs, args.output) else: @@ -44,7 +46,7 @@ def create_fw(args): create_fw_single(args, [os.path.join(d, f) for f in os.listdir(d)], output_fname.format(i)) -def do_cleanup(args): +def do_cleanup(args: Namespace) -> None: lp = LaunchPad.from_file(args.launchpad_file) if args.launchpad_file else LaunchPad(strm_lvl=args.loglvl) to_delete = [] for i in lp.get_fw_ids({}): diff --git a/fireworks/scripts/lpad_run.py b/fireworks/scripts/lpad_run.py index 10cdc4ab6..b58a0ab76 100644 --- a/fireworks/scripts/lpad_run.py +++ b/fireworks/scripts/lpad_run.py @@ -10,7 +10,8 @@ import re import time import traceback -from argparse import ArgumentParser, ArgumentTypeError +from argparse import ArgumentParser, ArgumentTypeError, Namespace +from typing import Any, Callable, Sequence import ruamel.yaml as yaml from pymongo import ASCENDING, DESCENDING @@ -47,7 +48,7 @@ DEFAULT_LPAD_YAML = "my_launchpad.yaml" -def pw_check(ids, args, skip_pw=False): +def pw_check(ids: Sequence[int], args: Namespace, skip_pw: bool = False) -> Sequence[int]: if len(ids) > PW_CHECK_NUM and not skip_pw: m_password = datetime.datetime.now().strftime("%Y-%m-%d") if not args.password: @@ -63,7 +64,7 @@ def pw_check(ids, args, skip_pw=False): return ids -def parse_helper(lp, args, wf_mode=False, skip_pw=False): +def parse_helper(lp: LaunchPad, args: Namespace, wf_mode: bool = False, skip_pw: bool = False) -> Sequence[int]: """ Helper method to parse args that can take either id, name, state or query. @@ -103,7 +104,7 @@ def parse_helper(lp, args, wf_mode=False, skip_pw=False): return pw_check(lp.get_fw_ids(query, sort=sort, limit=max, launches_mode=args.launches_mode), args, skip_pw) -def get_lp(args): +def get_lp(args: Namespace) -> LaunchPad: try: if not args.launchpad_file: if os.path.exists(os.path.join(args.config_dir, DEFAULT_LPAD_YAML)): @@ -131,7 +132,7 @@ def get_lp(args): raise ValueError(err_message) -def init_yaml(args): +def init_yaml(args: Namespace) -> None: if args.uri_mode: fields = ( ("host", None, "Example: mongodb+srv://USER:PASSWORD@CLUSTERNAME.mongodb.net/fireworks"), @@ -182,7 +183,7 @@ def init_yaml(args): print(f"\nConfiguration written to {args.config_file}!") -def reset(args): +def reset(args: Namespace) -> None: lp = get_lp(args) if not args.password: if ( @@ -197,7 +198,7 @@ def reset(args): lp.reset(args.password) -def add_wf(args): +def add_wf(args: Namespace) -> None: lp = get_lp(args) if args.dir: files = [] @@ -214,31 +215,31 @@ def add_wf(args): lp.add_wf(fwf) -def append_wf(args): +def append_wf(args: Namespace) -> None: lp = get_lp(args) lp.append_wf(Workflow.from_file(args.wf_file), args.fw_id, detour=args.detour, pull_spec_mods=args.pull_spec_mods) -def dump_wf(args): +def dump_wf(args: Namespace) -> None: lp = get_lp(args) lp.get_wf_by_fw_id(args.fw_id).to_file(args.wf_file) -def check_wf(args): +def check_wf(args: Namespace) -> None: from fireworks.utilities.dagflow import DAGFlow lp = get_lp(args) DAGFlow.from_fireworks(lp.get_wf_by_fw_id(args.fw_id)).check() -def add_wf_dir(args): +def add_wf_dir(args: Namespace) -> None: lp = get_lp(args) for filename in os.listdir(args.wf_dir): fwf = Workflow.from_file(filename) lp.add_wf(fwf) -def print_fws(ids, lp, args): +def print_fws(ids: Sequence[int], lp: LaunchPad, args: Namespace) -> None: """Prints results of some FireWorks query to stdout.""" fws = [] if args.display_format == "ids": @@ -316,7 +317,7 @@ def get_fw_ids_helper(lp, args, count_only=None): return ids -def get_fws_helper(lp, ids, args): +def get_fws_helper(lp: LaunchPad, ids: Sequence[int], args: Namespace) -> List[Any]: """Get fws from ids in a representation according to args.display_format.""" fws = [] if args.display_format == "ids": @@ -341,14 +342,14 @@ def get_fws_helper(lp, ids, args): return fws -def get_fws(args): +def get_fws(args: Namespace) -> None: lp = get_lp(args) ids = get_fw_ids_helper(lp, args) fws = get_fws_helper(lp, ids, args) print(args.output(fws)) -def get_fws_in_wfs(args): +def get_fws_in_wfs(args: Namespace) -> None: # get_wfs lp = get_lp(args) if sum(bool(x) for x in [args.wf_fw_id, args.wf_name, args.wf_state, args.wf_query]) > 1: @@ -414,13 +415,13 @@ def get_fws_in_wfs(args): print_fws(ids, lp, args) -def update_fws(args): +def update_fws(args: Namespace) -> None: lp = get_lp(args) fw_ids = parse_helper(lp, args) lp.update_spec(fw_ids, json.loads(args.update), args.mongo) -def get_wfs(args): +def get_wfs(args: Namespace) -> None: lp = get_lp(args) if sum(bool(x) for x in [args.fw_id, args.name, args.state, args.query]) > 1: raise ValueError("Please specify exactly one of (fw_id, name, state, query)") @@ -473,7 +474,7 @@ def get_wfs(args): print(args.output(wfs)) -def delete_wfs(args): +def delete_wfs(args: Namespace) -> None: lp = get_lp(args) fw_ids = parse_helper(lp, args, wf_mode=True) for f in fw_ids: @@ -482,18 +483,18 @@ def delete_wfs(args): lp.m_logger.info(f"Finished deleting {len(fw_ids)} WFs") -def get_children(links, start, max_depth): - data = {} - for l, c in links.items(): - if l == start: - if len(c) > 0: - data[l] = [get_children(links, i, max_depth) for i in c] - else: - data[l] = c - return data +# def get_children(links, start): +# data = {} +# for l, c in links.items(): +# if l == start: +# if len(c) > 0: +# data[l] = [get_children(links, i) for i in c] +# else: +# data[l] = c +# return data -def detect_lostruns(args): +def detect_lostruns(args: Namespace) -> None: lp = get_lp(args) query = ast.literal_eval(args.query) if args.query else None launch_query = ast.literal_eval(args.launch_query) if args.launch_query else None @@ -520,7 +521,7 @@ def detect_lostruns(args): print("You can fix inconsistent FWs using the --refresh argument to the " "detect_lostruns command") -def detect_unreserved(args): +def detect_unreserved(args: Namespace) -> None: lp = get_lp(args) if args.display_format is not None and args.display_format != "none": unreserved = lp.detect_unreserved(expiration_secs=args.time, rerun=False) @@ -533,12 +534,12 @@ def detect_unreserved(args): print(lp.detect_unreserved(expiration_secs=args.time, rerun=args.rerun)) -def tuneup(args): +def tuneup(args: Namespace) -> None: lp = get_lp(args) lp.tuneup(bkground=not args.full) -def defuse_wfs(args): +def defuse_wfs(args: Namespace) -> None: lp = get_lp(args) fw_ids = parse_helper(lp, args, wf_mode=True) for f in fw_ids: @@ -552,7 +553,7 @@ def defuse_wfs(args): ) -def pause_wfs(args): +def pause_wfs(args: Namespace) -> None: lp = get_lp(args) fw_ids = parse_helper(lp, args, wf_mode=True) for f in fw_ids: @@ -561,7 +562,7 @@ def pause_wfs(args): lp.m_logger.info(f"Finished defusing {len(fw_ids)} FWs.") -def archive(args): +def archive(args: Namespace) -> None: lp = get_lp(args) fw_ids = parse_helper(lp, args, wf_mode=True) for f in fw_ids: @@ -570,7 +571,7 @@ def archive(args): lp.m_logger.info(f"Finished archiving {len(fw_ids)} WFs") -def reignite_wfs(args): +def reignite_wfs(args: Namespace) -> None: lp = get_lp(args) fw_ids = parse_helper(lp, args, wf_mode=True) for f in fw_ids: @@ -579,7 +580,7 @@ def reignite_wfs(args): lp.m_logger.info(f"Finished reigniting {len(fw_ids)} Workflows") -def defuse_fws(args): +def defuse_fws(args: Namespace) -> None: lp = get_lp(args) fw_ids = parse_helper(lp, args) for f in fw_ids: @@ -588,7 +589,7 @@ def defuse_fws(args): lp.m_logger.info(f"Finished defusing {len(fw_ids)} FWs") -def pause_fws(args): +def pause_fws(args: Namespace) -> None: lp = get_lp(args) fw_ids = parse_helper(lp, args) for f in fw_ids: @@ -597,7 +598,7 @@ def pause_fws(args): lp.m_logger.info(f"Finished pausing {len(fw_ids)} FWs") -def reignite_fws(args): +def reignite_fws(args: Namespace) -> None: lp = get_lp(args) fw_ids = parse_helper(lp, args) for f in fw_ids: @@ -606,7 +607,7 @@ def reignite_fws(args): lp.m_logger.info(f"Finished reigniting {len(fw_ids)} FWs") -def resume_fws(args): +def resume_fws(args: Namespace) -> None: lp = get_lp(args) fw_ids = parse_helper(lp, args) for f in fw_ids: @@ -615,7 +616,7 @@ def resume_fws(args): lp.m_logger.info(f"Finished resuming {len(fw_ids)} FWs") -def rerun_fws(args): +def rerun_fws(args: Namespace) -> None: lp = get_lp(args) fw_ids = parse_helper(lp, args) if args.task_level: @@ -632,7 +633,7 @@ def rerun_fws(args): lp.m_logger.info(f"Finished setting {len(fw_ids)} FWs to rerun") -def refresh(args): +def refresh(args: Namespace) -> None: lp = get_lp(args) fw_ids = parse_helper(lp, args, wf_mode=True) for f in fw_ids: @@ -643,7 +644,7 @@ def refresh(args): lp.m_logger.info(f"Finished refreshing {len(fw_ids)} Workflows") -def unlock(args): +def unlock(args: Namespace) -> None: lp = get_lp(args) fw_ids = parse_helper(lp, args, wf_mode=True) for f in fw_ids: @@ -653,13 +654,13 @@ def unlock(args): lp.m_logger.info(f"Finished unlocking {len(fw_ids)} Workflows") -def get_qid(args): +def get_qid(args: Namespace) -> None: lp = get_lp(args) for f in args.fw_id: print(lp.get_reservation_id_from_fw_id(f)) -def cancel_qid(args): +def cancel_qid(args: Namespace) -> None: lp = get_lp(args) lp.m_logger.warning( "WARNING: cancel_qid does not actually remove jobs from the queue " @@ -668,7 +669,7 @@ def cancel_qid(args): lp.cancel_reservation_by_reservation_id(args.qid) -def set_priority(args): +def set_priority(args: Namespace) -> None: wf_mode = args.wf lp = get_lp(args) fw_ids = parse_helper(lp, args, wf_mode=wf_mode) @@ -684,7 +685,7 @@ def set_priority(args): lp.m_logger.info(f"Finished setting priorities of {len(fw_ids)} FWs") -def _open_webbrowser(url): +def _open_webbrowser(url: str) -> None: """Open a web browser after a delay to give the web server more startup time.""" import webbrowser @@ -692,7 +693,7 @@ def _open_webbrowser(url): webbrowser.open(url) -def webgui(args): +def webgui(args: Namespace) -> None: from fireworks.flask_site.app import app app.lp = get_lp(args) @@ -736,7 +737,7 @@ def webgui(args): StandaloneApplication(app, options).run() -def add_scripts(args): +def add_scripts(args: Namespace) -> None: lp = get_lp(args) args.names = args.names if args.names else [None] * len(args.scripts) args.wf_name = args.wf_name if args.wf_name else args.names[0] @@ -750,7 +751,7 @@ def add_scripts(args): lp.add_wf(Workflow(fws, links, args.wf_name)) -def recover_offline(args): +def recover_offline(args: Namespace) -> None: lp = get_lp(args) fworker_name = FWorker.from_file(args.fworker_file).name if args.fworker_file else None failed_fws = [] @@ -770,7 +771,7 @@ def recover_offline(args): lp.m_logger.info(f"FAILED to recover offline fw_ids: {failed_fws}") -def forget_offline(args): +def forget_offline(args: Namespace) -> None: lp = get_lp(args) fw_ids = parse_helper(lp, args) for f in fw_ids: @@ -779,7 +780,7 @@ def forget_offline(args): lp.m_logger.info(f"Finished forget_offine, processed {len(fw_ids)} FWs") -def report(args): +def report(args: Namespace) -> None: lp = get_lp(args) query = ast.literal_eval(args.query) if args.query else None fwr = FWReport(lp) @@ -794,7 +795,7 @@ def report(args): print(fwr.get_stats_str(stats)) -def introspect(args): +def introspect(args: Namespace) -> None: print("NOTE: This feature is in beta mode...") lp = get_lp(args) isp = Introspector(lp) @@ -806,13 +807,13 @@ def introspect(args): print("") -def get_launchdir(args): +def get_launchdir(args: Namespace) -> None: lp = get_lp(args) ld = lp.get_launchdir(args.fw_id, args.launch_idx) print(ld) -def track_fws(args): +def track_fws(args: Namespace) -> None: lp = get_lp(args) fw_ids = parse_helper(lp, args, skip_pw=True) include = args.include @@ -836,17 +837,17 @@ def track_fws(args): print("\n".join(output)) -def version(args): +def version(args: Namespace) -> None: print("FireWorks version:", FW_VERSION) print("located in:", FW_INSTALL_DIR) -def maintain(args): +def maintain(args: Namespace) -> None: lp = get_lp(args) lp.maintain(args.infinite, args.maintain_interval) -def orphaned(args): +def orphaned(args: Namespace) -> None: # get_fws lp = get_lp(args) fw_ids = get_fw_ids_helper(lp, args, count_only=False) @@ -867,14 +868,14 @@ def orphaned(args): print(args.output(fws)) -def get_output_func(format): +def get_output_func(format: str) -> Callable[[Any], str]: if format == "json": - return lambda x: json.dumps(x, default=DATETIME_HANDLER, indent=4) + return lambda x: json.dumps(x, default=DATETIME_HANDLER, indent=4) # type: ignore else: - return lambda x: yaml.safe_dump(recursive_dict(x, preserve_unicode=False), default_flow_style=False) + return lambda x: yaml.safe_dump(recursive_dict(x, preserve_unicode=False), default_flow_style=False) # type: ignore -def arg_positive_int(value): +def arg_positive_int(value: Any) -> int: try: ivalue = int(value) if ivalue < 1: @@ -884,7 +885,7 @@ def arg_positive_int(value): return ivalue -def lpad(): +def lpad() -> None: m_description = ( "A command line interface to FireWorks. For more help on a specific command, " 'type "lpad -h".' ) @@ -926,7 +927,7 @@ def lpad(): # enhanced display options allow for value 'none' or None (default) for no output enh_disp_args = copy.deepcopy(disp_args) enh_disp_kwargs = copy.deepcopy(disp_kwargs) - enh_disp_kwargs["choices"].append("none") + enh_disp_kwargs["choices"].append("none") # type: ignore enh_disp_kwargs["default"] = None query_args = ["-q", "--query"] diff --git a/fireworks/user_objects/firetasks/dataflow_tasks.py b/fireworks/user_objects/firetasks/dataflow_tasks.py index c90c1ccb1..f0df3ac3a 100644 --- a/fireworks/user_objects/firetasks/dataflow_tasks.py +++ b/fireworks/user_objects/firetasks/dataflow_tasks.py @@ -5,9 +5,11 @@ __copyright__ = "Copyright 2016, Karlsruhe Institute of Technology" import sys +from typing import Any, List, Mapping, Optional from fireworks import Firework from fireworks.core.firework import FireTaskBase, FWAction +from fireworks.core.types import Spec from fireworks.utilities.fw_serializers import load_object if sys.version_info[0] > 2: @@ -74,7 +76,7 @@ class CommandLineTask(FireTaskBase): required_params = ["command_spec"] optional_params = ["inputs", "outputs", "chunk_number"] - def run_task(self, fw_spec): + def run_task(self, fw_spec: Spec) -> FWAction: cmd_spec = self["command_spec"] ilabels = self.get("inputs") olabels = self.get("outputs") @@ -136,7 +138,7 @@ def run_task(self, fw_spec): return FWAction() @staticmethod - def command_line_tool(command, inputs=None, outputs=None): + def command_line_tool(command, inputs: Optional[List[Mapping[Any, Any]]] = None, outputs: Optional[Spec] = None): """ This function composes and executes a command from provided specifications. @@ -163,7 +165,7 @@ def command_line_tool(command, inputs=None, outputs=None): from shutil import copyfile from subprocess import PIPE, Popen - def set_binding(arg): + def set_binding(arg: Mapping[str, Any]) -> str: argstr = "" if "binding" in arg: if "prefix" in arg["binding"]: @@ -283,7 +285,7 @@ class ForeachTask(FireTaskBase): required_params = ["task", "split"] optional_params = ["number of chunks"] - def run_task(self, fw_spec): + def run_task(self, fw_spec: Spec) -> FWAction: assert isinstance(self["split"], basestring), self["split"] assert isinstance(fw_spec[self["split"]], list) if isinstance(self["task"]["inputs"], list): @@ -321,7 +323,7 @@ class JoinDictTask(FireTaskBase): required_params = ["inputs", "output"] optional_params = ["rename"] - def run_task(self, fw_spec): + def run_task(self, fw_spec: Spec) -> FWAction: assert isinstance(self["output"], basestring) assert isinstance(self["inputs"], list) @@ -351,7 +353,7 @@ class JoinListTask(FireTaskBase): _fw_name = "JoinListTask" required_params = ["inputs", "output"] - def run_task(self, fw_spec): + def run_task(self, fw_spec: Spec) -> FWAction: assert isinstance(self["output"], basestring) assert isinstance(self["inputs"], list) if self["output"] not in fw_spec: @@ -377,7 +379,7 @@ class ImportDataTask(FireTaskBase): required_params = ["filename", "mapstring"] optional_params = [] - def run_task(self, fw_spec): + def run_task(self, fw_spec: Spec) -> FWAction: import json import operator from functools import reduce diff --git a/fireworks/utilities/fw_serializers.py b/fireworks/utilities/fw_serializers.py index 1013394bb..d4672046f 100644 --- a/fireworks/utilities/fw_serializers.py +++ b/fireworks/utilities/fw_serializers.py @@ -33,6 +33,7 @@ import json # note that ujson is faster, but at this time does not support "default" in dumps() import pkgutil import traceback +from typing import Any, Mapping, MutableMapping, Optional, Type import ruamel.yaml as yaml from monty.json import MontyDecoder, MSONable @@ -56,7 +57,7 @@ # TODO: consider *somehow* switching FireWorks to monty serialization. e.g., numpy serialization is better handled. -SAVED_FW_MODULES = {} +SAVED_FW_MODULES: MutableMapping[str, str] = {} DATETIME_HANDLER = lambda obj: obj.isoformat() if isinstance(obj, datetime.datetime) else None ENCODING_PARAMS = {"encoding": "utf-8"} @@ -72,7 +73,7 @@ import fireworks_schema -def recursive_dict(obj, preserve_unicode=True): +def recursive_dict(obj: Any, preserve_unicode: bool = True) -> Any: if obj is None: return None @@ -104,7 +105,7 @@ def recursive_dict(obj, preserve_unicode=True): # TODO: is reconstitute_dates really needed? Can this method just do everything? -def _recursive_load(obj): +def _recursive_load(obj: Any) -> Any: if obj is None: return None @@ -197,20 +198,20 @@ class and implement the to_dict() and from_dict() methods. """ @property - def fw_name(self): + def fw_name(self) -> str: try: - return self._fw_name + return self._fw_name # type: ignore except AttributeError: return get_default_serialization(self.__class__) @abc.abstractmethod - def to_dict(self): + def to_dict(self) -> Mapping[Any, Any]: raise NotImplementedError("FWSerializable object did not implement to_dict()!") - def to_db_dict(self): + def to_db_dict(self) -> Mapping[Any, Any]: return self.to_dict() - def as_dict(self): + def as_dict(self) -> Mapping[Any, Any]: # strictly for pseudo-compatibility with MSONable # Note that FWSerializable is not MSONable, it uses _fw_name instead of __class__ and # __module__ @@ -218,13 +219,13 @@ def as_dict(self): @classmethod @abc.abstractmethod - def from_dict(cls, m_dict): + def from_dict(cls, m_dict) -> "FWSerializable": raise NotImplementedError("FWSerializable object did not implement from_dict()!") - def __repr__(self): + def __repr__(self) -> str: return json.dumps(self.to_dict(), default=DATETIME_HANDLER) - def to_format(self, f_format="json", **kwargs): + def to_format(self, f_format: str = "json", **kwargs) -> str: """ returns a String representation in the given format @@ -235,12 +236,12 @@ def to_format(self, f_format="json", **kwargs): return json.dumps(self.to_dict(), default=DATETIME_HANDLER, **kwargs) elif f_format == "yaml": # start with the JSON format, and convert to YAML - return yaml.safe_dump(self.to_dict(), default_flow_style=YAML_STYLE, allow_unicode=True) + return yaml.safe_dump(self.to_dict(), default_flow_style=YAML_STYLE, allow_unicode=True) # type: ignore else: raise ValueError(f"Unsupported format {f_format}") @classmethod - def from_format(cls, f_str, f_format="json"): + def from_format(cls, f_str: str, f_format: str = "json") -> "FWSerializable": """ convert from a String representation to its Object. @@ -257,11 +258,11 @@ def from_format(cls, f_str, f_format="json"): dct = yaml.safe_load(f_str) else: raise ValueError(f"Unsupported format {f_format}") - if JSON_SCHEMA_VALIDATE and cls.__name__ in JSON_SCHEMA_VALIDATE_LIST: + if JSON_SCHEMA_VALIDATE and cls.__name__ in JSON_SCHEMA_VALIDATE_LIST: # type: ignore fireworks_schema.validate(dct, cls.__name__) return cls.from_dict(reconstitute_dates(dct)) - def to_file(self, filename, f_format=None, **kwargs): + def to_file(self, filename: str, f_format: Optional[str] = None, **kwargs) -> None: """ Write a serialization of this object to a file. @@ -275,7 +276,7 @@ def to_file(self, filename, f_format=None, **kwargs): f.write(self.to_format(f_format=f_format, **kwargs)) @classmethod - def from_file(cls, filename, f_format=None): + def from_file(cls, filename: str, f_format: Optional[str] = None) -> "FWSerializable": """ Load a serialization of this object from a file. @@ -294,14 +295,14 @@ def from_file(cls, filename, f_format=None): def __getstate__(self): return self.to_dict() - def __setstate__(self, state): + def __setstate__(self, state) -> None: fw_obj = self.from_dict(state) for k, v in fw_obj.__dict__.items(): self.__dict__[k] = v # TODO: make this quicker the first time around -def load_object(obj_dict): +def load_object(obj_dict: Dict[str, Any]) -> Any: """ Creates an instantiation of a class based on a dictionary representation. We implicitly determine the Class through introspection along with information in the dictionary. @@ -371,7 +372,7 @@ def load_object(obj_dict): raise ValueError(f"load_object() could not find a class with cls._fw_name {fw_name}") -def load_object_from_file(filename, f_format=None): +def load_object_from_file(filename: str, f_format: Optional[str] = None) -> Any: """ Implicitly load an object from a file. just a friendly wrapper to load_object() @@ -414,7 +415,7 @@ def _search_module_for_obj(m_module, obj_dict): return obj.from_dict(obj_dict) -def reconstitute_dates(obj_dict): +def reconstitute_dates(obj_dict: Any) -> Any: if obj_dict is None: return None @@ -435,7 +436,7 @@ def reconstitute_dates(obj_dict): return obj_dict -def get_default_serialization(cls): +def get_default_serialization(cls: Type[Any]) -> str: root_mod = cls.__module__.split(".")[0] if root_mod == "__main__": raise ValueError(