Skip to content

Commit

Permalink
ENH: Emit correct logs when no diagnostics with octue get-diagnostics
Browse files Browse the repository at this point in the history
  • Loading branch information
cortadocodes committed Jul 4, 2024
1 parent 1615a9e commit 37d2a5d
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 14 deletions.
6 changes: 5 additions & 1 deletion octue/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,13 +307,17 @@ def get_diagnostics(cloud_path, local_path, download_datasets):
)
)

GoogleCloudStorageClient().download_all_files(
local_paths = GoogleCloudStorageClient().download_all_files(
local_path=local_path,
cloud_path=cloud_path,
filter=filter,
recursive=True,
)

if not local_paths:
logger.warning("No diagnostics found at %r.", cloud_path)
return

# Update the manifests with the local paths of the datasets.
if download_datasets:
for manifest_type in ("configuration_manifest", "input_manifest"):
Expand Down
30 changes: 17 additions & 13 deletions octue/cloud/storage/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,38 +180,42 @@ def download_all_files(self, local_path, cloud_path, filter=None, recursive=Fals
:param str cloud_path: the path to a cloud storage directory to download
:param callable|None filter: an optional callable to filter which files are downloaded from the cloud path; the callable should take a blob as its only positional argument
:param bool recursive: if `True`, also download all files in all subdirectories of the cloud directory recursively
:return None:
:return list(str): the list of paths the files were downloaded to
"""
bucket, _ = self._get_bucket_and_path_in_bucket(cloud_path)

cloud_and_local_paths = [
{
"cloud_path": storage.path.generate_gs_path(bucket.name, blob.name),
"local_path": os.path.join(
cloud_paths = []
local_paths = []

for blob in self.scandir(cloud_path, filter=filter, recursive=recursive):
cloud_paths.append(storage.path.generate_gs_path(bucket.name, blob.name))

local_paths.append(
os.path.join(
local_path,
storage.path.relpath(
storage.path.generate_gs_path(bucket.name, blob.name),
cloud_path,
),
),
}
for blob in self.scandir(cloud_path, filter=filter, recursive=recursive)
]
)
)

if not cloud_and_local_paths:
if not cloud_paths:
logger.warning(
"Attempted to download files from %r but it appears empty. Please check this is the correct path.",
cloud_path,
)
return
return []

def download_file(cloud_and_local_path):
self.download_to_file(cloud_and_local_path["local_path"], cloud_and_local_path["cloud_path"])
self.download_to_file(cloud_and_local_path[0], cloud_and_local_path[1])

with concurrent.futures.ThreadPoolExecutor() as executor:
for path in executor.map(download_file, cloud_and_local_paths):
for path in executor.map(download_file, zip(local_paths, cloud_paths)):
logger.debug("Downloaded file to %r.", path)

return local_paths

def download_as_string(self, cloud_path, timeout=_DEFAULT_TIMEOUT):
"""Download a file to a string from a Google Cloud bucket at gs://<bucket_name>/<path_in_bucket>.
Expand Down
27 changes: 27 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,33 @@ def setUpClass(cls):

diagnostics.upload(storage.path.join(cls.DIAGNOSTICS_CLOUD_PATH, cls.ANALYSIS_ID))

def test_warning_logged_if_no_diagnostics_found(self):
"""Test that a warning about there being no diagnostics is logged if the diagnostics cloud path is empty."""
with tempfile.TemporaryDirectory() as temporary_directory:
result = CliRunner().invoke(
octue_cli,
[
"get-diagnostics",
storage.path.join(self.DIAGNOSTICS_CLOUD_PATH, "9f4ccee3-15b0-4a03-b5ac-c19e1d66a709"),
"--local-path",
temporary_directory,
],
)

self.assertIn(
"Attempted to download files from 'gs://octue-sdk-python-test-bucket/diagnostics/9f4ccee3-15b0-4a03-b5ac-"
"c19e1d66a709' but it appears empty. Please check this is the correct path.",
result.output,
)

self.assertIn(
"No diagnostics found at 'gs://octue-sdk-python-test-bucket/diagnostics/9f4ccee3-15b0-4a03-b5ac-"
"c19e1d66a709'",
result.output,
)

self.assertNotIn("Downloaded diagnostics from", result.output)

def test_get_diagnostics(self):
"""Test that only the values files, manifests, and questions file are downloaded when using the
`get-diagnostics` CLI command.
Expand Down

0 comments on commit 37d2a5d

Please sign in to comment.