Skip to content

Commit

Permalink
Add stage_file_with_retry and replace stage_file with it in apiclient.
Browse files Browse the repository at this point in the history
  • Loading branch information
shunping committed Dec 21, 2024
1 parent a51a0e1 commit 0ffdd54
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 10 deletions.
35 changes: 28 additions & 7 deletions sdks/python/apache_beam/runners/dataflow/internal/apiclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,13 +557,11 @@ def _cached_gcs_file_copy(self, from_path, to_path, sha256):
source_file_names=[cached_path], destination_file_names=[to_path])
_LOGGER.info('Copied cached artifact from %s to %s', from_path, to_path)

@retry.with_exponential_backoff(
retry_filter=retry.retry_on_server_errors_and_timeout_filter)
def _uncached_gcs_file_copy(self, from_path, to_path):
to_folder, to_name = os.path.split(to_path)
total_size = os.path.getsize(from_path)
with open(from_path, 'rb') as f:
self.stage_file(to_folder, to_name, f, total_size=total_size)
self.stage_file_with_retry(
to_folder, to_name, from_path, total_size=total_size)

def _stage_resources(self, pipeline, options):
google_cloud_options = options.view_as(GoogleCloudOptions)
Expand Down Expand Up @@ -692,6 +690,29 @@ def stage_file(
(gcs_or_local_path, e))
raise

@retry.with_exponential_backoff(
retry_filter=retry.retry_on_server_errors_and_timeout_filter)
def stage_file_with_retry(
self,
gcs_or_local_path,
file_name,
stream_or_path,
mime_type='application/octet-stream',
total_size=None):

if isinstance(stream_or_path, str):
path = stream_or_path
with open(path, 'rb') as stream:
self.stage_file(
gcs_or_local_path, file_name, stream, mime_type, total_size)
elif isinstance(stream_or_path, io.BufferedIOBase):
stream = stream_or_path
assert stream.seekable(), "stream must be seekable"
if stream.tell() > 0:
stream.seek(0)
self.stage_file(
gcs_or_local_path, file_name, stream, mime_type, total_size)

@retry.no_retries # Using no_retries marks this as an integration point.
def create_job(self, job):
"""Creates job description. May stage and/or submit for remote execution."""
Expand All @@ -703,7 +724,7 @@ def create_job(self, job):
job.options.view_as(GoogleCloudOptions).template_location)

if job.options.view_as(DebugOptions).lookup_experiment('upload_graph'):
self.stage_file(
self.stage_file_with_retry(
job.options.view_as(GoogleCloudOptions).staging_location,
"dataflow_graph.json",
io.BytesIO(job.json().encode('utf-8')))
Expand All @@ -718,7 +739,7 @@ def create_job(self, job):
if job_location:
gcs_or_local_path = os.path.dirname(job_location)
file_name = os.path.basename(job_location)
self.stage_file(
self.stage_file_with_retry(
gcs_or_local_path, file_name, io.BytesIO(job.json().encode('utf-8')))

if not template_location:
Expand Down Expand Up @@ -790,7 +811,7 @@ def create_job_description(self, job):
resources = self._stage_resources(job.proto_pipeline, job.options)

# Stage proto pipeline.
self.stage_file(
self.stage_file_with_retry(
job.google_cloud_options.staging_location,
shared_names.STAGED_PIPELINE_FILENAME,
io.BytesIO(job.proto_pipeline.SerializeToString()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@

# pytype: skip-file

import io
import itertools
import json
import logging
import os
import sys
import time
import unittest

import mock
Expand Down Expand Up @@ -1064,7 +1066,11 @@ def test_graph_is_uploaded(self):
side_effect=None):
client.create_job(job)
client.stage_file.assert_called_once_with(
mock.ANY, "dataflow_graph.json", mock.ANY)
mock.ANY,
"dataflow_graph.json",
mock.ANY,
'application/octet-stream',
None)
client.create_job_description.assert_called_once()

def test_create_job_returns_existing_job(self):
Expand Down Expand Up @@ -1174,8 +1180,18 @@ def test_template_file_generation_with_upload_graph(self):
client.create_job(job)

client.stage_file.assert_has_calls([
mock.call(mock.ANY, 'dataflow_graph.json', mock.ANY),
mock.call(mock.ANY, 'template', mock.ANY)
mock.call(
mock.ANY,
'dataflow_graph.json',
mock.ANY,
'application/octet-stream',
None),
mock.call(
mock.ANY,
'template',
mock.ANY,
'application/octet-stream',
None)
])
client.create_job_description.assert_called_once()
# template is generated, but job should not be submitted to the
Expand Down Expand Up @@ -1653,6 +1669,50 @@ def exists_return_value(*args):
}))
self.assertEqual(pipeline, pipeline_expected)

def test_stage_file_with_retry(self):
count = 0

def effect(self, *args, **kwargs):
nonlocal count
count += 1
if count > 1:
return
raise Exception("This exception is raised for testing purpose.")

pipeline_options = PipelineOptions([
'--project',
'test_project',
'--job_name',
'test_job_name',
'--temp_location',
'gs://test-location/temp',
])
pipeline_options.view_as(GoogleCloudOptions).no_auth = True
client = apiclient.DataflowApplicationClient(pipeline_options)

with mock.patch.object(time, 'sleep'):
count = 0
with mock.patch("builtins.open",
mock.mock_open(read_data="data")) as mock_file_open:
with mock.patch.object(client, 'stage_file') as mock_stage_file:
mock_stage_file.side_effect = effect
# call with a file name
client.stage_file_with_retry(
"/to", "new_name", "/from/old_name", total_size=1024)
self.assertEqual(mock_file_open.call_count, 2)
self.assertEqual(mock_stage_file.call_count, 2)

count = 0
with mock.patch("builtins.open",
mock.mock_open(read_data="data")) as mock_file_open:
with mock.patch.object(client, 'stage_file') as mock_stage_file:
mock_stage_file.side_effect = effect
# call with a seekable stream
client.stage_file_with_retry(
"/to", "new_name", io.BytesIO(b'test'), total_size=4)
mock_file_open.assert_not_called()
self.assertEqual(mock_stage_file.call_count, 2)


if __name__ == '__main__':
unittest.main()

0 comments on commit 0ffdd54

Please sign in to comment.