Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Python multi-language guide #33348

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 10 additions & 20 deletions examples/multi-language/python/wordcount_external.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import logging

import apache_beam as beam
from apache_beam.io import ReadFromText
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.transforms.external_transform_provider import ExternalTransformProvider
from apache_beam.typehints.row_type import RowTypeConstraint
Expand Down Expand Up @@ -60,39 +59,30 @@
--expansion_service_port <PORT>
"""

# Original Java transform is in ExtractWordsProvider.java
EXTRACT_IDENTIFIER = "beam:schematransform:org.apache.beam:extract_words:v1"
# Original Java transform is in JavaCountProvider.java
COUNT_IDENTIFIER = "beam:schematransform:org.apache.beam:count:v1"
# Original Java transform is in WriteWordsProvider.java
WRITE_IDENTIFIER = "beam:schematransform:org.apache.beam:write_words:v1"


def run(input_path, output_path, expansion_service_port, pipeline_args):
pipeline_options = PipelineOptions(pipeline_args)

# Discover and get external transforms from this expansion service
provider = ExternalTransformProvider("localhost:" + expansion_service_port)
# Get transforms with identifiers, then use them as you would a regular
# native PTransform
# Retrieve portable transforms
Extract = provider.get_urn(EXTRACT_IDENTIFIER)
Count = provider.get_urn(COUNT_IDENTIFIER)
Write = provider.get_urn(WRITE_IDENTIFIER)

with beam.Pipeline(options=pipeline_options) as p:
lines = p | 'Read' >> ReadFromText(input_path)

words = (lines
| 'Prepare Rows' >> beam.Map(lambda line: beam.Row(line=line))
| 'Extract Words' >> Extract())
word_counts = words | 'Count Words' >> Count()
formatted_words = (
word_counts
| 'Format Text' >> beam.Map(lambda row: beam.Row(line="%s: %s" % (
row.word, row.count))).with_output_types(
RowTypeConstraint.from_fields([('line', str)])))

formatted_words | 'Write' >> Write(file_path_prefix=output_path)
_ = (p
| 'Read' >> beam.io.ReadFromText(input_path)
| 'Prepare Rows' >> beam.Map(lambda line: beam.Row(line=line))
| 'Extract Words' >> Extract(filter=["king", "palace"])
| 'Count Words' >> Count()
| 'Format Text' >> beam.Map(lambda row: beam.Row(line="%s: %s" % (
row.word, row.count))).with_output_types(
RowTypeConstraint.from_fields([('line', str)]))
| 'Write' >> Write(file_path_prefix=output_path))


if __name__ == '__main__':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@

import com.google.auto.service.AutoService;
import com.google.auto.value.AutoValue;
import java.util.Arrays;
import java.util.List;
import org.apache.beam.sdk.schemas.AutoValueSchema;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.annotations.DefaultSchema;
import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription;
import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider;
Expand All @@ -36,7 +39,6 @@
/** Splits a line into separate words and returns each word. */
@AutoService(SchemaTransformProvider.class)
public class ExtractWordsProvider extends TypedSchemaTransformProvider<Configuration> {
public static final Schema OUTPUT_SCHEMA = Schema.builder().addStringField("word").build();

@Override
public String identifier() {
Expand All @@ -45,32 +47,60 @@ public String identifier() {

@Override
protected SchemaTransform from(Configuration configuration) {
return new SchemaTransform() {
@Override
public PCollectionRowTuple expand(PCollectionRowTuple input) {
return PCollectionRowTuple.of(
"output",
input.get("input").apply(ParDo.of(new ExtractWordsFn())).setRowSchema(OUTPUT_SCHEMA));
}
};
return new ExtractWordsTransform(configuration);
}

static class ExtractWordsFn extends DoFn<Row, Row> {
@ProcessElement
public void processElement(@Element Row element, OutputReceiver<Row> receiver) {
// Split the line into words.
String line = Preconditions.checkStateNotNull(element.getString("line"));
String[] words = line.split("[^\\p{L}]+", -1);
static class ExtractWordsTransform extends SchemaTransform {
private static final Schema OUTPUT_SCHEMA = Schema.builder().addStringField("word").build();
private final List<String> filter;

for (String word : words) {
if (!word.isEmpty()) {
receiver.output(Row.withSchema(OUTPUT_SCHEMA).withFieldValue("word", word).build());
}
}
ExtractWordsTransform(Configuration configuration) {
this.filter = configuration.getFilter();
}

@Override
public PCollectionRowTuple expand(PCollectionRowTuple input) {
return PCollectionRowTuple.of(
"output",
input
.getSinglePCollection()
.apply(
ParDo.of(
new DoFn<Row, Row>() {
@ProcessElement
public void process(@Element Row element, OutputReceiver<Row> receiver) {
// Split the line into words.
String line = Preconditions.checkStateNotNull(element.getString("line"));
String[] words = line.split("[^\\p{L}]+", -1);
Arrays.stream(words)
.filter(filter::contains)
.forEach(
word ->
receiver.output(
Row.withSchema(OUTPUT_SCHEMA)
.withFieldValue("word", word)
.build()));
}
}))
.setRowSchema(OUTPUT_SCHEMA));
}
}

@DefaultSchema(AutoValueSchema.class)
@AutoValue
protected abstract static class Configuration {}
public abstract static class Configuration {
public static Builder builder() {
return new AutoValue_ExtractWordsProvider_Configuration.Builder();
}

@SchemaFieldDescription("List of words to filter out.")
public abstract List<String> getFilter();

@AutoValue.Builder
public abstract static class Builder {
public abstract Builder setFilter(List<String> foo);

public abstract Configuration build();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,35 +44,37 @@ public String identifier() {

@Override
protected SchemaTransform from(Configuration configuration) {
return new SchemaTransform() {
@Override
public PCollectionRowTuple expand(PCollectionRowTuple input) {
Schema outputSchema =
Schema.builder().addStringField("word").addInt64Field("count").build();
return new JavaCountTransform();
}

static class JavaCountTransform extends SchemaTransform {
static final Schema OUTPUT_SCHEMA =
Schema.builder().addStringField("word").addInt64Field("count").build();

PCollection<Row> wordCounts =
input
.get("input")
.apply(Count.perElement())
.apply(
MapElements.into(TypeDescriptors.rows())
.via(
kv ->
Row.withSchema(outputSchema)
.withFieldValue(
"word",
Preconditions.checkStateNotNull(
kv.getKey().getString("word")))
.withFieldValue("count", kv.getValue())
.build()))
.setRowSchema(outputSchema);
@Override
public PCollectionRowTuple expand(PCollectionRowTuple input) {
PCollection<Row> wordCounts =
input
.get("input")
.apply(Count.perElement())
.apply(
MapElements.into(TypeDescriptors.rows())
.via(
kv ->
Row.withSchema(OUTPUT_SCHEMA)
.withFieldValue(
"word",
Preconditions.checkStateNotNull(
kv.getKey().getString("word")))
.withFieldValue("count", kv.getValue())
.build()))
.setRowSchema(OUTPUT_SCHEMA);

return PCollectionRowTuple.of("output", wordCounts);
}
};
return PCollectionRowTuple.of("output", wordCounts);
}
}

@DefaultSchema(AutoValueSchema.class)
@AutoValue
protected abstract static class Configuration {}
public abstract static class Configuration {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,24 +42,32 @@ public String identifier() {

@Override
protected SchemaTransform from(Configuration configuration) {
return new SchemaTransform() {
@Override
public PCollectionRowTuple expand(PCollectionRowTuple input) {
input
.get("input")
.apply(
MapElements.into(TypeDescriptors.strings())
.via(row -> Preconditions.checkStateNotNull(row.getString("line"))))
.apply(TextIO.write().to(configuration.getFilePathPrefix()));
return new WriteWordsTransform(configuration);
}

static class WriteWordsTransform extends SchemaTransform {
private final String filePathPrefix;

WriteWordsTransform(Configuration configuration) {
this.filePathPrefix = configuration.getFilePathPrefix();
}

@Override
public PCollectionRowTuple expand(PCollectionRowTuple input) {
input
.get("input")
.apply(
MapElements.into(TypeDescriptors.strings())
.via(row -> Preconditions.checkStateNotNull(row.getString("line"))))
.apply(TextIO.write().to(filePathPrefix));

return PCollectionRowTuple.empty(input.getPipeline());
}
};
return PCollectionRowTuple.empty(input.getPipeline());
}
}

@DefaultSchema(AutoValueSchema.class)
@AutoValue
protected abstract static class Configuration {
public abstract static class Configuration {
public static Builder builder() {
return new AutoValue_WriteWordsProvider_Configuration.Builder();
}
Expand Down
14 changes: 7 additions & 7 deletions sdks/python/apache_beam/transforms/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,14 +962,14 @@ def __init__(
self, path_to_jar, extra_args=None, classpath=None, append_args=None):
if extra_args and append_args:
raise ValueError('Only one of extra_args or append_args may be provided')
self._path_to_jar = path_to_jar
self.path_to_jar = path_to_jar
self._extra_args = extra_args
self._classpath = classpath or []
self._service_count = 0
self._append_args = append_args or []

def is_existing_service(self):
return subprocess_server.is_service_endpoint(self._path_to_jar)
return subprocess_server.is_service_endpoint(self.path_to_jar)

@staticmethod
def _expand_jars(jar):
Expand Down Expand Up @@ -997,7 +997,7 @@ def _expand_jars(jar):
def _default_args(self):
"""Default arguments to be used by `JavaJarExpansionService`."""

to_stage = ','.join([self._path_to_jar] + sum((
to_stage = ','.join([self.path_to_jar] + sum((
JavaJarExpansionService._expand_jars(jar)
for jar in self._classpath or []), []))
args = ['{{PORT}}', f'--filesToStage={to_stage}']
Expand All @@ -1009,24 +1009,24 @@ def _default_args(self):

def __enter__(self):
if self._service_count == 0:
self._path_to_jar = subprocess_server.JavaJarServer.local_jar(
self._path_to_jar)
self.path_to_jar = subprocess_server.JavaJarServer.local_jar(
self.path_to_jar)
if self._extra_args is None:
self._extra_args = self._default_args() + self._append_args
# Consider memoizing these servers (with some timeout).
logging.info(
'Starting a JAR-based expansion service from JAR %s ' + (
'and with classpath: %s' %
self._classpath if self._classpath else ''),
self._path_to_jar)
self.path_to_jar)
classpath_urls = [
subprocess_server.JavaJarServer.local_jar(path)
for jar in self._classpath
for path in JavaJarExpansionService._expand_jars(jar)
]
self._service_provider = subprocess_server.JavaJarServer(
ExpansionAndArtifactRetrievalStub,
self._path_to_jar,
self.path_to_jar,
self._extra_args,
classpath=classpath_urls)
self._service = self._service_provider.__enter__()
Expand Down
Loading
Loading