Skip to content

Commit

Permalink
Fix more _call LSP errors + extra typings (#304)
Browse files Browse the repository at this point in the history
* types are only inferred through this assignments in constructor

* Typing PreprocessedImage

* LSP + other typings for ImageFeatureExtractor / SamImageProcessor / DetrFeatureExtractor

* Fix SamProcessor error

* Fix PretrainedOptions

* Fix double AnyTypedArray

* Update `unused` variable name

* Mark `_update` image function as private

* Update processor JSDoc

---------

Co-authored-by: Joshua Lochner <[email protected]>
  • Loading branch information
kungfooman and xenova authored Sep 27, 2023
1 parent df94965 commit 09cbb0c
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 54 deletions.
9 changes: 2 additions & 7 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,6 @@ import { executionProviders, ONNX } from './backends/onnx.js';
import { medianFilter } from './transformers.js';
const { InferenceSession, Tensor: ONNXTensor } = ONNX;

/**
* @typedef {import('./utils/hub.js').PretrainedOptions} PretrainedOptions
*/


//////////////////////////////////////////////////
// Model types: used internally
const MODEL_TYPES = {
Expand All @@ -113,7 +108,7 @@ const MODEL_CLASS_TO_NAME_MAPPING = new Map();
* Constructs an InferenceSession using a model file located at the specified path.
* @param {string} pretrained_model_name_or_path The path to the directory containing the model file.
* @param {string} fileName The name of the model file.
* @param {PretrainedOptions} options Additional options for loading the model.
* @param {import('./utils/hub.js').PretrainedOptions} options Additional options for loading the model.
* @returns {Promise<InferenceSession>} A Promise that resolves to an InferenceSession object.
* @private
*/
Expand Down Expand Up @@ -664,7 +659,7 @@ export class PreTrainedModel extends Callable {
* Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
* user or organization name, like `dbmdz/bert-base-german-cased`.
* - A path to a *directory* containing model weights, e.g., `./my_model_directory/`.
* @param {PretrainedOptions} options Additional options for loading the model.
* @param {import('./utils/hub.js').PretrainedOptions} options Additional options for loading the model.
*
* @returns {Promise<PreTrainedModel>} A new instance of the `PreTrainedModel` class.
*/
Expand Down
8 changes: 2 additions & 6 deletions src/pipelines.js
Original file line number Diff line number Diff line change
Expand Up @@ -2063,10 +2063,6 @@ const TASK_ALIASES = {
"embeddings": "feature-extraction",
}

/**
* @typedef {import('./utils/hub.js').PretrainedOptions} PretrainedOptions
*/

/**
* Utility factory method to build a [`Pipeline`] object.
*
Expand All @@ -2091,7 +2087,7 @@ const TASK_ALIASES = {
* - `"zero-shot-classification"`: will return a `ZeroShotClassificationPipeline`.
* - `"zero-shot-image-classification"`: will return a `ZeroShotImageClassificationPipeline`.
* @param {string} [model=null] The name of the pre-trained model to use. If not specified, the default model for the task will be used.
* @param {PretrainedOptions} [options] Optional parameters for the pipeline.
* @param {import('./utils/hub.js').PretrainedOptions} [options] Optional parameters for the pipeline.
* @returns {Promise<Pipeline>} A Pipeline object for the specified task.
* @throws {Error} If an unsupported pipeline is requested.
*/
Expand Down Expand Up @@ -2158,7 +2154,7 @@ export async function pipeline(
* Helper function to get applicable model, tokenizer, or processor classes for a given model.
* @param {Map<string, any>} mapping The mapping of names to classes, arrays of classes, or null.
* @param {string} model The name of the model to load.
* @param {PretrainedOptions} pretrainedOptions The options to pass to the `from_pretrained` method.
* @param {import('./utils/hub.js').PretrainedOptions} pretrainedOptions The options to pass to the `from_pretrained` method.
* @private
*/
async function loadItems(mapping, model, pretrainedOptions) {
Expand Down
85 changes: 65 additions & 20 deletions src/processors.js
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ function post_process_object_detection(outputs, threshold = 0.5, target_sizes =
return toReturn;
}

/**
* Named tuple to indicate the order we are using is (height x width), even though
* the Graphics’ industry standard is (width x height).
* @typedef {[height: number, width: number]} HeightWidth
*/

/**
* Base class for feature extractors.
*
Expand All @@ -137,6 +143,13 @@ export class FeatureExtractor extends Callable {
}
}

/**
* @typedef {object} ImageFeatureExtractorResult
* @property {Tensor} pixel_values The pixel values of the batched preprocessed images.
* @property {HeightWidth[]} original_sizes Array of two-dimensional tuples like [[480, 640]].
* @property {HeightWidth[]} reshaped_input_sizes Array of two-dimensional tuples like [[1000, 1330]].
*/

/**
* Feature extractor for image models.
*
Expand Down Expand Up @@ -216,11 +229,18 @@ export class ImageFeatureExtractor extends FeatureExtractor {
return await image.resize(width, height, { resample });
}

/**
* @typedef {object} PreprocessedImage
* @property {HeightWidth} original_size The original size of the image.
* @property {HeightWidth} reshaped_input_size The reshaped input size of the image.
* @property {Tensor} pixel_values The pixel values of the preprocessed image.
*/

/**
* Preprocesses the given image.
*
* @param {RawImage} image The image to preprocess.
* @returns {Promise<any>} The preprocessed image as a Tensor.
* @returns {Promise<PreprocessedImage>} The preprocessed image.
*/
async preprocess(image) {

Expand Down Expand Up @@ -316,6 +336,7 @@ export class ImageFeatureExtractor extends FeatureExtractor {
image = await image.center_crop(crop_width, crop_height);
}

/** @type {HeightWidth} */
let reshaped_input_size = [image.height, image.width];

// TODO is it okay to pad before rescaling/normalizing?
Expand Down Expand Up @@ -374,22 +395,23 @@ export class ImageFeatureExtractor extends FeatureExtractor {
* Calls the feature extraction process on an array of image
* URLs, preprocesses each image, and concatenates the resulting
* features into a single Tensor.
* @param {any} images The URL(s) of the image(s) to extract features from.
* @returns {Promise<Object>} An object containing the concatenated pixel values (and other metadata) of the preprocessed images.
* @param {any[]} images The URL(s) of the image(s) to extract features from.
* @param {...any} args Additional arguments.
* @returns {Promise<ImageFeatureExtractorResult>} An object containing the concatenated pixel values (and other metadata) of the preprocessed images.
*/
async _call(images) {
async _call(images, ...args) {
if (!Array.isArray(images)) {
images = [images];
}

let imageData = await Promise.all(images.map(x => this.preprocess(x)));
/** @type {PreprocessedImage[]} */
const imageData = await Promise.all(images.map(x => this.preprocess(x)));

// TODO:

// Concatenate pixel values
// TEMP: Add batch dimension so that concat works
imageData.forEach(x => x.pixel_values.dims = [1, ...x.pixel_values.dims]);
let pixel_values = cat(imageData.map(x => x.pixel_values));
const pixel_values = cat(imageData.map(x => x.pixel_values));

return {
pixel_values: pixel_values,
Expand All @@ -411,6 +433,12 @@ export class DeiTFeatureExtractor extends ImageFeatureExtractor { }
export class BeitFeatureExtractor extends ImageFeatureExtractor { }
export class DonutFeatureExtractor extends ImageFeatureExtractor { }

/**
* @typedef {object} DetrFeatureExtractorResultProps
* @property {Tensor} pixel_mask
* @typedef {ImageFeatureExtractorResult & DetrFeatureExtractorResultProps} DetrFeatureExtractorResult
*/

/**
* Detr Feature Extractor.
*
Expand All @@ -420,24 +448,23 @@ export class DetrFeatureExtractor extends ImageFeatureExtractor {
/**
* Calls the feature extraction process on an array of image URLs, preprocesses
* each image, and concatenates the resulting features into a single Tensor.
* @param {any} urls The URL(s) of the image(s) to extract features from.
* @returns {Promise<Object>} An object containing the concatenated pixel values of the preprocessed images.
* @param {any[]} urls The URL(s) of the image(s) to extract features from.
* @returns {Promise<DetrFeatureExtractorResult>} An object containing the concatenated pixel values of the preprocessed images.
*/
async _call(urls) {
let result = await super._call(urls);
const result = await super._call(urls);

// TODO support differently-sized images, for now assume all images are the same size.
// TODO support different mask sizes (not just 64x64)
// Currently, just fill pixel mask with 1s
let maskSize = [result.pixel_values.dims[0], 64, 64];
result.pixel_mask = new Tensor(
const maskSize = [result.pixel_values.dims[0], 64, 64];
const pixel_mask = new Tensor(
'int64',
// TODO: fix error below
new BigInt64Array(maskSize.reduce((a, b) => a * b)).fill(1n),
maskSize
);

return result;
return { ...result, pixel_mask };
}

/**
Expand Down Expand Up @@ -737,7 +764,22 @@ export class YolosFeatureExtractor extends ImageFeatureExtractor {
}
}

/**
* @typedef {object} SamImageProcessorResult
* @property {Tensor} pixel_values
* @property {HeightWidth[]} original_sizes
* @property {HeightWidth[]} reshaped_input_sizes
* @property {Tensor} input_points
*/

export class SamImageProcessor extends ImageFeatureExtractor {
/**
* @param {any[]} images The URL(s) of the image(s) to extract features from.
* @param {*} input_points A 3D or 4D array, representing the input points provided by the user.
* - 3D: `[point_batch_size, nb_points_per_image, 2]`. In this case, `batch_size` is assumed to be 1.
* - 4D: `[batch_size, point_batch_size, nb_points_per_image, 2]`.
* @returns {Promise<SamImageProcessorResult>}
*/
async _call(images, input_points) {
let {
pixel_values,
Expand All @@ -747,6 +789,7 @@ export class SamImageProcessor extends ImageFeatureExtractor {

let shape = calculateDimensions(input_points);

// TODO: add support for 2D input_points
if (shape.length === 3) {
// Correct user's input
shape = [1, ...shape];
Expand Down Expand Up @@ -1284,15 +1327,20 @@ export class Processor extends Callable {
/**
* Calls the feature_extractor function with the given input.
* @param {any} input The input to extract features from.
* @param {...any} args Additional arguments.
* @returns {Promise<any>} A Promise that resolves with the extracted features.
*/
async _call(input) {
async _call(input, ...args) {
return await this.feature_extractor(input);
}
}

export class SamProcessor extends Processor {

/**
* @param {*} images
* @param {*} input_points
* @returns {Promise<any>}
*/
async _call(images, input_points) {
return await this.feature_extractor(images, input_points);
}
Expand Down Expand Up @@ -1334,9 +1382,6 @@ export class Wav2Vec2ProcessorWithLM extends Processor {
}

//////////////////////////////////////////////////
/**
* @typedef {import('./utils/hub.js').PretrainedOptions} PretrainedOptions
*/
/**
* Helper class which is used to instantiate pretrained processors with the `from_pretrained` function.
* The chosen processor class is determined by the type specified in the processor config.
Expand Down Expand Up @@ -1400,7 +1445,7 @@ export class AutoProcessor {
* Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
* user or organization name, like `dbmdz/bert-base-german-cased`.
* - A path to a *directory* containing processor files, e.g., `./my_model_directory/`.
* @param {PretrainedOptions} options Additional options for loading the processor.
* @param {import('./utils/hub.js').PretrainedOptions} options Additional options for loading the processor.
*
* @returns {Promise<Processor>} A new instance of the Processor class.
*/
Expand Down
12 changes: 4 additions & 8 deletions src/tokenizers.js
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,11 @@ import {
CharTrie,
} from './utils/data-structures.js';

/**
* @typedef {import('./utils/hub.js').PretrainedOptions} PretrainedOptions
*/

/**
* Loads a tokenizer from the specified path.
* @param {string} pretrained_model_name_or_path The path to the tokenizer directory.
* @param {PretrainedOptions} options Additional options for loading the tokenizer.
* @returns {Promise<Array>} A promise that resolves with information about the loaded tokenizer.
* @param {import('./utils/hub.js').PretrainedOptions} options Additional options for loading the tokenizer.
* @returns {Promise<any[]>} A promise that resolves with information about the loaded tokenizer.
*/
async function loadTokenizer(pretrained_model_name_or_path, options) {

Expand Down Expand Up @@ -2228,7 +2224,7 @@ export class PreTrainedTokenizer extends Callable {
* Loads a pre-trained tokenizer from the given `pretrained_model_name_or_path`.
*
* @param {string} pretrained_model_name_or_path The path to the pre-trained tokenizer.
* @param {PretrainedOptions} options Additional options for loading the tokenizer.
* @param {import('./utils/hub.js').PretrainedOptions} options Additional options for loading the tokenizer.
*
* @throws {Error} Throws an error if the tokenizer.json or tokenizer_config.json files are not found in the `pretrained_model_name_or_path`.
* @returns {Promise<PreTrainedTokenizer>} A new instance of the `PreTrainedTokenizer` class.
Expand Down Expand Up @@ -3814,7 +3810,7 @@ export class AutoTokenizer {
* Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
* user or organization name, like `dbmdz/bert-base-german-cased`.
* - A path to a *directory* containing tokenizer files, e.g., `./my_model_directory/`.
* @param {PretrainedOptions} options Additional options for loading the tokenizer.
* @param {import('./utils/hub.js').PretrainedOptions} options Additional options for loading the tokenizer.
*
* @returns {Promise<PreTrainedTokenizer>} A new instance of the PreTrainedTokenizer class.
*/
Expand Down
2 changes: 0 additions & 2 deletions src/transformers.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
// @ts-nocheck

/**
* @file Entry point for the Transformers.js library. Only the exports from this file
* are available to the end user, and are grouped as follows:
Expand Down
14 changes: 7 additions & 7 deletions src/utils/hub.js
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,17 @@ if (!globalThis.ReadableStream) {

/**
* @typedef {Object} PretrainedOptions Options for loading a pretrained model.
* @property {boolean?} [options.quantized=true] Whether to load the 8-bit quantized version of the model (only applicable when loading model files).
* @property {function} [options.progress_callback=null] If specified, this function will be called during model construction, to provide the user with progress updates.
* @property {Object} [options.config=null] Configuration for the model to use instead of an automatically loaded configuration. Configuration can be automatically loaded when:
* @property {boolean?} [quantized=true] Whether to load the 8-bit quantized version of the model (only applicable when loading model files).
* @property {function} [progress_callback=null] If specified, this function will be called during model construction, to provide the user with progress updates.
* @property {Object} [config=null] Configuration for the model to use instead of an automatically loaded configuration. Configuration can be automatically loaded when:
* - The model is a model provided by the library (loaded with the *model id* string of a pretrained model).
* - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a configuration JSON file named *config.json* is found in the directory.
* @property {string} [options.cache_dir=null] Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.
* @property {boolean} [options.local_files_only=false] Whether or not to only look at local files (e.g., not try downloading the model).
* @property {string} [options.revision='main'] The specific model version to use. It can be a branch name, a tag name, or a commit id,
* @property {string} [cache_dir=null] Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.
* @property {boolean} [local_files_only=false] Whether or not to only look at local files (e.g., not try downloading the model).
* @property {string} [revision='main'] The specific model version to use. It can be a branch name, a tag name, or a commit id,
* since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.
* NOTE: This setting is ignored for local requests.
* @property {string} [options.model_file_name=null] If specified, load the model with this name (excluding the .onnx suffix). Currently only valid for encoder- or decoder-only models.
* @property {string} [model_file_name=null] If specified, load the model with this name (excluding the .onnx suffix). Currently only valid for encoder- or decoder-only models.
*/

class FileResponse {
Expand Down
8 changes: 6 additions & 2 deletions src/utils/image.js
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,10 @@ export class RawImage {
* @param {1|2|3|4} channels The number of channels.
*/
constructor(data, width, height, channels) {
this._update(data, width, height, channels);
this.data = data;
this.width = width;
this.height = height;
this.channels = channels;
}

/**
Expand Down Expand Up @@ -508,7 +511,8 @@ export class RawImage {
* @param {Uint8ClampedArray} data The new image data.
* @param {number} width The new width of the image.
* @param {number} height The new height of the image.
* @param {1|2|3|4} channels The new number of channels of the image.
* @param {1|2|3|4|null} [channels] The new number of channels of the image.
* @private
*/
_update(data, width, height, channels = null) {
this.data = data;
Expand Down
4 changes: 2 additions & 2 deletions src/utils/tensor.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import {


/**
* @typedef {import('./maths.js').AnyTypedArray} AnyTypedArray
* @typedef {import('./maths.js').AnyTypedArray | any[]} DataArray
*/

/** @type {Object} */
Expand All @@ -25,7 +25,7 @@ const ONNXTensor = ONNX.Tensor;
export class Tensor extends ONNXTensor {
/**
* Create a new Tensor or copy an existing Tensor.
* @param {[string, Array|AnyTypedArray, number[]]|[ONNXTensor]} args
* @param {[string, DataArray, number[]]|[ONNXTensor]} args
*/
constructor(...args) {
if (args[0] instanceof ONNX.Tensor) {
Expand Down

0 comments on commit 09cbb0c

Please sign in to comment.