Skip to content

Commit

Permalink
Fix sarama and kafkajs drivers (#1690)
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai authored Jul 16, 2024
1 parent ce4a153 commit 2d37bfe
Show file tree
Hide file tree
Showing 10 changed files with 272 additions and 22 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
/docs/mdbook_bin
/shotover-proxy/build/packages
/some_local_file
/test-helpers/src/connection/kafka/node/node_modules
42 changes: 42 additions & 0 deletions shotover-proxy/tests/kafka_int_tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::time::Duration;
use std::time::Instant;
use test_cases::produce_consume_partitions1;
use test_cases::{assert_topic_creation_is_denied_due_to_acl, setup_basic_user_acls};
use test_helpers::connection::kafka::node::run_node_smoke_test_scram;
use test_helpers::connection::kafka::{KafkaConnectionBuilder, KafkaDriver};
use test_helpers::docker_compose::docker_compose;
use test_helpers::shotover_process::{Count, EventMatcher};
Expand Down Expand Up @@ -34,6 +35,24 @@ async fn passthrough_standard(#[case] driver: KafkaDriver) {
.expect("Shotover did not shutdown within 10s");
}

#[tokio::test]
async fn passthrough_nodejs() {
let _docker_compose =
docker_compose("tests/test-configs/kafka/passthrough/docker-compose.yaml");
let shotover = shotover_process("tests/test-configs/kafka/passthrough/topology.yaml")
.start()
.await;

test_helpers::connection::kafka::node::run_node_smoke_test("127.0.0.1:9192").await;

tokio::time::timeout(
Duration::from_secs(10),
shotover.shutdown_and_then_consume_events(&[]),
)
.await
.expect("Shotover did not shutdown within 10s");
}

#[rstest]
#[cfg_attr(feature = "kafka-cpp-driver-tests", case::cpp(KafkaDriver::Cpp))]
#[case::java(KafkaDriver::Java)]
Expand Down Expand Up @@ -435,6 +454,29 @@ async fn assert_connection_fails_with_incorrect_password(driver: KafkaDriver, us
);
}

#[rstest]
#[tokio::test]
async fn cluster_sasl_scram_over_mtls_nodejs() {
test_helpers::cert::generate_kafka_test_certs();

let _docker_compose =
docker_compose("tests/test-configs/kafka/cluster-sasl-scram-over-mtls/docker-compose.yaml");
let shotover = shotover_process(
"tests/test-configs/kafka/cluster-sasl-scram-over-mtls/topology-single.yaml",
)
.start()
.await;

run_node_smoke_test_scram("127.0.0.1:9192", "super_user", "super_password").await;

tokio::time::timeout(
Duration::from_secs(10),
shotover.shutdown_and_then_consume_events(&[]),
)
.await
.expect("Shotover did not shutdown within 10s");
}

