diff --git a/Driver.cpp b/Driver.cpp index 5d3c64a..6175809 100644 --- a/Driver.cpp +++ b/Driver.cpp @@ -2,6 +2,7 @@ * ovpn-dco-win OpenVPN protocol accelerator for Windows * * Copyright (C) 2020-2021 OpenVPN Inc + * Copyright (C) 2023 Rubicon Communications LLC (Netgate) * * Author: Lev Stipakov * @@ -309,8 +310,6 @@ VOID OvpnEvtFileCleanup(WDFFILEOBJECT fileObject) { // peer might already be deleted (VOID)OvpnPeerDel(device); - InterlockedExchange(&device->UserspacePid, 0); - if (device->Adapter != NULL) { OvpnAdapterSetLinkState(OvpnGetAdapterContext(device->Adapter), MediaConnectStateDisconnected); } @@ -476,6 +475,9 @@ OvpnEvtDeviceAdd(WDFDRIVER wdfDriver, PWDFDEVICE_INIT deviceInit) { GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnCryptoInitAlgHandles(&device->AesAlgHandle, &device->ChachaAlgHandle)); + // Initialize peers tree + RtlInitializeGenericTable(&device->Peers, OvpnPeerCompareByPeerIdRoutine, OvpnPeerAllocateRoutine, OvpnPeerFreeRoutine, NULL); + LOG_IF_NOT_NT_SUCCESS(status = OvpnAdapterCreate(device)); done: @@ -483,3 +485,49 @@ OvpnEvtDeviceAdd(WDFDRIVER wdfDriver, PWDFDEVICE_INIT deviceInit) { return status; } + +_Use_decl_annotations_ +NTSTATUS +OvpnAddPeer(POVPN_DEVICE device, OvpnPeerContext* peer) +{ + NTSTATUS status; + BOOLEAN newElem; + + RtlInsertElementGenericTable(&device->Peers, (PVOID)&peer, sizeof(OvpnPeerContext*), &newElem); + + if (newElem) { + status = STATUS_SUCCESS; + } + else { + LOG_ERROR("Unable to add new peer"); + status = STATUS_NO_MEMORY; + } + return status; +} + +_Use_decl_annotations_ +VOID +OvpnFlushPeers(POVPN_DEVICE device) { + OvpnCleanupPeerTable(&device->Peers); +} + +_Use_decl_annotations_ +VOID +OvpnCleanupPeerTable(RTL_GENERIC_TABLE* peers) +{ + while (!RtlIsGenericTableEmpty(peers)) { + PVOID ptr = RtlGetElementGenericTable(peers, 0); + OvpnPeerContext* peer = *(OvpnPeerContext**)ptr; + RtlDeleteElementGenericTable(peers, ptr); + + OvpnPeerCtxFree(peer); + } +} + +_Use_decl_annotations_ +OvpnPeerContext* +OvpnGetFirstPeer(RTL_GENERIC_TABLE* peers) +{ + OvpnPeerContext** ptr = (OvpnPeerContext**)RtlGetElementGenericTable(peers, 0); + return ptr ? (OvpnPeerContext*)*ptr : NULL; +} \ No newline at end of file diff --git a/Driver.h b/Driver.h index 9f46a49..f33fc0e 100644 --- a/Driver.h +++ b/Driver.h @@ -76,22 +76,6 @@ struct OVPN_DEVICE { OVPN_STATS Stats; - // keepalive interval in seconds - _Guarded_by_(SpinLock) - LONG KeepaliveInterval; - - // keepalive timeout in seconds - _Guarded_by_(SpinLock) - LONG KeepaliveTimeout; - - // timer used to send periodic ping messages to the server if no data has been sent within the past KeepaliveInterval seconds - _Guarded_by_(SpinLock) - WDFTIMER KeepaliveXmitTimer; - - // timer used to report keepalive timeout error to userspace when no data has been received for KeepaliveTimeout seconds - _Guarded_by_(SpinLock) - WDFTIMER KeepaliveRecvTimer; - BCRYPT_ALG_HANDLE AesAlgHandle; BCRYPT_ALG_HANDLE ChachaAlgHandle; @@ -99,20 +83,41 @@ struct OVPN_DEVICE { _Guarded_by_(SpinLock) UINT16 MSS; - _Guarded_by_(SpinLock) - OvpnCryptoContext CryptoContext; - _Guarded_by_(SpinLock) OvpnSocket Socket; _Guarded_by_(SpinLock) NETADAPTER Adapter; - // pid of userspace process which called NEW_PEER - _Interlocked_ - LONG UserspacePid; + _Guarded_by_(SpinLock) + RTL_GENERIC_TABLE Peers; + + SIZE_T CryptoOverhead; }; typedef OVPN_DEVICE * POVPN_DEVICE; WDF_DECLARE_CONTEXT_TYPE_WITH_NAME(OVPN_DEVICE, OvpnGetDeviceContext) + +static inline +BOOLEAN +OvpnHasPeers(_In_ POVPN_DEVICE device) +{ + return !RtlIsGenericTableEmpty(&device->Peers); +} + +struct OvpnPeerContext; + +_Must_inspect_result_ +NTSTATUS +OvpnAddPeer(_In_ POVPN_DEVICE device, _In_ OvpnPeerContext* PeerCtx); + +VOID +OvpnFlushPeers(_In_ POVPN_DEVICE device); + +VOID +OvpnCleanupPeerTable(_In_ RTL_GENERIC_TABLE*); + +_Must_inspect_result_ +OvpnPeerContext* +OvpnGetFirstPeer(_In_ RTL_GENERIC_TABLE*); diff --git a/crypto.cpp b/crypto.cpp index f75df88..0e2d9a0 100644 --- a/crypto.cpp +++ b/crypto.cpp @@ -263,8 +263,6 @@ OvpnCryptoNewKey(OvpnCryptoContext* cryptoContext, POVPN_CRYPTO_DATA cryptoData, keySlot->KeyId = cryptoData->KeyId; keySlot->PeerId = cryptoData->PeerId; - cryptoContext->CryptoOverhead = AEAD_CRYPTO_OVERHEAD; - LOG_INFO("New key", TraceLoggingValue(cryptoData->CipherAlg == OVPN_CIPHER_ALG_AES_GCM ? "aes-gcm" : "chacha20-poly1305", "alg"), TraceLoggingValue(cryptoData->KeyId, "KeyId"), TraceLoggingValue(cryptoData->KeyId, "PeerId")); } @@ -272,8 +270,6 @@ OvpnCryptoNewKey(OvpnCryptoContext* cryptoContext, POVPN_CRYPTO_DATA cryptoData, cryptoContext->Encrypt = OvpnCryptoEncryptNone; cryptoContext->Decrypt = OvpnCryptoDecryptNone; - cryptoContext->CryptoOverhead = NONE_CRYPTO_OVERHEAD; - LOG_INFO("Using cipher none"); } else { diff --git a/crypto.h b/crypto.h index 11e277e..ee35cba 100644 --- a/crypto.h +++ b/crypto.h @@ -2,6 +2,7 @@ * ovpn-dco-win OpenVPN protocol accelerator for Windows * * Copyright (C) 2020-2021 OpenVPN Inc + * Copyright (C) 2023 Rubicon Communications LLC (Netgate) * * Author: Lev Stipakov * diff --git a/peer.cpp b/peer.cpp index 8b92ca0..a7cdd07 100644 --- a/peer.cpp +++ b/peer.cpp @@ -2,6 +2,7 @@ * ovpn-dco-win OpenVPN protocol accelerator for Windows * * Copyright (C) 2020-2021 OpenVPN Inc + * Copyright (C) 2023 Rubicon Communications LLC (Netgate) * * Author: Lev Stipakov * @@ -26,6 +27,55 @@ #include "timer.h" #include "socket.h" +_Use_decl_annotations_ +OvpnPeerContext* +OvpnPeerCtxAlloc() +{ + OvpnPeerContext* peer = (OvpnPeerContext*)ExAllocatePool2(POOL_FLAG_NON_PAGED, sizeof(OvpnPeerContext), 'ovpn'); + if (peer != NULL) { + RtlZeroMemory(peer, sizeof(OvpnPeerContext)); + } + return peer; +} + +_Use_decl_annotations_ +VOID +OvpnPeerCtxFree(OvpnPeerContext* peer) +{ + OvpnCryptoUninit(&peer->CryptoContext); + OvpnTimerDestroy(&peer->KeepaliveXmitTimer); + OvpnTimerDestroy(&peer->KeepaliveRecvTimer); + + ExFreePoolWithTag(peer, 'ovpn'); +} + +_Use_decl_annotations_ +PVOID +OvpnPeerAllocateRoutine(_RTL_GENERIC_TABLE* table, CLONG size) +{ + UNREFERENCED_PARAMETER(table); + + return ExAllocatePool2(POOL_FLAG_NON_PAGED, size, 'ovpn'); +} + +_Use_decl_annotations_ +VOID +OvpnPeerFreeRoutine(_RTL_GENERIC_TABLE* table, PVOID buffer) +{ + UNREFERENCED_PARAMETER(table); + + ExFreePoolWithTag(buffer, 'ovpn'); +} + +RTL_GENERIC_COMPARE_RESULTS OvpnPeerCompareByPeerIdRoutine(_RTL_GENERIC_TABLE* table, PVOID first, PVOID second) +{ + UNREFERENCED_PARAMETER(table); + UNREFERENCED_PARAMETER(first); + UNREFERENCED_PARAMETER(second); + + return GenericEqual; +} + static VOID OvpnPeerZeroStats(POVPN_STATS stats) @@ -54,40 +104,54 @@ OvpnPeerNew(POVPN_DEVICE device, WDFREQUEST request) NTSTATUS status; GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveInputBuffer(request, sizeof(OVPN_NEW_PEER), (PVOID*)&peer, nullptr)); - ULONG newPid = IoGetRequestorProcessId(WdfRequestWdmGetIrp(request)); - LONG existingPid = InterlockedCompareExchange(&device->UserspacePid, 0, 0); - - if (existingPid != 0) { - LOG_INFO("Peer already added, deleting existing peer"); - LOG_IF_NOT_NT_SUCCESS(OvpnPeerDel(device)); + KIRQL kirql = ExAcquireSpinLockExclusive(&device->SpinLock); + const BOOLEAN peerExists = OvpnHasPeers(device); + ExReleaseSpinLockExclusive(&device->SpinLock, kirql); + if (peerExists) { + LOG_WARN("Peer already exists"); + status = STATUS_OBJECTID_EXISTS; + goto done; } - InterlockedExchange(&device->UserspacePid, newPid); - - LOG_INFO("Userspace client connected", TraceLoggingValue(newPid, "pid")); - POVPN_DRIVER driver = OvpnGetDriverContext(WdfGetDriver()); PWSK_SOCKET socket = NULL; BOOLEAN proto_tcp = peer->Proto == OVPN_PROTO_TCP; SIZE_T remoteAddrSize = peer->Remote.Addr4.sin_family == AF_INET ? sizeof(peer->Remote.Addr4) : sizeof(peer->Remote.Addr6); - GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnSocketInit(&driver->WskProviderNpi, &driver->WskRegistration, peer->Local.Addr4.sin_family, proto_tcp, (PSOCKADDR)&peer->Local, - (PSOCKADDR)&peer->Remote, remoteAddrSize, device, &socket)); + OvpnPeerContext* peerCtx = OvpnPeerCtxAlloc(); + if (peerCtx == NULL) { + status = STATUS_NO_MEMORY; + goto done; + } - KIRQL kirql = ExAcquireSpinLockExclusive(&device->SpinLock); - RtlZeroMemory(&device->CryptoContext, sizeof(OvpnCryptoContext)); - device->Socket.Socket = socket; - device->Socket.Tcp = proto_tcp; - RtlZeroMemory(&device->Socket.TcpState, sizeof(OvpnSocketTcpState)); - RtlZeroMemory(&device->Socket.UdpState, sizeof(OvpnSocketUdpState)); - ExReleaseSpinLockExclusive(&device->SpinLock, kirql); + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnSocketInit(&driver->WskProviderNpi, + &driver->WskRegistration, peer->Local.Addr4.sin_family, proto_tcp, + (PSOCKADDR)&peer->Local, + (PSOCKADDR)&peer->Remote, + remoteAddrSize, device, &socket)); - OvpnPeerZeroStats(&device->Stats); + kirql = ExAcquireSpinLockExclusive(&device->SpinLock); - if (proto_tcp) { - LOG_IF_NOT_NT_SUCCESS(status = WdfRequestForwardToIoQueue(request, device->PendingNewPeerQueue)); - // start async connect - status = OvpnSocketTcpConnect(socket, device, (PSOCKADDR)&peer->Remote); + LOG_IF_NOT_NT_SUCCESS(status = OvpnAddPeer(device, peerCtx)); + if (status != STATUS_SUCCESS) { + ExReleaseSpinLockExclusive(&device->SpinLock, kirql); + OvpnPeerCtxFree(peerCtx); + LOG_IF_NOT_NT_SUCCESS(OvpnSocketClose(socket)); + } + else { + device->Socket.Socket = socket; + device->Socket.Tcp = proto_tcp; + RtlZeroMemory(&device->Socket.TcpState, sizeof(OvpnSocketTcpState)); + RtlZeroMemory(&device->Socket.UdpState, sizeof(OvpnSocketUdpState)); + ExReleaseSpinLockExclusive(&device->SpinLock, kirql); + + OvpnPeerZeroStats(&device->Stats); + + if (proto_tcp) { + LOG_IF_NOT_NT_SUCCESS(status = WdfRequestForwardToIoQueue(request, device->PendingNewPeerQueue)); + // start async connect + status = OvpnSocketTcpConnect(socket, device, (PSOCKADDR)&peer->Remote); + } } done: @@ -102,22 +166,12 @@ OvpnPeerDel(POVPN_DEVICE device) { LOG_ENTER(); - if (InterlockedCompareExchange(&device->UserspacePid, 0, 0) == 0) { - LOG_INFO("Peer not added."); - return STATUS_INVALID_DEVICE_REQUEST; - } - KIRQL kirql = ExAcquireSpinLockExclusive(&device->SpinLock); - OvpnTimerDestroy(&device->KeepaliveXmitTimer); - OvpnTimerDestroy(&device->KeepaliveRecvTimer); - - OvpnCryptoUninit(&device->CryptoContext); - - InterlockedExchange(&device->UserspacePid, 0); - PWSK_SOCKET socket = device->Socket.Socket; + device->Socket.Socket = NULL; + OvpnFlushPeers(device); RtlZeroMemory(&device->Socket.TcpState, sizeof(OvpnSocketTcpState)); RtlZeroMemory(&device->Socket.UdpState, sizeof(OvpnSocketUdpState)); @@ -151,48 +205,52 @@ NTSTATUS OvpnPeerSet(POVPN_DEVICE device, WDFREQUEST request) { LOG_ENTER(); - if (InterlockedCompareExchange(&device->UserspacePid, 0, 0) == 0) { + NTSTATUS status = STATUS_SUCCESS; + + OvpnPeerContext* peer = OvpnGetFirstPeer(&device->Peers); + + if (peer == NULL) { LOG_ERROR("Peer not added"); - return STATUS_INVALID_DEVICE_REQUEST; + status = STATUS_INVALID_DEVICE_REQUEST; + goto done; } - POVPN_SET_PEER peer = NULL; - NTSTATUS status; - GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveInputBuffer(request, sizeof(OVPN_SET_PEER), (PVOID*)&peer, nullptr)); + POVPN_SET_PEER set_peer = NULL; + GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveInputBuffer(request, sizeof(OVPN_SET_PEER), (PVOID*)&set_peer, nullptr)); - LOG_INFO("Set peer", TraceLoggingValue(peer->KeepaliveInterval, "interval"), - TraceLoggingValue(peer->KeepaliveTimeout, "timeout"), - TraceLoggingValue(peer->MSS, "MSS")); + LOG_INFO("Set peer", TraceLoggingValue(set_peer->KeepaliveInterval, "interval"), + TraceLoggingValue(set_peer->KeepaliveTimeout, "timeout"), + TraceLoggingValue(set_peer->MSS, "MSS")); - if (peer->MSS != -1) { - device->MSS = (UINT16)peer->MSS; + if (set_peer->MSS != -1) { + device->MSS = (UINT16)set_peer->MSS; } - if (peer->KeepaliveInterval != -1) { - device->KeepaliveInterval = peer->KeepaliveInterval; + if (set_peer->KeepaliveInterval != -1) { + peer->KeepaliveInterval = set_peer->KeepaliveInterval; - if (device->KeepaliveInterval > 0) { + if (peer->KeepaliveInterval > 0) { // keepalive xmit timer, sends ping packets - GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnTimerXmitCreate(device->WdfDevice, peer->KeepaliveInterval, &device->KeepaliveXmitTimer)); - OvpnTimerReset(device->KeepaliveXmitTimer, peer->KeepaliveInterval); + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnTimerXmitCreate(device->WdfDevice, peer, peer->KeepaliveInterval, &peer->KeepaliveXmitTimer)); + OvpnTimerReset(peer->KeepaliveXmitTimer, peer->KeepaliveInterval); } else { LOG_INFO("Destroy xmit timer"); - OvpnTimerDestroy(&device->KeepaliveXmitTimer); + OvpnTimerDestroy(&peer->KeepaliveXmitTimer); } } if (peer->KeepaliveTimeout != -1) { - device->KeepaliveTimeout = peer->KeepaliveTimeout; + peer->KeepaliveTimeout = set_peer->KeepaliveTimeout; - if (device->KeepaliveTimeout > 0) { + if (peer->KeepaliveTimeout > 0) { // keepalive recv timer, detects keepalive timeout - GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnTimerRecvCreate(device->WdfDevice, &device->KeepaliveRecvTimer)); - OvpnTimerReset(device->KeepaliveRecvTimer, peer->KeepaliveTimeout); + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnTimerRecvCreate(device->WdfDevice, peer, &peer->KeepaliveRecvTimer)); + OvpnTimerReset(peer->KeepaliveRecvTimer, peer->KeepaliveTimeout); } else { LOG_INFO("Destroy recv timer"); - OvpnTimerDestroy(&device->KeepaliveRecvTimer); + OvpnTimerDestroy(&peer->KeepaliveRecvTimer); } } @@ -205,13 +263,16 @@ _Use_decl_annotations_ NTSTATUS OvpnPeerGetStats(POVPN_DEVICE device, WDFREQUEST request, ULONG_PTR* bytesReturned) { - if (InterlockedCompareExchange(&device->UserspacePid, 0, 0) == 0) { + NTSTATUS status = STATUS_SUCCESS; + + OvpnPeerContext* peer = OvpnGetFirstPeer(&device->Peers); + if (peer == NULL) { LOG_ERROR("Peer not added"); - return STATUS_INVALID_DEVICE_REQUEST; + status = STATUS_INVALID_DEVICE_REQUEST; + goto done; } POVPN_STATS stats = NULL; - NTSTATUS status; GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveOutputBuffer(request, sizeof(OVPN_STATS), (PVOID*)&stats, NULL)); stats->LostInControlPackets = InterlockedCompareExchangeNoFence(&device->Stats.LostInControlPackets, 0, 0); @@ -239,16 +300,20 @@ OvpnPeerStartVPN(POVPN_DEVICE device) { LOG_ENTER(); - if (InterlockedCompareExchange(&device->UserspacePid, 0, 0) == 0) { + NTSTATUS status = STATUS_SUCCESS; + + if (!OvpnHasPeers(device)) { LOG_ERROR("Peer not added"); - return STATUS_INVALID_DEVICE_REQUEST; + status = STATUS_INVALID_DEVICE_REQUEST; + goto done; } OvpnAdapterSetLinkState(OvpnGetAdapterContext(device->Adapter), MediaConnectStateConnected); +done: LOG_EXIT(); - return STATUS_SUCCESS; + return status; } _Use_decl_annotations_ @@ -257,13 +322,15 @@ OvpnPeerNewKey(POVPN_DEVICE device, WDFREQUEST request) { LOG_ENTER(); - if (InterlockedCompareExchange(&device->UserspacePid, 0, 0) == 0) { + NTSTATUS status = STATUS_SUCCESS; + + if (!OvpnHasPeers(device)) { LOG_ERROR("Peer not added"); - return STATUS_INVALID_DEVICE_REQUEST; + status = STATUS_INVALID_DEVICE_REQUEST; + goto done; } POVPN_CRYPTO_DATA cryptoData = NULL; - NTSTATUS status; GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveInputBuffer(request, sizeof(OVPN_CRYPTO_DATA), (PVOID*)&cryptoData, nullptr)); @@ -271,6 +338,7 @@ OvpnPeerNewKey(POVPN_DEVICE device, WDFREQUEST request) switch (cryptoData->CipherAlg) { case OVPN_CIPHER_ALG_AES_GCM: algHandle = device->AesAlgHandle; + device->CryptoOverhead = AEAD_CRYPTO_OVERHEAD; break; case OVPN_CIPHER_ALG_CHACHA20_POLY1305: @@ -280,10 +348,20 @@ OvpnPeerNewKey(POVPN_DEVICE device, WDFREQUEST request) status = STATUS_INVALID_DEVICE_REQUEST; goto done; } + device->CryptoOverhead = AEAD_CRYPTO_OVERHEAD; + + default: + device->CryptoOverhead = NONE_CRYPTO_OVERHEAD; break; } - GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnCryptoNewKey(&device->CryptoContext, cryptoData, algHandle)); + OvpnPeerContext* peer = OvpnGetFirstPeer(&device->Peers); + if (peer == NULL) { + status = STATUS_OBJECTID_NOT_FOUND; + goto done; + } + + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnCryptoNewKey(&peer->CryptoContext, cryptoData, algHandle)); done: LOG_EXIT(); @@ -297,14 +375,26 @@ OvpnPeerSwapKeys(POVPN_DEVICE device) { LOG_ENTER(); - if (InterlockedCompareExchange(&device->UserspacePid, 0, 0) == 0) { + NTSTATUS status = STATUS_SUCCESS; + + if (!OvpnHasPeers(device)) { LOG_ERROR("Peer not added"); - return STATUS_INVALID_DEVICE_REQUEST; + status = STATUS_INVALID_DEVICE_REQUEST; + goto done; + } + + OvpnPeerContext* peer = OvpnGetFirstPeer(&device->Peers); + if (peer == NULL) { + LOG_ERROR("Peer not found"); + status = STATUS_INVALID_DEVICE_REQUEST; + goto done; } - OvpnCryptoSwapKeys(&device->CryptoContext); + OvpnCryptoSwapKeys(&peer->CryptoContext); +done: LOG_EXIT(); - return STATUS_SUCCESS; + return status; } + diff --git a/peer.h b/peer.h index 4eecca6..4495171 100644 --- a/peer.h +++ b/peer.h @@ -2,6 +2,7 @@ * ovpn-dco-win OpenVPN protocol accelerator for Windows * * Copyright (C) 2020-2021 OpenVPN Inc + * Copyright (C) 2023 Rubicon Communications LLC (Netgate) * * Author: Lev Stipakov * @@ -26,6 +27,36 @@ #include "driver.h" #include "uapi\ovpn-dco.h" +struct OvpnPeerContext +{ + OvpnCryptoContext CryptoContext; + + INT32 PeerId; + + // keepalive interval in seconds + LONG KeepaliveInterval; + + // keepalive timeout in seconds + LONG KeepaliveTimeout; + + // timer used to send periodic ping messages to the server if no data has been sent within the past KeepaliveInterval seconds + WDFTIMER KeepaliveXmitTimer; + + // timer used to report keepalive timeout error to userspace when no data has been received for KeepaliveTimeout seconds + WDFTIMER KeepaliveRecvTimer; +}; + +_Must_inspect_result_ +OvpnPeerContext* +OvpnPeerCtxAlloc(); + +VOID +OvpnPeerCtxFree(_In_ OvpnPeerContext*); + +RTL_GENERIC_ALLOCATE_ROUTINE OvpnPeerAllocateRoutine; +RTL_GENERIC_FREE_ROUTINE OvpnPeerFreeRoutine; +RTL_GENERIC_COMPARE_ROUTINE OvpnPeerCompareByPeerIdRoutine; + _Must_inspect_result_ _IRQL_requires_(PASSIVE_LEVEL) NTSTATUS diff --git a/rxqueue.cpp b/rxqueue.cpp index bc0f8c0..ce7d71b 100644 --- a/rxqueue.cpp +++ b/rxqueue.cpp @@ -115,7 +115,7 @@ OvpnEvtRxQueueAdvance(NETPACKETQUEUE netPacketQueue) fragment->ValidLength = buffer->Len; fragment->Offset = 0; NET_FRAGMENT_VIRTUAL_ADDRESS* virtualAddr = NetExtensionGetFragmentVirtualAddress(&queue->VirtualAddressExtension, NetFragmentIteratorGetIndex(&fi)); - RtlCopyMemory(virtualAddr->VirtualAddress, buffer->Data + device->CryptoContext.CryptoOverhead, buffer->Len); + RtlCopyMemory(virtualAddr->VirtualAddress, buffer->Data + device->CryptoOverhead, buffer->Len); InterlockedExchangeAddNoFence64(&device->Stats.TunBytesReceived, buffer->Len); diff --git a/socket.cpp b/socket.cpp index 3b342f0..ef57e37 100644 --- a/socket.cpp +++ b/socket.cpp @@ -2,6 +2,7 @@ * ovpn-dco-win OpenVPN protocol accelerator for Windows * * Copyright (C) 2020-2021 OpenVPN Inc + * Copyright (C) 2023 Rubicon Communications LLC (Netgate) * * Author: Lev Stipakov * @@ -30,6 +31,7 @@ #include "rxqueue.h" #include "timer.h" #include "socket.h" +#include "peer.h" IO_COMPLETION_ROUTINE OvpnSocketSyncOpCompletionRoutine; @@ -159,6 +161,13 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, _In_reads_ { InterlockedExchangeAddNoFence64(&device->Stats.TransportBytesReceived, len); + OvpnPeerContext* peer = OvpnGetFirstPeer(&device->Peers); + if (peer == NULL) { + LOG_WARN("No peer"); + InterlockedIncrementNoFence(&device->Stats.LostInDataPackets); + return; + } + OVPN_RX_BUFFER* buffer; // fetch buffer for plaintext @@ -169,9 +178,9 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, _In_reads_ return; } - if (device->CryptoContext.Decrypt) { + if (peer->CryptoContext.Decrypt) { UCHAR keyId = OvpnCryptoKeyIdExtract(op); - OvpnCryptoKeySlot* keySlot = OvpnCryptoKeySlotFromKeyId(&device->CryptoContext, keyId); + OvpnCryptoKeySlot* keySlot = OvpnCryptoKeySlotFromKeyId(&peer->CryptoContext, keyId); if (!keySlot) { status = STATUS_INVALID_DEVICE_STATE; @@ -179,8 +188,8 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, _In_reads_ } else { // decrypt into plaintext buffer - status = device->CryptoContext.Decrypt(keySlot, cipherTextBuf, len, buffer->Data); - buffer->Len = len - device->CryptoContext.CryptoOverhead; + status = peer->CryptoContext.Decrypt(keySlot, cipherTextBuf, len, buffer->Data); + buffer->Len = len - device->CryptoOverhead; } } else { @@ -194,10 +203,10 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, _In_reads_ return; } - OvpnTimerReset(device->KeepaliveRecvTimer, device->KeepaliveTimeout); + OvpnTimerReset(peer->KeepaliveRecvTimer, peer->KeepaliveTimeout); // points to the beginning of plaintext - UCHAR* buf = buffer->Data + device->CryptoContext.CryptoOverhead; + UCHAR* buf = buffer->Data + device->CryptoOverhead; // ping packet? if (OvpnTimerIsKeepaliveMessage(buf, buffer->Len)) { diff --git a/socket.h b/socket.h index e94247c..b6dee5a 100644 --- a/socket.h +++ b/socket.h @@ -66,7 +66,7 @@ OvpnSocketInit(_In_ WSK_PROVIDER_NPI* wskProviderNpi, _In_ WSK_REGISTRATION* wsk _Must_inspect_result_ _IRQL_requires_(PASSIVE_LEVEL) NTSTATUS -OvpnSocketClose(_In_ PWSK_SOCKET socket); +OvpnSocketClose(_In_opt_ PWSK_SOCKET socket); _Must_inspect_result_ NTSTATUS diff --git a/timer.cpp b/timer.cpp index d9563db..bac0e9b 100644 --- a/timer.cpp +++ b/timer.cpp @@ -19,17 +19,28 @@ * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. */ +// without this include, RTL_GENERIC_TABLE in Device.h is undefined +#include + #include "bufferpool.h" #include "driver.h" #include "trace.h" #include "timer.h" #include "socket.h" +#include "peer.h" static const UCHAR OvpnKeepaliveMessage[] = { 0x2a, 0x18, 0x7b, 0xf3, 0x64, 0x1e, 0xb4, 0xcb, 0x07, 0xed, 0x2d, 0x0a, 0x98, 0x1f, 0xc7, 0x48 }; +// Context added to a timer's attributes +typedef struct _OVPN_PEER_TIMER_CONTEXT { + OvpnPeerContext* Peer; +} OVPN_PEER_TIMER_CONTEXT, * POVPN_PEER_TIMER_CONTEXT; + +WDF_DECLARE_CONTEXT_TYPE_WITH_NAME(OVPN_PEER_TIMER_CONTEXT, OvpnGetPeerTimerContext); + _Use_decl_annotations_ BOOLEAN OvpnTimerIsKeepaliveMessage(const PUCHAR buf, SIZE_T len) { @@ -39,26 +50,31 @@ BOOLEAN OvpnTimerIsKeepaliveMessage(const PUCHAR buf, SIZE_T len) _Function_class_(EVT_WDF_TIMER) static VOID OvpnTimerXmit(WDFTIMER timer) { + LOG_ENTER(); + POVPN_DEVICE device = OvpnGetDeviceContext(WdfTimerGetParentObject(timer)); + POVPN_PEER_TIMER_CONTEXT timerCtx = OvpnGetPeerTimerContext(timer); OVPN_TX_BUFFER* buffer; NTSTATUS status; LOG_IF_NOT_NT_SUCCESS(status = OvpnTxBufferPoolGet(device->TxBufferPool, &buffer)); if (!NT_SUCCESS(status)) { + LOG_EXIT(); return; } // copy keepalive magic message to the buffer RtlCopyMemory(OvpnTxBufferPut(buffer, sizeof(OvpnKeepaliveMessage)), OvpnKeepaliveMessage, sizeof(OvpnKeepaliveMessage)); + OvpnPeerContext* peer = timerCtx->Peer; KIRQL kiqrl = ExAcquireSpinLockShared(&device->SpinLock); - if (device->CryptoContext.Encrypt) { + if (peer->CryptoContext.Encrypt) { // make space to crypto overhead - OvpnTxBufferPush(buffer, device->CryptoContext.CryptoOverhead); + OvpnTxBufferPush(buffer, device->CryptoOverhead); // in-place encrypt, always with primary key - status = device->CryptoContext.Encrypt(&device->CryptoContext.Primary, buffer->Data, buffer->Len); + status = peer->CryptoContext.Encrypt(&peer->CryptoContext.Primary, buffer->Data, buffer->Len); } else { status = STATUS_INVALID_DEVICE_STATE; @@ -76,11 +92,15 @@ static VOID OvpnTimerXmit(WDFTIMER timer) OvpnTxBufferPoolPut(buffer); } ExReleaseSpinLockShared(&device->SpinLock, kiqrl); + + LOG_EXIT(); } _Function_class_(EVT_WDF_TIMER) static VOID OvpnTimerRecv(WDFTIMER timer) { + LOG_ENTER(); + LOG_WARN("Keepalive timeout"); POVPN_DEVICE device = OvpnGetDeviceContext(WdfTimerGetParentObject(timer)); @@ -94,6 +114,8 @@ static VOID OvpnTimerRecv(WDFTIMER timer) ULONG_PTR bytesSent = 0; WdfRequestCompleteWithInformation(request, STATUS_CONNECTION_DISCONNECTED, bytesSent); } + + LOG_EXIT(); } _Use_decl_annotations_ @@ -107,8 +129,10 @@ VOID OvpnTimerDestroy(WDFTIMER* timer) } } -static NTSTATUS OvpnTimerCreate(WDFOBJECT parent, ULONG period, PFN_WDF_TIMER func, _Inout_ WDFTIMER* timer) +static NTSTATUS OvpnTimerCreate(WDFOBJECT parent, OvpnPeerContext* peer, ULONG period, PFN_WDF_TIMER func, _Inout_ WDFTIMER* timer) { + LOG_ENTER(); + if (*timer != WDF_NO_HANDLE) { WdfTimerStop(*timer, FALSE); WdfObjectDelete(*timer); @@ -122,27 +146,37 @@ static NTSTATUS OvpnTimerCreate(WDFOBJECT parent, ULONG period, PFN_WDF_TIMER fu WDF_OBJECT_ATTRIBUTES timerAttributes; WDF_OBJECT_ATTRIBUTES_INIT(&timerAttributes); + WDF_OBJECT_ATTRIBUTES_SET_CONTEXT_TYPE(&timerAttributes, OVPN_PEER_TIMER_CONTEXT); timerAttributes.ParentObject = parent; - return WdfTimerCreate(&timerConfig, &timerAttributes, timer); + *timer = WDF_NO_HANDLE; + NTSTATUS status; + LOG_IF_NOT_NT_SUCCESS(status = WdfTimerCreate(&timerConfig, &timerAttributes, timer)); + if (NT_SUCCESS(status)) { + POVPN_PEER_TIMER_CONTEXT pTimerContext = OvpnGetPeerTimerContext(*timer); + pTimerContext->Peer = peer; + } + + LOG_EXIT(); + return status; } _Use_decl_annotations_ -NTSTATUS OvpnTimerXmitCreate(WDFOBJECT parent, ULONG period, WDFTIMER* timer) +NTSTATUS OvpnTimerXmitCreate(WDFOBJECT parent, OvpnPeerContext* peer, ULONG period, WDFTIMER* timer) { NTSTATUS status; LOG_INFO("Create xmit timer", TraceLoggingValue(period, "period")); - LOG_IF_NOT_NT_SUCCESS(status = OvpnTimerCreate(parent, period, OvpnTimerXmit, timer)); + LOG_IF_NOT_NT_SUCCESS(status = OvpnTimerCreate(parent, peer, period, OvpnTimerXmit, timer)); return status; } _Use_decl_annotations_ -NTSTATUS OvpnTimerRecvCreate(WDFOBJECT parent, WDFTIMER* timer) +NTSTATUS OvpnTimerRecvCreate(WDFOBJECT parent, OvpnPeerContext* peer, WDFTIMER* timer) { NTSTATUS status; LOG_INFO("Create recv timer"); - LOG_IF_NOT_NT_SUCCESS(status = OvpnTimerCreate(parent, 0, OvpnTimerRecv, timer)); + LOG_IF_NOT_NT_SUCCESS(status = OvpnTimerCreate(parent, peer, 0, OvpnTimerRecv, timer)); return status; } diff --git a/timer.h b/timer.h index 41f52a5..b9f8bfc 100644 --- a/timer.h +++ b/timer.h @@ -2,6 +2,7 @@ * ovpn-dco-win OpenVPN protocol accelerator for Windows * * Copyright (C) 2020-2021 OpenVPN Inc + * Copyright (C) 2023 Rubicon Communications LLC (Netgate) * * Author: Lev Stipakov * @@ -29,11 +30,11 @@ OvpnTimerReset(WDFTIMER timer, ULONG dueTime); _Must_inspect_result_ NTSTATUS -OvpnTimerXmitCreate(WDFOBJECT parent, ULONG period, _Inout_ WDFTIMER* timer); +OvpnTimerXmitCreate(WDFOBJECT parent, OvpnPeerContext* peer, ULONG period, _Inout_ WDFTIMER* timer); _Must_inspect_result_ NTSTATUS -OvpnTimerRecvCreate(WDFOBJECT parent, _Inout_ WDFTIMER* timer); +OvpnTimerRecvCreate(WDFOBJECT parent, OvpnPeerContext* peer, _Inout_ WDFTIMER* timer); VOID OvpnTimerDestroy(_Inout_ WDFTIMER* timer); diff --git a/txqueue.cpp b/txqueue.cpp index 711b555..a93a470 100644 --- a/txqueue.cpp +++ b/txqueue.cpp @@ -2,6 +2,7 @@ * ovpn-dco-win OpenVPN protocol accelerator for Windows * * Copyright (C) 2020-2021 OpenVPN Inc + * Copyright (C) 2023 Rubicon Communications LLC (Netgate) * * Author: Lev Stipakov * @@ -33,6 +34,7 @@ #include "timer.h" #include "txqueue.h" #include "socket.h" +#include "peer.h" _Must_inspect_result_ _Requires_shared_lock_held_(device->SpinLock) @@ -81,14 +83,22 @@ OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET OvpnMssDoIPv6(buffer->Data, buffer->Len, device->MSS); } + OvpnPeerContext* peer = OvpnGetFirstPeer(&device->Peers); + if (peer == NULL) { + status = STATUS_ADDRESS_NOT_ASSOCIATED; + OvpnTxBufferPoolPut(buffer); + LOG_WARN("No peer"); + goto out; + } + InterlockedExchangeAddNoFence64(&device->Stats.TunBytesSent, buffer->Len); - if (device->CryptoContext.Encrypt) { + if (peer->CryptoContext.Encrypt) { // make space to crypto overhead - OvpnTxBufferPush(buffer, device->CryptoContext.CryptoOverhead); + OvpnTxBufferPush(buffer, device->CryptoOverhead); // in-place encrypt, always with primary key - status = device->CryptoContext.Encrypt(&device->CryptoContext.Primary, buffer->Data, buffer->Len); + status = peer->CryptoContext.Encrypt(&peer->CryptoContext.Primary, buffer->Data, buffer->Len); } else { status = STATUS_INVALID_DEVICE_STATE; @@ -117,6 +127,8 @@ OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET *tail = buffer; } + + OvpnTimerReset(peer->KeepaliveXmitTimer, peer->KeepaliveInterval); } else { OvpnTxBufferPoolPut(buffer); @@ -167,14 +179,9 @@ OvpnEvtTxQueueAdvance(NETPACKETQUEUE netPacketQueue) } NetPacketIteratorSet(&pi); - // reset keepalive timer - if (packetSent) { - OvpnTimerReset(device->KeepaliveXmitTimer, device->KeepaliveInterval); - - if (!device->Socket.Tcp) { - // this will use WskSendMessages to send buffers list which we constructed before - LOG_IF_NOT_NT_SUCCESS(OvpnSocketSend(&device->Socket, txBufferHead)); - } + if (packetSent && !device->Socket.Tcp) { + // this will use WskSendMessages to send buffers list which we constructed before + LOG_IF_NOT_NT_SUCCESS(OvpnSocketSend(&device->Socket, txBufferHead)); } ExReleaseSpinLockShared(&device->SpinLock, kirql);