From 03e5c7df8cadaf817b4b0d1ff82e6cb7e3c65620 Mon Sep 17 00:00:00 2001 From: Julien Vignoud Date: Wed, 23 Oct 2024 13:58:01 +0200 Subject: [PATCH] *: replace line by line text loaders by chunk by chunk text loaders. Loaders now yield token sequences of length blockSize --- cli/src/benchmark_gpt.ts | 43 +++----- cli/src/train_gpt.ts | 42 +++++--- discojs-node/src/loaders.spec.ts | 78 +++++++++++++- discojs-node/src/loaders/text.ts | 100 ++++++++++++++++-- discojs-web/src/loaders.spec.ts | 55 +++++++--- discojs-web/src/loaders/text.ts | 77 ++++++++++---- .../preprocessing/text_preprocessing.spec.ts | 29 ++++- .../data/preprocessing/text_preprocessing.ts | 24 +++++ discojs/src/dataset/types.ts | 2 +- discojs/src/default_tasks/wikitext.ts | 6 +- discojs/src/models/gpt/model.ts | 7 +- discojs/src/models/index.ts | 2 +- discojs/src/models/tokenizer.ts | 36 +++++++ discojs/src/validation/validator.ts | 7 ++ docs/examples/wikitext.ts | 13 +-- server/tests/e2e/federated.spec.ts | 12 ++- webapp/cypress.config.ts | 4 + webapp/cypress/e2e/testing.cy.ts | 23 ++-- webapp/cypress/support/e2e.ts | 2 + .../dataset_input/FileSelection.vue | 24 +++-- .../dataset_input/TextDatasetInput.vue | 27 +++-- .../src/components/testing/PredictSteps.vue | 1 + webapp/src/components/testing/TestSteps.vue | 5 +- .../src/components/training/TrainingSteps.vue | 1 + 24 files changed, 490 insertions(+), 130 deletions(-) diff --git a/cli/src/benchmark_gpt.ts b/cli/src/benchmark_gpt.ts index 6fc06f39f..77653d920 100644 --- a/cli/src/benchmark_gpt.ts +++ b/cli/src/benchmark_gpt.ts @@ -2,7 +2,7 @@ import { parse } from "ts-command-line-args"; import * as tf from "@tensorflow/tfjs" import { AutoTokenizer } from "@xenova/transformers"; -import { fetchTasks, models, async_iterator, defaultTasks, processing } from "@epfml/discojs"; +import { fetchTasks, models, async_iterator, defaultTasks } from "@epfml/discojs"; import { loadModelFromDisk, loadText } from '@epfml/discojs-node' import { Server } from "server"; @@ -86,36 +86,17 @@ async function main(args: Required): Promise { // to make sure the dataset is batched and tokenized correctly task.trainingInformation.batchSize = batchSize task.trainingInformation.maxSequenceLength = contextLength - const dataset = loadText('../datasets/wikitext/wiki.train.tokens') - - const maxLength = task.trainingInformation.maxSequenceLength ?? (tokenizer.model_max_length as number) + 1 - // TODO will be easier when preproccessing is redone - const preprocessedDataset = intoTFGenerator( - dataset - .map((line) => - processing.tokenizeAndLeftPad(line, tokenizer, maxLength), - ) - .batch(batchSize) - .map((batch) => - tf.tidy(() => ({ - xs: tf.tensor2d( - batch.map((tokens) => tokens.slice(0, -1)).toArray(), - ), - ys: tf.stack( - batch - .map( - (tokens) => - tf.oneHot( - tokens.slice(1), - tokenizer.model.vocab.length, - ) as tf.Tensor2D, - ) - .toArray(), - ) as tf.Tensor3D, - })), - ), - ); - + const dataset = loadText( + '../datasets/wikitext/wiki.train.tokens', + tokenizer, config.blockSize, batchSize + ) + // TODO will be easier when preprocessing is redone + const preprocessedDataset = intoTFGenerator(dataset).map((tokens: number[]) => { + const ys = tf.oneHot(tokens.slice(1), tokenizer.model.vocab.length) + const xs = tf.tensor(tokens.slice(0, config.blockSize), undefined, 'int32') + return {xs, ys} + }).batch(batchSize) as tf.data.Dataset<{ xs: tf.Tensor2D, ys: tf.Tensor3D }> + // Init and train the model const model = new models.GPT(config) console.log(`\tmodel type ${modelType} \n\tbatch size ${batchSize} \n\tcontext length ${contextLength}`) diff --git a/cli/src/train_gpt.ts b/cli/src/train_gpt.ts index 60466bde0..cf5adf3b4 100644 --- a/cli/src/train_gpt.ts +++ b/cli/src/train_gpt.ts @@ -1,38 +1,52 @@ import * as tf from "@tensorflow/tfjs-node" import { AutoTokenizer } from "@xenova/transformers"; -import { models, processing } from "@epfml/discojs"; +import { models } from "@epfml/discojs"; +import { loadText } from '@epfml/discojs-node' -async function main(): Promise { - const data = "Lorem ipsum dolor sit amet, consectetur adipis" - const datasetSource = new tf.data.FileDataSource(Buffer.from(data)) - const textDataset = new tf.data.TextLineDataset(datasetSource) +function intoTFGenerator( + iter: AsyncIterable, +): tf.data.Dataset { + // @ts-expect-error generator + return tf.data.generator(async function* () { + yield* iter; + }); +} +async function main(): Promise { + const config: models.GPTConfig = { modelType: 'gpt-nano', lr: 0.01, - maxIter: 50, + maxIter: 10, evaluateEvery:50, maxEvalBatches: 10, blockSize: 16, vocabSize: 50257, debug: false } - + + const batchSize = 8 const tokenizer = await AutoTokenizer.from_pretrained('Xenova/gpt2') - const tokenDataset = textDataset.map((text: string) => { - const tokens = processing.tokenizeAndLeftPad(text, tokenizer, config.blockSize + 1) + const textDataset = loadText( + "../datasets/wikitext/wiki.train.tokens", + tokenizer, config.blockSize, batchSize + ) + + const tokenDataset = intoTFGenerator(textDataset).map((tokens: number[]) => { const ys = tf.oneHot(tokens.slice(1), tokenizer.model.vocab.length) const xs = tf.tensor(tokens.slice(0, config.blockSize), undefined, 'int32') return {xs, ys} - }).repeat().batch(16) as tf.data.Dataset<{ xs: tf.Tensor2D, ys: tf.Tensor3D }> + }).batch(batchSize) as tf.data.Dataset<{ xs: tf.Tensor2D, ys: tf.Tensor3D }> const model = new models.GPT(config) - - for await (const logs of model.train(tokenDataset, undefined)) { - console.log(logs) + for (let i = 0; i < 6; i++) { + console.log(`Epoch ${i}`) + for await (const logs of model.train(tokenDataset, undefined)) { + console.log(logs) + } } - const generation = await model.generate("Lorem", tokenizer, { maxNewTokens: 10, doSample: false, topk: 5, temperature:0.1 }) + const generation = await model.generate("First", tokenizer, { maxNewTokens: 10, doSample: false, topk: 5, temperature:0.1 }) console.log(generation) } diff --git a/discojs-node/src/loaders.spec.ts b/discojs-node/src/loaders.spec.ts index c1f94d5a9..4772d68c9 100644 --- a/discojs-node/src/loaders.spec.ts +++ b/discojs-node/src/loaders.spec.ts @@ -2,6 +2,8 @@ import * as fs from "node:fs/promises"; import { withFile } from "tmp-promise"; import { describe, it } from "mocha"; import { expect } from "chai"; +import { models } from "@epfml/discojs"; +import { AutoTokenizer } from "@xenova/transformers"; import { loadCSV, @@ -50,13 +52,83 @@ describe("image directory parser", () => { }); describe("text parser", () => { + it("parses basic file", async () => { + const tokenizer = await AutoTokenizer.from_pretrained('Xenova/gpt2') await withFile(async ({ path }) => { - await fs.writeFile(path, ["a", "b", "c"].join("\n")); + const text = ["a", "b", "c"].join("\n") + await fs.writeFile(path, text); + // set block size to 4 to get 1 sequence of 4 tokens + 1 label token + const parsed = loadText(path, tokenizer, 4, 1); + expect(await parsed.size()).to.equal(1); + const next = await parsed[Symbol.asyncIterator]().next() + expect(next.done).to.be.false; + + const tokens = next.value as number[] + const expectedTokens = models.tokenize(tokenizer, text, { + padding: false, + truncation: false, + return_tensor: false + }) + expect(tokens).to.deep.equal(expectedTokens); + }); + }); + + it("yields the correct block size", async () => { + const tokenizer = await AutoTokenizer.from_pretrained('Xenova/gpt2') + await withFile(async ({ path }) => { + const text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit." + const expectedTokens = models.tokenize(tokenizer, text, { + padding: false, + truncation: false, + return_tensor: false + }) + await fs.writeFile(path, text); - const parsed = loadText(path); + // set block size to 4 to get 1 sequence of 4 tokens + 1 label token + // so we expect 5 tokens per read + const blockSize = 4 + const parsed = loadText(path, tokenizer, blockSize, 1); + // expect the number of sequences to be the total number of tokens divided by blockSize + // we use floor because the last incomplete sequence is dropped + expect(await parsed.size()).to.equal(Math.floor(expectedTokens.length / blockSize)); + + let i = 0 + for await (const tokens of parsed) { + // each sequence should have length blockSize + 1 (for the label) + expect(tokens).to.deep.equal(expectedTokens.slice(i, i + blockSize + 1)); + // but the window should move by blockSize only + i += blockSize + } + }) + }); + + it("reads multiple chunks", async () => { + const tokenizer = await AutoTokenizer.from_pretrained('Xenova/gpt2') + await withFile(async ({ path }) => { + const text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Donec sed risus maximus, ultricies ex sed, dictum elit. Curabitur faucibus egestas enim et auctor. Quisque vel dignissim turpis. Curabitur justo tellus, elementum sit amet erat eget, auctor ornare nisi. Nunc tortor odio, ultrices id leo vitae, euismod congue ex. Curabitur arcu leo, sagittis quis felis nec, imperdiet aliquet tellus. Integer a mollis nulla. Quisque pulvinar lectus eget nisi pharetra, non molestie magna ullamcorper. Sed porttitor diam non blandit molestie. Duis tristique arcu ut efficitur efficitur. Fusce et ullamcorper tortor. Pellentesque a accumsan lacus, nec mollis risus. Nunc quis eros a orci ultricies cursus. Maecenas sodales ipsum a magna malesuada efficitur. Maecenas at sapien blandit, egestas nisi eu, mollis elit." + const expectedTokens = models.tokenize(tokenizer, text, { + padding: false, + truncation: false, + return_tensor: false + }) + await fs.writeFile(path, text); - expect(await parsed.size()).to.equal(3); + // set block size to 4 to get 1 sequence of 4 tokens + 1 label token + // so we expect 5 tokens per read + const blockSize = 4 + const parsed = loadText(path, tokenizer, blockSize, 1, 1); // set the min chunk size allowed to 1 bit + // expect the number of sequences to be the total number of tokens divided by blockSize + // we use floor because the last incomplete sequence is dropped + expect(await parsed.size()).to.equal(Math.floor(expectedTokens.length / blockSize)); + + let i = 0 + for await (const tokens of parsed) { + // each sequence should have length blockSize + 1 (for the label) + expect(tokens).to.deep.equal(expectedTokens.slice(i, i + blockSize + 1)); + // but the window should move by blockSize only + i += blockSize + } }); }); }); diff --git a/discojs-node/src/loaders/text.ts b/discojs-node/src/loaders/text.ts index c1ae840a2..f4105fa80 100644 --- a/discojs-node/src/loaders/text.ts +++ b/discojs-node/src/loaders/text.ts @@ -1,14 +1,98 @@ -import * as fs from "node:fs/promises"; -import * as readline from "node:readline/promises"; +import createDebug from "debug"; +import { createReadStream } from 'node:fs'; -import { Dataset, Text } from "@epfml/discojs"; +import { PreTrainedTokenizer } from '@xenova/transformers'; +import { Dataset, Text, models } from "@epfml/discojs"; -export function load(path: string): Dataset { +const debug = createDebug("discojs-node:loaders:text"); + +/** + * Returns a Dataset that streams and tokenizes text to yield tokenized sequences + * one at a time. Each sequence has size `blockSize` + 1, where the first `blockSize` + * tokens are the input and the last token is the label. The following sequence + * starts with the last token of the previous sequence (so the previous label is now the + * first input token). + * In other words, the dataset yields sequences of size `blockSize` + 1 but with an overlap + * of 1 token between each sequence. + * + * @param path path to the text file to read + * @param tokenizer the tokenizer to use, should match the model that will be trained + * @param blockSize the context length, the maximum number of tokens of input sequences + * @param batchSize default to 1, the number of input sequences (of `blockSize` tokens) in each batch. + * The batch size is only used to configure the chunk size of the file stream such that each chunk is + * big enough to contain at least one batch. + * @param minChunkSize default to 16KiB, the minimum size of each chunk in bits + * @returns a dataset of tokenized input and label sequences + */ +export function load(path: string, tokenizer: PreTrainedTokenizer, + blockSize: number, batchSize: number = 1, minChunkSize = 16384): Dataset { return new Dataset(async function* () { - const input = (await fs.open(path)).createReadStream({ encoding: "utf8" }); + if (batchSize < 1 || blockSize < 1 || minChunkSize < 1) + throw new Error("batchSize, blockSize and minChunkSize must be positive integers"); + // we want each chunk to be at least bigger than the block size (each chunk corresponds to a block) + // (or event bigger than batch size * block size so that each chunk corresponds to a batch) + const chunkTokenSize = batchSize * (blockSize + 1) // + 1 for the next word label ys + // We read 8*8 = 8 bytes per expected token to ensure we have enough tokens + // For reference, the GPT-2 tokenizer encodes 3 to 4 bytes per token on average + const chunkBitSize = Math.max(minChunkSize, chunkTokenSize * 8 * 8); + debug("Setting the chunk size to %o bits", chunkBitSize) + // Create a stream to read the text file chunk by chunk + const stream = createReadStream(path, { + encoding: "utf8", + highWaterMark: chunkBitSize + }); - // `readline` is a bit overkill but seems standard - // https://nodejs.org/api/readline.html#example-read-file-stream-line-by-line - yield* readline.createInterface({ input, crlfDelay: Infinity }); + // iterate over the chunks + let endOfPreviousChunk = "" + let iteration = 0 + for await (const chunk of stream) { + if (typeof chunk !== 'string') throw new Error('Expected file stream to yield string') + debug("Reading chunk of size %o", chunk.length) + // tokenize the whole chunk at once + // Concatenate with potential leftovers from the previous chunk + const tokens = models.tokenize(tokenizer, endOfPreviousChunk + chunk, { + padding: false, + truncation: false, + return_tensor: false, + }) + if (tokens.length < blockSize + 1) { + // throw if it happens on the 1st iteration + if (iteration === 0) + throw new Error(`the chunk (${tokens.length} tokens) is too small ` + + `to get a sequence of length blockSize (${blockSize + 1} tokens). ` + + `Either the text file or the chunk size (${chunkBitSize} bits) is too small.`); + // if this isn't the first iteration we simply skip + // as we expect the last chunk to be potentially smaller than the block size + debug("chunk smaller than block size, loading next chunk") + continue + } + debug("batch per chunk: %o", tokens.length / (batchSize * blockSize)) + let currentPosition = 0; + // yield one block of tokens at a time + while (currentPosition + blockSize + 1 <= tokens.length) { + yield tokens.slice(currentPosition, currentPosition + blockSize + 1); + currentPosition += blockSize; // don't add + 1 here + } + // keep the last tokens for the next chunk + // if this was the last one the remaining tokens are discarded + if (currentPosition < tokens.length) { + // We actually need to decode the tokens to get the leftover text + // instead of simply keeping the remaining tokens. + // this is because the tokens may be different once prepended to the next chunk + // e.g. if the remaining text is ". A" and the next chunk starts with "nother" + // the tokenization will be different than if we simply concatenate the remaining tokens + endOfPreviousChunk = tokenizer.decode( + tokens.slice(currentPosition), + { skip_special_tokens: true } + ) + debug("End of chunk, remaining text: '%s'", endOfPreviousChunk) + } else { + // Note that the difference between tokenizing and then concatenating + // vs concatenating and then tokenizing can happen if their is no + // remaining text. We consider this difference negligible + endOfPreviousChunk = ""; + } + iteration++; + } }); } diff --git a/discojs-web/src/loaders.spec.ts b/discojs-web/src/loaders.spec.ts index 603eb292c..73cbde648 100644 --- a/discojs-web/src/loaders.spec.ts +++ b/discojs-web/src/loaders.spec.ts @@ -1,5 +1,7 @@ +import { AutoTokenizer } from "@xenova/transformers"; import { describe, it, expect } from "vitest"; +import { models } from "@epfml/discojs"; import { loadCSV, loadText } from "./loaders/index.js"; async function arrayFromAsync(iter: AsyncIterable): Promise { @@ -22,22 +24,51 @@ describe("csv parser", () => { }); describe("text parser", () => { - it("loads", async () => { + it("loads a simple sequence", async () => { + const text = ["first", "second", "third"].join("\n") + + const tokenizer = await AutoTokenizer.from_pretrained('Xenova/gpt2') + const expectedTokens = models.tokenize(tokenizer, text, { + padding: false, + truncation: false, + return_tensor: false, + }) + // jsdom doesn't implement .text on File/Blob // trick from https://github.com/jsdom/jsdom/issues/2555 - const text = await ( - await fetch( - // data URL content need to be url-encoded - ["data:,first", "second", "third"].join("%0A"), - ) + const file = await ( + await fetch( "data:," + encodeURIComponent(text)) ).blob(); + const parsed = loadText(file, tokenizer, 4); - const parsed = loadText(text); + expect(await parsed.size()).to.equal(1); // expect a single sequence + expect((await arrayFromAsync(parsed))[0]).to.deep.equal(expectedTokens); + }); - expect(await arrayFromAsync(parsed)).to.have.ordered.members([ - "first", - "second", - "third", - ]); + it("yields the correct block size", async () => { + const text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed quis faucibus ipsum." + + const tokenizer = await AutoTokenizer.from_pretrained('Xenova/gpt2') + const expectedTokens = models.tokenize(tokenizer, text, { + padding: false, + truncation: false, + return_tensor: false + }) + + const file = await ( + await fetch("data:," + encodeURIComponent(text)) + ).blob(); + + const blockSize = 4 + const parsed = loadText(file, tokenizer, blockSize); + expect(await parsed.size()).to.equal(Math.floor(expectedTokens.length / blockSize)); + + let i = 0 + for await (const tokens of parsed) { + // each sequence should have length blockSize + 1 (for the label) + expect(tokens).to.deep.equal(expectedTokens.slice(i, i + blockSize + 1)); + // but the window should move by blockSize only + i += blockSize + } }); }); diff --git a/discojs-web/src/loaders/text.ts b/discojs-web/src/loaders/text.ts index 0aee95d74..b6e29c787 100644 --- a/discojs-web/src/loaders/text.ts +++ b/discojs-web/src/loaders/text.ts @@ -1,35 +1,76 @@ -import { Dataset, Text } from "@epfml/discojs"; +import createDebug from "debug"; +import { Dataset, Text, models } from "@epfml/discojs"; +import { PreTrainedTokenizer } from '@xenova/transformers'; -class LineStream extends TransformStream { - constructor() { - let current_line = ""; +const debug = createDebug("discojs-web:loaders:text"); +/** + * Stream and tokenize text to yield tokenized sequences + * one at a time. Each sequence has size `blockSize` + 1, where the first `blockSize` + * tokens are the input and the last token is the label. The following sequence + * starts with the last token of the previous sequence (so the previous label is now the + * first input token). + * In other words, the stream yields sequences of size `blockSize` + 1 but with an overlap + * of 1 token between each sequence. + * + * @param file the file to read + * @param tokenizer the tokenizer to use, should match the model that will be trained + * @param blockSize the context length, the maximum number of tokens of input sequences + */ +class TokenizerStream extends TransformStream { + constructor(tokenizer: PreTrainedTokenizer, blockSize: number) { + let endOfPreviousChunk = "" + let chunkNumber = 0 super({ transform: (chunk, controller) => { - const [head, ...lines] = chunk.split(/\r\n|\r|\n/); - const first_line = current_line + head; - - if (lines.length === 0) { - current_line = first_line; - return; + debug("yield TokenizerStream chunk of length: %o", chunk.length); + // tokenize the whole chunk at once + const tokens = models.tokenize(tokenizer, endOfPreviousChunk + chunk, { + padding: false, + truncation: false, + return_tensor: false, + }); + if (tokens.length < blockSize + 1) { + // throw if it happens on the 1st chunk + if (chunkNumber === 0) + throw new Error(`the chunk (${tokens.length} tokens) is too small ` + + `to get a sequence of length blockSize (${blockSize + 1} tokens). ` + + `Either the text file or the chunk size is too small.`); + // if this isn't the first iteration we simply skip + // as we expect the last chunk to be potentially smaller than the block size + debug("chunk smaller than block size, loading next chunk") + return } - - controller.enqueue(first_line); - for (const line of lines.slice(0, -1)) controller.enqueue(line); - - current_line = lines[lines.length - 1]; + let currentPosition = 0; + // yield one block of tokens at a time + // add 1 to include the next token for the prediction label + while (currentPosition + blockSize + 1 <= tokens.length) { + controller.enqueue(tokens.slice(currentPosition, currentPosition + blockSize + 1)) + currentPosition += blockSize; // no +1 here + } + // keep the last tokens for the next chunk + // if this was the last chunk the remaining tokens are discarded + if (currentPosition < tokens.length) { + endOfPreviousChunk = tokenizer.decode( + tokens.slice(currentPosition), + { skip_special_tokens: true } + ) + } + else endOfPreviousChunk = ""; + chunkNumber++; }, - flush: (controller) => controller.enqueue(current_line), + // No flush, discard the last tokens }); } } -export function load(file: Blob): Dataset { +export function load(file: Blob, tokenizer: PreTrainedTokenizer, + blockSize: number): Dataset { return new Dataset(async function* () { const reader = file .stream() .pipeThrough(new TextDecoderStream()) - .pipeThrough(new LineStream()) + .pipeThrough(new TokenizerStream(tokenizer, blockSize)) .getReader(); while (true) { diff --git a/discojs/src/dataset/data/preprocessing/text_preprocessing.spec.ts b/discojs/src/dataset/data/preprocessing/text_preprocessing.spec.ts index 4e1b7862e..78501269b 100644 --- a/discojs/src/dataset/data/preprocessing/text_preprocessing.spec.ts +++ b/discojs/src/dataset/data/preprocessing/text_preprocessing.spec.ts @@ -5,7 +5,7 @@ import type { Task } from '../../../index.js' import * as tf from '@tensorflow/tfjs' describe('text preprocessing', function () { - const [tokenize, leftPadding] = TEXT_PREPROCESSING + const [toTFTensor, tokenize, leftPadding] = TEXT_PREPROCESSING // Use a function to create different task object for each test (otherwise the tokenizer gets cached) function initMockTask(): Task { return { @@ -97,5 +97,32 @@ describe('text preprocessing', function () { } throw new Error("invalid tokenizer name should have thrown an error") }) + + it('can convert a token sequence to tensors', async () => { + // Create a task where the model has a context length of 10 + const task = initMockTask() + // Create a token sequence of length 10 tokens + 1(for the label) + const tokens = [0,1,2,3,4,5,6,7,8,9,10] + const blockSize = tokens.length - 1 + task.trainingInformation.maxSequenceLength = blockSize + + const { xs, ys } = await toTFTensor.apply(Promise.resolve(tokens), task) as { xs: tf.Tensor1D, ys: tf.Tensor2D } + const xsArray = await xs.array() + const ysArray = await ys.array() + + // xsArray should simply be the input token sequence without the label token + expect(xsArray).to.be.deep.equal(tokens.slice(0, blockSize)) + + // ysArray should have shape (10, 50257), 50257 being the size of the vocab for gpt2 + expect(ysArray.length).to.be.equal(blockSize) + expect(ysArray[0].length).to.be.equal(50257) + + // ys should be a one hot encoding of the next token in xs + // So the sum of each row should be equal to 1 + const expectedOneHot = Array.from({ length: blockSize }).map(_ => 1) + expect(await ys.sum(-1).array()).to.be.deep.equal(expectedOneHot) + // and in each row the index of the 1 should be the token id + expect(await ys.argMax(-1).array()).to.be.deep.equal(tokens.slice(1)) + }) }) diff --git a/discojs/src/dataset/data/preprocessing/text_preprocessing.ts b/discojs/src/dataset/data/preprocessing/text_preprocessing.ts index 7481b6f9c..5b793c38e 100644 --- a/discojs/src/dataset/data/preprocessing/text_preprocessing.ts +++ b/discojs/src/dataset/data/preprocessing/text_preprocessing.ts @@ -9,6 +9,7 @@ import { models } from '../../../index.js' * Available text preprocessing types. */ export enum TextPreprocessing { + ToTFTensor, Tokenize, LeftPadding } @@ -119,10 +120,33 @@ const tokenize: PreprocessingFunction = { } } +/** + * Given a sequence of block size + 1 tokens, creates a 1D input tensor and a 2D one-hot encoded label tensor + */ +const toTFTensor: PreprocessingFunction = { + type: TextPreprocessing.ToTFTensor, + apply: async (x: Promise, task: Task): Promise<{ xs: tf.Tensor1D, ys: tf.Tensor2D }> => { + const tokens = await x + if (!isNumberArray(tokens)) + throw new Error("The toTFTensor preprocessing expects a array of numbers as input") + + const tokenizer = await models.getTaskTokenizer(task) + const maxLength = task.trainingInformation.maxSequenceLength ?? (tokenizer.model_max_length as number) + if (tokens.length != maxLength + 1) + throw new Error(`The toTFTensor preprocessing expects an array of length ${maxLength + 1}`) + // The inputs are truncated down to exactly maxSequenceLength + const xs = tf.tensor1d(tokens.slice(0, maxLength), 'int32') + const ys = tf.oneHot(tokens.slice(1), tokenizer.model.vocab.length) as tf.Tensor2D + + return { xs, ys } + } +} + /** * Available text preprocessing functions. */ export const AVAILABLE_PREPROCESSING = List.of( + toTFTensor, tokenize, leftPadding ).sortBy((e) => e.type) diff --git a/discojs/src/dataset/types.ts b/discojs/src/dataset/types.ts index 79c510ea0..28afb43db 100644 --- a/discojs/src/dataset/types.ts +++ b/discojs/src/dataset/types.ts @@ -2,4 +2,4 @@ import { Image } from "./image.js" export { Image }; export type Tabular = Partial>; -export type Text = string; +export type Text = number[]; diff --git a/discojs/src/default_tasks/wikitext.ts b/discojs/src/default_tasks/wikitext.ts index 518c060c2..1302b24d0 100644 --- a/discojs/src/default_tasks/wikitext.ts +++ b/discojs/src/default_tasks/wikitext.ts @@ -25,7 +25,7 @@ export const wikitext: TaskProvider = { }, trainingInformation: { dataType: 'text', - preprocessingFunctions: [data.TextPreprocessing.Tokenize, data.TextPreprocessing.LeftPadding], + preprocessingFunctions: [data.TextPreprocessing.ToTFTensor], scheme: 'federated', aggregationStrategy: 'mean', minNbOfParticipants: 2, @@ -34,9 +34,9 @@ export const wikitext: TaskProvider = { // But if set to 0 then the webapp doesn't display the validation metrics validationSplit: 0.1, roundDuration: 2, - batchSize: 1, // If set too high (e.g. 16) firefox raises a WebGL error + batchSize: 8, // If set too high firefox raises a WebGL error tokenizer: 'Xenova/gpt2', - maxSequenceLength: 128, + maxSequenceLength: 64, tensorBackend: 'gpt' } } diff --git a/discojs/src/models/gpt/model.ts b/discojs/src/models/gpt/model.ts index c9c9aff6d..6609e3d0c 100644 --- a/discojs/src/models/gpt/model.ts +++ b/discojs/src/models/gpt/model.ts @@ -59,7 +59,6 @@ class GPTModel extends tf.LayersModel { const callbacks = trainingArgs.callbacks as tf.CustomCallbackArgs const evalDataset = trainingArgs.validationData as tf.data.Dataset<{ xs: tf.Tensor2D, ys: tf.Tensor3D }> await callbacks.onTrainBegin?.() - for (let epoch = 1; epoch <= trainingArgs.epochs; epoch++) { let accuracyFraction: [number, number] = [0, 0]; let averageLoss = 0 @@ -75,7 +74,7 @@ class GPTModel extends tf.LayersModel { let preprocessingTime = performance.now() await Promise.all([xs.data(), ys.data()]) preprocessingTime = performance.now() - preprocessingTime - + // TODO include as a tensor inside the model const accTensor = tf.tidy(() => { const logits = this.apply(xs) @@ -92,7 +91,7 @@ class GPTModel extends tf.LayersModel { if (typeof accSum !== 'number') throw new Error('got multiple accuracy sum') accuracyFraction = [accuracyFraction[0] + accSum, accuracyFraction[1] + accSize]; - tf.dispose([accTensor]) + tf.dispose([accTensor]) const lossTensor = tf.tidy(() => { const { grads, value: lossTensor } = this.optimizer.computeGradients(() => { @@ -141,7 +140,7 @@ class GPTModel extends tf.LayersModel { tf.dispose([xs, ys]) } let logs: tf.Logs = { - 'loss': averageLoss / iteration, + 'loss': averageLoss / (iteration - 1), // -1 because iteration got incremented at the end of the loop 'acc': accuracyFraction[0] / accuracyFraction[1], } if (evalDataset !== undefined) { diff --git a/discojs/src/models/index.ts b/discojs/src/models/index.ts index 267e41bec..cf8422edb 100644 --- a/discojs/src/models/index.ts +++ b/discojs/src/models/index.ts @@ -4,4 +4,4 @@ export { BatchLogs, EpochLogs, ValidationMetrics } from "./logs.js"; export { GPT } from './gpt/index.js' export { GPTConfig } from './gpt/config.js' export { TFJS } from './tfjs.js' -export { getTaskTokenizer } from './tokenizer.js' +export { getTaskTokenizer, tokenize } from './tokenizer.js' diff --git a/discojs/src/models/tokenizer.ts b/discojs/src/models/tokenizer.ts index 00d1aa967..da573d43d 100644 --- a/discojs/src/models/tokenizer.ts +++ b/discojs/src/models/tokenizer.ts @@ -25,4 +25,40 @@ export async function getTaskTokenizer(task: Task): Promise task.trainingInformation.tokenizer = tokenizer } return tokenizer +} + +function isArrayOfNumber(raw: unknown): raw is number[] { + return Array.isArray(raw) && raw.every((e) => typeof e === "number"); +} + +interface TokenizingConfig { + padding?: boolean, + truncation?: boolean, + return_tensor?: boolean + text_pair?: string | null, + add_special_tokens?: boolean, + max_length?: number, + return_token_type_ids?: boolean, +} + +/** + * Wrapper around Transformers.js tokenizer to handle type checking and format the output. + * + * @param tokenizer the tokenizer object + * @param text the text to tokenize + * @param config TokenizingConfig, the tokenizing parameters when using `tokenizer` + * @returns number[] the tokenized text + */ +export function tokenize(tokenizer: PreTrainedTokenizer, text: string, config: TokenizingConfig): number[] { + const tokenizerResult: unknown = tokenizer(text, config); + + if ( + typeof tokenizerResult !== "object" || + tokenizerResult === null || + !("input_ids" in tokenizerResult) || + !isArrayOfNumber(tokenizerResult.input_ids) + ) + throw new Error("tokenizer returned unexpected type"); + + return tokenizerResult.input_ids } \ No newline at end of file diff --git a/discojs/src/validation/validator.ts b/discojs/src/validation/validator.ts index 437cbf73a..64eda99bb 100644 --- a/discojs/src/validation/validator.ts +++ b/discojs/src/validation/validator.ts @@ -94,6 +94,13 @@ export class Validator { ) throw new Error("unexpected shape of dataset"); + // TODO: implement WebWorker to remove this wait + // https://github.com/epfml/disco/issues/758 + // When running on cpu the inference hogs the main thread + // and freezes the UI + if (tf.getBackend() === "cpu") { + await new Promise((resolve) => setTimeout(resolve, 100)); + } const prediction = await this.#model.predict(row.xs); tf.dispose(row); let predictions: number[]; diff --git a/docs/examples/wikitext.ts b/docs/examples/wikitext.ts index 4bc2490f1..440ce8572 100644 --- a/docs/examples/wikitext.ts +++ b/docs/examples/wikitext.ts @@ -18,11 +18,15 @@ async function main(): Promise { // Toggle TRAIN_MODEL to either train and save a new model from scratch or load an existing model const TRAIN_MODEL = true + + // Retrieve the tokenizer + const tokenizer = await models.getTaskTokenizer(task) if (TRAIN_MODEL) { + const blockSize = task.trainingInformation.maxSequenceLength ?? 128 + const batchSize = task.trainingInformation.batchSize // Load the wikitext dataset from the `datasets` folder - const dataset = loadText("../../datasets/wikitext/wiki.train.tokens").chain( - loadText("../../datasets/wikitext/wiki.valid.tokens"), - ); + const dataset = loadText("../../datasets/wikitext/wiki.train.tokens", tokenizer, blockSize, batchSize) + .chain(loadText("../../datasets/wikitext/wiki.valid.tokens", tokenizer, blockSize, batchSize)); // Initialize a Disco instance and start training a language model const disco = new Disco(task, url, { scheme: 'federated' }) @@ -36,9 +40,6 @@ async function main(): Promise { // Load the trained model model = await loadModelFromDisk(`${modelFolder}/${modelFileName}`) as models.GPT } - - // Retrieve the tokenizer used during training - const tokenizer = await models.getTaskTokenizer(task) const prompt = 'The game began development in 2010 , carrying over a large portion' const generation = await model.generate(prompt, tokenizer) console.log(generation) diff --git a/server/tests/e2e/federated.spec.ts b/server/tests/e2e/federated.spec.ts index 99794645c..3f55af653 100644 --- a/server/tests/e2e/federated.spec.ts +++ b/server/tests/e2e/federated.spec.ts @@ -4,7 +4,7 @@ import type * as http from "node:http"; import path from "node:path"; import type { RoundLogs, RoundStatus, WeightsContainer } from "@epfml/discojs"; -import { Disco, defaultTasks } from "@epfml/discojs"; +import { Disco, defaultTasks, models } from "@epfml/discojs"; import { loadCSV, loadImagesInDir, loadText } from "@epfml/discojs-node"; import { Server } from "../../src/index.js"; @@ -93,10 +93,16 @@ describe("end-to-end federated", () => { async function wikitextUser(): Promise { const task = defaultTasks.wikitext.getTask(); task.trainingInformation.epochs = 2; - + const tokenizer = await models.getTaskTokenizer(task) + const blockSize = task.trainingInformation.maxSequenceLength ?? 8 + const batchSize = task.trainingInformation.batchSize const dataset = loadText( path.join(DATASET_DIR, "wikitext", "wiki.train.tokens"), - ).chain(loadText(path.join(DATASET_DIR, "wikitext", "wiki.valid.tokens"))); + tokenizer, blockSize, batchSize + ).chain(loadText( + path.join(DATASET_DIR, "wikitext", "wiki.valid.tokens"), + tokenizer, blockSize, batchSize + )); const disco = new Disco(task, url, { scheme: "federated" }); diff --git a/webapp/cypress.config.ts b/webapp/cypress.config.ts index 1031a48be..07adc2ef3 100644 --- a/webapp/cypress.config.ts +++ b/webapp/cypress.config.ts @@ -9,6 +9,10 @@ export default defineConfig({ on("task", { readdir: async (p: string) => (await fs.readdir(p)).map((filename) => path.join(p, filename)), + log: (message) => { + console.log(message) + return null + }, }); }, }, diff --git a/webapp/cypress/e2e/testing.cy.ts b/webapp/cypress/e2e/testing.cy.ts index 0647d0675..4347b5968 100644 --- a/webapp/cypress/e2e/testing.cy.ts +++ b/webapp/cypress/e2e/testing.cy.ts @@ -46,20 +46,27 @@ it("can test lus_covid", () => { it("can start and stop testing of wikitext", () => { setupServerWith(defaultTasks.wikitext); - cy.visit("/#/evaluate"); cy.contains("button", "download").click(); cy.contains("button", "test").click(); - + + // input the dataset cy.contains("label", "select text").selectFile( "../datasets/wikitext/wiki.test.tokens", ); + + // NOTE: internet connection needed + // wait for the tokenizer to load and the filename to display + // otherwise the training starts before the dataset is ready + cy.contains("Connect your data") + .parent() + .parent() + .contains("wiki.test.tokens", { timeout: 20_000 }); + cy.contains("button", "next").click(); + + cy.get('[data-cy="start-test"]').click() - cy.contains("Test & validate") - .parent() - .parent() - .contains("button", "test") - .click(); - cy.contains("button", "stop testing").click(); + cy.get('[data-cy="stop-test"]') + .click({ waitForAnimations: false }); }); diff --git a/webapp/cypress/support/e2e.ts b/webapp/cypress/support/e2e.ts index 1d8e97004..d42a4a47a 100644 --- a/webapp/cypress/support/e2e.ts +++ b/webapp/cypress/support/e2e.ts @@ -74,3 +74,5 @@ beforeEach( req.onsuccess = resolve; }), ); + +beforeEach(() => { localStorage.debug = "discojs*,webapp*" }); diff --git a/webapp/src/components/dataset_input/FileSelection.vue b/webapp/src/components/dataset_input/FileSelection.vue index 0a7fceb24..8e3c595c5 100644 --- a/webapp/src/components/dataset_input/FileSelection.vue +++ b/webapp/src/components/dataset_input/FileSelection.vue @@ -67,14 +67,21 @@ v-if="files !== undefined" class="pt-4 flex flex-col items-center pb-5" > -
+
+ +
+ +
+ + Number of selected files: + {{ files.size }} - Number of selected files: - {{ files.size }} - {{ files.first()?.name ?? "none" }} + + {{ files.first()?.name ?? "none" }}
@@ -89,6 +96,7 @@ diff --git a/webapp/src/components/testing/PredictSteps.vue b/webapp/src/components/testing/PredictSteps.vue index 98ccf7df6..8754f0307 100644 --- a/webapp/src/components/testing/PredictSteps.vue +++ b/webapp/src/components/testing/PredictSteps.vue @@ -32,6 +32,7 @@
diff --git a/webapp/src/components/testing/TestSteps.vue b/webapp/src/components/testing/TestSteps.vue index 946a3fee7..f0f8ea8c9 100644 --- a/webapp/src/components/testing/TestSteps.vue +++ b/webapp/src/components/testing/TestSteps.vue @@ -26,6 +26,7 @@
@@ -39,12 +40,12 @@ against a chosen dataset of yours. Below, once you assessed the model, you can compare the ground truth and the predicted values
- test + test
- stop testing + stop testing
diff --git a/webapp/src/components/training/TrainingSteps.vue b/webapp/src/components/training/TrainingSteps.vue index b9b8c7655..f5a2ffe0a 100644 --- a/webapp/src/components/training/TrainingSteps.vue +++ b/webapp/src/components/training/TrainingSteps.vue @@ -20,6 +20,7 @@