From a194c238e2922746c191199d9f77f674386c40d7 Mon Sep 17 00:00:00 2001 From: Simon Whitty Date: Fri, 22 Nov 2024 09:42:21 +1100 Subject: [PATCH] convert setPktInfo() to SocketOption --- FlyingSocks/Sources/Socket.swift | 23 ++++++++++------------- FlyingSocks/Tests/SocketTests.swift | 18 ++++++++++++++++++ FlyingSocks/XCTests/SocketTests.swift | 14 ++++++++++++++ 3 files changed, 42 insertions(+), 13 deletions(-) diff --git a/FlyingSocks/Sources/Socket.swift b/FlyingSocks/Sources/Socket.swift index 2b56d02e..58dbd473 100644 --- a/FlyingSocks/Sources/Socket.swift +++ b/FlyingSocks/Sources/Socket.swift @@ -123,25 +123,14 @@ public struct Socket: Sendable, Hashable { // enable return of ip_pktinfo/ipv6_pktinfo on recvmsg() private func setPktInfo(domain: Int32) throws { - var enable = Int32(1) - let level: Int32 - let name: Int32 - switch domain { case AF_INET: - level = Socket.ipproto_ip - name = Self.ip_pktinfo + try setValue(true, for: .packetInfoIP) case AF_INET6: - level = Socket.ipproto_ipv6 - name = Self.ipv6_recvpktinfo + try setValue(true, for: .packetInfoIPv6) default: return } - - let result = Socket.setsockopt(file.rawValue, level, name, &enable, socklen_t(MemoryLayout.size)) - guard result >= 0 else { - throw SocketError.makeFailed("SetPktInfoOption") - } } public func setValue(_ value: O.Value, for option: O) throws { @@ -573,6 +562,14 @@ public extension SocketOption where Self == BoolSocketOption { BoolSocketOption(name: SO_REUSEADDR) } + static var packetInfoIP: Self { + BoolSocketOption(level: Socket.ipproto_ip, name: Socket.ip_pktinfo) + } + + static var packetInfoIPv6: Self { + BoolSocketOption(level: Socket.ipproto_ipv6, name: Socket.ipv6_recvpktinfo) + } + #if canImport(Darwin) // Prevents SIG_TRAP when app is paused / running in background. static var noSIGPIPE: Self { diff --git a/FlyingSocks/Tests/SocketTests.swift b/FlyingSocks/Tests/SocketTests.swift index 148e8e69..218ab923 100644 --- a/FlyingSocks/Tests/SocketTests.swift +++ b/FlyingSocks/Tests/SocketTests.swift @@ -321,6 +321,24 @@ struct SocketTests { try Socket.inet_ntop(AF_INET6, &addr.sin6_addr, buffer, maxLength) } } + + @Test + func makes_datagram_ip4() throws { + let socket = try Socket(domain: Int32(sa_family_t(AF_INET)), type: .datagram) + + #expect( + try socket.getValue(for: .packetInfoIP) == true + ) + } + + @Test + func makes_datagram_ip6() throws { + let socket = try Socket(domain: Int32(sa_family_t(AF_INET6)), type: .datagram) + + #expect( + try socket.getValue(for: .packetInfoIPv6) == true + ) + } } extension Socket.Flags { diff --git a/FlyingSocks/XCTests/SocketTests.swift b/FlyingSocks/XCTests/SocketTests.swift index e555e459..e4962720 100644 --- a/FlyingSocks/XCTests/SocketTests.swift +++ b/FlyingSocks/XCTests/SocketTests.swift @@ -258,6 +258,20 @@ final class SocketTests: XCTestCase { let buffer = UnsafeMutablePointer.allocate(capacity: Int(maxLength)) XCTAssertThrowsError(try Socket.inet_ntop(AF_INET6, &addr.sin6_addr, buffer, maxLength)) } + + func testMakes_datagram_ip4() throws { + let socket = try Socket(domain: Int32(sa_family_t(AF_INET)), type: .datagram) + XCTAssertTrue( + try socket.getValue(for: .packetInfoIP) + ) + } + + func testMakes_datagram_ip6() throws { + let socket = try Socket(domain: Int32(sa_family_t(AF_INET6)), type: .datagram) + XCTAssertTrue( + try socket.getValue(for: .packetInfoIPv6) + ) + } } extension Socket.Flags {