Skip to content

Commit

Permalink
Extract MPI code from execute_task()
Browse files Browse the repository at this point in the history
The `execute_task()` function is used by multiple executors, but the MPI
code is specific to HTEX.
  • Loading branch information
rjmello committed Nov 19, 2024
1 parent 9fb5269 commit 2d8edf6
Showing 1 changed file with 24 additions and 20 deletions.
44 changes: 24 additions & 20 deletions parsl/executors/high_throughput/process_worker_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,28 +590,25 @@ def update_resource_spec_env_vars(mpi_launcher: str, resource_spec: Dict, node_i
os.environ[key] = prefix_table[key]


def execute_task(bufs, mpi_launcher: Optional[str] = None):
"""Deserialize the buffer and execute the task.
def _init_mpi_env(mpi_launcher: str, resource_spec: Dict):
node_list = resource_spec.get("MPI_NODELIST")
if node_list is None:
return
nodes_for_task = node_list.split(',')
logger.info(f"Launching task on provisioned nodes: {nodes_for_task}")
update_resource_spec_env_vars(mpi_launcher=mpi_launcher, resource_spec=resource_spec, node_info=nodes_for_task)


def execute_task(bufs: bytes):
"""Deserialize the buffer and execute the task.
Returns the result or throws exception.
"""
user_ns = locals()
user_ns.update({'__builtins__': __builtins__})

f, args, kwargs, resource_spec = unpack_res_spec_apply_message(bufs, user_ns, copy=False)
f, args, kwargs, resource_spec = unpack_res_spec_apply_message(bufs, copy=False)

for varname in resource_spec:
envname = "PARSL_" + str(varname).upper()
os.environ[envname] = str(resource_spec[varname])

if resource_spec.get("MPI_NODELIST"):
worker_id = os.environ['PARSL_WORKER_RANK']
nodes_for_task = resource_spec["MPI_NODELIST"].split(',')
logger.info(f"Launching task on provisioned nodes: {nodes_for_task}")
assert mpi_launcher
update_resource_spec_env_vars(mpi_launcher,
resource_spec=resource_spec,
node_info=nodes_for_task)
# We might need to look into callability of the function from itself
# since we change it's name in the new namespace
prefix = "parsl_"
Expand All @@ -620,13 +617,18 @@ def execute_task(bufs, mpi_launcher: Optional[str] = None):
kwargname = prefix + "kwargs"
resultname = prefix + "result"

user_ns.update({fname: f,
argname: args,
kwargname: kwargs,
resultname: resultname})

code = "{0} = {1}(*{2}, **{3})".format(resultname, fname,
argname, kwargname)

user_ns = locals()
user_ns.update({
'__builtins__': __builtins__,
fname: f,
argname: args,
kwargname: kwargs,
resultname: resultname
})

exec(code, user_ns, user_ns)
return user_ns.get(resultname)

Expand Down Expand Up @@ -786,8 +788,10 @@ def manager_is_alive():
ready_worker_count.value -= 1
worker_enqueued = False

_init_mpi_env(mpi_launcher=mpi_launcher, resource_spec=req["resource_spec"])

try:
result = execute_task(req['buffer'], mpi_launcher=mpi_launcher)
result = execute_task(req['buffer'])
serialized_result = serialize(result, buffer_threshold=1000000)
except Exception as e:
logger.info('Caught an exception: {}'.format(e))
Expand Down

0 comments on commit 2d8edf6

Please sign in to comment.