Skip to content

Commit

Permalink
[WIP] Add support for deepseek-ai/Janus-1.3B
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Oct 30, 2024
1 parent 03f6662 commit d040e81
Show file tree
Hide file tree
Showing 6 changed files with 324 additions and 6 deletions.
7 changes: 4 additions & 3 deletions src/configs.js
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ function getNormalizedConfig(config) {
case 'musicgen':
init_normalized_config = getNormalizedConfig(config.decoder);
break;
case 'multi_modality':
init_normalized_config = getNormalizedConfig(config.language_config);
break;

// Decoder-only models
case 'gpt2':
Expand Down Expand Up @@ -216,14 +219,12 @@ function getNormalizedConfig(config) {
*/
export function getKeyValueShapes(config, {
prefix = 'past_key_values',
batch_size=1,
} = {}) {
/** @type {Record<string, number[]>} */
const decoderFeeds = {};
const normalized_config = config.normalized_config;

// TODO support batches (i.e., batch_size > 1)
const batch_size = 1;

if (normalized_config.is_encoder_decoder && (
'num_encoder_heads' in normalized_config && 'num_decoder_heads' in normalized_config
)) {
Expand Down
186 changes: 183 additions & 3 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ import {
} from './utils/generic.js';

import {
isIntegralNumber,
mergeArrays,
pick,
} from './utils/core.js';
Expand Down Expand Up @@ -99,6 +98,7 @@ import {

import {
cat,
full,
full_like,
mean,
ones,
Expand All @@ -108,6 +108,7 @@ import {
Tensor,
zeros_like,
} from './utils/tensor.js';
import { RawImage } from './utils/image.js';

import { dynamic_time_warping, medianFilter } from './utils/maths.js';
import { EosTokenCriteria, MaxLengthCriteria, StoppingCriteriaList } from './generation/stopping_criteria.js';
Expand All @@ -128,6 +129,7 @@ const MODEL_TYPES = {
MaskGeneration: 5,
ImageTextToText: 6,
Musicgen: 7,
MultiModality: 8,
}
//////////////////////////////////////////////////

Expand Down Expand Up @@ -386,7 +388,7 @@ async function sessionRun(session, inputs) {
} catch (e) {
// This usually occurs when the inputs are of the wrong type.
console.error(`An error occurred during model execution: "${e}".`);
console.error('Inputs given to model:', checkedInputs);
console.error('Inputs given to model:', checkedInputs)
throw e;
}
}
Expand Down Expand Up @@ -716,6 +718,52 @@ function image_text_to_text_prepare_inputs_for_generation(self, ...args) {
}
}

function multimodality_prepare_inputs_for_generation(self, input_ids, model_inputs, generation_config) {
const has_past_key_values = !!model_inputs.past_key_values;

if (generation_config.guidance_scale !== null && generation_config.guidance_scale > 1) {
if (has_past_key_values) {
model_inputs.input_ids = cat([
model_inputs.input_ids,
model_inputs.input_ids,
], 0)
// NOTE: attention_mask handled in generation
} else {
model_inputs.input_ids = cat([
model_inputs.input_ids,
full_like(model_inputs.input_ids, BigInt(generation_config.pad_token_id)),
], 0);
model_inputs.attention_mask = cat([
model_inputs.attention_mask,
full_like(model_inputs.attention_mask, 0n),
], 0);
}
}

if (has_past_key_values || !model_inputs.pixel_values) {
model_inputs.pixel_values = full([0, 0, 3, 384, 384], 1.0);
}

if (has_past_key_values) {
const num_img_tokens = 0;
const num_text_tokens = 1;
const has_image = num_img_tokens > 0 ? 1 : 0;

const batch_size = 1;
model_inputs.images_seq_mask = new Tensor(
'bool',
new Array(num_img_tokens + num_text_tokens).fill(true).fill(false, 0, num_text_tokens),
[batch_size, num_img_tokens + num_text_tokens],
);
model_inputs.images_emb_mask = new Tensor(
'bool',
new Array(num_img_tokens).fill(!!has_image),
[batch_size, 1, num_img_tokens],
);
}
return model_inputs;
}

//////////////////////////////////////////////////

//////////////////////////////////////////////////
Expand Down Expand Up @@ -769,6 +817,11 @@ export class PreTrainedModel extends Callable {
this._prepare_inputs_for_generation = image_text_to_text_prepare_inputs_for_generation;
break;

case MODEL_TYPES.MultiModality:
this.can_generate = true;
this._prepare_inputs_for_generation = multimodality_prepare_inputs_for_generation;
break;

default:
// should be MODEL_TYPES.EncoderOnly
this._forward = encoderForward;
Expand Down Expand Up @@ -912,6 +965,21 @@ export class PreTrainedModel extends Callable {
}, options),
]);

} else if (modelType === MODEL_TYPES.MultiModality) {
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
prepare_inputs_embeds: 'prepare_inputs_embeds',
model: 'language_model',
lm_head: 'lm_head',
gen_head: 'gen_head',
gen_img_embeds: 'gen_img_embeds',
image_decode: 'image_decode',
}, options),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
]);

} else { // should be MODEL_TYPES.EncoderOnly
if (modelType !== MODEL_TYPES.EncoderOnly) {
console.warn(`Model type for '${modelName ?? config?.model_type}' not found, assuming encoder-only architecture. Please report this at ${GITHUB_ISSUE_URL}.`)
Expand Down Expand Up @@ -1658,7 +1726,8 @@ export class PreTrainedModel extends Callable {
const dtype = session?.config?.kv_cache_dtype ?? 'float32';
const empty = (dtype === 'float16') ? new Uint16Array() : [];

const shapes = getKeyValueShapes(this.config);
const batch_size = decoderFeeds[this.main_input_name].dims[0];
const shapes = getKeyValueShapes(this.config, { batch_size });

for (const name in shapes) {
decoderFeeds[name] = new Tensor(dtype, empty, shapes[name]);
Expand Down Expand Up @@ -5954,6 +6023,111 @@ export class DecisionTransformerModel extends DecisionTransformerPreTrainedModel

//////////////////////////////////////////////////

export class MultiModalityPreTrainedModel extends PreTrainedModel { }
export class MultiModalityCausalLM extends MultiModalityPreTrainedModel {
forward_params = [
// prepare_inputs_embeds
'input_ids',
'pixel_values',
'images_seq_mask',
'images_emb_mask',

// language_model
'attention_mask',
'position_ids',
'past_key_values',
];

constructor(...args) {
super(...args);

// State-based approach to switch out which heads to use during generation
this._generation_mode = 'text';
}

async forward(model_inputs) {
const mode = this._generation_mode ?? 'text';

// TODO support re-using PKVs for input_ids.dims[1] !== 1
// if (model_inputs.past_key_values) {
// // && model_inputs.input_ids.dims[1] === 1
// }

let output_1;
if (mode === 'text' || !model_inputs.past_key_values) {
const session = this.sessions['prepare_inputs_embeds'];
const prep_inputs = pick(model_inputs, session.inputNames);
output_1 = await sessionRun(session, prep_inputs);
} else {
const session = this.sessions['gen_img_embeds'];
const prep_inputs = pick({
image_ids: model_inputs.input_ids,
}, session.inputNames);
output_1 = await sessionRun(session, prep_inputs);
}

const input_2 = { ...model_inputs, ...output_1 }
const output_2 = await decoderForward(this, input_2);

const head = this.sessions[
mode === 'text'
? 'lm_head'
: 'gen_head'
];
if (!head) {
throw new Error(`Unable to find "${head}" generation head`);
}

const output_3 = await sessionRun(head, pick(output_2, head.inputNames))

return {
...output_1,
...output_2,
...output_3,
};
}

/**
* @param {import('./generation/parameters.js').GenerationFunctionParameters} options
*/
async generate(options) {
this._generation_mode = 'text';
return super.generate(options);
}

/**
* @param {import('./generation/parameters.js').GenerationFunctionParameters} options
*/
async generate_images(options) {
this._generation_mode = 'image';

const start_num_tokens = (options.inputs ?? options[this.main_input_name]).dims[1];
const all_tokens = await super.generate(options);

const generated_tokens = (/** @type {Tensor} */(all_tokens)).slice(null, [start_num_tokens, null])

const image_decode = this.sessions['image_decode'];
const { decoded_image } = await sessionRun(image_decode, {
generated_tokens,
});

// Equivalent to `np.clip((dec + 1) / 2 * 255, 0, 255)`
const clamped = decoded_image
.add_(1)
.mul_(255 / 2)
.clamp_(0, 255)
.to('uint8');

// Return as a list of images
const images = [];
for (const tensor of clamped) {
const img = RawImage.fromTensor(tensor);
images.push(img);
}
return images;
}
}

//////////////////////////////////////////////////
// AutoModels, used to simplify construction of PreTrainedModels
// (uses config to instantiate correct class)
Expand Down Expand Up @@ -6232,6 +6406,11 @@ const MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = new Map([
['stablelm', ['StableLmForCausalLM', StableLmForCausalLM]],
]);

const MODEL_FOR_MULTIMODALITY_MAPPING_NAMES = new Map([
['multi_modality', ['MultiModalityCausalLM', MultiModalityCausalLM]],
]);


const MODEL_FOR_MASKED_LM_MAPPING_NAMES = new Map([
['bert', ['BertForMaskedLM', BertForMaskedLM]],
['roformer', ['RoFormerForMaskedLM', RoFormerForMaskedLM]],
Expand Down Expand Up @@ -6404,6 +6583,7 @@ const MODEL_CLASS_TYPE_MAPPING = [
[MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_TYPES.Seq2Seq],
[MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES, MODEL_TYPES.Seq2Seq],
[MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_TYPES.DecoderOnly],
[MODEL_FOR_MULTIMODALITY_MAPPING_NAMES, MODEL_TYPES.MultiModality],
[MODEL_FOR_MASKED_LM_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, MODEL_TYPES.Vision2Seq],
Expand Down
1 change: 1 addition & 0 deletions src/models/image_processors.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export * from './donut/image_processing_donut.js'
export * from './dpt/image_processing_dpt.js'
export * from './efficientnet/image_processing_efficientnet.js'
export * from './glpn/image_processing_glpn.js'
export * from './janus/image_processing_janus.js'
export * from './jina_clip/image_processing_jina_clip.js'
export * from './mask2former/image_processing_mask2former.js'
export * from './maskformer/image_processing_maskformer.js'
Expand Down
26 changes: 26 additions & 0 deletions src/models/janus/image_processing_janus.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@

import {
ImageProcessor,
} from "../../base/image_processors_utils.js";

export class VLMImageProcessor extends ImageProcessor {
constructor(config) {
super({
do_pad: true,
pad_size: {
width: config.image_size,
height: config.image_size,
},
...config,
});
this.constant_values = this.config.background_color.map(x => x * this.rescale_factor)
}

pad_image(pixelData, imgDims, padSize, options) {
return super.pad_image(pixelData, imgDims, padSize, {
constant_values: this.constant_values,
center: true,
...options,
});
}
}
Loading

0 comments on commit d040e81

Please sign in to comment.