From 7819f8b4b0c4c135ea4543261757fa6e589b5ee4 Mon Sep 17 00:00:00 2001 From: David Woodhouse Date: Mon, 15 Apr 2019 14:19:38 +0100 Subject: [PATCH] Allow ESP functions to be overridden Signed-off-by: David Woodhouse --- esp.c | 14 ++++++++------ gnutls-esp.c | 14 ++++++++++++-- openconnect-internal.h | 7 ++++--- openssl-esp.c | 16 +++++++++++++--- 4 files changed, 37 insertions(+), 14 deletions(-) diff --git a/esp.c b/esp.c index cc423a58..2d0352f1 100644 --- a/esp.c +++ b/esp.c @@ -121,7 +121,7 @@ int construct_esp_packet(struct openconnect_info *vpninfo, struct pkt *pkt, uint memcpy(pkt->esp.iv, vpninfo->esp_out.iv, sizeof(pkt->esp.iv)); - ret = encrypt_esp_packet(vpninfo, pkt, pkt->len + padlen + 2); + ret = vpninfo->encrypt_esp_packet(vpninfo, pkt, pkt->len + padlen + 2); if (ret) return ret; @@ -179,14 +179,14 @@ int esp_mainloop(struct openconnect_info *vpninfo, int *timeout, int readable) pkt->len = len; if (pkt->esp.spi == esp->spi) { - if (decrypt_esp_packet(vpninfo, esp, pkt)) + if (vpninfo->decrypt_esp_packet(vpninfo, esp, pkt)) continue; } else if (pkt->esp.spi == old_esp->spi && ntohl(pkt->esp.seq) + esp->seq < vpninfo->old_esp_maxseq) { vpn_progress(vpninfo, PRG_TRACE, _("Received ESP packet from old SPI 0x%x, seq %u\n"), (unsigned)ntohl(old_esp->spi), (unsigned)ntohl(pkt->esp.seq)); - if (decrypt_esp_packet(vpninfo, old_esp, pkt)) + if (vpninfo->decrypt_esp_packet(vpninfo, old_esp, pkt)) continue; } else { vpn_progress(vpninfo, PRG_DEBUG, @@ -406,9 +406,11 @@ void esp_close(struct openconnect_info *vpninfo) void esp_shutdown(struct openconnect_info *vpninfo) { - destroy_esp_ciphers(&vpninfo->esp_in[0]); - destroy_esp_ciphers(&vpninfo->esp_in[1]); - destroy_esp_ciphers(&vpninfo->esp_out); + if (vpninfo->destroy_esp_ciphers) { + vpninfo->destroy_esp_ciphers(&vpninfo->esp_in[0]); + vpninfo->destroy_esp_ciphers(&vpninfo->esp_in[1]); + vpninfo->destroy_esp_ciphers(&vpninfo->esp_out); + } if (vpninfo->proto->udp_close) vpninfo->proto->udp_close(vpninfo); if (vpninfo->dtls_state != DTLS_DISABLED) diff --git a/gnutls-esp.c b/gnutls-esp.c index e350ff79..2bc6519d 100644 --- a/gnutls-esp.c +++ b/gnutls-esp.c @@ -28,7 +28,12 @@ #include #include -void destroy_esp_ciphers(struct esp *esp) +static int decrypt_esp_packet(struct openconnect_info *vpninfo, struct esp *esp, + struct pkt *pkt); +static int encrypt_esp_packet(struct openconnect_info *vpninfo, struct pkt *pkt, + int crypt_len); + +static void destroy_esp_ciphers(struct esp *esp) { if (esp->cipher) { gnutls_cipher_deinit(esp->cipher); @@ -114,11 +119,16 @@ int init_esp_ciphers(struct openconnect_info *vpninfo, struct esp *esp_out, stru return ret; } + vpninfo->decrypt_esp_packet = decrypt_esp_packet; + vpninfo->encrypt_esp_packet = encrypt_esp_packet; + vpninfo->destroy_esp_ciphers = destroy_esp_ciphers; + return 0; } /* pkt->len shall be the *payload* length. Omitting the header and the 12-byte HMAC */ -int decrypt_esp_packet(struct openconnect_info *vpninfo, struct esp *esp, struct pkt *pkt) +static int decrypt_esp_packet(struct openconnect_info *vpninfo, struct esp *esp, + struct pkt *pkt) { unsigned char hmac_buf[MAX_HMAC_SIZE]; int err; diff --git a/openconnect-internal.h b/openconnect-internal.h index 91610d41..b3a5ec82 100644 --- a/openconnect-internal.h +++ b/openconnect-internal.h @@ -752,6 +752,10 @@ struct openconnect_info { DELAY_CLOSE_IMMEDIATE_CALLBACK, } delay_close; /* Delay close of mainloop */ + void (*destroy_esp_ciphers)(struct esp *esp); + int (*decrypt_esp_packet)(struct openconnect_info *vpninfo, struct esp *esp, struct pkt *pkt); + int (*encrypt_esp_packet)(struct openconnect_info *vpninfo, struct pkt *pkt, int crypt_len); + int verbose; void *cbdata; openconnect_validate_peer_cert_vfn validate_peer_cert; @@ -1362,10 +1366,7 @@ int openconnect_setup_esp_keys(struct openconnect_info *vpninfo, int new_keys); int construct_esp_packet(struct openconnect_info *vpninfo, struct pkt *pkt, uint8_t next_hdr); /* {gnutls,openssl}-esp.c */ -void destroy_esp_ciphers(struct esp *esp); int init_esp_ciphers(struct openconnect_info *vpninfo, struct esp *out, struct esp *in); -int decrypt_esp_packet(struct openconnect_info *vpninfo, struct esp *esp, struct pkt *pkt); -int encrypt_esp_packet(struct openconnect_info *vpninfo, struct pkt *pkt, int crypt_len); /* {gnutls,openssl}.c */ const char *openconnect_get_tls_library_version(void); diff --git a/openssl-esp.c b/openssl-esp.c index 459e8c09..c2cd8cb0 100644 --- a/openssl-esp.c +++ b/openssl-esp.c @@ -27,6 +27,11 @@ #include #include +static int decrypt_esp_packet(struct openconnect_info *vpninfo, struct esp *esp, + struct pkt *pkt); +static int encrypt_esp_packet(struct openconnect_info *vpninfo, struct pkt *pkt, + int crypt_len); + #if OPENSSL_VERSION_NUMBER < 0x10100000L || defined(LIBRESSL_VERSION_NUMBER) #define EVP_CIPHER_CTX_free(c) do { \ @@ -45,7 +50,7 @@ static inline HMAC_CTX *HMAC_CTX_new(void) } #endif -void destroy_esp_ciphers(struct esp *esp) +static void destroy_esp_ciphers(struct esp *esp) { if (esp->cipher) { EVP_CIPHER_CTX_free(esp->cipher); @@ -103,6 +108,10 @@ static int init_esp_cipher(struct openconnect_info *vpninfo, struct esp *esp, destroy_esp_ciphers(esp); } + vpninfo->decrypt_esp_packet = decrypt_esp_packet; + vpninfo->encrypt_esp_packet = encrypt_esp_packet; + vpninfo->destroy_esp_ciphers = destroy_esp_ciphers; + return 0; } @@ -151,7 +160,8 @@ int init_esp_ciphers(struct openconnect_info *vpninfo, struct esp *esp_out, stru } /* pkt->len shall be the *payload* length. Omitting the header and the 12-byte HMAC */ -int decrypt_esp_packet(struct openconnect_info *vpninfo, struct esp *esp, struct pkt *pkt) +static int decrypt_esp_packet(struct openconnect_info *vpninfo, struct esp *esp, + struct pkt *pkt) { unsigned char hmac_buf[MAX_HMAC_SIZE]; unsigned int hmac_len = sizeof(hmac_buf); @@ -189,7 +199,7 @@ int decrypt_esp_packet(struct openconnect_info *vpninfo, struct esp *esp, struct return 0; } -int encrypt_esp_packet(struct openconnect_info *vpninfo, struct pkt *pkt, int crypt_len) +static int encrypt_esp_packet(struct openconnect_info *vpninfo, struct pkt *pkt, int crypt_len) { int blksize = 16; unsigned int hmac_len = vpninfo->hmac_out_len;