From 1ffe60b99f9f09b445c9b0060f1d905fc52a2fc0 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Fri, 15 Sep 2023 15:54:33 +0200 Subject: [PATCH] Store mapping between class and name --- src/models.js | 291 ++++++++++++++++++++++++++------------------------ 1 file changed, 152 insertions(+), 139 deletions(-) diff --git a/src/models.js b/src/models.js index 9265ab50f..ef6447e36 100644 --- a/src/models.js +++ b/src/models.js @@ -101,11 +101,11 @@ const MODEL_TYPES = { ////////////////////////////////////////////////// // Helper functions -// Will be populated fully later -const MODEL_TYPE_MAPPING = new Map([ - ['CLIPTextModelWithProjection', MODEL_TYPES.EncoderOnly], - ['CLIPVisionModelWithProjection', MODEL_TYPES.EncoderOnly], -]); +// NOTE: These will be populated fully later +const MODEL_TYPE_MAPPING = new Map(); +const MODEL_NAME_TO_CLASS_MAPPING = new Map(); +const MODEL_CLASS_TO_NAME_MAPPING = new Map(); + /** * Constructs an InferenceSession using a model file located at the specified path. @@ -586,9 +586,8 @@ export class PreTrainedModel extends Callable { this.config = config; this.session = session; - // Set `runBeam` function: - const className = this.constructor.name; - const modelType = MODEL_TYPE_MAPPING.get(className); + const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this.constructor); + const modelType = MODEL_TYPE_MAPPING.get(modelName); this.can_generate = false; this._runBeam = null; @@ -670,7 +669,8 @@ export class PreTrainedModel extends Callable { model_file_name, } - let modelType = MODEL_TYPE_MAPPING.get(this.name); + const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this); + const modelType = MODEL_TYPE_MAPPING.get(modelName); let info; if (modelType === MODEL_TYPES.DecoderOnly) { @@ -696,7 +696,7 @@ export class PreTrainedModel extends Callable { } else { // should be MODEL_TYPES.EncoderOnly if (modelType !== MODEL_TYPES.EncoderOnly) { - console.warn(`Model type for ${this.name} not found, assuming encoder-only architecture. Please report this at https://github.com/xenova/transformers.js/issues/new/choose.`) + console.warn(`Model type for '${modelName}' not found, assuming encoder-only architecture. Please report this at https://github.com/xenova/transformers.js/issues/new/choose.`) } info = await Promise.all([ AutoConfig.from_pretrained(pretrained_model_name_or_path, options), @@ -906,10 +906,12 @@ export class PreTrainedModel extends Callable { } = {}, ) { if (!this.can_generate) { - const possibleTypes = MODEL_WITH_LM_HEAD_MAPPING_NAMES.get(this.config.model_type); - let errorMessage = `The current model class (${this.constructor.name}) is not compatible with \`.generate()\`, as it doesn't have a language model head.` - if (possibleTypes) { - errorMessage += ` Please use one of the following classes instead: {'${possibleTypes.constructor.name}'}`; + // TODO: support multiple options + const possibleInfo = MODEL_WITH_LM_HEAD_MAPPING_NAMES.get(this.config.model_type); + const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this.constructor); + let errorMessage = `The current model class (${modelName}) is not compatible with \`.generate()\`, as it doesn't have a language model head.` + if (possibleInfo) { + errorMessage += ` Please use the following class instead: '${possibleInfo[0]}'`; } throw Error(errorMessage); } @@ -3348,174 +3350,174 @@ export class PretrainedMixin { } const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([ - ['bert', BertModel], - ['camembert', CamembertModel], - ['deberta', DebertaModel], - ['deberta-v2', DebertaV2Model], - ['mpnet', MPNetModel], - ['albert', AlbertModel], - ['distilbert', DistilBertModel], - ['roberta', RobertaModel], - ['xlm', XLMModel], - ['xlm-roberta', XLMRobertaModel], - ['clip', CLIPModel], - ['mobilebert', MobileBertModel], - ['squeezebert', SqueezeBertModel], - ['wav2vec2', Wav2Vec2Model], - ['wavlm', WavLMModel], - - ['detr', DetrModel], - ['vit', ViTModel], - ['mobilevit', MobileViTModel], - ['beit', BeitModel], - ['deit', DeiTModel], - ['resnet', ResNetModel], - ['swin', SwinModel], - ['yolos', YolosModel], - - ['sam', SamModel], // TODO change to encoder-decoder when model is split correctly + ['bert', ['BertModel', BertModel]], + ['camembert', ['CamembertModel', CamembertModel]], + ['deberta', ['DebertaModel', DebertaModel]], + ['deberta-v2', ['DebertaV2Model', DebertaV2Model]], + ['mpnet', ['MPNetModel', MPNetModel]], + ['albert', ['AlbertModel', AlbertModel]], + ['distilbert', ['DistilBertModel', DistilBertModel]], + ['roberta', ['RobertaModel', RobertaModel]], + ['xlm', ['XLMModel', XLMModel]], + ['xlm-roberta', ['XLMRobertaModel', XLMRobertaModel]], + ['clip', ['CLIPModel', CLIPModel]], + ['mobilebert', ['MobileBertModel', MobileBertModel]], + ['squeezebert', ['SqueezeBertModel', SqueezeBertModel]], + ['wav2vec2', ['Wav2Vec2Model', Wav2Vec2Model]], + ['wavlm', ['WavLMModel', WavLMModel]], + + ['detr', ['DetrModel', DetrModel]], + ['vit', ['ViTModel', ViTModel]], + ['mobilevit', ['MobileViTModel', MobileViTModel]], + ['beit', ['BeitModel', BeitModel]], + ['deit', ['DeiTModel', DeiTModel]], + ['resnet', ['ResNetModel', ResNetModel]], + ['swin', ['SwinModel', SwinModel]], + ['yolos', ['YolosModel', YolosModel]], + + ['sam', ['SamModel', SamModel]], // TODO change to encoder-decoder when model is split correctly ]); const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map([ - ['t5', T5Model], - ['mt5', MT5Model], - ['bart', BartModel], - ['mbart', MBartModel], - ['marian', MarianModel], - ['whisper', WhisperModel], - ['m2m_100', M2M100Model], + ['t5', ['T5Model', T5Model]], + ['mt5', ['MT5Model', MT5Model]], + ['bart', ['BartModel', BartModel]], + ['mbart', ['MBartModel', MBartModel]], + ['marian', ['MarianModel', MarianModel]], + ['whisper', ['WhisperModel', WhisperModel]], + ['m2m_100', ['M2M100Model', M2M100Model]], ]); const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([ - ['bloom', BloomModel], - ['gpt2', GPT2Model], - ['gptj', GPTJModel], - ['gpt_bigcode', GPTBigCodeModel], - ['gpt_neo', GPTNeoModel], - ['gpt_neox', GPTNeoXModel], - ['codegen', CodeGenModel], - ['llama', LlamaModel], - ['mpt', MptModel], - ['opt', OPTModel], + ['bloom', ['BloomModel', BloomModel]], + ['gpt2', ['GPT2Model', GPT2Model]], + ['gptj', ['GPTJModel', GPTJModel]], + ['gpt_bigcode', ['GPTBigCodeModel', GPTBigCodeModel]], + ['gpt_neo', ['GPTNeoModel', GPTNeoModel]], + ['gpt_neox', ['GPTNeoXModel', GPTNeoXModel]], + ['codegen', ['CodeGenModel', CodeGenModel]], + ['llama', ['LlamaModel', LlamaModel]], + ['mpt', ['MptModel', MptModel]], + ['opt', ['OPTModel', OPTModel]], ]); const MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = new Map([ - ['bert', BertForSequenceClassification], - ['camembert', CamembertForSequenceClassification], - ['deberta', DebertaForSequenceClassification], - ['deberta-v2', DebertaV2ForSequenceClassification], - ['mpnet', MPNetForSequenceClassification], - ['albert', AlbertForSequenceClassification], - ['distilbert', DistilBertForSequenceClassification], - ['roberta', RobertaForSequenceClassification], - ['xlm', XLMForSequenceClassification], - ['xlm-roberta', XLMRobertaForSequenceClassification], - ['bart', BartForSequenceClassification], - ['mbart', MBartForSequenceClassification], - ['mobilebert', MobileBertForSequenceClassification], - ['squeezebert', SqueezeBertForSequenceClassification], + ['bert', ['BertForSequenceClassification', BertForSequenceClassification]], + ['camembert', ['CamembertForSequenceClassification', CamembertForSequenceClassification]], + ['deberta', ['DebertaForSequenceClassification', DebertaForSequenceClassification]], + ['deberta-v2', ['DebertaV2ForSequenceClassification', DebertaV2ForSequenceClassification]], + ['mpnet', ['MPNetForSequenceClassification', MPNetForSequenceClassification]], + ['albert', ['AlbertForSequenceClassification', AlbertForSequenceClassification]], + ['distilbert', ['DistilBertForSequenceClassification', DistilBertForSequenceClassification]], + ['roberta', ['RobertaForSequenceClassification', RobertaForSequenceClassification]], + ['xlm', ['XLMForSequenceClassification', XLMForSequenceClassification]], + ['xlm-roberta', ['XLMRobertaForSequenceClassification', XLMRobertaForSequenceClassification]], + ['bart', ['BartForSequenceClassification', BartForSequenceClassification]], + ['mbart', ['MBartForSequenceClassification', MBartForSequenceClassification]], + ['mobilebert', ['MobileBertForSequenceClassification', MobileBertForSequenceClassification]], + ['squeezebert', ['SqueezeBertForSequenceClassification', SqueezeBertForSequenceClassification]], ]); const MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = new Map([ - ['bert', BertForTokenClassification], - ['camembert', CamembertForTokenClassification], - ['deberta', DebertaForTokenClassification], - ['deberta-v2', DebertaV2ForTokenClassification], - ['mpnet', MPNetForTokenClassification], - ['distilbert', DistilBertForTokenClassification], - ['roberta', RobertaForTokenClassification], - ['xlm', XLMForTokenClassification], - ['xlm-roberta', XLMRobertaForTokenClassification], + ['bert', ['BertForTokenClassification', BertForTokenClassification]], + ['camembert', ['CamembertForTokenClassification', CamembertForTokenClassification]], + ['deberta', ['DebertaForTokenClassification', DebertaForTokenClassification]], + ['deberta-v2', ['DebertaV2ForTokenClassification', DebertaV2ForTokenClassification]], + ['mpnet', ['MPNetForTokenClassification', MPNetForTokenClassification]], + ['distilbert', ['DistilBertForTokenClassification', DistilBertForTokenClassification]], + ['roberta', ['RobertaForTokenClassification', RobertaForTokenClassification]], + ['xlm', ['XLMForTokenClassification', XLMForTokenClassification]], + ['xlm-roberta', ['XLMRobertaForTokenClassification', XLMRobertaForTokenClassification]], ]); const MODEL_FOR_SEQ_2_SEQ_MAPPING_NAMES = new Map([ - ['t5', T5ForConditionalGeneration], - ['mt5', MT5ForConditionalGeneration], - ['bart', BartForConditionalGeneration], - ['mbart', MBartForConditionalGeneration], - ['whisper', WhisperForConditionalGeneration], - ['marian', MarianMTModel], - ['m2m_100', M2M100ForConditionalGeneration], + ['t5', ['T5ForConditionalGeneration', T5ForConditionalGeneration]], + ['mt5', ['MT5ForConditionalGeneration', MT5ForConditionalGeneration]], + ['bart', ['BartForConditionalGeneration', BartForConditionalGeneration]], + ['mbart', ['MBartForConditionalGeneration', MBartForConditionalGeneration]], + ['whisper', ['WhisperForConditionalGeneration', WhisperForConditionalGeneration]], + ['marian', ['MarianMTModel', MarianMTModel]], + ['m2m_100', ['M2M100ForConditionalGeneration', M2M100ForConditionalGeneration]], ]); const MODEL_WITH_LM_HEAD_MAPPING_NAMES = new Map([ - ['bloom', BloomForCausalLM], - ['gpt2', GPT2LMHeadModel], - ['gptj', GPTJForCausalLM], - ['gpt_bigcode', GPTBigCodeForCausalLM], - ['gpt_neo', GPTNeoForCausalLM], - ['gpt_neox', GPTNeoXForCausalLM], - ['codegen', CodeGenForCausalLM], - ['llama', LlamaForCausalLM], - ['mpt', MptForCausalLM], - ['opt', OPTForCausalLM], + ['bloom', ['BloomForCausalLM', BloomForCausalLM]], + ['gpt2', ['GPT2LMHeadModel', GPT2LMHeadModel]], + ['gptj', ['GPTJForCausalLM', GPTJForCausalLM]], + ['gpt_bigcode', ['GPTBigCodeForCausalLM', GPTBigCodeForCausalLM]], + ['gpt_neo', ['GPTNeoForCausalLM', GPTNeoForCausalLM]], + ['gpt_neox', ['GPTNeoXForCausalLM', GPTNeoXForCausalLM]], + ['codegen', ['CodeGenForCausalLM', CodeGenForCausalLM]], + ['llama', ['LlamaForCausalLM', LlamaForCausalLM]], + ['mpt', ['MptForCausalLM', MptForCausalLM]], + ['opt', ['OPTForCausalLM', OPTForCausalLM]], ]); const MODEL_FOR_MASKED_LM_MAPPING_NAMES = new Map([ - ['bert', BertForMaskedLM], - ['camembert', CamembertForMaskedLM], - ['deberta', DebertaForMaskedLM], - ['deberta-v2', DebertaV2ForMaskedLM], - ['mpnet', MPNetForMaskedLM], - ['albert', AlbertForMaskedLM], - ['distilbert', DistilBertForMaskedLM], - ['roberta', RobertaForMaskedLM], - ['xlm', XLMWithLMHeadModel], - ['xlm-roberta', XLMRobertaForMaskedLM], - ['mobilebert', MobileBertForMaskedLM], - ['squeezebert', SqueezeBertForMaskedLM], + ['bert', ['BertForMaskedLM', BertForMaskedLM]], + ['camembert', ['CamembertForMaskedLM', CamembertForMaskedLM]], + ['deberta', ['DebertaForMaskedLM', DebertaForMaskedLM]], + ['deberta-v2', ['DebertaV2ForMaskedLM', DebertaV2ForMaskedLM]], + ['mpnet', ['MPNetForMaskedLM', MPNetForMaskedLM]], + ['albert', ['AlbertForMaskedLM', AlbertForMaskedLM]], + ['distilbert', ['DistilBertForMaskedLM', DistilBertForMaskedLM]], + ['roberta', ['RobertaForMaskedLM', RobertaForMaskedLM]], + ['xlm', ['XLMWithLMHeadModel', XLMWithLMHeadModel]], + ['xlm-roberta', ['XLMRobertaForMaskedLM', XLMRobertaForMaskedLM]], + ['mobilebert', ['MobileBertForMaskedLM', MobileBertForMaskedLM]], + ['squeezebert', ['SqueezeBertForMaskedLM', SqueezeBertForMaskedLM]], ]); const MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = new Map([ - ['bert', BertForQuestionAnswering], - ['camembert', CamembertForQuestionAnswering], - ['deberta', DebertaForQuestionAnswering], - ['deberta-v2', DebertaV2ForQuestionAnswering], - ['mpnet', MPNetForQuestionAnswering], - ['albert', AlbertForQuestionAnswering], - ['distilbert', DistilBertForQuestionAnswering], - ['roberta', RobertaForQuestionAnswering], - ['xlm', XLMForQuestionAnswering], - ['xlm-roberta', XLMRobertaForQuestionAnswering], - ['mobilebert', MobileBertForQuestionAnswering], - ['squeezebert', SqueezeBertForQuestionAnswering], + ['bert', ['BertForQuestionAnswering', BertForQuestionAnswering]], + ['camembert', ['CamembertForQuestionAnswering', CamembertForQuestionAnswering]], + ['deberta', ['DebertaForQuestionAnswering', DebertaForQuestionAnswering]], + ['deberta-v2', ['DebertaV2ForQuestionAnswering', DebertaV2ForQuestionAnswering]], + ['mpnet', ['MPNetForQuestionAnswering', MPNetForQuestionAnswering]], + ['albert', ['AlbertForQuestionAnswering', AlbertForQuestionAnswering]], + ['distilbert', ['DistilBertForQuestionAnswering', DistilBertForQuestionAnswering]], + ['roberta', ['RobertaForQuestionAnswering', RobertaForQuestionAnswering]], + ['xlm', ['XLMForQuestionAnswering', XLMForQuestionAnswering]], + ['xlm-roberta', ['XLMRobertaForQuestionAnswering', XLMRobertaForQuestionAnswering]], + ['mobilebert', ['MobileBertForQuestionAnswering', MobileBertForQuestionAnswering]], + ['squeezebert', ['SqueezeBertForQuestionAnswering', SqueezeBertForQuestionAnswering]], ]); const MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = new Map([ - ['vision-encoder-decoder', VisionEncoderDecoderModel], + ['vision-encoder-decoder', ['VisionEncoderDecoderModel', VisionEncoderDecoderModel]], ]); const MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = new Map([ - ['vit', ViTForImageClassification], - ['mobilevit', MobileViTForImageClassification], - ['beit', BeitForImageClassification], - ['deit', DeiTForImageClassification], - ['resnet', ResNetForImageClassification], - ['swin', SwinForImageClassification], + ['vit', ['ViTForImageClassification', ViTForImageClassification]], + ['mobilevit', ['MobileViTForImageClassification', MobileViTForImageClassification]], + ['beit', ['BeitForImageClassification', BeitForImageClassification]], + ['deit', ['DeiTForImageClassification', DeiTForImageClassification]], + ['resnet', ['ResNetForImageClassification', ResNetForImageClassification]], + ['swin', ['SwinForImageClassification', SwinForImageClassification]], ]); const MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = new Map([ - ['detr', DetrForObjectDetection], - ['yolos', YolosForObjectDetection], + ['detr', ['DetrForObjectDetection', DetrForObjectDetection]], + ['yolos', ['YolosForObjectDetection', YolosForObjectDetection]], ]); const MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = new Map([ - ['detr', DetrForSegmentation], + ['detr', ['DetrForSegmentation', DetrForSegmentation]], ]); const MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = new Map([ - ['sam', SamModel], + ['sam', ['SamModel', SamModel]], ]); const MODEL_FOR_CTC_MAPPING_NAMES = new Map([ - ['wav2vec2', Wav2Vec2ForCTC], - ['wavlm', WavLMForCTC], + ['wav2vec2', ['Wav2Vec2ForCTC', Wav2Vec2ForCTC]], + ['wavlm', ['WavLMForCTC', WavLMForCTC]], ]); const MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = new Map([ - ['wav2vec2', Wav2Vec2ForSequenceClassification], - ['wavlm', WavLMForSequenceClassification], + ['wav2vec2', ['Wav2Vec2ForSequenceClassification', Wav2Vec2ForSequenceClassification]], + ['wavlm', ['WavLMForSequenceClassification', WavLMForSequenceClassification]], ]); @@ -3540,12 +3542,23 @@ const MODEL_CLASS_TYPE_MAPPING = [ for (const [mappings, type] of MODEL_CLASS_TYPE_MAPPING) { // @ts-ignore - for (const model of mappings.values()) { - // @ts-ignore - MODEL_TYPE_MAPPING.set(model.name, type); + for (const [name, model] of mappings.values()) { + MODEL_TYPE_MAPPING.set(name, type); + MODEL_CLASS_TO_NAME_MAPPING.set(model, name); + MODEL_NAME_TO_CLASS_MAPPING.set(name, model); } } +const CUSTOM_MAPPING = [ + ['CLIPTextModelWithProjection', CLIPTextModelWithProjection, MODEL_TYPES.EncoderOnly], + ['CLIPVisionModelWithProjection', CLIPVisionModelWithProjection, MODEL_TYPES.EncoderOnly], +] +for (const [name, model, type] of CUSTOM_MAPPING) { + MODEL_TYPE_MAPPING.set(name, type); + MODEL_CLASS_TO_NAME_MAPPING.set(model, name); + MODEL_NAME_TO_CLASS_MAPPING.set(name, model); +} + /** * Helper class which is used to instantiate pretrained models with the `from_pretrained` function.