From cdb48142e5ad820924f9caa168d7930484c8e76e Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Fri, 8 Sep 2023 03:30:32 +0200 Subject: [PATCH] Use enum object instead of classes for model types Fixes https://github.com/xenova/transformers.js/issues/283 --- src/models.js | 69 +++++++++++++++++++++++++-------------------------- 1 file changed, 34 insertions(+), 35 deletions(-) diff --git a/src/models.js b/src/models.js index cab89fe8f..b7f0c8214 100644 --- a/src/models.js +++ b/src/models.js @@ -88,14 +88,13 @@ const { InferenceSession, Tensor: ONNXTensor } = ONNX; ////////////////////////////////////////////////// // Model types: used internally -class ModelType { }; - -// Either encoder-only or encoder-decoder (and will be decided by `model.config.is_encoder_decoder`) -class EncoderOnlyModelType extends ModelType { }; -class EncoderDecoderModelType extends ModelType { }; -class Seq2SeqModelType extends EncoderDecoderModelType { }; -class Vision2SeqModelType extends EncoderDecoderModelType { }; -class DecoderOnlyModelType extends ModelType { }; +const MODEL_TYPES = { + EncoderOnly: 0, + EncoderDecoder: 1, + Seq2Seq: 2, + Vision2Seq: 3, + DecoderOnly: 4, +} ////////////////////////////////////////////////// @@ -104,8 +103,8 @@ class DecoderOnlyModelType extends ModelType { }; // Will be populated fully later const MODEL_TYPE_MAPPING = new Map([ - ['CLIPTextModelWithProjection', EncoderOnlyModelType], - ['CLIPVisionModelWithProjection', EncoderOnlyModelType], + ['CLIPTextModelWithProjection', MODEL_TYPES.EncoderOnly], + ['CLIPVisionModelWithProjection', MODEL_TYPES.EncoderOnly], ]); /** @@ -597,7 +596,7 @@ export class PreTrainedModel extends Callable { this._getStartBeams = null; this._updateBeam = null; this._forward = null; - if (modelType === DecoderOnlyModelType) { + if (modelType === MODEL_TYPES.DecoderOnly) { this.can_generate = true; this._runBeam = decoderRunBeam; @@ -605,7 +604,7 @@ export class PreTrainedModel extends Callable { this._updateBeam = decoderUpdatebeam; this._forward = decoderForward; - } else if (modelType === Seq2SeqModelType || modelType === Vision2SeqModelType) { + } else if (modelType === MODEL_TYPES.Seq2Seq || modelType === MODEL_TYPES.Vision2Seq) { this.can_generate = true; this._runBeam = seq2seqRunBeam; @@ -613,10 +612,10 @@ export class PreTrainedModel extends Callable { this._updateBeam = seq2seqUpdatebeam; this._forward = seq2seqForward; - } else if (modelType === EncoderDecoderModelType) { + } else if (modelType === MODEL_TYPES.EncoderDecoder) { this._forward = encoderForward; - } else { // should be EncoderOnlyModelType + } else { // should be MODEL_TYPES.EncoderOnly this._forward = encoderForward; } } @@ -675,13 +674,13 @@ export class PreTrainedModel extends Callable { let modelType = MODEL_TYPE_MAPPING.get(this.name); let info; - if (modelType === DecoderOnlyModelType) { + if (modelType === MODEL_TYPES.DecoderOnly) { info = await Promise.all([ AutoConfig.from_pretrained(pretrained_model_name_or_path, options), constructSession(pretrained_model_name_or_path, options.model_file_name ?? 'decoder_model_merged', options), ]); - } else if (modelType === Seq2SeqModelType || modelType === Vision2SeqModelType) { + } else if (modelType === MODEL_TYPES.Seq2Seq || modelType === MODEL_TYPES.Vision2Seq) { info = await Promise.all([ AutoConfig.from_pretrained(pretrained_model_name_or_path, options), constructSession(pretrained_model_name_or_path, 'encoder_model', options), @@ -689,15 +688,15 @@ export class PreTrainedModel extends Callable { getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options), ]); - } else if (modelType === EncoderDecoderModelType) { + } else if (modelType === MODEL_TYPES.EncoderDecoder) { info = await Promise.all([ AutoConfig.from_pretrained(pretrained_model_name_or_path, options), constructSession(pretrained_model_name_or_path, 'encoder_model', options), constructSession(pretrained_model_name_or_path, 'decoder_model_merged', options), ]); - } else { // should be EncoderOnlyModelType - if (modelType !== EncoderOnlyModelType) { + } else { // should be MODEL_TYPES.EncoderOnly + if (modelType !== MODEL_TYPES.EncoderOnly) { console.warn(`Model type for ${this.name} not found, assuming encoder-only architecture. Please report this at https://github.com/xenova/transformers.js/issues/new/choose.`) } info = await Promise.all([ @@ -3554,22 +3553,22 @@ const MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = new Map([ const MODEL_CLASS_TYPE_MAPPING = [ - [MODEL_MAPPING_NAMES_ENCODER_ONLY, EncoderOnlyModelType], - [MODEL_MAPPING_NAMES_ENCODER_DECODER, EncoderDecoderModelType], - [MODEL_MAPPING_NAMES_DECODER_ONLY, DecoderOnlyModelType], - [MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, EncoderOnlyModelType], - [MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, EncoderOnlyModelType], - [MODEL_FOR_SEQ_2_SEQ_MAPPING_NAMES, Seq2SeqModelType], - [MODEL_WITH_LM_HEAD_MAPPING_NAMES, DecoderOnlyModelType], - [MODEL_FOR_MASKED_LM_MAPPING_NAMES, EncoderOnlyModelType], - [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, EncoderOnlyModelType], - [MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, Vision2SeqModelType], - [MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, EncoderOnlyModelType], - [MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, EncoderOnlyModelType], - [MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, EncoderOnlyModelType], - [MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, EncoderOnlyModelType], - [MODEL_FOR_CTC_MAPPING_NAMES, EncoderOnlyModelType], - [MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, EncoderOnlyModelType], + [MODEL_MAPPING_NAMES_ENCODER_ONLY, MODEL_TYPES.EncoderOnly], + [MODEL_MAPPING_NAMES_ENCODER_DECODER, MODEL_TYPES.EncoderDecoder], + [MODEL_MAPPING_NAMES_DECODER_ONLY, MODEL_TYPES.DecoderOnly], + [MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], + [MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], + [MODEL_FOR_SEQ_2_SEQ_MAPPING_NAMES, MODEL_TYPES.Seq2Seq], + [MODEL_WITH_LM_HEAD_MAPPING_NAMES, MODEL_TYPES.DecoderOnly], + [MODEL_FOR_MASKED_LM_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], + [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], + [MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, MODEL_TYPES.Vision2Seq], + [MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], + [MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], + [MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], + [MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], + [MODEL_FOR_CTC_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], + [MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], ]; for (const [mappings, type] of MODEL_CLASS_TYPE_MAPPING) {