#[rstest]
//#[cfg_attr(feature = "kafka-cpp-driver-tests", case::cpp(KafkaDriver::Cpp))] // CPP driver does not support scram
#[case::java(KafkaDriver::Java)]
Expand Down
2 changes: 1 addition & 1 deletion shotover/src/transforms/kafka/sink_cluster/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ impl KafkaSinkCluster {
})) => {
if let Some(scram_over_mtls) = &mut self.authorize_scram_over_mtls {
if let Some(username) = get_username_from_scram_request(auth_bytes) {
scram_over_mtls.username = username;
scram_over_mtls.set_username(username).await?;
}
}
self.connection_factory.add_auth_request(request.clone());
Expand Down
34 changes: 27 additions & 7 deletions shotover/src/transforms/kafka/sink_cluster/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::frame::kafka::{KafkaFrame, RequestBody, ResponseBody};
use crate::frame::Frame;
use crate::message::Message;
use crate::tls::TlsConnector;
use crate::transforms::kafka::sink_cluster::scram_over_mtls::OriginalScramState;
use crate::transforms::kafka::sink_cluster::SASL_SCRAM_MECHANISMS;
use anyhow::{anyhow, Context, Result};
use bytes::Bytes;
Expand Down Expand Up @@ -89,9 +90,8 @@ impl ConnectionFactory {
if let Some(scram_over_mtls) = authorize_scram_over_mtls {
if let Some(sasl_mechanism) = sasl_mechanism {
if SASL_SCRAM_MECHANISMS.contains(&sasl_mechanism.as_str()) {
self.perform_tokenauth_scram_exchange(scram_over_mtls, &mut connection)
.await
.context("Failed to perform delegation token SCRAM exchange")?;
self.scram_over_mtls(scram_over_mtls, &mut connection)
.await?;
} else {
self.replay_sasl(&mut connection).await?;
}
Expand All @@ -106,6 +106,29 @@ impl ConnectionFactory {
Ok(connection)
}

async fn scram_over_mtls(
&self,
scram_over_mtls: &AuthorizeScramOverMtls,
connection: &mut SinkConnection,
) -> Result<()> {
if matches!(
scram_over_mtls.original_scram_state,
OriginalScramState::AuthSuccess
) {
// The original connection is authorized, so we are free to make authorize more session
self.perform_tokenauth_scram_exchange(scram_over_mtls, connection)
.await
.context("Failed to perform delegation token SCRAM exchange")
} else {
// If the original session has not authenticated yet, this is probably the first outgoing connection.
// So just create it with no outgoing connections, the client will perform the remainder of the scram handshake.
//
// If the original session failed to authenticate we cannot authorize this session.
// So just perform no scram handshake and let kafka uphold the authorization requirements for us.
Ok(())
}
}

/// authorize_scram_over_mtls creates new connections via delegation tokens.
/// Kafka implements delegation tokens as just a special case of SCRAM.
/// In particular kafka utilizes scram's concept of extensions to send `tokenauth=true` to the server,
Expand Down Expand Up @@ -141,10 +164,7 @@ impl ConnectionFactory {
));
}

let delegation_token = scram_over_mtls
.token_task
.get_token_for_user(scram_over_mtls.username.clone())
.await?;
let delegation_token = scram_over_mtls.get_token_for_user().await?;

// SCRAM client-first
let mut scram = Scram::<Sha256>::new(
Expand Down
59 changes: 45 additions & 14 deletions shotover/src/transforms/kafka/sink_cluster/scram_over_mtls.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,24 @@
use std::{
collections::HashMap,
sync::Arc,
time::{Duration, Instant},
use super::node::{ConnectionFactory, KafkaAddress};
use crate::{
connection::SinkConnection,
tls::{TlsConnector, TlsConnectorConfig},
};

use anyhow::{Context, Result};
use futures::stream::FuturesUnordered;
use kafka_protocol::protocol::StrBytes;
use metrics::{histogram, Histogram};
use rand::rngs::SmallRng;
use rand::SeedableRng;
use serde::{Deserialize, Serialize};
use std::{
collections::HashMap,
sync::Arc,
time::{Duration, Instant},
};
use tokio::sync::Notify;
use tokio::sync::{mpsc, oneshot};
use tokio_stream::StreamExt;

use crate::{
connection::SinkConnection,
tls::{TlsConnector, TlsConnectorConfig},
};

use super::node::{ConnectionFactory, KafkaAddress};

mod create_token;
mod recreate_token_queue;

Expand Down Expand Up @@ -70,6 +67,18 @@ impl TokenTask {
TokenTask { tx }
}

/// Informs the token task that we will need this token soon so it should start creating it if needed.
pub async fn prefetch_token_for_user(&self, username: String) -> Result<()> {
let (response_tx, _response_rx) = oneshot::channel();
self.tx
.send(TokenRequest {
username,
response_tx,
})
.await
.context("Failed to request delegation token from token task")
}

/// Request a token from the task.
/// If the task has a token for the user cached it will return it quickly.
/// If the task does not have a token for the user cached it will:
Expand Down Expand Up @@ -244,9 +253,31 @@ pub struct AuthorizeScramOverMtls {
/// Tracks the state of the original scram connections responses created from the clients actual requests
pub original_scram_state: OriginalScramState,
/// Shared task that fetches delegation tokens
pub token_task: TokenTask,
token_task: TokenTask,
/// The username used in the original scram auth to generate the delegation token
pub username: String,
username: String,
}

impl AuthorizeScramOverMtls {
pub async fn set_username(&mut self, username: String) -> Result<()> {
self.token_task
.prefetch_token_for_user(username.clone())
.await?;
self.username = username;
Ok(())
}

pub async fn get_token_for_user(&self) -> Result<DelegationToken> {
if !matches!(self.original_scram_state, OriginalScramState::AuthSuccess) {
// This should be enforced by logic that occurs before calling this method.
// This is a final check to enforce security, if this panic occurs it indicates a bug elsewhere in shotover.
panic!("Cannot hand out tokens to a connection that has not authenticated yet.")
}

self.token_task
.get_token_for_user(self.username.clone())
.await
}
}

pub enum OriginalScramState {
Expand Down
1 change: 1 addition & 0 deletions test-helpers/src/connection/kafka/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use pretty_assertions::assert_eq;
#[cfg(feature = "kafka-cpp-driver-tests")]
pub mod cpp;
pub mod java;
pub mod node;

use anyhow::Result;
#[cfg(feature = "kafka-cpp-driver-tests")]
Expand Down
46 changes: 46 additions & 0 deletions test-helpers/src/connection/kafka/node.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use std::path::Path;

pub async fn run_node_smoke_test(address: &str) {
let dir = Path::new(env!("CARGO_MANIFEST_DIR")).join("src/connection/kafka/node");
let config = format!(
r#"({{
clientId: 'nodejs-client',
brokers: ["{address}"],
}})"#
);
run_command(&dir, "npm", &["install"]).await;
run_command(&dir, "npm", &["start", &config]).await;
}

pub async fn run_node_smoke_test_scram(address: &str, user: &str, password: &str) {
let dir = Path::new(env!("CARGO_MANIFEST_DIR")).join("src/connection/kafka/node");
let config = format!(
r#"({{
clientId: 'nodejs-client',
brokers: ["{address}"],
sasl: {{
mechanism: 'scram-sha-256',
username: '{user}',
password: '{password}'
}}
}})"#
);
run_command(&dir, "npm", &["install"]).await;
run_command(&dir, "npm", &["start", &config]).await;
}

