From ea084454601023eb7cc95b311819fb5a487ecfb2 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Wed, 5 May 2021 16:42:37 +0200 Subject: [PATCH 01/18] Fix directory creation race condition in `Folder` and `SandboxFolder` (#4912) The `Folder` and `SandboxFolder` classes of `aiida.common.folders` used the following paradigm to create the required folders: if not os.path.exists(filepath): os.makedirs(filepath) However, this is susceptible to a race condition. If two processes call the same piece of code almost at the same time, they may both evaluate the conditional to be True if the filepath does not yet exist, but one of the two will actually get to the creation first, causing the second process to except with a `FileExistsError`. The solution is to replace it with `os.makedirs(filepath, exist_ok=True)` which will swallow the exception if the path already exists. Cherry-pick: dc686c5aaede1944bc40a06becbef34cde3ed282 --- aiida/common/folders.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/aiida/common/folders.py b/aiida/common/folders.py index df185faa79..bedb183929 100644 --- a/aiida/common/folders.py +++ b/aiida/common/folders.py @@ -342,8 +342,7 @@ def create(self): It is always safe to call it, it will do nothing if the folder already exists. """ - if not self.exists(): - os.makedirs(self.abspath, mode=self.mode_dir) + os.makedirs(self.abspath, mode=self.mode_dir, exist_ok=True) def replace_with_folder(self, srcdir, move=False, overwrite=False): """This routine copies or moves the source folder 'srcdir' to the local folder pointed to by this Folder. @@ -370,8 +369,7 @@ def replace_with_folder(self, srcdir, move=False, overwrite=False): # Create parent dir, if needed, with the right mode pardir = os.path.dirname(self.abspath) - if not os.path.exists(pardir): - os.makedirs(pardir, mode=self.mode_dir) + os.makedirs(pardir, mode=self.mode_dir, exist_ok=True) if move: shutil.move(srcdir, self.abspath) @@ -417,8 +415,7 @@ def __init__(self, sandbox_in_repo=True): # First check if the sandbox folder already exists if sandbox_in_repo: sandbox = os.path.join(get_profile().repository_path, 'sandbox') - if not os.path.exists(sandbox): - os.makedirs(sandbox) + os.makedirs(sandbox, exist_ok=True) abspath = tempfile.mkdtemp(dir=sandbox) else: abspath = tempfile.mkdtemp() From 21c8743aab51702f2643ee7cf7c2e35693fc69b5 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Tue, 4 May 2021 10:11:08 +0200 Subject: [PATCH 02/18] CLI: set `localhost` as default for database hostname in `verdi setup` (#4908) The default was actually being defined on the option, however, it was taken from the `pgsu.DEFAULT_DSN` dictionary, which defines the database hostname to be `None`. Still, 9 out of 10 times the database is on the localhost so not having this as a default is kind of annoying and unnecessary. Note that `pgsu` specifies `None` as the default because this is at times the only way the psql shell can be accessed without a password. Cherry-pick: 6c4ced3331b389cebd01a59d55eb4b07f9452672 --- aiida/cmdline/params/options/__init__.py | 2 +- aiida/cmdline/params/options/commands/setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/aiida/cmdline/params/options/__init__.py b/aiida/cmdline/params/options/__init__.py index 16ae5eb95c..accd78c65f 100644 --- a/aiida/cmdline/params/options/__init__.py +++ b/aiida/cmdline/params/options/__init__.py @@ -249,7 +249,7 @@ def decorator(command): '--db-host', type=types.HostnameType(), help='Database server host. Leave empty for "peer" authentication.', - default=DEFAULT_DBINFO['host'] + default='localhost' ) DB_PORT = OverridableOption( diff --git a/aiida/cmdline/params/options/commands/setup.py b/aiida/cmdline/params/options/commands/setup.py index b5cb9f974d..1ec43c82ed 100644 --- a/aiida/cmdline/params/options/commands/setup.py +++ b/aiida/cmdline/params/options/commands/setup.py @@ -259,7 +259,7 @@ def get_quicksetup_password(ctx, param, value): # pylint: disable=unused-argume SETUP_DATABASE_HOSTNAME = QUICKSETUP_DATABASE_HOSTNAME.clone( prompt='Database host', - contextual_default=functools.partial(get_profile_attribute_default, ('database_hostname', DEFAULT_DBINFO['host'])), + contextual_default=functools.partial(get_profile_attribute_default, ('database_hostname', 'localhost')), cls=options.interactive.InteractiveOption ) From 2683d71e4a8757a3b682a309346c8e9e750990e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tiziano=20M=C3=BCller?= Date: Mon, 19 Jul 2021 16:42:26 +0200 Subject: [PATCH 03/18] =?UTF-8?q?=F0=9F=91=8C=20IMPROVE:=20transports/ssh:?= =?UTF-8?q?=20support=20proxy=5Fjump=20(#4951)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit SSH provides multiple ways to forward connections. The legacy way is via SSHProxyCommand which spawns a separate process for each jump host/proxy. Controlling those processes is error prone and lingering/hanging processes have been observed (#4940 and others, depending on the setup). This commit adds support for the SSHProxyJump feature which permits to setup an arbitrary number of proxy jumps without additional processes by creating TCP channels over existing (Paramiko) connections. This gives a good control over the lifetime of the different connections and since a users SSH config is not re-read after the initial setup gives a controlled environment. Hence it has been decided to make this new directive the recommended default in the documentation while still supporting both ways. Co-authored-by: Marnik Bercx Co-authored-by: Leopold Talirz Cherry-pick: da179dceef71fe52856b93e756e8bde3b89e2c4e --- aiida/transports/plugins/ssh.py | 142 ++++++++++++++++++++++++++------ docs/source/howto/ssh.rst | 62 ++++++++++---- tests/transports/test_ssh.py | 68 +++++++++++++++ 3 files changed, 232 insertions(+), 40 deletions(-) diff --git a/aiida/transports/plugins/ssh.py b/aiida/transports/plugins/ssh.py index ada3650186..9f354ffe4c 100644 --- a/aiida/transports/plugins/ssh.py +++ b/aiida/transports/plugins/ssh.py @@ -12,9 +12,11 @@ import glob import io import os +import re from stat import S_ISDIR, S_ISREG import click +import paramiko from aiida.cmdline.params import options from aiida.cmdline.params.types.path import AbsolutePathOrEmptyParamType @@ -33,7 +35,6 @@ def parse_sshconfig(computername): :param computername: the computer name for which we want the configuration. """ - import paramiko config = paramiko.SSHConfig() try: with open(os.path.expanduser('~/.ssh/config'), encoding='utf8') as fhandle: @@ -118,12 +119,29 @@ class SshTransport(Transport): # pylint: disable=too-many-public-methods 'non_interactive_default': True } ), + ( + 'proxy_jump', { + 'prompt': + 'SSH proxy jump', + 'help': + 'SSH proxy jump for tunneling through other SSH hosts.' + ' Use a comma-separated list of hosts of the form [user@]host[:port].' + ' If user or port are not specified for a host, the user & port values from the target host are used.' + ' This option must be provided explicitly and is not parsed from the SSH config file when left empty.', + 'non_interactive_default': + True + } + ), # Managed 'manually' in connect ( 'proxy_command', { - 'prompt': 'SSH proxy command', - 'help': 'SSH proxy command for tunneling through a proxy server.' + 'prompt': + 'SSH proxy command', + 'help': + 'SSH proxy command for tunneling through a proxy server.' + ' For tunneling through another SSH host, consider using the "SSH proxy jump" option instead!' ' Leave empty to parse the proxy command from the SSH config file.', - 'non_interactive_default': True + 'non_interactive_default': + True } ), # Managed 'manually' in connect ( @@ -309,6 +327,13 @@ def _get_proxy_command_suggestion_string(cls, computer): return ' '.join(new_pieces) + @classmethod + def _get_proxy_jump_suggestion_string(cls, _): + """ + Return an empty suggestion since Paramiko does not parse ProxyJump from the SSH config. + """ + return '' + @classmethod def _get_compress_suggestion_string(cls, computer): # pylint: disable=unused-argument """ @@ -377,11 +402,11 @@ def __init__(self, *args, **kwargs): function (as port, username, password, ...); taken from the accepted paramiko.SSHClient.connect() params. """ - import paramiko super().__init__(*args, **kwargs) self._sftp = None self._proxy = None + self._proxies = [] self._machine = kwargs.pop('machine') @@ -410,7 +435,7 @@ def __init__(self, *args, **kwargs): except KeyError: pass - def open(self): + def open(self): # pylint: disable=too-many-branches,too-many-statements """ Open a SSHClient to the machine possibly using the parameters given in the __init__. @@ -420,6 +445,7 @@ def open(self): :raise aiida.common.InvalidOperation: if the channel is already open """ + from paramiko.ssh_exception import SSHException from aiida.common.exceptions import InvalidOperation from aiida.transports.util import _DetachedProxyCommand @@ -429,9 +455,65 @@ def open(self): connection_arguments = self._connect_args.copy() if 'key_filename' in connection_arguments and not connection_arguments['key_filename']: connection_arguments.pop('key_filename') - proxystring = connection_arguments.pop('proxy_command', None) - if proxystring: - self._proxy = _DetachedProxyCommand(proxystring) + + proxyjumpstring = connection_arguments.pop('proxy_jump', None) + proxycmdstring = connection_arguments.pop('proxy_command', None) + + if proxyjumpstring and proxycmdstring: + raise ValueError('The SSH proxy jump and SSH proxy command options can not be used together') + + if proxyjumpstring: + matcher = re.compile(r'^(?:(?P[^@]+)@)?(?P[^@:]+)(?::(?P\d+))?\s*$') + try: + # don't use a generator here to have everything evaluated + proxies = [matcher.match(s).groupdict() for s in proxyjumpstring.split(',')] + except AttributeError: + raise ValueError('The given configuration for the SSH proxy jump option could not be parsed') + + # proxy_jump supports a list of jump hosts, each jump host is another Paramiko SSH connection + # but when opening a forward channel on a connection, we have to give the next hop. + # So we go through adjacent pairs and by adding the final target to the list we make it universal. + for proxy, target in zip( + proxies, proxies[1:] + [{ + 'host': self._machine, + 'port': connection_arguments.get('port', 22), + }] + ): + proxy_connargs = connection_arguments.copy() + + if proxy['username']: + proxy_connargs['username'] = proxy['username'] + if proxy['port']: + proxy_connargs['port'] = int(proxy['port']) + if not target['port']: # the target port for the channel can not be None + target['port'] = connection_arguments.get('port', 22) + + proxy_client = paramiko.SSHClient() + if self._load_system_host_keys: + proxy_client.load_system_host_keys() + if self._missing_key_policy == 'RejectPolicy': + proxy_client.set_missing_host_key_policy(paramiko.RejectPolicy()) + elif self._missing_key_policy == 'WarningPolicy': + proxy_client.set_missing_host_key_policy(paramiko.WarningPolicy()) + elif self._missing_key_policy == 'AutoAddPolicy': + proxy_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + try: + proxy_client.connect(proxy['host'], **proxy_connargs) + except Exception as exc: + self.logger.error( + f"Error connecting to proxy '{proxy['host']}' through SSH: [{self.__class__.__name__}] {exc}, " + f'connect_args were: {proxy_connargs}' + ) + self._close_proxies() # close all since we're going to start anew on the next open() (if any) + raise + connection_arguments['sock'] = proxy_client.get_transport().open_channel( + 'direct-tcpip', (target['host'], target['port']), ('', 0) + ) + self._proxies.append(proxy_client) + + if proxycmdstring: + self._proxy = _DetachedProxyCommand(proxycmdstring) connection_arguments['sock'] = self._proxy try: @@ -441,22 +523,14 @@ def open(self): f"Error connecting to '{self._machine}' through SSH: " + f'[{self.__class__.__name__}] {exc}, ' + f'connect_args were: {self._connect_args}' ) + self._close_proxies() raise - # Open also a File transport client. SFTP by default, pure SSH in ssh_only - self.open_file_transport() - - return self - - def open_file_transport(self): - """ - Open the SFTP channel, and handle error by directing customer to try another transport - """ - from aiida.common.exceptions import InvalidOperation - from paramiko.ssh_exception import SSHException + # Open the SFTP channel, and handle error by directing customer to try another transport try: self._sftp = self._client.open_sftp() except SSHException: + self._close_proxies() raise InvalidOperation( 'Error in ssh transport plugin. This may be due to the remote computer not supporting SFTP. ' 'Try setting it up with the aiida.transports:ssh_only transport from the aiida-sshonly plugin instead.' @@ -467,6 +541,21 @@ def open_file_transport(self): # Set the current directory to a explicit path, and not to None self._sftp.chdir(self._sftp.normalize('.')) + return self + + def _close_proxies(self): + """Close all proxy connections (proxy_jump and proxy_command)""" + + # Paramiko only closes the channel when closing the main connection, but not the connection itself. + while self._proxies: + self._proxies.pop().close() + + if self._proxy: + # Paramiko should close this automatically when closing the channel, + # but since the process is started in __init__this might not happen correctly. + self._proxy.close() + self._proxy = None + def close(self): """ Close the SFTP channel, and the SSHClient. @@ -482,6 +571,8 @@ def close(self): self._sftp.close() self._client.close() + self._close_proxies() + self._is_open = False @property @@ -1156,7 +1247,6 @@ def _local_listdir(path, pattern=None): """ if not pattern: return os.listdir(path) - import re if path.startswith('/'): # always this is the case in the local case base_dir = path else: @@ -1177,7 +1267,6 @@ def listdir(self, path='.', pattern=None): """ if not pattern: return self.sftp.listdir(path) - import re if path.startswith('/'): base_dir = path else: @@ -1334,13 +1423,16 @@ def gotocomputer_command(self, remotedir): if 'username' in self._connect_args: further_params.append(f"-l {escape_for_bash(self._connect_args['username'])}") - if 'port' in self._connect_args and self._connect_args['port']: + if self._connect_args.get('port'): further_params.append(f"-p {self._connect_args['port']}") - if 'key_filename' in self._connect_args and self._connect_args['key_filename']: + if self._connect_args.get('key_filename'): further_params.append(f"-i {escape_for_bash(self._connect_args['key_filename'])}") - if 'proxy_command' in self._connect_args and self._connect_args['proxy_command']: + if self._connect_args.get('proxy_jump'): + further_params.append(f"-o ProxyJump={escape_for_bash(self._connect_args['proxy_jump'])}") + + if self._connect_args.get('proxy_command'): further_params.append(f"-o ProxyCommand={escape_for_bash(self._connect_args['proxy_command'])}") further_params_str = ' '.join(further_params) diff --git a/docs/source/howto/ssh.rst b/docs/source/howto/ssh.rst index 326d2a26af..6bf7b09e42 100644 --- a/docs/source/howto/ssh.rst +++ b/docs/source/howto/ssh.rst @@ -23,7 +23,7 @@ Very briefly, first create a new private/public keypair (``aiida``/``aiida.pub`` $ ssh-keygen -t rsa -b 4096 -f ~/.ssh/aiida -Copy the public key to the remote machine, normally this will add the public key to the rmote machine's ``~/.ssh/authorized_keys``: +Copy the public key to the remote machine, normally this will add the public key to the remote machine's ``~/.ssh/authorized_keys``: .. code-block:: console @@ -39,7 +39,7 @@ Add the following lines to your ``~/.ssh/config`` file (or create it, if it does .. note:: - If your cluster needs you to connect to another computer *PROXY* first, you can use the ``proxy_command`` feature of ssh, see :ref:`how-to:ssh:proxy`. + If your cluster needs you to connect to another computer *PROXY* first, you can use the ``ProxyJump`` or ``ProxyCommand`` feature of SSH, see :ref:`how-to:ssh:proxy`. You should now be able to access the remote computer (without the need to type a password) *via*: @@ -185,47 +185,79 @@ Connecting to a remote computer *via* a proxy server ==================================================== Some compute clusters require you to connect to an intermediate server *PROXY*, from which you can then connect to the cluster *TARGET* on which you run your calculations. -This section explains how to use the ``proxy_command`` feature of ``ssh`` in order to make this jump automatically. +This section explains how to use the ``ProxyJump`` or ``ProxyCommand`` feature of ``ssh`` in order to make this jump automatically. .. tip:: - This method can also be used to automatically tunnel into virtual private networks, if you have an account on a proxy/jumphost server with access to the network. - + This method can also be used to avoid having to start a virtual private network (VPN) client if you have an SSH account on a proxy/jumphost server which is accessible from your current network **and** from which you can access the *TARGET* machine directly. SSH configuration ^^^^^^^^^^^^^^^^^ -Edit the ``~/.ssh/config`` file on the computer on which you installed AiiDA (or create it if missing) and add the following lines:: +To decide whether to use the ``ProxyJump`` (recommended) or the ``ProxyCommand`` directive, please check the version of your SSH client first with ``ssh -V``. +The ``ProxyJump`` directive has been added in version 7.3 of OpenSSH, hence if you are using an older version of SSH (on your machine or the *PROXY*) you have to use the older ``ProxyCommand``. + +To setup the proxy configuration with ``ProxyJump``, edit the ``~/.ssh/config`` file on the computer on which you installed AiiDA (or create it if missing) +and add the following lines:: Host SHORTNAME_TARGET Hostname FULLHOSTNAME_TARGET User USER_TARGET IdentityFile ~/.ssh/aiida - ProxyCommand ssh -W %h:%p USER_PROXY@FULLHOSTNAME_PROXY + ProxyJump USER_PROXY@FULLHOSTNAME_PROXY + + Host FULLHOSTNAME_PROXY + IdentityFile ~/.ssh/aiida + +Replace the ``..._TARGET`` and ``..._PROXY`` variables with the host/user names of the respective servers. + +.. dropdown:: :fa:`plus-circle` Alternative setup with ``ProxyCommand`` + + To setup the proxy configuration with ``ProxyCommand`` **instead**, edit the ``~/.ssh/config`` file on the computer on which you installed AiiDA (or create it if missing) + and add the following lines:: + + Host SHORTNAME_TARGET + Hostname FULLHOSTNAME_TARGET + User USER_TARGET + IdentityFile ~/.ssh/aiida + ProxyCommand ssh -W %h:%p USER_PROXY@FULLHOSTNAME_PROXY -replacing the ``..._TARGET`` and ``..._PROXY`` variables with the host/user names of the respective servers. + Host FULLHOSTNAME_PROXY + IdentityFile ~/.ssh/aiida -This should allow you to directly connect to the *TARGET* server using + Replace the ``..._TARGET`` and ``..._PROXY`` variables with the host/user names of the respective servers. + +In both cases, this should allow you to directly connect to the *TARGET* server using .. code-block:: console $ ssh SHORTNAME_TARGET -For a *passwordless* connection, you need to follow the instructions :ref:`how-to:ssh:passwordless` *twice*: once for the connection from your computer to the *PROXY* server, and once for the connection from the *PROXY* server to the *TARGET* server. -.. dropdown:: Specifying an SSH key for the proxy - If you need to specify a separate SSH key for the proxy, provide it *after* the ``-W`` directive, e.g.:: +.. note :: + + If the user directory is not shared between the *PROXY* and the *TARGET* (in most supercomputing facilities your user directory is shared between the machines), you need to follow the :ref:`instructions for a passwordless connection ` *twice*: once for the connection from your computer to the *PROXY* server, and once for the connection from the *PROXY* server to the *TARGET* server (e.g. the public key must be listed in the ``~/.ssh/authorized_keys`` file of both the *PROXY* and the *TARGET* server). - ssh -W fidis.epfl.ch:22 -i /home/ubuntu/.ssh/proxy user@proxy.epfl.ch AiiDA configuration ^^^^^^^^^^^^^^^^^^^ -When :ref:`configuring the computer in AiiDA `, AiiDA will automatically parse the required information from your ``~/.ssh/config`` file. +When :ref:`configuring the computer in AiiDA `, AiiDA will automatically parse most of required information from your ``~/.ssh/config`` file. A notable exception to this is the ``proxy_jump`` directive, which **must** be specified manually. + +Simply copy & paste the same instructions as you have used for ``ProxyJump`` in your ``~/.ssh/config`` to the input for ``proxy_jump``: + +.. code-block:: console + + $ verdi computer configure ssh SHORTNAME_TARGET + ... + Allow ssh agent [True]: + SSH proxy jump []: USER_PROXY@FULLHOSTNAME_PROXY + +.. note:: A chain of proxies can be specified as a comma-separated list. If you need to specify a different username, you can so with ``USER_PROXY@...``. If no username is specified for the proxy the same username as for the *TARGET* is used. -.. dropdown:: Specifying the proxy_command manually +.. important:: Specifying the ``proxy_command`` manually When specifying or updating the ``proxy_command`` option via ``verdi computer configure ssh``, please **do not use placeholders** ``%h`` and ``%p`` but provide the *actual* hostname and port. AiiDA replaces them only when parsing from the ``~/.ssh/config`` file. diff --git a/tests/transports/test_ssh.py b/tests/transports/test_ssh.py index 98dcf8a113..8b2043f878 100644 --- a/tests/transports/test_ssh.py +++ b/tests/transports/test_ssh.py @@ -44,6 +44,55 @@ def test_auto_add_policy(): with SshTransport(machine='localhost', timeout=30, load_system_host_keys=True, key_policy='AutoAddPolicy'): pass + @staticmethod + def test_proxy_jump(): + """Test the connection with a proxy jump or several""" + with SshTransport( + machine='localhost', + proxy_jump='localhost', + timeout=30, + load_system_host_keys=True, + key_policy='AutoAddPolicy' + ): + pass + + # kind of pointless, but should work and to check that proxy chaining works + with SshTransport( + machine='localhost', + proxy_jump='localhost,localhost,localhost', + timeout=30, + load_system_host_keys=True, + key_policy='AutoAddPolicy' + ): + pass + + def test_proxy_jump_invalid(self): + """Test proper error reporting when invalid host as a proxy""" + + # import is also that when Python is running with debug warnings `-Wd` + # no unclosed files are reported. + with self.assertRaises(paramiko.SSHException): + with SshTransport( + machine='localhost', + proxy_jump='localhost,nohost', + timeout=30, + load_system_host_keys=True, + key_policy='AutoAddPolicy' + ): + pass + + @staticmethod + def test_proxy_command(): + """Test the connection with a proxy command""" + with SshTransport( + machine='localhost', + proxy_command='ssh -W localhost:22 localhost', + timeout=30, + load_system_host_keys=True, + key_policy='AutoAddPolicy' + ): + pass + def test_no_host_key(self): """Test if there is no host key.""" # Disable logging to avoid output during test @@ -74,3 +123,22 @@ def test_gotocomputer(): """echo ' ** /remote_dir/' ; echo ' ** seems to have been deleted, I logout...' ; fi" """ ) assert cmd_str == expected_str + + +def test_gotocomputer_proxyjump(): + """Test gotocomputer""" + with SshTransport( + machine='localhost', + timeout=30, + use_login_shell=False, + key_policy='AutoAddPolicy', + proxy_jump='localhost', + ) as transport: + cmd_str = transport.gotocomputer_command('/remote_dir/') + + expected_str = ( + """ssh -t localhost -o ProxyJump='localhost' "if [ -d '/remote_dir/' ] ;""" + """ then cd '/remote_dir/' ; bash ; else echo ' ** The directory' ; """ + """echo ' ** /remote_dir/' ; echo ' ** seems to have been deleted, I logout...' ; fi" """ + ) + assert cmd_str == expected_str From e02933c1c07f2f9c775cce7b8c92ebb125a98175 Mon Sep 17 00:00:00 2001 From: "Jason.Eu" Date: Thu, 22 Jul 2021 20:04:03 +0800 Subject: [PATCH 04/18] FIX: namespaced outputs in `BaseRestartWorkChain` (#4961) The `BaseRestartWorkChain` did not return an `output_namespace` of its `_process_class` as described in #4623. It happened because in its `results` method, only the output keys are obtained from the call to `node.get_outgoing` (checked and returned by the parent WorkChain). This was changed for a call to `exposed_outputs`, which instead returns the whole nested namespace. The `out_many` method is not used here in order to make a post-check for ports that allows to keep the original exit code check and report. Cherry-pick: e1abe0aad8c88a844889f33b8340bdaf63a1c415 --- aiida/engine/processes/workchains/restart.py | 6 ++- .../processes/workchains/test_restart.py | 47 ++++++++++++++++++- 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/aiida/engine/processes/workchains/restart.py b/aiida/engine/processes/workchains/restart.py index 5719e1496f..12a3a05dc4 100644 --- a/aiida/engine/processes/workchains/restart.py +++ b/aiida/engine/processes/workchains/restart.py @@ -303,11 +303,13 @@ def results(self) -> Optional['ExitCode']: self.report(f'work chain completed after {self.ctx.iteration} iterations') + exposed_outputs = self.exposed_outputs(node, self.process_class) + for name, port in self.spec().outputs.items(): try: - output = node.get_outgoing(link_label_filter=name).one().node - except ValueError: + output = exposed_outputs[name] + except KeyError: if port.required: self.report(f"required output '{name}' was not an output of {self.ctx.process_name}<{node.pk}>") else: diff --git a/tests/engine/processes/workchains/test_restart.py b/tests/engine/processes/workchains/test_restart.py index 034a244bba..fb01549342 100644 --- a/tests/engine/processes/workchains/test_restart.py +++ b/tests/engine/processes/workchains/test_restart.py @@ -11,7 +11,7 @@ # pylint: disable=invalid-name,no-self-use,no-member import pytest -from aiida import engine +from aiida import engine, orm from aiida.engine.processes.workchains.awaitable import Awaitable @@ -146,3 +146,48 @@ def mock_submit(_, process_class, **kwargs): assert isinstance(result, engine.ToContext) assert isinstance(result['children'], Awaitable) assert process.node.get_extra(SomeWorkChain._considered_handlers_extra) == [[]] # pylint: disable=protected-access + + +class OutputNamespaceWorkChain(engine.WorkChain): + """A WorkChain has namespaced output""" + + @classmethod + def define(cls, spec): + super().define(spec) + spec.output_namespace('sub', valid_type=orm.Int, dynamic=True) + spec.outline(cls.finalize) + + def finalize(self): + self.out('sub.result', orm.Int(1).store()) + + +class CustomBRWorkChain(engine.BaseRestartWorkChain): + """`BaseRestartWorkChain` of `OutputNamespaceWorkChain`""" + + _process_class = OutputNamespaceWorkChain + + @classmethod + def define(cls, spec): + super().define(spec) + spec.expose_outputs(cls._process_class) + spec.output('extra', valid_type=orm.Int) + + spec.outline( + cls.setup, + engine.while_(cls.should_run_process)( + cls.run_process, + cls.inspect_process, + ), + cls.results, + ) + + def setup(self): + super().setup() + self.ctx.inputs = {} + + +@pytest.mark.requires_rmq +def test_results(): + res, node = engine.launch.run_get_node(CustomBRWorkChain) + assert res['sub'].result.value == 1 + assert node.exit_status == 11 From a3efc53301110aad6607a35b3e92a469e2cab241 Mon Sep 17 00:00:00 2001 From: Francisco Ramirez Date: Wed, 21 Jul 2021 10:07:44 +0200 Subject: [PATCH 05/18] Engine: only call `get_detailed_job_info` if there is a job id (#4967) In the retrieve task of `CalcJobs` the `get_detailed_job_info` was called always. This would lead to problems if the node did not have an associated job id. Normally this doesn't happen because without a job id the engine would not even have been able to confirm that the job was ready for retrieval, however, this can happen in artificial situations where the whole calcjob process is mocked and for example an already completed job is passed through the system. When there is no job id, the `get_detailed_job_info` method should not be called because it requires the job id to get any information. Without it, the method would except and since it is called within the exponential backoff mechanism, the job would get stuck in the paused state. Cherry-pick: 2d513868c20af7f9d259b04dff5215ec1663c493 --- aiida/engine/processes/calcjobs/tasks.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/aiida/engine/processes/calcjobs/tasks.py b/aiida/engine/processes/calcjobs/tasks.py index 95fb4b0f8e..2b2c270015 100644 --- a/aiida/engine/processes/calcjobs/tasks.py +++ b/aiida/engine/processes/calcjobs/tasks.py @@ -256,6 +256,10 @@ async def do_retrieve(): scheduler = node.computer.get_scheduler() # type: ignore[union-attr] scheduler.set_transport(transport) + if node.get_job_id() is None: + logger.warning(f'there is no job id for CalcJobNoe<{node.pk}>: skipping `get_detailed_job_info`') + return execmanager.retrieve_calculation(node, transport, retrieved_temporary_folder) + try: detailed_job_info = scheduler.get_detailed_job_info(node.get_job_id()) except FeatureNotAvailable: From 4dccd5f3f093376e9334a0a16b6937ae3f544f88 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Wed, 16 Jun 2021 21:14:31 +0200 Subject: [PATCH 06/18] =?UTF-8?q?=F0=9F=90=9B=20FIX:=20Initialising=20a=20?= =?UTF-8?q?`Node`=20with=20a=20`User`=20(#4977)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The `Node` constructor allows for a specific user to be set other than the current default, except this wasn't actually working. The bug has been fixed and a test added. Cherry-pick: f4f543ce55c82e8ee6ec1d318ba28e66e17e64cf --- aiida/manage/tests/pytest_fixtures.py | 7 + aiida/orm/nodes/node.py | 2 +- tests/orm/node/test_node.py | 287 ++++++++++++++------------ 3 files changed, 162 insertions(+), 134 deletions(-) diff --git a/aiida/manage/tests/pytest_fixtures.py b/aiida/manage/tests/pytest_fixtures.py index 08c203b358..586d0cdac2 100644 --- a/aiida/manage/tests/pytest_fixtures.py +++ b/aiida/manage/tests/pytest_fixtures.py @@ -70,6 +70,13 @@ def clear_database_before_test(aiida_profile): yield +@pytest.fixture(scope='class') +def clear_database_before_test_class(aiida_profile): + """Clear the database before a test class.""" + aiida_profile.reset_db() + yield + + @pytest.fixture(scope='function') def temporary_event_loop(): """Create a temporary loop for independent test case""" diff --git a/aiida/orm/nodes/node.py b/aiida/orm/nodes/node.py index a30a1d1135..fe45ad92da 100644 --- a/aiida/orm/nodes/node.py +++ b/aiida/orm/nodes/node.py @@ -185,7 +185,7 @@ def __init__( raise ValueError('the computer is not stored') computer = computer.backend_entity if computer else None - user = user.backend_entity if user else User.objects(backend).get_default() + user = user if user else User.objects(backend).get_default() if user is None: raise ValueError('the user cannot be None') diff --git a/tests/orm/node/test_node.py b/tests/orm/node/test_node.py index afa318772b..a3d7a4f923 100644 --- a/tests/orm/node/test_node.py +++ b/tests/orm/node/test_node.py @@ -7,7 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=too-many-public-methods,no-self-use +# pylint: disable=attribute-defined-outside-init,no-self-use,too-many-public-methods """Tests for the Node ORM class.""" import io import logging @@ -16,44 +16,64 @@ import pytest -from aiida.backends.testbase import AiidaTestCase from aiida.common import exceptions, LinkType -from aiida.orm import Data, Log, Node, User, CalculationNode, WorkflowNode, load_node +from aiida.orm import Computer, Data, Log, Node, User, CalculationNode, WorkflowNode, load_node from aiida.orm.utils.links import LinkTriple -class TestNode(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestNode: """Tests for generic node functionality.""" - def setUp(self): - super().setUp() + def setup_method(self): + """Setup for methods.""" self.user = User.objects.get_default() + _, self.computer = Computer.objects.get_or_create( + label='localhost', + description='localhost computer set up by test manager', + hostname='localhost', + transport_type='local', + scheduler_type='direct' + ) + self.computer.store() + + def test_instantiate_with_user(self): + """Test a Node can be instantiated with a specific user.""" + new_user = User(email='a@b.com').store() + node = Data(user=new_user).store() + assert node.user.pk == new_user.pk + + def test_instantiate_with_computer(self): + """Test a Node can be instantiated with a specific computer.""" + node = Data(computer=self.computer).store() + assert node.computer.pk == self.computer.pk def test_repository_garbage_collection(self): """Verify that the repository sandbox folder is cleaned after the node instance is garbage collected.""" node = Data() dirpath = node._repository._get_temp_folder().abspath # pylint: disable=protected-access - self.assertTrue(os.path.isdir(dirpath)) + assert os.path.isdir(dirpath) del node - self.assertFalse(os.path.isdir(dirpath)) + assert not os.path.isdir(dirpath) def test_computer_user_immutability(self): """Test that computer and user of a node are immutable after storing.""" node = Data().store() - with self.assertRaises(exceptions.ModificationNotAllowed): + with pytest.raises(exceptions.ModificationNotAllowed): node.computer = self.computer - with self.assertRaises(exceptions.ModificationNotAllowed): + with pytest.raises(exceptions.ModificationNotAllowed): node.user = self.user -class TestNodeAttributesExtras(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestNodeAttributesExtras: """Test for node attributes and extras.""" - def setUp(self): - super().setUp() + def setup_method(self): + """Setup for methods.""" self.node = Data() def test_attributes(self): @@ -62,10 +82,10 @@ def test_attributes(self): self.node.set_attribute('key', original_attribute) node_attributes = self.node.attributes - self.assertEqual(node_attributes['key'], original_attribute) + assert node_attributes['key'] == original_attribute node_attributes['key']['nested']['a'] = 2 - self.assertEqual(original_attribute['nested']['a'], 2) + assert original_attribute['nested']['a'] == 2 # Now store the node and verify that `attributes` then returns a deep copy self.node.store() @@ -73,7 +93,7 @@ def test_attributes(self): # We change the returned node attributes but the original attribute should remain unchanged node_attributes['key']['nested']['a'] = 3 - self.assertEqual(original_attribute['nested']['a'], 2) + assert original_attribute['nested']['a'] == 2 def test_get_attribute(self): """Test the `Node.get_attribute` method.""" @@ -81,14 +101,14 @@ def test_get_attribute(self): self.node.set_attribute('key', original_attribute) node_attribute = self.node.get_attribute('key') - self.assertEqual(node_attribute, original_attribute) + assert node_attribute == original_attribute node_attribute['nested']['a'] = 2 - self.assertEqual(original_attribute['nested']['a'], 2) + assert original_attribute['nested']['a'] == 2 default = 'default' - self.assertEqual(self.node.get_attribute('not_existing', default=default), default) - with self.assertRaises(AttributeError): + assert self.node.get_attribute('not_existing', default=default) == default + with pytest.raises(AttributeError): self.node.get_attribute('not_existing') # Now store the node and verify that `get_attribute` then returns a deep copy @@ -97,11 +117,11 @@ def test_get_attribute(self): # We change the returned node attributes but the original attribute should remain unchanged node_attribute['nested']['a'] = 3 - self.assertEqual(original_attribute['nested']['a'], 2) + assert original_attribute['nested']['a'] == 2 default = 'default' - self.assertEqual(self.node.get_attribute('not_existing', default=default), default) - with self.assertRaises(AttributeError): + assert self.node.get_attribute('not_existing', default=default) == default + with pytest.raises(AttributeError): self.node.get_attribute('not_existing') def test_get_attribute_many(self): @@ -110,10 +130,10 @@ def test_get_attribute_many(self): self.node.set_attribute('key', original_attribute) node_attribute = self.node.get_attribute_many(['key'])[0] - self.assertEqual(node_attribute, original_attribute) + assert node_attribute == original_attribute node_attribute['nested']['a'] = 2 - self.assertEqual(original_attribute['nested']['a'], 2) + assert original_attribute['nested']['a'] == 2 # Now store the node and verify that `get_attribute` then returns a deep copy self.node.store() @@ -121,28 +141,28 @@ def test_get_attribute_many(self): # We change the returned node attributes but the original attribute should remain unchanged node_attribute['nested']['a'] = 3 - self.assertEqual(original_attribute['nested']['a'], 2) + assert original_attribute['nested']['a'] == 2 def test_set_attribute(self): """Test the `Node.set_attribute` method.""" - with self.assertRaises(exceptions.ValidationError): + with pytest.raises(exceptions.ValidationError): self.node.set_attribute('illegal.key', 'value') self.node.set_attribute('valid_key', 'value') self.node.store() - with self.assertRaises(exceptions.ModificationNotAllowed): + with pytest.raises(exceptions.ModificationNotAllowed): self.node.set_attribute('valid_key', 'value') def test_set_attribute_many(self): """Test the `Node.set_attribute` method.""" - with self.assertRaises(exceptions.ValidationError): + with pytest.raises(exceptions.ValidationError): self.node.set_attribute_many({'illegal.key': 'value', 'valid_key': 'value'}) self.node.set_attribute_many({'valid_key': 'value'}) self.node.store() - with self.assertRaises(exceptions.ModificationNotAllowed): + with pytest.raises(exceptions.ModificationNotAllowed): self.node.set_attribute_many({'valid_key': 'value'}) def test_reset_attribute(self): @@ -152,32 +172,32 @@ def test_reset_attribute(self): attributes_illegal = {'attribute.illegal': 'value', 'attribute_four': 'value'} self.node.set_attribute_many(attributes_before) - self.assertEqual(self.node.attributes, attributes_before) + assert self.node.attributes == attributes_before self.node.reset_attributes(attributes_after) - self.assertEqual(self.node.attributes, attributes_after) + assert self.node.attributes == attributes_after - with self.assertRaises(exceptions.ValidationError): + with pytest.raises(exceptions.ValidationError): self.node.reset_attributes(attributes_illegal) self.node.store() - with self.assertRaises(exceptions.ModificationNotAllowed): + with pytest.raises(exceptions.ModificationNotAllowed): self.node.reset_attributes(attributes_after) def test_delete_attribute(self): """Test the `Node.delete_attribute` method.""" self.node.set_attribute('valid_key', 'value') - self.assertEqual(self.node.get_attribute('valid_key'), 'value') + assert self.node.get_attribute('valid_key') == 'value' self.node.delete_attribute('valid_key') - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): self.node.delete_attribute('valid_key') # Repeat with stored node self.node.set_attribute('valid_key', 'value') self.node.store() - with self.assertRaises(exceptions.ModificationNotAllowed): + with pytest.raises(exceptions.ModificationNotAllowed): self.node.delete_attribute('valid_key') def test_delete_attribute_many(self): @@ -187,28 +207,28 @@ def test_clear_attributes(self): """Test the `Node.clear_attributes` method.""" attributes = {'attribute_one': 'value', 'attribute_two': 'value'} self.node.set_attribute_many(attributes) - self.assertEqual(self.node.attributes, attributes) + assert self.node.attributes == attributes self.node.clear_attributes() - self.assertEqual(self.node.attributes, {}) + assert self.node.attributes == {} # Repeat for stored node self.node.store() - with self.assertRaises(exceptions.ModificationNotAllowed): + with pytest.raises(exceptions.ModificationNotAllowed): self.node.clear_attributes() def test_attributes_items(self): """Test the `Node.attributes_items` generator.""" attributes = {'attribute_one': 'value', 'attribute_two': 'value'} self.node.set_attribute_many(attributes) - self.assertEqual(dict(self.node.attributes_items()), attributes) + assert dict(self.node.attributes_items()) == attributes def test_attributes_keys(self): """Test the `Node.attributes_keys` generator.""" attributes = {'attribute_one': 'value', 'attribute_two': 'value'} self.node.set_attribute_many(attributes) - self.assertEqual(set(self.node.attributes_keys()), set(attributes)) + assert set(self.node.attributes_keys()) == set(attributes) def test_extras(self): """Test the `Node.extras` property.""" @@ -216,10 +236,10 @@ def test_extras(self): self.node.set_extra('key', original_extra) node_extras = self.node.extras - self.assertEqual(node_extras['key'], original_extra) + assert node_extras['key'] == original_extra node_extras['key']['nested']['a'] = 2 - self.assertEqual(original_extra['nested']['a'], 2) + assert original_extra['nested']['a'] == 2 # Now store the node and verify that `extras` then returns a deep copy self.node.store() @@ -227,7 +247,7 @@ def test_extras(self): # We change the returned node extras but the original extra should remain unchanged node_extras['key']['nested']['a'] = 3 - self.assertEqual(original_extra['nested']['a'], 2) + assert original_extra['nested']['a'] == 2 def test_get_extra(self): """Test the `Node.get_extra` method.""" @@ -235,14 +255,14 @@ def test_get_extra(self): self.node.set_extra('key', original_extra) node_extra = self.node.get_extra('key') - self.assertEqual(node_extra, original_extra) + assert node_extra == original_extra node_extra['nested']['a'] = 2 - self.assertEqual(original_extra['nested']['a'], 2) + assert original_extra['nested']['a'] == 2 default = 'default' - self.assertEqual(self.node.get_extra('not_existing', default=default), default) - with self.assertRaises(AttributeError): + assert self.node.get_extra('not_existing', default=default) == default + with pytest.raises(AttributeError): self.node.get_extra('not_existing') # Now store the node and verify that `get_extra` then returns a deep copy @@ -251,11 +271,11 @@ def test_get_extra(self): # We change the returned node extras but the original extra should remain unchanged node_extra['nested']['a'] = 3 - self.assertEqual(original_extra['nested']['a'], 2) + assert original_extra['nested']['a'] == 2 default = 'default' - self.assertEqual(self.node.get_extra('not_existing', default=default), default) - with self.assertRaises(AttributeError): + assert self.node.get_extra('not_existing', default=default) == default + with pytest.raises(AttributeError): self.node.get_extra('not_existing') def test_get_extra_many(self): @@ -264,10 +284,10 @@ def test_get_extra_many(self): self.node.set_extra('key', original_extra) node_extra = self.node.get_extra_many(['key'])[0] - self.assertEqual(node_extra, original_extra) + assert node_extra == original_extra node_extra['nested']['a'] = 2 - self.assertEqual(original_extra['nested']['a'], 2) + assert original_extra['nested']['a'] == 2 # Now store the node and verify that `get_extra` then returns a deep copy self.node.store() @@ -275,29 +295,29 @@ def test_get_extra_many(self): # We change the returned node extras but the original extra should remain unchanged node_extra['nested']['a'] = 3 - self.assertEqual(original_extra['nested']['a'], 2) + assert original_extra['nested']['a'] == 2 def test_set_extra(self): """Test the `Node.set_extra` method.""" - with self.assertRaises(exceptions.ValidationError): + with pytest.raises(exceptions.ValidationError): self.node.set_extra('illegal.key', 'value') self.node.set_extra('valid_key', 'value') self.node.store() self.node.set_extra('valid_key', 'changed') - self.assertEqual(load_node(self.node.pk).get_extra('valid_key'), 'changed') + assert load_node(self.node.pk).get_extra('valid_key') == 'changed' def test_set_extra_many(self): """Test the `Node.set_extra` method.""" - with self.assertRaises(exceptions.ValidationError): + with pytest.raises(exceptions.ValidationError): self.node.set_extra_many({'illegal.key': 'value', 'valid_key': 'value'}) self.node.set_extra_many({'valid_key': 'value'}) self.node.store() self.node.set_extra_many({'valid_key': 'changed'}) - self.assertEqual(load_node(self.node.pk).get_extra('valid_key'), 'changed') + assert load_node(self.node.pk).get_extra('valid_key') == 'changed' def test_reset_extra(self): """Test the `Node.reset_extra` method.""" @@ -306,25 +326,25 @@ def test_reset_extra(self): extras_illegal = {'extra.illegal': 'value', 'extra_four': 'value'} self.node.set_extra_many(extras_before) - self.assertEqual(self.node.extras, extras_before) + assert self.node.extras == extras_before self.node.reset_extras(extras_after) - self.assertEqual(self.node.extras, extras_after) + assert self.node.extras == extras_after - with self.assertRaises(exceptions.ValidationError): + with pytest.raises(exceptions.ValidationError): self.node.reset_extras(extras_illegal) self.node.store() self.node.reset_extras(extras_after) - self.assertEqual(load_node(self.node.pk).extras, extras_after) + assert load_node(self.node.pk).extras == extras_after def test_delete_extra(self): """Test the `Node.delete_extra` method.""" self.node.set_extra('valid_key', 'value') - self.assertEqual(self.node.get_extra('valid_key'), 'value') + assert self.node.get_extra('valid_key') == 'value' self.node.delete_extra('valid_key') - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): self.node.delete_extra('valid_key') # Repeat with stored node @@ -332,16 +352,16 @@ def test_delete_extra(self): self.node.store() self.node.delete_extra('valid_key') - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): load_node(self.node.pk).get_extra('valid_key') def test_delete_extra_many(self): """Test the `Node.delete_extra_many` method.""" self.node.set_extra('valid_key', 'value') - self.assertEqual(self.node.get_extra('valid_key'), 'value') + assert self.node.get_extra('valid_key') == 'value' self.node.delete_extra('valid_key') - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): self.node.delete_extra('valid_key') # Repeat with stored group @@ -349,42 +369,43 @@ def test_delete_extra_many(self): self.node.store() self.node.delete_extra('valid_key') - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): load_node(self.node.pk).get_extra('valid_key') def test_clear_extras(self): """Test the `Node.clear_extras` method.""" extras = {'extra_one': 'value', 'extra_two': 'value'} self.node.set_extra_many(extras) - self.assertEqual(self.node.extras, extras) + assert self.node.extras == extras self.node.clear_extras() - self.assertEqual(self.node.extras, {}) + assert self.node.extras == {} # Repeat for stored node self.node.store() self.node.clear_extras() - self.assertEqual(load_node(self.node.pk).extras, {}) + assert load_node(self.node.pk).extras == {} def test_extras_items(self): """Test the `Node.extras_items` generator.""" extras = {'extra_one': 'value', 'extra_two': 'value'} self.node.set_extra_many(extras) - self.assertEqual(dict(self.node.extras_items()), extras) + assert dict(self.node.extras_items()) == extras def test_extras_keys(self): """Test the `Node.extras_keys` generator.""" extras = {'extra_one': 'value', 'extra_two': 'value'} self.node.set_extra_many(extras) - self.assertEqual(set(self.node.extras_keys()), set(extras)) + assert set(self.node.extras_keys()) == set(extras) -class TestNodeLinks(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestNodeLinks: """Test for linking from and to Node.""" - def setUp(self): - super().setUp() + def setup_method(self): + """Setup for methods.""" self.node_source = CalculationNode() self.node_target = Data() @@ -397,21 +418,21 @@ def test_get_stored_link_triples(self): calculation.store() stored_triples = calculation.get_stored_link_triples() - self.assertEqual(len(stored_triples), 1) + assert len(stored_triples) == 1 link_triple = stored_triples[0] # Verify the type and value of the tuple elements - self.assertTrue(isinstance(link_triple, LinkTriple)) - self.assertTrue(isinstance(link_triple.node, Node)) - self.assertTrue(isinstance(link_triple.link_type, LinkType)) - self.assertEqual(link_triple.node.uuid, data.uuid) - self.assertEqual(link_triple.link_type, LinkType.INPUT_CALC) - self.assertEqual(link_triple.link_label, 'input') + assert isinstance(link_triple, LinkTriple) + assert isinstance(link_triple.node, Node) + assert isinstance(link_triple.link_type, LinkType) + assert link_triple.node.uuid == data.uuid + assert link_triple.link_type == LinkType.INPUT_CALC + assert link_triple.link_label == 'input' def test_validate_incoming_ipsum(self): """Test the `validate_incoming` method with respect to linking ourselves.""" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): self.node_target.validate_incoming(self.node_target, LinkType.CREATE, 'link_label') def test_validate_incoming(self): @@ -420,13 +441,13 @@ def test_validate_incoming(self): For a generic Node all incoming link types are valid as long as the source is also of type Node and the link type is a valid LinkType enum value. """ - with self.assertRaises(TypeError): + with pytest.raises(TypeError): self.node_target.validate_incoming(self.node_source, None, 'link_label') - with self.assertRaises(TypeError): + with pytest.raises(TypeError): self.node_target.validate_incoming(None, LinkType.CREATE, 'link_label') - with self.assertRaises(TypeError): + with pytest.raises(TypeError): self.node_target.validate_incoming(self.node_source, LinkType.CREATE.value, 'link_label') def test_add_incoming_create(self): @@ -438,15 +459,15 @@ def test_add_incoming_create(self): target.add_incoming(source_one, LinkType.CREATE, 'link_label') # Can only have a single incoming CREATE link - with self.assertRaises(ValueError): + with pytest.raises(ValueError): target.validate_incoming(source_one, LinkType.CREATE, 'link_label') # Even when the source node is different - with self.assertRaises(ValueError): + with pytest.raises(ValueError): target.validate_incoming(source_two, LinkType.CREATE, 'link_label') # Or when the link label is different - with self.assertRaises(ValueError): + with pytest.raises(ValueError): target.validate_incoming(source_one, LinkType.CREATE, 'other_label') def test_add_incoming_call_calc(self): @@ -458,15 +479,15 @@ def test_add_incoming_call_calc(self): target.add_incoming(source_one, LinkType.CALL_CALC, 'link_label') # Can only have a single incoming CALL_CALC link - with self.assertRaises(ValueError): + with pytest.raises(ValueError): target.validate_incoming(source_one, LinkType.CALL_CALC, 'link_label') # Even when the source node is different - with self.assertRaises(ValueError): + with pytest.raises(ValueError): target.validate_incoming(source_two, LinkType.CALL_CALC, 'link_label') # Or when the link label is different - with self.assertRaises(ValueError): + with pytest.raises(ValueError): target.validate_incoming(source_one, LinkType.CALL_CALC, 'other_label') def test_add_incoming_call_work(self): @@ -478,15 +499,15 @@ def test_add_incoming_call_work(self): target.add_incoming(source_one, LinkType.CALL_WORK, 'link_label') # Can only have a single incoming CALL_WORK link - with self.assertRaises(ValueError): + with pytest.raises(ValueError): target.validate_incoming(source_one, LinkType.CALL_WORK, 'link_label') # Even when the source node is different - with self.assertRaises(ValueError): + with pytest.raises(ValueError): target.validate_incoming(source_two, LinkType.CALL_WORK, 'link_label') # Or when the link label is different - with self.assertRaises(ValueError): + with pytest.raises(ValueError): target.validate_incoming(source_one, LinkType.CALL_WORK, 'other_label') def test_add_incoming_input_calc(self): @@ -498,14 +519,14 @@ def test_add_incoming_input_calc(self): target.add_incoming(source_one, LinkType.INPUT_CALC, 'link_label') # Can only have a single incoming INPUT_CALC link from each source node if the label is not unique - with self.assertRaises(ValueError): + with pytest.raises(ValueError): target.validate_incoming(source_one, LinkType.INPUT_CALC, 'link_label') # Using another link label is fine target.validate_incoming(source_one, LinkType.INPUT_CALC, 'other_label') # However, using the same link, even from another node is illegal - with self.assertRaises(ValueError): + with pytest.raises(ValueError): target.validate_incoming(source_two, LinkType.INPUT_CALC, 'link_label') def test_add_incoming_input_work(self): @@ -517,14 +538,14 @@ def test_add_incoming_input_work(self): target.add_incoming(source_one, LinkType.INPUT_WORK, 'link_label') # Can only have a single incoming INPUT_WORK link from each source node if the label is not unique - with self.assertRaises(ValueError): + with pytest.raises(ValueError): target.validate_incoming(source_one, LinkType.INPUT_WORK, 'link_label') # Using another link label is fine target.validate_incoming(source_one, LinkType.INPUT_WORK, 'other_label') # However, using the same link, even from another node is illegal - with self.assertRaises(ValueError): + with pytest.raises(ValueError): target.validate_incoming(source_two, LinkType.INPUT_WORK, 'link_label') def test_add_incoming_return(self): @@ -536,7 +557,7 @@ def test_add_incoming_return(self): target.add_incoming(source_one, LinkType.RETURN, 'link_label') # Can only have a single incoming RETURN link from each source node if the label is not unique - with self.assertRaises(ValueError): + with pytest.raises(ValueError): target.validate_incoming(source_one, LinkType.RETURN, 'link_label') # From another source node or using another label is fine @@ -553,7 +574,7 @@ def test_validate_outgoing_workflow(self): source = WorkflowNode() target = Data() - with self.assertRaises(ValueError): + with pytest.raises(ValueError): target.add_incoming(source, LinkType.RETURN, 'link_label') def test_get_incoming(self): @@ -568,17 +589,17 @@ def test_get_incoming(self): # Without link type incoming_nodes = target.get_incoming().all() incoming_uuids = sorted([neighbor.node.uuid for neighbor in incoming_nodes]) - self.assertEqual(incoming_uuids, sorted([source_one.uuid, source_two.uuid])) + assert incoming_uuids == sorted([source_one.uuid, source_two.uuid]) # Using a single link type incoming_nodes = target.get_incoming(link_type=LinkType.INPUT_CALC).all() incoming_uuids = sorted([neighbor.node.uuid for neighbor in incoming_nodes]) - self.assertEqual(incoming_uuids, sorted([source_one.uuid, source_two.uuid])) + assert incoming_uuids == sorted([source_one.uuid, source_two.uuid]) # Using a link type tuple incoming_nodes = target.get_incoming(link_type=(LinkType.INPUT_CALC, LinkType.INPUT_WORK)).all() incoming_uuids = sorted([neighbor.node.uuid for neighbor in incoming_nodes]) - self.assertEqual(incoming_uuids, sorted([source_one.uuid, source_two.uuid])) + assert incoming_uuids == sorted([source_one.uuid, source_two.uuid]) def test_node_indegree_unique_pair(self): """Test that the validation of links with indegree `unique_pair` works correctly @@ -597,7 +618,7 @@ def test_node_indegree_unique_pair(self): uuids_incoming = set(node.uuid for node in called.get_incoming().all_nodes()) uuids_expected = set([caller.uuid, data.uuid]) - self.assertEqual(uuids_incoming, uuids_expected) + assert uuids_incoming == uuids_expected def test_node_indegree_unique_triple(self): """Test that the validation of links with indegree `unique_triple` works correctly @@ -615,7 +636,7 @@ def test_node_indegree_unique_triple(self): uuids_incoming = set(node.uuid for node in data.get_incoming().all_nodes()) uuids_expected = set([return_one.uuid, return_two.uuid]) - self.assertEqual(uuids_incoming, uuids_expected) + assert uuids_incoming == uuids_expected def test_node_outdegree_unique_triple(self): """Test that the validation of links with outdegree `unique_triple` works correctly @@ -636,7 +657,7 @@ def test_node_outdegree_unique_triple(self): uuids_outgoing = set(node.uuid for node in creator.get_outgoing().all_nodes()) uuids_expected = set([data_one.uuid, data_two.uuid]) - self.assertEqual(uuids_outgoing, uuids_expected) + assert uuids_outgoing == uuids_expected def test_get_node_by_label(self): """Test the get_node_by_label() method of the `LinkManager` @@ -661,12 +682,12 @@ def test_get_node_by_label(self): # Retrieve a link when the label is unique output_the_input = data.get_outgoing(link_type=LinkType.INPUT_CALC).get_node_by_label('the_input') - self.assertEqual(output_the_input.pk, calc_two.pk) + assert output_the_input.pk == calc_two.pk - with self.assertRaises(exceptions.MultipleObjectsError): + with pytest.raises(exceptions.MultipleObjectsError): data.get_outgoing(link_type=LinkType.INPUT_CALC).get_node_by_label('input') - with self.assertRaises(exceptions.NotExistent): + with pytest.raises(exceptions.NotExistent): data.get_outgoing(link_type=LinkType.INPUT_CALC).get_node_by_label('some_weird_label') def test_tab_completable_properties(self): @@ -710,42 +731,42 @@ def test_tab_completable_properties(self): output2.add_incoming(top_workflow, link_type=LinkType.RETURN, link_label='result_b') # creator - self.assertEqual(output1.creator.pk, calc1.pk) - self.assertEqual(output2.creator.pk, calc2.pk) + assert output1.creator.pk == calc1.pk + assert output2.creator.pk == calc2.pk # caller (for calculations) - self.assertEqual(calc1.caller.pk, workflow.pk) - self.assertEqual(calc2.caller.pk, workflow.pk) + assert calc1.caller.pk == workflow.pk + assert calc2.caller.pk == workflow.pk # caller (for workflows) - self.assertEqual(workflow.caller.pk, top_workflow.pk) + assert workflow.caller.pk == top_workflow.pk # .inputs for calculations - self.assertEqual(calc1.inputs.input_value.pk, input1.pk) - self.assertEqual(calc2.inputs.input_value.pk, input2.pk) - with self.assertRaises(AttributeError): + assert calc1.inputs.input_value.pk == input1.pk + assert calc2.inputs.input_value.pk == input2.pk + with pytest.raises(AttributeError): _ = calc1.inputs.some_label # .inputs for workflows - self.assertEqual(top_workflow.inputs.a.pk, input1.pk) - self.assertEqual(top_workflow.inputs.b.pk, input2.pk) - self.assertEqual(workflow.inputs.a.pk, input1.pk) - self.assertEqual(workflow.inputs.b.pk, input2.pk) - with self.assertRaises(AttributeError): + assert top_workflow.inputs.a.pk == input1.pk + assert top_workflow.inputs.b.pk == input2.pk + assert workflow.inputs.a.pk == input1.pk + assert workflow.inputs.b.pk == input2.pk + with pytest.raises(AttributeError): _ = workflow.inputs.some_label # .outputs for calculations - self.assertEqual(calc1.outputs.result.pk, output1.pk) - self.assertEqual(calc2.outputs.result.pk, output2.pk) - with self.assertRaises(AttributeError): + assert calc1.outputs.result.pk == output1.pk + assert calc2.outputs.result.pk == output2.pk + with pytest.raises(AttributeError): _ = calc1.outputs.some_label # .outputs for workflows - self.assertEqual(top_workflow.outputs.result_a.pk, output1.pk) - self.assertEqual(top_workflow.outputs.result_b.pk, output2.pk) - self.assertEqual(workflow.outputs.result_a.pk, output1.pk) - self.assertEqual(workflow.outputs.result_b.pk, output2.pk) - with self.assertRaises(AttributeError): + assert top_workflow.outputs.result_a.pk == output1.pk + assert top_workflow.outputs.result_b.pk == output2.pk + assert workflow.outputs.result_a.pk == output1.pk + assert workflow.outputs.result_b.pk == output2.pk + with pytest.raises(AttributeError): _ = workflow.outputs.some_label From 3ea86030da4ab922454fbe81c3367c8d05eaaec7 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Fri, 11 Jun 2021 14:07:34 +0200 Subject: [PATCH 07/18] `ProcessBuilder`: add the `_merge` method The class already had the `_update` method, but like the normal `update` method of a mapping, this will not recursively merge the contents of nested dictionaries with any existing values. Anything below the top level will simply get overwritten. However, often one wants to merge the contents of the dictionary with the existing namespace structure only overriding leaf nodes. To this end the `_merge` method is added which recursively merges the content of the new dictionary with the existing content. Cherry-pick: 97a7fa9f4a715dec69a7d1453ae713b806425ec2 --- aiida/engine/processes/builder.py | 58 +++++++++++++++++++---- tests/engine/test_process_builder.py | 69 ++++++++++++++++++++++++++-- 2 files changed, 114 insertions(+), 13 deletions(-) diff --git a/aiida/engine/processes/builder.py b/aiida/engine/processes/builder.py index c7f6939918..3f6eab4271 100644 --- a/aiida/engine/processes/builder.py +++ b/aiida/engine/processes/builder.py @@ -41,7 +41,7 @@ def __init__(self, port_namespace: PortNamespace) -> None: self._valid_fields = [] self._data = {} - # The name and port objects have to be passed to the defined functions as defaults for + # The name and port objects have to be passed to the defined functions as defaults for # their arguments, because this way the content at the time of defining the method is # saved. If they are used directly in the body, it will try to capture the value from # its enclosing scope at the time of being called. @@ -83,16 +83,28 @@ def __setattr__(self, attr: str, value: Any) -> None: else: try: port = self._port_namespace[attr] - except KeyError: + except KeyError as exception: if not self._port_namespace.dynamic: - raise AttributeError(f'Unknown builder parameter: {attr}') + raise AttributeError(f'Unknown builder parameter: {attr}') from exception + port = None # type: ignore[assignment] else: value = port.serialize(value) # type: ignore[union-attr] validation_error = port.validate(value) if validation_error: raise ValueError(f'invalid attribute value {validation_error.message}') - self._data[attr] = value + # If the attribute that is being set corresponds to a port that is a ``PortNamespace`` we need to make sure + # that the nested value remains a ``ProcessBuilderNamespace``. Otherwise, the nested namespaces will become + # plain dictionaries and no longer have the properties of the ``ProcessBuilderNamespace`` that provide all + # the autocompletion and validation when values are being set. Therefore we first construct a new instance + # of a ``ProcessBuilderNamespace`` for the port of the attribute that is being set and than iteratively set + # all the values within the mapping that is being assigned to the attribute. + if isinstance(port, PortNamespace): + self._data[attr] = ProcessBuilderNamespace(port) + for sub_key, sub_value in value.items(): + setattr(self._data[attr], sub_key, sub_value) + else: + self._data[attr] = value def __repr__(self): return self._data.__repr__() @@ -119,19 +131,45 @@ def __delitem__(self, item): def __delattr__(self, item): self._data.__delitem__(item) - def _update(self, *args, **kwds): - """Update the values of the builder namespace passing a mapping as argument or individual keyword value pairs. + def _recursive_merge(self, dictionary, key, value): + """Recursively merge the contents of ``dictionary`` setting its ``key`` to ``value``.""" + if isinstance(value, collections.abc.Mapping): + for inner_key, inner_value in value.items(): + self._recursive_merge(dictionary[key], inner_key, inner_value) + else: + dictionary[key] = value + + def _merge(self, *args, **kwds): + """Merge the content of a dictionary or keyword arguments in . - The method is prefixed with an underscore in order to not reserve the name for a potential port, but in - principle the method functions just as `collections.abc.MutableMapping.update`. + .. note:: This method differs in behavior from ``_update`` in that ``_merge`` will recursively update the + existing dictionary with the one that is specified in the arguments. The ``_update`` method will merge only + the keys on the top level, but any lower lying nested namespace will be replaced entirely. - :param args: a single mapping that should be mapped on the namespace + The method is prefixed with an underscore in order to not reserve the name for a potential port. - :param kwds: keyword value pairs that should be mapped onto the ports + :param args: a single mapping that should be mapped on the namespace. + :param kwds: keyword value pairs that should be mapped onto the ports. """ if len(args) > 1: raise TypeError(f'update expected at most 1 arguments, got {int(len(args))}') + if args: + for key, value in args[0].items(): + self._recursive_merge(self, key, value) + + for key, value in kwds.items(): + self._recursive_merge(self, key, value) + + def _update(self, *args, **kwds): + """Update the values of the builder namespace passing a mapping as argument or individual keyword value pairs. + + The method functions just as `collections.abc.MutableMapping.update` and is merely prefixed with an underscore + in order to not reserve the name for a potential port. + + :param args: a single mapping that should be mapped on the namespace. + :param kwds: keyword value pairs that should be mapped onto the ports. + """ if args: for key, value in args[0].items(): if isinstance(value, collections.abc.Mapping): diff --git a/tests/engine/test_process_builder.py b/tests/engine/test_process_builder.py index 39782ac81e..7abef2f457 100644 --- a/tests/engine/test_process_builder.py +++ b/tests/engine/test_process_builder.py @@ -14,7 +14,7 @@ from aiida import orm from aiida.backends.testbase import AiidaTestCase from aiida.common import LinkType -from aiida.engine import WorkChain, Process +from aiida.engine import WorkChain, Process, ProcessBuilderNamespace from aiida.plugins import CalculationFactory DEFAULT_INT = 256 @@ -47,6 +47,28 @@ def define(cls, spec): spec.input('namespace.c') +class SimpleProcessNamespace(Process): + """Process with basic nested namespaces to test "pruning" of empty nested namespaces from the builder.""" + + @classmethod + def define(cls, spec): + super().define(spec) + spec.input_namespace('namespace.nested', dynamic=True) + spec.input('namespace.a', valid_type=int) + spec.input('namespace.c', valid_type=dict) + + +class NestedNamespaceProcess(Process): + """Process with nested required ports to check the update and merge functionality.""" + + @classmethod + def define(cls, spec): + super().define(spec) + spec.input('nested.namespace.int', valid_type=int, required=True) + spec.input('nested.namespace.float', valid_type=float, required=True) + spec.input('nested.namespace.str', valid_type=str, required=False) + + class MappingData(Mapping, orm.Data): """Data sub class that is also a `Mapping`.""" @@ -159,7 +181,7 @@ def test_dynamic_getters_value(self): self.builder_workchain.boolean = self.inputs['boolean'] # Verify that the correct type is returned by the getter - self.assertTrue(isinstance(self.builder_workchain.dynamic.namespace, dict)) + self.assertTrue(isinstance(self.builder_workchain.dynamic.namespace, ProcessBuilderNamespace)) self.assertTrue(isinstance(self.builder_workchain.name.spaced, orm.Int)) self.assertTrue(isinstance(self.builder_workchain.name_spaced, orm.Str)) self.assertTrue(isinstance(self.builder_workchain.boolean, orm.Bool)) @@ -259,7 +281,7 @@ def test_calc_job_node_get_builder_restart(self): self.assertIn('options', builder.metadata) self.assertEqual(builder.x, orm.Int(1)) self.assertEqual(builder.y, orm.Int(2)) - self.assertDictEqual(builder.metadata.options, original.get_options()) + self.assertDictEqual(builder._inputs(prune=True)['metadata']['options'], original.get_options()) def test_code_get_builder(self): """Test that the `Code.get_builder` method returns a builder where the code is already set.""" @@ -283,3 +305,44 @@ def test_code_get_builder(self): # Check that it complains if the type is not the correct one (for the templatereplacer, it should be a Dict) with self.assertRaises(ValueError): builder.parameters = orm.Int(3) + + def test_set_attr(self): + """Test that ``__setattr__`` keeps sub portnamespaces as ``ProcessBuilderNamespace`` instances.""" + builder = LazyProcessNamespace.get_builder() + self.assertTrue(isinstance(builder.namespace, ProcessBuilderNamespace)) + self.assertTrue(isinstance(builder.namespace.nested, ProcessBuilderNamespace)) + + builder.namespace = {'a': 'a', 'c': 'c', 'nested': {'bird': 'mus'}} + self.assertTrue(isinstance(builder.namespace, ProcessBuilderNamespace)) + self.assertTrue(isinstance(builder.namespace.nested, ProcessBuilderNamespace)) + + def test_update(self): + """Test the ``_update`` method to update an existing builder with a dictionary.""" + builder = NestedNamespaceProcess.get_builder() + builder.nested.namespace = {'int': 1, 'float': 2.0} + self.assertEqual(builder._inputs(prune=True), {'nested': {'namespace': {'int': 1, 'float': 2.0}}}) + + # Since ``_update`` will replace nested namespaces and not recursively merge them, if we don't specify all + # required inputs, the validation should fail. + with self.assertRaises(ValueError): + builder._update({'nested': {'namespace': {'int': 5, 'str': 'x'}}}) + + # Now we specify all required inputs and an additional optional one and since it is a nested namespace + builder._update({'nested': {'namespace': {'int': 5, 'float': 3.0, 'str': 'x'}}}) + self.assertEqual(builder._inputs(prune=True), {'nested': {'namespace': {'int': 5, 'float': 3.0, 'str': 'x'}}}) + + def test_merge(self): + """Test the ``_merge`` method to merge a dictionary into an existing builder.""" + builder = NestedNamespaceProcess.get_builder() + builder.nested.namespace = {'int': 1, 'float': 2.0} + self.assertEqual(builder._inputs(prune=True), {'nested': {'namespace': {'int': 1, 'float': 2.0}}}) + + # Define only one of the required ports of `nested.namespace`. This should leave the `float` input untouched and + # even though not specified explicitly again, since the merged dictionary still contains it, the + # `nested.namespace` port should still be valid. + builder._merge({'nested': {'namespace': {'int': 5}}}) + self.assertEqual(builder._inputs(prune=True), {'nested': {'namespace': {'int': 5, 'float': 2.0}}}) + + # Perform same test but passing the dictionary in as keyword arguments + builder._merge(**{'nested': {'namespace': {'int': 5}}}) + self.assertEqual(builder._inputs(prune=True), {'nested': {'namespace': {'int': 5, 'float': 2.0}}}) From dd4075e51cde95216341df92199ce4b10bf551fe Mon Sep 17 00:00:00 2001 From: Matt Clarke Date: Mon, 21 Jun 2021 12:24:05 +0100 Subject: [PATCH 08/18] `SgeScheduler`: fix bug where sanitized job title was not used (#4994) The job title was actually sanitized, removing characters that are not supported by the SGE scheduler, but the original string was used accidentally, which was not caught due to missing tests. Cherry-pick: f2367e95595dde6fcfdd64e1971ac805764a419d --- aiida/schedulers/plugins/sge.py | 2 +- tests/schedulers/test_sge.py | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/aiida/schedulers/plugins/sge.py b/aiida/schedulers/plugins/sge.py index 85e55eb89c..c07f92e503 100644 --- a/aiida/schedulers/plugins/sge.py +++ b/aiida/schedulers/plugins/sge.py @@ -206,7 +206,7 @@ def _get_submit_script_header(self, job_tmpl): if not job_title or (job_title[0] not in string.ascii_letters): job_title = f'j{job_title}' - lines.append(f'#$ -N {job_tmpl.job_name}') + lines.append(f'#$ -N {job_title}') if job_tmpl.import_sys_environment: lines.append('#$ -V') diff --git a/tests/schedulers/test_sge.py b/tests/schedulers/test_sge.py index 2c5fa680b9..cfa6cb543e 100644 --- a/tests/schedulers/test_sge.py +++ b/tests/schedulers/test_sge.py @@ -356,3 +356,16 @@ def _parse_time_string(string, fmt='%Y-%m-%dT%H:%M:%S'): # the seconds since epoch, as suggested on stackoverflow: # http://stackoverflow.com/questions/1697815 return datetime.datetime.fromtimestamp(time.mktime(time_struct)) + + def test_job_name_cleaning(self): + """Test that invalid characters are cleaned from job name.""" + from aiida.schedulers.datastructures import JobTemplate + + scheduler = SgeScheduler() + + job_tmpl = JobTemplate() + job_tmpl.job_resource = scheduler.create_job_resource(parallel_env='mpi8', tot_num_mpiprocs=16) + job_tmpl.job_name = 'Some/job:name@with*invalid-characters.' + + header = scheduler._get_submit_script_header(job_tmpl) + self.assertTrue('#$ -N Somejobnamewithinvalid-characters.' in header, header) From 80f0d7c8682d6edc35da134b04e762407c139112 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Thu, 8 Jul 2021 17:42:51 +0200 Subject: [PATCH 09/18] ORM: fix deprecation warning always being shown in link managers (#5011) The link managers for the `Node` class which are used for the `inputs` and `outputs` attributes and facilitate the tab-completion of incoming and outgoing links, was recently changed to deprecate the direct use of double underscores in link labels in favor of treating them as normal nested dictionaries. The deprecation warning was thrown whenever the label contained a double underscore, but this would therefore also trigger on dunder methods, which is not desirable behaviour. This inaccuracy manifested itself in the deprecation method being printed even when just activating the tab-completion on `node.outputs` or `node.inputs` without even specifying a label with a double underscore. It is not fully understood how `_get_node_by_link_label` is called in doing this, but it seems some caching mechanism is calling the `__wrapped__` attribute on the link manager, which in turn triggers the deprecation warning. An additional clause in the condition to exclude dunder methods fixes the behaviour. Cherry-pick: 53c5564529eb636cc376e4b9535fd59c756d5b0f --- aiida/orm/utils/managers.py | 16 ++++++++++++++-- tests/orm/utils/test_managers.py | 27 ++++++++++++++++++++++++--- 2 files changed, 38 insertions(+), 5 deletions(-) diff --git a/aiida/orm/utils/managers.py b/aiida/orm/utils/managers.py index ad53db290c..be13ed6a81 100644 --- a/aiida/orm/utils/managers.py +++ b/aiida/orm/utils/managers.py @@ -87,7 +87,10 @@ def _get_node_by_link_label(self, label): try: node = attribute_dict[label] except KeyError as exception: - if '__' in label: + # Check whether the label contains a double underscore, in which case we want to warn the user that this is + # deprecated. However, we need to exclude labels that corresponds to dunder methods, i.e., those that start + # and end with a double underscore. + if '__' in label and not (label.startswith('__') and label.endswith('__')): import functools import warnings from aiida.common.warnings import AiidaDeprecationWarning @@ -98,7 +101,16 @@ def _get_node_by_link_label(self, label): 'Support for double underscores will be removed in the future.', AiidaDeprecationWarning ) # pylint: disable=no-member namespaces = label.split('__') - return functools.reduce(lambda d, namespace: d.get(namespace), namespaces, attribute_dict) + try: + return functools.reduce(lambda d, namespace: d.get(namespace), namespaces, attribute_dict) + except TypeError as exc: + # This can be raised if part of the `namespaces` correspond to an actual leaf node, but is treated + # like a namespace + raise NotExistent from exc + except AttributeError as exc: + # This will be raised if any of the intermediate namespaces don't exist, and so the label node does + # not exist. + raise NotExistent from exc raise NotExistent from exception return node diff --git a/tests/orm/utils/test_managers.py b/tests/orm/utils/test_managers.py index aee9d6dde6..e68f916fa4 100644 --- a/tests/orm/utils/test_managers.py +++ b/tests/orm/utils/test_managers.py @@ -152,6 +152,10 @@ def test_link_manager_with_nested_namespaces(clear_database_before_test): out1.add_incoming(calc, link_type=LinkType.CREATE, link_label='nested__sub__namespace') out1.store() + out2 = orm.Data() + out2.add_incoming(calc, link_type=LinkType.CREATE, link_label='remote_folder') + out2.store() + # Check that the recommended way of dereferencing works assert calc.inputs.nested.sub.namespace.uuid == inp1.uuid assert calc.outputs.nested.sub.namespace.uuid == out1.uuid @@ -164,10 +168,27 @@ def test_link_manager_with_nested_namespaces(clear_database_before_test): assert calc.inputs.nested__sub__namespace.uuid == inp1.uuid assert calc.outputs.nested__sub__namespace.uuid == out1.uuid + # Dunder methods should not invoke the deprecation warning + with pytest.warns(None) as record: + try: + calc.inputs.__name__ + except AttributeError: + pass + assert not record + # Must raise a AttributeError, otherwise tab competion will not work - with pytest.raises(AttributeError): - _ = calc.outputs.nested.not_existent + for attribute in ['not_existent', 'not__existent__nested']: + with pytest.raises(AttributeError): + _ = getattr(calc.outputs.nested, attribute) # Must raise a KeyError + for key in ['not_existent', 'not__existent__nested']: + with pytest.raises(KeyError): + _ = calc.outputs.nested[key] + + # Note that `remote_folder` corresponds to an actual leaf node, but it is treated like an intermediate namespace + with pytest.raises(AttributeError): + _ = calc.outputs.remote_folder__namespace + with pytest.raises(KeyError): - _ = calc.outputs.nested['not_existent'] + _ = calc.outputs['remote_folder__namespace'] From 3df48bdd8c7419d0a9e54eb258399880d5c07eb4 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Tue, 6 Jul 2021 15:33:32 +0200 Subject: [PATCH 10/18] CLI: do not provide command hints during tab-completion (#5012) The `verdi` command uses a custom `Group` subclass that overrides the `get_command` method to provide some additional information in case the provided command name does not exist. These hints should help the user spot potential typos by giving a list of existing commands with similar names. However, this feature was also being triggered during tab-completion. For example, typing `verdi comput list` followed by activating tab-completion would result in the error being displayed since `comput` is not a valid command. In this case, one does not want to display the error message with hint at all. The tricky part is that there is no canonical way to distinguish between a normal command execution and a tab-completion event. The best bet is to use the `resilient_parsing` attribute on the `Context` which is set to `True` during tab-completion. Although this attribute was introduced into `click` directly to support auto-completion functionality, the problem is that this is not the only use case for which this flag can be set. It is therefore possible that there is some code path where this flag is set to `True` but it does not actually correspond to a tab-completion event. For now there doesn't seem to be a better solution though and in most cases this approach should work. Cherry-pick: 8e763bb680533ddc0e4fc2f2da6ca81aeecd8b4d --- aiida/cmdline/commands/cmd_verdi.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/aiida/cmdline/commands/cmd_verdi.py b/aiida/cmdline/commands/cmd_verdi.py index 6a395b1185..2cd299f164 100644 --- a/aiida/cmdline/commands/cmd_verdi.py +++ b/aiida/cmdline/commands/cmd_verdi.py @@ -8,8 +8,8 @@ # For further information please visit http://www.aiida.net # ########################################################################### """The main `verdi` click group.""" - import difflib + import click from aiida import __version__ @@ -52,23 +52,33 @@ def get_command(self, ctx, cmd_name): """ cmd = click.Group.get_command(self, ctx, cmd_name) - # return the exact match + # If we match an actual command, simply return the match if cmd is not None: return cmd + # If this command is called during tab-completion, we do not want to print an error message if the command can't + # be found, but instead we want to simply return here. However, in a normal command execution, we do want to + # execute the rest of this method to try and match commands that are similar in order to provide the user with + # some hints. The problem is that there is no one canonical way to determine whether the invocation is due to a + # normal command execution or a tab-complete operation. The `resilient_parsing` attribute of the `Context` is + # designed to allow things like tab-completion, however, it is not the only purpose. For now this is our best + # bet though to detect a tab-complete event. When `resilient_parsing` is switched on, we assume a tab-complete + # and do nothing in case the command name does not match an actual command. + if ctx.resilient_parsing: + return + if int(cmd_name.lower().encode('utf-8').hex(), 16) == 0x6769757365707065: import base64 import gzip click.echo(gzip.decompress(base64.b85decode(GIU.encode('utf-8'))).decode('utf-8')) return None - # we might get better results with the Levenshtein distance - # or more advanced methods implemented in FuzzyWuzzy or similar libs, - # but this is an easy win for now + # We might get better results with the Levenshtein distance or more advanced methods implemented in FuzzyWuzzy + # or similar libs, but this is an easy win for now. matches = difflib.get_close_matches(cmd_name, self.list_commands(ctx), cutoff=0.5) if not matches: - # single letters are sometimes not matched, try with a simple startswith + # Single letters are sometimes not matched so also try with a simple startswith matches = [c for c in sorted(self.list_commands(ctx)) if c.startswith(cmd_name)][:3] if matches: From b468b26f1dd28936256d87a9be2a6454ff046c76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20R=C3=BC=C3=9Fmann?= Date: Tue, 27 Jul 2021 16:59:38 +0200 Subject: [PATCH 11/18] =?UTF-8?q?=F0=9F=90=9B=20FIX:=20`BandsData`=20matpl?= =?UTF-8?q?otlib=20with=20NaN=20values=20(#5024)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Use `numpy.nanmin` and `numpy.nanmax` for computing y-limits. Co-authored-by: Chris Sewell Cherry-pick: d13884289971b47f851d903900cbe17e5e3131c9 --- aiida/orm/nodes/data/array/bands.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aiida/orm/nodes/data/array/bands.py b/aiida/orm/nodes/data/array/bands.py index 21ba303707..0559017a4f 100644 --- a/aiida/orm/nodes/data/array/bands.py +++ b/aiida/orm/nodes/data/array/bands.py @@ -784,9 +784,9 @@ def _matplotlib_get_dict( # axis limits if y_max_lim is None: - y_max_lim = numpy.array(bands).max() + y_max_lim = numpy.nanmax(bands) if y_min_lim is None: - y_min_lim = numpy.array(bands).min() + y_min_lim = numpy.nanmin(bands) x_min_lim = min(x) # this isn't a numpy array, but a list x_max_lim = max(x) all_data['x_min_lim'] = x_min_lim From af2e9e35028cb5990b18ce4293f774002f26001c Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Mon, 9 Aug 2021 15:07:36 +0200 Subject: [PATCH 12/18] Database: minimize PostgreSQL logs for Django backend (#5056) For the Django backend, we are recording the current version of the database schema manually in our own `DbSetting` table. We use this to determine whether the database schema version is compatible with that of the code that is installed. Since the way the version and generation of the schema are stored change over the years, we had to rely on a try-except method to determine which way was the case. The original try was for the original way of v0.x of `aiida-core`, and if that failed we would fall back to the modern method of storing the info introduced in `aiida-core==1.0`. However, each time the except was hit, meaning a programming error was encountered in the PostgreSQL statement, a number of warning are logged to PostgreSQL. Not only can they appear in stdout if no log file has been configured for the PostgreSQL server (can be the case for conda installations) the log file will also accumulate these necessary logs, which are produced every time the database version is checked, i.e., when a profile is loaded, which is often. It might not be trivial to get completely rid of these errors being logged but we can invert the try-except logic. Assuming that the majority of users are now on v1.0 or above, they would always be hitting the except. By first checking for the schema-style of v1.0 instead, only users of v0.x will be getting the logs, which should be negligible. Cherry-pick: 9285245849f480a9668313d3dbfcbbc33691eea3 --- aiida/backends/djsite/manager.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/aiida/backends/djsite/manager.py b/aiida/backends/djsite/manager.py index 81c12c2dfe..edaf636c4b 100644 --- a/aiida/backends/djsite/manager.py +++ b/aiida/backends/djsite/manager.py @@ -87,14 +87,16 @@ def get_schema_generation_database(self): backend = get_manager()._load_backend(schema_check=False) # pylint: disable=protected-access try: - result = backend.execute_raw(r"""SELECT tval FROM db_dbsetting WHERE key = 'schema_generation';""") - except ProgrammingError: result = backend.execute_raw(r"""SELECT val FROM db_dbsetting WHERE key = 'schema_generation';""") - - try: - return str(int(result[0][0])) - except (IndexError, TypeError, ValueError): + except ProgrammingError: + # If this value does not exist, the schema has to correspond to the first generation which didn't actually + # record its value explicitly in the database until ``aiida-core>=1.0.0``. return '1' + else: + try: + return str(int(result[0][0])) + except (IndexError, ValueError, TypeError): + return '1' def get_schema_version_database(self): """Return the database schema version. @@ -107,9 +109,9 @@ def get_schema_version_database(self): backend = get_manager()._load_backend(schema_check=False) # pylint: disable=protected-access try: - result = backend.execute_raw(r"""SELECT tval FROM db_dbsetting WHERE key = 'db|schemaversion';""") - except ProgrammingError: result = backend.execute_raw(r"""SELECT val FROM db_dbsetting WHERE key = 'db|schemaversion';""") + except ProgrammingError: + result = backend.execute_raw(r"""SELECT tval FROM db_dbsetting WHERE key = 'db|schemaversion';""") return result[0][0] def set_schema_version_database(self, version): From 59ee84f522886552e19a0cc0331c7e0b35ba591b Mon Sep 17 00:00:00 2001 From: Dominik Gresch Date: Wed, 16 Jun 2021 10:09:50 +0200 Subject: [PATCH 13/18] =?UTF-8?q?=F0=9F=91=8C=20IMPROVE:=20Allow=20numpy?= =?UTF-8?q?=20arrays=20to=20be=20serialized=20on=20process=20checkpoints?= =?UTF-8?q?=20(#4730)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit To allow objects such as numpy arrays to be serialized to a process checkpoint, the `AiiDALoader` now inherits from `yaml.UnsafeLoader` instead of `yaml.FullLoader`. Note, this change represents a potential security risk, whereby maliciously crafted code could be added to the serialized data and then loaded upon importing an archive. To mitigate this risk, the function `deserialize` has been renamed to `deserialize_unsafe`, and node checkpoint attributes are removed before importing an archive. This code is not part of the public API, and so we assume no specific deprecations are required. This change has also allowed for a relaxation of the `pyaml` pinning (to 5.2), although it should be noted that this upgrade will not be realised until a similar relaxation is implemented in plumpy. Cherry-pick: 1bc9dbe43ff31b737ce29aca605ae15985d38ca0 --- aiida/engine/persistence.py | 2 +- aiida/engine/processes/process.py | 2 +- aiida/manage/manager.py | 2 +- aiida/orm/utils/serialize.py | 13 +++---- .../importexport/dbimport/backends/common.py | 15 +++++++- .../importexport/dbimport/backends/django.py | 4 +- .../importexport/dbimport/backends/sqla.py | 4 +- docs/source/nitpick-exceptions | 4 +- environment.yml | 2 +- setup.json | 2 +- tests/common/test_serialize.py | 38 ++++++++++++++++--- .../importexport/test_specific_import.py | 30 +++++++++++++++ 12 files changed, 95 insertions(+), 23 deletions(-) diff --git a/aiida/engine/persistence.py b/aiida/engine/persistence.py index 2ccdac03c1..5ee9970b14 100644 --- a/aiida/engine/persistence.py +++ b/aiida/engine/persistence.py @@ -121,7 +121,7 @@ def load_checkpoint(self, pid: Hashable, tag: Optional[str] = None) -> plumpy.pe raise PersistenceError(f'Calculation<{calculation.pk}> does not have a saved checkpoint') try: - bundle = serialize.deserialize(checkpoint) + bundle = serialize.deserialize_unsafe(checkpoint) except Exception: raise PersistenceError(f'Failed to load the checkpoint for process<{pid}>: {traceback.format_exc()}') diff --git a/aiida/engine/processes/process.py b/aiida/engine/processes/process.py index 12d4d9dc6c..3064bfe75b 100644 --- a/aiida/engine/processes/process.py +++ b/aiida/engine/processes/process.py @@ -604,7 +604,7 @@ def decode_input_args(self, encoded: str) -> Dict[str, Any]: # pylint: disable= :param encoded: encoded (serialized) inputs :return: The decoded input args """ - return serialize.deserialize(encoded) + return serialize.deserialize_unsafe(encoded) def update_node_state(self, state: plumpy.process_states.State) -> None: self.update_outputs() diff --git a/aiida/manage/manager.py b/aiida/manage/manager.py index 8f8bdfd1f1..390e62fba3 100644 --- a/aiida/manage/manager.py +++ b/aiida/manage/manager.py @@ -243,7 +243,7 @@ def create_communicator( if with_orm: from aiida.orm.utils import serialize encoder = functools.partial(serialize.serialize, encoding='utf-8') - decoder = serialize.deserialize + decoder = serialize.deserialize_unsafe else: # used by verdi status to get a communicator without needing to load the dbenv from aiida.common import json diff --git a/aiida/orm/utils/serialize.py b/aiida/orm/utils/serialize.py index ae1ebf49dc..ba7fdd85ad 100644 --- a/aiida/orm/utils/serialize.py +++ b/aiida/orm/utils/serialize.py @@ -176,13 +176,11 @@ def represent_data(self, data): return super().represent_data(data) -class AiiDALoader(yaml.FullLoader): +class AiiDALoader(yaml.UnsafeLoader): """AiiDA specific yaml loader - .. note:: we subclass the `FullLoader` which is the one that since `pyyaml>=5.1` is the loader that prevents - arbitrary code execution. Even though this is in principle only used internally, one could imagine someone - sharing a database with a maliciously crafted process instance dump, which when reloaded could execute arbitrary - code. This load prevents this: https://github.com/yaml/pyyaml/wiki/PyYAML-yaml.load(input)-Deprecation + .. note:: The `AiiDALoader` should only be used on trusted input, because it uses the `yaml.UnsafeLoader`. When + importing a shared database, we strip all process node checkpoints to avoid this being a security risk. """ @@ -219,10 +217,11 @@ def serialize(data, encoding=None): return serialized -def deserialize(serialized): +def deserialize_unsafe(serialized): """Deserialize a yaml dump that represents a serialized data structure. - .. note:: no need to use `yaml.safe_load` here because the `Loader` will ensure that loading is safe. + .. note:: This function should not be used on untrusted input, because + it is built upon `yaml.UnsafeLoader`. :param serialized: a yaml serialized string representation :return: the deserialized data structure diff --git a/aiida/tools/importexport/dbimport/backends/common.py b/aiida/tools/importexport/dbimport/backends/common.py index d5a78dbd20..c3212589f0 100644 --- a/aiida/tools/importexport/dbimport/backends/common.py +++ b/aiida/tools/importexport/dbimport/backends/common.py @@ -14,7 +14,7 @@ from aiida.common import timezone from aiida.common.folders import RepositoryFolder from aiida.common.progress_reporter import get_progress_reporter, create_callback -from aiida.orm import Group, ImportGroup, Node, QueryBuilder +from aiida.orm import Group, ImportGroup, Node, QueryBuilder, ProcessNode from aiida.orm.utils._repository import Repository from aiida.tools.importexport.archive.readers import ArchiveReaderAbstract from aiida.tools.importexport.common import exceptions @@ -108,3 +108,16 @@ def _sanitize_extras(fields: dict) -> dict: if fields.get('node_type', '').endswith('code.Code.'): fields['extras'] = {key: value for key, value in fields['extras'].items() if not key == 'hidden'} return fields + + +def _strip_checkpoints(fields: dict) -> dict: + """Remove checkpoint from attributes of process nodes. + + :param fields: the database fields for the entity + """ + if fields.get('node_type', '').startswith('process.'): + fields = copy.copy(fields) + fields['attributes'] = { + key: value for key, value in fields['attributes'].items() if key != ProcessNode.CHECKPOINT_KEY + } + return fields diff --git a/aiida/tools/importexport/dbimport/backends/django.py b/aiida/tools/importexport/dbimport/backends/django.py index 562266dee4..b020b9b73c 100644 --- a/aiida/tools/importexport/dbimport/backends/django.py +++ b/aiida/tools/importexport/dbimport/backends/django.py @@ -34,7 +34,7 @@ from aiida.tools.importexport.archive.readers import ArchiveReaderAbstract, get_reader from aiida.tools.importexport.dbimport.backends.common import ( - _copy_node_repositories, _make_import_group, _sanitize_extras, MAX_COMPUTERS, MAX_GROUPS + _copy_node_repositories, _make_import_group, _sanitize_extras, _strip_checkpoints, MAX_COMPUTERS, MAX_GROUPS ) @@ -364,6 +364,8 @@ def _select_entity_data( if entity_name == NODE_ENTITY_NAME: # format extras fields = _sanitize_extras(fields) + # strip checkpoints + fields = _strip_checkpoints(fields) if extras_mode_new != 'import': fields.pop('extras', None) new_entries[entity_name][str(pk)] = fields diff --git a/aiida/tools/importexport/dbimport/backends/sqla.py b/aiida/tools/importexport/dbimport/backends/sqla.py index 463fd89850..89b22bed92 100644 --- a/aiida/tools/importexport/dbimport/backends/sqla.py +++ b/aiida/tools/importexport/dbimport/backends/sqla.py @@ -40,7 +40,7 @@ from aiida.tools.importexport.archive.readers import ArchiveReaderAbstract, get_reader from aiida.tools.importexport.dbimport.backends.common import ( - _copy_node_repositories, _make_import_group, _sanitize_extras, MAX_COMPUTERS, MAX_GROUPS + _copy_node_repositories, _make_import_group, _sanitize_extras, _strip_checkpoints, MAX_COMPUTERS, MAX_GROUPS ) @@ -401,6 +401,8 @@ def _select_entity_data( if entity_name == NODE_ENTITY_NAME: # format extras fields = _sanitize_extras(fields) + # strip checkpoints + fields = _strip_checkpoints(fields) if extras_mode_new != 'import': fields.pop('extras', None) new_entries[entity_name][str(pk)] = fields diff --git a/docs/source/nitpick-exceptions b/docs/source/nitpick-exceptions index 06635a95c4..57bcdd18ac 100644 --- a/docs/source/nitpick-exceptions +++ b/docs/source/nitpick-exceptions @@ -134,8 +134,8 @@ py:class yaml.Dumper py:class yaml.Loader py:class yaml.dumper.Dumper py:class yaml.loader.Loader -py:class yaml.FullLoader -py:class yaml.loader.FullLoader +py:class yaml.UnsafeLoader +py:class yaml.loader.UnsafeLoader py:class uuid.UUID py:class psycopg2.extensions.cursor diff --git a/environment.yml b/environment.yml index 849c4a23fb..77cdf2949e 100644 --- a/environment.yml +++ b/environment.yml @@ -31,7 +31,7 @@ dependencies: - psycopg2-binary~=2.8.3 - python-dateutil~=2.8 - pytz~=2019.3 -- pyyaml~=5.1.2 +- pyyaml~=5.1 - reentry~=1.3 - simplejson~=3.16 - sqlalchemy-utils~=0.36.0 diff --git a/setup.json b/setup.json index a37a793718..e39fd029d3 100644 --- a/setup.json +++ b/setup.json @@ -45,7 +45,7 @@ "psycopg2-binary~=2.8.3", "python-dateutil~=2.8", "pytz~=2019.3", - "pyyaml~=5.1.2", + "pyyaml~=5.1", "reentry~=1.3", "simplejson~=3.16", "sqlalchemy-utils~=0.36.0", diff --git a/tests/common/test_serialize.py b/tests/common/test_serialize.py index 720456678f..3b86b0cb58 100644 --- a/tests/common/test_serialize.py +++ b/tests/common/test_serialize.py @@ -9,6 +9,10 @@ ########################################################################### """Serialization tests""" +import types + +import numpy as np + from aiida import orm from aiida.orm.utils import serialize from aiida.backends.testbase import AiidaTestCase @@ -28,7 +32,7 @@ def test_serialize_round_trip(self): data = {'test': 1, 'list': [1, 2, 3, node_a], 'dict': {('Si',): node_b, 'foo': 'bar'}, 'baz': 'aar'} serialized_data = serialize.serialize(data) - deserialized_data = serialize.deserialize(serialized_data) + deserialized_data = serialize.deserialize_unsafe(serialized_data) # For now manual element-for-element comparison until we come up with general # purpose function that can equate two node instances properly @@ -49,7 +53,7 @@ def test_serialize_group(self): data = {'group': group_a} serialized_data = serialize.serialize(data) - deserialized_data = serialize.deserialize(serialized_data) + deserialized_data = serialize.deserialize_unsafe(serialized_data) self.assertEqual(data['group'].uuid, deserialized_data['group'].uuid) self.assertEqual(data['group'].label, deserialized_data['group'].label) @@ -57,13 +61,13 @@ def test_serialize_group(self): def test_serialize_node_round_trip(self): """Test you can serialize and deserialize a node""" node = orm.Data().store() - deserialized = serialize.deserialize(serialize.serialize(node)) + deserialized = serialize.deserialize_unsafe(serialize.serialize(node)) self.assertEqual(node.uuid, deserialized.uuid) def test_serialize_group_round_trip(self): """Test you can serialize and deserialize a group""" group = orm.Group(label='test_serialize_group_round_trip').store() - deserialized = serialize.deserialize(serialize.serialize(group)) + deserialized = serialize.deserialize_unsafe(serialize.serialize(group)) self.assertEqual(group.uuid, deserialized.uuid) self.assertEqual(group.label, deserialized.label) @@ -71,7 +75,7 @@ def test_serialize_group_round_trip(self): def test_serialize_computer_round_trip(self): """Test you can serialize and deserialize a computer""" computer = self.computer - deserialized = serialize.deserialize(serialize.serialize(computer)) + deserialized = serialize.deserialize_unsafe(serialize.serialize(computer)) # pylint: disable=no-member self.assertEqual(computer.uuid, deserialized.uuid) @@ -117,6 +121,28 @@ def test_mixed_attribute_normal_dict(self): attribute_dict['nested']['normal'] = {'a': 2} serialized = serialize.serialize(attribute_dict) - deserialized = serialize.deserialize(serialized) + deserialized = serialize.deserialize_unsafe(serialized) self.assertEqual(attribute_dict, deserialized) + + def test_serialize_numpy(self): # pylint: disable=no-self-use + """Regression test for #3709 + + Check that numpy arrays can be serialized. + """ + data = np.array([1, 2, 3]) + + serialized = serialize.serialize(data) + deserialized = serialize.deserialize_unsafe(serialized) + assert np.all(data == deserialized) + + def test_serialize_simplenamespace(self): # pylint: disable=no-self-use + """Regression test for #3709 + + Check that `types.SimpleNamespace` can be serialized. + """ + data = types.SimpleNamespace(a=1, b=2.1) + + serialized = serialize.serialize(data) + deserialized = serialize.deserialize_unsafe(serialized) + assert data == deserialized diff --git a/tests/tools/importexport/test_specific_import.py b/tests/tools/importexport/test_specific_import.py index 926f3956ac..e943f0eafe 100644 --- a/tests/tools/importexport/test_specific_import.py +++ b/tests/tools/importexport/test_specific_import.py @@ -349,3 +349,33 @@ def test_import_folder(self): src_folders += [os.path.join(dirpath, dirname) for dirname in dirnames] self.maxDiff = None # pylint: disable=invalid-name self.assertListEqual(org_folders, src_folders) + + def test_import_checkpoints(self): + """Check that process node checkpoints are stripped when importing. + + The process node checkpoints need to be stripped because they + could be maliciously crafted to execute arbitrary code, since + we use the `yaml.UnsafeLoader` to load them. + """ + node = orm.WorkflowNode().store() + node.set_checkpoint(12) + node.seal() + node_uuid = node.uuid + assert node.checkpoint == 12 + + with tempfile.NamedTemporaryFile() as handle: + nodes = [node] + export(nodes, filename=handle.name, overwrite=True) + + # Check that we have the expected number of nodes in the database + self.assertEqual(orm.QueryBuilder().append(orm.Node).count(), len(nodes)) + + # Clean the database and verify there are no nodes left + self.clean_db() + assert orm.QueryBuilder().append(orm.Node).count() == 0 + + import_data(handle.name) + + assert orm.QueryBuilder().append(orm.Node).count() == len(nodes) + node_new = orm.load_node(node_uuid) + assert node_new.checkpoint is None From 54871d32650f95443e7e1955dea609b70419bc5a Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Wed, 11 Aug 2021 09:26:08 +0200 Subject: [PATCH 14/18] Dependencies: update requirement `pyyaml~=5.4` (#5060) Earlier versions have critical security flaws that have been fixed in `pyyaml==5.4`. Note that `plumpy` also needs to be upgraded to `0.20.0` which adds support for this version of `pyyaml`. The `UnsafeLoader` is replaced by the `Loader` which are identical, but the former is only being kept as an alias for backwards compatibility but it might be removed in future releases. Cherry-pick: c78e0e27a60b6859dfd2ba66e0ded55ca2651ce4 --- aiida/orm/utils/serialize.py | 9 ++++----- environment.yml | 4 ++-- requirements/requirements-py-3.7.txt | 4 ++-- requirements/requirements-py-3.8.txt | 4 ++-- requirements/requirements-py-3.9.txt | 4 ++-- setup.json | 4 ++-- tests/engine/processes/test_exit_code.py | 3 ++- 7 files changed, 16 insertions(+), 16 deletions(-) diff --git a/aiida/orm/utils/serialize.py b/aiida/orm/utils/serialize.py index ba7fdd85ad..74e5abaf55 100644 --- a/aiida/orm/utils/serialize.py +++ b/aiida/orm/utils/serialize.py @@ -176,11 +176,11 @@ def represent_data(self, data): return super().represent_data(data) -class AiiDALoader(yaml.UnsafeLoader): +class AiiDALoader(yaml.Loader): """AiiDA specific yaml loader - .. note:: The `AiiDALoader` should only be used on trusted input, because it uses the `yaml.UnsafeLoader`. When - importing a shared database, we strip all process node checkpoints to avoid this being a security risk. + .. note:: The `AiiDALoader` should only be used on trusted input, since it uses the `yaml.Loader` which is not safe. + When importing a shared database, we strip all process node checkpoints to avoid this being a security risk. """ @@ -220,8 +220,7 @@ def serialize(data, encoding=None): def deserialize_unsafe(serialized): """Deserialize a yaml dump that represents a serialized data structure. - .. note:: This function should not be used on untrusted input, because - it is built upon `yaml.UnsafeLoader`. + .. note:: This function should not be used on untrusted input, since it is built upon `yaml.Loader` which is unsafe. :param serialized: a yaml serialized string representation :return: the deserialized data structure diff --git a/environment.yml b/environment.yml index 77cdf2949e..2303028d5f 100644 --- a/environment.yml +++ b/environment.yml @@ -25,13 +25,13 @@ dependencies: - numpy~=1.17 - pamqp~=2.3 - paramiko>=2.7.2,~=2.7 -- plumpy~=0.19.0 +- plumpy~=0.20.0 - pgsu~=0.2.0 - psutil~=5.6 - psycopg2-binary~=2.8.3 - python-dateutil~=2.8 - pytz~=2019.3 -- pyyaml~=5.1 +- pyyaml~=5.4 - reentry~=1.3 - simplejson~=3.16 - sqlalchemy-utils~=0.36.0 diff --git a/requirements/requirements-py-3.7.txt b/requirements/requirements-py-3.7.txt index 010b1b6807..d9d73e17e8 100644 --- a/requirements/requirements-py-3.7.txt +++ b/requirements/requirements-py-3.7.txt @@ -88,7 +88,7 @@ pickleshare==0.7.5 Pillow==8.1.1 plotly==4.14.3 pluggy==0.13.1 -plumpy==0.19.0 +plumpy==0.20.0 prometheus-client==0.9.0 prompt-toolkit==3.0.14 psutil==5.8.0 @@ -117,7 +117,7 @@ python-editor==1.0.4 python-memcached==1.59 pytray==0.3.1 pytz==2019.3 -PyYAML==5.1.2 +PyYAML==5.4.1 pyzmq==22.0.2 qtconsole==5.0.2 QtPy==1.9.0 diff --git a/requirements/requirements-py-3.8.txt b/requirements/requirements-py-3.8.txt index 5ec2a1e63e..2a430c6ab0 100644 --- a/requirements/requirements-py-3.8.txt +++ b/requirements/requirements-py-3.8.txt @@ -87,7 +87,7 @@ pickleshare==0.7.5 Pillow==8.1.1 plotly==4.14.3 pluggy==0.13.1 -plumpy==0.19.0 +plumpy==0.20.0 prometheus-client==0.9.0 prompt-toolkit==3.0.14 psutil==5.8.0 @@ -116,7 +116,7 @@ python-editor==1.0.4 python-memcached==1.59 pytray==0.3.1 pytz==2019.3 -PyYAML==5.1.2 +PyYAML==5.4.1 pyzmq==22.0.2 qtconsole==5.0.2 QtPy==1.9.0 diff --git a/requirements/requirements-py-3.9.txt b/requirements/requirements-py-3.9.txt index ea5e9f330b..5368ccd09d 100644 --- a/requirements/requirements-py-3.9.txt +++ b/requirements/requirements-py-3.9.txt @@ -87,7 +87,7 @@ pickleshare==0.7.5 Pillow==8.1.1 plotly==4.14.3 pluggy==0.13.1 -plumpy==0.19.0 +plumpy==0.20.0 prometheus-client==0.9.0 prompt-toolkit==3.0.14 psutil==5.8.0 @@ -116,7 +116,7 @@ python-editor==1.0.4 python-memcached==1.59 pytray==0.3.1 pytz==2019.3 -PyYAML==5.1.2 +PyYAML==5.4.1 pyzmq==22.0.2 qtconsole==5.0.2 QtPy==1.9.0 diff --git a/setup.json b/setup.json index e39fd029d3..ecbd13230d 100644 --- a/setup.json +++ b/setup.json @@ -39,13 +39,13 @@ "numpy~=1.17", "pamqp~=2.3", "paramiko~=2.7,>=2.7.2", - "plumpy~=0.19.0", + "plumpy~=0.20.0", "pgsu~=0.2.0", "psutil~=5.6", "psycopg2-binary~=2.8.3", "python-dateutil~=2.8", "pytz~=2019.3", - "pyyaml~=5.1", + "pyyaml~=5.4", "reentry~=1.3", "simplejson~=3.16", "sqlalchemy-utils~=0.36.0", diff --git a/tests/engine/processes/test_exit_code.py b/tests/engine/processes/test_exit_code.py index 2672a11285..83b2bfdac0 100644 --- a/tests/engine/processes/test_exit_code.py +++ b/tests/engine/processes/test_exit_code.py @@ -39,7 +39,8 @@ def test_exit_code_serializability(): exit_code = ExitCode() serialized = yaml.dump(exit_code) - deserialized = yaml.full_load(serialized) + # The default loaders are "safe" and won't load an ``ExitCode``, however, the ``Loader`` loader will. + deserialized = yaml.load(serialized, Loader=yaml.Loader) assert deserialized == exit_code assert isinstance(deserialized, ExitCode) From ee11b3027f8eff3003c6d3877b479c419c862c57 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Thu, 12 Aug 2021 12:41:07 +0200 Subject: [PATCH 15/18] ORM: deprecate double underscores in `LinkManager` contains (#5067) The use of double underscores in the interface of the `LinkManager` was recently deprecated for v2.0, however, it unintentionally broke the `__contains__` operator. Legacy code that was using code like: if 'some__nested__namespace' in node.inputs which used to work, will now return false, breaking existing code. The solution is to override the `__contains__` operator and check for the presence of a double underscore in the key. If that is the case, now a deprecation warning is emitted, but the key is split on the double underscores and the namespaces are used to fetch the intended nested dictionary before applying the actual contains check on the leaf node. Cherry-pick: bf80fdeb82cb35d6c5b8b9e563f948ac5ee93794 --- aiida/orm/utils/managers.py | 36 +++++++++++++++++++++++++++----- tests/orm/utils/test_managers.py | 27 ++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 5 deletions(-) diff --git a/aiida/orm/utils/managers.py b/aiida/orm/utils/managers.py index be13ed6a81..dd11a4f0d4 100644 --- a/aiida/orm/utils/managers.py +++ b/aiida/orm/utils/managers.py @@ -12,9 +12,12 @@ to access members of other classes via TAB-completable attributes (e.g. the class underlying `calculation.inputs` to allow to do `calculation.inputs.