diff --git a/shotover/src/transforms/kafka/sink_cluster/mod.rs b/shotover/src/transforms/kafka/sink_cluster/mod.rs index feefdd0e2..4ba41a4ab 100644 --- a/shotover/src/transforms/kafka/sink_cluster/mod.rs +++ b/shotover/src/transforms/kafka/sink_cluster/mod.rs @@ -1,6 +1,6 @@ use crate::frame::kafka::{KafkaFrame, RequestBody, ResponseBody}; use crate::frame::{Frame, MessageType}; -use crate::message::{Message, MessageIdMap, Messages}; +use crate::message::{Message, Messages}; use crate::tls::{TlsConnector, TlsConnectorConfig}; use crate::transforms::{ ChainState, DownChainProtocol, Transform, TransformBuilder, TransformContextBuilder, @@ -204,11 +204,6 @@ impl TransformBuilder for KafkaSinkClusterBuilder { transform_context.force_run_chain, ), pending_requests: Default::default(), - // TODO: this approach with `find_coordinator_requests` and `routed_to_coordinator_for_group` - // is prone to memory leaks and logic errors. - // We should replace these fields with extra state within `PendingRequestState::Received/Sent`. - find_coordinator_requests: Default::default(), - routed_to_coordinator_for_group: Default::default(), temp_responses_buffer: Default::default(), sasl_mechanism: None, authorize_scram_over_mtls: self.authorize_scram_over_mtls.as_ref().map(|x| x.build()), @@ -267,8 +262,6 @@ struct KafkaSinkCluster { /// Maintains the state of each request/response pair. /// Ordering must be maintained to ensure responses match up with their request. pending_requests: VecDeque, - find_coordinator_requests: MessageIdMap, - routed_to_coordinator_for_group: MessageIdMap, /// A temporary buffer used when receiving responses, only held onto in order to avoid reallocating. temp_responses_buffer: Vec, sasl_mechanism: Option, @@ -318,13 +311,16 @@ impl PendingRequestState { } } -#[derive(Debug)] +#[derive(Debug, Clone)] enum PendingRequestTy { Fetch { originally_sent_at: Instant, max_wait_ms: i32, min_bytes: i32, }, + FindCoordinator(FindCoordinator), + // Covers multiple request types: JoinGroup, DeleteGroups etc. + RoutedToGroup(GroupId), Other, } @@ -347,46 +343,24 @@ impl Transform for KafkaSinkCluster { &mut self, chain_state: &'shorter mut ChainState<'longer>, ) -> Result { - let mut responses = if chain_state.requests.is_empty() { + if chain_state.requests.is_empty() { // there are no requests, so no point sending any, but we should check for any responses without awaiting - self.recv_responses() + self.recv_responses(&mut chain_state.close_client_connection) .await - .context("Failed to receive responses (without sending requests)")? + .context("Failed to receive responses (without sending requests)") } else { self.update_local_nodes().await; - for request in &mut chain_state.requests { - let id = request.id(); - if let Some(Frame::Kafka(KafkaFrame::Request { - body: RequestBody::FindCoordinator(find_coordinator), - .. - })) = request.frame() - { - self.find_coordinator_requests.insert( - id, - FindCoordinator { - key: find_coordinator.key.clone(), - key_type: find_coordinator.key_type, - }, - ); - } - } - self.route_requests(std::mem::take(&mut chain_state.requests)) .await .context("Failed to route requests")?; self.send_requests() .await .context("Failed to send requests")?; - self.recv_responses() + self.recv_responses(&mut chain_state.close_client_connection) .await - .context("Failed to receive responses")? - }; - - self.process_responses(&mut responses, &mut chain_state.close_client_connection) - .await - .context("Failed to process responses")?; - Ok(responses) + .context("Failed to receive responses") + } } } @@ -795,6 +769,12 @@ impl KafkaSinkCluster { let group_id = groups.groups_names.first().unwrap().clone(); self.route_to_coordinator(message, group_id); } + Some(Frame::Kafka(KafkaFrame::Request { + body: RequestBody::FindCoordinator(_), + .. + })) => { + self.route_find_coordinator(message); + } // route to controller broker Some(Frame::Kafka(KafkaFrame::Request { @@ -1237,6 +1217,8 @@ impl KafkaSinkCluster { unreachable!() } } + PendingRequestTy::RoutedToGroup(_) => None, + PendingRequestTy::FindCoordinator(_) => None, PendingRequestTy::Other => None, }; let mut value = PendingRequestState::Sent { @@ -1330,7 +1312,7 @@ impl KafkaSinkCluster { /// Receive all responses from the outgoing connections, returns all responses that are ready to be returned. /// For response ordering reasons, some responses will remain in self.pending_requests until other responses are received. - async fn recv_responses(&mut self) -> Result> { + async fn recv_responses(&mut self, close_client_connection: &mut bool) -> Result> { // To work around borrow checker issues, store connection errors in this temporary list before handling them. let mut connection_errors = vec![]; @@ -1404,19 +1386,20 @@ impl KafkaSinkCluster { ) }); if all_combined_received { + let pending_request_ty = pending_request.ty.clone(); // perform special handling for certain message types if let PendingRequestTy::Fetch { originally_sent_at, max_wait_ms, min_bytes, - } = pending_request.ty + } = &pending_request_ty { // resend the requests if we havent yet met the `max_wait_ms` and `min_bytes` requirements - if originally_sent_at.elapsed() < Duration::from_millis(max_wait_ms as u64) + if originally_sent_at.elapsed() < Duration::from_millis(*max_wait_ms as u64) && Self::total_fetch_record_bytes( &mut self.pending_requests, combine_responses, - ) < min_bytes as i64 + ) < *min_bytes as i64 { tokio::time::sleep(self.refetch_backoff).await; @@ -1448,13 +1431,15 @@ impl KafkaSinkCluster { } // The next response we are waiting on has been received, add it to responses - if combine_responses == 1 { + let mut response = if combine_responses == 1 { if let Some(PendingRequest { state: PendingRequestState::Received { response, .. }, .. }) = self.pending_requests.pop_front() { - responses.push(response); + response + } else { + unreachable!("Guaranteed by all_combined_received") } } else { let drain = self.pending_requests.drain(..combine_responses).map(|x| { @@ -1468,8 +1453,13 @@ impl KafkaSinkCluster { unreachable!("Guaranteed by all_combined_received") } }); - responses.push(Self::combine_responses(drain)?); - } + Self::combine_responses(drain)? + }; + + self.process_response(&mut response, pending_request_ty, close_client_connection) + .await + .context("Failed to process response")?; + responses.push(response); } else { // The pending_request is not received, we need to break to maintain response ordering. break; @@ -1623,224 +1613,223 @@ impl KafkaSinkCluster { Ok(()) } - async fn process_responses( + async fn process_response( &mut self, - responses: &mut [Message], + response: &mut Message, + request_ty: PendingRequestTy, close_client_connection: &mut bool, ) -> Result<()> { - for response in responses.iter_mut() { - let request_id = response.request_id().unwrap(); - match response.frame() { - Some(Frame::Kafka(KafkaFrame::Response { - body: ResponseBody::FindCoordinator(find_coordinator), - version, - .. - })) => { - let request = self - .find_coordinator_requests - .remove(&request_id) - .ok_or_else(|| anyhow!("Received find_coordinator but not requested"))?; - + match response.frame() { + Some(Frame::Kafka(KafkaFrame::Response { + body: ResponseBody::FindCoordinator(find_coordinator), + version, + .. + })) => { + if let PendingRequestTy::FindCoordinator(request) = request_ty { self.process_find_coordinator_response(*version, request, find_coordinator); self.rewrite_find_coordinator_response(*version, find_coordinator); response.invalidate_cache(); + } else { + return Err(anyhow!("Received find_coordinator but not requested")); } - Some(Frame::Kafka(KafkaFrame::Response { - body: ResponseBody::SaslHandshake(handshake), - .. - })) => { - // If authorize_scram_over_mtls is disabled there is no way that scram can work through KafkaSinkCluster - // since it is specifically designed such that replay attacks wont work. - // So when authorize_scram_over_mtls is disabled report to the user that SCRAM is not enabled. - if self.authorize_scram_over_mtls.is_none() { - // remove scram from supported mechanisms - handshake - .mechanisms - .retain(|x| !SASL_SCRAM_MECHANISMS.contains(&x.as_str())); - - // declare unsupported if the client requested SCRAM - if let Some(sasl_mechanism) = &self.sasl_mechanism { - if SASL_SCRAM_MECHANISMS.contains(&sasl_mechanism.as_str()) { - handshake.error_code = - ResponseError::UnsupportedSaslMechanism.code(); - } + } + Some(Frame::Kafka(KafkaFrame::Response { + body: ResponseBody::SaslHandshake(handshake), + .. + })) => { + // If authorize_scram_over_mtls is disabled there is no way that scram can work through KafkaSinkCluster + // since it is specifically designed such that replay attacks wont work. + // So when authorize_scram_over_mtls is disabled report to the user that SCRAM is not enabled. + if self.authorize_scram_over_mtls.is_none() { + // remove scram from supported mechanisms + handshake + .mechanisms + .retain(|x| !SASL_SCRAM_MECHANISMS.contains(&x.as_str())); + + // declare unsupported if the client requested SCRAM + if let Some(sasl_mechanism) = &self.sasl_mechanism { + if SASL_SCRAM_MECHANISMS.contains(&sasl_mechanism.as_str()) { + handshake.error_code = ResponseError::UnsupportedSaslMechanism.code(); } - - response.invalidate_cache(); } + + response.invalidate_cache(); } - Some(Frame::Kafka(KafkaFrame::Response { - body: ResponseBody::SaslAuthenticate(authenticate), - .. - })) => { - self.process_sasl_authenticate(authenticate, close_client_connection)?; - } - Some(Frame::Kafka(KafkaFrame::Response { - body: ResponseBody::Produce(produce), - .. - })) => { - // Clear this optional field to avoid making clients try to bypass shotover - produce.node_endpoints.clear(); - for (topic_name, response_topic) in &mut produce.responses { - for response_partition in &response_topic.partition_responses { - if let Some(ResponseError::NotLeaderOrFollower) = - ResponseError::try_from_code(response_partition.error_code) - { - if response_partition.current_leader.leader_id != -1 { - // The broker has informed us who the new leader is, we can just directly update the leader - if let Some(mut stored_topic) = - self.topic_by_name.get_mut(topic_name) + } + Some(Frame::Kafka(KafkaFrame::Response { + body: ResponseBody::SaslAuthenticate(authenticate), + .. + })) => { + self.process_sasl_authenticate(authenticate, close_client_connection)?; + } + Some(Frame::Kafka(KafkaFrame::Response { + body: ResponseBody::Produce(produce), + .. + })) => { + // Clear this optional field to avoid making clients try to bypass shotover + produce.node_endpoints.clear(); + for (topic_name, response_topic) in &mut produce.responses { + for response_partition in &response_topic.partition_responses { + if let Some(ResponseError::NotLeaderOrFollower) = + ResponseError::try_from_code(response_partition.error_code) + { + if response_partition.current_leader.leader_id != -1 { + // The broker has informed us who the new leader is, we can just directly update the leader + if let Some(mut stored_topic) = + self.topic_by_name.get_mut(topic_name) + { + if let Some(stored_partition) = stored_topic + .partitions + .get_mut(response_partition.index as usize) { - if let Some(stored_partition) = stored_topic - .partitions - .get_mut(response_partition.index as usize) + if response_partition.current_leader.leader_epoch + > stored_partition.leader_epoch { - if response_partition.current_leader.leader_epoch - > stored_partition.leader_epoch - { - stored_partition.leader_id = - response_partition.current_leader.leader_id; - stored_partition.leader_epoch = - response_partition.current_leader.leader_epoch; - } - tracing::info!( + stored_partition.leader_id = + response_partition.current_leader.leader_id; + stored_partition.leader_epoch = + response_partition.current_leader.leader_epoch; + } + tracing::info!( "Produce response included error NOT_LEADER_OR_FOLLOWER and so updated leader in topic {:?} partition {}", topic_name, response_partition.index ); - } } - } else { - // The broker doesnt know who the new leader is, clear the entire topic. - self.topic_by_name.remove(topic_name); - tracing::info!( + } + } else { + // The broker doesnt know who the new leader is, clear the entire topic. + self.topic_by_name.remove(topic_name); + tracing::info!( "Produce response included error NOT_LEADER_OR_FOLLOWER and so cleared topic {:?}", topic_name, ); - break; - } + break; } } - for response_partition in &mut response_topic.partition_responses { - // Clear this optional field to avoid making clients try to bypass shotover - response_partition.current_leader = - ProduceResponseLeaderIdAndEpoch::default(); - } } - response.invalidate_cache(); + for response_partition in &mut response_topic.partition_responses { + // Clear this optional field to avoid making clients try to bypass shotover + response_partition.current_leader = + ProduceResponseLeaderIdAndEpoch::default(); + } } - Some(Frame::Kafka(KafkaFrame::Response { - body: ResponseBody::Fetch(fetch), - .. - })) => { - // Clear this optional field to avoid making clients try to bypass shotover - // partition.current_leader and partition.preferred_read_replica are cleared due to the same reason - fetch.node_endpoints.clear(); - for fetch_response in &mut fetch.responses { - for partition in &mut fetch_response.partitions { - partition.current_leader = FetchResponseLeaderIdAndEpoch::default(); - partition.preferred_read_replica = BrokerId(-1); - if let Some(ResponseError::NotLeaderOrFollower) = - ResponseError::try_from_code(partition.error_code) - { - // The fetch response includes the leader_id which a client could use to route a fetch request to, - // but we cant use it to fix our list of replicas, so our only option is to clear the whole thing. - self.topic_by_name.remove(&fetch_response.topic); - self.topic_by_id.remove(&fetch_response.topic_id); - tracing::info!( + response.invalidate_cache(); + } + Some(Frame::Kafka(KafkaFrame::Response { + body: ResponseBody::Fetch(fetch), + .. + })) => { + // Clear this optional field to avoid making clients try to bypass shotover + // partition.current_leader and partition.preferred_read_replica are cleared due to the same reason + fetch.node_endpoints.clear(); + for fetch_response in &mut fetch.responses { + for partition in &mut fetch_response.partitions { + partition.current_leader = FetchResponseLeaderIdAndEpoch::default(); + partition.preferred_read_replica = BrokerId(-1); + if let Some(ResponseError::NotLeaderOrFollower) = + ResponseError::try_from_code(partition.error_code) + { + // The fetch response includes the leader_id which a client could use to route a fetch request to, + // but we cant use it to fix our list of replicas, so our only option is to clear the whole thing. + self.topic_by_name.remove(&fetch_response.topic); + self.topic_by_id.remove(&fetch_response.topic_id); + tracing::info!( "Fetch response included error NOT_LEADER_OR_FOLLOWER and so cleared metadata for topic {:?} {:?}", fetch_response.topic, fetch_response.topic_id ); - break; - } + break; } } - response.invalidate_cache(); } - Some(Frame::Kafka(KafkaFrame::Response { - body: ResponseBody::Heartbeat(heartbeat), - .. - })) => self.handle_coordinator_routing_error(request_id, heartbeat.error_code), - Some(Frame::Kafka(KafkaFrame::Response { - body: ResponseBody::SyncGroup(sync_group), - .. - })) => self.handle_coordinator_routing_error(request_id, sync_group.error_code), - Some(Frame::Kafka(KafkaFrame::Response { - body: ResponseBody::OffsetFetch(offset_fetch), - .. - })) => self.handle_coordinator_routing_error(request_id, offset_fetch.error_code), - Some(Frame::Kafka(KafkaFrame::Response { - body: ResponseBody::JoinGroup(join_group), - .. - })) => self.handle_coordinator_routing_error(request_id, join_group.error_code), - Some(Frame::Kafka(KafkaFrame::Response { - body: ResponseBody::LeaveGroup(leave_group), - .. - })) => self.handle_coordinator_routing_error(request_id, leave_group.error_code), - Some(Frame::Kafka(KafkaFrame::Response { - body: ResponseBody::DeleteGroups(delete_groups), - .. - })) => { - for (group_id, result) in &delete_groups.results { - if let Some(ResponseError::NotCoordinator) = - ResponseError::try_from_code(result.error_code) - { - // Need to run this to avoid memory leaks, since route_to_coordinator is called for DeleteGroup requests - self.routed_to_coordinator_for_group.remove(&request_id); - - // Need to run this to ensure we remove for all groups - self.group_to_coordinator_broker.remove(group_id); - } + response.invalidate_cache(); + } + Some(Frame::Kafka(KafkaFrame::Response { + body: ResponseBody::Heartbeat(heartbeat), + .. + })) => self.handle_group_coordinator_routing_error(&request_ty, heartbeat.error_code), + Some(Frame::Kafka(KafkaFrame::Response { + body: ResponseBody::SyncGroup(sync_group), + .. + })) => self.handle_group_coordinator_routing_error(&request_ty, sync_group.error_code), + Some(Frame::Kafka(KafkaFrame::Response { + body: ResponseBody::OffsetFetch(offset_fetch), + .. + })) => { + self.handle_group_coordinator_routing_error(&request_ty, offset_fetch.error_code) + } + Some(Frame::Kafka(KafkaFrame::Response { + body: ResponseBody::JoinGroup(join_group), + .. + })) => self.handle_group_coordinator_routing_error(&request_ty, join_group.error_code), + Some(Frame::Kafka(KafkaFrame::Response { + body: ResponseBody::LeaveGroup(leave_group), + .. + })) => self.handle_group_coordinator_routing_error(&request_ty, leave_group.error_code), + Some(Frame::Kafka(KafkaFrame::Response { + body: ResponseBody::DeleteGroups(delete_groups), + .. + })) => { + // clear metadata that resulted in NotCoordinator error + for (group_id, result) in &delete_groups.results { + if let Some(ResponseError::NotCoordinator) = + ResponseError::try_from_code(result.error_code) + { + self.group_to_coordinator_broker.remove(group_id); } } - Some(Frame::Kafka(KafkaFrame::Response { - body: ResponseBody::CreateTopics(create_topics), - .. - })) => { - for topic in create_topics.topics.values() { - if let Some(ResponseError::NotController) = - ResponseError::try_from_code(topic.error_code) - { - tracing::info!( + } + Some(Frame::Kafka(KafkaFrame::Response { + body: ResponseBody::CreateTopics(create_topics), + .. + })) => { + for topic in create_topics.topics.values() { + if let Some(ResponseError::NotController) = + ResponseError::try_from_code(topic.error_code) + { + tracing::info!( "Response to CreateTopics included error NOT_CONTROLLER and so reset controller broker, previously was {:?}", self.controller_broker.get() ); - self.controller_broker.clear(); - break; - } + self.controller_broker.clear(); + break; } } - Some(Frame::Kafka(KafkaFrame::Response { - body: ResponseBody::Metadata(metadata), - .. - })) => { - self.process_metadata_response(metadata).await; - self.rewrite_metadata_response(metadata)?; - response.invalidate_cache(); - } - Some(Frame::Kafka(KafkaFrame::Response { - body: ResponseBody::DescribeCluster(_), - .. - })) => { - // If clients were to send this we would need to rewrite the broker information. - // However I dont think clients actually send this, so just error to ensure we dont break invariants. - return Err(anyhow!( - "I think this is a raft specific message and never sent by clients" - )); - } - _ => {} } + Some(Frame::Kafka(KafkaFrame::Response { + body: ResponseBody::Metadata(metadata), + .. + })) => { + self.process_metadata_response(metadata).await; + self.rewrite_metadata_response(metadata)?; + response.invalidate_cache(); + } + Some(Frame::Kafka(KafkaFrame::Response { + body: ResponseBody::DescribeCluster(_), + .. + })) => { + // If clients were to send this we would need to rewrite the broker information. + // However I dont think clients actually send this, so just error to ensure we dont break invariants. + return Err(anyhow!( + "I think this is a raft specific message and never sent by clients" + )); + } + _ => {} } Ok(()) } /// This method must be called for every response to a request that was routed via `route_to_coordinator` - fn handle_coordinator_routing_error(&mut self, request_id: u128, error_code: i16) { + fn handle_group_coordinator_routing_error( + &mut self, + pending_request_ty: &PendingRequestTy, + error_code: i16, + ) { if let Some(ResponseError::NotCoordinator) = ResponseError::try_from_code(error_code) { - if let Some(group_id) = self.routed_to_coordinator_for_group.remove(&request_id) { - let broker_id = self.group_to_coordinator_broker.remove(&group_id); + if let PendingRequestTy::RoutedToGroup(group_id) = pending_request_ty { + let broker_id = self.group_to_coordinator_broker.remove(group_id); tracing::info!( "Response was error NOT_COORDINATOR and so cleared group id {:?} coordinator mapping to broker {:?}", group_id, @@ -1935,8 +1924,6 @@ impl KafkaSinkCluster { fn route_to_coordinator(&mut self, request: Message, group_id: GroupId) { let destination = self.group_to_coordinator_broker.get(&group_id); - self.routed_to_coordinator_for_group - .insert(request.id(), group_id.clone()); let destination = match destination { Some(destination) => *destination, None => { @@ -1945,16 +1932,38 @@ impl KafkaSinkCluster { } }; - self.pending_requests.push_back(PendingRequest { - state: PendingRequestState::routed(destination, request), - ty: PendingRequestTy::Other, - combine_responses: 1, - }); tracing::debug!( "Routing request relating to group id {:?} to broker {}", group_id.0, destination.0 ); + + self.pending_requests.push_back(PendingRequest { + state: PendingRequestState::routed(destination, request), + ty: PendingRequestTy::RoutedToGroup(group_id), + combine_responses: 1, + }); + } + + fn route_find_coordinator(&mut self, mut request: Message) { + if let Some(Frame::Kafka(KafkaFrame::Request { + body: RequestBody::FindCoordinator(find_coordinator), + .. + })) = request.frame() + { + let destination = random_broker_id(&self.nodes, &mut self.rng); + let ty = PendingRequestTy::FindCoordinator(FindCoordinator { + key: find_coordinator.key.clone(), + key_type: find_coordinator.key_type, + }); + tracing::debug!("Routing FindCoordinator to random broker {}", destination.0); + + self.pending_requests.push_back(PendingRequest { + state: PendingRequestState::routed(destination, request), + ty, + combine_responses: 1, + }); + } } async fn process_metadata_response(&mut self, metadata: &MetadataResponse) { @@ -2075,7 +2084,7 @@ impl KafkaSinkCluster { if version <= 3 { if find_coordinator.error_code == 0 { self.group_to_coordinator_broker - .insert(GroupId(request.key.clone()), find_coordinator.node_id); + .insert(GroupId(request.key), find_coordinator.node_id); } } else { for coordinator in &find_coordinator.coordinators { @@ -2362,6 +2371,7 @@ struct Partition { external_rack_replica_nodes: Vec, } +#[derive(Debug, Clone)] struct FindCoordinator { key: StrBytes, key_type: i8,