Skip to content

Commit

Permalink
Transport & Engine: factor out getcwd() & chdir() for compati…
Browse files Browse the repository at this point in the history
…bility with upcoming async transport (aiidateam#6594)

The practice of setting the working directory for a whole instance of the Transport class is abandoned.
This is done, by *always* passing absolute paths to `Transport`, instead of relative paths.
However, both `getcwd()` and `chdir()` remain in the code base, for backward compatibility. 

From now on, any change in `aiida-core`, should respect the new practice.
  • Loading branch information
khsrali authored Nov 5, 2024
1 parent 2ed19dc commit 6f5c35e
Show file tree
Hide file tree
Showing 15 changed files with 257 additions and 161 deletions.
7 changes: 6 additions & 1 deletion src/aiida/calculations/monitors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import tempfile
from pathlib import Path

from aiida.orm import CalcJobNode
from aiida.transports import Transport
Expand All @@ -19,7 +20,11 @@ def always_kill(node: CalcJobNode, transport: Transport) -> str | None:
:returns: A string if the job should be killed, `None` otherwise.
"""
with tempfile.NamedTemporaryFile('w+') as handle:
transport.getfile('_aiidasubmit.sh', handle.name)
cwd = node.get_remote_workdir()
if cwd is None:
raise ValueError('The remote work directory cannot be None')

transport.getfile(str(Path(cwd).joinpath('_aiidasubmit.sh')), handle.name)
handle.seek(0)
output = handle.read()

Expand Down
6 changes: 1 addition & 5 deletions src/aiida/cmdline/commands/cmd_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,7 @@ def _computer_create_temp_file(transport, scheduler, authinfo, computer):
file_content = f"Test from 'verdi computer test' on {datetime.datetime.now().isoformat()}"
workdir = authinfo.get_workdir().format(username=transport.whoami())

try:
transport.chdir(workdir)
except OSError:
transport.makedirs(workdir)
transport.chdir(workdir)
transport.makedirs(workdir, ignore_existing=True)

with tempfile.NamedTemporaryFile(mode='w+') as tempf:
fname = os.path.split(tempf.name)[1]
Expand Down
135 changes: 73 additions & 62 deletions src/aiida/engine/daemon/execmanager.py

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion src/aiida/engine/processes/calcjobs/calcjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,6 @@ def _perform_dry_run(self):
with LocalTransport() as transport:
with SubmitTestFolder() as folder:
calc_info = self.presubmit(folder)
transport.chdir(folder.abspath)
upload_calculation(self.node, transport, calc_info, folder, inputs=self.inputs, dry_run=True)
self.node.dry_run_info = { # type: ignore[attr-defined]
'folder': folder.abspath,
Expand Down
1 change: 0 additions & 1 deletion src/aiida/engine/processes/calcjobs/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ async def task_monitor_job(
async def do_monitor():
with transport_queue.request_transport(authinfo) as request:
transport = await cancellable.with_interrupt(request)
transport.chdir(node.get_remote_workdir())
return monitors.process(node, transport)

try:
Expand Down
49 changes: 16 additions & 33 deletions src/aiida/orm/nodes/data/remote/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,10 @@ def is_empty(self):
transport = authinfo.get_transport()

with transport:
try:
transport.chdir(self.get_remote_path())
except OSError:
# If the transport OSError the directory no longer exists and was deleted
if not transport.isdir(self.get_remote_path()):
return True

return not transport.listdir()
return not transport.listdir(self.get_remote_path())

def getfile(self, relpath, destpath):
"""Connects to the remote folder and retrieves the content of a file.
Expand Down Expand Up @@ -96,22 +93,15 @@ def listdir(self, relpath='.'):
authinfo = self.get_authinfo()

with authinfo.get_transport() as transport:
try:
full_path = os.path.join(self.get_remote_path(), relpath)
transport.chdir(full_path)
except OSError as exception:
if exception.errno in (2, 20): # directory not existing or not a directory
exc = OSError(
f'The required remote folder {full_path} on {self.computer.label} does not exist, is not a '
'directory or has been deleted.'
)
exc.errno = exception.errno
raise exc from exception
else:
raise
full_path = os.path.join(self.get_remote_path(), relpath)
if not transport.isdir(full_path):
raise OSError(
f'The required remote folder {full_path} on {self.computer.label} does not exist, is not a '
'directory or has been deleted.'
)

try:
return transport.listdir()
return transport.listdir(full_path)
except OSError as exception:
if exception.errno in (2, 20): # directory not existing or not a directory
exc = OSError(
Expand All @@ -132,22 +122,15 @@ def listdir_withattributes(self, path='.'):
authinfo = self.get_authinfo()

with authinfo.get_transport() as transport:
try:
full_path = os.path.join(self.get_remote_path(), path)
transport.chdir(full_path)
except OSError as exception:
if exception.errno in (2, 20): # directory not existing or not a directory
exc = OSError(
f'The required remote folder {full_path} on {self.computer.label} does not exist, is not a '
'directory or has been deleted.'
)
exc.errno = exception.errno
raise exc from exception
else:
raise
full_path = os.path.join(self.get_remote_path(), path)
if not transport.isdir(full_path):
raise OSError(
f'The required remote folder {full_path} on {self.computer.label} does not exist, is not a '
'directory or has been deleted.'
)

try:
return transport.listdir_withattributes()
return transport.listdir_withattributes(full_path)
except OSError as exception:
if exception.errno in (2, 20): # directory not existing or not a directory
exc = OSError(
Expand Down
5 changes: 1 addition & 4 deletions src/aiida/orm/utils/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,8 @@ def clean_remote(transport: Transport, path: str) -> None:
if not transport.is_open:
raise ValueError('the transport should already be open')

basedir, relative_path = os.path.split(path)

try:
transport.chdir(basedir)
transport.rmtree(relative_path)
transport.rmtree(path)
except OSError:
pass

Expand Down
7 changes: 4 additions & 3 deletions src/aiida/schedulers/plugins/bash.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@ class BashCliScheduler(Scheduler, metaclass=abc.ABCMeta):
def submit_job(self, working_directory: str, filename: str) -> str | ExitCode:
"""Submit a job.
:param working_directory: The absolute filepath to the working directory where the job is to be exectued.
:param working_directory: The absolute filepath to the working directory where the job is to be executed.
:param filename: The filename of the submission script relative to the working directory.
"""
self.transport.chdir(working_directory)
result = self.transport.exec_command_wait(self._get_submit_command(escape_for_bash(filename)))
result = self.transport.exec_command_wait(
self._get_submit_command(escape_for_bash(filename)), workdir=working_directory
)
return self._parse_submit_output(*result)

def get_jobs(
Expand Down
2 changes: 1 addition & 1 deletion src/aiida/schedulers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def create_job_resource(cls, **kwargs):
def submit_job(self, working_directory: str, filename: str) -> str | ExitCode:
"""Submit a job.
:param working_directory: The absolute filepath to the working directory where the job is to be exectued.
:param working_directory: The absolute filepath to the working directory where the job is to be executed.
:param filename: The filename of the submission script relative to the working directory.
:returns:
"""
Expand Down
43 changes: 27 additions & 16 deletions src/aiida/transports/plugins/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,6 @@
###########################################################################
"""Local transport"""

###
### GP: a note on the local transport:
### I believe that we must not use os.chdir to keep track of the folder
### in which we are, since this may have very nasty side effects in other
### parts of code, and make things not thread-safe.
### we should instead keep track internally of the 'current working directory'
### in the exact same way as paramiko does already.

import contextlib
import errno
import glob
Expand Down Expand Up @@ -101,7 +93,11 @@ def curdir(self):
raise TransportInternalError('Error, local method called for LocalTransport without opening the channel first')

def chdir(self, path):
"""Changes directory to path, emulated internally.
"""
PLEASE DON'T USE `chdir()` IN NEW DEVELOPMENTS, INSTEAD DIRECTLY PASS ABSOLUTE PATHS TO INTERFACE.
`chdir()` is DEPRECATED and will be removed in the next major version.
Changes directory to path, emulated internally.
:param path: path to cd into
:raise OSError: if the directory does not have read attributes.
"""
Expand All @@ -123,7 +119,11 @@ def normalize(self, path='.'):
return os.path.realpath(os.path.join(self.curdir, path))

def getcwd(self):
"""Returns the current working directory, emulated by the transport"""
"""
PLEASE DON'T USE `getcwd()` IN NEW DEVELOPMENTS, INSTEAD DIRECTLY PASS ABSOLUTE PATHS TO INTERFACE.
`getcwd()` is DEPRECATED and will be removed in the next major version.
Returns the current working directory, emulated by the transport"""
return self.curdir

@staticmethod
Expand Down Expand Up @@ -695,11 +695,9 @@ def isfile(self, path):
return os.path.isfile(os.path.join(self.curdir, path))

@contextlib.contextmanager
def _exec_command_internal(self, command, **kwargs):
def _exec_command_internal(self, command, workdir=None, **kwargs):
"""Executes the specified command in bash login shell.
Before the command is executed, changes directory to the current
working directory as returned by self.getcwd().
For executing commands and waiting for them to finish, use
exec_command_wait.
Expand All @@ -710,6 +708,10 @@ def _exec_command_internal(self, command, **kwargs):
:param command: the command to execute. The command is assumed to be
already escaped using :py:func:`aiida.common.escaping.escape_for_bash`.
:param workdir: (optional, default=None) if set, the command will be executed
in the specified working directory.
if None, the command will be executed in the current working directory,
from DEPRECATED `self.getcwd()`.
:return: a tuple with (stdin, stdout, stderr, proc),
where stdin, stdout and stderr behave as file-like objects,
Expand All @@ -724,26 +726,35 @@ def _exec_command_internal(self, command, **kwargs):

command = bash_commmand + escape_for_bash(command)

if workdir:
cwd = workdir
else:
cwd = self.getcwd()

with subprocess.Popen(
command,
shell=True,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
cwd=self.getcwd(),
cwd=cwd,
start_new_session=True,
) as process:
yield process

def exec_command_wait_bytes(self, command, stdin=None, **kwargs):
def exec_command_wait_bytes(self, command, stdin=None, workdir=None, **kwargs):
"""Executes the specified command and waits for it to finish.
:param command: the command to execute
:param workdir: (optional, default=None) if set, the command will be executed
in the specified working directory.
if None, the command will be executed in the current working directory,
from DEPRECATED `self.getcwd()`.
:return: a tuple with (return_value, stdout, stderr) where stdout and stderr
are both bytes and the return_value is an int.
"""
with self._exec_command_internal(command) as process:
with self._exec_command_internal(command, workdir) as process:
if stdin is not None:
# Implicitly assume that the desired encoding is 'utf-8' if I receive a string.
# Also, if I get a StringIO, I just read it all in memory and put it into a BytesIO.
Expand Down
53 changes: 36 additions & 17 deletions src/aiida/transports/plugins/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,11 @@ def __str__(self):
return f"{'OPEN' if self._is_open else 'CLOSED'} [{conn_info}]"

def chdir(self, path):
"""Change directory of the SFTP session. Emulated internally by paramiko.
"""
PLEASE DON'T USE `chdir()` IN NEW DEVELOPMENTS, INSTEAD DIRECTLY PASS ABSOLUTE PATHS TO INTERFACE.
`chdir()` is DEPRECATED and will be removed in the next major version.
Change directory of the SFTP session. Emulated internally by paramiko.
Differently from paramiko, if you pass None to chdir, nothing
happens and the cwd is unchanged.
Expand Down Expand Up @@ -646,7 +650,11 @@ def lstat(self, path):
return self.sftp.lstat(path)

def getcwd(self):
"""Return the current working directory for this SFTP session, as
"""
PLEASE DON'T USE `getcwd()` IN NEW DEVELOPMENTS, INSTEAD DIRECTLY PASS ABSOLUTE PATHS TO INTERFACE.
`getcwd()` is DEPRECATED and will be removed in the next major version.
Return the current working directory for this SFTP session, as
emulated by paramiko. If no directory has been set with chdir,
this method will return None. But in __enter__ this is set explicitly,
so this should never happen within this class.
Expand Down Expand Up @@ -1218,17 +1226,18 @@ def listdir(self, path='.', pattern=None):
:param pattern: returns the list of files matching pattern.
Unix only. (Use to emulate ``ls *`` for example)
"""
if not pattern:
return self.sftp.listdir(path)
if path.startswith('/'):
base_dir = path
abs_dir = path
else:
base_dir = os.path.join(self.getcwd(), path)
abs_dir = os.path.join(self.getcwd(), path)

filtered_list = self.glob(os.path.join(base_dir, pattern))
if not base_dir.endswith('/'):
base_dir += '/'
return [re.sub(base_dir, '', i) for i in filtered_list]
if not pattern:
return self.sftp.listdir(abs_dir)

filtered_list = self.glob(os.path.join(abs_dir, pattern))
if not abs_dir.endswith('/'):
abs_dir += '/'
return [re.sub(abs_dir, '', i) for i in filtered_list]

def remove(self, path):
"""Remove a single file at 'path'"""
Expand Down Expand Up @@ -1276,11 +1285,9 @@ def isfile(self, path):
return False
raise # Typically if I don't have permissions (errno=13)

def _exec_command_internal(self, command, combine_stderr=False, bufsize=-1):
def _exec_command_internal(self, command, combine_stderr=False, bufsize=-1, workdir=None):
"""Executes the specified command in bash login shell.
Before the command is executed, changes directory to the current
working directory as returned by self.getcwd().
For executing commands and waiting for them to finish, use
exec_command_wait.
Expand All @@ -1291,6 +1298,10 @@ def _exec_command_internal(self, command, combine_stderr=False, bufsize=-1):
stderr on the same buffer (i.e., stdout).
Note: If combine_stderr is True, stderr will always be empty.
:param bufsize: same meaning of the one used by paramiko.
:param workdir: (optional, default=None) if set, the command will be executed
in the specified working directory.
if None, the command will be executed in the current working directory,
from DEPRECATED `self.getcwd()`, if that has a value.
:return: a tuple with (stdin, stdout, stderr, channel),
where stdin, stdout and stderr behave as file-like objects,
Expand All @@ -1300,8 +1311,10 @@ def _exec_command_internal(self, command, combine_stderr=False, bufsize=-1):
channel = self.sshclient.get_transport().open_session()
channel.set_combine_stderr(combine_stderr)

if self.getcwd() is not None:
escaped_folder = escape_for_bash(self.getcwd())
if workdir is not None:
command_to_execute = f'cd {workdir} && ( {command} )'
elif (cwd := self.getcwd()) is not None:
escaped_folder = escape_for_bash(cwd)
command_to_execute = f'cd {escaped_folder} && ( {command} )'
else:
command_to_execute = command
Expand All @@ -1320,7 +1333,9 @@ def _exec_command_internal(self, command, combine_stderr=False, bufsize=-1):

return stdin, stdout, stderr, channel

def exec_command_wait_bytes(self, command, stdin=None, combine_stderr=False, bufsize=-1, timeout=0.01):
def exec_command_wait_bytes(
self, command, stdin=None, combine_stderr=False, bufsize=-1, timeout=0.01, workdir=None
):
"""Executes the specified command and waits for it to finish.
:param command: the command to execute
Expand All @@ -1330,14 +1345,18 @@ def exec_command_wait_bytes(self, command, stdin=None, combine_stderr=False, buf
self._exec_command_internal()
:param bufsize: same meaning of paramiko.
:param timeout: ssh channel timeout for stdout, stderr.
:param workdir: (optional, default=None) if set, the command will be executed
in the specified working directory.
:return: a tuple with (return_value, stdout, stderr) where stdout and stderr
are both bytes and the return_value is an int.
"""
import socket
import time

ssh_stdin, stdout, stderr, channel = self._exec_command_internal(command, combine_stderr, bufsize=bufsize)
ssh_stdin, stdout, stderr, channel = self._exec_command_internal(
command, combine_stderr, bufsize=bufsize, workdir=workdir
)

if stdin is not None:
if isinstance(stdin, str):
Expand Down
Loading

0 comments on commit 6f5c35e

Please sign in to comment.