From 97211648e03adb1349426d3f2e61abb25f5bbba4 Mon Sep 17 00:00:00 2001 From: superstar54 Date: Sat, 2 Dec 2023 13:53:10 +0000 Subject: [PATCH] add worktree cli and tree subcommand --- aiida_worktree/cli/__init__.py | 5 + aiida_worktree/cli/cmd_tree.py | 405 +++++++++++++++++++++++++++ aiida_worktree/cli/cmd_worktree.py | 29 ++ aiida_worktree/cli/query_worktree.py | 194 +++++++++++++ aiida_worktree/node.py | 1 + aiida_worktree/worktree.py | 28 ++ pyproject.toml | 6 + 7 files changed, 668 insertions(+) create mode 100644 aiida_worktree/cli/__init__.py create mode 100644 aiida_worktree/cli/cmd_tree.py create mode 100644 aiida_worktree/cli/cmd_worktree.py create mode 100644 aiida_worktree/cli/query_worktree.py diff --git a/aiida_worktree/cli/__init__.py b/aiida_worktree/cli/__init__.py new file mode 100644 index 00000000..38730244 --- /dev/null +++ b/aiida_worktree/cli/__init__.py @@ -0,0 +1,5 @@ +"""Sub commands of the ``verdi`` command line interface. + +The commands need to be imported here for them to be registered with the top-level command group. +""" +from aiida_worktree.cli import cmd_tree diff --git a/aiida_worktree/cli/cmd_tree.py b/aiida_worktree/cli/cmd_tree.py new file mode 100644 index 00000000..6bbe60a8 --- /dev/null +++ b/aiida_worktree/cli/cmd_tree.py @@ -0,0 +1,405 @@ +"""`verdi process` command.""" +import click + +from aiida_worktree.cli.cmd_worktree import worktree +from aiida.cmdline.params import arguments, options, types +from aiida.cmdline.utils import decorators, echo +from aiida.common.log import LOG_LEVELS, capture_logging +from aiida.manage import get_manager +from aiida_worktree.cli.query_worktree import WorkTreeQueryBuilder + +REPAIR_INSTRUCTIONS = """\ +If one ore more processes are unreachable, you can run the following commands to try and repair them: + + verdi daemon stop + verdi process repair + verdi daemon start +""" + + +def valid_projections(): + """Return list of valid projections for the ``--project`` option of ``verdi process list``. + + This indirection is necessary to prevent loading the imported module which slows down tab-completion. + """ + return WorkTreeQueryBuilder.valid_projections + + +def default_projections(): + """Return list of default projections for the ``--project`` option of ``verdi process list``. + + This indirection is necessary to prevent loading the imported module which slows down tab-completion. + """ + return WorkTreeQueryBuilder.default_projections + + +@worktree.group("tree") +def worktree_tree(): + """Inspect and manage processes.""" + + +@worktree_tree.command("list") +@options.PROJECT( + type=types.LazyChoice(valid_projections), default=lambda: default_projections() +) # pylint: disable=unnecessary-lambda +@options.ORDER_BY() +@options.ORDER_DIRECTION() +@options.GROUP(help="Only include entries that are a member of this group.") +@options.ALL(help="Show all entries, regardless of their process state.") +@options.PROCESS_STATE() +@options.PROCESS_LABEL() +@options.PAUSED() +@options.EXIT_STATUS() +@options.FAILED() +@options.PAST_DAYS() +@options.LIMIT() +@options.RAW() +@click.pass_context +@decorators.with_dbenv() +def process_list( + ctx, + all_entries, + group, + process_state, + process_label, + paused, + exit_status, + failed, + past_days, + limit, + project, + raw, + order_by, + order_dir, +): + """Show a list of running or terminated WorkTree processes. + + By default, only those that are still running are shown, but there are options to show also the finished ones. + """ + # pylint: disable=too-many-locals + from tabulate import tabulate + + from aiida.cmdline.commands.cmd_daemon import execute_client_command + from aiida.cmdline.utils.common import print_last_process_state_change + from aiida.engine.daemon.client import get_daemon_client + from aiida.orm import ProcessNode, QueryBuilder + + relationships = {} + + if group: + relationships["with_node"] = group + + builder = WorkTreeQueryBuilder() + filters = builder.get_filters( + all_entries, process_state, process_label, paused, exit_status, failed + ) + query_set = builder.get_query_set( + relationships=relationships, + filters=filters, + order_by={order_by: order_dir}, + past_days=past_days, + limit=limit, + ) + projected = builder.get_projected(query_set, projections=project) + headers = projected.pop(0) + + if raw: + tabulated = tabulate(projected, tablefmt="plain") + echo.echo(tabulated) + return + + tabulated = tabulate(projected, headers=headers) + echo.echo(tabulated) + echo.echo(f"\nTotal results: {len(projected)}\n") + + if "cached" in project: + echo.echo_report( + "\u267B Processes marked with check-mark were not run but taken from the cache." + ) + echo.echo_report( + "Add the option `-P pk cached_from` to the command to display cache source." + ) + + print_last_process_state_change() + + if not get_daemon_client().is_daemon_running: + echo.echo_warning("The daemon is not running", bold=True) + return + + echo.echo_report("Checking daemon load... ", nl=False) + response = execute_client_command("get_numprocesses") + + if not response: + # Daemon could not be reached + return + + try: + active_workers = response["numprocesses"] + except KeyError: + echo.echo_report("No active daemon workers.") + else: + # Second query to get active process count. Currently this is slow but will be fixed with issue #2770. It is + # placed at the end of the command so that the user can Ctrl+C after getting the process table. + slots_per_worker = ctx.obj.config.get_option( + "daemon.worker_process_slots", scope=ctx.obj.profile.name + ) + active_processes = ( + QueryBuilder() + .append( + ProcessNode, + filters={ + "attributes.process_state": { + "in": ("created", "waiting", "running") + } + }, + ) + .count() + ) + available_slots = active_workers * slots_per_worker + percent_load = active_processes / available_slots + if percent_load > 0.9: # 90% + echo.echo_warning( + f"{percent_load * 100:.0f}% of the available daemon worker slots have been used!" + ) + echo.echo_warning( + "Increase the number of workers with `verdi daemon incr`." + ) + else: + echo.echo_report( + f"Using {percent_load * 100:.0f}% of the available daemon worker slots." + ) + + +@worktree_tree.command("show") +@arguments.PROCESSES() +@decorators.with_dbenv() +def process_show(processes): + """Show details for one or multiple processes.""" + from aiida_worktree import WorkTree + + for process in processes: + wt = WorkTree.load(process.pk) + wt.show() + + +@worktree_tree.command("report") +@arguments.PROCESSES() +@click.option( + "-i", + "--indent-size", + type=int, + default=2, + help="Set the number of spaces to indent each level by.", +) +@click.option( + "-l", + "--levelname", + type=click.Choice(list(LOG_LEVELS)), + default="REPORT", + help="Filter the results by name of the log level.", +) +@click.option( + "-m", + "--max-depth", + "max_depth", + type=int, + default=None, + help="Limit the number of levels to be printed.", +) +@decorators.with_dbenv() +def process_report(processes, levelname, indent_size, max_depth): + """Show the log report for one or multiple processes.""" + from aiida.cmdline.utils.common import ( + get_calcjob_report, + get_process_function_report, + get_workchain_report, + ) + from aiida.orm import CalcFunctionNode, CalcJobNode, WorkChainNode, WorkFunctionNode + + for process in processes: + if isinstance(process, CalcJobNode): + echo.echo(get_calcjob_report(process)) + elif isinstance(process, WorkChainNode): + echo.echo(get_workchain_report(process, levelname, indent_size, max_depth)) + elif isinstance(process, (CalcFunctionNode, WorkFunctionNode)): + echo.echo(get_process_function_report(process)) + else: + echo.echo(f"Nothing to show for node type {process.__class__}") + + +@worktree_tree.command("status") +@click.option( + "-c", + "--call-link-label", + "call_link_label", + is_flag=True, + help="Include the call link label if set.", +) +@click.option( + "-m", + "--max-depth", + "max_depth", + type=int, + default=None, + help="Limit the number of levels to be printed.", +) +@arguments.PROCESSES() +def process_status(call_link_label, max_depth, processes): + """Print the status of one or multiple processes.""" + from aiida.cmdline.utils.ascii_vis import format_call_graph + + for process in processes: + graph = format_call_graph( + process, max_depth=max_depth, call_link_label=call_link_label + ) + echo.echo(graph) + + +@worktree_tree.command("kill") +@arguments.PROCESSES() +@options.ALL(help="Kill all processes if no specific processes are specified.") +@options.TIMEOUT() +@options.WAIT() +@decorators.with_dbenv() +def process_kill(processes, all_entries, timeout, wait): + """Kill running processes.""" + from aiida.engine.processes import control + + if processes and all_entries: + raise click.BadOptionUsage( + "all", + "cannot specify individual processes and the `--all` flag at the same time.", + ) + + if all_entries: + click.confirm("Are you sure you want to kill all processes?", abort=True) + + with capture_logging() as stream: + try: + message = "Killed through `verdi process kill`" + control.kill_processes( + processes, + all_entries=all_entries, + timeout=timeout, + wait=wait, + message=message, + ) + except control.ProcessTimeoutException as exception: + echo.echo_critical(f"{exception}\n{REPAIR_INSTRUCTIONS}") + + if "unreachable" in stream.getvalue(): + echo.echo_report(REPAIR_INSTRUCTIONS) + + +@worktree_tree.command("pause") +@arguments.PROCESSES() +@options.ALL(help="Pause all active processes if no specific processes are specified.") +@options.TIMEOUT() +@options.WAIT() +@decorators.with_dbenv() +def process_pause(processes, all_entries, timeout, wait): + """Pause running processes.""" + from aiida.engine.processes import control + + if processes and all_entries: + raise click.BadOptionUsage( + "all", + "cannot specify individual processes and the `--all` flag at the same time.", + ) + + with capture_logging() as stream: + try: + message = "Paused through `verdi process pause`" + control.pause_processes( + processes, + all_entries=all_entries, + timeout=timeout, + wait=wait, + message=message, + ) + except control.ProcessTimeoutException as exception: + echo.echo_critical(f"{exception}\n{REPAIR_INSTRUCTIONS}") + + if "unreachable" in stream.getvalue(): + echo.echo_report(REPAIR_INSTRUCTIONS) + + +@worktree_tree.command("play") +@arguments.PROCESSES() +@options.ALL(help="Play all paused processes if no specific processes are specified.") +@options.TIMEOUT() +@options.WAIT() +@decorators.with_dbenv() +def process_play(processes, all_entries, timeout, wait): + """Play (unpause) paused processes.""" + from aiida.engine.processes import control + + if processes and all_entries: + raise click.BadOptionUsage( + "all", + "cannot specify individual processes and the `--all` flag at the same time.", + ) + + with capture_logging() as stream: + try: + control.play_processes( + processes, all_entries=all_entries, timeout=timeout, wait=wait + ) + except control.ProcessTimeoutException as exception: + echo.echo_critical(f"{exception}\n{REPAIR_INSTRUCTIONS}") + + if "unreachable" in stream.getvalue(): + echo.echo_report(REPAIR_INSTRUCTIONS) + + +@worktree_tree.command("watch") +@arguments.PROCESSES() +@decorators.with_dbenv() +@decorators.only_if_daemon_running( + echo.echo_warning, "daemon is not running, so process may not be reachable" +) +def process_watch(processes): + """Watch the state transitions for a process.""" + from time import sleep + + from kiwipy import BroadcastFilter + + def _print( + communicator, body, sender, subject, correlation_id + ): # pylint: disable=unused-argument + """Format the incoming broadcast data into a message and echo it to stdout.""" + if body is None: + body = "No message specified" + + if correlation_id is None: + correlation_id = "--" + + echo.echo(f"Process<{sender}> [{subject}|{correlation_id}]: {body}") + + communicator = get_manager().get_communicator() + echo.echo_report("watching for broadcasted messages, press CTRL+C to stop...") + + for process in processes: + + if process.is_terminated: + echo.echo_error(f"Process<{process.pk}> is already terminated") + continue + + communicator.add_broadcast_subscriber( + BroadcastFilter(_print, sender=process.pk) + ) + + try: + # Block this thread indefinitely until interrupt + while True: + sleep(2) + except (SystemExit, KeyboardInterrupt): + echo.echo("") # add a new line after the interrupt character + echo.echo_report("received interrupt, exiting...") + try: + communicator.close() + except RuntimeError: + pass + + # Reraise to trigger clicks builtin abort sequence + raise diff --git a/aiida_worktree/cli/cmd_worktree.py b/aiida_worktree/cli/cmd_worktree.py new file mode 100644 index 00000000..1f0f4ac6 --- /dev/null +++ b/aiida_worktree/cli/cmd_worktree.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""The main `verdi` click group.""" +import click + +from aiida_worktree import __version__ + +from aiida.cmdline.groups import VerdiCommandGroup +from aiida.cmdline.params import options, types + + +# Pass the version explicitly to ``version_option`` otherwise editable installs can show the wrong version number +@click.group( + cls=VerdiCommandGroup, context_settings={"help_option_names": ["--help", "-h"]} +) +@options.PROFILE(type=types.ProfileParamType(load_profile=True), expose_value=False) +@options.VERBOSITY() +@click.version_option( + __version__, package_name="aiida_core", message="AiiDA-WorkTree version %(version)s" +) +def worktree(): + """The command line interface of AiiDA-WorkTree.""" diff --git a/aiida_worktree/cli/query_worktree.py b/aiida_worktree/cli/query_worktree.py new file mode 100644 index 00000000..787296a3 --- /dev/null +++ b/aiida_worktree/cli/query_worktree.py @@ -0,0 +1,194 @@ +from aiida.common.lang import classproperty + +from aiida.tools.query.mapping import CalculationProjectionMapper + + +class WorkTreeQueryBuilder: + """Utility class to construct a QueryBuilder instance for WorkTree nodes and project the query set.""" + + # This tuple serves to mark compound projections that cannot explicitly be projected in the QueryBuilder, but will + # have to be manually projected from composing its individual projection constituents + _compound_projections = ("state",) + _default_projections = ( + "pk", + "ctime", + "process_label", + "cached", + "state", + "process_status", + ) + _valid_projections = ( + "pk", + "uuid", + "ctime", + "mtime", + "state", + "process_state", + "process_status", + "exit_status", + "exit_message", + "sealed", + "process_label", + "label", + "description", + "node_type", + "paused", + "process_type", + "job_state", + "scheduler_state", + "exception", + "cached", + "cached_from", + ) + + def __init__(self, mapper=None): + if mapper is None: + self._mapper = CalculationProjectionMapper(self._valid_projections) + else: + self._mapper = mapper + + @property + def mapper(self): + return self._mapper + + @classproperty + def default_projections(self): + return self._default_projections + + @classproperty + def valid_projections(self): + return self._valid_projections + + def get_filters( + self, + all_entries=False, + process_state=None, + process_label=None, + paused=False, + exit_status=None, + failed=False, + node_types=None, + ): + """Return a set of QueryBuilder filters based on typical command line options. + + :param node_types: A tuple of node classes to filter for (must be sub classes of Calculation). + :param all_entries: Boolean to negate filtering for process state. + :param process_state: Filter for this process state attribute. + :param process_label: Filter for this process label attribute. + :param paused: Boolean, if True, filter for processes that are paused. + :param exit_status: Filter for this exit status. + :param failed: Boolean to filter only failed processes. + :return: Dictionary of filters suitable for a QueryBuilder.append() call. + """ + # pylint: disable=too-many-arguments + from aiida.engine import ProcessState + + exit_status_attribute = self.mapper.get_attribute("exit_status") + process_label_attribute = self.mapper.get_attribute("process_label") + process_state_attribute = self.mapper.get_attribute("process_state") + paused_attribute = self.mapper.get_attribute("paused") + + filters = {} + + if node_types is not None: + filters["or"] = [] + for node_class in node_types: + filters["or"].append({"type": node_class.class_node_type}) + + if process_state and not all_entries: + filters[process_state_attribute] = {"in": process_state} + + if process_label is not None: + if "%" in process_label or "_" in process_label: + filters[process_label_attribute] = {"like": process_label} + else: + filters[process_label_attribute] = process_label + + if paused: + filters[paused_attribute] = True + + if failed: + filters[process_state_attribute] = {"==": ProcessState.FINISHED.value} + filters[exit_status_attribute] = {">": 0} + + if exit_status is not None: + filters[process_state_attribute] = {"==": ProcessState.FINISHED.value} + filters[exit_status_attribute] = {"==": exit_status} + + return filters + + def get_query_set( + self, + relationships=None, + filters=None, + order_by=None, + past_days=None, + limit=None, + ): + """Return the query set of calculations for the given filters and query parameters. + + :param relationships: A mapping of relationships to join on, e.g. {'with_node': Group} to join on a Group. The + keys in this dictionary should be the keyword used in the `append` method of the `QueryBuilder` to join the + entity on that is defined as the value. + :param filters: Rules to filter query results with. + :param order_by: Order the query set by this criterion. + :param past_days: Only include entries from the last past days. + :param limit: Limit the query set to this number of entries. + :return: The query set, a list of dictionaries. + """ + import datetime + + from aiida import orm + from aiida.common import timezone + from aiida_worktree.engine.worktree import WorkTree + + # Define the list of projections for the QueryBuilder, which are all valid minus the compound projections + projected_attributes = [ + self.mapper.get_attribute(projection) + for projection in self._valid_projections + if projection not in self._compound_projections + ] + unique_projections = list(set(projected_attributes)) + + if filters is None: + filters = {} + + if past_days is not None: + filters["ctime"] = { + ">": timezone.now() - datetime.timedelta(days=past_days) + } + + builder = orm.QueryBuilder() + builder.append( + WorkTree, filters=filters, project=unique_projections, tag="process" + ) + + if relationships is not None: + for tag, entity in relationships.items(): + builder.append( + cls=type(entity), filters={"id": entity.pk}, **{tag: "process"} + ) + + if order_by is not None: + builder.order_by({"process": order_by}) + else: + builder.order_by({"process": {"ctime": "asc"}}) + + if limit is not None: + builder.limit(limit) + + return builder.iterdict() + + def get_projected(self, query_set, projections): + """Project the query set for the given set of projections.""" + header = [self.mapper.get_label(projection) for projection in projections] + result = [header] + + for query_result in query_set: + result_row = [ + self.mapper.format(projection, query_result["process"]) + for projection in projections + ] + result.append(result_row) + + return result diff --git a/aiida_worktree/node.py b/aiida_worktree/node.py index ea60042a..76d601ca 100644 --- a/aiida_worktree/node.py +++ b/aiida_worktree/node.py @@ -19,6 +19,7 @@ def __init__(self, **kwargs): self.to_ctx = None self.wait = None self.process = None + self.pk = None def to_dict(self): ndata = super().to_dict() diff --git a/aiida_worktree/worktree.py b/aiida_worktree/worktree.py index 090cd8b1..2e818704 100644 --- a/aiida_worktree/worktree.py +++ b/aiida_worktree/worktree.py @@ -148,3 +148,31 @@ def load(cls, pk): wt.process = process wt.update() return wt + + def show(self): + """ + Print the current state of the worktree process. + """ + from tabulate import tabulate + + table = [] + self.update() + for node in self.nodes: + table.append([node.name, node.pk, node.state]) + print("-" * 80) + print("WorkTree: {}, PK: {}, State: {}".format(self.name, self.pk, self.state)) + print("-" * 80) + # show nodes + print("Nodes:") + print(tabulate(table, headers=["Name", "PK", "State"])) + print("-" * 80) + + def pause_nodes(self, nodes): + """ + Pause the given nodes + """ + + def play_nodes(self, nodes): + """ + Play the given nodes + """ diff --git a/pyproject.toml b/pyproject.toml index 7e798b0b..c9ce234c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,12 @@ tests = [ "pytest-cov~=2.7,<2.11", ] +[project.scripts] +worktree = "aiida_worktree.cli.cmd_worktree:worktree" + +[project.entry-points."aiida.cmdline"] +"worktree" = "aiida_worktree.cli.cmd_worktree:worktree" + [project.entry-points."aiida.node"] "process.workflow.worktree" = "aiida_worktree.orm.worktree:WorkTreeNode"