diff --git a/README.md b/README.md index f95d9b6..9672cfc 100644 --- a/README.md +++ b/README.md @@ -219,9 +219,6 @@ Note that if property `wsUpstream` not specified then proxy will try to connect The options passed to [`new ws.Server()`](https://github.com/websockets/ws/blob/HEAD/doc/ws.md#class-websocketserver). -In case multiple websocket proxies are attached to the same HTTP server at different paths. -In this case, only the first `wsServerOptions` is applied. - ### `wsClientOptions` The options passed to the [`WebSocket` constructor](https://github.com/websockets/ws/blob/HEAD/doc/ws.md#class-websocket) for outgoing websockets. diff --git a/index.js b/index.js index 6c85d86..1db3276 100644 --- a/index.js +++ b/index.js @@ -10,6 +10,7 @@ const httpMethods = ['DELETE', 'GET', 'HEAD', 'PATCH', 'POST', 'PUT', 'OPTIONS'] const urlPattern = /^https?:\/\// const kWs = Symbol('ws') const kWsHead = Symbol('wsHead') +const kWsUpgradeListener = Symbol('wsUpgradeListener') function liftErrorCode (code) { /* istanbul ignore next */ @@ -74,32 +75,46 @@ function proxyWebSockets (source, target) { target.on('unexpected-response', () => close(1011, 'unexpected response')) } +function handleUpgrade (fastify, rawRequest, socket, head) { + // Save a reference to the socket and then dispatch the request through the normal fastify router so that it will invoke hooks and then eventually a route handler that might upgrade the socket. + rawRequest[kWs] = socket + rawRequest[kWsHead] = head + + const rawResponse = new ServerResponse(rawRequest) + rawResponse.assignSocket(socket) + fastify.routing(rawRequest, rawResponse) + + rawResponse.on('finish', () => { + socket.destroy() + }) +} + class WebSocketProxy { - constructor (fastify, wsServerOptions) { + constructor (fastify, { wsServerOptions, wsClientOptions, upstream, wsUpstream, replyOptions: { getUpstream } = {} }) { this.logger = fastify.log + this.wsClientOptions = { + rewriteRequestHeaders: defaultWsHeadersRewrite, + headers: {}, + ...wsClientOptions + } + this.upstream = convertUrlToWebSocket(upstream) + this.wsUpstream = wsUpstream ? convertUrlToWebSocket(wsUpstream) : '' + this.getUpstream = getUpstream const wss = new WebSocket.Server({ noServer: true, ...wsServerOptions }) - fastify.server.on('upgrade', (rawRequest, socket, head) => { - // Save a reference to the socket and then dispatch the request through the normal fastify router so that it will invoke hooks and then eventually a route handler that might upgrade the socket. - rawRequest[kWs] = socket - rawRequest[kWsHead] = head - - const rawResponse = new ServerResponse(rawRequest) - rawResponse.assignSocket(socket) - fastify.routing(rawRequest, rawResponse) - - rawResponse.on('finish', () => { - socket.destroy() - }) - }) + if (!fastify.server[kWsUpgradeListener]) { + fastify.server[kWsUpgradeListener] = (rawRequest, socket, head) => + handleUpgrade(fastify, rawRequest, socket, head) + fastify.server.on('upgrade', fastify.server[kWsUpgradeListener]) + } - this.handleUpgrade = (request, cb) => { + this.handleUpgrade = (request, dest, cb) => { wss.handleUpgrade(request.raw, request.raw[kWs], request.raw[kWsHead], (socket) => { - this.handleConnection(socket, request) + this.handleConnection(socket, request, dest) cb() }) } @@ -134,45 +149,33 @@ class WebSocketProxy { this.prefixList = [] } - addUpstream (prefix, rewritePrefix, upstream, wsUpstream, wsClientOptions) { - this.prefixList.push({ - prefix: new URL(prefix, 'ws://127.0.0.1').pathname, - rewritePrefix, - upstream: convertUrlToWebSocket(upstream), - wsUpstream: wsUpstream ? convertUrlToWebSocket(wsUpstream) : '', - wsClientOptions - }) + findUpstream (request, dest) { + const search = new URL(request.url, 'ws://127.0.0.1').search - // sort by decreasing prefix length, so that findUpstreamUrl() does longest prefix match - this.prefixList.sort((a, b) => b.prefix.length - a.prefix.length) - } - - findUpstream (request) { - const source = new URL(request.url, 'ws://127.0.0.1') - - for (const { prefix, rewritePrefix, upstream, wsUpstream, wsClientOptions } of this.prefixList) { - if (wsUpstream) { - const target = new URL(wsUpstream) - target.search = source.search - return { target, wsClientOptions } - } + if (typeof this.wsUpstream === 'string' && this.wsUpstream !== '') { + const target = new URL(this.wsUpstream) + target.search = search + return target + } - if (source.pathname.startsWith(prefix)) { - const target = new URL(source.pathname.replace(prefix, rewritePrefix), upstream) - target.search = source.search - return { target, wsClientOptions } - } + if (typeof this.upstream === 'string' && this.upstream !== '') { + const target = new URL(dest, this.upstream) + target.search = search + return target } + const upstream = this.getUpstream(request, '') + const target = new URL(dest, upstream) /* istanbul ignore next */ - throw new Error(`no upstream found for ${request.url}. this should not happened. Please report to https://github.com/fastify/fastify-http-proxy`) + target.protocol = upstream.indexOf('http:') === 0 ? 'ws:' : 'wss' + target.search = search + return target } - handleConnection (source, request) { - const upstream = this.findUpstream(request) - const { target: url, wsClientOptions } = upstream - const rewriteRequestHeaders = wsClientOptions?.rewriteRequestHeaders || defaultWsHeadersRewrite - const headersToRewrite = wsClientOptions?.headers || {} + handleConnection (source, request, dest) { + const url = this.findUpstream(request, dest) + const rewriteRequestHeaders = this.wsClientOptions.rewriteRequestHeaders + const headersToRewrite = this.wsClientOptions.headers const subprotocols = [] if (source.protocol) { @@ -180,7 +183,7 @@ class WebSocketProxy { } const headers = rewriteRequestHeaders(headersToRewrite, request) - const optionsWs = { ...(wsClientOptions || {}), headers } + const optionsWs = { ...this.wsClientOptions, headers } const target = new WebSocket(url, subprotocols, optionsWs) this.logger.debug({ url: url.href }, 'proxy websocket') @@ -195,41 +198,6 @@ function defaultWsHeadersRewrite (headers, request) { return { ...headers } } -const httpWss = new WeakMap() // http.Server => WebSocketProxy - -function setupWebSocketProxy (fastify, options, rewritePrefix) { - let wsProxy = httpWss.get(fastify.server) - if (!wsProxy) { - wsProxy = new WebSocketProxy(fastify, options.wsServerOptions) - httpWss.set(fastify.server, wsProxy) - } - - if ( - (typeof options.wsUpstream === 'string' && options.wsUpstream !== '') || - (typeof options.upstream === 'string' && options.upstream !== '') - ) { - wsProxy.addUpstream( - fastify.prefix, - rewritePrefix, - options.upstream, - options.wsUpstream, - options.wsClientOptions - ) - // The else block is validate earlier in the code - } else { - wsProxy.findUpstream = function (request) { - const source = new URL(request.url, 'ws://127.0.0.1') - const upstream = options.replyOptions.getUpstream(request, '') - const target = new URL(source.pathname, upstream) - /* istanbul ignore next */ - target.protocol = upstream.indexOf('http:') === 0 ? 'ws:' : 'wss' - target.search = source.search - return { target, wsClientOptions: options.wsClientOptions } - } - } - return wsProxy -} - function generateRewritePrefix (prefix, opts) { let rewritePrefix = opts.rewritePrefix || (opts.upstream ? new URL(opts.upstream).pathname : '/') @@ -303,7 +271,7 @@ async function fastifyHttpProxy (fastify, opts) { let wsProxy if (opts.websocket) { - wsProxy = setupWebSocketProxy(fastify, opts, rewritePrefix) + wsProxy = new WebSocketProxy(fastify, opts) } function extractUrlComponents (urlString) { @@ -321,16 +289,6 @@ async function fastifyHttpProxy (fastify, opts) { } function handler (request, reply) { - if (request.raw[kWs]) { - reply.hijack() - try { - wsProxy.handleUpgrade(request, noop) - } catch (err) { - /* istanbul ignore next */ - request.log.warn({ err }, 'websocket proxy error') - } - return - } const { path, queryParams } = extractUrlComponents(request.url) let dest = path @@ -350,6 +308,17 @@ async function fastifyHttpProxy (fastify, opts) { } else { dest = dest.replace(this.prefix, rewritePrefix) } + + if (request.raw[kWs]) { + reply.hijack() + try { + wsProxy.handleUpgrade(request, dest || '/', noop) + } catch (err) { + /* istanbul ignore next */ + request.log.warn({ err }, 'websocket proxy error') + } + return + } reply.from(dest || '/', replyOpts) } } diff --git a/test/websocket.js b/test/websocket.js index 4d85f15..b08395e 100644 --- a/test/websocket.js +++ b/test/websocket.js @@ -474,3 +474,103 @@ test('Proxy websocket with custom upstream url', async (t) => { server.close() ]) }) + +test('multiple websocket upstreams with host constraints', async (t) => { + t.plan(4) + + const server = Fastify() + + for (const name of ['foo', 'bar']) { + const origin = createServer() + const wss = new WebSocket.Server({ server: origin }) + t.teardown(wss.close.bind(wss)) + t.teardown(origin.close.bind(origin)) + + wss.once('connection', (ws) => { + ws.once('message', message => { + t.equal(message.toString(), `hello ${name}`) + // echo + ws.send(message) + }) + }) + + await promisify(origin.listen.bind(origin))({ port: 0, host: '127.0.0.1' }) + server.register(proxy, { + upstream: `ws://127.0.0.1:${origin.address().port}`, + websocket: true, + constraints: { host: name } + }) + } + + await server.listen({ port: 0, host: '127.0.0.1' }) + t.teardown(server.close.bind(server)) + + const wsClients = [] + for (const name of ['foo', 'bar']) { + const ws = new WebSocket(`ws://127.0.0.1:${server.server.address().port}`, { headers: { host: name } }) + await once(ws, 'open') + ws.send(`hello ${name}`) + const [reply] = await once(ws, 'message') + t.equal(reply.toString(), `hello ${name}`) + wsClients.push(ws) + } + + await Promise.all([ + ...wsClients.map(ws => once(ws, 'close')), + server.close() + ]) +}) + +test('multiple websocket upstreams with distinct server options', async (t) => { + t.plan(4) + + const server = Fastify() + + for (const name of ['foo', 'bar']) { + const origin = createServer() + const wss = new WebSocket.Server({ server: origin }) + t.teardown(wss.close.bind(wss)) + t.teardown(origin.close.bind(origin)) + + wss.once('connection', (ws, req) => { + t.equal(req.url, `/?q=${name}`) + ws.once('message', message => { + // echo + ws.send(message) + }) + }) + + await promisify(origin.listen.bind(origin))({ port: 0, host: '127.0.0.1' }) + server.register(proxy, { + upstream: `ws://127.0.0.1:${origin.address().port}`, + websocket: true, + constraints: { host: name }, + wsServerOptions: { + verifyClient: ({ req }) => { + t.equal(req.url, `/?q=${name}`) + return true + } + } + }) + } + + await server.listen({ port: 0, host: '127.0.0.1' }) + t.teardown(server.close.bind(server)) + + const wsClients = [] + for (const name of ['foo', 'bar']) { + const ws = new WebSocket( + `ws://127.0.0.1:${server.server.address().port}/?q=${name}`, + { headers: { host: name } } + ) + await once(ws, 'open') + ws.send(`hello ${name}`) + await once(ws, 'message') + wsClients.push(ws) + } + + await Promise.all([ + ...wsClients.map(ws => once(ws, 'close')), + server.close() + ]) +})