Skip to content

Commit

Permalink
Remove extraneous mocking that was causing NPEs in DataflowWorkUnitCl…
Browse files Browse the repository at this point in the history
…ientTest
  • Loading branch information
kennknowles committed Feb 2, 2024
1 parent 30a778b commit f03b115
Showing 1 changed file with 61 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,7 @@
package org.apache.beam.runners.dataflow.worker;

import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doCallRealMethod;
import static org.mockito.Mockito.when;

import com.google.api.client.http.LowLevelHttpResponse;
import com.google.api.client.json.Json;
import com.google.api.client.testing.http.MockHttpTransport;
import com.google.api.client.testing.http.MockLowLevelHttpRequest;
Expand Down Expand Up @@ -53,16 +49,13 @@
import org.apache.beam.sdk.util.FastNanoClockAndSleeper;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.rules.TestRule;
import org.junit.rules.Timeout;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -75,36 +68,34 @@ public class DataflowWorkUnitClientTest {
private static final String PROJECT_ID = "TEST_PROJECT_ID";
private static final String JOB_ID = "TEST_JOB_ID";
private static final String WORKER_ID = "TEST_WORKER_ID";

@Rule public TestRule restoreSystemProperties = new RestoreSystemProperties();
@Rule public TestRule restoreLogging = new RestoreDataflowLoggingMDC();
@Rule public ExpectedException expectedException = ExpectedException.none();
@Rule public FastNanoClockAndSleeper fastNanoClockAndSleeper = new FastNanoClockAndSleeper();
@Mock private MockHttpTransport transport;
@Mock private MockLowLevelHttpRequest request;
private DataflowWorkerHarnessOptions pipelineOptions;

@Before
public void setUp() throws Exception {
MockitoAnnotations.initMocks(this);
when(transport.buildRequest(anyString(), anyString())).thenReturn(request);
doCallRealMethod().when(request).getContentAsString();

DataflowWorkerHarnessOptions createPipelineOptionsWithTransport(MockHttpTransport transport) {
Dataflow service = new Dataflow(transport, Transport.getJsonFactory(), null);
pipelineOptions = PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class);
DataflowWorkerHarnessOptions pipelineOptions =
PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class);
pipelineOptions.setProject(PROJECT_ID);
pipelineOptions.setJobId(JOB_ID);
pipelineOptions.setWorkerId(WORKER_ID);
pipelineOptions.setGcpCredential(new TestCredential());
pipelineOptions.setDataflowClient(service);
pipelineOptions.setRegion("us-central1");
return pipelineOptions;
}

