Skip to content

Commit

Permalink
Migrate to Spark 3 DataSource v2 interfaces
Browse files Browse the repository at this point in the history
Fixes woltapp#3
  • Loading branch information
mojodna committed Apr 3, 2024
1 parent a16f07d commit 7b75ff3
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 70 deletions.
5 changes: 3 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ val mavenLocal = "Local Maven Repository" at Path.userHome.asFile.toURI.toURL +
resolvers += mavenLocal

libraryDependencies ++= Seq(
"org.apache.spark" %% "spark-core" % "2.4.4" % "provided",
"org.apache.spark" %% "spark-sql" % "2.4.4" % "provided",
"org.apache.spark" %% "spark-core" % "3.5.1" % "provided",
"org.apache.spark" %% "spark-sql" % "3.5.1" % "provided",
"org.apache.spark" %% "spark-hadoop-cloud" % "3.5.1" % "provided",
"com.wolt.osm" % "parallelpbf" % "0.3.1",
"org.scalatest" %% "scalatest-funsuite" % "3.2.18" % "it,test",
"org.scalactic" %% "scalactic" % "3.2.18" % "it,test"
Expand Down
11 changes: 7 additions & 4 deletions src/main/scala/com/wolt/osm/spark/OsmSource/OsmPartition.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package com.wolt.osm.spark.OsmSource

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader}
import org.apache.spark.sql.connector.read.InputPartition
import org.apache.spark.sql.types.StructType

