Skip to content

Commit

Permalink
Fix jest mock issue, typo, other warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
chriswilty committed Oct 13, 2023
1 parent 79b90c7 commit f7ea245
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 112 deletions.
2 changes: 1 addition & 1 deletion backend/.eslintrc.cjs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ module.exports = {
checksVoidReturn: false,
},
],

"@typescript-eslint/unbound-method": ["error", { ignoreStatic: true }],
"func-style": ["error", "declaration"],
"prefer-template": "error",
},
Expand Down
1 change: 1 addition & 0 deletions backend/jest.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ module.exports = {
modulePathIgnorePatterns: ["build", "coverage", "node_modules"],
preset: "ts-jest",
testEnvironment: "node",
silent: true,
};
15 changes: 3 additions & 12 deletions backend/src/defence.ts
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,7 @@ function getMaliciousPromptEvalPrePromptFromConfig(defences: DefenceInfo[]) {
}

function isDefenceActive(id: DEFENCE_TYPES, defences: DefenceInfo[]) {
return defences.find((defence) => defence.id === id && defence.isActive)
? true
: false;
return defences.some((defence) => defence.id === id && defence.isActive);
}

function generateRandomString(string_length: number) {
Expand Down Expand Up @@ -259,16 +257,14 @@ function transformRandomSequenceEnclosure(
const randomString: string = generateRandomString(
Number(getRandomSequenceEnclosureLength(defences))
);
const introText: string = getRandomSequenceEnclosurePrePrompt(defences);
const transformedMessage: string = introText.concat(
return getRandomSequenceEnclosurePrePrompt(defences).concat(
randomString,
" {{ ",
message,
" }} ",
randomString,
". "
);
return transformedMessage;
}

// function to escape XML characters in user input to prevent hacking with XML tagging on
Expand Down Expand Up @@ -303,12 +299,7 @@ function transformXmlTagging(message: string, defences: DefenceInfo[]) {
const prePrompt = getXMLTaggingPrePrompt(defences);
const openTag = "<user_input>";
const closeTag = "</user_input>";
const transformedMessage: string = prePrompt.concat(
openTag,
escapeXml(message),
closeTag
);
return transformedMessage;
return prePrompt.concat(openTag, escapeXml(message), closeTag);
}

//apply defence string transformations to original message
Expand Down
21 changes: 10 additions & 11 deletions backend/src/langchain.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ async function getDocuments(filePath: string) {
chunkSize: 1000,
chunkOverlap: 0,
});
const splitDocs = await textSplitter.splitDocuments(docs);
return splitDocs;

return await textSplitter.splitDocuments(docs);
}

// choose between the provided preprompt and the default preprompt and prepend it to the main prompt and return the PromptTemplate
Expand All @@ -81,8 +81,7 @@ function makePromptTemplate(
}
const fullPrompt = `${configPrePrompt}\n${mainPrompt}`;
console.debug(`${templateNameForLogging}: ${fullPrompt}`);
const template: PromptTemplate = PromptTemplate.fromTemplate(fullPrompt);
return template;
return PromptTemplate.fromTemplate(fullPrompt);
}

