Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add file support to the bulk CDK #49931

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,6 @@ class MockStreamLoader(override val stream: DestinationStream) : StreamLoader {
return LocalBatch(records.asSequence().toList())
}

override suspend fun processFile(file: DestinationFile): Batch {
return LocalFileBatch(file)
}

override suspend fun processBatch(batch: Batch): Batch {
return when (batch) {
is LocalBatch -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ abstract class DestinationConfiguration : Configuration {

open val numProcessRecordsWorkers: Int = 2
open val numProcessBatchWorkers: Int = 5
open val numProcessBatchWorkersForFileTransfer: Int = 3
open val batchQueueDepth: Int = 10

companion object {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import io.airbyte.cdk.load.message.BatchEnvelope
import io.airbyte.cdk.load.message.MultiProducerChannel
import io.airbyte.cdk.load.state.ReservationManager
import io.airbyte.cdk.load.task.implementor.FileAggregateMessage
import io.airbyte.cdk.load.task.implementor.FileTransferQueueMessage
import io.github.oshai.kotlinlogging.KotlinLogging
import io.micronaut.context.annotation.Factory
import io.micronaut.context.annotation.Value
Expand Down Expand Up @@ -79,4 +80,13 @@ class SyncBeanFactory {
val channel = Channel<BatchEnvelope<*>>(config.batchQueueDepth)
return MultiProducerChannel(config.numProcessRecordsWorkers.toLong(), channel, "batchQueue")
}

@Singleton
@Named("fileMessageQueue")
fun fileMessageQueue(
config: DestinationConfiguration,
): MultiProducerChannel<FileTransferQueueMessage> {
val channel = Channel<FileTransferQueueMessage>(config.batchQueueDepth)
return MultiProducerChannel(1, channel, "fileMessageQueue")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,17 @@ import io.airbyte.cdk.load.command.DestinationConfiguration
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.message.BatchEnvelope
import io.airbyte.cdk.load.message.CheckpointMessageWrapped
import io.airbyte.cdk.load.message.DestinationFile
import io.airbyte.cdk.load.message.DestinationMessage
import io.airbyte.cdk.load.message.DestinationStreamEvent
import io.airbyte.cdk.load.message.MessageQueue
import io.airbyte.cdk.load.message.MessageQueueSupplier
import io.airbyte.cdk.load.message.QueueWriter
import io.airbyte.cdk.load.state.Reserved
import io.airbyte.cdk.load.state.SyncManager
import io.airbyte.cdk.load.task.implementor.CloseStreamTaskFactory
import io.airbyte.cdk.load.task.implementor.FailStreamTaskFactory
import io.airbyte.cdk.load.task.implementor.FailSyncTaskFactory
import io.airbyte.cdk.load.task.implementor.FileTransferQueueMessage
import io.airbyte.cdk.load.task.implementor.OpenStreamTaskFactory
import io.airbyte.cdk.load.task.implementor.ProcessBatchTaskFactory
import io.airbyte.cdk.load.task.implementor.ProcessFileTaskFactory
Expand All @@ -37,6 +38,7 @@ import io.airbyte.cdk.load.util.setOnce
import io.github.oshai.kotlinlogging.KotlinLogging
import io.micronaut.context.annotation.Secondary
import io.micronaut.context.annotation.Value
import jakarta.inject.Named
import jakarta.inject.Singleton
import java.util.concurrent.atomic.AtomicBoolean
import kotlinx.coroutines.CancellationException
Expand All @@ -50,8 +52,6 @@ interface DestinationTaskLauncher : TaskLauncher {
suspend fun handleNewBatch(stream: DestinationStream.Descriptor, wrapped: BatchEnvelope<*>)
suspend fun handleStreamClosed(stream: DestinationStream.Descriptor)
suspend fun handleTeardownComplete(success: Boolean = true)
suspend fun handleFile(stream: DestinationStream.Descriptor, file: DestinationFile, index: Long)

suspend fun handleException(e: Exception)
suspend fun handleFailStreamComplete(stream: DestinationStream.Descriptor, e: Exception)
}
Expand Down Expand Up @@ -129,6 +129,7 @@ class DefaultDestinationTaskLauncher(
private val recordQueueSupplier:
MessageQueueSupplier<DestinationStream.Descriptor, Reserved<DestinationStreamEvent>>,
private val checkpointQueue: QueueWriter<Reserved<CheckpointMessageWrapped>>,
@Named("fileMessageQueue") private val fileTransferQueue: MessageQueue<FileTransferQueueMessage>
) : DestinationTaskLauncher {
private val log = KotlinLogging.logger {}

Expand Down Expand Up @@ -180,7 +181,8 @@ class DefaultDestinationTaskLauncher(
inputFlow = inputFlow,
recordQueueSupplier = recordQueueSupplier,
checkpointQueue = checkpointQueue,
this,
fileTransferQueue = fileTransferQueue,
destinationTaskLauncher = this,
)
enqueue(inputConsumerTask)

Expand Down Expand Up @@ -209,6 +211,17 @@ class DefaultDestinationTaskLauncher(
val task = processBatchTaskFactory.make(this)
enqueue(task)
}
} else {
repeat(config.numProcessRecordsWorkers) {
log.info { "Launching process file task $it" }
enqueue(processFileTaskFactory.make(this))
}

repeat(config.numProcessBatchWorkersForFileTransfer) {
log.info { "Launching process batch task $it" }
val task = processBatchTaskFactory.make(this)
enqueue(task)
}
}

// Start flush task
Expand Down Expand Up @@ -284,14 +297,6 @@ class DefaultDestinationTaskLauncher(
}
}

override suspend fun handleFile(
stream: DestinationStream.Descriptor,
file: DestinationFile,
index: Long
) {
enqueue(processFileTaskFactory.make(this, stream, file, index))
}

override suspend fun handleException(e: Exception) {
catalog.streams
.map { failStreamTaskFactory.make(this, e, it.descriptor) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import io.airbyte.cdk.load.state.SyncManager
import io.airbyte.cdk.load.task.DestinationTaskLauncher
import io.airbyte.cdk.load.task.KillableScope
import io.airbyte.cdk.load.write.StreamLoader
import io.github.oshai.kotlinlogging.KotlinLogging
import io.micronaut.context.annotation.Secondary
import jakarta.inject.Named
import jakarta.inject.Singleton
Expand All @@ -22,7 +23,7 @@ class DefaultProcessBatchTask(
private val batchQueue: MultiProducerChannel<BatchEnvelope<*>>,
private val taskLauncher: DestinationTaskLauncher
) : ProcessBatchTask {

val log = KotlinLogging.logger {}
override suspend fun execute() {
batchQueue.consume().collect { batchEnvelope ->
val streamLoader = syncManager.getOrAwaitStreamLoader(batchEnvelope.streamDescriptor)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,58 +4,72 @@

package io.airbyte.cdk.load.task.implementor

import com.google.common.collect.Range
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.message.BatchEnvelope
import io.airbyte.cdk.load.message.DestinationFile
import io.airbyte.cdk.load.message.MessageQueue
import io.airbyte.cdk.load.message.MultiProducerChannel
import io.airbyte.cdk.load.state.SyncManager
import io.airbyte.cdk.load.task.DestinationTaskLauncher
import io.airbyte.cdk.load.task.ImplementorScope
import io.airbyte.cdk.load.util.use
import io.airbyte.cdk.load.write.BatchAccumulator
import io.github.oshai.kotlinlogging.KotlinLogging
import io.micronaut.context.annotation.Secondary
import jakarta.inject.Named
import jakarta.inject.Singleton
import java.util.concurrent.ConcurrentHashMap

interface ProcessFileTask : ImplementorScope

class DefaultProcessFileTask(
private val streamDescriptor: DestinationStream.Descriptor,
private val taskLauncher: DestinationTaskLauncher,
private val syncManager: SyncManager,
private val file: DestinationFile,
private val index: Long,
private val taskLauncher: DestinationTaskLauncher,
private val inputQueue: MessageQueue<FileTransferQueueMessage>,
private val outputQueue: MultiProducerChannel<BatchEnvelope<*>>,
) : ProcessFileTask {
val log = KotlinLogging.logger {}
private val accumulators = ConcurrentHashMap<DestinationStream.Descriptor, BatchAccumulator>()

override suspend fun execute() {
val streamLoader = syncManager.getOrAwaitStreamLoader(streamDescriptor)
outputQueue.use {
inputQueue.consume().collect { (streamDescriptor, file, index) ->
val streamLoader = syncManager.getOrAwaitStreamLoader(streamDescriptor)

val batch = streamLoader.processFile(file)
val acc =
accumulators.getOrPut(streamDescriptor) {
streamLoader.createFileBatchAccumulator(outputQueue)
}

val wrapped = BatchEnvelope(batch, Range.singleton(index), streamDescriptor)
taskLauncher.handleNewBatch(streamDescriptor, wrapped)
acc.processFilePart(file, index)
}
}
}
}

interface ProcessFileTaskFactory {
fun make(
taskLauncher: DestinationTaskLauncher,
stream: DestinationStream.Descriptor,
file: DestinationFile,
index: Long,
): ProcessFileTask
}

@Singleton
@Secondary
class DefaultFileRecordsTaskFactory(
private val syncManager: SyncManager,
@Named("fileMessageQueue")
private val fileTransferQueue: MessageQueue<FileTransferQueueMessage>,
@Named("batchQueue") private val outputQueue: MultiProducerChannel<BatchEnvelope<*>>,
) : ProcessFileTaskFactory {
override fun make(
taskLauncher: DestinationTaskLauncher,
stream: DestinationStream.Descriptor,
file: DestinationFile,
index: Long,
): ProcessFileTask {
return DefaultProcessFileTask(stream, taskLauncher, syncManager, file, index)
return DefaultProcessFileTask(syncManager, taskLauncher, fileTransferQueue, outputQueue)
}
}

data class FileTransferQueueMessage(
val streamDescriptor: DestinationStream.Descriptor,
val file: DestinationFile,
val index: Long,
)
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import io.airbyte.cdk.load.message.DestinationStreamAffinedMessage
import io.airbyte.cdk.load.message.DestinationStreamEvent
import io.airbyte.cdk.load.message.GlobalCheckpoint
import io.airbyte.cdk.load.message.GlobalCheckpointWrapped
import io.airbyte.cdk.load.message.MessageQueue
import io.airbyte.cdk.load.message.MessageQueueSupplier
import io.airbyte.cdk.load.message.QueueWriter
import io.airbyte.cdk.load.message.SimpleBatch
Expand All @@ -34,9 +35,11 @@ import io.airbyte.cdk.load.state.Reserved
import io.airbyte.cdk.load.state.SyncManager
import io.airbyte.cdk.load.task.DestinationTaskLauncher
import io.airbyte.cdk.load.task.KillableScope
import io.airbyte.cdk.load.task.implementor.FileTransferQueueMessage
import io.airbyte.cdk.load.util.use
import io.github.oshai.kotlinlogging.KotlinLogging
import io.micronaut.context.annotation.Secondary
import jakarta.inject.Named
import jakarta.inject.Singleton

interface InputConsumerTask : KillableScope
Expand All @@ -61,6 +64,8 @@ class DefaultInputConsumerTask(
private val checkpointQueue: QueueWriter<Reserved<CheckpointMessageWrapped>>,
private val syncManager: SyncManager,
private val destinationTaskLauncher: DestinationTaskLauncher,
@Named("fileMessageQueue")
private val fileTransferQueue: MessageQueue<FileTransferQueueMessage>,
) : InputConsumerTask {
private val log = KotlinLogging.logger {}

Expand Down Expand Up @@ -97,15 +102,17 @@ class DefaultInputConsumerTask(
}
is DestinationFile -> {
val index = manager.countRecordIn()
destinationTaskLauncher.handleFile(stream, message, index)
// destinationTaskLauncher.handleFile(stream, message, index)
fileTransferQueue.publish(FileTransferQueueMessage(stream, message, index))
}
is DestinationFileStreamComplete -> {
reserved.release() // safe because multiple calls conflate
manager.markEndOfStream(true)
fileTransferQueue.close()
val envelope =
BatchEnvelope(
SimpleBatch(Batch.State.COMPLETE),
streamDescriptor = message.stream
streamDescriptor = message.stream,
)
destinationTaskLauncher.handleNewBatch(stream, envelope)
}
Expand Down Expand Up @@ -198,6 +205,7 @@ interface InputConsumerTaskFactory {
MessageQueueSupplier<DestinationStream.Descriptor, Reserved<DestinationStreamEvent>>,
checkpointQueue: QueueWriter<Reserved<CheckpointMessageWrapped>>,
destinationTaskLauncher: DestinationTaskLauncher,
fileTransferQueue: MessageQueue<FileTransferQueueMessage>
): InputConsumerTask
}

Expand All @@ -212,14 +220,16 @@ class DefaultInputConsumerTaskFactory(private val syncManager: SyncManager) :
MessageQueueSupplier<DestinationStream.Descriptor, Reserved<DestinationStreamEvent>>,
checkpointQueue: QueueWriter<Reserved<CheckpointMessageWrapped>>,
destinationTaskLauncher: DestinationTaskLauncher,
fileTransferQueue: MessageQueue<FileTransferQueueMessage>,
): InputConsumerTask {
return DefaultInputConsumerTask(
catalog,
inputFlow,
recordQueueSupplier,
checkpointQueue,
syncManager,
destinationTaskLauncher
destinationTaskLauncher,
fileTransferQueue,
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ package io.airbyte.cdk.load.write

import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.message.Batch
import io.airbyte.cdk.load.message.BatchEnvelope
import io.airbyte.cdk.load.message.DestinationFile
import io.airbyte.cdk.load.message.DestinationRecord
import io.airbyte.cdk.load.message.MultiProducerChannel
import io.airbyte.cdk.load.message.SimpleBatch
import io.airbyte.cdk.load.state.StreamProcessingFailed

Expand Down Expand Up @@ -48,8 +50,10 @@ interface StreamLoader : BatchAccumulator {

suspend fun start() {}
suspend fun createBatchAccumulator(): BatchAccumulator = this
suspend fun createFileBatchAccumulator(
outputQueue: MultiProducerChannel<BatchEnvelope<*>>,
): BatchAccumulator = this

suspend fun processFile(file: DestinationFile): Batch
suspend fun processBatch(batch: Batch): Batch = SimpleBatch(Batch.State.COMPLETE)
suspend fun close(streamFailure: StreamProcessingFailed? = null) {}
}
Expand All @@ -63,4 +67,9 @@ interface BatchAccumulator {
throw NotImplementedError(
"processRecords must be implemented if createBatchAccumulator is overridden"
)

suspend fun processFilePart(file: DestinationFile, index: Long): Unit =
throw NotImplementedError(
"processRecords must be implemented if createBatchAccumulator is overridden"
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import io.airbyte.cdk.load.task.implementor.FailStreamTask
import io.airbyte.cdk.load.task.implementor.FailStreamTaskFactory
import io.airbyte.cdk.load.task.implementor.FailSyncTask
import io.airbyte.cdk.load.task.implementor.FailSyncTaskFactory
import io.airbyte.cdk.load.task.implementor.FileTransferQueueMessage
import io.airbyte.cdk.load.task.implementor.OpenStreamTask
import io.airbyte.cdk.load.task.implementor.OpenStreamTaskFactory
import io.airbyte.cdk.load.task.implementor.ProcessBatchTaskFactory
Expand Down Expand Up @@ -159,7 +160,8 @@ class DestinationTaskLauncherTest<T : ScopedTask> {
MessageQueueSupplier<
DestinationStream.Descriptor, Reserved<DestinationStreamEvent>>,
checkpointQueue: QueueWriter<Reserved<CheckpointMessageWrapped>>,
destinationTaskLauncher: DestinationTaskLauncher
destinationTaskLauncher: DestinationTaskLauncher,
fileTransferQueue: MessageQueue<FileTransferQueueMessage>,
): InputConsumerTask {
return object : InputConsumerTask {
override suspend fun execute() {
Expand Down
Loading
Loading