class OsmPartition(input: String, hadoop: SerializableHadoopConfigration, schema: StructType, threads: Int, partitionsNo: Int, partition: Int, useLocal: Boolean) extends InputPartition[InternalRow] {
override def createPartitionReader(): InputPartitionReader[InternalRow] = new OsmPartitionReader(input, hadoop, schema, threads, partitionsNo, partition, useLocal)
class OsmPartition(val input: String, val hadoop: SerializableHadoopConfigration, val schema: StructType, val threads: Int, val partitionsNo: Int, val partition: Int, val useLocal: Boolean) extends InputPartition

object OsmPartition {
def unapply(inputPartition: OsmPartition): Option[(String, SerializableHadoopConfigration, StructType, Int, Int, Int, Boolean)] = {
Some((inputPartition.input, inputPartition.hadoop, inputPartition.schema, inputPartition.threads, inputPartition.partitionsNo, inputPartition.partition, inputPartition.useLocal))
}
}
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
package com.wolt.osm.spark.OsmSource

import java.io.{FileInputStream, InputStream}
import java.util.concurrent._
import java.util.function.Consumer

import com.wolt.osm.parallelpbf.ParallelBinaryParser
import com.wolt.osm.parallelpbf.entity._
import org.apache.hadoop.fs.{FSDataInputStream, Path}
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkFiles
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.sources.v2.reader.InputPartitionReader
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.connector.read.PartitionReader
import org.apache.spark.unsafe.types.UTF8String

import java.io.{FileInputStream, InputStream}
import java.util.concurrent._
import java.util.function.Consumer
import scala.collection.JavaConverters._
import scala.collection.mutable

class OsmPartitionReader(input: String, hadoop:SerializableHadoopConfigration, schema: StructType, threads: Int, partitionsNo: Int, partition: Int, useLocal: Boolean) extends InputPartitionReader[InternalRow] {
class OsmPartitionReader(inputPartition: OsmPartition) extends PartitionReader[InternalRow] {
val OsmPartition(input, hadoop, schema, threads, partitionsNo, partition, useLocal) = inputPartition
private val schemaColumnNames = schema.fields.map(_.name)

private val parserTask = new FutureTask[Unit](new Callable[Unit]() {
Expand Down
50 changes: 50 additions & 0 deletions src/main/scala/com/wolt/osm/spark/OsmSource/OsmScan.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package com.wolt.osm.spark.OsmSource

import org.apache.hadoop.fs.Path
import org.apache.spark.SparkFiles
import org.apache.spark.sql.SparkSession

import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory, Scan}

import java.io.File
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

import scala.util.Try

class OsmScan(val options: CaseInsensitiveStringMap) extends Scan with Batch {
private val path = options.get("path")
private val useLocal = Option(options.get("useLocalFile")).getOrElse("").equalsIgnoreCase("true")

private val spark = SparkSession.active
private val hadoop = spark.sessionState.newHadoopConf()
private val hadoopConfigration = new SerializableHadoopConfigration(hadoop)

if (useLocal) {
if (!new File(SparkFiles.get(path)).canRead) {
throw new RuntimeException(s"Input unavailable: $path")
}
} else {
val source = new Path(path)
val fs = source.getFileSystem(hadoop)
if (!fs.exists(source)) {
throw new RuntimeException(s"Input unavailable: $path")
}
}

private val partitions = Option(options.get("partitions")).getOrElse("1")
private val threads = Option(options.get("threads")).getOrElse("1")

override def readSchema(): StructType = OsmSource.schema

override def planInputPartitions(): Array[InputPartition] = {
val partitionsNo = Try(partitions.toInt).getOrElse(1)
val threadsNo = Try(threads.toInt).getOrElse(1)
val shiftedPartitions = partitionsNo - 1
(0 to shiftedPartitions).map(p => new OsmPartition(path, hadoopConfigration, readSchema(), threadsNo, partitionsNo, p, useLocal)).toArray
}

override def toBatch: Batch = this

override def createReaderFactory(): PartitionReaderFactory = new OsmPartitionReaderFactory()
}
62 changes: 32 additions & 30 deletions src/main/scala/com/wolt/osm/spark/OsmSource/OsmSource.scala
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
package com.wolt.osm.spark.OsmSource

import java.io.File

import org.apache.hadoop.fs.Path
import org.apache.spark.SparkFiles
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.sources.v2.reader.DataSourceReader
import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory, Scan, ScanBuilder}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.CaseInsensitiveStringMap

import java.util
import scala.collection.JavaConverters._

object OsmSource {
val OSM_SOURCE_NAME = "com.wolt.osm.spark.OsmSource"
Expand Down Expand Up @@ -41,26 +42,27 @@ object OsmSource {
val schema: StructType = StructType(fields)
}

class DefaultSource extends DataSourceV2 with ReadSupport {
override def createReader(options: DataSourceOptions): DataSourceReader = {
val path = options.get("path").get
val useLocal = options.get("useLocalFile").orElse("").equalsIgnoreCase("true")

val spark = SparkSession.active
val hadoop = spark.sessionState.newHadoopConf()
val hadoopConfiguration = new SerializableHadoopConfigration(hadoop)

if (useLocal) {
if (!new File(SparkFiles.get(path)).canRead) {
throw new RuntimeException(s"Input unavailable: $path")
}
} else {
val source = new Path(path)
val fs = source.getFileSystem(hadoop)
if (!fs.exists(source)) {
throw new RuntimeException(s"Input unavailable: $path")
}
}
new OsmSourceReader(path, hadoopConfiguration, options.get("partitions").orElse("1"), options.get("threads").orElse("1"), useLocal)
}
}
class DefaultSource extends TableProvider {
override def inferSchema(options: CaseInsensitiveStringMap): StructType = OsmSource.schema

override def getTable(schema: StructType, partitioning: Array[Transform], properties: util.Map[String, String]): Table = new OsmTable()
}

class OsmTable extends Table with SupportsRead {
override def name(): String = this.getClass.toString

override def schema(): StructType = OsmSource.schema

override def capabilities(): util.Set[TableCapability] = Set(TableCapability.BATCH_READ).asJava

override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = new OsmScanBuilder(options)
}

class OsmScanBuilder(val options: CaseInsensitiveStringMap) extends ScanBuilder {

override def build(): Scan = new OsmScan(options)
}

class OsmPartitionReaderFactory() extends PartitionReaderFactory {
override def createReader(partition: InputPartition): PartitionReader[InternalRow] = new OsmPartitionReader(partition.asInstanceOf[OsmPartition])
}
26 changes: 0 additions & 26 deletions src/main/scala/com/wolt/osm/spark/OsmSource/OsmSourceReader.scala

This file was deleted.

0 comments on commit 7b75ff3

Please sign in to comment.