Skip to content

Commit

Permalink
Add support for DTLS raw keys.
Browse files Browse the repository at this point in the history
  • Loading branch information
JonathanLennox committed Nov 2, 2024
1 parent 7245c6a commit de1ea08
Show file tree
Hide file tree
Showing 10 changed files with 136 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ import java.security.KeyPair
data class CertificateInfo(
val keyPair: KeyPair,
val certificate: org.bouncycastle.tls.Certificate,
val rawKeyCertificate: org.bouncycastle.tls.Certificate,
val localFingerprintHashFunction: String,
val localFingerprint: String,
val localRawKeyFingerprint: String,
val creationTimestampMs: Long
)
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ class DtlsConfig {
}
}

val negotiateRawKeyFingerprints: Boolean by config {
"jmt.dtls.negotiate-raw-key-fingerprints".from(JitsiConfig.newConfig)
}

companion object {
val config = DtlsConfig()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,19 @@ class DtlsStack(
val localFingerprint: String
get() = certificateInfo.localFingerprint

val localRawKeyFingerprint: String
get() = certificateInfo.localRawKeyFingerprint

/**
* The remote fingerprints sent to us over the signaling path.
*/
var remoteFingerprints: Map<String, String> = HashMap()

/**
* The remote raw key fingerprints.
*/
var remoteRawKeyFingerprints: Map<String, String> = HashMap()

/**
* A handler which will be invoked when DTLS application data is received
*/
Expand Down Expand Up @@ -174,7 +182,7 @@ class DtlsStack(
*/
private fun verifyAndValidateRemoteCertificate(remoteCertificate: Certificate?) {
remoteCertificate?.let {
DtlsUtils.verifyAndValidateCertificate(it, remoteFingerprints)
DtlsUtils.verifyAndValidateCertificate(it, remoteFingerprints, remoteRawKeyFingerprints)
// The above throws an exception if the checks fail.
logger.cdebug { "Fingerprints verified." }
} ?: run {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,27 @@
package org.jitsi.nlj.dtls

import org.bouncycastle.asn1.ASN1Encoding
import org.bouncycastle.asn1.ASN1Object
import org.bouncycastle.asn1.x500.X500Name
import org.bouncycastle.asn1.x500.X500NameBuilder
import org.bouncycastle.asn1.x500.style.BCStyle
import org.bouncycastle.asn1.x509.Certificate
import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo
import org.bouncycastle.cert.jcajce.JcaX509v3CertificateBuilder
import org.bouncycastle.jce.ECNamedCurveTable
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.bouncycastle.operator.DefaultDigestAlgorithmIdentifierFinder
import org.bouncycastle.operator.bc.BcDefaultDigestProvider
import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder
import org.bouncycastle.tls.AlertDescription
import org.bouncycastle.tls.CertificateEntry
import org.bouncycastle.tls.CertificateType
import org.bouncycastle.tls.TlsContext
import org.bouncycastle.tls.TlsUtils
import org.bouncycastle.tls.crypto.TlsSecret
import org.bouncycastle.tls.crypto.impl.bc.BcTlsCertificate
import org.bouncycastle.tls.crypto.impl.bc.BcTlsCrypto
import org.bouncycastle.tls.crypto.impl.bc.BcTlsRawKeyCertificate
import org.jitsi.utils.logging2.Logger
import org.jitsi.utils.logging2.cdebug
import org.jitsi.utils.logging2.cerror
Expand Down Expand Up @@ -68,14 +73,24 @@ class DtlsUtils {
val localFingerprintHashFunction = x509certificate.getHashFunction()
val localFingerprint = x509certificate.getFingerprint(localFingerprintHashFunction)

val sPKI = x509certificate.subjectPublicKeyInfo
val rawKeyFingerprint = sPKI.getFingerprint(localFingerprintHashFunction)

val certificate = org.bouncycastle.tls.Certificate(
arrayOf(BcTlsCertificate(BC_TLS_CRYPTO, x509certificate))
)
val rawKeyCertificate = org.bouncycastle.tls.Certificate(
CertificateType.RawPublicKey,
null,
arrayOf(CertificateEntry(BcTlsRawKeyCertificate(BC_TLS_CRYPTO, sPKI), null))
)
return CertificateInfo(
keyPair,
certificate,
rawKeyCertificate,
localFingerprintHashFunction,
localFingerprint,
rawKeyFingerprint,
System.currentTimeMillis()
)
}
Expand Down Expand Up @@ -158,14 +173,26 @@ class DtlsUtils {
*/
fun verifyAndValidateCertificate(
certificateInfo: org.bouncycastle.tls.Certificate,
remoteFingerprints: Map<String, String>
remoteFingerprints: Map<String, String>,
remoteRawKeyFingerprints: Map<String, String>
) {
if (certificateInfo.certificateList.isEmpty()) {
throw DtlsException("No remote fingerprints.")
}
val type = certificateInfo.certificateType
for (currCertificate in certificateInfo.certificateList) {
val x509Cert = Certificate.getInstance(currCertificate.encoded)
verifyAndValidateCertificate(x509Cert, remoteFingerprints)
when (type) {
CertificateType.X509 -> {
val x509Cert = Certificate.getInstance(currCertificate.encoded)
verifyAndValidateCertificate(x509Cert, remoteFingerprints)
}
CertificateType.RawPublicKey -> {
val sPKI = SubjectPublicKeyInfo.getInstance(currCertificate.encoded)
verifyAndValidateRawPublicKey(sPKI, remoteRawKeyFingerprints)
}
else ->
throw DtlsException("Invalid certificate type")
}
}
}

Expand Down Expand Up @@ -229,6 +256,18 @@ class DtlsUtils {
}
}

private fun verifyAndValidateRawPublicKey(
sPKI: SubjectPublicKeyInfo,
remoteRawKeyFingerprints: Map<String, String>
) {
if (!remoteRawKeyFingerprints.any { (hash, fingerprint) ->
sPKI.getFingerprint(hash) == fingerprint
}
) {
throw DtlsException("No remote raw key fingerprint matches SubjectPublicKeyInfo")
}
}

/**
* Determine and return the hash function (as a [String]) used by this certificate
*/
Expand All @@ -242,10 +281,10 @@ class DtlsUtils {
}

/**
* Computes the fingerprint of a [org.bouncycastle.asn1.x509.Certificate] using [hashFunction] and returns it
* as a [String]
* Computes the fingerprint of a [ASN1Object] (e.g. a [org.bouncycastle.asn1.x509.Certificate])
* using [hashFunction] and returns it as a [String]
*/
private fun Certificate.getFingerprint(hashFunction: String): String {
private fun ASN1Object.getFingerprint(hashFunction: String): String {
val digAlgId = DefaultDigestAlgorithmIdentifierFinder().find(hashFunction.uppercase())
val digest = BcDefaultDigestProvider.INSTANCE.get(digAlgId)
val input: ByteArray = getEncoded(ASN1Encoding.DER)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ package org.jitsi.nlj.dtls

import edu.umd.cs.findbugs.annotations.SuppressFBWarnings
import org.bouncycastle.crypto.util.PrivateKeyFactory
import org.bouncycastle.tls.AlertDescription
import org.bouncycastle.tls.Certificate
import org.bouncycastle.tls.CertificateRequest
import org.bouncycastle.tls.CertificateType
import org.bouncycastle.tls.DefaultTlsClient
import org.bouncycastle.tls.ExporterLabel
import org.bouncycastle.tls.ExtensionType
Expand All @@ -28,6 +30,7 @@ import org.bouncycastle.tls.SignatureAlgorithm
import org.bouncycastle.tls.SignatureAndHashAlgorithm
import org.bouncycastle.tls.TlsAuthentication
import org.bouncycastle.tls.TlsCredentials
import org.bouncycastle.tls.TlsFatalAlert
import org.bouncycastle.tls.TlsSRTPUtils
import org.bouncycastle.tls.TlsServerCertificate
import org.bouncycastle.tls.TlsSession
Expand Down Expand Up @@ -81,12 +84,20 @@ class TlsClientImpl(
return object : TlsAuthentication {
override fun getClientCredentials(certificateRequest: CertificateRequest): TlsCredentials {
// NOTE: can't set clientCredentials when it is declared because 'context' won't be set yet
val cert = when (context.securityParametersHandshake.clientCertificateType) {
CertificateType.RawPublicKey ->
certificateInfo.rawKeyCertificate
CertificateType.X509 ->
certificateInfo.certificate
else ->
throw TlsFatalAlert(AlertDescription.internal_error)
}
if (clientCredentials == null) {
clientCredentials = BcDefaultTlsCredentialedSigner(
TlsCryptoParameters(context),
(context.crypto as BcTlsCrypto),
PrivateKeyFactory.createKey(certificateInfo.keyPair.private.encoded),
certificateInfo.certificate,
cert,
if (TlsUtils.isSignatureAlgorithmsExtensionAllowed(context.serverVersion)) {
SignatureAndHashAlgorithm(
HashAlgorithm.sha256,
Expand Down Expand Up @@ -123,6 +134,20 @@ class TlsClientImpl(
return clientExtensions
}

override fun getAllowedClientCertificateTypes(): ShortArray? {
if (DtlsConfig.config.negotiateRawKeyFingerprints) {
return shortArrayOf(CertificateType.X509, CertificateType.RawPublicKey)
}
return null
}

override fun getAllowedServerCertificateTypes(): ShortArray? {
if (DtlsConfig.config.negotiateRawKeyFingerprints) {
return shortArrayOf(CertificateType.X509, CertificateType.RawPublicKey)
}
return null
}

override fun processServerExtensions(serverExtensions: Hashtable<*, *>?) {
// TODO: a few cases we should be throwing alerts for in here. see old TlsClientImpl
val useSRTPData = TlsSRTPUtils.getUseSRTPExtension(serverExtensions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import edu.umd.cs.findbugs.annotations.SuppressFBWarnings
import org.bouncycastle.crypto.util.PrivateKeyFactory
import org.bouncycastle.tls.Certificate
import org.bouncycastle.tls.CertificateRequest
import org.bouncycastle.tls.CertificateType
import org.bouncycastle.tls.ClientCertificateType
import org.bouncycastle.tls.DefaultTlsServer
import org.bouncycastle.tls.ExporterLabel
Expand All @@ -29,6 +30,7 @@ import org.bouncycastle.tls.SignatureAlgorithm
import org.bouncycastle.tls.SignatureAndHashAlgorithm
import org.bouncycastle.tls.TlsCredentialedDecryptor
import org.bouncycastle.tls.TlsCredentialedSigner
import org.bouncycastle.tls.TlsExtensionsUtils
import org.bouncycastle.tls.TlsSRTPUtils
import org.bouncycastle.tls.TlsSession
import org.bouncycastle.tls.TlsUtils
Expand Down Expand Up @@ -64,6 +66,8 @@ class TlsServerImpl(

private var session: TlsSession? = null

private var useRawKeys: Boolean = false

/**
* Only set after a handshake has completed
*/
Expand Down Expand Up @@ -96,6 +100,17 @@ class TlsServerImpl(
val protectionProfiles = useSRTPData.protectionProfiles
chosenSrtpProtectionProfile =
DtlsUtils.chooseSrtpProtectionProfile(SrtpConfig.protectionProfiles, protectionProfiles.asIterable())

if (DtlsConfig.config.negotiateRawKeyFingerprints) {
val remoteServerCertTypes = TlsExtensionsUtils.getServerCertificateTypeExtensionClient(clientExtensions)
val remoteClientCertTypes = TlsExtensionsUtils.getClientCertificateTypeExtensionClient(clientExtensions)

if (remoteServerCertTypes?.contains(CertificateType.RawPublicKey) == true &&
remoteClientCertTypes?.contains(CertificateType.RawPublicKey) == true
) {
useRawKeys = true
}
}
}

override fun getCipherSuites() = DtlsConfig.config.cipherSuites.toIntArray()
Expand All @@ -109,15 +124,27 @@ class TlsServerImpl(
}

override fun getECDSASignerCredentials(): TlsCredentialedSigner {
val cert = if (useRawKeys) {
certificateInfo.rawKeyCertificate
} else {
certificateInfo.certificate
}
return BcDefaultTlsCredentialedSigner(
TlsCryptoParameters(context),
(context.crypto as BcTlsCrypto),
PrivateKeyFactory.createKey(certificateInfo.keyPair.private.encoded),
certificateInfo.certificate,
cert,
SignatureAndHashAlgorithm(HashAlgorithm.sha256, SignatureAlgorithm.ecdsa)
)
}

override fun getAllowedClientCertificateTypes(): ShortArray? {
if (useRawKeys) {
return shortArrayOf(CertificateType.RawPublicKey)
}
return null
}

override fun getCertificateRequest(): CertificateRequest {
val signatureAlgorithms = Vector<SignatureAndHashAlgorithm>(1)
signatureAlgorithms.add(SignatureAndHashAlgorithm(HashAlgorithm.sha256, SignatureAlgorithm.ecdsa))
Expand Down
3 changes: 3 additions & 0 deletions jitsi-media-transform/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ jmt {
// TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
// TLS_DHE_RSA_WITH_AES_256_GCM_SHA384
]

// Whether to send and recognize fingerprints for raw keys
negotiate-raw-key-fingerprints = false
}
srtp {
// The maximum number of packets that can be discarded early (without going through the SRTP stack for
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import kotlin.concurrent.thread
class DtlsTest : ShouldSpec() {
override fun isolationMode(): IsolationMode? = IsolationMode.InstancePerLeaf
private val debugEnabled = true
private val pcapEnabled = false
private val pcapEnabled = true
private val logger = StdoutLogger(_level = Level.OFF)

fun debug(s: String) {
Expand All @@ -50,9 +50,15 @@ class DtlsTest : ShouldSpec() {
dtlsClient.remoteFingerprints = mapOf(
dtlsServer.localFingerprintHashFunction to dtlsServer.localFingerprint
)
dtlsClient.remoteRawKeyFingerprints = mapOf(
dtlsServer.localFingerprintHashFunction to dtlsServer.localRawKeyFingerprint
)
dtlsServer.remoteFingerprints = mapOf(
dtlsClient.localFingerprintHashFunction to dtlsClient.localFingerprint
)
dtlsServer.remoteRawKeyFingerprints = mapOf(
dtlsClient.localFingerprintHashFunction to dtlsClient.localRawKeyFingerprint
)

// The DTLS server's send is wired directly to the DTLS client's receive
dtlsServer.outgoingDataHandler = object : DtlsStack.OutgoingDataHandler {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.jitsi.videobridge.transport.dtls

import org.ice4j.util.Buffer
import org.jitsi.nlj.dtls.DtlsClient
import org.jitsi.nlj.dtls.DtlsConfig
import org.jitsi.nlj.dtls.DtlsServer
import org.jitsi.nlj.dtls.DtlsStack
import org.jitsi.nlj.srtp.TlsRole
Expand All @@ -28,6 +29,7 @@ import org.jitsi.utils.logging2.createChildLogger
import org.jitsi.utils.queue.PacketQueue
import org.jitsi.videobridge.util.TaskPools
import org.jitsi.xmpp.extensions.jingle.DtlsFingerprintPacketExtension
import org.jitsi.xmpp.extensions.jingle.DtlsRawKeyFingerprintPacketExtension
import org.jitsi.xmpp.extensions.jingle.IceUdpTransportPacketExtension
import java.util.concurrent.atomic.AtomicBoolean

Expand Down Expand Up @@ -192,6 +194,15 @@ class DtlsTransport(parentLogger: Logger, id: String) {
if (cryptex) {
fingerprintPE.cryptex = true
}
if (DtlsConfig.config.negotiateRawKeyFingerprints) {
val rawKeyFingerprintPE = iceUdpTransportPe.getFirstChildOfType(
DtlsRawKeyFingerprintPacketExtension::class.java
) ?: run {
DtlsRawKeyFingerprintPacketExtension().also { iceUdpTransportPe.addChildExtension(it) }
}
rawKeyFingerprintPE.fingerprint = dtlsStack.localRawKeyFingerprint
rawKeyFingerprintPE.hash = dtlsStack.localFingerprintHashFunction
}
}

fun enqueueBuffer(buffer: Buffer) = dtlsQueue.add(buffer)
Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@
<dependency>
<groupId>${project.groupId}</groupId>
<artifactId>jitsi-xmpp-extensions</artifactId>
<version>1.0-81-g3816e5a</version>
<version>1.0-SNAPSHOT</version>
</dependency>
</dependencies>
</dependencyManagement>
Expand Down

0 comments on commit de1ea08

Please sign in to comment.