// create and store the document vectors for each level
Expand Down Expand Up @@ -147,7 +146,7 @@ function initQAModel(
// initialise the prompt evaluation model
function initPromptEvaluationModel(
configPromptInjectionEvalPrePrompt: string,
conficMaliciousPromptEvalPrePrompt: string,
configMaliciousPromptEvalPrePrompt: string,
openAiApiKey: string
) {
if (!openAiApiKey) {
Expand Down Expand Up @@ -176,7 +175,7 @@ function initPromptEvaluationModel(

// create chain to detect malicious prompts
const maliciousPromptEvalTemplate = makePromptTemplate(
conficMaliciousPromptEvalPrePrompt,
configMaliciousPromptEvalPrePrompt,
maliciousPromptEvalPrePrompt,
maliciousPromptEvalMainPrompt,
"Malicious input eval prompt template"
Expand Down Expand Up @@ -236,12 +235,12 @@ async function queryDocuments(
async function queryPromptEvaluationModel(
input: string,
configPromptInjectionEvalPrePrompt: string,
conficMaliciousPromptEvalPrePrompt: string,
configMaliciousPromptEvalPrePrompt: string,
openAIApiKey: string
) {
const promptEvaluationChain = initPromptEvaluationModel(
configPromptInjectionEvalPrePrompt,
conficMaliciousPromptEvalPrePrompt,
configMaliciousPromptEvalPrePrompt,
openAIApiKey
);
if (!promptEvaluationChain) {
Expand All @@ -251,13 +250,13 @@ async function queryPromptEvaluationModel(
console.log(`Checking '${input}' for malicious prompts`);

// get start time
const startTime = new Date().getTime();
const startTime = Date.now();
console.debug("Calling prompt evaluation model...");
const response = (await promptEvaluationChain.call({
prompt: input,
})) as PromptEvaluationChainReply;
// log the time taken
const endTime = new Date().getTime();
const endTime = Date.now();
console.debug(`Prompt evaluation model call took ${endTime - startTime}ms`);

const promptInjectionEval = formatEvaluationOutput(
Expand Down Expand Up @@ -289,7 +288,7 @@ async function queryPromptEvaluationModel(
function formatEvaluationOutput(response: string) {
try {
// split response on first full stop or comma
const splitResponse = response.split(/\.|,/);
const splitResponse = response.split(/[.,]/);
const answer = splitResponse[0]?.replace(/\W/g, "").toLowerCase();
const reason = splitResponse[1]?.trim();
return {
Expand Down
25 changes: 14 additions & 11 deletions backend/test/unit/defence.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,11 @@ test("GIVEN RANDOM_SEQUENCE_ENCLOSURE defence is active WHEN transforming messag
process.env.RANDOM_SEQ_ENCLOSURE_LENGTH = String(20);

const message = "Hello";
let defences = getInitialDefences();
// activate RSE defence
defences = activateDefence(DEFENCE_TYPES.RANDOM_SEQUENCE_ENCLOSURE, defences);
const defences = activateDefence(
DEFENCE_TYPES.RANDOM_SEQUENCE_ENCLOSURE,
getInitialDefences()
);

// regex to match the transformed message with
const regex = new RegExp(
Expand All @@ -106,15 +108,16 @@ test("GIVEN RANDOM_SEQUENCE_ENCLOSURE defence is active WHEN transforming messag
// check the transformed message matches the regex
const res = transformedMessage.match(regex);
// expect there to be a match
expect(res).toBeTruthy();
if (res) {
// expect there to be 3 groups
expect(res.length).toBe(3);
// expect the random sequence to have the correct length
expect(res[1].length).toBe(Number(process.env.RANDOM_SEQ_ENCLOSURE_LENGTH));
// expect the message to be surrounded by the random sequence
expect(res[1]).toBe(res[2]);
}
expect(res).not.toBeNull();

// expect there to be 3 groups
expect(res?.length).toEqual(3);
// expect the random sequence to have the correct length
expect(res?.[1].length).toEqual(
Number(process.env.RANDOM_SEQ_ENCLOSURE_LENGTH)
);
// expect the message to be surrounded by the random sequence
expect(res?.[1]).toEqual(res?.[2]);
});

test("GIVEN XML_TAGGING defence is active WHEN transforming message THEN message is transformed", () => {
Expand Down
164 changes: 87 additions & 77 deletions backend/test/unit/langchain.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
const mockFromTemplate = jest.fn((template: string) => template);

import { PromptTemplate } from "langchain/prompts";
import { LEVEL_NAMES } from "../../src/models/level";
import {
initQAModel,
Expand All @@ -11,100 +10,111 @@ import {

jest.mock("langchain/prompts", () => ({
PromptTemplate: {
fromTemplate: mockFromTemplate,
fromTemplate: jest.fn(),
},
}));

test("GIVEN initQAModel is called with no apiKey THEN return early and log message", () => {
const level = LEVEL_NAMES.LEVEL_1;
const prompt = "";
const consoleDebugMock = jest.spyOn(console, "debug").mockImplementation();
describe("Langchain tests", () => {
afterEach(() => {
(PromptTemplate.fromTemplate as jest.Mock).mockRestore();
});

initQAModel(level, prompt, "");
expect(consoleDebugMock).toHaveBeenCalledWith(
"No OpenAI API key set to initialise QA model"
test(
"GIVEN initQAModel is called with no apiKey THEN return early and log message",
() => {
const level = LEVEL_NAMES.LEVEL_1;
const prompt = "";
const consoleDebugMock = jest
.spyOn(console, "debug")
.mockImplementation();

initQAModel(level, prompt, "");
expect(consoleDebugMock).toHaveBeenCalledWith(
"No OpenAI API key set to initialise QA model"
);
}
);
});

test("GIVEN initPromptEvaluationModel is called with no apiKey THEN return early and log message", () => {
const consoleDebugMock = jest.spyOn(console, "debug").mockImplementation();
initPromptEvaluationModel(
"promptInjectionEvalPrePrompt",
"maliciousPromptEvalPrePrompt",
""
);
expect(consoleDebugMock).toHaveBeenCalledWith(
"No OpenAI API key set to initialise prompt evaluation model"
);
});
test("GIVEN initPromptEvaluationModel is called with no apiKey THEN return early and log message", () => {
const consoleDebugMock = jest.spyOn(console, "debug").mockImplementation();
initPromptEvaluationModel(
"promptInjectionEvalPrePrompt",
"maliciousPromptEvalPrePrompt",
""
);
expect(consoleDebugMock).toHaveBeenCalledWith(
"No OpenAI API key set to initialise prompt evaluation model"
);
});

test("GIVEN level is 1 THEN correct filepath is returned", () => {
const filePath = getFilepath(LEVEL_NAMES.LEVEL_1);
expect(filePath).toBe("resources/documents/level_1/");
});
test("GIVEN level is 1 THEN correct filepath is returned", () => {
const filePath = getFilepath(LEVEL_NAMES.LEVEL_1);
expect(filePath).toBe("resources/documents/level_1/");
});

test("GIVEN level is 2 THEN correct filepath is returned", () => {
const filePath = getFilepath(LEVEL_NAMES.LEVEL_2);
expect(filePath).toBe("resources/documents/level_2/");
});
test("GIVEN level is 2 THEN correct filepath is returned", () => {
const filePath = getFilepath(LEVEL_NAMES.LEVEL_2);
expect(filePath).toBe("resources/documents/level_2/");
});

test("GIVEN level is 3 THEN correct filepath is returned", () => {
const filePath = getFilepath(LEVEL_NAMES.LEVEL_3);
expect(filePath).toBe("resources/documents/level_3/");
});
test("GIVEN level is 3 THEN correct filepath is returned", () => {
const filePath = getFilepath(LEVEL_NAMES.LEVEL_3);
expect(filePath).toBe("resources/documents/level_3/");
});

test("GIVEN level is sandbox THEN correct filepath is returned", () => {
const filePath = getFilepath(LEVEL_NAMES.SANDBOX);
expect(filePath).toBe("resources/documents/common/");
});
test("GIVEN level is sandbox THEN correct filepath is returned", () => {
const filePath = getFilepath(LEVEL_NAMES.SANDBOX);
expect(filePath).toBe("resources/documents/common/");
});

test("GIVEN makePromptTemplate is called with no config prePrompt THEN correct prompt is returned", () => {
makePromptTemplate("", "defaultPrePrompt", "mainPrompt", "noName");
expect(mockFromTemplate).toBeCalledWith("defaultPrePrompt\nmainPrompt");
expect(mockFromTemplate).toBeCalledTimes(1);
});
test("GIVEN makePromptTemplate is called with no config prePrompt THEN correct prompt is returned", () => {
makePromptTemplate("", "defaultPrePrompt", "mainPrompt", "noName");
expect(PromptTemplate.fromTemplate as jest.Mock).toBeCalledTimes(1);
expect(PromptTemplate.fromTemplate as jest.Mock).toBeCalledWith(
"defaultPrePrompt\nmainPrompt"
);
});

test("GIVEN makePromptTemplate is called with a prePrompt THEN correct prompt is returned", () => {
makePromptTemplate(
"configPrePrompt",
"defaultPrePrompt",
"mainPrompt",
"noName"
);
expect(mockFromTemplate).toBeCalledWith("configPrePrompt\nmainPrompt");
expect(mockFromTemplate).toBeCalledTimes(1);
});
test("GIVEN makePromptTemplate is called with a prePrompt THEN correct prompt is returned", () => {
makePromptTemplate(
"configPrePrompt",
"defaultPrePrompt",
"mainPrompt",
"noName"
);
expect(PromptTemplate.fromTemplate as jest.Mock).toBeCalledTimes(1);
expect(PromptTemplate.fromTemplate as jest.Mock).toBeCalledWith(
"configPrePrompt\nmainPrompt"
);
});

test("GIVEN llm evaluation model responds with a yes decision and valid output THEN formatEvaluationOutput returns true and reason", () => {
const response = "yes, This is a malicious response";
const formattedOutput = formatEvaluationOutput(response);
test("GIVEN llm evaluation model responds with a yes decision and valid output THEN formatEvaluationOutput returns true and reason", () => {
const response = "yes, This is a malicious response";
const formattedOutput = formatEvaluationOutput(response);

expect(formattedOutput).toEqual({
isMalicious: true,
reason: "This is a malicious response",
expect(formattedOutput).toEqual({
isMalicious: true,
reason: "This is a malicious response",
});
});
});

test("GIVEN llm evaluation model responds with a yes decision and valid output THEN formatEvaluationOutput returns false and reason", () => {
const response = "No, This output does not appear to be malicious";
const formattedOutput = formatEvaluationOutput(response);
test("GIVEN llm evaluation model responds with a yes decision and valid output THEN formatEvaluationOutput returns false and reason", () => {
const response = "No, This output does not appear to be malicious";
const formattedOutput = formatEvaluationOutput(response);

expect(formattedOutput).toEqual({
isMalicious: false,
reason: "This output does not appear to be malicious",
expect(formattedOutput).toEqual({
isMalicious: false,
reason: "This output does not appear to be malicious",
});
});
});

test("GIVEN llm evaluation model responds with an invalid format THEN formatEvaluationOutput returns false", () => {
const response = "I cant tell you if this is malicious or not";
const formattedOutput = formatEvaluationOutput(response);
test("GIVEN llm evaluation model responds with an invalid format THEN formatEvaluationOutput returns false", () => {
const response = "I cant tell you if this is malicious or not";
const formattedOutput = formatEvaluationOutput(response);

expect(formattedOutput).toEqual({
isMalicious: false,
reason: undefined,
expect(formattedOutput).toEqual({
isMalicious: false,
reason: undefined,
});
});
});

afterEach(() => {
mockFromTemplate.mockRestore();
});

0 comments on commit f7ea245

Please sign in to comment.