async fn run_command(current_dir: &Path, command: &str, args: &[&str]) -> String {
let output = tokio::process::Command::new(command)
.args(args)
.current_dir(current_dir)
.output()
.await
.unwrap();

let stdout = String::from_utf8(output.stdout).unwrap();
let stderr = String::from_utf8(output.stderr).unwrap();
if !output.status.success() {
panic!("command {command} {args:?} failed:\nstdout:\n{stdout}\nstderr:\n{stderr}")
}
stdout
}
70 changes: 70 additions & 0 deletions test-helpers/src/connection/kafka/node/index.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
const { Kafka } = require('kafkajs')
const fs = require('fs')
const assert = require('assert')

function delay(time) {
return new Promise(resolve => setTimeout(resolve, time));
}

const run = async () => {
const args = process.argv.slice(2);
const config = args[0];

const kafka = new Kafka(eval(config))

// Producing
const producer = kafka.producer()
await producer.connect()
await producer.send({
topic: 'test',
messages: [
{ value: 'foo' },
],
})
await producer.send({
topic: 'test',
messages: [
{ value: 'a longer string' },
],
})
await producer.disconnect()

// Consuming
const consumer = kafka.consumer({ groupId: 'test-group' })
await consumer.connect()
await consumer.subscribe({ topic: 'test', fromBeginning: true })

messages = []
await consumer.run({
eachMessage: async ({ topic, partition, message }) => {
messages.push({
topic,
partition,
offset: message.offset,
value: message.value.toString(),
})
},
})

// Use a very primitive sleep loop since nodejs doesnt seem to have any kind of mpsc or channel functionality :/
while (messages.length < 2) {
await delay(10);
}
assert.deepStrictEqual(messages, [
{
topic: 'test',
partition: 0,
offset: '0',
value: 'foo',
},
{
topic: 'test',
partition: 0,
offset: '1',
value: 'a longer string',
}
])
await consumer.disconnect()
}

run()
25 changes: 25 additions & 0 deletions test-helpers/src/connection/kafka/node/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions test-helpers/src/connection/kafka/node/package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"name": "kafkajs_wrapper",
"version": "1.0.0",
"main": "index.js",
"scripts": {
"start": "node index.js"
},
"author": "",
"license": "Apache-2.0",
"description": "",
"dependencies": {
"kafkajs": "^2.2.4"
}
}

0 comments on commit 2d37bfe

Please sign in to comment.