diff --git a/kotlin-mbedtls-metrics/src/test/kotlin/org/opencoap/ssl/transport/metrics/micrometer/DtlsServerMetricsCallbacksTest.kt b/kotlin-mbedtls-metrics/src/test/kotlin/org/opencoap/ssl/transport/metrics/micrometer/DtlsServerMetricsCallbacksTest.kt index 702b62c4..5b3308a0 100644 --- a/kotlin-mbedtls-metrics/src/test/kotlin/org/opencoap/ssl/transport/metrics/micrometer/DtlsServerMetricsCallbacksTest.kt +++ b/kotlin-mbedtls-metrics/src/test/kotlin/org/opencoap/ssl/transport/metrics/micrometer/DtlsServerMetricsCallbacksTest.kt @@ -117,6 +117,7 @@ class DtlsServerMetricsCallbacksTest { } @Test + @Disabled("After implementation of invalid handshake datagrams dropping it's hard to simulate wrong handshake") fun `should report DTLS server metrics for handshake errors`() { server = DtlsServerTransport.create(conf, lifecycleCallbacks = metricsCallbacks).listen(echoHandler) val cliChannel: DatagramChannel = DatagramChannel.open() diff --git a/kotlin-mbedtls-netty/src/main/kotlin/org/opencoap/ssl/netty/DtlsChannelHandler.kt b/kotlin-mbedtls-netty/src/main/kotlin/org/opencoap/ssl/netty/DtlsChannelHandler.kt index 6b58dad1..12c2f5ab 100644 --- a/kotlin-mbedtls-netty/src/main/kotlin/org/opencoap/ssl/netty/DtlsChannelHandler.kt +++ b/kotlin-mbedtls-netty/src/main/kotlin/org/opencoap/ssl/netty/DtlsChannelHandler.kt @@ -36,7 +36,8 @@ class DtlsChannelHandler @JvmOverloads constructor( private val sslConfig: SslConfig, private val expireAfter: Duration = Duration.ofSeconds(60), private val sessionStore: SessionStore = NoOpsSessionStore, - private val lifecycleCallbacks: DtlsSessionLifecycleCallbacks = object : DtlsSessionLifecycleCallbacks {} + private val lifecycleCallbacks: DtlsSessionLifecycleCallbacks = object : DtlsSessionLifecycleCallbacks {}, + private val cidRequired: Boolean = false ) : ChannelDuplexHandler() { private lateinit var ctx: ChannelHandlerContext lateinit var dtlsServer: DtlsServer @@ -52,7 +53,7 @@ class DtlsChannelHandler @JvmOverloads constructor( override fun handlerAdded(ctx: ChannelHandlerContext) { this.ctx = ctx - this.dtlsServer = DtlsServer(::write, sslConfig, expireAfter, sessionStore::write, lifecycleCallbacks, ctx.executor()) + this.dtlsServer = DtlsServer(::write, sslConfig, expireAfter, sessionStore::write, lifecycleCallbacks, ctx.executor(), cidRequired) } override fun channelRead(ctx: ChannelHandlerContext, msg: Any) { diff --git a/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/SslConfig.kt b/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/SslConfig.kt index 45092423..96e355d6 100644 --- a/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/SslConfig.kt +++ b/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/SslConfig.kt @@ -63,7 +63,7 @@ import java.time.Duration.ofSeconds class SslConfig( private val conf: Memory, - val cidSupplier: CidSupplier, + val cidSupplier: CidSupplier?, private val mtu: Int, private val close: Closeable ) : Closeable by close { @@ -75,8 +75,10 @@ class SslConfig( mbedtls_ssl_setup(sslContext, conf).verify() mbedtls_ssl_set_timer_cb(sslContext, Pointer.NULL, NoOpsSetDelayCallback, NoOpsGetDelayCallback) - val cid = cidSupplier.next() - mbedtls_ssl_set_cid(sslContext, 1, cid, cid.size).verify() + val cid = cidSupplier?.next() + if (cid != null) { + mbedtls_ssl_set_cid(sslContext, 1, cid, cid.size).verify() + } mbedtls_ssl_set_mtu(sslContext, mtu) val clientId = peerAddress.toString() @@ -103,13 +105,13 @@ class SslConfig( @JvmStatic @JvmOverloads - fun client(auth: AuthConfig, cipherSuites: List = emptyList(), reqAuthentication: Boolean = true, cidSupplier: CidSupplier = EmptyCidSupplier, retransmitMin: Duration = ofSeconds(1), retransmitMax: Duration = ofSeconds(60)): SslConfig { + fun client(auth: AuthConfig, cipherSuites: List = emptyList(), reqAuthentication: Boolean = true, cidSupplier: CidSupplier? = EmptyCidSupplier, retransmitMin: Duration = ofSeconds(1), retransmitMax: Duration = ofSeconds(60)): SslConfig { return create(false, auth, cipherSuites, cidSupplier, reqAuthentication, 0, retransmitMin, retransmitMax) } @JvmStatic @JvmOverloads - fun server(auth: AuthConfig, cipherSuites: List = emptyList(), reqAuthentication: Boolean = true, cidSupplier: CidSupplier = EmptyCidSupplier, mtu: Int = 0, retransmitMin: Duration = ofSeconds(1), retransmitMax: Duration = ofSeconds(60)): SslConfig { + fun server(auth: AuthConfig, cipherSuites: List = emptyList(), reqAuthentication: Boolean = true, cidSupplier: CidSupplier? = EmptyCidSupplier, mtu: Int = 0, retransmitMin: Duration = ofSeconds(1), retransmitMax: Duration = ofSeconds(60)): SslConfig { return create(true, auth, cipherSuites, cidSupplier, reqAuthentication, mtu, retransmitMin, retransmitMax) } @@ -117,11 +119,11 @@ class SslConfig( isServer: Boolean, authConfig: AuthConfig, cipherSuites: List, - cidSupplier: CidSupplier, + cidSupplier: CidSupplier?, requiredAuthMode: Boolean = true, mtu: Int, retransmitMin: Duration, - retransmitMax: Duration, + retransmitMax: Duration ): SslConfig { val sslConfig = Memory(MbedtlsSizeOf.mbedtls_ssl_config).also(MbedtlsApi::mbedtls_ssl_config_init) val entropy = Memory(MbedtlsSizeOf.mbedtls_entropy_context).also(MbedtlsApi.Crypto::mbedtls_entropy_init) @@ -154,7 +156,7 @@ class SslConfig( mbedtls_ssl_conf_ciphersuites(sslConfig, cipherSuiteIds) } - if (cidSupplier != EmptyCidSupplier) { + if (cidSupplier != null && cidSupplier != EmptyCidSupplier) { mbedtls_ssl_conf_cid(sslConfig, cidSupplier.next().size, 0) } diff --git a/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/DtlsServer.kt b/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/DtlsServer.kt index 8f431f68..f157996d 100644 --- a/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/DtlsServer.kt +++ b/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/DtlsServer.kt @@ -26,6 +26,7 @@ import org.opencoap.ssl.SslSession import org.slf4j.LoggerFactory import java.net.InetSocketAddress import java.nio.ByteBuffer +import java.nio.ByteOrder import java.time.Duration import java.time.Instant import java.util.concurrent.CompletableFuture @@ -38,7 +39,8 @@ class DtlsServer( private val expireAfter: Duration = Duration.ofSeconds(60), private val storeSession: (cid: ByteArray, session: SessionWithContext) -> Unit, private val lifecycleCallbacks: DtlsSessionLifecycleCallbacks = object : DtlsSessionLifecycleCallbacks {}, - private val executor: ScheduledExecutorService + private val executor: ScheduledExecutorService, + private val cidRequired: Boolean = false ) { companion object { private val EMPTY_BUFFER = ByteBuffer.allocate(0) @@ -49,11 +51,12 @@ class DtlsServer( // note: non thread save, must be used only from same thread private val sessions = mutableMapOf() - private val cidSize = sslConfig.cidSupplier.next().size + private val cidSize = sslConfig.cidSupplier?.next()?.size ?: 0 val numberOfSessions get() = sessions.size fun handleReceived(adr: InetSocketAddress, buf: ByteBuffer): ReceiveResult { val cid by lazy { SslContext.peekCID(cidSize, buf) } + val isValidHandshake by lazy { isValidHandshakeRequest(buf) } val dtlsState = sessions[adr] return when { @@ -63,12 +66,19 @@ class DtlsServer( // no session, but dtls packet contains CID cid != null -> ReceiveResult.CidSessionMissing(cid!!) - // new handshake - else -> { + // start new handshake if datagram is valid + isValidHandshake -> { val dtlsHandshake = DtlsHandshake(sslConfig.newContext(adr), adr) sessions[adr] = dtlsHandshake dtlsHandshake.step(buf) } + + // drop silently + else -> { + logger.warn("[{}] Invalid DTLS session handshake.", adr) + reportMessageDrop(adr) + ReceiveResult.Handled + } } } @@ -186,6 +196,7 @@ class DtlsServer( when (ex) { is SslException -> logger.warn("[{}] DTLS failed: {}", peerAddress, ex.message) + else -> logger.error(ex.toString(), ex) } @@ -305,4 +316,66 @@ class DtlsServer( lifecycleCallbacks.sessionFinished(peerAddress, reason, err) } } + + private fun isValidHandshakeRequest(buf: ByteBuffer): Boolean { + val workingBuf = buf.slice().order(ByteOrder.BIG_ENDIAN) + + // Check if the header is correct + val header = workingBuf.getLong(0) + if (header != 0x16FEFD0000000000L) { + logger.debug("Bad DTLS header") + return false + } + + // Check if it is a ClientHello handshake + val handshakeType = workingBuf.get(13).toInt() + if (handshakeType != 1) { + logger.debug("Bad handshake type") + return false + } + + // Check if CID is supported by the client in case if CID support is mandatory + if (cidRequired && !supportsCid(workingBuf)) { + logger.debug("No CID support") + return false + } + + return true + } + + private fun supportsCid(buf: ByteBuffer): Boolean { + val workingBuffer = buf.slice().order(ByteOrder.BIG_ENDIAN) + + // Go to the start of extensions + workingBuffer + // Skip DTLSHeader(13) + HandshakeHeader(12) + CookieLengthOffset(35) + .seek(60) + // Skip variable-length Cookie + .readByteAndSeek() + // Skip variable-length CipherSuites + .readShortAndSeek() + // Skip variable-length CompressionMethods + .readByteAndSeek() + // Limit buffer to the extensions length + .getShort().also { + workingBuffer.limit(workingBuffer.position() + it.toInt()) + } + + // Search for CID extension + while (workingBuffer.remaining() >= 4) { + val type = workingBuffer.getShort() + if (type == 0x36.toShort()) { + return true + } + + // Skip to the next extension + workingBuffer.readShortAndSeek() + } + + return false + } } + +private fun ByteBuffer.seek(offset: Int): ByteBuffer = this.position(this.position() + offset) as ByteBuffer +private fun ByteBuffer.readShortAndSeek(): ByteBuffer = this.getShort().let { this.seek(it.toInt()) } +private fun ByteBuffer.readByteAndSeek(): ByteBuffer = this.get().let { this.seek(it.toInt()) } diff --git a/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/DtlsServerTransport.kt b/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/DtlsServerTransport.kt index 5ca80aa7..53d42a03 100644 --- a/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/DtlsServerTransport.kt +++ b/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/DtlsServerTransport.kt @@ -45,10 +45,11 @@ class DtlsServerTransport private constructor( expireAfter: Duration = Duration.ofSeconds(60), sessionStore: SessionStore = NoOpsSessionStore, transport: Transport = DatagramChannelAdapter.open(listenPort), - lifecycleCallbacks: DtlsSessionLifecycleCallbacks = object : DtlsSessionLifecycleCallbacks {} + lifecycleCallbacks: DtlsSessionLifecycleCallbacks = object : DtlsSessionLifecycleCallbacks {}, + cidRequired: Boolean = false ): DtlsServerTransport { val executor = SingleThreadExecutor.create("dtls-srv-") - val dtlsServer = DtlsServer(transport, config, expireAfter, sessionStore::write, lifecycleCallbacks, executor) + val dtlsServer = DtlsServer(transport, config, expireAfter, sessionStore::write, lifecycleCallbacks, executor, cidRequired) return DtlsServerTransport(transport, dtlsServer, sessionStore, executor) } } diff --git a/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/transport/DtlsServerTest.kt b/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/transport/DtlsServerTest.kt index edb9fe8f..08b373e6 100644 --- a/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/transport/DtlsServerTest.kt +++ b/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/transport/DtlsServerTest.kt @@ -47,6 +47,7 @@ import java.util.concurrent.CompletableFuture.completedFuture class DtlsServerTest { val serverConf = SslConfig.server(CertificateAuth(Certs.serverChain, Certs.server.privateKey), listOf("TLS-ECDHE-ECDSA-WITH-AES-128-GCM-SHA256"), false, RandomCidSupplier(16)) val clientConf = SslConfig.client(CertificateAuth.trusted(Certs.root.asX509()), cipherSuites = listOf("TLS-ECDHE-ECDSA-WITH-AES-128-GCM-SHA256")) + val clientConfNoCid = SslConfig.client(CertificateAuth.trusted(Certs.root.asX509()), cipherSuites = listOf("TLS-ECDHE-ECDSA-WITH-AES-128-GCM-SHA256"), cidSupplier = null) private val sessionStore = HashMapSessionStore() private lateinit var dtlsServer: DtlsServer @@ -122,6 +123,42 @@ class DtlsServerTest { clientSession.close() } + @Test + fun `should handshake when CID is required`() { + dtlsServer = DtlsServer(::outboundTransport, serverConf, 100.millis, sessionStore::write, executor = SingleThreadExecutor.create("dtls-srv-"), cidRequired = true) + + // when + val clientSession = clientHandshake() + + // then + val dtlsPacket = clientSession.encrypt("terve".toByteBuffer()).order(ByteOrder.BIG_ENDIAN) + val dtlsPacketIn = (dtlsServer.handleReceived(localAddress(2_5684), dtlsPacket) as ReceiveResult.Decrypted).packet + assertEquals("terve", dtlsPacketIn.buffer.decodeToString()) + assertEquals(1, dtlsServer.numberOfSessions) + assertNotNull(dtlsPacketIn.sessionContext.sessionStartTimestamp) + + await.untilAsserted { + assertTrue(serverOutboundQueue.isEmpty()) + } + + clientSession.close() + } + + @Test + fun `should fail handshake when CID is required and client doesn't provide it`() { + dtlsServer = DtlsServer(::outboundTransport, serverConf, 100.millis, sessionStore::write, executor = SingleThreadExecutor.create("dtls-srv-"), cidRequired = true) + val send: (ByteBuffer) -> Unit = { dtlsServer.handleReceived(localAddress(2_5684), it) } + val cliHandshake = clientConfNoCid.newContext(localAddress(5684)) + + // when + cliHandshake.step(send) + + // then + await.untilAsserted { + assertTrue(serverOutboundQueue.isEmpty()) + } + } + @Test fun `should handshake with replaying records`() { lateinit var sendingBuffer: ByteBuffer diff --git a/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/transport/DtlsServerTransportTest.kt b/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/transport/DtlsServerTransportTest.kt index 83a0b5da..7d4a5b19 100644 --- a/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/transport/DtlsServerTransportTest.kt +++ b/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/transport/DtlsServerTransportTest.kt @@ -234,10 +234,6 @@ class DtlsServerTransportTest { assertEquals(0, cliChannel.read("aaa".toByteBuffer())) cliChannel.close() - verify(atMost = 100) { - sslLifecycleCallbacks.handshakeFinished(any(), any(), any(), DtlsSessionLifecycleCallbacks.Reason.FAILED, ofType(SslException::class)) - } - verify(exactly = 0) { sslLifecycleCallbacks.handshakeFinished(any(), any(), any(), DtlsSessionLifecycleCallbacks.Reason.FAILED, ofType(HelloVerifyRequired::class)) }