Skip to content

Commit

Permalink
Add support for ModernBert (huggingface#1104)
Browse files Browse the repository at this point in the history
* Fix token decode in fill-mask pipeline

* Add support for ModernBERT

* Add modernbert unit tests

* Cleanup bert unit tests

* Add unit test for `sequence_length > local_attention_window`
  • Loading branch information
xenova authored Dec 19, 2024
1 parent 610391d commit 1691557
Show file tree
Hide file tree
Showing 7 changed files with 377 additions and 93 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te
1. **MobileNetV4** (from Google Inc.) released with the paper [MobileNetV4 - Universal Models for the Mobile Ecosystem](https://arxiv.org/abs/2404.10518) by Danfeng Qin, Chas Leichner, Manolis Delakis, Marco Fornoni, Shixin Luo, Fan Yang, Weijun Wang, Colby Banbury, Chengxi Ye, Berkin Akin, Vaibhav Aggarwal, Tenghui Zhu, Daniele Moro, Andrew Howard.
1. **[MobileViT](https://huggingface.co/docs/transformers/model_doc/mobilevit)** (from Apple) released with the paper [MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer](https://arxiv.org/abs/2110.02178) by Sachin Mehta and Mohammad Rastegari.
1. **[MobileViTV2](https://huggingface.co/docs/transformers/model_doc/mobilevitv2)** (from Apple) released with the paper [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680) by Sachin Mehta and Mohammad Rastegari.
1. **[ModernBERT](https://huggingface.co/docs/transformers/model_doc/modernbert)** (from Answer.AI) released with the paper [Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference](https://arxiv.org/abs/2412.13663) by Benjamin Warner, Antoine Chaffin, Benjamin Clavié, Orion Weller, Oskar Hallström, Said Taghadouini, Alexis Gallagher, Raja Biswas, Faisal Ladhak, Tom Aarsen, Nathan Cooper, Griffin Adams, Jeremy Howard, Iacopo Poli.
1. **Moondream1** released in the repository [moondream](https://github.com/vikhyat/moondream) by vikhyat.
1. **[Moonshine](https://huggingface.co/docs/transformers/model_doc/moonshine)** (from Useful Sensors) released with the paper [Moonshine: Speech Recognition for Live Transcription and Voice Commands](https://arxiv.org/abs/2410.15608) by Nat Jeffries, Evan King, Manjunath Kudlur, Guy Nicholson, James Wang, Pete Warden.
1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu.
Expand Down
1 change: 1 addition & 0 deletions docs/snippets/6_supported-models.snippet
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
1. **MobileNetV4** (from Google Inc.) released with the paper [MobileNetV4 - Universal Models for the Mobile Ecosystem](https://arxiv.org/abs/2404.10518) by Danfeng Qin, Chas Leichner, Manolis Delakis, Marco Fornoni, Shixin Luo, Fan Yang, Weijun Wang, Colby Banbury, Chengxi Ye, Berkin Akin, Vaibhav Aggarwal, Tenghui Zhu, Daniele Moro, Andrew Howard.
1. **[MobileViT](https://huggingface.co/docs/transformers/model_doc/mobilevit)** (from Apple) released with the paper [MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer](https://arxiv.org/abs/2110.02178) by Sachin Mehta and Mohammad Rastegari.
1. **[MobileViTV2](https://huggingface.co/docs/transformers/model_doc/mobilevitv2)** (from Apple) released with the paper [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680) by Sachin Mehta and Mohammad Rastegari.
1. **[ModernBERT](https://huggingface.co/docs/transformers/model_doc/modernbert)** (from Answer.AI) released with the paper [Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference](https://arxiv.org/abs/2412.13663) by Benjamin Warner, Antoine Chaffin, Benjamin Clavié, Orion Weller, Oskar Hallström, Said Taghadouini, Alexis Gallagher, Raja Biswas, Faisal Ladhak, Tom Aarsen, Nathan Cooper, Griffin Adams, Jeremy Howard, Iacopo Poli.
1. **Moondream1** released in the repository [moondream](https://github.com/vikhyat/moondream) by vikhyat.
1. **[Moonshine](https://huggingface.co/docs/transformers/model_doc/moonshine)** (from Useful Sensors) released with the paper [Moonshine: Speech Recognition for Live Transcription and Voice Commands](https://arxiv.org/abs/2410.15608) by Nat Jeffries, Evan King, Manjunath Kudlur, Guy Nicholson, James Wang, Pete Warden.
1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu.
Expand Down
47 changes: 47 additions & 0 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -1951,6 +1951,49 @@ export class BertForQuestionAnswering extends BertPreTrainedModel {
}
//////////////////////////////////////////////////

//////////////////////////////////////////////////
// ModernBert models
export class ModernBertPreTrainedModel extends PreTrainedModel { }
export class ModernBertModel extends ModernBertPreTrainedModel { }

export class ModernBertForMaskedLM extends ModernBertPreTrainedModel {
/**
* Calls the model on new inputs.
*
* @param {Object} model_inputs The inputs to the model.
* @returns {Promise<MaskedLMOutput>} An object containing the model's output logits for masked language modeling.
*/
async _call(model_inputs) {
return new MaskedLMOutput(await super._call(model_inputs));
}
}

export class ModernBertForSequenceClassification extends ModernBertPreTrainedModel {
/**
* Calls the model on new inputs.
*
* @param {Object} model_inputs The inputs to the model.
* @returns {Promise<SequenceClassifierOutput>} An object containing the model's output logits for sequence classification.
*/
async _call(model_inputs) {
return new SequenceClassifierOutput(await super._call(model_inputs));
}
}

export class ModernBertForTokenClassification extends ModernBertPreTrainedModel {
/**
* Calls the model on new inputs.
*
* @param {Object} model_inputs The inputs to the model.
* @returns {Promise<TokenClassifierOutput>} An object containing the model's output logits for token classification.
*/
async _call(model_inputs) {
return new TokenClassifierOutput(await super._call(model_inputs));
}
}
//////////////////////////////////////////////////


//////////////////////////////////////////////////
// NomicBert models
export class NomicBertPreTrainedModel extends PreTrainedModel { }
Expand Down Expand Up @@ -6921,6 +6964,7 @@ export class PretrainedMixin {

const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
['bert', ['BertModel', BertModel]],
['modernbert', ['ModernBertModel', ModernBertModel]],
['nomic_bert', ['NomicBertModel', NomicBertModel]],
['roformer', ['RoFormerModel', RoFormerModel]],
['electra', ['ElectraModel', ElectraModel]],
Expand Down Expand Up @@ -7059,6 +7103,7 @@ const MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = new Map([

const MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = new Map([
['bert', ['BertForSequenceClassification', BertForSequenceClassification]],
['modernbert', ['ModernBertForSequenceClassification', ModernBertForSequenceClassification]],
['roformer', ['RoFormerForSequenceClassification', RoFormerForSequenceClassification]],
['electra', ['ElectraForSequenceClassification', ElectraForSequenceClassification]],
['esm', ['EsmForSequenceClassification', EsmForSequenceClassification]],
Expand All @@ -7080,6 +7125,7 @@ const MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = new Map([

const MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = new Map([
['bert', ['BertForTokenClassification', BertForTokenClassification]],
['modernbert', ['ModernBertForTokenClassification', ModernBertForTokenClassification]],
['roformer', ['RoFormerForTokenClassification', RoFormerForTokenClassification]],
['electra', ['ElectraForTokenClassification', ElectraForTokenClassification]],
['esm', ['EsmForTokenClassification', EsmForTokenClassification]],
Expand Down Expand Up @@ -7148,6 +7194,7 @@ const MODEL_FOR_MULTIMODALITY_MAPPING_NAMES = new Map([

const MODEL_FOR_MASKED_LM_MAPPING_NAMES = new Map([
['bert', ['BertForMaskedLM', BertForMaskedLM]],
['modernbert', ['ModernBertForMaskedLM', ModernBertForMaskedLM]],
['roformer', ['RoFormerForMaskedLM', RoFormerForMaskedLM]],
['electra', ['ElectraForMaskedLM', ElectraForMaskedLM]],
['esm', ['EsmForMaskedLM', EsmForMaskedLM]],
Expand Down
2 changes: 1 addition & 1 deletion src/pipelines.js
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ export class FillMaskPipeline extends (/** @type {new (options: TextPipelineCons
return {
score: values[i],
token: Number(x),
token_str: this.tokenizer.model.vocab[x],
token_str: this.tokenizer.decode([x]),
sequence: this.tokenizer.decode(sequence, { skip_special_tokens: true }),
}
}));
Expand Down
18 changes: 4 additions & 14 deletions tests/models/bert/test_modeling_bert.js
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,9 @@ export default () => {
async () => {
const inputs = tokenizer("hello");
const { logits } = await model(inputs);
const target = [[0.00043986947275698185, -0.030218850821256638]].flat();
const target = [[0.00043986947275698185, -0.030218850821256638]];
expect(logits.dims).toEqual([1, 2]);
logits
.tolist()
.flat()
.forEach((item, i) => {
expect(item).toBeCloseTo(target[i], 5);
});
expect(logits.tolist()).toBeCloseToNested(target, 5);
},
MAX_TEST_EXECUTION_TIME,
);
Expand All @@ -120,14 +115,9 @@ export default () => {
const target = [
[0.00043986947275698185, -0.030218850821256638],
[0.0003853091038763523, -0.03022204339504242],
].flat();
];
expect(logits.dims).toEqual([2, 2]);
logits
.tolist()
.flat()
.forEach((item, i) => {
expect(item).toBeCloseTo(target[i], 5);
});
expect(logits.tolist()).toBeCloseToNested(target, 5);
},
MAX_TEST_EXECUTION_TIME,
);
Expand Down
180 changes: 180 additions & 0 deletions tests/models/modernbert/test_modeling_modernbert.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import { PreTrainedTokenizer, ModernBertModel, ModernBertForMaskedLM, ModernBertForSequenceClassification, ModernBertForTokenClassification } from "../../../src/transformers.js";

import { MAX_MODEL_LOAD_TIME, MAX_TEST_EXECUTION_TIME, MAX_MODEL_DISPOSE_TIME, DEFAULT_MODEL_OPTIONS } from "../../init.js";

export default () => {
describe("ModernBertModel", () => {
const model_id = "hf-internal-testing/tiny-random-ModernBertModel";

/** @type {ModernBertModel} */
let model;
/** @type {PreTrainedTokenizer} */
let tokenizer;
beforeAll(async () => {
model = await ModernBertModel.from_pretrained(model_id, DEFAULT_MODEL_OPTIONS);
tokenizer = await PreTrainedTokenizer.from_pretrained(model_id);
}, MAX_MODEL_LOAD_TIME);

it(
"batch_size=1",
async () => {
const inputs = tokenizer("hello");
const { last_hidden_state } = await model(inputs);
expect(last_hidden_state.dims).toEqual([1, 3, 32]);
expect(last_hidden_state.mean().item()).toBeCloseTo(-0.08922556787729263, 5);
},
MAX_TEST_EXECUTION_TIME,
);

it(
"batch_size>1",
async () => {
const inputs = tokenizer(["hello", "hello world"], { padding: true });
const { last_hidden_state } = await model(inputs);
expect(last_hidden_state.dims).toEqual([2, 4, 32]);
expect(last_hidden_state.mean().item()).toBeCloseTo(0.048988230526447296, 5);
},
MAX_TEST_EXECUTION_TIME,
);

it(
"sequence_length > local_attention_window",
async () => {
const text = "The sun cast long shadows across the weathered cobblestones as Thomas made his way through the ancient city. The evening air carried whispers of autumn, rustling through the golden leaves that danced and swirled around his feet. His thoughts wandered to the events that had brought him here, to this moment, in this forgotten corner of the world. The old buildings loomed above him, their facades telling stories of centuries past. Windows reflected the dying light of day, creating a kaleidoscope of amber and rose that painted the narrow streets. The distant sound of church bells echoed through the maze of alleyways, marking time's steady march forward. In his pocket, he fingered the small brass key that had belonged to his grandfather. Its weight seemed to grow heavier with each step, a tangible reminder of the promise he had made. The mystery of its purpose had consumed his thoughts for weeks, leading him through archives and dusty libraries, through conversations with local historians and elderly residents who remembered the old days. As the evening deepened into dusk, streetlamps flickered to life one by one, creating pools of warm light that guided his way. The smell of wood smoke and distant cooking fires drifted through the air, reminding him of childhood evenings spent by the hearth, listening to his grandfather's tales of hidden treasures and secret passages. His footsteps echoed against the stone walls, a rhythmic accompaniment to his journey. Each step brought him closer to his destination, though uncertainty still clouded his mind about what he might find. The old map in his other pocket, creased and worn from constant consultation, had led him this far. The street ahead narrowed, and the buildings seemed to lean in closer, their upper stories nearly touching above his head. The air grew cooler in this shadowed passage, and his breath formed small clouds in front of him. Something about this place felt different, charged with possibility and ancient secrets. He walked down the [MASK]";
const inputs = tokenizer(text);
const { last_hidden_state } = await model(inputs);
expect(last_hidden_state.dims).toEqual([1, 397, 32]);
expect(last_hidden_state.mean().item()).toBeCloseTo(-0.06889555603265762, 5);
},
MAX_TEST_EXECUTION_TIME,
);

afterAll(async () => {
await model?.dispose();
}, MAX_MODEL_DISPOSE_TIME);
});

describe("ModernBertForMaskedLM", () => {
const model_id = "hf-internal-testing/tiny-random-ModernBertForMaskedLM";

const texts = ["The goal of life is [MASK].", "Paris is the [MASK] of France."];

/** @type {ModernBertForMaskedLM} */
let model;
/** @type {PreTrainedTokenizer} */
let tokenizer;
beforeAll(async () => {
model = await ModernBertForMaskedLM.from_pretrained(model_id, DEFAULT_MODEL_OPTIONS);
tokenizer = await PreTrainedTokenizer.from_pretrained(model_id);
}, MAX_MODEL_LOAD_TIME);

it(
"batch_size=1",
async () => {
const inputs = tokenizer(texts[0]);
const { logits } = await model(inputs);
expect(logits.dims).toEqual([1, 9, 50368]);
expect(logits.mean().item()).toBeCloseTo(0.0053214821964502335, 5);
},
MAX_TEST_EXECUTION_TIME,
);

it(
"batch_size>1",
async () => {
const inputs = tokenizer(texts, { padding: true });
const { logits } = await model(inputs);
expect(logits.dims).toEqual([2, 9, 50368]);
expect(logits.mean().item()).toBeCloseTo(0.009154772385954857, 5);
},
MAX_TEST_EXECUTION_TIME,
);

afterAll(async () => {
await model?.dispose();
}, MAX_MODEL_DISPOSE_TIME);
});

describe("ModernBertForSequenceClassification", () => {
const model_id = "hf-internal-testing/tiny-random-ModernBertForSequenceClassification";

/** @type {ModernBertForSequenceClassification} */
let model;
/** @type {PreTrainedTokenizer} */
let tokenizer;
beforeAll(async () => {
model = await ModernBertForSequenceClassification.from_pretrained(model_id, DEFAULT_MODEL_OPTIONS);
tokenizer = await PreTrainedTokenizer.from_pretrained(model_id);
}, MAX_MODEL_LOAD_TIME);

it(
"batch_size=1",
async () => {
const inputs = tokenizer("hello");
const { logits } = await model(inputs);
const target = [[-0.7050137519836426, 2.343430519104004]];
expect(logits.dims).toEqual([1, 2]);
expect(logits.tolist()).toBeCloseToNested(target, 5);
},
MAX_TEST_EXECUTION_TIME,
);

it(
"batch_size>1",
async () => {
const inputs = tokenizer(["hello", "hello world"], { padding: true });
const { logits } = await model(inputs);
const target = [
[-0.7050137519836426, 2.343430519104004],
[-2.6860175132751465, 3.993380546569824],
];
expect(logits.dims).toEqual([2, 2]);
expect(logits.tolist()).toBeCloseToNested(target, 5);
},
MAX_TEST_EXECUTION_TIME,
);

afterAll(async () => {
await model?.dispose();
}, MAX_MODEL_DISPOSE_TIME);
});

describe("ModernBertForTokenClassification", () => {
const model_id = "hf-internal-testing/tiny-random-ModernBertForTokenClassification";

/** @type {ModernBertForTokenClassification} */
let model;
/** @type {PreTrainedTokenizer} */
let tokenizer;
beforeAll(async () => {
model = await ModernBertForTokenClassification.from_pretrained(model_id, DEFAULT_MODEL_OPTIONS);
tokenizer = await PreTrainedTokenizer.from_pretrained(model_id);
}, MAX_MODEL_LOAD_TIME);

it(
"batch_size=1",
async () => {
const inputs = tokenizer("hello");
const { logits } = await model(inputs);
expect(logits.dims).toEqual([1, 3, 2]);
expect(logits.mean().item()).toBeCloseTo(1.0337047576904297, 5);
},
MAX_TEST_EXECUTION_TIME,
);

it(
"batch_size>1",
async () => {
const inputs = tokenizer(["hello", "hello world"], { padding: true });
const { logits } = await model(inputs);
expect(logits.dims).toEqual([2, 4, 2]);
expect(logits.mean().item()).toBeCloseTo(-1.3397092819213867, 5);
},
MAX_TEST_EXECUTION_TIME,
);

afterAll(async () => {
await model?.dispose();
}, MAX_MODEL_DISPOSE_TIME);
});
};
Loading

0 comments on commit 1691557

Please sign in to comment.