Skip to content

Commit

Permalink
Provide a way to require CID support from a peer (#53)
Browse files Browse the repository at this point in the history
  • Loading branch information
akolosov-n authored Feb 29, 2024
1 parent f431eb8 commit b0d3fde
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class DtlsServerMetricsCallbacksTest {
}

@Test
@Disabled("After implementation of invalid handshake datagrams dropping it's hard to simulate wrong handshake")
fun `should report DTLS server metrics for handshake errors`() {
server = DtlsServerTransport.create(conf, lifecycleCallbacks = metricsCallbacks).listen(echoHandler)
val cliChannel: DatagramChannel = DatagramChannel.open()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ class DtlsChannelHandler @JvmOverloads constructor(
private val sslConfig: SslConfig,
private val expireAfter: Duration = Duration.ofSeconds(60),
private val sessionStore: SessionStore = NoOpsSessionStore,
private val lifecycleCallbacks: DtlsSessionLifecycleCallbacks = object : DtlsSessionLifecycleCallbacks {}
private val lifecycleCallbacks: DtlsSessionLifecycleCallbacks = object : DtlsSessionLifecycleCallbacks {},
private val cidRequired: Boolean = false
) : ChannelDuplexHandler() {
private lateinit var ctx: ChannelHandlerContext
lateinit var dtlsServer: DtlsServer
Expand All @@ -52,7 +53,7 @@ class DtlsChannelHandler @JvmOverloads constructor(

override fun handlerAdded(ctx: ChannelHandlerContext) {
this.ctx = ctx
this.dtlsServer = DtlsServer(::write, sslConfig, expireAfter, sessionStore::write, lifecycleCallbacks, ctx.executor())
this.dtlsServer = DtlsServer(::write, sslConfig, expireAfter, sessionStore::write, lifecycleCallbacks, ctx.executor(), cidRequired)
}

override fun channelRead(ctx: ChannelHandlerContext, msg: Any) {
Expand Down
18 changes: 10 additions & 8 deletions kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/SslConfig.kt
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ import java.time.Duration.ofSeconds

class SslConfig(
private val conf: Memory,
val cidSupplier: CidSupplier,
val cidSupplier: CidSupplier?,
private val mtu: Int,
private val close: Closeable
) : Closeable by close {
Expand All @@ -75,8 +75,10 @@ class SslConfig(
mbedtls_ssl_setup(sslContext, conf).verify()
mbedtls_ssl_set_timer_cb(sslContext, Pointer.NULL, NoOpsSetDelayCallback, NoOpsGetDelayCallback)

val cid = cidSupplier.next()
mbedtls_ssl_set_cid(sslContext, 1, cid, cid.size).verify()
val cid = cidSupplier?.next()
if (cid != null) {
mbedtls_ssl_set_cid(sslContext, 1, cid, cid.size).verify()
}
mbedtls_ssl_set_mtu(sslContext, mtu)

val clientId = peerAddress.toString()
Expand All @@ -103,25 +105,25 @@ class SslConfig(

@JvmStatic
@JvmOverloads
fun client(auth: AuthConfig, cipherSuites: List<String> = emptyList(), reqAuthentication: Boolean = true, cidSupplier: CidSupplier = EmptyCidSupplier, retransmitMin: Duration = ofSeconds(1), retransmitMax: Duration = ofSeconds(60)): SslConfig {
fun client(auth: AuthConfig, cipherSuites: List<String> = emptyList(), reqAuthentication: Boolean = true, cidSupplier: CidSupplier? = EmptyCidSupplier, retransmitMin: Duration = ofSeconds(1), retransmitMax: Duration = ofSeconds(60)): SslConfig {
return create(false, auth, cipherSuites, cidSupplier, reqAuthentication, 0, retransmitMin, retransmitMax)
}

@JvmStatic
@JvmOverloads
fun server(auth: AuthConfig, cipherSuites: List<String> = emptyList(), reqAuthentication: Boolean = true, cidSupplier: CidSupplier = EmptyCidSupplier, mtu: Int = 0, retransmitMin: Duration = ofSeconds(1), retransmitMax: Duration = ofSeconds(60)): SslConfig {
fun server(auth: AuthConfig, cipherSuites: List<String> = emptyList(), reqAuthentication: Boolean = true, cidSupplier: CidSupplier? = EmptyCidSupplier, mtu: Int = 0, retransmitMin: Duration = ofSeconds(1), retransmitMax: Duration = ofSeconds(60)): SslConfig {
return create(true, auth, cipherSuites, cidSupplier, reqAuthentication, mtu, retransmitMin, retransmitMax)
}

private fun create(
isServer: Boolean,
authConfig: AuthConfig,
cipherSuites: List<String>,
cidSupplier: CidSupplier,
cidSupplier: CidSupplier?,
requiredAuthMode: Boolean = true,
mtu: Int,
retransmitMin: Duration,
retransmitMax: Duration,
retransmitMax: Duration
): SslConfig {
val sslConfig = Memory(MbedtlsSizeOf.mbedtls_ssl_config).also(MbedtlsApi::mbedtls_ssl_config_init)
val entropy = Memory(MbedtlsSizeOf.mbedtls_entropy_context).also(MbedtlsApi.Crypto::mbedtls_entropy_init)
Expand Down Expand Up @@ -154,7 +156,7 @@ class SslConfig(
mbedtls_ssl_conf_ciphersuites(sslConfig, cipherSuiteIds)
}

if (cidSupplier != EmptyCidSupplier) {
if (cidSupplier != null && cidSupplier != EmptyCidSupplier) {
mbedtls_ssl_conf_cid(sslConfig, cidSupplier.next().size, 0)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.opencoap.ssl.SslSession
import org.slf4j.LoggerFactory
import java.net.InetSocketAddress
import java.nio.ByteBuffer
import java.nio.ByteOrder
import java.time.Duration
import java.time.Instant
import java.util.concurrent.CompletableFuture
Expand All @@ -38,7 +39,8 @@ class DtlsServer(
private val expireAfter: Duration = Duration.ofSeconds(60),
private val storeSession: (cid: ByteArray, session: SessionWithContext) -> Unit,
private val lifecycleCallbacks: DtlsSessionLifecycleCallbacks = object : DtlsSessionLifecycleCallbacks {},
private val executor: ScheduledExecutorService
private val executor: ScheduledExecutorService,
private val cidRequired: Boolean = false
) {
companion object {
private val EMPTY_BUFFER = ByteBuffer.allocate(0)
Expand All @@ -49,11 +51,12 @@ class DtlsServer(

// note: non thread save, must be used only from same thread
private val sessions = mutableMapOf<InetSocketAddress, DtlsState>()
private val cidSize = sslConfig.cidSupplier.next().size
private val cidSize = sslConfig.cidSupplier?.next()?.size ?: 0
val numberOfSessions get() = sessions.size

fun handleReceived(adr: InetSocketAddress, buf: ByteBuffer): ReceiveResult {
val cid by lazy { SslContext.peekCID(cidSize, buf) }
val isValidHandshake by lazy { isValidHandshakeRequest(buf) }
val dtlsState = sessions[adr]

return when {
Expand All @@ -63,12 +66,19 @@ class DtlsServer(
// no session, but dtls packet contains CID
cid != null -> ReceiveResult.CidSessionMissing(cid!!)

// new handshake
else -> {
// start new handshake if datagram is valid
isValidHandshake -> {
val dtlsHandshake = DtlsHandshake(sslConfig.newContext(adr), adr)
sessions[adr] = dtlsHandshake
dtlsHandshake.step(buf)
}

// drop silently
else -> {
logger.warn("[{}] Invalid DTLS session handshake.", adr)
reportMessageDrop(adr)
ReceiveResult.Handled
}
}
}

Expand Down Expand Up @@ -186,6 +196,7 @@ class DtlsServer(
when (ex) {
is SslException ->
logger.warn("[{}] DTLS failed: {}", peerAddress, ex.message)

else ->
logger.error(ex.toString(), ex)
}
Expand Down Expand Up @@ -305,4 +316,66 @@ class DtlsServer(
lifecycleCallbacks.sessionFinished(peerAddress, reason, err)
}
}

private fun isValidHandshakeRequest(buf: ByteBuffer): Boolean {
val workingBuf = buf.slice().order(ByteOrder.BIG_ENDIAN)

// Check if the header is correct
val header = workingBuf.getLong(0)
if (header != 0x16FEFD0000000000L) {
logger.debug("Bad DTLS header")
return false
}

// Check if it is a ClientHello handshake
val handshakeType = workingBuf.get(13).toInt()
if (handshakeType != 1) {
logger.debug("Bad handshake type")
return false
}

// Check if CID is supported by the client in case if CID support is mandatory
if (cidRequired && !supportsCid(workingBuf)) {
logger.debug("No CID support")
return false
}

return true
}

private fun supportsCid(buf: ByteBuffer): Boolean {
val workingBuffer = buf.slice().order(ByteOrder.BIG_ENDIAN)

// Go to the start of extensions
workingBuffer
// Skip DTLSHeader(13) + HandshakeHeader(12) + CookieLengthOffset(35)
.seek(60)
// Skip variable-length Cookie
.readByteAndSeek()
// Skip variable-length CipherSuites
.readShortAndSeek()
// Skip variable-length CompressionMethods
.readByteAndSeek()
// Limit buffer to the extensions length
.getShort().also {
workingBuffer.limit(workingBuffer.position() + it.toInt())
}

// Search for CID extension
while (workingBuffer.remaining() >= 4) {
val type = workingBuffer.getShort()
if (type == 0x36.toShort()) {
return true
}

// Skip to the next extension
workingBuffer.readShortAndSeek()
}

return false
}
}

private fun ByteBuffer.seek(offset: Int): ByteBuffer = this.position(this.position() + offset) as ByteBuffer
private fun ByteBuffer.readShortAndSeek(): ByteBuffer = this.getShort().let { this.seek(it.toInt()) }
private fun ByteBuffer.readByteAndSeek(): ByteBuffer = this.get().let { this.seek(it.toInt()) }
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,11 @@ class DtlsServerTransport private constructor(
expireAfter: Duration = Duration.ofSeconds(60),
sessionStore: SessionStore = NoOpsSessionStore,
transport: Transport<ByteBufferPacket> = DatagramChannelAdapter.open(listenPort),
lifecycleCallbacks: DtlsSessionLifecycleCallbacks = object : DtlsSessionLifecycleCallbacks {}
lifecycleCallbacks: DtlsSessionLifecycleCallbacks = object : DtlsSessionLifecycleCallbacks {},
cidRequired: Boolean = false
): DtlsServerTransport {
val executor = SingleThreadExecutor.create("dtls-srv-")
val dtlsServer = DtlsServer(transport, config, expireAfter, sessionStore::write, lifecycleCallbacks, executor)
val dtlsServer = DtlsServer(transport, config, expireAfter, sessionStore::write, lifecycleCallbacks, executor, cidRequired)
return DtlsServerTransport(transport, dtlsServer, sessionStore, executor)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import java.util.concurrent.CompletableFuture.completedFuture
class DtlsServerTest {
val serverConf = SslConfig.server(CertificateAuth(Certs.serverChain, Certs.server.privateKey), listOf("TLS-ECDHE-ECDSA-WITH-AES-128-GCM-SHA256"), false, RandomCidSupplier(16))
val clientConf = SslConfig.client(CertificateAuth.trusted(Certs.root.asX509()), cipherSuites = listOf("TLS-ECDHE-ECDSA-WITH-AES-128-GCM-SHA256"))
val clientConfNoCid = SslConfig.client(CertificateAuth.trusted(Certs.root.asX509()), cipherSuites = listOf("TLS-ECDHE-ECDSA-WITH-AES-128-GCM-SHA256"), cidSupplier = null)

private val sessionStore = HashMapSessionStore()
private lateinit var dtlsServer: DtlsServer
Expand Down Expand Up @@ -122,6 +123,42 @@ class DtlsServerTest {
clientSession.close()
}

@Test
fun `should handshake when CID is required`() {
dtlsServer = DtlsServer(::outboundTransport, serverConf, 100.millis, sessionStore::write, executor = SingleThreadExecutor.create("dtls-srv-"), cidRequired = true)

// when
val clientSession = clientHandshake()

// then
val dtlsPacket = clientSession.encrypt("terve".toByteBuffer()).order(ByteOrder.BIG_ENDIAN)
val dtlsPacketIn = (dtlsServer.handleReceived(localAddress(2_5684), dtlsPacket) as ReceiveResult.Decrypted).packet
assertEquals("terve", dtlsPacketIn.buffer.decodeToString())
assertEquals(1, dtlsServer.numberOfSessions)
assertNotNull(dtlsPacketIn.sessionContext.sessionStartTimestamp)

await.untilAsserted {
assertTrue(serverOutboundQueue.isEmpty())
}

clientSession.close()
}

@Test
fun `should fail handshake when CID is required and client doesn't provide it`() {
dtlsServer = DtlsServer(::outboundTransport, serverConf, 100.millis, sessionStore::write, executor = SingleThreadExecutor.create("dtls-srv-"), cidRequired = true)
val send: (ByteBuffer) -> Unit = { dtlsServer.handleReceived(localAddress(2_5684), it) }
val cliHandshake = clientConfNoCid.newContext(localAddress(5684))

// when
cliHandshake.step(send)

// then
await.untilAsserted {
assertTrue(serverOutboundQueue.isEmpty())
}
}

@Test
fun `should handshake with replaying records`() {
lateinit var sendingBuffer: ByteBuffer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,6 @@ class DtlsServerTransportTest {
assertEquals(0, cliChannel.read("aaa".toByteBuffer()))
cliChannel.close()

verify(atMost = 100) {
sslLifecycleCallbacks.handshakeFinished(any(), any(), any(), DtlsSessionLifecycleCallbacks.Reason.FAILED, ofType(SslException::class))
}

verify(exactly = 0) {
sslLifecycleCallbacks.handshakeFinished(any(), any(), any(), DtlsSessionLifecycleCallbacks.Reason.FAILED, ofType(HelloVerifyRequired::class))
}
Expand Down

0 comments on commit b0d3fde

Please sign in to comment.