diff --git a/shotover-proxy/tests/kafka_int_tests/mod.rs b/shotover-proxy/tests/kafka_int_tests/mod.rs index 3eaabb642..e046aad48 100644 --- a/shotover-proxy/tests/kafka_int_tests/mod.rs +++ b/shotover-proxy/tests/kafka_int_tests/mod.rs @@ -9,6 +9,7 @@ use test_cases::produce_consume_partitions1; use test_cases::produce_consume_partitions3; use test_cases::{assert_topic_creation_is_denied_due_to_acl, setup_basic_user_acls}; use test_helpers::connection::kafka::node::run_node_smoke_test_scram; +use test_helpers::connection::kafka::python::run_python_bad_auth_sasl_scram; use test_helpers::connection::kafka::python::run_python_smoke_test_sasl_scram; use test_helpers::connection::kafka::{KafkaConnectionBuilder, KafkaDriver}; use test_helpers::docker_compose::docker_compose; @@ -755,21 +756,53 @@ async fn cluster_sasl_scram_over_mtls_nodejs_and_python() { let _docker_compose = docker_compose("tests/test-configs/kafka/cluster-sasl-scram-over-mtls/docker-compose.yaml"); - let shotover = shotover_process( - "tests/test-configs/kafka/cluster-sasl-scram-over-mtls/topology-single.yaml", - ) - .start() - .await; - run_node_smoke_test_scram("127.0.0.1:9192", "super_user", "super_password").await; - run_python_smoke_test_sasl_scram("127.0.0.1:9192", "super_user", "super_password").await; + { + let shotover = shotover_process( + "tests/test-configs/kafka/cluster-sasl-scram-over-mtls/topology-single.yaml", + ) + .start() + .await; + + run_node_smoke_test_scram("127.0.0.1:9192", "super_user", "super_password").await; + run_python_smoke_test_sasl_scram("127.0.0.1:9192", "super_user", "super_password").await; - tokio::time::timeout( - Duration::from_secs(10), - shotover.shutdown_and_then_consume_events(&[]), - ) - .await - .expect("Shotover did not shutdown within 10s"); + tokio::time::timeout( + Duration::from_secs(10), + shotover.shutdown_and_then_consume_events(&[]), + ) + .await + .expect("Shotover did not shutdown within 10s"); + } + + { + let shotover = shotover_process( + "tests/test-configs/kafka/cluster-sasl-scram-over-mtls/topology-single.yaml", + ) + .start() + .await; + + run_python_bad_auth_sasl_scram("127.0.0.1:9192", "incorrect_user", "super_password").await; + run_python_bad_auth_sasl_scram("127.0.0.1:9192", "super_user", "incorrect_password").await; + + tokio::time::timeout( + Duration::from_secs(10), + shotover.shutdown_and_then_consume_events(&[EventMatcher::new() + .with_level(Level::Error) + .with_target("shotover::server") + .with_message(r#"encountered an error when flushing the chain kafka for shutdown + +Caused by: + 0: KafkaSinkCluster transform failed + 1: Failed to receive responses (without sending requests) + 2: Outgoing connection had pending requests, those requests/responses are lost so connection recovery cannot be attempted. + 3: Failed to receive from ControlConnection + 4: The other side of this connection closed the connection"#) + .with_count(Count::Times(2))]), + ) + .await + .expect("Shotover did not shutdown within 10s"); + } } #[rstest] diff --git a/shotover/src/codec/kafka.rs b/shotover/src/codec/kafka.rs index e61a9bb45..dddfdbd2d 100644 --- a/shotover/src/codec/kafka.rs +++ b/shotover/src/codec/kafka.rs @@ -5,10 +5,8 @@ use crate::frame::{Frame, MessageType}; use crate::message::{Encodable, Message, MessageId, Messages}; use anyhow::{anyhow, Result}; use bytes::BytesMut; -use kafka_protocol::messages::{ - ApiKey, RequestHeader as RequestHeaderProtocol, RequestKind, ResponseHeader, ResponseKind, - SaslAuthenticateRequest, SaslAuthenticateResponse, -}; +use kafka_protocol::messages::{ApiKey, RequestKind, ResponseKind}; +use kafka_protocol::protocol::StrBytes; use metrics::Histogram; use std::sync::mpsc; use std::time::Instant; @@ -16,7 +14,6 @@ use tokio_util::codec::{Decoder, Encoder}; #[derive(Copy, Clone, Debug, PartialEq)] pub struct RequestHeader { - // TODO: this should be i16??? pub api_key: ApiKey, pub version: i16, } @@ -66,21 +63,36 @@ impl CodecBuilder for KafkaCodecBuilder { pub struct RequestInfo { header: RequestHeader, id: MessageId, - expect_raw_sasl: Option, + expect_raw_sasl: Option, } +// Keeps track of the next expected sasl message #[derive(Debug, Clone, PartialEq, Copy)] -pub enum SaslType { +pub enum SaslRequestState { + /// The next message will be a sasl message in the PLAIN mechanism Plain, - ScramMessage1, - ScramMessage2, + /// The next message will be the first sasl message in the SCRAM mechanism + ScramFirst, + /// The next message will be the final sasl message in the SCRAM mechanism + ScramFinal, +} + +impl SaslRequestState { + fn from_name(mechanism: &StrBytes) -> Result { + match mechanism.as_str() { + "PLAIN" => Ok(SaslRequestState::Plain), + "SCRAM-SHA-512" => Ok(SaslRequestState::ScramFirst), + "SCRAM-SHA-256" => Ok(SaslRequestState::ScramFirst), + mechanism => Err(anyhow!("Unknown sasl mechanism {mechanism}")), + } + } } pub struct KafkaDecoder { // Some when Sink (because it receives responses) request_header_rx: Option>, direction: Direction, - expect_raw_sasl: Option, + expect_raw_sasl: Option, } impl KafkaDecoder { @@ -123,103 +135,82 @@ impl Decoder for KafkaDecoder { pretty_hex::pretty_hex(&bytes) ); + let request_info = self + .request_header_rx + .as_ref() + .map(|rx| rx.recv()) + .transpose() + .map_err(|_| CodecReadError::Parser(anyhow!("kafka encoder half was lost")))?; + struct Meta { request_header: RequestHeader, message_id: Option, } - let request_info = self - .request_header_rx - .as_ref() - .map(|rx| { - rx.recv() - .map_err(|_| CodecReadError::Parser(anyhow!("kafka encoder half was lost"))) - }) - .transpose()?; - - let message = if self.expect_raw_sasl.is_some() { - // Convert the unframed raw sasl into a framed sasl - // This allows transforms to correctly parse the message and inspect the sasl request - let kafka_frame = match self.direction { - Direction::Source => KafkaFrame::Request { - header: RequestHeaderProtocol::default() - .with_request_api_key(ApiKey::SaslAuthenticateKey as i16), - body: RequestKind::SaslAuthenticate( - SaslAuthenticateRequest::default().with_auth_bytes(bytes.freeze()), - ), - }, - Direction::Sink => KafkaFrame::Response { + let meta = if let Some(RequestInfo { header, id, .. }) = request_info { + Meta { + request_header: header, + message_id: Some(id), + } + } else if self.expect_raw_sasl.is_some() { + Meta { + request_header: RequestHeader { + api_key: ApiKey::SaslAuthenticateKey, version: 0, - header: ResponseHeader::default(), - body: ResponseKind::SaslAuthenticate( - SaslAuthenticateResponse::default().with_auth_bytes(bytes.freeze()), - // TODO: we need to set with_error_code - ), }, - }; - let codec_state = CodecState::Kafka(KafkaCodecState { - request_header: None, - raw_sasl: self.expect_raw_sasl, - }); - self.expect_raw_sasl = match self.expect_raw_sasl { - Some(SaslType::Plain) => None, - Some(SaslType::ScramMessage1) => Some(SaslType::ScramMessage2), - Some(SaslType::ScramMessage2) => None, - None => None, - }; - Message::from_frame_and_codec_state_at_instant( - Frame::Kafka(kafka_frame), - codec_state, + // This code path is only used for requests, so message_id can be None. + message_id: None, + } + } else { + Meta { + request_header: RequestHeader { + api_key: ApiKey::try_from(i16::from_be_bytes( + bytes[4..6].try_into().unwrap(), + )) + .unwrap(), + version: i16::from_be_bytes(bytes[6..8].try_into().unwrap()), + }, + // This code path is only used for requests, so message_id can be None. + message_id: None, + } + }; + let mut message = if let Some(id) = meta.message_id.as_ref() { + let mut message = Message::from_bytes_at_instant( + bytes.freeze(), + CodecState::Kafka(KafkaCodecState { + request_header: Some(meta.request_header), + raw_sasl: self.expect_raw_sasl, + }), Some(received_at), - ) + ); + message.set_request_id(*id); + message } else { - let meta = if let Some(RequestInfo { - header, - id, - expect_raw_sasl, - }) = request_info - { - if let Some(expect_raw_sasl) = expect_raw_sasl { - self.expect_raw_sasl = Some(expect_raw_sasl); - } - Meta { - request_header: header, - message_id: Some(id), - } - } else { - Meta { - request_header: RequestHeader { - api_key: ApiKey::try_from(i16::from_be_bytes( - bytes[4..6].try_into().unwrap(), - )) - .unwrap(), - version: i16::from_be_bytes(bytes[6..8].try_into().unwrap()), - }, - message_id: None, - } - }; - let mut message = if let Some(id) = meta.message_id.as_ref() { - let mut message = Message::from_bytes_at_instant( - bytes.freeze(), - CodecState::Kafka(KafkaCodecState { - request_header: Some(meta.request_header), - raw_sasl: None, - }), - Some(received_at), - ); - message.set_request_id(*id); - message - } else { - Message::from_bytes_at_instant( - bytes.freeze(), - CodecState::Kafka(KafkaCodecState { - request_header: None, - raw_sasl: None, - }), - Some(received_at), - ) - }; + Message::from_bytes_at_instant( + bytes.freeze(), + CodecState::Kafka(KafkaCodecState { + request_header: None, + raw_sasl: self.expect_raw_sasl, + }), + Some(received_at), + ) + }; + // advanced to the next state of expect_raw_sasl + self.expect_raw_sasl = match self.expect_raw_sasl { + Some(SaslRequestState::Plain) => None, + Some(SaslRequestState::ScramFirst) => Some(SaslRequestState::ScramFinal), + Some(SaslRequestState::ScramFinal) => None, + None => None, + }; + + if let Some(request_info) = request_info { + // set expect_raw_sasl for responses + if let Some(expect_raw_sasl) = request_info.expect_raw_sasl { + self.expect_raw_sasl = Some(expect_raw_sasl); + } + } else { + // set expect_raw_sasl for requests if meta.request_header.api_key == ApiKey::SaslHandshakeKey && meta.request_header.version == 0 { @@ -229,16 +220,10 @@ impl Decoder for KafkaDecoder { .. })) = message.frame() { - self.expect_raw_sasl = Some(match sasl_handshake.mechanism.as_str() { - "PLAIN" => SaslType::Plain, - "SCRAM-SHA-512" => SaslType::ScramMessage1, - "SCRAM-SHA-256" => SaslType::ScramMessage1, - mechanism => { - return Err(CodecReadError::Parser(anyhow!( - "Unknown sasl mechanism {mechanism}" - ))) - } - }); + self.expect_raw_sasl = Some( + SaslRequestState::from_name(&sasl_handshake.mechanism) + .map_err(CodecReadError::Parser)?, + ); // Clear raw bytes of the message to force the encoder to encode from frame. // This is needed because the encoder only has access to the frame if it does not have any raw bytes, @@ -246,8 +231,7 @@ impl Decoder for KafkaDecoder { message.invalidate_cache(); } } - message - }; + } Ok(Some(vec![message])) } else { @@ -288,7 +272,11 @@ impl Encoder for KafkaEncoder { let response_is_dummy = m.response_is_dummy(); let id = m.id(); let received_at = m.received_from_source_or_sink_at; - let codec_state = m.codec_state.as_kafka(); + let message_contains_raw_sasl = if let CodecState::Kafka(codec_state) = m.codec_state { + codec_state.raw_sasl.is_some() + } else { + false + }; let mut expect_raw_sasl = None; let result = match m.into_encodable() { Encodable::Bytes(bytes) => { @@ -296,7 +284,7 @@ impl Encoder for KafkaEncoder { Ok(()) } Encodable::Frame(frame) => { - if codec_state.raw_sasl.is_some() { + if message_contains_raw_sasl { match frame { Frame::Kafka(KafkaFrame::Request { body: RequestKind::SaslAuthenticate(body), @@ -315,23 +303,17 @@ impl Encoder for KafkaEncoder { Ok(()) } else { let frame = frame.into_kafka().unwrap(); - // it is garanteed that all v0 SaslHandshakes will be in a parsed state since we parse it in the KafkaDecoder. + // it is garanteed that all v0 SaslHandshakes will be in a parsed state since we parse + invalidate_cache in the KafkaDecoder. if let KafkaFrame::Request { body: RequestKind::SaslHandshake(sasl_handshake), header, } = &frame { if header.request_api_version == 0 { - expect_raw_sasl = Some(match sasl_handshake.mechanism.as_str() { - "PLAIN" => SaslType::Plain, - "SCRAM-SHA-512" => SaslType::ScramMessage1, - "SCRAM-SHA-256" => SaslType::ScramMessage1, - mechanism => { - return Err(CodecWriteError::Encoder(anyhow!( - "Unknown sasl mechanism {mechanism}" - ))) - } - }); + expect_raw_sasl = Some( + SaslRequestState::from_name(&sasl_handshake.mechanism) + .map_err(CodecWriteError::Encoder)?, + ); } } frame.encode(dst) @@ -343,7 +325,7 @@ impl Encoder for KafkaEncoder { // or if it will generate a dummy response if !dst[start..].is_empty() && !response_is_dummy { if let Some(tx) = self.request_header_tx.as_ref() { - let header = if codec_state.raw_sasl.is_some() { + let header = if message_contains_raw_sasl { RequestHeader { api_key: ApiKey::SaslAuthenticateKey, version: 0, diff --git a/shotover/src/codec/mod.rs b/shotover/src/codec/mod.rs index 6a373cfe5..ec6cb3a13 100644 --- a/shotover/src/codec/mod.rs +++ b/shotover/src/codec/mod.rs @@ -7,7 +7,7 @@ use core::fmt; #[cfg(feature = "kafka")] use kafka::RequestHeader; #[cfg(feature = "kafka")] -use kafka::SaslType; +use kafka::SaslRequestState; use metrics::{histogram, Histogram}; use tokio_util::codec::{Decoder, Encoder}; @@ -46,6 +46,14 @@ pub fn message_latency(direction: Direction, destination_name: String) -> Histog } } +/// Database protocols are often designed such that their messages can be parsed without knowledge of any state of prior messages. +/// When protocols remain stateless, Shotover's parser implementations can remain fairly simple. +/// However in the real world there is often some kind of connection level state that we need in order to parse to requests. +/// +/// Shotover solves this issue via this enum which provides any of the connection level state required to decode and then reencode messages. +/// 1. The Decoder includes this value in all messages it produces. +/// 2. If any transforms call `.frame()` this value is used to parse the frame of the message. +/// 3. The Encoder uses this value to reencode the message if it has been modified. #[derive(Debug, Clone, PartialEq, Copy)] pub enum CodecState { #[cfg(feature = "cassandra")] @@ -86,8 +94,13 @@ impl CodecState { #[cfg(feature = "kafka")] #[derive(Debug, Clone, PartialEq, Copy)] pub struct KafkaCodecState { + /// When the message is: + /// a request - this value is None + /// a response - this value is Some and contains the header values of the corresponding request. pub request_header: Option, - pub raw_sasl: Option, + /// When `Some` this message is not a valid kafka protocol message and is instead a raw SASL message. + /// KafkaFrame will parse this as a SaslHandshake to hide the legacy raw SASL message from transform implementations. + pub raw_sasl: Option, } #[derive(Debug)] diff --git a/shotover/src/frame/kafka.rs b/shotover/src/frame/kafka.rs index c08773b1a..8b55ce77a 100644 --- a/shotover/src/frame/kafka.rs +++ b/shotover/src/frame/kafka.rs @@ -2,7 +2,9 @@ use crate::codec::kafka::RequestHeader as CodecRequestHeader; use crate::codec::KafkaCodecState; use anyhow::{anyhow, Context, Result}; use bytes::{BufMut, Bytes, BytesMut}; -use kafka_protocol::messages::{ApiKey, RequestHeader, ResponseHeader}; +use kafka_protocol::messages::{ + ApiKey, RequestHeader, ResponseHeader, SaslAuthenticateRequest, SaslAuthenticateResponse, +}; use kafka_protocol::protocol::{Decodable, Encodable}; use std::fmt::{Display, Formatter, Result as FmtResult}; @@ -70,12 +72,35 @@ impl Display for KafkaFrame { impl KafkaFrame { pub fn from_bytes(mut bytes: Bytes, codec_state: KafkaCodecState) -> Result { - // remove length header - let _ = bytes.split_to(4); - - match &codec_state.request_header { - Some(request_header) => KafkaFrame::parse_response(bytes, *request_header), - None => KafkaFrame::parse_request(bytes), + if codec_state.raw_sasl.is_some() { + match &codec_state.request_header { + Some(_) => Ok(KafkaFrame::Response { + version: 0, + header: ResponseHeader::default(), + body: ResponseBody::SaslAuthenticate( + SaslAuthenticateResponse::default().with_auth_bytes(bytes), + // We dont set error_code field when the response contains a scram error, which sounds problematic. + // But in reality, at least for raw sasl mode, if kafka encounters an auth failure, + // it just kills the connection without sending any sasl response at all. + // So we never actually receive a scram response containing an error and + // so there would be no case where the error_code field would need to be set. + ), + }), + None => Ok(KafkaFrame::Request { + header: RequestHeader::default() + .with_request_api_key(ApiKey::SaslAuthenticateKey as i16), + body: RequestBody::SaslAuthenticate( + SaslAuthenticateRequest::default().with_auth_bytes(bytes), + ), + }), + } + } else { + // remove length header + let _ = bytes.split_to(4); + match &codec_state.request_header { + Some(request_header) => KafkaFrame::parse_response(bytes, *request_header), + None => KafkaFrame::parse_request(bytes), + } } } diff --git a/shotover/src/transforms/kafka/sink_cluster/connections.rs b/shotover/src/transforms/kafka/sink_cluster/connections.rs index 97d71c4a5..8764230b8 100644 --- a/shotover/src/transforms/kafka/sink_cluster/connections.rs +++ b/shotover/src/transforms/kafka/sink_cluster/connections.rs @@ -238,16 +238,14 @@ impl Connections { } else { KafkaNodeState::Up }; - nodes - .iter() - .find(|x| match destination { - Destination::Id(id) => x.broker_id == id, - Destination::ControlConnection => { - &x.kafka_address == self.control_connection_address.as_ref().unwrap() - } - }) - .unwrap() - .set_state(node_state); + if let Some(node) = nodes.iter().find(|x| match destination { + Destination::Id(id) => x.broker_id == id, + Destination::ControlConnection => { + &x.kafka_address == self.control_connection_address.as_ref().unwrap() + } + }) { + node.set_state(node_state); + } if old_connection .map(|old| old.pending_requests_count()) diff --git a/shotover/src/transforms/kafka/sink_cluster/mod.rs b/shotover/src/transforms/kafka/sink_cluster/mod.rs index d03a568ad..1ade57a9c 100644 --- a/shotover/src/transforms/kafka/sink_cluster/mod.rs +++ b/shotover/src/transforms/kafka/sink_cluster/mod.rs @@ -273,8 +273,10 @@ impl AtomicBrokerId { } fn set(&self, value: BrokerId) { - self.0 - .store(value.0.into(), std::sync::atomic::Ordering::Relaxed) + if value != -1 { + self.0 + .store(value.0.into(), std::sync::atomic::Ordering::Relaxed) + } } fn clear(&self) { @@ -989,7 +991,8 @@ impl KafkaSinkCluster { | RequestBody::AlterConfigs(_) | RequestBody::CreatePartitions(_) | RequestBody::DeleteTopics(_) - | RequestBody::CreateAcls(_), + | RequestBody::CreateAcls(_) + | RequestBody::ApiVersions(_), .. })) => self.route_to_random_broker(message), @@ -2406,9 +2409,9 @@ impl KafkaSinkCluster { ResponseError::try_from_code(topic.error_code) { tracing::info!( - "Response to CreateTopics included error NOT_CONTROLLER and so reset controller broker, previously was {:?}", - self.controller_broker.get() - ); + "Response to CreateTopics included error NOT_CONTROLLER and so reset controller broker, previously was {:?}", + self.controller_broker.get() + ); self.controller_broker.clear(); break; } diff --git a/test-helpers/src/connection/kafka/python.rs b/test-helpers/src/connection/kafka/python.rs index 90c7e62b1..7f500ed31 100644 --- a/test-helpers/src/connection/kafka/python.rs +++ b/test-helpers/src/connection/kafka/python.rs @@ -80,6 +80,32 @@ pub async fn run_python_smoke_test_sasl_scram(address: &str, user: &str, passwor .unwrap(); } +pub async fn run_python_bad_auth_sasl_scram(address: &str, user: &str, password: &str) { + ensure_uv_is_installed().await; + + let project_dir = Path::new(env!("CARGO_MANIFEST_DIR")).join("src/connection/kafka/python"); + let uv_binary = uv_binary_path(); + let config = format!( + r#"{{ + 'bootstrap_servers': ["{address}"], + 'security_protocol': "SASL_PLAINTEXT", + 'sasl_mechanism': "SCRAM-SHA-256", + 'sasl_plain_username': "{user}", + 'sasl_plain_password': "{password}", +}}"# + ); + tokio::time::timeout( + Duration::from_secs(60), + run_command_async( + &project_dir, + uv_binary.to_str().unwrap(), + &["run", "auth_fail.py", &config], + ), + ) + .await + .unwrap(); +} + /// Install a specific version of UV to: /// * avoid developers having to manually install an external tool /// * avoid issues due to a different version being installed diff --git a/test-helpers/src/connection/kafka/python/auth_fail.py b/test-helpers/src/connection/kafka/python/auth_fail.py new file mode 100644 index 000000000..f81229e18 --- /dev/null +++ b/test-helpers/src/connection/kafka/python/auth_fail.py @@ -0,0 +1,18 @@ +from kafka import KafkaConsumer +from kafka import KafkaProducer +from kafka.errors import KafkaError +import sys + +def main(): + config = eval(sys.argv[1]) + print("Running kafka-python script with config:") + print(config) + + try: + KafkaProducer(**config) + raise Exception("KafkaProducer was succesfully created but expected to fail due to using incorrect username/password") + except KafkaError: + print("kafka-python auth_fail script passed all test cases") + +if __name__ == "__main__": + main() diff --git a/test-helpers/src/connection/kafka/python/main.py b/test-helpers/src/connection/kafka/python/main.py index 927383fe6..57d62e2a2 100644 --- a/test-helpers/src/connection/kafka/python/main.py +++ b/test-helpers/src/connection/kafka/python/main.py @@ -1,5 +1,7 @@ from kafka import KafkaConsumer +from kafka import KafkaAdminClient from kafka import KafkaProducer +from kafka.admin import NewTopic import sys def main(): @@ -7,19 +9,28 @@ def main(): print("Running kafka-python script with config:") print(config) + admin = KafkaAdminClient(**config) + admin.create_topics([ + NewTopic( + name='python_test_topic', + num_partitions=1, + replication_factor=1 + ) + ]) + producer = KafkaProducer(**config) - producer.send('test_topic', b'some_message_bytes').get(timeout=10) - producer.send('test_topic', b'another_message').get(timeout=10) + producer.send('python_test_topic', b'some_message_bytes').get(timeout=30) + producer.send('python_test_topic', b'another_message').get(timeout=30) - consumer = KafkaConsumer('test_topic', auto_offset_reset='earliest', **config) + consumer = KafkaConsumer('python_test_topic', auto_offset_reset='earliest', **config) msg = next(consumer) - assert(msg.topic == "test_topic") + assert(msg.topic == "python_test_topic") assert(msg.value == b"some_message_bytes") assert(msg.offset == 0) msg = next(consumer) - assert(msg.topic == "test_topic") + assert(msg.topic == "python_test_topic") assert(msg.value == b"another_message") assert(msg.offset == 1) diff --git a/test-helpers/src/lib.rs b/test-helpers/src/lib.rs index ba3a52bc4..5cf831393 100644 --- a/test-helpers/src/lib.rs +++ b/test-helpers/src/lib.rs @@ -42,6 +42,7 @@ pub async fn run_command_async(current_dir: &Path, command: &str, args: &[&str]) let output = tokio::process::Command::new(command) .args(args) .current_dir(current_dir) + .kill_on_drop(true) .status() .await .unwrap();