diff --git a/packages/aws-lambda-graphql/README.md b/packages/aws-lambda-graphql/README.md index b222304..f77476f 100644 --- a/packages/aws-lambda-graphql/README.md +++ b/packages/aws-lambda-graphql/README.md @@ -56,6 +56,7 @@ All options from Apollo Lambda Server and - **waitForInitialization** (`optional`) - if connection is not initialised on GraphQL operation, wait for connection to be initialised or throw prohibited connection error. If `onConnect` is specified then we wait for initialisation otherwise we don't wait. (this is usefull if you're performing authentication in `onConnect`). - **retryCount** (`number`, `optional`, `default 10`) - how many times should we try to check the connection state? - **timeout** (`number`, `optional`, `default 50ms`) - how long should we wait (in milliseconds) until we try to check the connection state again? + - **connectionEndpoint** (`string`, `optional`) - if specified, the connection endpoint will be registered with this value as opposed to extracted from the event payload (as `${domainName}/${stage}`) #### `createHttpHandler()` diff --git a/packages/aws-lambda-graphql/src/Server.ts b/packages/aws-lambda-graphql/src/Server.ts index 842d5ad..b7d4e02 100644 --- a/packages/aws-lambda-graphql/src/Server.ts +++ b/packages/aws-lambda-graphql/src/Server.ts @@ -118,6 +118,12 @@ export interface ServerConfig< */ timeout?: number; }; + + /** + * If specified, the connection endpoint will be registered with this value as opposed to extracted from the event payload + * + */ + connectionEndpoint?: string; }; } @@ -270,12 +276,14 @@ export class Server< // based on routeKey, do actions switch (event.requestContext.routeKey) { case '$connect': { - const { onWebsocketConnect } = this.subscriptionOptions || {}; + const { onWebsocketConnect, connectionEndpoint } = + this.subscriptionOptions || {}; // register connection // if error is thrown during registration, connection is rejected // we can implement some sort of authorization here - const endpoint = extractEndpointFromEvent(event); + const endpoint = + connectionEndpoint || extractEndpointFromEvent(event); const connection = await this.connectionManager.registerConnection({ endpoint, diff --git a/packages/aws-lambda-graphql/src/__tests__/Server.test.ts b/packages/aws-lambda-graphql/src/__tests__/Server.test.ts index 97b36fd..7831c10 100644 --- a/packages/aws-lambda-graphql/src/__tests__/Server.test.ts +++ b/packages/aws-lambda-graphql/src/__tests__/Server.test.ts @@ -212,6 +212,49 @@ describe('Server', () => { ); }); + it('registers connection with the endpoint value of connectionEndpoint option', async () => { + (connectionManager.registerConnection as jest.Mock).mockResolvedValueOnce( + {}, + ); + (connectionManager.setConnectionData as jest.Mock).mockResolvedValueOnce( + {}, + ); + const handlerWithConnectionEndpoint = new Server({ + connectionManager, + eventProcessor: new MemoryEventProcessor(), + schema: createSchema(), + subscriptionManager, + subscriptions: { + connectionEndpoint: 'customdomain', + }, + }).createWebSocketHandler(); + + await expect( + handlerWithConnectionEndpoint( + { + requestContext: { + connectionId: '1', + domainName: 'domain', + routeKey: '$connect', + stage: 'stage', + } as any, + } as any, + {} as any, + ), + ).resolves.toEqual( + expect.objectContaining({ + body: '', + statusCode: 200, + }), + ); + + expect(connectionManager.registerConnection).toHaveBeenCalledTimes(1); + expect(connectionManager.registerConnection).toHaveBeenCalledWith({ + endpoint: 'customdomain', + connectionId: '1', + }); + }); + it('refuses connection when onWebsocketConnect returns false', async () => { (connectionManager.registerConnection as jest.Mock).mockResolvedValueOnce( {},