Skip to content

Commit

Permalink
need to pass more args to replace func
Browse files Browse the repository at this point in the history
  • Loading branch information
truib committed May 22, 2024
1 parent 8a9c913 commit 23406aa
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions eb_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ def post_sanitycheck_hook(self, *args, **kwargs):
POST_SANITYCHECK_HOOKS[self.name](self, *args, **kwargs)


def replace_non_distributable_files_with_symlinks(self, package, allowlist):
def replace_non_distributable_files_with_symlinks(log, install_dir, package, allowlist):
"""
Replace files that cannot be distributed with symlinks into host_injections
"""
Expand All @@ -632,7 +632,7 @@ def replace_non_distributable_files_with_symlinks(self, package, allowlist):
raise EasyBuildError("Don't know how to strip non-distributable files from package %s.", package)

# iterate over all files in the package installation directory
for dir_path, _, files in os.walk(self.installdir):
for dir_path, _, files in os.walk(install_dir):
for filename in files:
full_path = os.path.join(dir_path, filename)
# we only really care about real files, i.e. not symlinks
Expand All @@ -643,16 +643,16 @@ def replace_non_distributable_files_with_symlinks(self, package, allowlist):
if '.' in filename:
extension = '.' + filename.split('.')[1]
if basename in allowlist:
self.log.debug("%s is found in allowlist, so keeping it: %s", basename, full_path)
log.debug("%s is found in allowlist, so keeping it: %s", basename, full_path)
elif extension_based[package] and '.' in filename and extension in allowlist:
self.log.debug("%s is found in allowlist, so keeping it: %s", extension, full_path)
log.debug("%s is found in allowlist, so keeping it: %s", extension, full_path)
else:
if extension_based[package]:
print_name = filename
else:
print_name = basename
self.log.debug("%s is not found in allowlist, so replacing it with symlink: %s",
print_name, full_path)
log.debug("%s is not found in allowlist, so replacing it with symlink: %s",
print_name, full_path)
# if it is not in the allowlist, delete the file and create a symlink to host_injections
host_inj_path = full_path.replace('versions', 'host_injections')
# make sure source and target of symlink are not the same
Expand Down Expand Up @@ -705,7 +705,7 @@ def post_sanitycheck_cuda(self, *args, **kwargs):

# replace files that are not distributable with symlinks into
# host_injections
replace_non_distributable_files_with_symlinks(self.name, allowlist)
replace_non_distributable_files_with_symlinks(self.log, self.installdir, self.name, allowlist)
else:
raise EasyBuildError("CUDA-specific hook triggered for non-CUDA easyconfig?!")

Expand Down Expand Up @@ -738,7 +738,7 @@ def post_sanitycheck_cudnn(self, *args, **kwargs):

# replace files that are not distributable with symlinks into
# host_injections
replace_non_distributable_files_with_symlinks(self.name, allowlist)
replace_non_distributable_files_with_symlinks(self.log, self.installdir, self.name, allowlist)
else:
raise EasyBuildError("cuDNN-specific hook triggered for non-cuDNN easyconfig?!")

Expand Down Expand Up @@ -771,7 +771,7 @@ def post_sanitycheck_cutensor(self, *args, **kwargs):

# replace files that are not distributable with symlinks into
# host_injections
replace_non_distributable_files_with_symlinks(self.name, allowlist)
replace_non_distributable_files_with_symlinks(self.log, self.installdir, self.name, allowlist)
else:
raise EasyBuildError("cuTENSOR-specific hook triggered for non-cuTENSOR easyconfig?!")

Expand Down

0 comments on commit 23406aa

Please sign in to comment.