Skip to content

Commit

Permalink
debug print exec job command and keep pool option
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevc committed Jul 16, 2024
1 parent ded0728 commit 3df271b
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 52 deletions.
104 changes: 52 additions & 52 deletions snakemake_executor_plugin_azure_batch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__author__ = "Johannes Köster, Jake VanCampen, Andreas Wilm"
__author__ = "Jake VanCampen, Johannes Köster, Andreas Wilm"
__copyright__ = "Copyright 2023, Snakemake community"
__email__ = "[email protected]"
__email__ = "[email protected]"
__license__ = "MIT"


Expand Down Expand Up @@ -58,7 +58,11 @@
AZURE_BATCH_RESOURCE_ENDPOINT,
DEFAULT_AUTO_SCALE_FORMULA,
)
from snakemake_executor_plugin_azure_batch.util import AzureIdentityCredentialAdapter
from snakemake_executor_plugin_azure_batch.util import (
AzureIdentityCredentialAdapter,
unpack_compute_node_errors,
unpack_task_failure_information,
)

# Required:
# Specify common settings shared by various executors.
Expand All @@ -72,7 +76,7 @@
# Define whether your executor plugin implies that there is no shared
# filesystem (True) or not (False).
# This is e.g. the case for cloud execution.
implies_no_shared_fs=False,
implies_no_shared_fs=True,
job_deploy_sources=True,
pass_default_storage_provider_args=True,
pass_default_resources_args=True,
Expand Down Expand Up @@ -130,6 +134,14 @@ class ExecutorSettings(ExecutorSettingsBase):
"env_var": True,
},
)
keep_pool: bool = field(
default=False,
metadata={
"help": "Keep the Azure batch pool after the workflow completes.",
"required": False,
"env_var": False,
},
)
managed_identity_resource_id: Optional[str] = field(
default=None,
metadata={
Expand Down Expand Up @@ -263,7 +275,7 @@ class Executor(RemoteExecutor):
def __post_init__(self):
# the snakemake/snakemake:latest container image
self.container_image = self.workflow.remote_execution_settings.container_image
self.settings = self.workflow.executor_settings
self.settings: ExecutorSettings = self.workflow.executor_settings
self.logger.debug(
f"ExecutorSettings: {pformat(self.workflow.executor_settings, indent=2)}"
)
Expand Down Expand Up @@ -333,30 +345,23 @@ def init_batch_client(self):
except Exception as e:
raise WorkflowError("Failed to initialize batch client", e)

def managed_id_present(self):
"""returns true if managed identity resource and client id are not None"""
return (
self.settings.managed_identity_client_id is not None
and self.settings.managed_identity_resource_id is not None
)

def shutdown(self):
# perform additional steps on shutdown
# if necessary (jobs were cancelled already)
if not self.settings.keep_pool:
try:
self.logger.debug("Deleting AzBatch job")
self.batch_client.job.delete(self.job_id)
except bm.BatchErrorException as be:
if be.error.code == "JobNotFound":
pass

try:
self.logger.debug("Deleting AzBatch job")
self.batch_client.job.delete(self.job_id)
except bm.BatchErrorException as be:
if be.error.code == "JobNotFound":
pass

try:
self.logger.debug("Deleting AzBatch pool")
self.batch_client.pool.delete(self.pool_id)
except bm.BatchErrorException as be:
if be.error.code == "PoolBeingDeleted":
pass
try:
self.logger.debug("Deleting AzBatch pool")
self.batch_client.pool.delete(self.pool_id)
except bm.BatchErrorException as be:
if be.error.code == "PoolBeingDeleted":
pass

super().shutdown()

Expand All @@ -376,7 +381,8 @@ def run_job(self, job: JobExecutorInterface):
continue

exec_job = self.format_job_exec(job)
exec_job = f"/bin/bash -c {shlex.quote(exec_job)}"
remote_command = f"/bin/bash -c {shlex.quote(exec_job)}"
self.logger.debug(f"Remote command: {remote_command}")

# A string that uniquely identifies the Task within the Job.
task_uuid = str(uuid.uuid4())
Expand All @@ -401,7 +407,7 @@ def run_job(self, job: JobExecutorInterface):
# the container, and the Task command line is executed in the container
task = bm.TaskAddParameter(
id=task_id,
command_line=exec_job,
command_line=remote_command,
container_settings=task_container_settings,
user_identity=bm.UserIdentity(auto_user=user),
environment_settings=envsettings,
Expand Down Expand Up @@ -438,7 +444,6 @@ def _report_task_status(self, job: SubmittedJobInfo):
)
except Exception as e:
self.logger.warning(f"Unable to get Azure Batch Task status: {e}")
# go on and query again next time
return True

self.logger.debug(
Expand All @@ -454,12 +459,15 @@ def _report_task_status(self, job: SubmittedJobInfo):
ei: bm.TaskExecutionInformation = task.execution_info
if ei is not None:
if ei.result == bm.TaskExecutionResult.failure:
msg = f"Azure Batch execution failure: {ei.failure_info.__dict__}"
formatted_failure_info = pformat(
unpack_task_failure_information(ei.failure_info), indent=2
)
msg = f"Batch Task Failure: {formatted_failure_info}\n"
self.report_job_error(job, msg=msg, stderr=stderr, stdout=stdout)
elif ei.result == bm.TaskExecutionResult.success:
self.report_job_success(job)
else:
msg = f"Unknown Azure task execution result: {ei.__dict__}"
msg = f"\nUnknown task execution result: {ei.__dict__}\n"
self.report_job_error(
job,
msg=msg,
Expand All @@ -470,7 +478,7 @@ def _report_task_status(self, job: SubmittedJobInfo):
else:
return True

def _report_node_errors(self, batch_job):
def _report_node_errors(self):
"""report node errors
Fails if start task fails on a node, or node state becomes unusable (this can
Expand All @@ -481,12 +489,10 @@ def _report_node_errors(self, batch_job):

node_list = self.batch_client.compute_node.list(self.pool_id)
for n in node_list:
if n.state == "unusable":
errors = []
if n.errors is not None:
for e in n.errors:
errors.append(e)
self.logger.error(f"An azure batch node became unusable: {errors}")
if n.state == bm.ComputeNodeState.unusable:
if n.errors:
errors = unpack_compute_node_errors(n)
self.logger.error(f"An Azure Batch node became unusable: {errors}")

if n.start_task_info is not None and (
n.start_task_info.result == bm.TaskExecutionResult.failure
Expand All @@ -508,7 +514,7 @@ def _report_node_errors(self, batch_job):
stdout_stream = ""

msg = (
"Azure start task execution failed: "
"Start task execution failed: "
f"{n.start_task_info.failure_info.message}.\n"
f"stderr:\n{stderr_stream}\n"
f"stdout:\n{stdout_stream}"
Expand Down Expand Up @@ -536,18 +542,12 @@ async def check_active_jobs(

for batch_job in active_jobs:
async with self.status_rate_limiter:
# fail on pool resize errors
self._report_pool_errors(batch_job)

async with self.status_rate_limiter:
# report any node errors
self._report_node_errors(batch_job)

# report the task failure or success
self._report_node_errors()
still_running = self._report_task_status(batch_job)

if still_running:
# report as still running
yield batch_job

def cancel_jobs(self, active_jobs: List[SubmittedJobInfo]):
Expand All @@ -569,8 +569,6 @@ def create_batch_pool(self):
version="latest",
)

# optional subnet network configuration
# requires AAD batch auth instead of batch key auth
network_config = None
if self.settings.pool_subnet_id is not None:
network_config = NetworkConfiguration(
Expand All @@ -579,15 +577,15 @@ def create_batch_pool(self):

# configure batch pool identity
batch_pool_identity = None
mrid = self.settings.managed_identity_resource_id
if self.settings.managed_identity_resource_id is not None:
batch_pool_identity = BatchPoolIdentity(
type=PoolIdentityType.USER_ASSIGNED,
user_assigned_identities={mrid: UserAssignedIdentities()},
user_assigned_identities={
self.settings.managed_identity_resource_id: UserAssignedIdentities()
},
)

# configure a container registry

# Specify container configuration, fetching an image
# https://docs.microsoft.com/en-us/azure/batch/batch-docker-container-workloads#prefetch-images-for-container-configuration
container_config = ContainerConfiguration(
Expand Down Expand Up @@ -639,7 +637,7 @@ def create_batch_pool(self):
# default to no start task
start_task_conf = None

# if configured us start task bash script from sas url
# if configured use start task bash script from sas url
if self.settings.node_start_task_sas_url is not None:
_SIMPLE_TASK_NAME = "start_task.sh"
start_task_admin = UserIdentity(
Expand Down Expand Up @@ -697,9 +695,11 @@ def create_batch_pool(self):
),
target_node_communication_mode=NodeCommunicationMode.CLASSIC,
)
# create pool if not exists
try:
self.logger.info(f"Creating pool: {self.pool_id}")
# we use the azure.mgmt.batch client to create the pool here because if you
# configure a managed identity for the batch nodes, the azure.batch client
# does not correctly apply it to the pool
self.batch_mgmt_client.pool.create(
resource_group_name=self.settings.resource_group_name,
account_name=self.settings.batch_account_name,
Expand Down
39 changes: 39 additions & 0 deletions snakemake_executor_plugin_azure_batch/util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import List

import msrest.authentication as msa
from azure.batch.models import ComputeNodeError, NameValuePair, TaskFailureInformation
from azure.core.pipeline import PipelineContext, PipelineRequest
from azure.core.pipeline.policies import BearerTokenCredentialPolicy
from azure.core.pipeline.transport import HttpRequest
Expand Down Expand Up @@ -52,3 +55,39 @@ def set_token(self):
def signed_session(self, session=None):
self.set_token()
return super(AzureIdentityCredentialAdapter, self).signed_session(session)


def _error_item(code: str, message: str) -> dict:
return {
"code": code,
"message": message,
"error_details": [],
}


def unpack_task_failure_information(failure_info: TaskFailureInformation) -> dict:
"""
Unpack task failure information into object
{ 'code': '', 'message': '', 'error_details': [{ 'detail': 'description' }]}
"""
error_item = _error_item(failure_info.code, failure_info.message)
for detail in failure_info.details:
if isinstance(detail, NameValuePair):
error_item["error_details"].append({detail.name: detail.value})
return error_item


def unpack_compute_node_errors(node_errors: List[ComputeNodeError]) -> list:
"""
Unpack a list of compute node errors as list of items
{ 'code': '', 'message': '', 'error_details': [{ 'detail': 'description' }]}
"""
errors = []
for node_error in node_errors:
error_item = _error_item(node_error.code, node_error.message)
if node_error.error_details:
for detail in node_error.error_details:
if isinstance(detail, NameValuePair):
error_item["error_details"].append({detail.name: detail.value})
errors.append(error_item)
return errors

0 comments on commit 3df271b

Please sign in to comment.