From f6b9c0c909702f44e3b16f2cec1bc30c5195448f Mon Sep 17 00:00:00 2001 From: Ciro Spaciari Date: Sat, 13 Apr 2024 00:58:45 -0300 Subject: [PATCH] fix(socket) fix error in case of failure/returning a error in the open handler (#10154) * fix socket * one more test * always clean callback on deinit * Update src/bun.js/api/bun/socket.zig Co-authored-by: Jarred Sumner * make context close private * keep old logic * move clean step to SocketContext.close * add comment * wait for close on stop * cleanup --------- Co-authored-by: Jarred Sumner --- src/bun.js/api/bun/socket.zig | 25 ++++++---- src/deps/uws.zig | 43 +++++++++++++++++ test/js/bun/net/socket.test.ts | 87 +++++++++++++++++++++++++++++++++- 3 files changed, 144 insertions(+), 11 deletions(-) diff --git a/src/bun.js/api/bun/socket.zig b/src/bun.js/api/bun/socket.zig index 835052d680b19f..b8c570acda5657 100644 --- a/src/bun.js/api/bun/socket.zig +++ b/src/bun.js/api/bun/socket.zig @@ -218,8 +218,9 @@ const Handlers = struct { this.unprotect(); // will deinit when is not wrapped or when is the TCP wrapped connection if (wrapped != .tls) { - if (ctx) |ctx_| + if (ctx) |ctx_| { ctx_.deinit(ssl); + } } bun.default_allocator.destroy(this); } @@ -825,23 +826,27 @@ pub const Listener = struct { const arguments = callframe.arguments(1); log("close", .{}); - if (arguments.len > 0 and arguments.ptr[0].isBoolean() and arguments.ptr[0].toBoolean() and this.socket_context != null) { - this.socket_context.?.close(this.ssl); - this.listener = null; - } else { - var listener = this.listener orelse return JSValue.jsUndefined(); - this.listener = null; - listener.close(this.ssl); - } + var listener = this.listener orelse return JSValue.jsUndefined(); + this.listener = null; this.poll_ref.unref(this.handlers.vm); + // if we already have no active connections, we can deinit the context now if (this.handlers.active_connections == 0) { this.handlers.unprotect(); - this.socket_context.?.close(this.ssl); + // deiniting the context will also close the listener this.socket_context.?.deinit(this.ssl); this.socket_context = null; this.strong_self.clear(); this.strong_data.clear(); + } else { + const forceClose = arguments.len > 0 and arguments.ptr[0].isBoolean() and arguments.ptr[0].toBoolean() and this.socket_context != null; + if (forceClose) { + // close all connections in this context and wait for them to close + this.socket_context.?.close(this.ssl); + } else { + // only close the listener and wait for the connections to close by it self + listener.close(this.ssl); + } } return JSValue.jsUndefined(); diff --git a/src/deps/uws.zig b/src/deps/uws.zig index 386934c50d2dd8..e1fa199c5b411c 100644 --- a/src/deps/uws.zig +++ b/src/deps/uws.zig @@ -893,6 +893,47 @@ pub const SocketContext = opaque { us_socket_context_free(@as(i32, 0), this); } + pub fn cleanCallbacks(ctx: *SocketContext, is_ssl: bool) void { + const ssl_int: i32 = @intFromBool(is_ssl); + // replace callbacks with dummy ones + const DummyCallbacks = struct { + fn open(socket: *Socket, _: i32, _: [*c]u8, _: i32) callconv(.C) ?*Socket { + return socket; + } + fn close(socket: *Socket, _: i32, _: ?*anyopaque) callconv(.C) ?*Socket { + return socket; + } + fn data(socket: *Socket, _: [*c]u8, _: i32) callconv(.C) ?*Socket { + return socket; + } + fn writable(socket: *Socket) callconv(.C) ?*Socket { + return socket; + } + fn timeout(socket: *Socket) callconv(.C) ?*Socket { + return socket; + } + fn connect_error(socket: *Socket, _: i32) callconv(.C) ?*Socket { + return socket; + } + fn end(socket: *Socket) callconv(.C) ?*Socket { + return socket; + } + fn handshake(_: *Socket, _: i32, _: us_bun_verify_error_t, _: ?*anyopaque) callconv(.C) void {} + fn long_timeout(socket: *Socket) callconv(.C) ?*Socket { + return socket; + } + }; + us_socket_context_on_open(ssl_int, ctx, DummyCallbacks.open); + us_socket_context_on_close(ssl_int, ctx, DummyCallbacks.close); + us_socket_context_on_data(ssl_int, ctx, DummyCallbacks.data); + us_socket_context_on_writable(ssl_int, ctx, DummyCallbacks.writable); + us_socket_context_on_timeout(ssl_int, ctx, DummyCallbacks.timeout); + us_socket_context_on_connect_error(ssl_int, ctx, DummyCallbacks.connect_error); + us_socket_context_on_end(ssl_int, ctx, DummyCallbacks.end); + us_socket_context_on_handshake(ssl_int, ctx, DummyCallbacks.handshake, null); + us_socket_context_on_long_timeout(ssl_int, ctx, DummyCallbacks.long_timeout); + } + fn getLoop(this: *SocketContext, ssl: bool) ?*Loop { if (ssl) { return us_socket_context_loop(@as(i32, 1), this); @@ -902,6 +943,8 @@ pub const SocketContext = opaque { /// closes and deinit the SocketContexts pub fn deinit(this: *SocketContext, ssl: bool) void { + // we clean the callbacks to avoid UAF because we are deiniting + this.cleanCallbacks(ssl); this.close(ssl); //always deinit in next iteration if (ssl) { diff --git a/test/js/bun/net/socket.test.ts b/test/js/bun/net/socket.test.ts index af131410fa6800..c46a46b43950e3 100644 --- a/test/js/bun/net/socket.test.ts +++ b/test/js/bun/net/socket.test.ts @@ -1,7 +1,7 @@ import { expect, it } from "bun:test"; import { bunEnv, bunExe, expectMaxObjectTypeCount, isWindows } from "harness"; import { connect, fileURLToPath, SocketHandler, spawn } from "bun"; - +import type { Socket } from "bun"; it("should coerce '0' to 0", async () => { const listener = Bun.listen({ // @ts-expect-error @@ -271,3 +271,88 @@ it("socket.timeout works", async () => { it("should allow large amounts of data to be sent and received", async () => { expect([fileURLToPath(new URL("./socket-huge-fixture.js", import.meta.url))]).toRun(); }, 60_000); + +it("it should not crash when getting a ReferenceError on client socket open", async () => { + const server = Bun.serve({ + port: 8080, + hostname: "localhost", + fetch() { + return new Response("Hello World"); + }, + }); + try { + const { resolve, reject, promise } = Promise.withResolvers(); + let client: Socket | null = null; + const timeout = setTimeout(() => { + client?.end(); + reject(new Error("Timeout")); + }, 1000); + client = await Bun.connect({ + port: server.port, + hostname: server.hostname, + socket: { + open(socket) { + // ReferenceError: Can't find variable: bytes + // @ts-expect-error + socket.write(bytes); + }, + error(socket, error) { + clearTimeout(timeout); + resolve(error); + }, + close(socket) { + // we need the close handler + resolve({ message: "Closed" }); + }, + data(socket, data) {}, + }, + }); + + const result: any = await promise; + expect(result?.message).toBe("Can't find variable: bytes"); + } finally { + server.stop(true); + } +}); + +it("it should not crash when returning a Error on client socket open", async () => { + const server = Bun.serve({ + port: 8080, + hostname: "localhost", + fetch() { + return new Response("Hello World"); + }, + }); + try { + const { resolve, reject, promise } = Promise.withResolvers(); + let client: Socket | null = null; + const timeout = setTimeout(() => { + client?.end(); + reject(new Error("Timeout")); + }, 1000); + client = await Bun.connect({ + port: server.port, + hostname: server.hostname, + socket: { + //@ts-ignore + open(socket) { + return new Error("CustomError"); + }, + error(socket, error) { + clearTimeout(timeout); + resolve(error); + }, + close(socket) { + // we need the close handler + resolve({ message: "Closed" }); + }, + data(socket, data) {}, + }, + }); + + const result: any = await promise; + expect(result?.message).toBe("CustomError"); + } finally { + server.stop(true); + } +});