Skip to content

Commit

Permalink
Store mapping between class and name
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Sep 15, 2023
1 parent 281eaf0 commit 1ffe60b
Showing 1 changed file with 152 additions and 139 deletions.
291 changes: 152 additions & 139 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,11 @@ const MODEL_TYPES = {
//////////////////////////////////////////////////
// Helper functions

// Will be populated fully later
const MODEL_TYPE_MAPPING = new Map([
['CLIPTextModelWithProjection', MODEL_TYPES.EncoderOnly],
['CLIPVisionModelWithProjection', MODEL_TYPES.EncoderOnly],
]);
// NOTE: These will be populated fully later
const MODEL_TYPE_MAPPING = new Map();
const MODEL_NAME_TO_CLASS_MAPPING = new Map();
const MODEL_CLASS_TO_NAME_MAPPING = new Map();


/**
* Constructs an InferenceSession using a model file located at the specified path.
Expand Down Expand Up @@ -586,9 +586,8 @@ export class PreTrainedModel extends Callable {
this.config = config;
this.session = session;

// Set `runBeam` function:
const className = this.constructor.name;
const modelType = MODEL_TYPE_MAPPING.get(className);
const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this.constructor);
const modelType = MODEL_TYPE_MAPPING.get(modelName);

this.can_generate = false;
this._runBeam = null;
Expand Down Expand Up @@ -670,7 +669,8 @@ export class PreTrainedModel extends Callable {
model_file_name,
}

let modelType = MODEL_TYPE_MAPPING.get(this.name);
const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this);
const modelType = MODEL_TYPE_MAPPING.get(modelName);

let info;
if (modelType === MODEL_TYPES.DecoderOnly) {
Expand All @@ -696,7 +696,7 @@ export class PreTrainedModel extends Callable {

} else { // should be MODEL_TYPES.EncoderOnly
if (modelType !== MODEL_TYPES.EncoderOnly) {
console.warn(`Model type for ${this.name} not found, assuming encoder-only architecture. Please report this at https://github.com/xenova/transformers.js/issues/new/choose.`)
console.warn(`Model type for '${modelName}' not found, assuming encoder-only architecture. Please report this at https://github.com/xenova/transformers.js/issues/new/choose.`)
}
info = await Promise.all([
AutoConfig.from_pretrained(pretrained_model_name_or_path, options),
Expand Down Expand Up @@ -906,10 +906,12 @@ export class PreTrainedModel extends Callable {
} = {},
) {
if (!this.can_generate) {
const possibleTypes = MODEL_WITH_LM_HEAD_MAPPING_NAMES.get(this.config.model_type);
let errorMessage = `The current model class (${this.constructor.name}) is not compatible with \`.generate()\`, as it doesn't have a language model head.`
if (possibleTypes) {
errorMessage += ` Please use one of the following classes instead: {'${possibleTypes.constructor.name}'}`;
// TODO: support multiple options
const possibleInfo = MODEL_WITH_LM_HEAD_MAPPING_NAMES.get(this.config.model_type);
const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this.constructor);
let errorMessage = `The current model class (${modelName}) is not compatible with \`.generate()\`, as it doesn't have a language model head.`
if (possibleInfo) {
errorMessage += ` Please use the following class instead: '${possibleInfo[0]}'`;
}
throw Error(errorMessage);
}
Expand Down Expand Up @@ -3348,174 +3350,174 @@ export class PretrainedMixin {
}

const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
['bert', BertModel],
['camembert', CamembertModel],
['deberta', DebertaModel],
['deberta-v2', DebertaV2Model],
['mpnet', MPNetModel],
['albert', AlbertModel],
['distilbert', DistilBertModel],
['roberta', RobertaModel],
['xlm', XLMModel],
['xlm-roberta', XLMRobertaModel],
['clip', CLIPModel],
['mobilebert', MobileBertModel],
['squeezebert', SqueezeBertModel],
['wav2vec2', Wav2Vec2Model],
['wavlm', WavLMModel],

['detr', DetrModel],
['vit', ViTModel],
['mobilevit', MobileViTModel],
['beit', BeitModel],
['deit', DeiTModel],
['resnet', ResNetModel],
['swin', SwinModel],
['yolos', YolosModel],

['sam', SamModel], // TODO change to encoder-decoder when model is split correctly
['bert', ['BertModel', BertModel]],
['camembert', ['CamembertModel', CamembertModel]],
['deberta', ['DebertaModel', DebertaModel]],
['deberta-v2', ['DebertaV2Model', DebertaV2Model]],
['mpnet', ['MPNetModel', MPNetModel]],
['albert', ['AlbertModel', AlbertModel]],
['distilbert', ['DistilBertModel', DistilBertModel]],
['roberta', ['RobertaModel', RobertaModel]],
['xlm', ['XLMModel', XLMModel]],
['xlm-roberta', ['XLMRobertaModel', XLMRobertaModel]],
['clip', ['CLIPModel', CLIPModel]],
['mobilebert', ['MobileBertModel', MobileBertModel]],
['squeezebert', ['SqueezeBertModel', SqueezeBertModel]],
['wav2vec2', ['Wav2Vec2Model', Wav2Vec2Model]],
['wavlm', ['WavLMModel', WavLMModel]],

['detr', ['DetrModel', DetrModel]],
['vit', ['ViTModel', ViTModel]],
['mobilevit', ['MobileViTModel', MobileViTModel]],
['beit', ['BeitModel', BeitModel]],
['deit', ['DeiTModel', DeiTModel]],
['resnet', ['ResNetModel', ResNetModel]],
['swin', ['SwinModel', SwinModel]],
['yolos', ['YolosModel', YolosModel]],

['sam', ['SamModel', SamModel]], // TODO change to encoder-decoder when model is split correctly
]);

const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map([
['t5', T5Model],
['mt5', MT5Model],
['bart', BartModel],
['mbart', MBartModel],
['marian', MarianModel],
['whisper', WhisperModel],
['m2m_100', M2M100Model],
['t5', ['T5Model', T5Model]],
['mt5', ['MT5Model', MT5Model]],
['bart', ['BartModel', BartModel]],
['mbart', ['MBartModel', MBartModel]],
['marian', ['MarianModel', MarianModel]],
['whisper', ['WhisperModel', WhisperModel]],
['m2m_100', ['M2M100Model', M2M100Model]],
]);


const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([
['bloom', BloomModel],
['gpt2', GPT2Model],
['gptj', GPTJModel],
['gpt_bigcode', GPTBigCodeModel],
['gpt_neo', GPTNeoModel],
['gpt_neox', GPTNeoXModel],
['codegen', CodeGenModel],
['llama', LlamaModel],
['mpt', MptModel],
['opt', OPTModel],
['bloom', ['BloomModel', BloomModel]],
['gpt2', ['GPT2Model', GPT2Model]],
['gptj', ['GPTJModel', GPTJModel]],
['gpt_bigcode', ['GPTBigCodeModel', GPTBigCodeModel]],
['gpt_neo', ['GPTNeoModel', GPTNeoModel]],
['gpt_neox', ['GPTNeoXModel', GPTNeoXModel]],
['codegen', ['CodeGenModel', CodeGenModel]],
['llama', ['LlamaModel', LlamaModel]],
['mpt', ['MptModel', MptModel]],
['opt', ['OPTModel', OPTModel]],
]);

const MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = new Map([
['bert', BertForSequenceClassification],
['camembert', CamembertForSequenceClassification],
['deberta', DebertaForSequenceClassification],
['deberta-v2', DebertaV2ForSequenceClassification],
['mpnet', MPNetForSequenceClassification],
['albert', AlbertForSequenceClassification],
['distilbert', DistilBertForSequenceClassification],
['roberta', RobertaForSequenceClassification],
['xlm', XLMForSequenceClassification],
['xlm-roberta', XLMRobertaForSequenceClassification],
['bart', BartForSequenceClassification],
['mbart', MBartForSequenceClassification],
['mobilebert', MobileBertForSequenceClassification],
['squeezebert', SqueezeBertForSequenceClassification],
['bert', ['BertForSequenceClassification', BertForSequenceClassification]],
['camembert', ['CamembertForSequenceClassification', CamembertForSequenceClassification]],
['deberta', ['DebertaForSequenceClassification', DebertaForSequenceClassification]],
['deberta-v2', ['DebertaV2ForSequenceClassification', DebertaV2ForSequenceClassification]],
['mpnet', ['MPNetForSequenceClassification', MPNetForSequenceClassification]],
['albert', ['AlbertForSequenceClassification', AlbertForSequenceClassification]],
['distilbert', ['DistilBertForSequenceClassification', DistilBertForSequenceClassification]],
['roberta', ['RobertaForSequenceClassification', RobertaForSequenceClassification]],
['xlm', ['XLMForSequenceClassification', XLMForSequenceClassification]],
['xlm-roberta', ['XLMRobertaForSequenceClassification', XLMRobertaForSequenceClassification]],
['bart', ['BartForSequenceClassification', BartForSequenceClassification]],
['mbart', ['MBartForSequenceClassification', MBartForSequenceClassification]],
['mobilebert', ['MobileBertForSequenceClassification', MobileBertForSequenceClassification]],
['squeezebert', ['SqueezeBertForSequenceClassification', SqueezeBertForSequenceClassification]],
]);

const MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = new Map([
['bert', BertForTokenClassification],
['camembert', CamembertForTokenClassification],
['deberta', DebertaForTokenClassification],
['deberta-v2', DebertaV2ForTokenClassification],
['mpnet', MPNetForTokenClassification],
['distilbert', DistilBertForTokenClassification],
['roberta', RobertaForTokenClassification],
['xlm', XLMForTokenClassification],
['xlm-roberta', XLMRobertaForTokenClassification],
['bert', ['BertForTokenClassification', BertForTokenClassification]],
['camembert', ['CamembertForTokenClassification', CamembertForTokenClassification]],
['deberta', ['DebertaForTokenClassification', DebertaForTokenClassification]],
['deberta-v2', ['DebertaV2ForTokenClassification', DebertaV2ForTokenClassification]],
['mpnet', ['MPNetForTokenClassification', MPNetForTokenClassification]],
['distilbert', ['DistilBertForTokenClassification', DistilBertForTokenClassification]],
['roberta', ['RobertaForTokenClassification', RobertaForTokenClassification]],
['xlm', ['XLMForTokenClassification', XLMForTokenClassification]],
['xlm-roberta', ['XLMRobertaForTokenClassification', XLMRobertaForTokenClassification]],
]);

const MODEL_FOR_SEQ_2_SEQ_MAPPING_NAMES = new Map([
['t5', T5ForConditionalGeneration],
['mt5', MT5ForConditionalGeneration],
['bart', BartForConditionalGeneration],
['mbart', MBartForConditionalGeneration],
['whisper', WhisperForConditionalGeneration],
['marian', MarianMTModel],
['m2m_100', M2M100ForConditionalGeneration],
['t5', ['T5ForConditionalGeneration', T5ForConditionalGeneration]],
['mt5', ['MT5ForConditionalGeneration', MT5ForConditionalGeneration]],
['bart', ['BartForConditionalGeneration', BartForConditionalGeneration]],
['mbart', ['MBartForConditionalGeneration', MBartForConditionalGeneration]],
['whisper', ['WhisperForConditionalGeneration', WhisperForConditionalGeneration]],
['marian', ['MarianMTModel', MarianMTModel]],
['m2m_100', ['M2M100ForConditionalGeneration', M2M100ForConditionalGeneration]],
]);

const MODEL_WITH_LM_HEAD_MAPPING_NAMES = new Map([
['bloom', BloomForCausalLM],
['gpt2', GPT2LMHeadModel],
['gptj', GPTJForCausalLM],
['gpt_bigcode', GPTBigCodeForCausalLM],
['gpt_neo', GPTNeoForCausalLM],
['gpt_neox', GPTNeoXForCausalLM],
['codegen', CodeGenForCausalLM],
['llama', LlamaForCausalLM],
['mpt', MptForCausalLM],
['opt', OPTForCausalLM],
['bloom', ['BloomForCausalLM', BloomForCausalLM]],
['gpt2', ['GPT2LMHeadModel', GPT2LMHeadModel]],
['gptj', ['GPTJForCausalLM', GPTJForCausalLM]],
['gpt_bigcode', ['GPTBigCodeForCausalLM', GPTBigCodeForCausalLM]],
['gpt_neo', ['GPTNeoForCausalLM', GPTNeoForCausalLM]],
['gpt_neox', ['GPTNeoXForCausalLM', GPTNeoXForCausalLM]],
['codegen', ['CodeGenForCausalLM', CodeGenForCausalLM]],
['llama', ['LlamaForCausalLM', LlamaForCausalLM]],
['mpt', ['MptForCausalLM', MptForCausalLM]],
['opt', ['OPTForCausalLM', OPTForCausalLM]],
]);

const MODEL_FOR_MASKED_LM_MAPPING_NAMES = new Map([
['bert', BertForMaskedLM],
['camembert', CamembertForMaskedLM],
['deberta', DebertaForMaskedLM],
['deberta-v2', DebertaV2ForMaskedLM],
['mpnet', MPNetForMaskedLM],
['albert', AlbertForMaskedLM],
['distilbert', DistilBertForMaskedLM],
['roberta', RobertaForMaskedLM],
['xlm', XLMWithLMHeadModel],
['xlm-roberta', XLMRobertaForMaskedLM],
['mobilebert', MobileBertForMaskedLM],
['squeezebert', SqueezeBertForMaskedLM],
['bert', ['BertForMaskedLM', BertForMaskedLM]],
['camembert', ['CamembertForMaskedLM', CamembertForMaskedLM]],
['deberta', ['DebertaForMaskedLM', DebertaForMaskedLM]],
['deberta-v2', ['DebertaV2ForMaskedLM', DebertaV2ForMaskedLM]],
['mpnet', ['MPNetForMaskedLM', MPNetForMaskedLM]],
['albert', ['AlbertForMaskedLM', AlbertForMaskedLM]],
['distilbert', ['DistilBertForMaskedLM', DistilBertForMaskedLM]],
['roberta', ['RobertaForMaskedLM', RobertaForMaskedLM]],
['xlm', ['XLMWithLMHeadModel', XLMWithLMHeadModel]],
['xlm-roberta', ['XLMRobertaForMaskedLM', XLMRobertaForMaskedLM]],
['mobilebert', ['MobileBertForMaskedLM', MobileBertForMaskedLM]],
['squeezebert', ['SqueezeBertForMaskedLM', SqueezeBertForMaskedLM]],
]);

const MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = new Map([
['bert', BertForQuestionAnswering],
['camembert', CamembertForQuestionAnswering],
['deberta', DebertaForQuestionAnswering],
['deberta-v2', DebertaV2ForQuestionAnswering],
['mpnet', MPNetForQuestionAnswering],
['albert', AlbertForQuestionAnswering],
['distilbert', DistilBertForQuestionAnswering],
['roberta', RobertaForQuestionAnswering],
['xlm', XLMForQuestionAnswering],
['xlm-roberta', XLMRobertaForQuestionAnswering],
['mobilebert', MobileBertForQuestionAnswering],
['squeezebert', SqueezeBertForQuestionAnswering],
['bert', ['BertForQuestionAnswering', BertForQuestionAnswering]],
['camembert', ['CamembertForQuestionAnswering', CamembertForQuestionAnswering]],
['deberta', ['DebertaForQuestionAnswering', DebertaForQuestionAnswering]],
['deberta-v2', ['DebertaV2ForQuestionAnswering', DebertaV2ForQuestionAnswering]],
['mpnet', ['MPNetForQuestionAnswering', MPNetForQuestionAnswering]],
['albert', ['AlbertForQuestionAnswering', AlbertForQuestionAnswering]],
['distilbert', ['DistilBertForQuestionAnswering', DistilBertForQuestionAnswering]],
['roberta', ['RobertaForQuestionAnswering', RobertaForQuestionAnswering]],
['xlm', ['XLMForQuestionAnswering', XLMForQuestionAnswering]],
['xlm-roberta', ['XLMRobertaForQuestionAnswering', XLMRobertaForQuestionAnswering]],
['mobilebert', ['MobileBertForQuestionAnswering', MobileBertForQuestionAnswering]],
['squeezebert', ['SqueezeBertForQuestionAnswering', SqueezeBertForQuestionAnswering]],
]);

const MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = new Map([
['vision-encoder-decoder', VisionEncoderDecoderModel],
['vision-encoder-decoder', ['VisionEncoderDecoderModel', VisionEncoderDecoderModel]],
]);

const MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = new Map([
['vit', ViTForImageClassification],
['mobilevit', MobileViTForImageClassification],
['beit', BeitForImageClassification],
['deit', DeiTForImageClassification],
['resnet', ResNetForImageClassification],
['swin', SwinForImageClassification],
['vit', ['ViTForImageClassification', ViTForImageClassification]],
['mobilevit', ['MobileViTForImageClassification', MobileViTForImageClassification]],
['beit', ['BeitForImageClassification', BeitForImageClassification]],
['deit', ['DeiTForImageClassification', DeiTForImageClassification]],
['resnet', ['ResNetForImageClassification', ResNetForImageClassification]],
['swin', ['SwinForImageClassification', SwinForImageClassification]],
]);

const MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = new Map([
['detr', DetrForObjectDetection],
['yolos', YolosForObjectDetection],
['detr', ['DetrForObjectDetection', DetrForObjectDetection]],
['yolos', ['YolosForObjectDetection', YolosForObjectDetection]],
]);

const MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = new Map([
['detr', DetrForSegmentation],
['detr', ['DetrForSegmentation', DetrForSegmentation]],
]);

const MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = new Map([
['sam', SamModel],
['sam', ['SamModel', SamModel]],
]);

const MODEL_FOR_CTC_MAPPING_NAMES = new Map([
['wav2vec2', Wav2Vec2ForCTC],
['wavlm', WavLMForCTC],
['wav2vec2', ['Wav2Vec2ForCTC', Wav2Vec2ForCTC]],
['wavlm', ['WavLMForCTC', WavLMForCTC]],
]);

const MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = new Map([
['wav2vec2', Wav2Vec2ForSequenceClassification],
['wavlm', WavLMForSequenceClassification],
['wav2vec2', ['Wav2Vec2ForSequenceClassification', Wav2Vec2ForSequenceClassification]],
['wavlm', ['WavLMForSequenceClassification', WavLMForSequenceClassification]],
]);


Expand All @@ -3540,12 +3542,23 @@ const MODEL_CLASS_TYPE_MAPPING = [

for (const [mappings, type] of MODEL_CLASS_TYPE_MAPPING) {
// @ts-ignore
for (const model of mappings.values()) {
// @ts-ignore
MODEL_TYPE_MAPPING.set(model.name, type);
for (const [name, model] of mappings.values()) {
MODEL_TYPE_MAPPING.set(name, type);
MODEL_CLASS_TO_NAME_MAPPING.set(model, name);
MODEL_NAME_TO_CLASS_MAPPING.set(name, model);
}
}

const CUSTOM_MAPPING = [
['CLIPTextModelWithProjection', CLIPTextModelWithProjection, MODEL_TYPES.EncoderOnly],
['CLIPVisionModelWithProjection', CLIPVisionModelWithProjection, MODEL_TYPES.EncoderOnly],
]
for (const [name, model, type] of CUSTOM_MAPPING) {
MODEL_TYPE_MAPPING.set(name, type);
MODEL_CLASS_TO_NAME_MAPPING.set(model, name);
MODEL_NAME_TO_CLASS_MAPPING.set(name, model);
}


/**
* Helper class which is used to instantiate pretrained models with the `from_pretrained` function.
Expand Down

0 comments on commit 1ffe60b

Please sign in to comment.