diff --git a/shotover-proxy/tests/kafka_int_tests/mod.rs b/shotover-proxy/tests/kafka_int_tests/mod.rs index a7b559f38..e046aad48 100644 --- a/shotover-proxy/tests/kafka_int_tests/mod.rs +++ b/shotover-proxy/tests/kafka_int_tests/mod.rs @@ -9,6 +9,8 @@ use test_cases::produce_consume_partitions1; use test_cases::produce_consume_partitions3; use test_cases::{assert_topic_creation_is_denied_due_to_acl, setup_basic_user_acls}; use test_helpers::connection::kafka::node::run_node_smoke_test_scram; +use test_helpers::connection::kafka::python::run_python_bad_auth_sasl_scram; +use test_helpers::connection::kafka::python::run_python_smoke_test_sasl_scram; use test_helpers::connection::kafka::{KafkaConnectionBuilder, KafkaDriver}; use test_helpers::docker_compose::docker_compose; use test_helpers::shotover_process::{Count, EventMatcher}; @@ -37,7 +39,7 @@ async fn passthrough_standard(#[case] driver: KafkaDriver) { } #[tokio::test] -async fn passthrough_nodejs() { +async fn passthrough_nodejs_and_python() { let _docker_compose = docker_compose("tests/test-configs/kafka/passthrough/docker-compose.yaml"); let shotover = shotover_process("tests/test-configs/kafka/passthrough/topology.yaml") @@ -45,23 +47,6 @@ async fn passthrough_nodejs() { .await; test_helpers::connection::kafka::node::run_node_smoke_test("127.0.0.1:9192").await; - - tokio::time::timeout( - Duration::from_secs(10), - shotover.shutdown_and_then_consume_events(&[]), - ) - .await - .expect("Shotover did not shutdown within 10s"); -} - -#[tokio::test] -async fn passthrough_python() { - let _docker_compose = - docker_compose("tests/test-configs/kafka/passthrough/docker-compose.yaml"); - let shotover = shotover_process("tests/test-configs/kafka/passthrough/topology.yaml") - .start() - .await; - test_helpers::connection::kafka::python::run_python_smoke_test("127.0.0.1:9192").await; tokio::time::timeout( @@ -206,6 +191,27 @@ async fn passthrough_sasl_plain(#[case] driver: KafkaDriver) { shotover.shutdown_and_then_consume_events(&[]).await; } +#[cfg(feature = "alpha-transforms")] +#[rstest] +#[tokio::test] +async fn passthrough_sasl_plain_python() { + let _docker_compose = + docker_compose("tests/test-configs/kafka/passthrough-sasl-plain/docker-compose.yaml"); + let shotover = + shotover_process("tests/test-configs/kafka/passthrough-sasl-plain/topology.yaml") + .start() + .await; + + test_helpers::connection::kafka::python::run_python_smoke_test_sasl_plain( + "127.0.0.1:9192", + "user", + "password", + ) + .await; + + shotover.shutdown_and_then_consume_events(&[]).await; +} + #[rstest] #[case::java(KafkaDriver::Java)] #[tokio::test(flavor = "multi_thread")] // multi_thread is needed since java driver will block when consuming, causing shotover logs to not appear @@ -745,25 +751,58 @@ async fn assert_connection_fails_with_incorrect_password(driver: KafkaDriver, us #[rstest] #[tokio::test] -async fn cluster_sasl_scram_over_mtls_nodejs() { +async fn cluster_sasl_scram_over_mtls_nodejs_and_python() { test_helpers::cert::generate_kafka_test_certs(); let _docker_compose = docker_compose("tests/test-configs/kafka/cluster-sasl-scram-over-mtls/docker-compose.yaml"); - let shotover = shotover_process( - "tests/test-configs/kafka/cluster-sasl-scram-over-mtls/topology-single.yaml", - ) - .start() - .await; - run_node_smoke_test_scram("127.0.0.1:9192", "super_user", "super_password").await; + { + let shotover = shotover_process( + "tests/test-configs/kafka/cluster-sasl-scram-over-mtls/topology-single.yaml", + ) + .start() + .await; - tokio::time::timeout( - Duration::from_secs(10), - shotover.shutdown_and_then_consume_events(&[]), - ) - .await - .expect("Shotover did not shutdown within 10s"); + run_node_smoke_test_scram("127.0.0.1:9192", "super_user", "super_password").await; + run_python_smoke_test_sasl_scram("127.0.0.1:9192", "super_user", "super_password").await; + + tokio::time::timeout( + Duration::from_secs(10), + shotover.shutdown_and_then_consume_events(&[]), + ) + .await + .expect("Shotover did not shutdown within 10s"); + } + + { + let shotover = shotover_process( + "tests/test-configs/kafka/cluster-sasl-scram-over-mtls/topology-single.yaml", + ) + .start() + .await; + + run_python_bad_auth_sasl_scram("127.0.0.1:9192", "incorrect_user", "super_password").await; + run_python_bad_auth_sasl_scram("127.0.0.1:9192", "super_user", "incorrect_password").await; + + tokio::time::timeout( + Duration::from_secs(10), + shotover.shutdown_and_then_consume_events(&[EventMatcher::new() + .with_level(Level::Error) + .with_target("shotover::server") + .with_message(r#"encountered an error when flushing the chain kafka for shutdown + +Caused by: + 0: KafkaSinkCluster transform failed + 1: Failed to receive responses (without sending requests) + 2: Outgoing connection had pending requests, those requests/responses are lost so connection recovery cannot be attempted. + 3: Failed to receive from ControlConnection + 4: The other side of this connection closed the connection"#) + .with_count(Count::Times(2))]), + ) + .await + .expect("Shotover did not shutdown within 10s"); + } } #[rstest] diff --git a/shotover/benches/benches/codec/kafka.rs b/shotover/benches/benches/codec/kafka.rs index c1ab3bea3..2c34b1637 100644 --- a/shotover/benches/benches/codec/kafka.rs +++ b/shotover/benches/benches/codec/kafka.rs @@ -1,6 +1,7 @@ use bytes::{Bytes, BytesMut}; use criterion::{criterion_group, BatchSize, Criterion}; use shotover::codec::kafka::KafkaCodecBuilder; +use shotover::codec::kafka::KafkaCodecState; use shotover::codec::{CodecBuilder, CodecState, Direction}; use shotover::message::Message; use tokio_util::codec::{Decoder, Encoder}; @@ -77,9 +78,10 @@ fn criterion_benchmark(c: &mut Criterion) { { let mut message = Message::from_bytes( Bytes::from(message.to_vec()), - CodecState::Kafka { + CodecState::Kafka(KafkaCodecState { request_header: None, - }, + raw_sasl: false, + }), ); // force the message to be parsed and clear raw message message.frame(); @@ -113,9 +115,10 @@ fn criterion_benchmark(c: &mut Criterion) { for (message, _) in KAFKA_REQUESTS { let mut message = Message::from_bytes( Bytes::from(message.to_vec()), - CodecState::Kafka { + CodecState::Kafka(KafkaCodecState { request_header: None, - }, + raw_sasl: false, + }), ); // force the message to be parsed and clear raw message message.frame(); diff --git a/shotover/src/codec/kafka.rs b/shotover/src/codec/kafka.rs index 3ce03910b..11a87f394 100644 --- a/shotover/src/codec/kafka.rs +++ b/shotover/src/codec/kafka.rs @@ -1,10 +1,12 @@ use super::{message_latency, CodecWriteError, Direction}; use crate::codec::{CodecBuilder, CodecReadError, CodecState}; -use crate::frame::MessageType; +use crate::frame::kafka::KafkaFrame; +use crate::frame::{Frame, MessageType}; use crate::message::{Encodable, Message, MessageId, Messages}; use anyhow::{anyhow, Result}; use bytes::BytesMut; -use kafka_protocol::messages::ApiKey; +use kafka_protocol::messages::{ApiKey, RequestKind, ResponseKind}; +use kafka_protocol::protocol::StrBytes; use metrics::Histogram; use std::sync::mpsc; use std::time::Instant; @@ -56,16 +58,40 @@ impl CodecBuilder for KafkaCodecBuilder { MessageType::Kafka } } - +#[derive(Debug)] pub struct RequestInfo { header: RequestHeader, id: MessageId, + expect_raw_sasl: Option, +} + +// Keeps track of the next expected sasl message +#[derive(Debug, Clone, PartialEq, Copy)] +pub enum SaslMessageState { + /// The next message will be a sasl message in the PLAIN mechanism + Plain, + /// The next message will be the first sasl message in the SCRAM mechanism + ScramFirst, + /// The next message will be the final sasl message in the SCRAM mechanism + ScramFinal, +} + +impl SaslMessageState { + fn from_name(mechanism: &StrBytes) -> Result { + match mechanism.as_str() { + "PLAIN" => Ok(SaslMessageState::Plain), + "SCRAM-SHA-512" => Ok(SaslMessageState::ScramFirst), + "SCRAM-SHA-256" => Ok(SaslMessageState::ScramFirst), + mechanism => Err(anyhow!("Unknown sasl mechanism {mechanism}")), + } + } } pub struct KafkaDecoder { // Some when Sink (because it receives responses) request_header_rx: Option>, direction: Direction, + expect_raw_sasl: Option, } impl KafkaDecoder { @@ -76,12 +102,13 @@ impl KafkaDecoder { KafkaDecoder { request_header_rx, direction, + expect_raw_sasl: None, } } } fn get_length_of_full_message(src: &BytesMut) -> Option { - if src.len() > 4 { + if src.len() >= 4 { let size = u32::from_be_bytes(src[0..4].try_into().unwrap()) as usize + 4; if size <= src.len() { Some(size) @@ -106,28 +133,105 @@ impl Decoder for KafkaDecoder { self.direction, pretty_hex::pretty_hex(&bytes) ); - let message = if let Some(rx) = self.request_header_rx.as_ref() { - let RequestInfo { header, id } = rx - .recv() - .map_err(|_| CodecReadError::Parser(anyhow!("kafka encoder half was lost")))?; + + let request_info = self + .request_header_rx + .as_ref() + .map(|rx| rx.recv()) + .transpose() + .map_err(|_| CodecReadError::Parser(anyhow!("kafka encoder half was lost")))?; + + struct Meta { + request_header: RequestHeader, + message_id: Option, + } + + let meta = if let Some(RequestInfo { header, id, .. }) = request_info { + Meta { + request_header: header, + message_id: Some(id), + } + } else if self.expect_raw_sasl.is_some() { + Meta { + request_header: RequestHeader { + api_key: ApiKey::SaslAuthenticateKey, + version: 0, + }, + // This code path is only used for requests, so message_id can be None. + message_id: None, + } + } else { + Meta { + request_header: RequestHeader { + api_key: ApiKey::try_from(i16::from_be_bytes( + bytes[4..6].try_into().unwrap(), + )) + .unwrap(), + version: i16::from_be_bytes(bytes[6..8].try_into().unwrap()), + }, + // This code path is only used for requests, so message_id can be None. + message_id: None, + } + }; + let mut message = if let Some(id) = meta.message_id.as_ref() { let mut message = Message::from_bytes_at_instant( bytes.freeze(), - CodecState::Kafka { - request_header: Some(header), - }, + CodecState::Kafka(KafkaCodecState { + request_header: Some(meta.request_header), + raw_sasl: self.expect_raw_sasl.is_some(), + }), Some(received_at), ); - message.set_request_id(id); + message.set_request_id(*id); message } else { Message::from_bytes_at_instant( bytes.freeze(), - CodecState::Kafka { + CodecState::Kafka(KafkaCodecState { request_header: None, - }, + raw_sasl: self.expect_raw_sasl.is_some(), + }), Some(received_at), ) }; + + // advanced to the next state of expect_raw_sasl + self.expect_raw_sasl = match self.expect_raw_sasl { + Some(SaslMessageState::Plain) => None, + Some(SaslMessageState::ScramFirst) => Some(SaslMessageState::ScramFinal), + Some(SaslMessageState::ScramFinal) => None, + None => None, + }; + + if let Some(request_info) = request_info { + // set expect_raw_sasl for responses + if let Some(expect_raw_sasl) = request_info.expect_raw_sasl { + self.expect_raw_sasl = Some(expect_raw_sasl); + } + } else { + // set expect_raw_sasl for requests + if meta.request_header.api_key == ApiKey::SaslHandshakeKey + && meta.request_header.version == 0 + { + // Only parse the full frame once we manually check its a v0 sasl handshake + if let Some(Frame::Kafka(KafkaFrame::Request { + body: RequestKind::SaslHandshake(sasl_handshake), + .. + })) = message.frame() + { + self.expect_raw_sasl = Some( + SaslMessageState::from_name(&sasl_handshake.mechanism) + .map_err(CodecReadError::Parser)?, + ); + + // Clear raw bytes of the message to force the encoder to encode from frame. + // This is needed because the encoder only has access to the frame if it does not have any raw bytes, + // and the encoder needs to inspect the frame to set its own sasl state. + message.invalidate_cache(); + } + } + } + Ok(Some(vec![message])) } else { Ok(None) @@ -167,28 +271,84 @@ impl Encoder for KafkaEncoder { let response_is_dummy = m.response_is_dummy(); let id = m.id(); let received_at = m.received_from_source_or_sink_at; + let message_contains_raw_sasl = if let CodecState::Kafka(codec_state) = m.codec_state { + codec_state.raw_sasl + } else { + false + }; + let mut expect_raw_sasl = None; let result = match m.into_encodable() { Encodable::Bytes(bytes) => { dst.extend_from_slice(&bytes); Ok(()) } - Encodable::Frame(frame) => frame.into_kafka().unwrap().encode(dst), + Encodable::Frame(frame) => { + if message_contains_raw_sasl { + match frame { + Frame::Kafka(KafkaFrame::Request { + body: RequestKind::SaslAuthenticate(body), + .. + }) => { + dst.extend_from_slice(&body.auth_bytes); + } + Frame::Kafka(KafkaFrame::Response { + body: ResponseKind::SaslAuthenticate(body), + .. + }) => { + dst.extend_from_slice(&body.auth_bytes); + } + _ => unreachable!("not expected {frame:?}"), + } + Ok(()) + } else { + let frame = frame.into_kafka().unwrap(); + // it is garanteed that all v0 SaslHandshakes will be in a parsed state since we parse + invalidate_cache in the KafkaDecoder. + if let KafkaFrame::Request { + body: RequestKind::SaslHandshake(sasl_handshake), + header, + } = &frame + { + if header.request_api_version == 0 { + expect_raw_sasl = Some( + SaslMessageState::from_name(&sasl_handshake.mechanism) + .map_err(CodecWriteError::Encoder)?, + ); + } + } + frame.encode(dst) + } + } }; // Skip if the message wrote nothing to dst, possibly due to being a dummy message. // or if it will generate a dummy response if !dst[start..].is_empty() && !response_is_dummy { if let Some(tx) = self.request_header_tx.as_ref() { - let api_key = i16::from_be_bytes(dst[start + 4..start + 6].try_into().unwrap()); - let version = i16::from_be_bytes(dst[start + 6..start + 8].try_into().unwrap()); - let api_key = ApiKey::try_from(api_key).map_err(|_| { - CodecWriteError::Encoder(anyhow!("unknown api key {api_key}")) - })?; - tx.send(RequestInfo { - header: RequestHeader { api_key, version }, + let header = if message_contains_raw_sasl { + RequestHeader { + api_key: ApiKey::SaslAuthenticateKey, + version: 0, + } + } else { + let api_key = + i16::from_be_bytes(dst[start + 4..start + 6].try_into().unwrap()); + let version = + i16::from_be_bytes(dst[start + 6..start + 8].try_into().unwrap()); + // TODO: handle unknown API key + let api_key = ApiKey::try_from(api_key).map_err(|_| { + CodecWriteError::Encoder(anyhow!("unknown api key {api_key}")) + })?; + + RequestHeader { api_key, version } + }; + + let request_info = RequestInfo { + header, id, - }) - .map_err(|e| CodecWriteError::Encoder(anyhow!(e)))?; + expect_raw_sasl, + }; + tx.send(request_info) + .map_err(|e| CodecWriteError::Encoder(anyhow!(e)))?; } } @@ -204,3 +364,15 @@ impl Encoder for KafkaEncoder { }) } } + +#[cfg(feature = "kafka")] +#[derive(Debug, Clone, PartialEq, Copy)] +pub struct KafkaCodecState { + /// When the message is: + /// a request - this value is None + /// a response - this value is Some and contains the header values of the corresponding request. + pub request_header: Option, + /// When `true` this message is not a valid kafka protocol message and is instead a raw SASL message. + /// KafkaFrame will parse this as a SaslHandshake to hide the legacy raw SASL message from transform implementations. + pub raw_sasl: bool, +} diff --git a/shotover/src/codec/mod.rs b/shotover/src/codec/mod.rs index f750ab526..3d5036b79 100644 --- a/shotover/src/codec/mod.rs +++ b/shotover/src/codec/mod.rs @@ -5,7 +5,7 @@ use crate::{frame::MessageType, message::Messages}; use cassandra_protocol::compression::Compression; use core::fmt; #[cfg(feature = "kafka")] -use kafka::RequestHeader; +use kafka::KafkaCodecState; use metrics::{histogram, Histogram}; use tokio_util::codec::{Decoder, Encoder}; @@ -44,6 +44,14 @@ pub fn message_latency(direction: Direction, destination_name: String) -> Histog } } +/// Database protocols are often designed such that their messages can be parsed without knowledge of any state of prior messages. +/// When protocols remain stateless, Shotover's parser implementations can remain fairly simple. +/// However in the real world there is often some kind of connection level state that we need to track in order to parse messages. +/// +/// Shotover solves this issue via this enum which provides any of the connection level state required to decode and then reencode messages. +/// 1. The Decoder includes this value in all messages it produces. +/// 2. If any transforms call `.frame()` this value is used to parse the frame of the message. +/// 3. The Encoder uses this value to reencode the message if it has been modified. #[derive(Debug, Clone, PartialEq, Copy)] pub enum CodecState { #[cfg(feature = "cassandra")] @@ -53,9 +61,7 @@ pub enum CodecState { #[cfg(feature = "redis")] Redis, #[cfg(feature = "kafka")] - Kafka { - request_header: Option, - }, + Kafka(KafkaCodecState), Dummy, #[cfg(feature = "opensearch")] OpenSearch, @@ -73,9 +79,9 @@ impl CodecState { } #[cfg(feature = "kafka")] - pub fn as_kafka(&self) -> Option { + pub fn as_kafka(&self) -> KafkaCodecState { match self { - CodecState::Kafka { request_header } => *request_header, + CodecState::Kafka(state) => *state, _ => { panic!("This is a {self:?}, expected CodecState::Kafka") } diff --git a/shotover/src/frame/kafka.rs b/shotover/src/frame/kafka.rs index 68d373648..c8eb20ea9 100644 --- a/shotover/src/frame/kafka.rs +++ b/shotover/src/frame/kafka.rs @@ -1,7 +1,10 @@ +use crate::codec::kafka::KafkaCodecState; use crate::codec::kafka::RequestHeader as CodecRequestHeader; use anyhow::{anyhow, Context, Result}; use bytes::{BufMut, Bytes, BytesMut}; -use kafka_protocol::messages::{ApiKey, RequestHeader, ResponseHeader}; +use kafka_protocol::messages::{ + ApiKey, RequestHeader, ResponseHeader, SaslAuthenticateRequest, SaslAuthenticateResponse, +}; use kafka_protocol::protocol::{Decodable, Encodable}; use std::fmt::{Display, Formatter, Result as FmtResult}; @@ -68,16 +71,36 @@ impl Display for KafkaFrame { } impl KafkaFrame { - pub fn from_bytes( - mut bytes: Bytes, - request_header: Option, - ) -> Result { - // remove length header - let _ = bytes.split_to(4); - - match request_header { - Some(request_header) => KafkaFrame::parse_response(bytes, request_header), - None => KafkaFrame::parse_request(bytes), + pub fn from_bytes(mut bytes: Bytes, codec_state: KafkaCodecState) -> Result { + if codec_state.raw_sasl { + match &codec_state.request_header { + Some(_) => Ok(KafkaFrame::Response { + version: 0, + header: ResponseHeader::default(), + body: ResponseBody::SaslAuthenticate( + SaslAuthenticateResponse::default().with_auth_bytes(bytes), + // We dont set error_code field when the response contains a scram error, which sounds problematic. + // But in reality, at least for raw sasl mode, if kafka encounters an auth failure, + // it just kills the connection without sending any sasl response at all. + // So we never actually receive a scram response containing an error and + // so there would be no case where the error_code field would need to be set. + ), + }), + None => Ok(KafkaFrame::Request { + header: RequestHeader::default() + .with_request_api_key(ApiKey::SaslAuthenticateKey as i16), + body: RequestBody::SaslAuthenticate( + SaslAuthenticateRequest::default().with_auth_bytes(bytes), + ), + }), + } + } else { + // remove length header + let _ = bytes.split_to(4); + match &codec_state.request_header { + Some(request_header) => KafkaFrame::parse_response(bytes, *request_header), + None => KafkaFrame::parse_request(bytes), + } } } diff --git a/shotover/src/frame/mod.rs b/shotover/src/frame/mod.rs index c138c100e..aa9c61745 100644 --- a/shotover/src/frame/mod.rs +++ b/shotover/src/frame/mod.rs @@ -1,5 +1,7 @@ //! parsed AST-like representations of messages +#[cfg(feature = "kafka")] +use crate::codec::kafka::KafkaCodecState; use crate::codec::CodecState; use anyhow::{anyhow, Result}; use bytes::Bytes; @@ -94,9 +96,10 @@ impl Frame { #[cfg(feature = "redis")] Frame::Redis(_) => CodecState::Redis, #[cfg(feature = "kafka")] - Frame::Kafka(_) => CodecState::Kafka { + Frame::Kafka(_) => CodecState::Kafka(KafkaCodecState { request_header: None, - }, + raw_sasl: false, + }), Frame::Dummy => CodecState::Dummy, #[cfg(feature = "opensearch")] Frame::OpenSearch(_) => CodecState::OpenSearch, diff --git a/shotover/src/transforms/kafka/sink_cluster/connections.rs b/shotover/src/transforms/kafka/sink_cluster/connections.rs index 97d71c4a5..8764230b8 100644 --- a/shotover/src/transforms/kafka/sink_cluster/connections.rs +++ b/shotover/src/transforms/kafka/sink_cluster/connections.rs @@ -238,16 +238,14 @@ impl Connections { } else { KafkaNodeState::Up }; - nodes - .iter() - .find(|x| match destination { - Destination::Id(id) => x.broker_id == id, - Destination::ControlConnection => { - &x.kafka_address == self.control_connection_address.as_ref().unwrap() - } - }) - .unwrap() - .set_state(node_state); + if let Some(node) = nodes.iter().find(|x| match destination { + Destination::Id(id) => x.broker_id == id, + Destination::ControlConnection => { + &x.kafka_address == self.control_connection_address.as_ref().unwrap() + } + }) { + node.set_state(node_state); + } if old_connection .map(|old| old.pending_requests_count()) diff --git a/shotover/src/transforms/kafka/sink_cluster/kafka_node.rs b/shotover/src/transforms/kafka/sink_cluster/kafka_node.rs index 9c06dc9f0..5811da4c2 100644 --- a/shotover/src/transforms/kafka/sink_cluster/kafka_node.rs +++ b/shotover/src/transforms/kafka/sink_cluster/kafka_node.rs @@ -117,7 +117,7 @@ impl ConnectionFactory { scram_over_mtls.original_scram_state, OriginalScramState::AuthSuccess ) { - // The original connection is authorized, so we are free to make authorize more session + // The original connection is authorized, so we are free to authorize more sessions self.perform_tokenauth_scram_exchange(scram_over_mtls, connection) .await .context("Failed to perform delegation token SCRAM exchange") @@ -144,10 +144,19 @@ impl ConnectionFactory { scram_over_mtls: &AuthorizeScramOverMtls, connection: &mut SinkConnection, ) -> Result<()> { - let mut auth_requests = self.auth_requests.clone(); - // send/receive SaslHandshake - connection.send(vec![auth_requests.remove(0)])?; + let mut sasl_handshake_request = self.auth_requests.first().unwrap().clone(); + if let Some(Frame::Kafka(KafkaFrame::Request { header, .. })) = + sasl_handshake_request.frame() + { + // If the request is version 0 it requires SaslAuthenticate messages to be sent as raw bytes which is impossible. + // So instead force it to version 1. + if header.request_api_version == 0 { + header.request_api_version = 1; + sasl_handshake_request.invalidate_cache(); + } + } + connection.send(vec![sasl_handshake_request])?; let mut handshake_response = connection.recv().await?.pop().unwrap(); if let Some(Frame::Kafka(KafkaFrame::Response { body: ResponseBody::SaslHandshake(handshake_response), @@ -176,19 +185,35 @@ impl ConnectionFactory { ) .map_err(|x| anyhow!("{x:?}"))? .with_first_extensions("tokenauth=true".to_owned()); - connection.send(vec![Self::create_auth_request(scram.initial())])?; + connection + .send(vec![Self::create_auth_request(scram.initial())]) + .context("Failed to send first SCRAM request")?; // SCRAM server-first - let first_scram_response = connection.recv().await?.pop().unwrap(); + let first_scram_response = connection + .recv() + .await + .context("Failed to receive first scram response")? + .pop() + .unwrap(); let first_scram_response = Self::process_auth_response(first_scram_response) .context("first response to delegation token SCRAM reported an error")?; // SCRAM client-final - let final_scram_request = scram.response(&first_scram_response)?; - connection.send(vec![Self::create_auth_request(final_scram_request)])?; + let final_scram_request = scram + .response(&first_scram_response) + .context("Failed to generate final scram request")?; + connection + .send(vec![Self::create_auth_request(final_scram_request)]) + .context("Failed to send final SCRAM request")?; // SCRAM server-final - let final_scram_response = connection.recv().await?.pop().unwrap(); + let final_scram_response = connection + .recv() + .await + .context("Failed to receive second scram response")? + .pop() + .unwrap(); let final_scram_response = Self::process_auth_response(final_scram_response) .context("final response to delegation token SCRAM reported an error")?; scram diff --git a/shotover/src/transforms/kafka/sink_cluster/mod.rs b/shotover/src/transforms/kafka/sink_cluster/mod.rs index 980bf379d..1ade57a9c 100644 --- a/shotover/src/transforms/kafka/sink_cluster/mod.rs +++ b/shotover/src/transforms/kafka/sink_cluster/mod.rs @@ -273,8 +273,10 @@ impl AtomicBrokerId { } fn set(&self, value: BrokerId) { - self.0 - .store(value.0.into(), std::sync::atomic::Ordering::Relaxed) + if value != -1 { + self.0 + .store(value.0.into(), std::sync::atomic::Ordering::Relaxed) + } } fn clear(&self) { @@ -989,7 +991,8 @@ impl KafkaSinkCluster { | RequestBody::AlterConfigs(_) | RequestBody::CreatePartitions(_) | RequestBody::DeleteTopics(_) - | RequestBody::CreateAcls(_), + | RequestBody::CreateAcls(_) + | RequestBody::ApiVersions(_), .. })) => self.route_to_random_broker(message), @@ -2406,9 +2409,9 @@ impl KafkaSinkCluster { ResponseError::try_from_code(topic.error_code) { tracing::info!( - "Response to CreateTopics included error NOT_CONTROLLER and so reset controller broker, previously was {:?}", - self.controller_broker.get() - ); + "Response to CreateTopics included error NOT_CONTROLLER and so reset controller broker, previously was {:?}", + self.controller_broker.get() + ); self.controller_broker.clear(); break; } @@ -2972,10 +2975,9 @@ impl KafkaSinkCluster { metadata.controller_id = shotover_node.broker_id; } else { - return Err(anyhow!( - "Invalid metadata, controller points at unknown broker {:?}", - metadata.controller_id - )); + // controller is either -1 or an unknown broker + // In both cases it is reasonable to set to -1 to indicate the controller is unknown. + metadata.controller_id = BrokerId(-1); } Ok(()) diff --git a/test-helpers/src/connection/kafka/python.rs b/test-helpers/src/connection/kafka/python.rs index 4013c782f..7f500ed31 100644 --- a/test-helpers/src/connection/kafka/python.rs +++ b/test-helpers/src/connection/kafka/python.rs @@ -28,6 +28,84 @@ pub async fn run_python_smoke_test(address: &str) { .unwrap(); } +pub async fn run_python_smoke_test_sasl_plain(address: &str, user: &str, password: &str) { + ensure_uv_is_installed().await; + + let project_dir = Path::new(env!("CARGO_MANIFEST_DIR")).join("src/connection/kafka/python"); + let uv_binary = uv_binary_path(); + let config = format!( + r#"{{ + 'bootstrap_servers': ["{address}"], + 'security_protocol': "SASL_PLAINTEXT", + 'sasl_mechanism': "PLAIN", + 'sasl_plain_username': "{user}", + 'sasl_plain_password': "{password}", +}}"# + ); + tokio::time::timeout( + Duration::from_secs(60), + run_command_async( + &project_dir, + uv_binary.to_str().unwrap(), + &["run", "main.py", &config], + ), + ) + .await + .unwrap(); +} + +pub async fn run_python_smoke_test_sasl_scram(address: &str, user: &str, password: &str) { + ensure_uv_is_installed().await; + + let project_dir = Path::new(env!("CARGO_MANIFEST_DIR")).join("src/connection/kafka/python"); + let uv_binary = uv_binary_path(); + let config = format!( + r#"{{ + 'bootstrap_servers': ["{address}"], + 'security_protocol': "SASL_PLAINTEXT", + 'sasl_mechanism': "SCRAM-SHA-256", + 'sasl_plain_username': "{user}", + 'sasl_plain_password': "{password}", +}}"# + ); + tokio::time::timeout( + Duration::from_secs(60), + run_command_async( + &project_dir, + uv_binary.to_str().unwrap(), + &["run", "main.py", &config], + ), + ) + .await + .unwrap(); +} + +pub async fn run_python_bad_auth_sasl_scram(address: &str, user: &str, password: &str) { + ensure_uv_is_installed().await; + + let project_dir = Path::new(env!("CARGO_MANIFEST_DIR")).join("src/connection/kafka/python"); + let uv_binary = uv_binary_path(); + let config = format!( + r#"{{ + 'bootstrap_servers': ["{address}"], + 'security_protocol': "SASL_PLAINTEXT", + 'sasl_mechanism': "SCRAM-SHA-256", + 'sasl_plain_username': "{user}", + 'sasl_plain_password': "{password}", +}}"# + ); + tokio::time::timeout( + Duration::from_secs(60), + run_command_async( + &project_dir, + uv_binary.to_str().unwrap(), + &["run", "auth_fail.py", &config], + ), + ) + .await + .unwrap(); +} + /// Install a specific version of UV to: /// * avoid developers having to manually install an external tool /// * avoid issues due to a different version being installed diff --git a/test-helpers/src/connection/kafka/python/auth_fail.py b/test-helpers/src/connection/kafka/python/auth_fail.py new file mode 100644 index 000000000..f81229e18 --- /dev/null +++ b/test-helpers/src/connection/kafka/python/auth_fail.py @@ -0,0 +1,18 @@ +from kafka import KafkaConsumer +from kafka import KafkaProducer +from kafka.errors import KafkaError +import sys + +def main(): + config = eval(sys.argv[1]) + print("Running kafka-python script with config:") + print(config) + + try: + KafkaProducer(**config) + raise Exception("KafkaProducer was succesfully created but expected to fail due to using incorrect username/password") + except KafkaError: + print("kafka-python auth_fail script passed all test cases") + +if __name__ == "__main__": + main() diff --git a/test-helpers/src/connection/kafka/python/main.py b/test-helpers/src/connection/kafka/python/main.py index 927383fe6..57d62e2a2 100644 --- a/test-helpers/src/connection/kafka/python/main.py +++ b/test-helpers/src/connection/kafka/python/main.py @@ -1,5 +1,7 @@ from kafka import KafkaConsumer +from kafka import KafkaAdminClient from kafka import KafkaProducer +from kafka.admin import NewTopic import sys def main(): @@ -7,19 +9,28 @@ def main(): print("Running kafka-python script with config:") print(config) + admin = KafkaAdminClient(**config) + admin.create_topics([ + NewTopic( + name='python_test_topic', + num_partitions=1, + replication_factor=1 + ) + ]) + producer = KafkaProducer(**config) - producer.send('test_topic', b'some_message_bytes').get(timeout=10) - producer.send('test_topic', b'another_message').get(timeout=10) + producer.send('python_test_topic', b'some_message_bytes').get(timeout=30) + producer.send('python_test_topic', b'another_message').get(timeout=30) - consumer = KafkaConsumer('test_topic', auto_offset_reset='earliest', **config) + consumer = KafkaConsumer('python_test_topic', auto_offset_reset='earliest', **config) msg = next(consumer) - assert(msg.topic == "test_topic") + assert(msg.topic == "python_test_topic") assert(msg.value == b"some_message_bytes") assert(msg.offset == 0) msg = next(consumer) - assert(msg.topic == "test_topic") + assert(msg.topic == "python_test_topic") assert(msg.value == b"another_message") assert(msg.offset == 1) diff --git a/test-helpers/src/lib.rs b/test-helpers/src/lib.rs index ba3a52bc4..5cf831393 100644 --- a/test-helpers/src/lib.rs +++ b/test-helpers/src/lib.rs @@ -42,6 +42,7 @@ pub async fn run_command_async(current_dir: &Path, command: &str, args: &[&str]) let output = tokio::process::Command::new(command) .args(args) .current_dir(current_dir) + .kill_on_drop(true) .status() .await .unwrap();