diff --git a/backend/package-lock.json b/backend/package-lock.json index cdf0ea602..c98a35edd 100644 --- a/backend/package-lock.json +++ b/backend/package-lock.json @@ -17,7 +17,8 @@ "memorystore": "^1.6.7", "openai": "^4.19.0", "openai-chat-tokens": "^0.2.8", - "pdf-parse": "^1.1.1" + "pdf-parse": "^1.1.1", + "query-types": "^0.1.4" }, "devDependencies": { "@jest/globals": "^29.7.0", @@ -6847,6 +6848,11 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/query-types": { + "version": "0.1.4", + "resolved": "https://registry.npmjs.org/query-types/-/query-types-0.1.4.tgz", + "integrity": "sha512-DvOCeSGI+RJeZ8+13NYTZUGbGva3P85GmATpY2ojj0mIsNfw70xeTpQW51JaYDZyQ2Li/svGT+FUdd4CBUoDcQ==" + }, "node_modules/queue-microtask": { "version": "1.2.3", "resolved": "https://registry.npmjs.org/queue-microtask/-/queue-microtask-1.2.3.tgz", diff --git a/backend/package.json b/backend/package.json index 10865e543..c56b25471 100644 --- a/backend/package.json +++ b/backend/package.json @@ -23,7 +23,8 @@ "memorystore": "^1.6.7", "openai": "^4.19.0", "openai-chat-tokens": "^0.2.8", - "pdf-parse": "^1.1.1" + "pdf-parse": "^1.1.1", + "query-types": "^0.1.4" }, "devDependencies": { "@jest/globals": "^29.7.0", diff --git a/backend/src/app.ts b/backend/src/app.ts index 5a965a1ec..7bf9a1c21 100644 --- a/backend/src/app.ts +++ b/backend/src/app.ts @@ -1,11 +1,13 @@ import cors from 'cors'; import express from 'express'; +import queryTypes from 'query-types'; import nonSessionRoutes from './nonSessionRoutes'; import sessionRoutes from './sessionRoutes'; export default express() .use(express.json()) + .use(queryTypes.middleware()) .use( cors({ origin: process.env.CORS_ALLOW_ORIGIN, diff --git a/backend/src/controller/startController.ts b/backend/src/controller/startController.ts index fd7bd30c5..8ce03d812 100644 --- a/backend/src/controller/startController.ts +++ b/backend/src/controller/startController.ts @@ -1,6 +1,9 @@ import { Response } from 'express'; -import { StartGetRequest } from '@src/models/api/StartGetRequest'; +import { + StartGetRequest, + StartResponse, +} from '@src/models/api/StartGetRequest'; import { LEVEL_NAMES, isValidLevel } from '@src/models/level'; import { getValidOpenAIModels } from '@src/openai'; import { @@ -36,7 +39,7 @@ function handleStart(req: StartGetRequest, res: Response) { defences: req.session.levelState[level].defences, availableModels: getValidOpenAIModels(), systemRoles, - }); + } as StartResponse); } export { handleStart }; diff --git a/backend/src/models/api/StartGetRequest.ts b/backend/src/models/api/StartGetRequest.ts index 188b6a310..344c12adf 100644 --- a/backend/src/models/api/StartGetRequest.ts +++ b/backend/src/models/api/StartGetRequest.ts @@ -5,15 +5,20 @@ import { Defence } from '@src/models/defence'; import { EmailInfo } from '@src/models/email'; import { LEVEL_NAMES } from '@src/models/level'; +export type StartResponse = { + emails: EmailInfo[]; + chatHistory: ChatMessage[]; + defences: Defence[]; + availableModels: string[]; + systemRoles: { + level: LEVEL_NAMES; + systemRole: string; + }[]; +}; + export type StartGetRequest = Request< never, - { - emails: EmailInfo[]; - chatHistory: ChatMessage[]; - defences: Defence[]; - availableModels: string[]; - systemRoles: string[]; - }, + StartResponse, never, { level?: LEVEL_NAMES; diff --git a/backend/test/api/start.test.ts b/backend/test/api/start.test.ts new file mode 100644 index 000000000..bef928f59 --- /dev/null +++ b/backend/test/api/start.test.ts @@ -0,0 +1,54 @@ +import { beforeAll, describe, expect, it, jest } from '@jest/globals'; +import { OpenAI } from 'openai'; +import request from 'supertest'; + +import app from '@src/app'; +import { StartResponse } from '@src/models/api/StartGetRequest'; +import { LEVEL_NAMES } from '@src/models/level'; + +jest.mock('openai'); + +const PATH = '/start'; + +describe('/start endpoints', () => { + const mockListFn = jest.fn(); + jest.mocked(OpenAI).mockImplementation( + () => + ({ + models: { + list: mockListFn, + }, + } as unknown as jest.MockedObject) + ); + + beforeAll(() => { + mockListFn.mockResolvedValue({ + data: [{ id: 'gpt-3.5-turbo' }], + } as OpenAI.ModelsPage); + }); + + it.each(Object.values(LEVEL_NAMES))( + 'WHEN given valid level [%s] THEN it responds with 200', + async (level) => + request(app) + .get(`${PATH}?level=${level}`) + .expect(200) + .expect('Content-Type', /application\/json/) + .then((response) => { + const { chatHistory, emails } = response.body as StartResponse; + expect(chatHistory).toEqual([]); + expect(emails).toEqual([]); + }) + ); + + it.each([-1, 4, 'SANDBOX'])( + 'WHEN given invalid level [%s] THEN it responds with 400', + async (level) => + request(app) + .get(`${PATH}?level=${level}`) + .expect(400) + .then((response) => { + expect(response.text).toEqual('Invalid level'); + }) + ); +}); diff --git a/backend/test/setupEnvVars.ts b/backend/test/setupEnvVars.ts index 24cd1fce3..a2d47c58b 100644 --- a/backend/test/setupEnvVars.ts +++ b/backend/test/setupEnvVars.ts @@ -1,2 +1,3 @@ // set the environment variables process.env.OPENAI_API_KEY = 'sk-12345'; +process.env.SESSION_SECRET = "shhh! Don't tell anyone..."; diff --git a/backend/typings/query-types.d.ts b/backend/typings/query-types.d.ts new file mode 100644 index 000000000..8a7524645 --- /dev/null +++ b/backend/typings/query-types.d.ts @@ -0,0 +1,5 @@ +declare module 'query-types' { + import type { NextHandleFunction } from 'connect'; + + function middleware(): NextHandleFunction; +}