Skip to content

Commit

Permalink
Move decoding into frame
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai committed Oct 24, 2024
1 parent a9f55e5 commit 699e388
Show file tree
Hide file tree
Showing 10 changed files with 276 additions and 166 deletions.
59 changes: 46 additions & 13 deletions shotover-proxy/tests/kafka_int_tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use test_cases::produce_consume_partitions1;
use test_cases::produce_consume_partitions3;
use test_cases::{assert_topic_creation_is_denied_due_to_acl, setup_basic_user_acls};
use test_helpers::connection::kafka::node::run_node_smoke_test_scram;
use test_helpers::connection::kafka::python::run_python_bad_auth_sasl_scram;
use test_helpers::connection::kafka::python::run_python_smoke_test_sasl_scram;
use test_helpers::connection::kafka::{KafkaConnectionBuilder, KafkaDriver};
use test_helpers::docker_compose::docker_compose;
Expand Down Expand Up @@ -755,21 +756,53 @@ async fn cluster_sasl_scram_over_mtls_nodejs_and_python() {

let _docker_compose =
docker_compose("tests/test-configs/kafka/cluster-sasl-scram-over-mtls/docker-compose.yaml");
let shotover = shotover_process(
"tests/test-configs/kafka/cluster-sasl-scram-over-mtls/topology-single.yaml",
)
.start()
.await;

run_node_smoke_test_scram("127.0.0.1:9192", "super_user", "super_password").await;
run_python_smoke_test_sasl_scram("127.0.0.1:9192", "super_user", "super_password").await;
{
let shotover = shotover_process(
"tests/test-configs/kafka/cluster-sasl-scram-over-mtls/topology-single.yaml",
)
.start()
.await;

run_node_smoke_test_scram("127.0.0.1:9192", "super_user", "super_password").await;
run_python_smoke_test_sasl_scram("127.0.0.1:9192", "super_user", "super_password").await;

tokio::time::timeout(
Duration::from_secs(10),
shotover.shutdown_and_then_consume_events(&[]),
)
.await
.expect("Shotover did not shutdown within 10s");
tokio::time::timeout(
Duration::from_secs(10),
shotover.shutdown_and_then_consume_events(&[]),
)
.await
.expect("Shotover did not shutdown within 10s");
}

{
let shotover = shotover_process(
"tests/test-configs/kafka/cluster-sasl-scram-over-mtls/topology-single.yaml",
)
.start()
.await;

run_python_bad_auth_sasl_scram("127.0.0.1:9192", "incorrect_user", "super_password").await;
run_python_bad_auth_sasl_scram("127.0.0.1:9192", "super_user", "incorrect_password").await;

tokio::time::timeout(
Duration::from_secs(10),
shotover.shutdown_and_then_consume_events(&[EventMatcher::new()
.with_level(Level::Error)
.with_target("shotover::server")
.with_message(r#"encountered an error when flushing the chain kafka for shutdown
Caused by:
0: KafkaSinkCluster transform failed
1: Failed to receive responses (without sending requests)
2: Outgoing connection had pending requests, those requests/responses are lost so connection recovery cannot be attempted.
3: Failed to receive from ControlConnection
4: The other side of this connection closed the connection"#)
.with_count(Count::Times(2))]),
)
.await
.expect("Shotover did not shutdown within 10s");
}
}

#[rstest]
Expand Down
228 changes: 105 additions & 123 deletions shotover/src/codec/kafka.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,15 @@ use crate::frame::{Frame, MessageType};
use crate::message::{Encodable, Message, MessageId, Messages};
use anyhow::{anyhow, Result};
use bytes::BytesMut;
use kafka_protocol::messages::{
ApiKey, RequestHeader as RequestHeaderProtocol, RequestKind, ResponseHeader, ResponseKind,
SaslAuthenticateRequest, SaslAuthenticateResponse,
};
use kafka_protocol::messages::{ApiKey, RequestKind, ResponseKind};
use kafka_protocol::protocol::StrBytes;
use metrics::Histogram;
use std::sync::mpsc;
use std::time::Instant;
use tokio_util::codec::{Decoder, Encoder};

#[derive(Copy, Clone, Debug, PartialEq)]
pub struct RequestHeader {
// TODO: this should be i16???
pub api_key: ApiKey,
pub version: i16,
}
Expand Down Expand Up @@ -66,21 +63,36 @@ impl CodecBuilder for KafkaCodecBuilder {
pub struct RequestInfo {
header: RequestHeader,
id: MessageId,
expect_raw_sasl: Option<SaslType>,
expect_raw_sasl: Option<SaslRequestState>,
}

// Keeps track of the next expected sasl message
#[derive(Debug, Clone, PartialEq, Copy)]
pub enum SaslType {
pub enum SaslRequestState {
/// The next message will be a sasl message in the PLAIN mechanism
Plain,
ScramMessage1,
ScramMessage2,
/// The next message will be the first sasl message in the SCRAM mechanism
ScramFirst,
/// The next message will be the final sasl message in the SCRAM mechanism
ScramFinal,
}

impl SaslRequestState {
fn from_name(mechanism: &StrBytes) -> Result<SaslRequestState> {
match mechanism.as_str() {
"PLAIN" => Ok(SaslRequestState::Plain),
"SCRAM-SHA-512" => Ok(SaslRequestState::ScramFirst),
"SCRAM-SHA-256" => Ok(SaslRequestState::ScramFirst),
mechanism => Err(anyhow!("Unknown sasl mechanism {mechanism}")),
}
}
}

pub struct KafkaDecoder {
// Some when Sink (because it receives responses)
request_header_rx: Option<mpsc::Receiver<RequestInfo>>,
direction: Direction,
expect_raw_sasl: Option<SaslType>,
expect_raw_sasl: Option<SaslRequestState>,
}

impl KafkaDecoder {
Expand Down Expand Up @@ -123,103 +135,82 @@ impl Decoder for KafkaDecoder {
pretty_hex::pretty_hex(&bytes)
);

let request_info = self
.request_header_rx
.as_ref()
.map(|rx| rx.recv())
.transpose()
.map_err(|_| CodecReadError::Parser(anyhow!("kafka encoder half was lost")))?;

struct Meta {
request_header: RequestHeader,
message_id: Option<u128>,
}

let request_info = self
.request_header_rx
.as_ref()
.map(|rx| {
rx.recv()
.map_err(|_| CodecReadError::Parser(anyhow!("kafka encoder half was lost")))
})
.transpose()?;

let message = if self.expect_raw_sasl.is_some() {
// Convert the unframed raw sasl into a framed sasl
// This allows transforms to correctly parse the message and inspect the sasl request
let kafka_frame = match self.direction {
Direction::Source => KafkaFrame::Request {
header: RequestHeaderProtocol::default()
.with_request_api_key(ApiKey::SaslAuthenticateKey as i16),
body: RequestKind::SaslAuthenticate(
SaslAuthenticateRequest::default().with_auth_bytes(bytes.freeze()),
),
},
Direction::Sink => KafkaFrame::Response {
let meta = if let Some(RequestInfo { header, id, .. }) = request_info {
Meta {
request_header: header,
message_id: Some(id),
}
} else if self.expect_raw_sasl.is_some() {
Meta {
request_header: RequestHeader {
api_key: ApiKey::SaslAuthenticateKey,
version: 0,
header: ResponseHeader::default(),
body: ResponseKind::SaslAuthenticate(
SaslAuthenticateResponse::default().with_auth_bytes(bytes.freeze()),
// TODO: we need to set with_error_code
),
},
};
let codec_state = CodecState::Kafka(KafkaCodecState {
request_header: None,
raw_sasl: self.expect_raw_sasl,
});
self.expect_raw_sasl = match self.expect_raw_sasl {
Some(SaslType::Plain) => None,
Some(SaslType::ScramMessage1) => Some(SaslType::ScramMessage2),
Some(SaslType::ScramMessage2) => None,
None => None,
};
Message::from_frame_and_codec_state_at_instant(
Frame::Kafka(kafka_frame),
codec_state,
// This code path is only used for requests, so message_id can be None.
message_id: None,
}
} else {
Meta {
request_header: RequestHeader {
api_key: ApiKey::try_from(i16::from_be_bytes(
bytes[4..6].try_into().unwrap(),
))
.unwrap(),
version: i16::from_be_bytes(bytes[6..8].try_into().unwrap()),
},
// This code path is only used for requests, so message_id can be None.
message_id: None,
}
};
let mut message = if let Some(id) = meta.message_id.as_ref() {
let mut message = Message::from_bytes_at_instant(
bytes.freeze(),
CodecState::Kafka(KafkaCodecState {
request_header: Some(meta.request_header),
raw_sasl: self.expect_raw_sasl,
}),
Some(received_at),
)
);
message.set_request_id(*id);
message
} else {
let meta = if let Some(RequestInfo {
header,
id,
expect_raw_sasl,
}) = request_info
{
if let Some(expect_raw_sasl) = expect_raw_sasl {
self.expect_raw_sasl = Some(expect_raw_sasl);
}
Meta {
request_header: header,
message_id: Some(id),
}
} else {
Meta {
request_header: RequestHeader {
api_key: ApiKey::try_from(i16::from_be_bytes(
bytes[4..6].try_into().unwrap(),
))
.unwrap(),
version: i16::from_be_bytes(bytes[6..8].try_into().unwrap()),
},
message_id: None,
}
};
let mut message = if let Some(id) = meta.message_id.as_ref() {
let mut message = Message::from_bytes_at_instant(
bytes.freeze(),
CodecState::Kafka(KafkaCodecState {
request_header: Some(meta.request_header),
raw_sasl: None,
}),
Some(received_at),
);
message.set_request_id(*id);
message
} else {
Message::from_bytes_at_instant(
bytes.freeze(),
CodecState::Kafka(KafkaCodecState {
request_header: None,
raw_sasl: None,
}),
Some(received_at),
)
};
Message::from_bytes_at_instant(
bytes.freeze(),
CodecState::Kafka(KafkaCodecState {
request_header: None,
raw_sasl: self.expect_raw_sasl,
}),
Some(received_at),
)
};

// advanced to the next state of expect_raw_sasl
self.expect_raw_sasl = match self.expect_raw_sasl {
Some(SaslRequestState::Plain) => None,
Some(SaslRequestState::ScramFirst) => Some(SaslRequestState::ScramFinal),
Some(SaslRequestState::ScramFinal) => None,
None => None,
};

if let Some(request_info) = request_info {
// set expect_raw_sasl for responses
if let Some(expect_raw_sasl) = request_info.expect_raw_sasl {
self.expect_raw_sasl = Some(expect_raw_sasl);
}
} else {
// set expect_raw_sasl for requests
if meta.request_header.api_key == ApiKey::SaslHandshakeKey
&& meta.request_header.version == 0
{
Expand All @@ -229,25 +220,18 @@ impl Decoder for KafkaDecoder {
..
})) = message.frame()
{
self.expect_raw_sasl = Some(match sasl_handshake.mechanism.as_str() {
"PLAIN" => SaslType::Plain,
"SCRAM-SHA-512" => SaslType::ScramMessage1,
"SCRAM-SHA-256" => SaslType::ScramMessage1,
mechanism => {
return Err(CodecReadError::Parser(anyhow!(
"Unknown sasl mechanism {mechanism}"
)))
}
});
self.expect_raw_sasl = Some(
SaslRequestState::from_name(&sasl_handshake.mechanism)
.map_err(CodecReadError::Parser)?,
);

// Clear raw bytes of the message to force the encoder to encode from frame.
// This is needed because the encoder only has access to the frame if it does not have any raw bytes,
// and the encoder needs to inspect the frame to set its own sasl state.
message.invalidate_cache();
}
}
message
};
}

Ok(Some(vec![message]))
} else {
Expand Down Expand Up @@ -288,15 +272,19 @@ impl Encoder<Messages> for KafkaEncoder {
let response_is_dummy = m.response_is_dummy();
let id = m.id();
let received_at = m.received_from_source_or_sink_at;
let codec_state = m.codec_state.as_kafka();
let message_contains_raw_sasl = if let CodecState::Kafka(codec_state) = m.codec_state {
codec_state.raw_sasl.is_some()
} else {
false
};
let mut expect_raw_sasl = None;
let result = match m.into_encodable() {
Encodable::Bytes(bytes) => {
dst.extend_from_slice(&bytes);
Ok(())
}
Encodable::Frame(frame) => {
if codec_state.raw_sasl.is_some() {
if message_contains_raw_sasl {
match frame {
Frame::Kafka(KafkaFrame::Request {
body: RequestKind::SaslAuthenticate(body),
Expand All @@ -315,23 +303,17 @@ impl Encoder<Messages> for KafkaEncoder {
Ok(())
} else {
let frame = frame.into_kafka().unwrap();
// it is garanteed that all v0 SaslHandshakes will be in a parsed state since we parse it in the KafkaDecoder.
// it is garanteed that all v0 SaslHandshakes will be in a parsed state since we parse + invalidate_cache in the KafkaDecoder.
if let KafkaFrame::Request {
body: RequestKind::SaslHandshake(sasl_handshake),
header,
} = &frame
{
if header.request_api_version == 0 {
expect_raw_sasl = Some(match sasl_handshake.mechanism.as_str() {
"PLAIN" => SaslType::Plain,
"SCRAM-SHA-512" => SaslType::ScramMessage1,
"SCRAM-SHA-256" => SaslType::ScramMessage1,
mechanism => {
return Err(CodecWriteError::Encoder(anyhow!(
"Unknown sasl mechanism {mechanism}"
)))
}
});
expect_raw_sasl = Some(
SaslRequestState::from_name(&sasl_handshake.mechanism)
.map_err(CodecWriteError::Encoder)?,
);
}
}
frame.encode(dst)
Expand All @@ -343,7 +325,7 @@ impl Encoder<Messages> for KafkaEncoder {
// or if it will generate a dummy response
if !dst[start..].is_empty() && !response_is_dummy {
if let Some(tx) = self.request_header_tx.as_ref() {
let header = if codec_state.raw_sasl.is_some() {
let header = if message_contains_raw_sasl {
RequestHeader {
api_key: ApiKey::SaslAuthenticateKey,
version: 0,
Expand Down
Loading

0 comments on commit 699e388

Please sign in to comment.