Skip to content

Commit

Permalink
Ensure WASM fallback does not crash in GH actions (#402)
Browse files Browse the repository at this point in the history
* Ensure WASM fallback does not crash in GH actions

* Add unit test for WordPiece `max_input_chars_per_word`

* Cleanup

* Set max test concurrency to 1
  • Loading branch information
xenova authored Nov 19, 2023
1 parent 19daf2d commit b8719b1
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 9 deletions.
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"dev": "webpack serve --no-client-overlay",
"build": "webpack && npm run typegen",
"generate-tests": "python -m tests.generate_tests",
"test": "node --experimental-vm-modules node_modules/jest/bin/jest.js --verbose",
"test": "node --experimental-vm-modules node_modules/jest/bin/jest.js --verbose --maxConcurrency 1",
"readme": "python ./docs/scripts/build_readme.py",
"docs-api": "node ./docs/scripts/generate.js",
"docs-preview": "doc-builder preview transformers.js ./docs/source/ --not_python_module",
Expand Down
4 changes: 2 additions & 2 deletions tests/generation.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ describe('Generation parameters', () => {

// List all models which will be tested
const models = [
'Xenova/LaMini-Flan-T5-77M', // encoder-decoder
'Xenova/LaMini-GPT-124M', // decoder-only
'MBZUAI/LaMini-Flan-T5-77M', // encoder-decoder
'MBZUAI/LaMini-GPT-124M', // decoder-only
];

// encoder-decoder model
Expand Down
6 changes: 6 additions & 0 deletions tests/init.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ import { onnxruntimeBackend } from "onnxruntime-node/dist/backend";
import ONNX_COMMON from "onnxruntime-common";

export function init() {
// In rare cases (specifically when running unit tests with GitHub actions), possibly due to
// a large number of concurrent executions, onnxruntime might fallback to use the WASM backend.
// In this case, we set the number of threads to 1 to avoid errors like:
// - `TypeError: The worker script or module filename must be an absolute path or a relative path starting with './' or '../'. Received "blob:nodedata:..."`
ONNX_COMMON.env.wasm.numThreads = 1;

// A workaround to define a new backend for onnxruntime, which
// will not throw an error when running tests with jest.
// For more information, see: https://github.com/jestjs/jest/issues/11864#issuecomment-1261468011
Expand Down
5 changes: 1 addition & 4 deletions tests/tensor.test.js
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@

import { AutoProcessor, Tensor } from '../src/transformers.js';

import { MAX_TEST_EXECUTION_TIME, m } from './init.js';
import { Tensor } from '../src/transformers.js';
import { compare } from './test_utils.js';

import { cat, mean, stack } from '../src/utils/tensor.js';

describe('Tensor operations', () => {
Expand Down
13 changes: 11 additions & 2 deletions tests/tokenizers.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import { AutoTokenizer } from '../src/transformers.js';
import { getFile } from '../src/utils/hub.js';
import { m, MAX_TEST_EXECUTION_TIME } from './init.js';
import { compare } from './test_utils.js';

// Load test data generated by the python tests
// TODO do this dynamically?
Expand Down Expand Up @@ -41,10 +42,18 @@ describe('Tokenizers', () => {

describe('Edge cases', () => {
it('should not crash when encoding a very long string', async () => {
let tokenizer = await AutoTokenizer.from_pretrained('t5-small');
let tokenizer = await AutoTokenizer.from_pretrained('Xenova/t5-small');

let text = String.prototype.repeat.call('Hello world! ', 50000);
let encoded = await tokenizer(text);
let encoded = tokenizer(text);
expect(encoded.input_ids.data.length).toBeGreaterThan(100000);
}, MAX_TEST_EXECUTION_TIME);

it('should not take too long', async () => {
let tokenizer = await AutoTokenizer.from_pretrained('Xenova/all-MiniLM-L6-v2');

let text = String.prototype.repeat.call('a', 50000);
let token_ids = tokenizer.encode(text);
compare(token_ids, [101, 100, 102])
}, 5000); // NOTE: 5 seconds
});

0 comments on commit b8719b1

Please sign in to comment.