diff --git a/tunnel-server/src/memoize.test.ts b/tunnel-server/src/memoize.test.ts new file mode 100644 index 00000000..d476694f --- /dev/null +++ b/tunnel-server/src/memoize.test.ts @@ -0,0 +1,64 @@ +import { afterAll, beforeAll, beforeEach, describe, it, expect, jest } from '@jest/globals' +import { memoizeForDuration } from './memoize' + +describe('memoizeForDuration', () => { + beforeAll(() => { + jest.useFakeTimers() + }) + afterAll(() => { + jest.useRealTimers() + }) + + let fn: jest.Mock<() => number> + let memoized: () => number + + beforeEach(() => { + fn = jest.fn(() => 12) + memoized = memoizeForDuration(fn, 1000) + }) + + describe('before the first call', () => { + it('does not call the specified function', () => { + expect(fn).not.toHaveBeenCalled() + }) + }) + + describe('on the first call', () => { + let v: number + beforeEach(() => { + v = memoized() + }) + it('calls the specified function', () => { + expect(fn).toHaveBeenCalledTimes(1) + }) + it('returns the memoized value', () => { + expect(v).toBe(12) + }) + + describe('on the second call, when the expiry duration has not passed', () => { + beforeEach(() => { + jest.advanceTimersByTime(999) + v = memoized() + }) + it('does not call the specified function again', () => { + expect(fn).toHaveBeenCalledTimes(1) + }) + it('returns the memoized value', () => { + expect(v).toBe(12) + }) + }) + + describe('on the second call, when the expiry duration has passed', () => { + beforeEach(() => { + jest.advanceTimersByTime(1000) + v = memoized() + }) + it('calls the specified function again', () => { + expect(fn).toHaveBeenCalledTimes(2) + }) + it('returns the memoized value', () => { + expect(v).toBe(12) + }) + }) + }) +}) diff --git a/tunnel-server/src/memoize.ts b/tunnel-server/src/memoize.ts new file mode 100644 index 00000000..65652279 --- /dev/null +++ b/tunnel-server/src/memoize.ts @@ -0,0 +1,9 @@ +export const memoizeForDuration = (f: () => T, milliseconds: number) => { + let cache: { value: T; expiry: number } | undefined + return () => { + if (!cache || cache.expiry <= Date.now()) { + cache = { value: f(), expiry: Date.now() + milliseconds } + } + return cache.value + } +} diff --git a/tunnel-server/src/ssh/base-server.ts b/tunnel-server/src/ssh/base-server.ts index 70918b88..059d9228 100644 --- a/tunnel-server/src/ssh/base-server.ts +++ b/tunnel-server/src/ssh/base-server.ts @@ -10,6 +10,7 @@ import { calculateJwkThumbprintUri, exportJWK } from 'jose' import { ForwardRequest, parseForwardRequest } from '../forward-request' import { createDestroy } from '../destroy-server' import { onceWithTimeout } from '../events' +import { memoizeForDuration } from '../memoize' const clientIdFromPublicSsh = (key: Buffer) => crypto.createHash('sha1').update(key).digest('base64url').replace(/[_-]/g, '') @@ -109,11 +110,12 @@ export const baseSshServer = ( let authContext: ssh2.AuthContext let key: ssh2.ParsedKey - const ping = async (milliseconds: number) => { - const result = onceWithTimeout(client, 'rekey', { milliseconds, fallback: () => 'timeout' as const }) + const PING_TIMEOUT = 5000 + const ping = memoizeForDuration(async () => { + const result = onceWithTimeout(client, 'rekey', { milliseconds: PING_TIMEOUT, fallback: () => 'timeout' as const }) client.rekey() - return await result !== 'timeout' - } + return (await result) !== 'timeout' + }, PING_TIMEOUT) client .on('authentication', async ctx => {