@Test
public void testCloudServiceCall() throws Exception {
WorkItem workItem = createWorkItem(PROJECT_ID, JOB_ID);

when(request.execute()).thenReturn(generateMockResponse(workItem));

MockLowLevelHttpResponse response = generateMockResponse(workItem);
MockLowLevelHttpRequest request = new MockLowLevelHttpRequest().setResponse(response);
MockHttpTransport transport =
new MockHttpTransport.Builder().setLowLevelHttpRequest(request).build();
DataflowWorkerHarnessOptions pipelineOptions = createPipelineOptionsWithTransport(transport);
WorkUnitClient client = new DataflowWorkUnitClient(pipelineOptions, LOG);

assertEquals(Optional.of(workItem), client.getWorkItem());
Expand All @@ -124,30 +115,40 @@ public void testCloudServiceCall() throws Exception {

@Test
public void testCloudServiceCallMapTaskStagePropagation() throws Exception {
WorkUnitClient client = new DataflowWorkUnitClient(pipelineOptions, LOG);

// Publish and acquire a map task work item, and verify we're now processing that stage.
final String stageName = "test_stage_name";
MapTask mapTask = new MapTask();
mapTask.setStageName(stageName);
WorkItem workItem = createWorkItem(PROJECT_ID, JOB_ID);
workItem.setMapTask(mapTask);
when(request.execute()).thenReturn(generateMockResponse(workItem));

MockLowLevelHttpResponse response = generateMockResponse(workItem);
MockLowLevelHttpRequest request = new MockLowLevelHttpRequest().setResponse(response);
MockHttpTransport transport =
new MockHttpTransport.Builder().setLowLevelHttpRequest(request).build();
DataflowWorkerHarnessOptions pipelineOptions = createPipelineOptionsWithTransport(transport);
WorkUnitClient client = new DataflowWorkUnitClient(pipelineOptions, LOG);

assertEquals(Optional.of(workItem), client.getWorkItem());
assertEquals(stageName, DataflowWorkerLoggingMDC.getStageName());
}

@Test
public void testCloudServiceCallSeqMapTaskStagePropagation() throws Exception {
WorkUnitClient client = new DataflowWorkUnitClient(pipelineOptions, LOG);

// Publish and acquire a seq map task work item, and verify we're now processing that stage.
final String stageName = "test_stage_name";
SeqMapTask seqMapTask = new SeqMapTask();
seqMapTask.setStageName(stageName);
WorkItem workItem = createWorkItem(PROJECT_ID, JOB_ID);
workItem.setSeqMapTask(seqMapTask);
when(request.execute()).thenReturn(generateMockResponse(workItem));

MockLowLevelHttpResponse response = generateMockResponse(workItem);
MockLowLevelHttpRequest request = new MockLowLevelHttpRequest().setResponse(response);
MockHttpTransport transport =
new MockHttpTransport.Builder().setLowLevelHttpRequest(request).build();
DataflowWorkerHarnessOptions pipelineOptions = createPipelineOptionsWithTransport(transport);
WorkUnitClient client = new DataflowWorkUnitClient(pipelineOptions, LOG);

assertEquals(Optional.of(workItem), client.getWorkItem());
assertEquals(stageName, DataflowWorkerLoggingMDC.getStageName());
}
Expand All @@ -157,8 +158,11 @@ public void testCloudServiceCallNoWorkPresent() throws Exception {
// If there's no work the service should return an empty work item.
WorkItem workItem = new WorkItem();

when(request.execute()).thenReturn(generateMockResponse(workItem));

MockLowLevelHttpResponse response = generateMockResponse(workItem);
MockLowLevelHttpRequest request = new MockLowLevelHttpRequest().setResponse(response);
MockHttpTransport transport =
new MockHttpTransport.Builder().setLowLevelHttpRequest(request).build();
DataflowWorkerHarnessOptions pipelineOptions = createPipelineOptionsWithTransport(transport);
WorkUnitClient client = new DataflowWorkUnitClient(pipelineOptions, LOG);

assertEquals(Optional.empty(), client.getWorkItem());
Expand All @@ -181,8 +185,11 @@ public void testCloudServiceCallNoWorkId() throws Exception {
WorkItem workItem = createWorkItem(PROJECT_ID, JOB_ID);
workItem.setId(null);

when(request.execute()).thenReturn(generateMockResponse(workItem));

MockLowLevelHttpResponse response = generateMockResponse(workItem);
MockLowLevelHttpRequest request = new MockLowLevelHttpRequest().setResponse(response);
MockHttpTransport transport =
new MockHttpTransport.Builder().setLowLevelHttpRequest(request).build();
DataflowWorkerHarnessOptions pipelineOptions = createPipelineOptionsWithTransport(transport);
WorkUnitClient client = new DataflowWorkUnitClient(pipelineOptions, LOG);

assertEquals(Optional.empty(), client.getWorkItem());
Expand All @@ -201,8 +208,11 @@ public void testCloudServiceCallNoWorkId() throws Exception {

@Test
public void testCloudServiceCallNoWorkItem() throws Exception {
when(request.execute()).thenReturn(generateMockResponse());

MockLowLevelHttpResponse response = generateMockResponse();
MockLowLevelHttpRequest request = new MockLowLevelHttpRequest().setResponse(response);
MockHttpTransport transport =
new MockHttpTransport.Builder().setLowLevelHttpRequest(request).build();
DataflowWorkerHarnessOptions pipelineOptions = createPipelineOptionsWithTransport(transport);
WorkUnitClient client = new DataflowWorkUnitClient(pipelineOptions, LOG);

assertEquals(Optional.empty(), client.getWorkItem());
Expand All @@ -228,8 +238,11 @@ public void testCloudServiceCallMultipleWorkItems() throws Exception {
WorkItem workItem1 = createWorkItem(PROJECT_ID, JOB_ID);
WorkItem workItem2 = createWorkItem(PROJECT_ID, JOB_ID);

when(request.execute()).thenReturn(generateMockResponse(workItem1, workItem2));

MockLowLevelHttpResponse response = generateMockResponse(workItem1, workItem2);
MockLowLevelHttpRequest request = new MockLowLevelHttpRequest().setResponse(response);
MockHttpTransport transport =
new MockHttpTransport.Builder().setLowLevelHttpRequest(request).build();
DataflowWorkerHarnessOptions pipelineOptions = createPipelineOptionsWithTransport(transport);
WorkUnitClient client = new DataflowWorkUnitClient(pipelineOptions, LOG);

client.getWorkItem();
Expand All @@ -242,7 +255,13 @@ public void testReportWorkerMessage_streamingScalingReport() throws Exception {
SendWorkerMessagesResponse workerMessage = new SendWorkerMessagesResponse();
workerMessage.setFactory(Transport.getJsonFactory());
response.setContent(workerMessage.toPrettyString());
when(request.execute()).thenReturn(response);

MockLowLevelHttpRequest request = new MockLowLevelHttpRequest().setResponse(response);
MockHttpTransport transport =
new MockHttpTransport.Builder().setLowLevelHttpRequest(request).build();
DataflowWorkerHarnessOptions pipelineOptions = createPipelineOptionsWithTransport(transport);
WorkUnitClient client = new DataflowWorkUnitClient(pipelineOptions, LOG);

StreamingScalingReport activeThreadsReport =
new StreamingScalingReport()
.setActiveThreadCount(1)
Expand All @@ -251,7 +270,6 @@ public void testReportWorkerMessage_streamingScalingReport() throws Exception {
.setMaximumThreadCount(4)
.setMaximumBundleCount(5)
.setMaximumBytes(6L);
WorkUnitClient client = new DataflowWorkUnitClient(pipelineOptions, LOG);
WorkerMessage msg = client.createWorkerMessageFromStreamingScalingReport(activeThreadsReport);
client.reportWorkerMessage(Collections.singletonList(msg));

Expand All @@ -268,7 +286,13 @@ public void testReportWorkerMessage_perWorkerMetrics() throws Exception {
SendWorkerMessagesResponse workerMessage = new SendWorkerMessagesResponse();
workerMessage.setFactory(Transport.getJsonFactory());
response.setContent(workerMessage.toPrettyString());
when(request.execute()).thenReturn(response);

MockLowLevelHttpRequest request = new MockLowLevelHttpRequest().setResponse(response);
MockHttpTransport transport =
new MockHttpTransport.Builder().setLowLevelHttpRequest(request).build();
DataflowWorkerHarnessOptions pipelineOptions = createPipelineOptionsWithTransport(transport);
WorkUnitClient client = new DataflowWorkUnitClient(pipelineOptions, LOG);

PerStepNamespaceMetrics stepNamespaceMetrics =
new PerStepNamespaceMetrics()
.setOriginalStep("s1")
Expand All @@ -279,7 +303,6 @@ public void testReportWorkerMessage_perWorkerMetrics() throws Exception {
new PerWorkerMetrics()
.setPerStepNamespaceMetrics(Collections.singletonList(stepNamespaceMetrics));

WorkUnitClient client = new DataflowWorkUnitClient(pipelineOptions, LOG);
WorkerMessage perWorkerMetricsMsg =
client.createWorkerMessageFromPerWorkerMetrics(perWorkerMetrics);
client.reportWorkerMessage(Collections.singletonList(perWorkerMetricsMsg));
Expand All @@ -290,7 +313,7 @@ public void testReportWorkerMessage_perWorkerMetrics() throws Exception {
assertEquals(ImmutableList.of(perWorkerMetricsMsg), actualRequest.getWorkerMessages());
}

private LowLevelHttpResponse generateMockResponse(WorkItem... workItems) throws Exception {
private MockLowLevelHttpResponse generateMockResponse(WorkItem... workItems) throws Exception {
MockLowLevelHttpResponse response = new MockLowLevelHttpResponse();
response.setContentType(Json.MEDIA_TYPE);
LeaseWorkItemResponse lease = new LeaseWorkItemResponse();
Expand Down

0 comments on commit f03b115

Please sign in to comment.