From 884c3ecd8755fe5ab83a6c8a29794ae592b2ea4f Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Tue, 9 Apr 2024 15:49:58 +0200 Subject: [PATCH] Revert "simplify code" This reverts commit 249a543190776f9d2f131e94c3ac2080e185404b. --- fireworks/core/firework.py | 18 +++---- fireworks/core/launchpad.py | 47 ++++++++++++------- fireworks/features/stats.py | 10 ++-- fireworks/queue/queue_launcher.py | 2 +- .../queue_adapters/common_adapter.py | 2 +- 5 files changed, 45 insertions(+), 34 deletions(-) diff --git a/fireworks/core/firework.py b/fireworks/core/firework.py index 58ee39624..b043d0102 100644 --- a/fireworks/core/firework.py +++ b/fireworks/core/firework.py @@ -153,9 +153,9 @@ def __init__( not only to direct children, but to all dependent FireWorks down to the Workflow's leaves. """ - mod_spec = mod_spec or [] - additions = additions or [] - detours = detours or [] + mod_spec = mod_spec if mod_spec is not None else [] + additions = additions if additions is not None else [] + detours = detours if detours is not None else [] self.stored_data = stored_data if stored_data else {} self.exit = exit @@ -267,13 +267,13 @@ def __init__( NEGATIVE_FWID_CTR -= 1 self.fw_id = NEGATIVE_FWID_CTR - self.launches = launches or [] - self.archived_launches = archived_launches or [] + self.launches = launches if launches else [] + self.archived_launches = archived_launches if archived_launches else [] self.created_on = created_on or datetime.utcnow() self.updated_on = updated_on or datetime.utcnow() parents = [parents] if isinstance(parents, Firework) else parents - self.parents = parents or [] + self.parents = parents if parents else [] self._state = state @@ -476,9 +476,9 @@ def __init__( self.fworker = fworker or FWorker() self.host = host or get_my_host() self.ip = ip or get_my_ip() - self.trackers = trackers or [] + self.trackers = trackers if trackers else [] self.action = action if action else None - self.state_history = state_history or [] + self.state_history = state_history if state_history else [] self.state = state self.launch_id = launch_id self.fw_id = fw_id @@ -643,7 +643,7 @@ def _update_state_history(self, state) -> None: now_time = datetime.utcnow() new_history_entry = {"state": state, "created_on": now_time} if state != "COMPLETED" and last_checkpoint: - new_history_entry.update(checkpoint=last_checkpoint) + new_history_entry.update({"checkpoint": last_checkpoint}) self.state_history.append(new_history_entry) if state in ["RUNNING", "RESERVED"]: self.touch_history() # add updated_on key diff --git a/fireworks/core/launchpad.py b/fireworks/core/launchpad.py index 4afc4ec01..c43a343dd 100644 --- a/fireworks/core/launchpad.py +++ b/fireworks/core/launchpad.py @@ -185,15 +185,15 @@ def __init__( self.password = password self.authsource = authsource or self.name self.mongoclient_kwargs = mongoclient_kwargs or {} - self.uri_mode = bool(uri_mode) + self.uri_mode = uri_mode # set up logger self.logdir = logdir self.strm_lvl = strm_lvl if strm_lvl else "INFO" self.m_logger = get_fw_logger("launchpad", l_dir=self.logdir, stream_level=self.strm_lvl) - self.user_indices = user_indices or [] - self.wf_user_indices = wf_user_indices or [] + self.user_indices = user_indices if user_indices else [] + self.wf_user_indices = wf_user_indices if wf_user_indices else [] # get connection if uri_mode: @@ -267,20 +267,31 @@ def update_spec(self, fw_ids, spec_document, mongo=False) -> None: ) @classmethod - def from_dict(cls, dct): + def from_dict(cls, d): + port = d.get("port", None) + name = d.get("name", None) + username = d.get("username", None) + password = d.get("password", None) + logdir = d.get("logdir", None) + strm_lvl = d.get("strm_lvl", None) + user_indices = d.get("user_indices", []) + wf_user_indices = d.get("wf_user_indices", []) + authsource = d.get("authsource", None) + uri_mode = d.get("uri_mode", False) + mongoclient_kwargs = d.get("mongoclient_kwargs", None) return LaunchPad( - dct["host"], - port=dct.get("port"), - name=dct.get("name"), - username=dct.get("username"), - password=dct.get("password"), - logdir=dct.get("logdir"), - strm_lvl=dct.get("strm_lvl"), - user_indices=dct.get("user_indices"), - wf_user_indices=dct.get("wf_user_indices"), - authsource=dct.get("authsource"), - uri_mode=dct.get("uri_mode", False), - mongoclient_kwargs=dct.get("mongoclient_kwargs"), + d["host"], + port, + name, + username, + password, + logdir, + strm_lvl, + user_indices, + wf_user_indices, + authsource, + uri_mode, + mongoclient_kwargs, ) @classmethod @@ -1659,7 +1670,7 @@ def rerun_fw(self, fw_id, rerun_duplicates=True, recover_launch=None, recover_mo # Launch recovery if recover_launch is not None: recovery = self.get_recovery(fw_id, recover_launch) - recovery.update(_mode=recover_mode) + recovery.update({"_mode": recover_mode}) set_spec = recursive_dict({"$set": {"spec._recovery": recovery}}) if recover_mode == "prev_dir": prev_dir = self.get_launch_by_id(recovery.get("_launch_id")).launch_dir @@ -1703,7 +1714,7 @@ def get_recovery(self, fw_id, launch_id="last"): m_fw = self.get_fw_by_id(fw_id) launch = m_fw.launches[-1] if launch_id == "last" else self.get_launch_by_id(launch_id) recovery = launch.state_history[-1].get("checkpoint") - recovery.update(_prev_dir=launch.launch_dir, _launch_id=launch.launch_id) + recovery.update({"_prev_dir": launch.launch_dir, "_launch_id": launch.launch_id}) return recovery def _refresh_wf(self, fw_id) -> None: diff --git a/fireworks/features/stats.py b/fireworks/features/stats.py index 402e5ef6b..c381e64d6 100644 --- a/fireworks/features/stats.py +++ b/fireworks/features/stats.py @@ -197,8 +197,8 @@ def group_fizzled_fireworks( "created_on": self._query_datetime_range(start_time=query_start, end_time=query_end, **args), } if include_ids: - project_query.update(fw_id=1) - group_query.update(fw_id={"$push": "$fw_id"}) + project_query.update({"fw_id": 1}) + group_query.update({"fw_id": {"$push": "$fw_id"}}) if query: match_query.update(query) return self._aggregate( @@ -306,11 +306,11 @@ def _get_summary( } match_query.update(query) if runtime_stats: - project_query.update(runtime_secs=1) + project_query.update({"runtime_secs": 1}) group_query.update(RUNTIME_STATS) if include_ids: project_query.update({id_field: 1}) - group_query.update(ids={"$push": "$" + id_field}) + group_query.update({"ids": {"$push": "$" + id_field}}) return self._aggregate( coll=coll, match=match_query, @@ -357,7 +357,7 @@ def _aggregate( for arg in [match, project, unwind, group_op]: if arg is None: arg = {} - group_op.update(_id=f"${group_by}") + group_op.update({"_id": "$" + group_by}) if sort is None: sort_query = ("_id", 1) query = [{"$match": match}, {"$project": project}, {"$group": group_op}, {"$sort": SON([sort_query])}] diff --git a/fireworks/queue/queue_launcher.py b/fireworks/queue/queue_launcher.py index 781ee59c6..b99827c93 100644 --- a/fireworks/queue/queue_launcher.py +++ b/fireworks/queue/queue_launcher.py @@ -96,7 +96,7 @@ def launch_rocket_to_queue( # update qadapter job_name based on FW name job_name = get_slug(fw.name)[0:QUEUE_JOBNAME_MAXLEN] - qadapter.update(job_name=job_name) + qadapter.update({"job_name": job_name}) if "_queueadapter" in fw.spec: l_logger.debug("updating queue params using Firework spec..") diff --git a/fireworks/user_objects/queue_adapters/common_adapter.py b/fireworks/user_objects/queue_adapters/common_adapter.py index 259d6ada7..6e1380b25 100644 --- a/fireworks/user_objects/queue_adapters/common_adapter.py +++ b/fireworks/user_objects/queue_adapters/common_adapter.py @@ -66,7 +66,7 @@ def __init__(self, q_type, q_name=None, template_file=None, timeout=None, **kwar ) self.q_name = q_name or q_type self.timeout = timeout or 5 - self.update(kwargs) + self.update(dict(kwargs)) self.q_commands = copy.deepcopy(CommonAdapter.default_q_commands) if "_q_commands_override" in self: