diff --git a/index.js b/index.js index 4f24ca1..ad52b9e 100644 --- a/index.js +++ b/index.js @@ -175,6 +175,9 @@ class WebSocketProxy { handleConnection (source, request, dest) { const url = this.findUpstream(request, dest) + const queryString = getQueryString(url.search, request.url, this.wsClientOptions, request) + url.search = queryString + const rewriteRequestHeaders = this.wsClientOptions.rewriteRequestHeaders const headersToRewrite = this.wsClientOptions.headers @@ -192,6 +195,22 @@ class WebSocketProxy { } } +function getQueryString (search, reqUrl, opts, request) { + if (typeof opts.queryString === 'function') { + return '?' + opts.queryString(search, reqUrl, request) + } + + if (opts.queryString) { + return '?' + qs.stringify(opts.queryString) + } + + if (search.length > 0) { + return search + } + + return '' +} + function defaultWsHeadersRewrite (headers, request) { if (request.headers.cookie) { return { ...headers, cookie: request.headers.cookie } diff --git a/test/websocket-querystring.js b/test/websocket-querystring.js new file mode 100644 index 0000000..c986f2e --- /dev/null +++ b/test/websocket-querystring.js @@ -0,0 +1,126 @@ +'use strict' + +const { test } = require('tap') +const Fastify = require('fastify') +const proxy = require('../') +const WebSocket = require('ws') +const { createServer } = require('node:http') +const { promisify } = require('node:util') +const { once } = require('node:events') +const qs = require('fast-querystring') + +const subprotocolValue = 'foo-subprotocol' + +test('websocket proxy with object queryString', async (t) => { + t.plan(7) + + const origin = createServer() + const wss = new WebSocket.Server({ server: origin }) + t.teardown(wss.close.bind(wss)) + t.teardown(origin.close.bind(origin)) + + const serverMessages = [] + wss.on('connection', (ws, request) => { + t.equal(ws.protocol, subprotocolValue) + t.equal(request.url, '/?q=test') + ws.on('message', (message, binary) => { + serverMessages.push([message.toString(), binary]) + // echo + ws.send(message, { binary }) + }) + }) + + await promisify(origin.listen.bind(origin))({ port: 0, host: '127.0.0.1' }) + + const server = Fastify() + server.register(proxy, { + upstream: `ws://127.0.0.1:${origin.address().port}`, + websocket: true, + wsClientOptions: { + queryString: { q: 'test' } + } + }) + + await server.listen({ port: 0, host: '127.0.0.1' }) + t.teardown(server.close.bind(server)) + + const ws = new WebSocket(`ws://127.0.0.1:${server.server.address().port}`, [subprotocolValue]) + await once(ws, 'open') + + ws.send('hello', { binary: false }) + const [reply0, binary0] = await once(ws, 'message') + t.equal(reply0.toString(), 'hello') + t.equal(binary0, false) + + ws.send(Buffer.from('fastify'), { binary: true }) + const [reply1, binary1] = await once(ws, 'message') + t.equal(reply1.toString(), 'fastify') + t.equal(binary1, true) + + t.strictSame(serverMessages, [ + ['hello', false], + ['fastify', true] + ]) + + await Promise.all([ + once(ws, 'close'), + server.close() + ]) +}) + +test('websocket proxy with function queryString', async (t) => { + t.plan(7) + + const origin = createServer() + const wss = new WebSocket.Server({ server: origin }) + t.teardown(wss.close.bind(wss)) + t.teardown(origin.close.bind(origin)) + + const serverMessages = [] + wss.on('connection', (ws, request) => { + t.equal(ws.protocol, subprotocolValue) + t.equal(request.url, '/?q=test') + ws.on('message', (message, binary) => { + serverMessages.push([message.toString(), binary]) + // echo + ws.send(message, { binary }) + }) + }) + + await promisify(origin.listen.bind(origin))({ port: 0, host: '127.0.0.1' }) + + const server = Fastify() + server.register(proxy, { + upstream: `ws://127.0.0.1:${origin.address().port}`, + websocket: true, + wsClientOptions: { + queryString: () => qs.stringify({ q: 'test' }) + } + }) + + await server.listen({ port: 0, host: '127.0.0.1' }) + t.teardown(server.close.bind(server)) + + const ws = new WebSocket(`ws://127.0.0.1:${server.server.address().port}`, [subprotocolValue]) + await once(ws, 'open') + + ws.send('hello', { binary: false }) + const [reply0, binary0] = await once(ws, 'message') + t.equal(reply0.toString(), 'hello') + t.equal(binary0, false) + + ws.send(Buffer.from('fastify'), { binary: true }) + const [reply1, binary1] = await once(ws, 'message') + t.equal(reply1.toString(), 'fastify') + t.equal(binary1, true) + + t.strictSame(serverMessages, [ + ['hello', false], + ['fastify', true] + ]) + + await Promise.all([ + once(ws, 'close'), + server.close() + ]) +}) diff --git a/types/index.d.ts b/types/index.d.ts index d40cfd9..166bf27 100644 --- a/types/index.d.ts +++ b/types/index.d.ts @@ -1,6 +1,13 @@ /// -import { FastifyPluginCallback, preHandlerHookHandler, preValidationHookHandler } from 'fastify'; +import { + FastifyPluginCallback, + FastifyRequest, + preHandlerHookHandler, + preValidationHookHandler, + RawServerBase, + RequestGenericInterface, +} from 'fastify'; import { FastifyReplyFromOptions, @@ -24,6 +31,12 @@ type FastifyHttpProxy = FastifyPluginCallback< >; declare namespace fastifyHttpProxy { + type QueryStringFunction = ( + search: string | undefined, + reqUrl: string, + request: FastifyRequest + ) => string; + export interface FastifyHttpProxyOptions extends FastifyReplyFromOptions { upstream: string; prefix?: string; @@ -34,7 +47,7 @@ declare namespace fastifyHttpProxy { preValidation?: preValidationHookHandler; config?: Object; replyOptions?: FastifyReplyFromHooks; - wsClientOptions?: ClientOptions; + wsClientOptions?: ClientOptions & { queryString?: { [key: string]: unknown } | QueryStringFunction; }; wsServerOptions?: ServerOptions; httpMethods?: string[]; constraints?: { [name: string]: any }; diff --git a/types/index.test-d.ts b/types/index.test-d.ts index b1373e9..f2f9ea9 100644 --- a/types/index.test-d.ts +++ b/types/index.test-d.ts @@ -1,6 +1,9 @@ import fastify, { RawReplyDefaultExpression, RawRequestDefaultExpression, + type FastifyRequest, + type RawServerBase, + type RequestGenericInterface, } from 'fastify'; import { expectError, expectType } from 'tsd'; import fastifyHttpProxy from '..'; @@ -52,6 +55,14 @@ app.register(fastifyHttpProxy, { constraints: { version: '1.0.2' }, websocket: true, wsUpstream: 'ws://origin.asd/connection', + wsClientOptions: { + queryString(search, reqUrl, request) { + expectType(search); + expectType(reqUrl); + expectType>(request); + return ''; + }, + }, internalRewriteLocationHeader: true, });