Skip to content

Commit

Permalink
Merge branch 'main' into add-owlvit
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova authored Nov 16, 2023
2 parents 4b616e2 + 4e4148c commit da0bc91
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 6 deletions.
4 changes: 2 additions & 2 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -3035,9 +3035,9 @@ export class LlamaPreTrainedModel extends PreTrainedModel {
// config doesn't contain pad_token_id, so we assume it is the eos_token_id
this.config.pad_token_id = this.config.eos_token_id

this.num_heads = this.config.num_attention_heads
this.num_heads = this.config.num_key_value_heads ?? this.config.num_attention_heads
this.num_layers = this.config.num_hidden_layers
this.dim_kv = this.config.hidden_size / this.num_heads;
this.dim_kv = this.config.hidden_size / this.config.num_attention_heads
}
}
/**
Expand Down
2 changes: 2 additions & 0 deletions src/processors.js
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,7 @@ export class ImageFeatureExtractor extends FeatureExtractor {

}

export class CLIPFeatureExtractor extends ImageFeatureExtractor { }
export class ConvNextFeatureExtractor extends ImageFeatureExtractor { }
export class ViTFeatureExtractor extends ImageFeatureExtractor { }
export class MobileViTFeatureExtractor extends ImageFeatureExtractor { }
Expand Down Expand Up @@ -1565,6 +1566,7 @@ export class AutoProcessor {
ViTFeatureExtractor,
MobileViTFeatureExtractor,
OwlViTFeatureExtractor,
CLIPFeatureExtractor,
ConvNextFeatureExtractor,
BeitFeatureExtractor,
DeiTFeatureExtractor,
Expand Down
24 changes: 20 additions & 4 deletions tests/processors.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ describe('Processors', () => {
detr: 'facebook/detr-resnet-50',
yolos: 'hustvl/yolos-small-300',
owlvit: 'google/owlvit-base-patch32',
clip: 'openai/clip-vit-base-patch16',
}

const TEST_IMAGES = {
Expand Down Expand Up @@ -173,7 +174,7 @@ describe('Processors', () => {
it(MODELS.deit, async () => {
const processor = await AutoProcessor.from_pretrained(m(MODELS.deit))

{ // Tests grayscale image
{
const image = await load_image(TEST_IMAGES.tiger);
const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image);

Expand All @@ -189,7 +190,7 @@ describe('Processors', () => {
it(MODELS.beit, async () => {
const processor = await AutoProcessor.from_pretrained(m(MODELS.beit))

{ // Tests grayscale image
{
const image = await load_image(TEST_IMAGES.tiger);
const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image);

Expand All @@ -206,7 +207,7 @@ describe('Processors', () => {
it(MODELS.detr, async () => {
const processor = await AutoProcessor.from_pretrained(m(MODELS.detr))

{ // Tests grayscale image
{
const image = await load_image(TEST_IMAGES.tiger);
const { pixel_values, original_sizes, reshaped_input_sizes, pixel_mask } = await processor(image);

Expand All @@ -227,7 +228,7 @@ describe('Processors', () => {
it(MODELS.yolos, async () => {
const processor = await AutoProcessor.from_pretrained(m(MODELS.yolos))

{ // Tests grayscale image
{
const image = await load_image(TEST_IMAGES.tiger);
const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image);

Expand All @@ -253,6 +254,21 @@ describe('Processors', () => {

compare(original_sizes, [[480, 640]]);
compare(reshaped_input_sizes, [[768, 768]]);

// CLIPFeatureExtractor
// - tests center crop (do_center_crop=true, crop_size=224)
it(MODELS.clip, async () => {
const processor = await AutoProcessor.from_pretrained(m(MODELS.clip))

{
const image = await load_image(TEST_IMAGES.tiger);
const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image);

compare(pixel_values.dims, [1, 3, 224, 224]);
compare(avg(pixel_values.data), -0.06678297738282096);

compare(original_sizes, [[408, 612]]);
compare(reshaped_input_sizes, [[224, 224]]);
}
}, MAX_TEST_EXECUTION_TIME);
});
Expand Down

0 comments on commit da0bc91

Please sign in to comment.