Skip to content

Commit

Permalink
Use enum object instead of classes for model types
Browse files Browse the repository at this point in the history
Fixes #283
  • Loading branch information
xenova committed Sep 8, 2023
1 parent a140648 commit cdb4814
Showing 1 changed file with 34 additions and 35 deletions.
69 changes: 34 additions & 35 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,13 @@ const { InferenceSession, Tensor: ONNXTensor } = ONNX;

//////////////////////////////////////////////////
// Model types: used internally
class ModelType { };

// Either encoder-only or encoder-decoder (and will be decided by `model.config.is_encoder_decoder`)
class EncoderOnlyModelType extends ModelType { };
class EncoderDecoderModelType extends ModelType { };
class Seq2SeqModelType extends EncoderDecoderModelType { };
class Vision2SeqModelType extends EncoderDecoderModelType { };
class DecoderOnlyModelType extends ModelType { };
const MODEL_TYPES = {
EncoderOnly: 0,
EncoderDecoder: 1,
Seq2Seq: 2,
Vision2Seq: 3,
DecoderOnly: 4,
}
//////////////////////////////////////////////////


Expand All @@ -104,8 +103,8 @@ class DecoderOnlyModelType extends ModelType { };

// Will be populated fully later
const MODEL_TYPE_MAPPING = new Map([
['CLIPTextModelWithProjection', EncoderOnlyModelType],
['CLIPVisionModelWithProjection', EncoderOnlyModelType],
['CLIPTextModelWithProjection', MODEL_TYPES.EncoderOnly],
['CLIPVisionModelWithProjection', MODEL_TYPES.EncoderOnly],
]);

/**
Expand Down Expand Up @@ -597,26 +596,26 @@ export class PreTrainedModel extends Callable {
this._getStartBeams = null;
this._updateBeam = null;
this._forward = null;
if (modelType === DecoderOnlyModelType) {
if (modelType === MODEL_TYPES.DecoderOnly) {
this.can_generate = true;

this._runBeam = decoderRunBeam;
this._getStartBeams = decoderStartBeams;
this._updateBeam = decoderUpdatebeam;
this._forward = decoderForward;

} else if (modelType === Seq2SeqModelType || modelType === Vision2SeqModelType) {
} else if (modelType === MODEL_TYPES.Seq2Seq || modelType === MODEL_TYPES.Vision2Seq) {
this.can_generate = true;

this._runBeam = seq2seqRunBeam;
this._getStartBeams = seq2seqStartBeams;
this._updateBeam = seq2seqUpdatebeam;
this._forward = seq2seqForward;

} else if (modelType === EncoderDecoderModelType) {
} else if (modelType === MODEL_TYPES.EncoderDecoder) {
this._forward = encoderForward;

} else { // should be EncoderOnlyModelType
} else { // should be MODEL_TYPES.EncoderOnly
this._forward = encoderForward;
}
}
Expand Down Expand Up @@ -675,29 +674,29 @@ export class PreTrainedModel extends Callable {
let modelType = MODEL_TYPE_MAPPING.get(this.name);

let info;
if (modelType === DecoderOnlyModelType) {
if (modelType === MODEL_TYPES.DecoderOnly) {
info = await Promise.all([
AutoConfig.from_pretrained(pretrained_model_name_or_path, options),
constructSession(pretrained_model_name_or_path, options.model_file_name ?? 'decoder_model_merged', options),
]);

} else if (modelType === Seq2SeqModelType || modelType === Vision2SeqModelType) {
} else if (modelType === MODEL_TYPES.Seq2Seq || modelType === MODEL_TYPES.Vision2Seq) {
info = await Promise.all([
AutoConfig.from_pretrained(pretrained_model_name_or_path, options),
constructSession(pretrained_model_name_or_path, 'encoder_model', options),
constructSession(pretrained_model_name_or_path, 'decoder_model_merged', options),
getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options),
]);

} else if (modelType === EncoderDecoderModelType) {
} else if (modelType === MODEL_TYPES.EncoderDecoder) {
info = await Promise.all([
AutoConfig.from_pretrained(pretrained_model_name_or_path, options),
constructSession(pretrained_model_name_or_path, 'encoder_model', options),
constructSession(pretrained_model_name_or_path, 'decoder_model_merged', options),
]);

} else { // should be EncoderOnlyModelType
if (modelType !== EncoderOnlyModelType) {
} else { // should be MODEL_TYPES.EncoderOnly
if (modelType !== MODEL_TYPES.EncoderOnly) {
console.warn(`Model type for ${this.name} not found, assuming encoder-only architecture. Please report this at https://github.com/xenova/transformers.js/issues/new/choose.`)
}
info = await Promise.all([
Expand Down Expand Up @@ -3554,22 +3553,22 @@ const MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = new Map([


const MODEL_CLASS_TYPE_MAPPING = [
[MODEL_MAPPING_NAMES_ENCODER_ONLY, EncoderOnlyModelType],
[MODEL_MAPPING_NAMES_ENCODER_DECODER, EncoderDecoderModelType],
[MODEL_MAPPING_NAMES_DECODER_ONLY, DecoderOnlyModelType],
[MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, EncoderOnlyModelType],
[MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, EncoderOnlyModelType],
[MODEL_FOR_SEQ_2_SEQ_MAPPING_NAMES, Seq2SeqModelType],
[MODEL_WITH_LM_HEAD_MAPPING_NAMES, DecoderOnlyModelType],
[MODEL_FOR_MASKED_LM_MAPPING_NAMES, EncoderOnlyModelType],
[MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, EncoderOnlyModelType],
[MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, Vision2SeqModelType],
[MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, EncoderOnlyModelType],
[MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, EncoderOnlyModelType],
[MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, EncoderOnlyModelType],
[MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, EncoderOnlyModelType],
[MODEL_FOR_CTC_MAPPING_NAMES, EncoderOnlyModelType],
[MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, EncoderOnlyModelType],
[MODEL_MAPPING_NAMES_ENCODER_ONLY, MODEL_TYPES.EncoderOnly],
[MODEL_MAPPING_NAMES_ENCODER_DECODER, MODEL_TYPES.EncoderDecoder],
[MODEL_MAPPING_NAMES_DECODER_ONLY, MODEL_TYPES.DecoderOnly],
[MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_SEQ_2_SEQ_MAPPING_NAMES, MODEL_TYPES.Seq2Seq],
[MODEL_WITH_LM_HEAD_MAPPING_NAMES, MODEL_TYPES.DecoderOnly],
[MODEL_FOR_MASKED_LM_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, MODEL_TYPES.Vision2Seq],
[MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_CTC_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
];

for (const [mappings, type] of MODEL_CLASS_TYPE_MAPPING) {
Expand Down

0 comments on commit cdb4814

Please sign in to comment.