Skip to content

Commit

Permalink
Update on_kill function of CWLStepOperator to stop running docker con…
Browse files Browse the repository at this point in the history
…tainer
  • Loading branch information
michael-kotliar committed Jul 29, 2020
1 parent 4e69852 commit a0f607a
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 23 deletions.
41 changes: 19 additions & 22 deletions cwl_airflow/extensions/operators/cwlstepoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from cwl_airflow.utilities.cwl import (
execute_workflow_step,
get_containers,
kill_containers,
collect_reports
)
from cwl_airflow.utilities.report import post_status
Expand All @@ -30,34 +32,29 @@ def execute(self, context):

post_status(context)

self.job_data = collect_reports(context) # we need it also in "on_kill"
_, step_report = execute_workflow_step(
workflow=context["dag"].workflow,
task_id=self.task_id,
job_data=collect_reports(context),
job_data=self.job_data,
cwl_args=context["dag"].default_args["cwl"]
)

return step_report


# def on_kill(self):
# _logger.info("Stop docker containers")
# for cidfile in glob.glob(os.path.join(self.dag.default_args["cidfile_dir"], self.task_id + "*.cid")): # make this better, doesn't look good to read from self.dag.default_args
# try:
# with open(cidfile, "r") as inp_stream:
# _logger.debug(f"""Read container id from {cidfile}""")
# command = ["docker", "kill", inp_stream.read()]
# _logger.debug(f"""Call {" ".join(command)}""")
# p = subprocess.Popen(command, shell=False)
# try:
# p.wait(timeout=10)
# except subprocess.TimeoutExpired:
# p.kill()
# except Exception as ex:
# _logger.error(f"""Failed to stop docker container with ID from {cidfile}\n {ex}""")

# # _logger.info(f"""Delete temporary output directory {self.outdir}""")
# # try:
# # shutil.rmtree(self.outdir)
# # except Exception as ex:
# # _logger.error(f"""Failed to delete temporary output directory {self.outdir}\n {ex}""")
def on_kill(self):
"""
Function is called only if task is manually stopped, for example, from UI.
First, we need to find all cidfiles that correspond to the current step.
We can have more than one cidfile, if previous run of this step has failed.
We search for cidfiles in the subfolder "task_id" of the "tmp_folder" read
from "job_data". For all found cidfile we check if cointainer is still
running and try to stop it. If container was not running, was not found or
had been already successfully killed, we remove the correspondent cidfile.
If container was running but we failed to kill it do not remove cidfile.
"""

kill_containers(
get_containers(self.job_data, self.task_id)
)
52 changes: 51 additions & 1 deletion cwl_airflow/utilities/cwl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import zlib
import errno
import shutil
import docker
import logging
import binascii

from uuid import uuid4
Expand Down Expand Up @@ -52,7 +54,8 @@
get_rootname,
remove_field_from_dict,
get_uncompressed,
get_compressed
get_compressed,
get_files
)


Expand Down Expand Up @@ -322,6 +325,53 @@ def relocate_outputs(
return relocated_job_data, workflow_report


def get_containers(job_data, task_id):
"""
Searches for cidfiles in the "step_tmp_folder", loads
container IDs from the found files, adds them to dict
in a form of {cid: location}. If nothing found,
returns {}.
"""

containers = {}

step_tmp_folder, _, _, _ = get_temp_folders(
task_id=task_id,
job_data=job_data
)

for location in get_files(step_tmp_folder, ".*\\.cid$").values():
try:
with open(location, "r") as input_stream:
containers[input_stream.read()] = location
except OSError as err:
logging.error(f"Failed to read container ID \
from {location} due to \n{err}")

return containers


def kill_containers(containers):
"""
Iterates over "containers" dictionary received from "get_containers"
and tries to kill all running containers based on cid. If killed
container was not in "running" state, was successfully killed or not
found at all, removes correspondent cidfile.
"""

docker_client = docker.from_env()
for cid, location in containers.items():
try:
container = docker_client.containers.get(cid)
if container.status == "running":
container.kill()
os.remove(location)
except docker.errors.NotFound as err:
os.remove(location)
except docker.errors.APIError as err:
logging.error(f"Failed to kill container. \n {err}")


def execute_workflow_step(
workflow,
task_id,
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def get_version():
"pyjwt",
"connexion",
"tornado",
"docker",
"swagger-ui-bundle"
],
zip_safe=False,
Expand Down
1 change: 1 addition & 0 deletions tests/data/cid/dummy_1.cid
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
43dd79ede44946a1954d27327000d1c013cb39196c096ce70bc158ee7531a557
1 change: 1 addition & 0 deletions tests/data/cid/dummy_2.cid
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
000d1c013cb39196c096ce70bc158ee7531a55743dd79ede44946a1954d27327
61 changes: 61 additions & 0 deletions tests/test_cwl.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
convert_to_workflow,
get_default_cwl_args,
overwrite_deprecated_dag,
get_containers,
CWL_TMP_FOLDER,
CWL_OUTPUTS_FOLDER,
CWL_PICKLE_FOLDER,
Expand All @@ -49,6 +50,66 @@
tempfile.tempdir = "/private/tmp"


@pytest.mark.parametrize(
"task_id, cidfiles, control_containers",
[
(
"bam_to_bedgraph",
[
"dummy_1.cid",
"dummy_2.cid"
],
{
"43dd79ede44946a1954d27327000d1c013cb39196c096ce70bc158ee7531a557": "dummy_1.cid",
"000d1c013cb39196c096ce70bc158ee7531a55743dd79ede44946a1954d27327": "dummy_2.cid"
}
),
(
"bam_to_bedgraph",
[
"dummy_1.cid"
],
{
"43dd79ede44946a1954d27327000d1c013cb39196c096ce70bc158ee7531a557": "dummy_1.cid"
}
),
(
"bam_to_bedgraph",
[],
{}
)
]
)
def test_get_containers(task_id, cidfiles, control_containers, monkeypatch):
temp_home = tempfile.mkdtemp()
monkeypatch.delenv("AIRFLOW_HOME", raising=False)
monkeypatch.delenv("AIRFLOW_CONFIG", raising=False)
monkeypatch.setattr(
os.path,
"expanduser",
lambda x: x.replace("~", temp_home)
)

for cidfile in cidfiles:
shutil.copy(
os.path.join(DATA_FOLDER, "cid", cidfile),
get_dir(os.path.join(temp_home, task_id))
)

try:
containers = get_containers({"tmp_folder": temp_home}, task_id)
control_containers = {
cid: os.path.join(temp_home, task_id, filename)
for cid, filename in control_containers.items()
}
except (BaseException, Exception) as err:
assert False, f"Failed to run test. \n {err}"
finally:
shutil.rmtree(temp_home)
assert control_containers == containers, \
"Failed to find cidfiles"


@pytest.mark.parametrize(
"dag_location, workflow_location, control_deprecated_files",
[
Expand Down

0 comments on commit a0f607a

Please sign in to comment.