Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/ahmedabu98/beam into mana…
Browse files Browse the repository at this point in the history
…ged_bigquery
  • Loading branch information
ahmedabu98 committed Nov 8, 2024
2 parents c0767d7 + 2488ca1 commit a600f62
Show file tree
Hide file tree
Showing 54 changed files with 849 additions and 444 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run"
"comment": "Modify this file in a trivial way to cause this test suite to run",
"modification": 1
}
1 change: 1 addition & 0 deletions .github/workflows/beam_PreCommit_Java.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ on:
tags: ['v*']
branches: ['master', 'release-*']
paths:
- "buildSrc/**"
- 'model/**'
- 'sdks/java/**'
- 'runners/**'
Expand Down
1 change: 0 additions & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@
* Removed support for Flink 1.15 and 1.16
* Removed support for Python 3.8
* X behavior is deprecated and will be removed in X versions ([#X](https://github.com/apache/beam/issues/X)).
* Upgrade antlr from 4.7 to 4.13.1 ([#33016](https://github.com/apache/beam/pull/33016)).

## Bugfixes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -665,8 +665,8 @@ class BeamModulePlugin implements Plugin<Project> {
activemq_junit : "org.apache.activemq.tooling:activemq-junit:$activemq_version",
activemq_kahadb_store : "org.apache.activemq:activemq-kahadb-store:$activemq_version",
activemq_mqtt : "org.apache.activemq:activemq-mqtt:$activemq_version",
antlr : "org.antlr:antlr4:4.13.1",
antlr_runtime : "org.antlr:antlr4-runtime:4.13.1",
antlr : "org.antlr:antlr4:4.7",
antlr_runtime : "org.antlr:antlr4-runtime:4.7",
args4j : "args4j:args4j:2.33",
auto_value_annotations : "com.google.auto.value:auto-value-annotations:$autovalue_version",
avro : "org.apache.avro:avro:1.11.3",
Expand Down
14 changes: 8 additions & 6 deletions examples/notebooks/beam-ml/run_inference_vllm.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -352,11 +352,11 @@
"\n",
"1. In the sidebar, click **Files** to open the **Files** pane.\n",
"2. In an environment with Docker installed, download the file **VllmDockerfile** file to an empty folder.\n",
"3. Run the following commands. Replace `<REPOSITORY_NAME>` with a valid [Artifact Registry](https://cloud.google.com/artifact-registry/docs/overview) repository.\n",
"3. Run the following commands. Replace `<REPOSITORY_NAME>:<TAG>` with a valid [Artifact Registry](https://cloud.google.com/artifact-registry/docs/overview) repository and tag.\n",
"\n",
" ```\n",
" docker build -t \"<REPOSITORY_NAME>:latest\" -f VllmDockerfile ./\n",
" docker image push \"<REPOSITORY_NAME>:latest\"\n",
" docker build -t \"<REPOSITORY_NAME>:<TAG>\" -f VllmDockerfile ./\n",
" docker image push \"<REPOSITORY_NAME>:<TAG>\"\n",
" ```"
],
"metadata": {
Expand All @@ -373,7 +373,8 @@
"First, define the pipeline options that you want to use to launch the Dataflow job. Before running the next cell, replace the following variables:\n",
"\n",
"- `<BUCKET_NAME>`: the name of a valid [Google Cloud Storage](https://cloud.google.com/storage?e=48754805&hl=en) bucket. Don't include a `gs://` prefix or trailing slashes.\n",
"- `<REPOSITORY_NAME>`: the name of the Google Artifact Registry repository that you used in the previous step. Don't include the `latest` tag, because this tag is automatically appended as part of the cell.\n",
"- `<REPOSITORY_NAME>`: the name of the Google Artifact Registry repository that you used in the previous step. \n",
"- `<IMAGE_TAG>`: image tag used in the previous step. Prefer a versioned tag or SHA instead of :latest tag or mutable tags.\n",
"- `<PROJECT_ID>`: the name of the Google Cloud project that you created your bucket and Artifact Registry repository in.\n",
"\n",
"This workflow uses the following Dataflow service option: `worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver:5xx`. When you use this service option, Dataflow to installs a T4 GPU that uses a `5xx` series Nvidia driver on each worker machine. The 5xx driver is required to run vLLM jobs."
Expand All @@ -396,7 +397,7 @@
"options = PipelineOptions()\n",
"\n",
"BUCKET_NAME = '<BUCKET_NAME>' # Replace with your bucket name.\n",
"CONTAINER_LOCATION = '<REPOSITORY_NAME>' # Replace with your container location (<your_gar_repository> from the previous step)\n",
"CONTAINER_IMAGE = '<REPOSITORY_NAME>:<TAG>' # Replace with the image repository and tag from the previous step.\n",
"PROJECT_NAME = '<PROJECT_ID>' # Replace with your GCP project\n",
"\n",
"options.view_as(GoogleCloudOptions).project = PROJECT_NAME\n",
Expand Down Expand Up @@ -428,7 +429,7 @@
"# Choose a machine type compatible with GPU type\n",
"options.view_as(WorkerOptions).machine_type = \"n1-standard-4\"\n",
"\n",
"options.view_as(WorkerOptions).worker_harness_container_image = '%s:latest' % CONTAINER_LOCATION"
"options.view_as(WorkerOptions).sdk_container_image = CONTAINER_IMAGE"
],
"metadata": {
"id": "kXy9FRYVCSjq"
Expand Down Expand Up @@ -484,6 +485,7 @@
" def process(self, element, *args, **kwargs):\n",
" yield \"Input: {input}, Output: {output}\".format(input=element.example, output=element.inference)\n",
"\n",
"logging.getLogger().setLevel(logging.INFO) # Output additional Dataflow Job metadata and launch logs. \n",
"prompts = [\n",
" \"Hello, my name is\",\n",
" \"The president of the United States is\",\n",
Expand Down
22 changes: 12 additions & 10 deletions runners/direct-java/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ plugins { id 'org.apache.beam.module' }
// Shade away runner execution utilities till because this causes ServiceLoader conflicts with
// TransformPayloadTranslatorRegistrar amongst other runners. This only happens in the DirectRunner
// because it is likely to appear on the classpath of another runner.
def dependOnProjects = [
":runners:core-java",
":runners:local-java",
":runners:java-fn-execution",
":sdks:java:core",
]
def dependOnProjectsAndConfigs = [
":runners:core-java":null,
":runners:local-java":null,
":runners:java-fn-execution":null,
":sdks:java:core":"shadow",
]

applyJavaNature(
automaticModuleName: 'org.apache.beam.runners.direct',
Expand All @@ -36,8 +36,8 @@ applyJavaNature(
],
shadowClosure: {
dependencies {
dependOnProjects.each {
include(project(path: it, configuration: "shadow"))
dependOnProjectsAndConfigs.each {
include(project(path: it.key, configuration: "shadow"))
}
}
},
Expand All @@ -63,8 +63,10 @@ configurations {
dependencies {
shadow library.java.vendored_guava_32_1_2_jre
shadow project(path: ":model:pipeline", configuration: "shadow")
dependOnProjects.each {
implementation project(it)
dependOnProjectsAndConfigs.each {
// For projects producing shadowjar, use the packaged jar as dependency to
// handle redirected packages from it
implementation project(path: it.key, configuration: it.value)
}
shadow library.java.vendored_grpc_1_60_1
shadow library.java.joda_time
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
import org.apache.beam.runners.dataflow.worker.windmill.appliance.JniWindmillApplianceServer;
import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamPool;
import org.apache.beam.runners.dataflow.worker.windmill.client.commits.Commits;
import org.apache.beam.runners.dataflow.worker.windmill.client.commits.CompleteCommit;
import org.apache.beam.runners.dataflow.worker.windmill.client.commits.StreamingApplianceWorkCommitter;
import org.apache.beam.runners.dataflow.worker.windmill.client.commits.StreamingEngineWorkCommitter;
Expand Down Expand Up @@ -199,6 +200,7 @@ private StreamingDataflowWorker(
this.workCommitter =
windmillServiceEnabled
? StreamingEngineWorkCommitter.builder()
.setCommitByteSemaphore(Commits.maxCommitByteSemaphore())
.setCommitWorkStreamFactory(
WindmillStreamPool.create(
numCommitThreads,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,49 +18,40 @@
package org.apache.beam.runners.dataflow.worker.streaming;

import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.checkerframework.checker.nullness.qual.Nullable;

/** Bounded set of queues, with a maximum total weight. */
/** Queue bounded by a {@link WeightedSemaphore}. */
public final class WeightedBoundedQueue<V> {

private final LinkedBlockingQueue<V> queue;
private final int maxWeight;
private final Semaphore limit;
private final Function<V, Integer> weigher;
private final WeightedSemaphore<V> weightedSemaphore;

private WeightedBoundedQueue(
LinkedBlockingQueue<V> linkedBlockingQueue,
int maxWeight,
Semaphore limit,
Function<V, Integer> weigher) {
LinkedBlockingQueue<V> linkedBlockingQueue, WeightedSemaphore<V> weightedSemaphore) {
this.queue = linkedBlockingQueue;
this.maxWeight = maxWeight;
this.limit = limit;
this.weigher = weigher;
this.weightedSemaphore = weightedSemaphore;
}

public static <V> WeightedBoundedQueue<V> create(int maxWeight, Function<V, Integer> weigherFn) {
return new WeightedBoundedQueue<>(
new LinkedBlockingQueue<>(), maxWeight, new Semaphore(maxWeight, true), weigherFn);
public static <V> WeightedBoundedQueue<V> create(WeightedSemaphore<V> weightedSemaphore) {
return new WeightedBoundedQueue<>(new LinkedBlockingQueue<>(), weightedSemaphore);
}

/**
* Adds the value to the queue, blocking if this would cause the overall weight to exceed the
* limit.
*/
public void put(V value) {
limit.acquireUninterruptibly(weigher.apply(value));
weightedSemaphore.acquireUninterruptibly(value);
queue.add(value);
}

/** Returns and removes the next value, or null if there is no such value. */
public @Nullable V poll() {
V result = queue.poll();
@Nullable V result = queue.poll();
if (result != null) {
limit.release(weigher.apply(result));
weightedSemaphore.release(result);
}
return result;
}
Expand All @@ -76,26 +67,22 @@ public void put(V value) {
* @throws InterruptedException if interrupted while waiting
*/
public @Nullable V poll(long timeout, TimeUnit unit) throws InterruptedException {
V result = queue.poll(timeout, unit);
@Nullable V result = queue.poll(timeout, unit);
if (result != null) {
limit.release(weigher.apply(result));
weightedSemaphore.release(result);
}
return result;
}

/** Returns and removes the next value, or blocks until one is available. */
public @Nullable V take() throws InterruptedException {
public V take() throws InterruptedException {
V result = queue.take();
limit.release(weigher.apply(result));
weightedSemaphore.release(result);
return result;
}

/** Returns the current weight of the queue. */
public int queuedElementsWeight() {
return maxWeight - limit.availablePermits();
}

public int size() {
@VisibleForTesting
int size() {
return queue.size();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.dataflow.worker.streaming;

import java.util.concurrent.Semaphore;
import java.util.function.Function;

public final class WeightedSemaphore<V> {
private final int maxWeight;
private final Semaphore limit;
private final Function<V, Integer> weigher;

private WeightedSemaphore(int maxWeight, Semaphore limit, Function<V, Integer> weigher) {
this.maxWeight = maxWeight;
this.limit = limit;
this.weigher = weigher;
}

public static <V> WeightedSemaphore<V> create(int maxWeight, Function<V, Integer> weigherFn) {
return new WeightedSemaphore<>(maxWeight, new Semaphore(maxWeight, true), weigherFn);
}

public void acquireUninterruptibly(V value) {
limit.acquireUninterruptibly(computePermits(value));
}

public void release(V value) {
limit.release(computePermits(value));
}

private int computePermits(V value) {
return Math.min(weigher.apply(value), maxWeight);
}

public int currentWeight() {
return maxWeight - limit.availablePermits();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.dataflow.worker.windmill.client.commits;

import org.apache.beam.runners.dataflow.worker.streaming.WeightedSemaphore;
import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;

/** Utility class for commits. */
@Internal
public final class Commits {

/** Max bytes of commits queued on the user worker. */
@VisibleForTesting static final int MAX_QUEUED_COMMITS_BYTES = 500 << 20; // 500MB

private Commits() {}

public static WeightedSemaphore<Commit> maxCommitByteSemaphore() {
return WeightedSemaphore.create(MAX_QUEUED_COMMITS_BYTES, Commit::getSize);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
public final class StreamingApplianceWorkCommitter implements WorkCommitter {
private static final Logger LOG = LoggerFactory.getLogger(StreamingApplianceWorkCommitter.class);
private static final long TARGET_COMMIT_BUNDLE_BYTES = 32 << 20;
private static final int MAX_COMMIT_QUEUE_BYTES = 500 << 20; // 500MB

private final Consumer<CommitWorkRequest> commitWorkFn;
private final WeightedBoundedQueue<Commit> commitQueue;
Expand All @@ -53,9 +52,7 @@ public final class StreamingApplianceWorkCommitter implements WorkCommitter {
private StreamingApplianceWorkCommitter(
Consumer<CommitWorkRequest> commitWorkFn, Consumer<CompleteCommit> onCommitComplete) {
this.commitWorkFn = commitWorkFn;
this.commitQueue =
WeightedBoundedQueue.create(
MAX_COMMIT_QUEUE_BYTES, commit -> Math.min(MAX_COMMIT_QUEUE_BYTES, commit.getSize()));
this.commitQueue = WeightedBoundedQueue.create(Commits.maxCommitByteSemaphore());
this.commitWorkers =
Executors.newSingleThreadExecutor(
new ThreadFactoryBuilder()
Expand All @@ -73,10 +70,9 @@ public static StreamingApplianceWorkCommitter create(
}

@Override
@SuppressWarnings("FutureReturnValueIgnored")
public void start() {
if (!commitWorkers.isShutdown()) {
commitWorkers.submit(this::commitLoop);
commitWorkers.execute(this::commitLoop);
}
}

Expand Down
Loading

0 comments on commit a600f62

Please sign in to comment.