Skip to content

Commit

Permalink
Add WebSocket host constraints (#332)
Browse files Browse the repository at this point in the history
  • Loading branch information
valeneiko authored Jan 30, 2024
1 parent 2229d76 commit 916d85c
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 98 deletions.
3 changes: 0 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,6 @@ Note that if property `wsUpstream` not specified then proxy will try to connect

The options passed to [`new ws.Server()`](https://github.com/websockets/ws/blob/HEAD/doc/ws.md#class-websocketserver).

In case multiple websocket proxies are attached to the same HTTP server at different paths.
In this case, only the first `wsServerOptions` is applied.

### `wsClientOptions`

The options passed to the [`WebSocket` constructor](https://github.com/websockets/ws/blob/HEAD/doc/ws.md#class-websocket) for outgoing websockets.
Expand Down
159 changes: 64 additions & 95 deletions index.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ const httpMethods = ['DELETE', 'GET', 'HEAD', 'PATCH', 'POST', 'PUT', 'OPTIONS']
const urlPattern = /^https?:\/\//
const kWs = Symbol('ws')
const kWsHead = Symbol('wsHead')
const kWsUpgradeListener = Symbol('wsUpgradeListener')

function liftErrorCode (code) {
/* istanbul ignore next */
Expand Down Expand Up @@ -74,32 +75,46 @@ function proxyWebSockets (source, target) {
target.on('unexpected-response', () => close(1011, 'unexpected response'))
}

function handleUpgrade (fastify, rawRequest, socket, head) {
// Save a reference to the socket and then dispatch the request through the normal fastify router so that it will invoke hooks and then eventually a route handler that might upgrade the socket.
rawRequest[kWs] = socket
rawRequest[kWsHead] = head

const rawResponse = new ServerResponse(rawRequest)
rawResponse.assignSocket(socket)
fastify.routing(rawRequest, rawResponse)

rawResponse.on('finish', () => {
socket.destroy()
})
}

class WebSocketProxy {
constructor (fastify, wsServerOptions) {
constructor (fastify, { wsServerOptions, wsClientOptions, upstream, wsUpstream, replyOptions: { getUpstream } = {} }) {
this.logger = fastify.log
this.wsClientOptions = {
rewriteRequestHeaders: defaultWsHeadersRewrite,
headers: {},
...wsClientOptions
}
this.upstream = convertUrlToWebSocket(upstream)
this.wsUpstream = wsUpstream ? convertUrlToWebSocket(wsUpstream) : ''
this.getUpstream = getUpstream

const wss = new WebSocket.Server({
noServer: true,
...wsServerOptions
})

fastify.server.on('upgrade', (rawRequest, socket, head) => {
// Save a reference to the socket and then dispatch the request through the normal fastify router so that it will invoke hooks and then eventually a route handler that might upgrade the socket.
rawRequest[kWs] = socket
rawRequest[kWsHead] = head

const rawResponse = new ServerResponse(rawRequest)
rawResponse.assignSocket(socket)
fastify.routing(rawRequest, rawResponse)

rawResponse.on('finish', () => {
socket.destroy()
})
})
if (!fastify.server[kWsUpgradeListener]) {
fastify.server[kWsUpgradeListener] = (rawRequest, socket, head) =>
handleUpgrade(fastify, rawRequest, socket, head)
fastify.server.on('upgrade', fastify.server[kWsUpgradeListener])
}

this.handleUpgrade = (request, cb) => {
this.handleUpgrade = (request, dest, cb) => {
wss.handleUpgrade(request.raw, request.raw[kWs], request.raw[kWsHead], (socket) => {
this.handleConnection(socket, request)
this.handleConnection(socket, request, dest)
cb()
})
}
Expand Down Expand Up @@ -134,53 +149,41 @@ class WebSocketProxy {
this.prefixList = []
}

addUpstream (prefix, rewritePrefix, upstream, wsUpstream, wsClientOptions) {
this.prefixList.push({
prefix: new URL(prefix, 'ws://127.0.0.1').pathname,
rewritePrefix,
upstream: convertUrlToWebSocket(upstream),
wsUpstream: wsUpstream ? convertUrlToWebSocket(wsUpstream) : '',
wsClientOptions
})
findUpstream (request, dest) {
const search = new URL(request.url, 'ws://127.0.0.1').search

// sort by decreasing prefix length, so that findUpstreamUrl() does longest prefix match
this.prefixList.sort((a, b) => b.prefix.length - a.prefix.length)
}

findUpstream (request) {
const source = new URL(request.url, 'ws://127.0.0.1')

for (const { prefix, rewritePrefix, upstream, wsUpstream, wsClientOptions } of this.prefixList) {
if (wsUpstream) {
const target = new URL(wsUpstream)
target.search = source.search
return { target, wsClientOptions }
}
if (typeof this.wsUpstream === 'string' && this.wsUpstream !== '') {
const target = new URL(this.wsUpstream)
target.search = search
return target
}

if (source.pathname.startsWith(prefix)) {
const target = new URL(source.pathname.replace(prefix, rewritePrefix), upstream)
target.search = source.search
return { target, wsClientOptions }
}
if (typeof this.upstream === 'string' && this.upstream !== '') {
const target = new URL(dest, this.upstream)
target.search = search
return target
}

const upstream = this.getUpstream(request, '')
const target = new URL(dest, upstream)
/* istanbul ignore next */
throw new Error(`no upstream found for ${request.url}. this should not happened. Please report to https://github.com/fastify/fastify-http-proxy`)
target.protocol = upstream.indexOf('http:') === 0 ? 'ws:' : 'wss'
target.search = search
return target
}

handleConnection (source, request) {
const upstream = this.findUpstream(request)
const { target: url, wsClientOptions } = upstream
const rewriteRequestHeaders = wsClientOptions?.rewriteRequestHeaders || defaultWsHeadersRewrite
const headersToRewrite = wsClientOptions?.headers || {}
handleConnection (source, request, dest) {
const url = this.findUpstream(request, dest)
const rewriteRequestHeaders = this.wsClientOptions.rewriteRequestHeaders
const headersToRewrite = this.wsClientOptions.headers

const subprotocols = []
if (source.protocol) {
subprotocols.push(source.protocol)
}

const headers = rewriteRequestHeaders(headersToRewrite, request)
const optionsWs = { ...(wsClientOptions || {}), headers }
const optionsWs = { ...this.wsClientOptions, headers }

const target = new WebSocket(url, subprotocols, optionsWs)
this.logger.debug({ url: url.href }, 'proxy websocket')
Expand All @@ -195,41 +198,6 @@ function defaultWsHeadersRewrite (headers, request) {
return { ...headers }
}

const httpWss = new WeakMap() // http.Server => WebSocketProxy

function setupWebSocketProxy (fastify, options, rewritePrefix) {
let wsProxy = httpWss.get(fastify.server)
if (!wsProxy) {
wsProxy = new WebSocketProxy(fastify, options.wsServerOptions)
httpWss.set(fastify.server, wsProxy)
}

if (
(typeof options.wsUpstream === 'string' && options.wsUpstream !== '') ||
(typeof options.upstream === 'string' && options.upstream !== '')
) {
wsProxy.addUpstream(
fastify.prefix,
rewritePrefix,
options.upstream,
options.wsUpstream,
options.wsClientOptions
)
// The else block is validate earlier in the code
} else {
wsProxy.findUpstream = function (request) {
const source = new URL(request.url, 'ws://127.0.0.1')
const upstream = options.replyOptions.getUpstream(request, '')
const target = new URL(source.pathname, upstream)
/* istanbul ignore next */
target.protocol = upstream.indexOf('http:') === 0 ? 'ws:' : 'wss'
target.search = source.search
return { target, wsClientOptions: options.wsClientOptions }
}
}
return wsProxy
}

function generateRewritePrefix (prefix, opts) {
let rewritePrefix = opts.rewritePrefix || (opts.upstream ? new URL(opts.upstream).pathname : '/')

Expand Down Expand Up @@ -303,7 +271,7 @@ async function fastifyHttpProxy (fastify, opts) {
let wsProxy

if (opts.websocket) {
wsProxy = setupWebSocketProxy(fastify, opts, rewritePrefix)
wsProxy = new WebSocketProxy(fastify, opts)
}

function extractUrlComponents (urlString) {
Expand All @@ -321,16 +289,6 @@ async function fastifyHttpProxy (fastify, opts) {
}

function handler (request, reply) {
if (request.raw[kWs]) {
reply.hijack()
try {
wsProxy.handleUpgrade(request, noop)
} catch (err) {
/* istanbul ignore next */
request.log.warn({ err }, 'websocket proxy error')
}
return
}
const { path, queryParams } = extractUrlComponents(request.url)
let dest = path

Expand All @@ -350,6 +308,17 @@ async function fastifyHttpProxy (fastify, opts) {
} else {
dest = dest.replace(this.prefix, rewritePrefix)
}

if (request.raw[kWs]) {
reply.hijack()
try {
wsProxy.handleUpgrade(request, dest || '/', noop)
} catch (err) {
/* istanbul ignore next */
request.log.warn({ err }, 'websocket proxy error')
}
return
}
reply.from(dest || '/', replyOpts)
}
}
Expand Down
100 changes: 100 additions & 0 deletions test/websocket.js
Original file line number Diff line number Diff line change
Expand Up @@ -474,3 +474,103 @@ test('Proxy websocket with custom upstream url', async (t) => {
server.close()
])
})

test('multiple websocket upstreams with host constraints', async (t) => {
t.plan(4)

const server = Fastify()

for (const name of ['foo', 'bar']) {
const origin = createServer()
const wss = new WebSocket.Server({ server: origin })
t.teardown(wss.close.bind(wss))
t.teardown(origin.close.bind(origin))

wss.once('connection', (ws) => {
ws.once('message', message => {
t.equal(message.toString(), `hello ${name}`)
// echo
ws.send(message)
})
})

await promisify(origin.listen.bind(origin))({ port: 0, host: '127.0.0.1' })
server.register(proxy, {
upstream: `ws://127.0.0.1:${origin.address().port}`,
websocket: true,
constraints: { host: name }
})
}

await server.listen({ port: 0, host: '127.0.0.1' })
t.teardown(server.close.bind(server))

const wsClients = []
for (const name of ['foo', 'bar']) {
const ws = new WebSocket(`ws://127.0.0.1:${server.server.address().port}`, { headers: { host: name } })
await once(ws, 'open')
ws.send(`hello ${name}`)
const [reply] = await once(ws, 'message')
t.equal(reply.toString(), `hello ${name}`)
wsClients.push(ws)
}

await Promise.all([
...wsClients.map(ws => once(ws, 'close')),
server.close()
])
})

test('multiple websocket upstreams with distinct server options', async (t) => {
t.plan(4)

const server = Fastify()

for (const name of ['foo', 'bar']) {
const origin = createServer()
const wss = new WebSocket.Server({ server: origin })
t.teardown(wss.close.bind(wss))
t.teardown(origin.close.bind(origin))

wss.once('connection', (ws, req) => {
t.equal(req.url, `/?q=${name}`)
ws.once('message', message => {
// echo
ws.send(message)
})
})

await promisify(origin.listen.bind(origin))({ port: 0, host: '127.0.0.1' })
server.register(proxy, {
upstream: `ws://127.0.0.1:${origin.address().port}`,
websocket: true,
constraints: { host: name },
wsServerOptions: {
verifyClient: ({ req }) => {
t.equal(req.url, `/?q=${name}`)
return true
}
}
})
}

await server.listen({ port: 0, host: '127.0.0.1' })
t.teardown(server.close.bind(server))

const wsClients = []
for (const name of ['foo', 'bar']) {
const ws = new WebSocket(
`ws://127.0.0.1:${server.server.address().port}/?q=${name}`,
{ headers: { host: name } }
)
await once(ws, 'open')
ws.send(`hello ${name}`)
await once(ws, 'message')
wsClients.push(ws)
}

await Promise.all([
...wsClients.map(ws => once(ws, 'close')),
server.close()
])
})

0 comments on commit 916d85c

Please sign in to comment.