Skip to content

Commit

Permalink
Move decoding into frame
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai committed Oct 24, 2024
1 parent a9f55e5 commit 443130d
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 96 deletions.
156 changes: 67 additions & 89 deletions shotover/src/codec/kafka.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@ use crate::frame::{Frame, MessageType};
use crate::message::{Encodable, Message, MessageId, Messages};
use anyhow::{anyhow, Result};
use bytes::BytesMut;
use kafka_protocol::messages::{
ApiKey, RequestHeader as RequestHeaderProtocol, RequestKind, ResponseHeader, ResponseKind,
SaslAuthenticateRequest, SaslAuthenticateResponse,
};
use kafka_protocol::messages::{ApiKey, RequestKind, ResponseKind};
use metrics::Histogram;
use std::sync::mpsc;
use std::time::Instant;
Expand Down Expand Up @@ -123,11 +120,6 @@ impl Decoder for KafkaDecoder {
pretty_hex::pretty_hex(&bytes)
);

struct Meta {
request_header: RequestHeader,
message_id: Option<u128>,
}

let request_info = self
.request_header_rx
.as_ref()
Expand All @@ -137,89 +129,75 @@ impl Decoder for KafkaDecoder {
})
.transpose()?;

let message = if self.expect_raw_sasl.is_some() {
// Convert the unframed raw sasl into a framed sasl
// This allows transforms to correctly parse the message and inspect the sasl request
let kafka_frame = match self.direction {
Direction::Source => KafkaFrame::Request {
header: RequestHeaderProtocol::default()
.with_request_api_key(ApiKey::SaslAuthenticateKey as i16),
body: RequestKind::SaslAuthenticate(
SaslAuthenticateRequest::default().with_auth_bytes(bytes.freeze()),
),
},
Direction::Sink => KafkaFrame::Response {
struct Meta {
request_header: RequestHeader,
message_id: Option<u128>,
}

let meta = if let Some(RequestInfo { header, id, .. }) = request_info {
Meta {
request_header: header,
message_id: Some(id),
}
} else if self.expect_raw_sasl.is_some() {
Meta {
request_header: RequestHeader {
api_key: ApiKey::SaslAuthenticateKey,
version: 0,
header: ResponseHeader::default(),
body: ResponseKind::SaslAuthenticate(
SaslAuthenticateResponse::default().with_auth_bytes(bytes.freeze()),
// TODO: we need to set with_error_code
),
},
};
let codec_state = CodecState::Kafka(KafkaCodecState {
request_header: None,
raw_sasl: self.expect_raw_sasl,
});
self.expect_raw_sasl = match self.expect_raw_sasl {
Some(SaslType::Plain) => None,
Some(SaslType::ScramMessage1) => Some(SaslType::ScramMessage2),
Some(SaslType::ScramMessage2) => None,
None => None,
};
Message::from_frame_and_codec_state_at_instant(
Frame::Kafka(kafka_frame),
codec_state,
// This code path is only used for requests, so message_id can be None.
message_id: None,
}
} else {
Meta {
request_header: RequestHeader {
api_key: ApiKey::try_from(i16::from_be_bytes(
bytes[4..6].try_into().unwrap(),
))
.unwrap(),
version: i16::from_be_bytes(bytes[6..8].try_into().unwrap()),
},
// This code path is only used for requests, so message_id can be None.
message_id: None,
}
};
let mut message = if let Some(id) = meta.message_id.as_ref() {
let mut message = Message::from_bytes_at_instant(
bytes.freeze(),
CodecState::Kafka(KafkaCodecState {
request_header: Some(meta.request_header),
raw_sasl: self.expect_raw_sasl,
}),
Some(received_at),
)
);
message.set_request_id(*id);
message
} else {
let meta = if let Some(RequestInfo {
header,
id,
expect_raw_sasl,
}) = request_info
{
if let Some(expect_raw_sasl) = expect_raw_sasl {
self.expect_raw_sasl = Some(expect_raw_sasl);
}
Meta {
request_header: header,
message_id: Some(id),
}
} else {
Meta {
request_header: RequestHeader {
api_key: ApiKey::try_from(i16::from_be_bytes(
bytes[4..6].try_into().unwrap(),
))
.unwrap(),
version: i16::from_be_bytes(bytes[6..8].try_into().unwrap()),
},
message_id: None,
}
};
let mut message = if let Some(id) = meta.message_id.as_ref() {
let mut message = Message::from_bytes_at_instant(
bytes.freeze(),
CodecState::Kafka(KafkaCodecState {
request_header: Some(meta.request_header),
raw_sasl: None,
}),
Some(received_at),
);
message.set_request_id(*id);
message
} else {
Message::from_bytes_at_instant(
bytes.freeze(),
CodecState::Kafka(KafkaCodecState {
request_header: None,
raw_sasl: None,
}),
Some(received_at),
)
};
Message::from_bytes_at_instant(
bytes.freeze(),
CodecState::Kafka(KafkaCodecState {
request_header: None,
raw_sasl: self.expect_raw_sasl,
}),
Some(received_at),
)
};

// advanced to the next state of expect_raw_sasl
self.expect_raw_sasl = match self.expect_raw_sasl {
Some(SaslType::Plain) => None,
Some(SaslType::ScramMessage1) => Some(SaslType::ScramMessage2),
Some(SaslType::ScramMessage2) => None,
None => None,
};

if let Some(request_info) = request_info {
// set expect_raw_sasl for responses
if let Some(expect_raw_sasl) = request_info.expect_raw_sasl {
self.expect_raw_sasl = Some(expect_raw_sasl);
}
} else {
// set expect_raw_sasl for requests
if meta.request_header.api_key == ApiKey::SaslHandshakeKey
&& meta.request_header.version == 0
{
Expand All @@ -229,6 +207,7 @@ impl Decoder for KafkaDecoder {
..
})) = message.frame()
{
// TODO: move into function
self.expect_raw_sasl = Some(match sasl_handshake.mechanism.as_str() {
"PLAIN" => SaslType::Plain,
"SCRAM-SHA-512" => SaslType::ScramMessage1,
Expand All @@ -246,8 +225,7 @@ impl Decoder for KafkaDecoder {
message.invalidate_cache();
}
}
message
};
}

Ok(Some(vec![message]))
} else {
Expand Down
2 changes: 2 additions & 0 deletions shotover/src/codec/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ impl CodecState {
#[derive(Debug, Clone, PartialEq, Copy)]
pub struct KafkaCodecState {
pub request_header: Option<RequestHeader>,
/// When some this message is not in the kafka protocol and is instead a raw SASL message
/// KafkaFrame will parse this as a SaslHandshake to hide the legacy raw message from transform implementations.
pub raw_sasl: Option<SaslType>,
}

Expand Down
35 changes: 28 additions & 7 deletions shotover/src/frame/kafka.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use crate::codec::kafka::RequestHeader as CodecRequestHeader;
use crate::codec::KafkaCodecState;
use anyhow::{anyhow, Context, Result};
use bytes::{BufMut, Bytes, BytesMut};
use kafka_protocol::messages::{ApiKey, RequestHeader, ResponseHeader};
use kafka_protocol::messages::{
ApiKey, RequestHeader, ResponseHeader, SaslAuthenticateRequest, SaslAuthenticateResponse,
};
use kafka_protocol::protocol::{Decodable, Encodable};
use std::fmt::{Display, Formatter, Result as FmtResult};

Expand Down Expand Up @@ -70,12 +72,31 @@ impl Display for KafkaFrame {

impl KafkaFrame {
pub fn from_bytes(mut bytes: Bytes, codec_state: KafkaCodecState) -> Result<Self> {
// remove length header
let _ = bytes.split_to(4);

match &codec_state.request_header {
Some(request_header) => KafkaFrame::parse_response(bytes, *request_header),
None => KafkaFrame::parse_request(bytes),
if codec_state.raw_sasl.is_some() {
match &codec_state.request_header {
Some(_) => Ok(KafkaFrame::Response {
version: 0,
header: ResponseHeader::default(),
body: ResponseBody::SaslAuthenticate(
SaslAuthenticateResponse::default().with_auth_bytes(bytes),
// TODO: we need to set with_error_code
),
}),
None => Ok(KafkaFrame::Request {
header: RequestHeader::default()
.with_request_api_key(ApiKey::SaslAuthenticateKey as i16),
body: RequestBody::SaslAuthenticate(
SaslAuthenticateRequest::default().with_auth_bytes(bytes),
),
}),
}
} else {
// remove length header
let _ = bytes.split_to(4);
match &codec_state.request_header {
Some(request_header) => KafkaFrame::parse_response(bytes, *request_header),
None => KafkaFrame::parse_request(bytes),
}
}
}

Expand Down

0 comments on commit 443130d

Please sign in to comment.