diff --git a/shotover/src/transforms/kafka/sink_cluster/mod.rs b/shotover/src/transforms/kafka/sink_cluster/mod.rs index 980bf379d..a65711bfb 100644 --- a/shotover/src/transforms/kafka/sink_cluster/mod.rs +++ b/shotover/src/transforms/kafka/sink_cluster/mod.rs @@ -10,6 +10,7 @@ use crate::transforms::{ use crate::transforms::{TransformConfig, TransformContextConfig}; use anyhow::{anyhow, Context, Result}; use async_trait::async_trait; +use bytes::{Bytes, BytesMut}; use connections::{Connections, Destination}; use dashmap::DashMap; use kafka_node::{ConnectionFactory, KafkaAddress, KafkaNode, KafkaNodeState}; @@ -253,6 +254,7 @@ impl TransformBuilder for KafkaSinkClusterBuilder { sasl_mechanism: None, authorize_scram_over_mtls: self.authorize_scram_over_mtls.as_ref().map(|x| x.build()), refetch_backoff: Duration::from_millis(1), + next_fetch_progress: Default::default(), }) } @@ -314,8 +316,11 @@ pub(crate) struct KafkaSinkCluster { authorize_scram_over_mtls: Option, connections: Connections, refetch_backoff: Duration, + next_fetch_progress: HashMap>, } +type TopicPartition = (TopicName, i32); + /// State of a Request/Response is maintained by this enum. /// The state progresses from Routed -> Sent -> Received #[derive(Debug)] @@ -1787,13 +1792,72 @@ impl KafkaSinkCluster { if let PendingRequestState::Received { destination, request, - .. + response, } = &mut pending_request.state { + if let Some(Frame::Kafka(KafkaFrame::Response { + body: ResponseBody::Fetch(fetch), + .. + })) = response.frame() + { + for response in &fetch.responses { + for partition in &response.partitions { + if let Some(records) = &partition.records { + self.next_fetch_progress + .entry(( + response.topic.clone(), + partition.partition_index, + )) + .or_default() + .push(records.clone()); + + // ahh... + // we need to rewrite the request at this point to avoid rerequesting the data we already have. + // However to do that we need to know how many records we received, which requires parsing the record bytes. + + // TODO: handle old record batch format as well + const RECORD_COUNT_START: usize = 57; + if records.len() > RECORD_COUNT_START + 4 { + let records_count = i32::from_be_bytes( + records[RECORD_COUNT_START + ..RECORD_COUNT_START + 4] + .try_into() + .unwrap(), + ); + tracing::info!( + "storing {records_count} unused records" + ); + + if let Some(request) = request { + if let Some(Frame::Kafka( + KafkaFrame::Request { + body: RequestBody::Fetch(fetch), + .. + }, + )) = request.frame() + { + for topic in &mut fetch.topics { + for partition in + &mut topic.partitions + { + partition.fetch_offset += + records_count as i64; + } + } + } + } + } + } + } + } + } else { + panic!("Must be fetch"); + } + pending_request.state = PendingRequestState::Routed { destination: *destination, request: request.take().unwrap(), - } + }; } else { unreachable!() } @@ -1803,6 +1867,51 @@ impl KafkaSinkCluster { break; } else { self.refetch_backoff = Duration::from_millis(1); + + if !self.next_fetch_progress.is_empty() { + if let PendingRequestState::Received { response, .. } = + &mut self.pending_requests.front_mut().unwrap().state + { + if let Some(Frame::Kafka(KafkaFrame::Response { + body: ResponseBody::Fetch(fetch), + .. + })) = response.frame() + { + for response in &mut fetch.responses { + for partition in &mut response.partitions { + if let Some(extra_records) = + self.next_fetch_progress.get(&( + response.topic.clone(), + partition.partition_index, + )) + { + if !extra_records.is_empty() { + // prefix records with all the stored extra records + let mut combined_records = BytesMut::new(); + for record in extra_records { + combined_records.extend_from_slice(record); + } + tracing::info!( + "combine extra_records={} base={}", + extra_records.len(), + partition.records.is_some() + ); + + if let Some(records) = &partition.records { + combined_records.extend_from_slice(records); + } + partition.records = + Some(combined_records.freeze()) + } + } + } + } + } else { + panic!("Must be fetch"); + } + } + } + self.next_fetch_progress.clear(); } }