diff --git a/fourm/data/modality_info.py b/fourm/data/modality_info.py index a6572e7..dd85882 100644 --- a/fourm/data/modality_info.py +++ b/fourm/data/modality_info.py @@ -14,414 +14,442 @@ from functools import partial import fourm.utils.data_constants as data_constants -from fourm.data.modality_transforms import (CaptionTransform, DepthTransform, - DetectionTransform, MaskTransform, - NormalTransform, RGBTransform, - SemsegTransform, TokTransform, - CaptionEmbTransform, MetadataTransform, - HumanPoseTransform, ColorPaletteTransform, - SAMInstanceTokTransform, SAMInstanceTransform) -from fourm.models.decoder_embeddings import (ImageTokenDecoderEmbedding, - SequenceDecoderEmbedding) -from fourm.models.encoder_embeddings import (ImageEncoderEmbedding, - ImageTokenEncoderEmbedding, - SequenceEncoderEmbedding, - SequenceEmbEncoderEmbedding) +from fourm.data.modality_transforms import ( + CaptionTransform, + DepthTransform, + DetectionTransform, + MaskTransform, + NormalTransform, + RGBTransform, + SemsegTransform, + TokTransform, + CaptionEmbTransform, + MetadataTransform, + HumanPoseTransform, + ColorPaletteTransform, + SAMInstanceTokTransform, + SAMInstanceTransform, +) +from fourm.models.decoder_embeddings import ImageTokenDecoderEmbedding, SequenceDecoderEmbedding +from fourm.models.encoder_embeddings import ( + ImageEncoderEmbedding, + ImageTokenEncoderEmbedding, + SequenceEncoderEmbedding, + SequenceEmbEncoderEmbedding, +) from fourm.utils import generate_uint15_hash MODALITY_INFO = { # 4M-7 modalities - 'rgb@224': { - 'input_size': 224, - 'patch_size': 16, - 'encoder_embedding': partial(ImageEncoderEmbedding, num_channels=3), - 'decoder_embedding': None, - 'min_tokens': 0, - 'max_tokens': None, # Will be set to 196 - 'type': 'img', - 'num_channels': 3, - 'id': generate_uint15_hash('rgb@224'), - 'path': 'rgb', - }, - 'rgb': { # used for tokenizer training - 'type': 'img', - 'num_channels': 3, - 'id': generate_uint15_hash('rgb'), - 'path': 'rgb', - }, - 'caption': { - 'vocab_size': 30_000, - 'encoder_embedding': partial(SequenceEncoderEmbedding, vocab_size=30_000, max_length=256, padding_idx=0), - 'decoder_embedding': partial(SequenceDecoderEmbedding, vocab_size=30_000, max_length=256, padding_idx=0), - 'min_tokens': 0, - 'max_tokens': 256, - 'type': 'seq', - 'id': generate_uint15_hash('caption'), - }, - 'det': { - 'vocab_size': 30_000, - 'encoder_embedding': partial(SequenceEncoderEmbedding, vocab_size=30_000, max_length=256, padding_idx=0), - 'decoder_embedding': partial(SequenceDecoderEmbedding, vocab_size=30_000, max_length=256, padding_idx=0), - 'min_tokens': 0, - 'max_tokens': 256, - 'type': 'seq', - 'id': generate_uint15_hash('det'), - }, - 'tok_rgb@224': { - 'input_size': 224, - 'patch_size': 16, - 'vocab_size': 16384, - 'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=16384), - 'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=16384), - 'min_tokens': 0, - 'max_tokens': None, # Will be set to 196 - 'type': 'img', - 'id': generate_uint15_hash('tok_rgb@224'), - 'pretokenized': True, - }, - 'tok_depth@224': { - 'input_size': 224, - 'patch_size': 16, - 'vocab_size': 8192, - 'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192), - 'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192), - 'min_tokens': 0, - 'max_tokens': None, # Will be set to 196 - 'type': 'img', - 'id': generate_uint15_hash('tok_depth@224'), - 'pretokenized': True, - }, - 'depth': { # used for tokenizer training - 'type': 'img', - 'num_channels': 1, - 'id': generate_uint15_hash('depth'), - }, - 'tok_normal@224': { - 'input_size': 224, - 'patch_size': 16, - 'vocab_size': 8192, - 'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192), - 'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192), - 'min_tokens': 0, - 'max_tokens': None, # Will be set to 196 - 'type': 'img', - 'id': generate_uint15_hash('tok_normal@224'), - 'pretokenized': True, - }, - 'normal': { # used for tokenizer training - 'type': 'img', - 'num_channels': 3, - 'id': generate_uint15_hash('normal'), - }, - 'tok_semseg@224': { - 'input_size': 224, - 'patch_size': 16, - 'vocab_size': 4096, - 'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=4096), - 'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=4096), - 'min_tokens': 0, - 'max_tokens': None, # Will be set to 196 - 'type': 'img', - 'id': generate_uint15_hash('tok_semseg@224'), - 'pretokenized': True, - }, - 'semseg_coco': { # used for tokenizer training - 'type': 'img', - 'num_channels': 64, - 'num_labels': data_constants.COCO_SEMSEG_NUM_CLASSES, - 'id': generate_uint15_hash('semseg_coco'), - }, - 'tok_clip@224': { - 'input_size': 224, - 'patch_size': 16, - 'vocab_size': 8192, - 'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192), - 'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192), - 'min_tokens': 0, - 'max_tokens': None, # Will be set to 196 - 'type': 'img', - 'id': generate_uint15_hash('tok_clip@224'), - 'pretokenized': True, - }, - 'CLIP-B16': { # used for tokenizer training - 'type': 'feature_map', - 'num_channels': 512, - 'id': generate_uint15_hash('CLIP-B16'), + "rgb@224": { + "input_size": 224, + "patch_size": 16, + "encoder_embedding": partial(ImageEncoderEmbedding, num_channels=3), + "decoder_embedding": None, + "min_tokens": 0, + "max_tokens": None, # Will be set to 196 + "type": "img", + "num_channels": 3, + "id": generate_uint15_hash("rgb@224"), + "path": "rgb", + }, + "rgb": { # used for tokenizer training + "type": "img", + "num_channels": 3, + "id": generate_uint15_hash("rgb"), + "path": "rgb", + }, + "caption": { + "vocab_size": 30_000, + "encoder_embedding": partial(SequenceEncoderEmbedding, vocab_size=30_000, max_length=256, padding_idx=0), + "decoder_embedding": partial(SequenceDecoderEmbedding, vocab_size=30_000, max_length=256, padding_idx=0), + "min_tokens": 0, + "max_tokens": 256, + "type": "seq", + "id": generate_uint15_hash("caption"), + }, + "det": { + "vocab_size": 30_000, + "encoder_embedding": partial(SequenceEncoderEmbedding, vocab_size=30_000, max_length=256, padding_idx=0), + "decoder_embedding": partial(SequenceDecoderEmbedding, vocab_size=30_000, max_length=256, padding_idx=0), + "min_tokens": 0, + "max_tokens": 256, + "type": "seq", + "id": generate_uint15_hash("det"), + }, + "tok_rgb@224": { + "input_size": 224, + "patch_size": 16, + "vocab_size": 16384, + "encoder_embedding": partial(ImageTokenEncoderEmbedding, vocab_size=16384), + "decoder_embedding": partial(ImageTokenDecoderEmbedding, vocab_size=16384), + "min_tokens": 0, + "max_tokens": None, # Will be set to 196 + "type": "img", + "id": generate_uint15_hash("tok_rgb@224"), + "pretokenized": True, + }, + "tok_depth@224": { + "input_size": 224, + "patch_size": 16, + "vocab_size": 8192, + "encoder_embedding": partial(ImageTokenEncoderEmbedding, vocab_size=8192), + "decoder_embedding": partial(ImageTokenDecoderEmbedding, vocab_size=8192), + "min_tokens": 0, + "max_tokens": None, # Will be set to 196 + "type": "img", + "id": generate_uint15_hash("tok_depth@224"), + "pretokenized": True, + }, + "depth": { # used for tokenizer training + "type": "img", + "num_channels": 1, + "id": generate_uint15_hash("depth"), + }, + "tok_normal@224": { + "input_size": 224, + "patch_size": 16, + "vocab_size": 8192, + "encoder_embedding": partial(ImageTokenEncoderEmbedding, vocab_size=8192), + "decoder_embedding": partial(ImageTokenDecoderEmbedding, vocab_size=8192), + "min_tokens": 0, + "max_tokens": None, # Will be set to 196 + "type": "img", + "id": generate_uint15_hash("tok_normal@224"), + "pretokenized": True, + }, + "normal": { # used for tokenizer training + "type": "img", + "num_channels": 3, + "id": generate_uint15_hash("normal"), + }, + "tok_semseg@224": { + "input_size": 224, + "patch_size": 16, + "vocab_size": 4096, + "encoder_embedding": partial(ImageTokenEncoderEmbedding, vocab_size=4096), + "decoder_embedding": partial(ImageTokenDecoderEmbedding, vocab_size=4096), + "min_tokens": 0, + "max_tokens": None, # Will be set to 196 + "type": "img", + "id": generate_uint15_hash("tok_semseg@224"), + "pretokenized": True, + }, + "semseg_coco": { # used for tokenizer training + "type": "img", + "num_channels": 64, + "num_labels": data_constants.COCO_SEMSEG_NUM_CLASSES, + "id": generate_uint15_hash("semseg_coco"), + }, + "tok_clip@224": { + "input_size": 224, + "patch_size": 16, + "vocab_size": 8192, + "encoder_embedding": partial(ImageTokenEncoderEmbedding, vocab_size=8192), + "decoder_embedding": partial(ImageTokenDecoderEmbedding, vocab_size=8192), + "min_tokens": 0, + "max_tokens": None, # Will be set to 196 + "type": "img", + "id": generate_uint15_hash("tok_clip@224"), + "pretokenized": True, + }, + "CLIP-B16": { # used for tokenizer training + "type": "feature_map", + "num_channels": 512, + "id": generate_uint15_hash("CLIP-B16"), }, - # 4M-21 modalities - 't5_caption': { - 'encoder_embedding': partial(SequenceEmbEncoderEmbedding, max_length=77, padding_idx=0), - 'decoder_embedding': None, - 'min_tokens': 0, - 'max_tokens': 77, - 'type': 'seq_emb', - 'id': generate_uint15_hash('t5_caption'), - }, - 'metadata': { - 'vocab_size': 30_000, - 'encoder_embedding': partial(SequenceEncoderEmbedding, vocab_size=30_000, max_length=40, padding_idx=0, sincos_pos_emb=True), - 'decoder_embedding': partial(SequenceDecoderEmbedding, vocab_size=30_000, max_length=40, padding_idx=0, sincos_pos_emb=True), - 'min_tokens': 0, - 'max_tokens': 40, # At most 2x19=38 for 19 metadata types, +1 for EOS, +1 for sentinel - 'type': 'seq', - 'id': generate_uint15_hash('metadata'), - 'shared_vocab': ['caption'], - 'path': 'metadata', - }, - 'human_poses': { - 'vocab_size': 30_000, - 'encoder_embedding': partial(SequenceEncoderEmbedding, vocab_size=30_000, max_length=263, padding_idx=0, sincos_pos_emb=True), - 'decoder_embedding': partial(SequenceDecoderEmbedding, vocab_size=30_000, max_length=263, padding_idx=0, sincos_pos_emb=True), - 'min_tokens': 0, - 'max_tokens': 275, #7*39+1 EOS+1 S_1#263, #261 in one of the models, or 263 to have EOS #261+1+1 #238, - 'type': 'seq', - 'num_channels': 207, # for tokenization training, only the pose part is needed - 'id': generate_uint15_hash('human_poses'), - 'shared_vocab': ['caption'], - }, - 'color_palette': { - 'vocab_size': 30_000, - 'encoder_embedding': partial(SequenceEncoderEmbedding, vocab_size=30_000, max_length=23, padding_idx=0, sincos_pos_emb=True), - 'decoder_embedding': partial(SequenceDecoderEmbedding, vocab_size=30_000, max_length=23, padding_idx=0, sincos_pos_emb=True), - 'min_tokens': 0, - 'max_tokens': 23, #7x3=21 for 7 colors, +1 for EOS, +1 for sentinel - 'type': 'seq', - 'id': generate_uint15_hash('color_palette'), - 'shared_vocab': ['caption'], - 'path': 'color_palette', - }, - 'sam_mask': { - 'encoder_embedding': None, - 'decoder_embedding': None, - 'min_tokens': 0, - 'max_tokens': 64, - 'type': 'img', - 'num_channels': 1, - 'id': generate_uint15_hash('sam_mask'), - }, - 'sam_instance': { - 'vocab_size': 30_000, - 'encoder_embedding': partial(SequenceEncoderEmbedding, vocab_size=30_000, max_length=290, padding_idx=0, sincos_pos_emb=True), - 'decoder_embedding': partial(SequenceDecoderEmbedding, vocab_size=30_000, max_length=290, padding_idx=0, sincos_pos_emb=True), - 'min_tokens': 0, - 'max_tokens': 290, - 'type': 'seq', - 'id': generate_uint15_hash('sam_instance'), - 'shared_vocab': ['caption'], - 'pretokenized': True, - }, - 'tok_canny_edge@224': { - 'input_size': 224, - 'patch_size': 16, - 'vocab_size': 8192, - 'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192), - 'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192), - 'min_tokens': 0, - 'max_tokens': None, # Will be set to 196 - 'type': 'img', - 'id': generate_uint15_hash('tok_canny_edge@224'), - 'pretokenized': True, - }, - 'canny_edge': { # used for tokenizer training - 'type': 'img', - 'num_channels': 1, - 'id': generate_uint15_hash('canny_edge'), - }, - 'tok_sam_edge@224': { - 'input_size': 224, - 'patch_size': 16, - 'vocab_size': 8192, - 'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192), - 'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192), - 'min_tokens': 0, - 'max_tokens': None, # Will be set to 196 - 'type': 'img', - 'id': generate_uint15_hash('tok_sam_edge@224'), - 'pretokenized': True, - }, - 'tok_dinov2@224': { - 'input_size': 224, - 'patch_size': 14, - 'vocab_size': 8192, - 'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192), - 'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192), - 'min_tokens': 0, - 'max_tokens': None, # Will be set to 256 - 'type': 'img', - 'id': generate_uint15_hash('tok_dinov2@224'), - 'pretokenized': True, - }, - 'DINOv2-B14': { # used for tokenizer training - 'type': 'feature_map', - 'num_channels': 768, - 'id': generate_uint15_hash('DINOv2-B14'), - }, - 'tok_imagebind@224': { - 'input_size': 224, - 'patch_size': 14, - 'vocab_size': 8192, - 'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192), - 'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192), - 'min_tokens': 0, - 'max_tokens': None, # Will be set to 256 - 'type': 'img', - 'id': generate_uint15_hash('tok_imagebind@224'), - 'pretokenized': True, - }, - 'ImageBind-H14': { # used for tokenizer training - 'type': 'feature_map', - 'num_channels': 1280, - 'id': generate_uint15_hash('ImageBind-H14'), - }, - 'tok_dinov2_global': { - 'vocab_size': 8192, - 'patch_size': 56, - 'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192, sincos_pos_emb=False), - 'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192, sincos_pos_emb=False), - 'min_tokens': 0, - 'max_tokens': 16, - 'type': 'img', - 'id': generate_uint15_hash('tok_dinov2_global'), - 'pretokenized': True, - }, - 'DINOv2-B14-global': { # used for tokenizer training - 'type': 'feature_map', - 'num_channels': 768, - 'id': generate_uint15_hash('DINOv2-B14-global'), - }, - 'tok_imagebind_global': { - 'vocab_size': 8192, - 'patch_size': 56, - 'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192, sincos_pos_emb=False), - 'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192, sincos_pos_emb=False), - 'min_tokens': 0, - 'max_tokens': 16, - 'type': 'img', - 'id': generate_uint15_hash('tok_imagebind_global'), - 'pretokenized': True, - }, - 'ImageBind-H14-global': { # used for tokenizer training - 'type': 'feature_map', - 'num_channels': 1280, - 'id': generate_uint15_hash('ImageBind-H14-global'), + "t5_caption": { + "encoder_embedding": partial(SequenceEmbEncoderEmbedding, max_length=77, padding_idx=0), + "decoder_embedding": None, + "min_tokens": 0, + "max_tokens": 77, + "type": "seq_emb", + "id": generate_uint15_hash("t5_caption"), + }, + "metadata": { + "vocab_size": 30_000, + "encoder_embedding": partial( + SequenceEncoderEmbedding, vocab_size=30_000, max_length=40, padding_idx=0, sincos_pos_emb=True + ), + "decoder_embedding": partial( + SequenceDecoderEmbedding, vocab_size=30_000, max_length=40, padding_idx=0, sincos_pos_emb=True + ), + "min_tokens": 0, + "max_tokens": 40, # At most 2x19=38 for 19 metadata types, +1 for EOS, +1 for sentinel + "type": "seq", + "id": generate_uint15_hash("metadata"), + "shared_vocab": ["caption"], + "path": "metadata", + }, + "human_poses": { + "vocab_size": 30_000, + "encoder_embedding": partial( + SequenceEncoderEmbedding, vocab_size=30_000, max_length=263, padding_idx=0, sincos_pos_emb=True + ), + "decoder_embedding": partial( + SequenceDecoderEmbedding, vocab_size=30_000, max_length=263, padding_idx=0, sincos_pos_emb=True + ), + "min_tokens": 0, + "max_tokens": 275, # 7*39+1 EOS+1 S_1#263, #261 in one of the models, or 263 to have EOS #261+1+1 #238, + "type": "seq", + "num_channels": 207, # for tokenization training, only the pose part is needed + "id": generate_uint15_hash("human_poses"), + "shared_vocab": ["caption"], + }, + "color_palette": { + "vocab_size": 30_000, + "encoder_embedding": partial( + SequenceEncoderEmbedding, vocab_size=30_000, max_length=23, padding_idx=0, sincos_pos_emb=True + ), + "decoder_embedding": partial( + SequenceDecoderEmbedding, vocab_size=30_000, max_length=23, padding_idx=0, sincos_pos_emb=True + ), + "min_tokens": 0, + "max_tokens": 23, # 7x3=21 for 7 colors, +1 for EOS, +1 for sentinel + "type": "seq", + "id": generate_uint15_hash("color_palette"), + "shared_vocab": ["caption"], + "path": "color_palette", + }, + "sam_mask": { + "encoder_embedding": None, + "decoder_embedding": None, + "min_tokens": 0, + "max_tokens": 64, + "type": "img", + "num_channels": 1, + "id": generate_uint15_hash("sam_mask"), + }, + "sam_instance": { + "vocab_size": 30_000, + "encoder_embedding": partial( + SequenceEncoderEmbedding, vocab_size=30_000, max_length=290, padding_idx=0, sincos_pos_emb=True + ), + "decoder_embedding": partial( + SequenceDecoderEmbedding, vocab_size=30_000, max_length=290, padding_idx=0, sincos_pos_emb=True + ), + "min_tokens": 0, + "max_tokens": 290, + "type": "seq", + "id": generate_uint15_hash("sam_instance"), + "shared_vocab": ["caption"], + "pretokenized": True, + }, + "tok_canny_edge@224": { + "input_size": 224, + "patch_size": 16, + "vocab_size": 8192, + "encoder_embedding": partial(ImageTokenEncoderEmbedding, vocab_size=8192), + "decoder_embedding": partial(ImageTokenDecoderEmbedding, vocab_size=8192), + "min_tokens": 0, + "max_tokens": None, # Will be set to 196 + "type": "img", + "id": generate_uint15_hash("tok_canny_edge@224"), + "pretokenized": True, + }, + "canny_edge": { # used for tokenizer training + "type": "img", + "num_channels": 1, + "id": generate_uint15_hash("canny_edge"), + }, + "tok_sam_edge@224": { + "input_size": 224, + "patch_size": 16, + "vocab_size": 8192, + "encoder_embedding": partial(ImageTokenEncoderEmbedding, vocab_size=8192), + "decoder_embedding": partial(ImageTokenDecoderEmbedding, vocab_size=8192), + "min_tokens": 0, + "max_tokens": None, # Will be set to 196 + "type": "img", + "id": generate_uint15_hash("tok_sam_edge@224"), + "pretokenized": True, + }, + "tok_dinov2@224": { + "input_size": 224, + "patch_size": 14, + "vocab_size": 8192, + "encoder_embedding": partial(ImageTokenEncoderEmbedding, vocab_size=8192), + "decoder_embedding": partial(ImageTokenDecoderEmbedding, vocab_size=8192), + "min_tokens": 0, + "max_tokens": None, # Will be set to 256 + "type": "img", + "id": generate_uint15_hash("tok_dinov2@224"), + "pretokenized": True, + }, + "DINOv2-B14": { # used for tokenizer training + "type": "feature_map", + "num_channels": 768, + "id": generate_uint15_hash("DINOv2-B14"), + }, + "tok_imagebind@224": { + "input_size": 224, + "patch_size": 14, + "vocab_size": 8192, + "encoder_embedding": partial(ImageTokenEncoderEmbedding, vocab_size=8192), + "decoder_embedding": partial(ImageTokenDecoderEmbedding, vocab_size=8192), + "min_tokens": 0, + "max_tokens": None, # Will be set to 256 + "type": "img", + "id": generate_uint15_hash("tok_imagebind@224"), + "pretokenized": True, + }, + "ImageBind-H14": { # used for tokenizer training + "type": "feature_map", + "num_channels": 1280, + "id": generate_uint15_hash("ImageBind-H14"), + }, + "tok_dinov2_global": { + "vocab_size": 8192, + "patch_size": 56, + "encoder_embedding": partial(ImageTokenEncoderEmbedding, vocab_size=8192, sincos_pos_emb=False), + "decoder_embedding": partial(ImageTokenDecoderEmbedding, vocab_size=8192, sincos_pos_emb=False), + "min_tokens": 0, + "max_tokens": 16, + "type": "img", + "id": generate_uint15_hash("tok_dinov2_global"), + "pretokenized": True, + }, + "DINOv2-B14-global": { # used for tokenizer training + "type": "feature_map", + "num_channels": 768, + "id": generate_uint15_hash("DINOv2-B14-global"), + }, + "tok_imagebind_global": { + "vocab_size": 8192, + "patch_size": 56, + "encoder_embedding": partial(ImageTokenEncoderEmbedding, vocab_size=8192, sincos_pos_emb=False), + "decoder_embedding": partial(ImageTokenDecoderEmbedding, vocab_size=8192, sincos_pos_emb=False), + "min_tokens": 0, + "max_tokens": 16, + "type": "img", + "id": generate_uint15_hash("tok_imagebind_global"), + "pretokenized": True, + }, + "ImageBind-H14-global": { # used for tokenizer training + "type": "feature_map", + "num_channels": 1280, + "id": generate_uint15_hash("ImageBind-H14-global"), }, - ### 224->448 super resolution modalities - 'rgb@448': { - 'input_size': 448, - 'patch_size': 16, - 'encoder_embedding': partial(ImageEncoderEmbedding, num_channels=3), - 'decoder_embedding': None, - 'min_tokens': 0, - 'max_tokens': None, # Will be set to 784 - 'type': 'img', - 'num_channels': 3, - 'id': generate_uint15_hash('rgb@448'), - 'path': 'rgb', - }, - 'tok_rgb@448': { - 'input_size': 448, - 'patch_size': 16, - 'vocab_size': 16384, - 'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=16384), - 'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=16384), - 'min_tokens': 0, - 'max_tokens': None, # Will be set to 784 - 'type': 'img', - 'id': generate_uint15_hash('tok_rgb@448'), - 'pretokenized': True, - }, - 'tok_depth@448': { - 'input_size': 448, - 'patch_size': 16, - 'vocab_size': 8192, - 'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192), - 'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192), - 'min_tokens': 0, - 'max_tokens': None, # Will be set to 784 - 'type': 'img', - 'id': generate_uint15_hash('tok_depth@448'), - 'pretokenized': True, - }, - 'tok_normal@448': { - 'input_size': 448, - 'patch_size': 16, - 'vocab_size': 8192, - 'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192), - 'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192), - 'min_tokens': 0, - 'max_tokens': None, # Will be set to 784 - 'type': 'img', - 'id': generate_uint15_hash('tok_normal@448'), - 'pretokenized': True, - }, - 'tok_semseg@448': { - 'input_size': 448, - 'patch_size': 16, - 'vocab_size': 4096, - 'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=4096), - 'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=4096), - 'min_tokens': 0, - 'max_tokens': None, # Will be set to 784 - 'type': 'img', - 'id': generate_uint15_hash('tok_semseg@448'), - 'pretokenized': True, - }, - 'tok_clip@448': { - 'input_size': 448, - 'patch_size': 16, - 'vocab_size': 8192, - 'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192), - 'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192), - 'min_tokens': 0, - 'max_tokens': None, # Will be set to 784 - 'type': 'img', - 'id': generate_uint15_hash('tok_clip@448'), - 'pretokenized': True, + "rgb@448": { + "input_size": 448, + "patch_size": 16, + "encoder_embedding": partial(ImageEncoderEmbedding, num_channels=3), + "decoder_embedding": None, + "min_tokens": 0, + "max_tokens": None, # Will be set to 784 + "type": "img", + "num_channels": 3, + "id": generate_uint15_hash("rgb@448"), + "path": "rgb", + }, + "tok_rgb@448": { + "input_size": 448, + "patch_size": 16, + "vocab_size": 16384, + "encoder_embedding": partial(ImageTokenEncoderEmbedding, vocab_size=16384), + "decoder_embedding": partial(ImageTokenDecoderEmbedding, vocab_size=16384), + "min_tokens": 0, + "max_tokens": None, # Will be set to 784 + "type": "img", + "id": generate_uint15_hash("tok_rgb@448"), + "pretokenized": True, + }, + "tok_depth@448": { + "input_size": 448, + "patch_size": 16, + "vocab_size": 8192, + "encoder_embedding": partial(ImageTokenEncoderEmbedding, vocab_size=8192), + "decoder_embedding": partial(ImageTokenDecoderEmbedding, vocab_size=8192), + "min_tokens": 0, + "max_tokens": None, # Will be set to 784 + "type": "img", + "id": generate_uint15_hash("tok_depth@448"), + "pretokenized": True, + }, + "tok_normal@448": { + "input_size": 448, + "patch_size": 16, + "vocab_size": 8192, + "encoder_embedding": partial(ImageTokenEncoderEmbedding, vocab_size=8192), + "decoder_embedding": partial(ImageTokenDecoderEmbedding, vocab_size=8192), + "min_tokens": 0, + "max_tokens": None, # Will be set to 784 + "type": "img", + "id": generate_uint15_hash("tok_normal@448"), + "pretokenized": True, + }, + "tok_semseg@448": { + "input_size": 448, + "patch_size": 16, + "vocab_size": 4096, + "encoder_embedding": partial(ImageTokenEncoderEmbedding, vocab_size=4096), + "decoder_embedding": partial(ImageTokenDecoderEmbedding, vocab_size=4096), + "min_tokens": 0, + "max_tokens": None, # Will be set to 784 + "type": "img", + "id": generate_uint15_hash("tok_semseg@448"), + "pretokenized": True, + }, + "tok_clip@448": { + "input_size": 448, + "patch_size": 16, + "vocab_size": 8192, + "encoder_embedding": partial(ImageTokenEncoderEmbedding, vocab_size=8192), + "decoder_embedding": partial(ImageTokenDecoderEmbedding, vocab_size=8192), + "min_tokens": 0, + "max_tokens": None, # Will be set to 784 + "type": "img", + "id": generate_uint15_hash("tok_clip@448"), + "pretokenized": True, }, } # Note: @res suffix is ignored for modality transforms MODALITY_TRANSFORMS = { # 4M-7 modalities - 'rgb': RGBTransform(imagenet_default_mean_and_std=True), - 'caption': CaptionTransform(aligned_captions=True), - 'det': DetectionTransform(det_threshold=0.6, det_max_instances=None, bbox_order='dist_to_orig', coord_bins=1000, min_visibility=0.0), - 'tok_rgb': TokTransform(), - 'tok_depth': TokTransform(), - 'tok_normal': TokTransform(), - 'tok_semseg': TokTransform(), - 'tok_clip': TokTransform(), + "rgb": RGBTransform(imagenet_default_mean_and_std=True), + "caption": CaptionTransform(aligned_captions=True), + "det": DetectionTransform( + det_threshold=0.6, det_max_instances=None, bbox_order="dist_to_orig", coord_bins=1000, min_visibility=0.0 + ), + "tok_rgb": TokTransform(), + "tok_depth": TokTransform(), + "tok_normal": TokTransform(), + "tok_semseg": TokTransform(), + "tok_clip": TokTransform(), # 4M-21 modalities - 't5_caption': CaptionEmbTransform(), - 'metadata': MetadataTransform(special_vmin=0, special_vmax=999, shuffle=True, random_trunc=False, return_chunks=True), - 'human_poses': HumanPoseTransform(coord_bins=1000), - 'color_palette': ColorPaletteTransform(coord_bins=1000), - 'sam_instance': SAMInstanceTokTransform(image_size=224, points_per_side=7, point_order='random'), - 'tok_canny_edge': TokTransform(), - 'tok_sam_edge': TokTransform(), - 'tok_dinov2': TokTransform(), - 'tok_imagebind': TokTransform(), - 'tok_dinov2_global': TokTransform(), - 'tok_imagebind_global': TokTransform(), + "t5_caption": CaptionEmbTransform(), + "metadata": MetadataTransform( + special_vmin=0, special_vmax=999, shuffle=True, random_trunc=False, return_chunks=True + ), + "human_poses": HumanPoseTransform(coord_bins=1000), + "color_palette": ColorPaletteTransform(coord_bins=1000), + "sam_instance": SAMInstanceTokTransform(image_size=224, points_per_side=7, point_order="random"), + "tok_canny_edge": TokTransform(), + "tok_sam_edge": TokTransform(), + "tok_dinov2": TokTransform(), + "tok_imagebind": TokTransform(), + "tok_dinov2_global": TokTransform(), + "tok_imagebind_global": TokTransform(), # Other - 'mask_valid': MaskTransform(mask_pool_size=1), + "mask_valid": MaskTransform(mask_pool_size=1), } MODALITY_TRANSFORMS_DIVAE = { - 'rgb': RGBTransform(imagenet_default_mean_and_std=False), - 'depth': DepthTransform(standardize_depth=True), - 'normal': NormalTransform(standardize_surface_normals=False), - 'mask_valid': MaskTransform(mask_pool_size=1), - 'semseg_coco': SemsegTransform(shift_idx_by_one=True), - 'canny_edge': RGBTransform(imagenet_default_mean_and_std=False), - 'human_poses': HumanPoseTransform(coord_bins=1000, only_pose=True), - 'sam_mask': SAMInstanceTransform(mask_size=64, max_instance_n=1), + "rgb": RGBTransform(imagenet_default_mean_and_std=False), + "depth": DepthTransform(standardize_depth=True), + "normal": NormalTransform(standardize_surface_normals=False), + "mask_valid": MaskTransform(mask_pool_size=1), + "semseg_coco": SemsegTransform(shift_idx_by_one=True), + "canny_edge": RGBTransform(imagenet_default_mean_and_std=False), + "human_poses": HumanPoseTransform(coord_bins=1000, only_pose=True), + "sam_mask": SAMInstanceTransform(mask_size=64, max_instance_n=1), } MODALITY_TRANSFORMS_VQCONTROLNET = { - 'rgb': RGBTransform(imagenet_default_mean_and_std=False), - 'mask_valid': MaskTransform(mask_pool_size=1), - 'caption': CaptionTransform(aligned_captions=True), + "rgb": RGBTransform(imagenet_default_mean_and_std=False), + "mask_valid": MaskTransform(mask_pool_size=1), + "caption": CaptionTransform(aligned_captions=True), } diff --git a/fourm/data/modality_transforms.py b/fourm/data/modality_transforms.py index f3c3563..bdc7dbf 100644 --- a/fourm/data/modality_transforms.py +++ b/fourm/data/modality_transforms.py @@ -29,23 +29,32 @@ from einops import rearrange, repeat, reduce from fourm.utils import to_2tuple -from fourm.utils.data_constants import (IMAGENET_DEFAULT_MEAN, - IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, - IMAGENET_SURFACE_NORMAL_STD, IMAGENET_SURFACE_NORMAL_MEAN, - IMAGENET_INCEPTION_STD, SEG_IGNORE_INDEX, PAD_MASK_VALUE) +from fourm.utils.data_constants import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + IMAGENET_INCEPTION_MEAN, + IMAGENET_SURFACE_NORMAL_STD, + IMAGENET_SURFACE_NORMAL_MEAN, + IMAGENET_INCEPTION_STD, + SEG_IGNORE_INDEX, + PAD_MASK_VALUE, +) # The @-symbol is used to specify the resolution of a modality. Syntax: modality@resolution def get_transform_key(mod_name): - return mod_name.split('@')[0] + return mod_name.split("@")[0] + def get_transform_resolution(mod_name, default_resolution, to_tuple=True): - res = int(mod_name.split('@')[1]) if '@' in mod_name else default_resolution + res = int(mod_name.split("@")[1]) if "@" in mod_name else default_resolution return to_2tuple(res) if to_tuple else res + def get_transform(mod_name, transforms_dict): return transforms_dict.get(get_transform_key(mod_name), IdentityTransform()) + def get_pil_resample_mode(resample_mode: str): """ Returns the PIL resampling mode for the given resample mode string. @@ -56,14 +65,15 @@ def get_pil_resample_mode(resample_mode: str): if resample_mode is None: return None elif resample_mode == "bilinear": - return Image.Resampling.BILINEAR if hasattr(Image, 'Resampling') else Image.BILINEAR + return Image.Resampling.BILINEAR if hasattr(Image, "Resampling") else Image.BILINEAR elif resample_mode == "bicubic": - return Image.Resampling.BICUBIC if hasattr(Image, 'Resampling') else Image.BICUBIC + return Image.Resampling.BICUBIC if hasattr(Image, "Resampling") else Image.BICUBIC elif resample_mode == "nearest": - return Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST + return Image.Resampling.NEAREST if hasattr(Image, "Resampling") else Image.NEAREST else: raise ValueError(f"Resample mode {resample_mode} is not supported.") + class UnifiedDataTransform(object): def __init__(self, transforms_dict, image_augmenter, resample_mode: str = None, add_sizes: bool = False, **kwargs): """Unified data augmentation for FourM @@ -93,12 +103,16 @@ def unified_image_augment(self, mod_dict, crop_settings): """ crop_coords, flip, orig_size, target_size, rand_aug_idx = self.image_augmenter(mod_dict, crop_settings) - + mod_dict = { k: self.transforms_dict[get_transform_key(k)].image_augment( - v, crop_coords=crop_coords, flip=flip, orig_size=orig_size, - target_size=get_transform_resolution(k, target_size), rand_aug_idx=rand_aug_idx, - resample_mode=self.resample_mode + v, + crop_coords=crop_coords, + flip=flip, + orig_size=orig_size, + target_size=get_transform_resolution(k, target_size), + rand_aug_idx=rand_aug_idx, + resample_mode=self.resample_mode, ) for k, v in mod_dict.items() } @@ -145,8 +159,16 @@ def preprocess(self, sample): pass @abstractmethod - def image_augment(self, v, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple, - rand_aug_idx: Optional[int], resample_mode: str = None): + def image_augment( + self, + v, + crop_coords: Tuple, + flip: bool, + orig_size: Tuple, + target_size: Tuple, + rand_aug_idx: Optional[int], + resample_mode: str = None, + ): pass @abstractmethod @@ -164,7 +186,6 @@ def pil_loader(path: str) -> Image.Image: img = Image.open(path) return img - @staticmethod def image_hflip(img: Image, flip: bool): """Crop and resize an image @@ -206,10 +227,22 @@ def __init__(self, imagenet_default_mean_and_std=True, color_jitter=False, color def random_color_jitter(self, strength=0.5): # Color Jitter from Pix2Seq and SimCLR # Source: https://github.com/google-research/pix2seq/blob/main/data/data_utils.py#L114 - t = T.Compose([ - T.RandomApply([T.ColorJitter(brightness=0.8 * strength, contrast=0.8 * strength, saturation=0.8 * strength, hue=0.2 * strength)], p=0.8), - T.RandomApply([T.Grayscale(num_output_channels=3)], p=0.2), - ]) + t = T.Compose( + [ + T.RandomApply( + [ + T.ColorJitter( + brightness=0.8 * strength, + contrast=0.8 * strength, + saturation=0.8 * strength, + hue=0.2 * strength, + ) + ], + p=0.8, + ), + T.RandomApply([T.Grayscale(num_output_channels=3)], p=0.2), + ] + ) return t @@ -224,15 +257,23 @@ def load(self, path): return sample def preprocess(self, sample): - sample = sample.convert('RGB') - + sample = sample.convert("RGB") + if self.color_jitter: sample = self.color_jitter_transform(sample) return sample - def image_augment(self, img, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple, - rand_aug_idx: Optional[int], resample_mode: str = None): + def image_augment( + self, + img, + crop_coords: Tuple, + flip: bool, + orig_size: Tuple, + target_size: Tuple, + rand_aug_idx: Optional[int], + resample_mode: str = None, + ): img = self.image_crop_and_resize(img, crop_coords, target_size, resample_mode=resample_mode) img = self.image_hflip(img, flip) return img @@ -248,7 +289,7 @@ def __init__(self, standardize_depth=True): self.standardize_depth = standardize_depth def depth_to_tensor(self, img): - img = torch.Tensor( img / (2 ** 16 - 1.0) ) + img = torch.Tensor(img / (2**16 - 1.0)) img = img.unsqueeze(0) # 1 x H x W if self.standardize_depth: img = self.truncated_depth_standardization(img) @@ -264,7 +305,7 @@ def truncated_depth_standardization(depth, thresh: float = 0.1): """ # Flatten depth and remove bottom and top 10% of values trunc_depth = torch.sort(depth.reshape(-1), dim=0)[0] - trunc_depth = trunc_depth[int(thresh * trunc_depth.shape[0]): int((1 - thresh) * trunc_depth.shape[0])] + trunc_depth = trunc_depth[int(thresh * trunc_depth.shape[0]) : int((1 - thresh) * trunc_depth.shape[0])] return (depth - trunc_depth.mean()) / torch.sqrt(trunc_depth.var() + 1e-6) def load(self, path): @@ -274,8 +315,16 @@ def load(self, path): def preprocess(self, sample): return sample - def image_augment(self, img, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple, - rand_aug_idx: Optional[int], resample_mode: str = None): + def image_augment( + self, + img, + crop_coords: Tuple, + flip: bool, + orig_size: Tuple, + target_size: Tuple, + rand_aug_idx: Optional[int], + resample_mode: str = None, + ): img = self.image_crop_and_resize(img, crop_coords, target_size, resample_mode=resample_mode) img = self.image_hflip(img, flip) return img @@ -313,8 +362,16 @@ def image_hflip(self, img: Image, flip: bool): return img - def image_augment(self, img, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple, - rand_aug_idx: Optional[int], resample_mode: str = None): + def image_augment( + self, + img, + crop_coords: Tuple, + flip: bool, + orig_size: Tuple, + target_size: Tuple, + rand_aug_idx: Optional[int], + resample_mode: str = None, + ): img = self.image_crop_and_resize(img, crop_coords, target_size, resample_mode=resample_mode) img = self.image_hflip(img, flip) return img @@ -322,11 +379,13 @@ def image_augment(self, img, crop_coords: Tuple, flip: bool, orig_size: Tuple, t def postprocess(self, sample): sample = self.normal_to_tensor(sample) return sample - + class SemsegTransform(ImageTransform): - def __init__(self, scale_factor=1.0, shift_idx_by_one=False, id_mapping: Optional[Dict] = None, select_channel=None): + def __init__( + self, scale_factor=1.0, shift_idx_by_one=False, id_mapping: Optional[Dict] = None, select_channel=None + ): self.scale_factor = scale_factor self.shift_idx_by_one = shift_idx_by_one self.id_mapping = id_mapping @@ -336,7 +395,7 @@ def map_semseg_values(self, sample): sample = np.asarray(sample) mapping_fn = lambda x: self.id_mapping.get(x, x) sample = np.vectorize(mapping_fn)(sample) - sample = Image.fromarray(sample, mode='P') + sample = Image.fromarray(sample, mode="P") return sample def semseg_to_tensor(self, img): @@ -356,7 +415,7 @@ def load(self, path): return sample def preprocess(self, sample): - sample = sample.convert('P') + sample = sample.convert("P") if self.id_mapping is not None: sample = self.map_semseg_values(sample) @@ -364,15 +423,23 @@ def preprocess(self, sample): if self.shift_idx_by_one: sample = np.asarray(sample) sample = sample + 1 - sample = Image.fromarray(sample, mode='P') + sample = Image.fromarray(sample, mode="P") return sample - def image_augment(self, img, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple, - rand_aug_idx: Optional[int], resample_mode: str = None): + def image_augment( + self, + img, + crop_coords: Tuple, + flip: bool, + orig_size: Tuple, + target_size: Tuple, + rand_aug_idx: Optional[int], + resample_mode: str = None, + ): # Value for padding with TF.crop is always 0. # Override resampling mode to 'nearest' for semseg - img = self.image_crop_and_resize(img, crop_coords, target_size, resample_mode='nearest') + img = self.image_crop_and_resize(img, crop_coords, target_size, resample_mode="nearest") img = self.image_hflip(img, flip) return img @@ -389,17 +456,16 @@ def __init__(self, mask_size=64, max_instance_n=20, bbox_area_threshold=0.0005): self.bbox_area_threshold = bbox_area_threshold def get_bbox(self, instance): - """ Gets bounding box of the given instance - """ - min_h, max_h = instance[:,:,1].min(), instance[:,:,1].max() - min_w, max_w = instance[:,:,0].min(), instance[:,:,0].max() + """Gets bounding box of the given instance""" + min_h, max_h = instance[:, :, 1].min(), instance[:, :, 1].max() + min_w, max_w = instance[:, :, 0].min(), instance[:, :, 0].max() return [min_h, min_w, max_h, max_w] def extend_instance_points(self, instance, border_fn): - """ Given an instance and a border function `border_fn`, extends the instance points with crossing points between the instance and + """Given an instance and a border function `border_fn`, extends the instance points with crossing points between the instance and the crop borders. The crossing points are obtained using border_fn. """ - p = instance[:,0] + p = instance[:, 0] p_next = np.roll(p, (-1), axis=(0)) final_points = [] for x, xn in zip(p, p_next): @@ -407,68 +473,67 @@ def extend_instance_points(self, instance, border_fn): for r in border_fn(x, xn): final_points.append(r.astype(np.int32)) p = np.stack(final_points) - return p[:,None] + return p[:, None] def remove_redundant_lines(self, orig_instance, instance): - """ Removes the redundant lines added during cropping. - """ + """Removes the redundant lines added during cropping.""" final_points = [] for p in instance: - distance = cv2.pointPolygonTest(orig_instance, (p[0,0].item(), p[0,1].item()), measureDist=True) + distance = cv2.pointPolygonTest(orig_instance, (p[0, 0].item(), p[0, 1].item()), measureDist=True) if distance >= 0: final_points.append(p[0]) - return np.stack(final_points)[:,None] + return np.stack(final_points)[:, None] def get_border_functions(self, crop_points): - """ Creates and returns a function `fn` using crop region coordinates given in crop_points. + """Creates and returns a function `fn` using crop region coordinates given in crop_points. `fn` receives two input points x and xn and returns all the crossing points between the line connecting x and xn, and the borders of the cropping rectangle. """ - p = crop_points[:,0] + p = crop_points[:, 0] p_next = np.roll(p, (-1), axis=(0)) + def fn(x, xn): output = [] c_diff = p_next - p x_diff = x - xn for diff, c in zip(c_diff, p): - A = np.array([ - [diff[0], x_diff[0]], - [diff[1], x_diff[1]] - ]) + A = np.array([[diff[0], x_diff[0]], [diff[1], x_diff[1]]]) b = x - c try: lmbda = np.linalg.solve(A, b) if 0 <= lmbda[0] <= 1 and 0 <= lmbda[1] <= 1: - output.append(lmbda[1] * xn + (1-lmbda[1]) * x) + output.append(lmbda[1] * xn + (1 - lmbda[1]) * x) except: continue return output + return fn def crop_sample(self, sample, crop_coords): - """ Crop the sample using crop coordinates. - """ + """Crop the sample using crop coordinates.""" top, left, h, w = crop_coords crop_region = (left, top, left + w, top + h) - crop_points = np.array([ - [crop_region[0], crop_region[1]], - [crop_region[2], crop_region[1]], - [crop_region[2], crop_region[3]], - [crop_region[0], crop_region[3]], - ])[:,None] + crop_points = np.array( + [ + [crop_region[0], crop_region[1]], + [crop_region[2], crop_region[1]], + [crop_region[2], crop_region[3]], + [crop_region[0], crop_region[3]], + ] + )[:, None] border_functions = self.get_border_functions(crop_points) cropped_sample = [] for instance in sample: instance = self.extend_instance_points(instance, border_functions) filter_condition = ( - (instance[:, :, 0] > crop_region[0]) & - (instance[:, :, 0] < crop_region[2]) & - (instance[:, :, 1] > crop_region[1]) & - (instance[:, :, 1] < crop_region[3]) + (instance[:, :, 0] > crop_region[0]) + & (instance[:, :, 0] < crop_region[2]) + & (instance[:, :, 1] > crop_region[1]) + & (instance[:, :, 1] < crop_region[3]) ) if not np.any(filter_condition): continue - + instance_copy = instance.copy() instance_copy[:, :, 0] = np.clip(instance[:, :, 0], a_min=crop_region[0], a_max=crop_region[2]) instance_copy[:, :, 1] = np.clip(instance[:, :, 1], a_min=crop_region[1], a_max=crop_region[3]) @@ -478,10 +543,9 @@ def crop_sample(self, sample, crop_coords): cropped_sample.append(instance_copy) return cropped_sample - + def resize_sample(self, sample, original_size, target_size): - """ Resize the sample - """ + """Resize the sample""" width_scale = target_size[1] / original_size[1] height_scale = target_size[0] / original_size[0] resized_sample = [] @@ -491,10 +555,9 @@ def resize_sample(self, sample, original_size, target_size): instance_copy[:, :, 1] = np.round(height_scale * instance_copy[:, :, 1]) resized_sample.append(instance_copy) return resized_sample - + def remove_tiny_instances(self, sample, image_size): - """ Remove instances that have an area ratio smaller than `bbox_area_threshold`. - """ + """Remove instances that have an area ratio smaller than `bbox_area_threshold`.""" filtered_sample = [] for instance in sample: min_h, min_w, max_h, max_w = self.get_bbox(instance) @@ -505,23 +568,21 @@ def remove_tiny_instances(self, sample, image_size): return filtered_sample def hflip(self, sample, width): - """ Horizontal flipping the instances in a sample. - """ + """Horizontal flipping the instances in a sample.""" flipped_sample = [] for instance in sample: instance_copy = instance.copy() instance_copy[:, :, 0] = width - instance_copy[:, :, 0] flipped_sample.append(instance_copy) return flipped_sample - + def get_binary_masks(self, sample): - """ Creates the binary mask of each instance in the sample. - """ + """Creates the binary mask of each instance in the sample.""" if self.max_instance_n is None: max_instance_n = len(sample) else: max_instance_n = self.max_instance_n - masks = np.zeros((max_instance_n, self.mask_size, self.mask_size)) + masks = np.zeros((max_instance_n, self.mask_size, self.mask_size)) bboxes = np.zeros((max_instance_n, 4)) valid = np.full(max_instance_n, False) for i, instance in enumerate(sample): @@ -529,8 +590,8 @@ def get_binary_masks(self, sample): min_h, min_w, max_h, max_w = bbox instance_copy = instance.copy() mask = np.zeros((self.mask_size, self.mask_size), dtype=np.uint8) - instance_copy[:,:,0] = (instance_copy[:,:,0] - min_w) / (max_w - min_w) * self.mask_size - instance_copy[:,:,1] = (instance_copy[:,:,1] - min_h) / (max_h - min_h) * self.mask_size + instance_copy[:, :, 0] = (instance_copy[:, :, 0] - min_w) / (max_w - min_w) * self.mask_size + instance_copy[:, :, 1] = (instance_copy[:, :, 1] - min_h) / (max_h - min_h) * self.mask_size cv2.drawContours(mask, [instance_copy], 0, (255), thickness=cv2.FILLED) masks[i] = mask / 255.0 bboxes[i] = np.array(bbox) @@ -546,10 +607,18 @@ def preprocess(self, sample): indecies = np.arange(len(sample)) else: indecies = np.random.choice(len(sample), size=self.max_instance_n, replace=False) - return [p['points'] for i, p in enumerate(sample) if i in indecies] - - def image_augment(self, v, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple, - rand_aug_idx: Optional[int], resample_mode: str = None): + return [p["points"] for i, p in enumerate(sample) if i in indecies] + + def image_augment( + self, + v, + crop_coords: Tuple, + flip: bool, + orig_size: Tuple, + target_size: Tuple, + rand_aug_idx: Optional[int], + resample_mode: str = None, + ): v = self.crop_sample(v, crop_coords) _, _, h, w = crop_coords v = self.resize_sample(v, (h, w), target_size) @@ -561,9 +630,9 @@ def image_augment(self, v, crop_coords: Tuple, flip: bool, orig_size: Tuple, tar def postprocess(self, sample): sample, bboxes, valid = self.get_binary_masks(sample) return { - 'instance': torch.from_numpy(sample).to(torch.float32), - 'bbox': torch.from_numpy(bboxes).to(torch.float32), - 'valid': torch.from_numpy(valid) + "instance": torch.from_numpy(sample).to(torch.float32), + "bbox": torch.from_numpy(bboxes).to(torch.float32), + "valid": torch.from_numpy(valid), } @@ -571,14 +640,14 @@ class MaskTransform(ImageTransform): def __init__(self, mask_pool_size=1): assert isinstance(mask_pool_size, int) - self.mask_pool_size = mask_pool_size # Use to expand masks + self.mask_pool_size = mask_pool_size # Use to expand masks def mask_to_tensor(self, img): mask = TF.to_tensor(img) if self.mask_pool_size > 1: - mask = reduce(mask, 'c (h1 h2) (w1 w2) -> c h1 w1', 'min', h2=self.mask_pool_size, w2=self.mask_pool_size) - mask = repeat(mask, 'c h1 w1 -> c (h1 h2) (w1 w2)', h2=self.mask_pool_size, w2=self.mask_pool_size) - return (mask == 1.0) + mask = reduce(mask, "c (h1 h2) (w1 w2) -> c h1 w1", "min", h2=self.mask_pool_size, w2=self.mask_pool_size) + mask = repeat(mask, "c h1 w1 -> c (h1 h2) (w1 w2)", h2=self.mask_pool_size, w2=self.mask_pool_size) + return mask == 1.0 def load(self, path): sample = self.pil_loader(path) @@ -587,10 +656,18 @@ def load(self, path): def preprocess(self, sample): return sample - def image_augment(self, img, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple, - rand_aug_idx: Optional[int], resample_mode: str = None): + def image_augment( + self, + img, + crop_coords: Tuple, + flip: bool, + orig_size: Tuple, + target_size: Tuple, + rand_aug_idx: Optional[int], + resample_mode: str = None, + ): # Override resampling mode to 'nearest' for masks - img = self.image_crop_and_resize(img, crop_coords, target_size, resample_mode='nearest') + img = self.image_crop_and_resize(img, crop_coords, target_size, resample_mode="nearest") img = self.image_hflip(img, flip) return img @@ -611,10 +688,20 @@ def load(self, path): def preprocess(self, sample): return sample - def image_augment(self, v, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple, - rand_aug_idx: Optional[int], resample_mode: str = None): + def image_augment( + self, + v, + crop_coords: Tuple, + flip: bool, + orig_size: Tuple, + target_size: Tuple, + rand_aug_idx: Optional[int], + resample_mode: str = None, + ): if rand_aug_idx is None: - raise ValueError("Crop settings / augmentation index are missing but a pre-tokenized modality is being used") + raise ValueError( + "Crop settings / augmentation index are missing but a pre-tokenized modality is being used" + ) v = torch.tensor(v[rand_aug_idx]) return v @@ -624,18 +711,26 @@ def postprocess(self, sample): class DetectionTransform(AbstractTransform): - def __init__(self, det_threshold=0.6, det_max_instances=None, bbox_order='dist_to_orig', coord_bins=1000, min_visibility=0.0, return_raw=False): + def __init__( + self, + det_threshold=0.6, + det_max_instances=None, + bbox_order="dist_to_orig", + coord_bins=1000, + min_visibility=0.0, + return_raw=False, + ): self.det_threshold = det_threshold self.det_max_instances = det_max_instances self.coord_bins = coord_bins self.min_visibility = min_visibility self.return_raw = return_raw - if bbox_order == 'area': + if bbox_order == "area": self.bbox_order = self.order_bboxes_by_area - elif bbox_order == 'score': + elif bbox_order == "score": self.bbox_order = self.order_bboxes_by_score - elif bbox_order == 'random': + elif bbox_order == "random": self.bbox_order = self.shuffle_bboxes else: self.bbox_order = self.order_bboxes_by_dist_to_orig @@ -661,14 +756,19 @@ def convert_detection_instance(self, instances): [xmin, ymin, xmax, ymax, class_name, score] """ - instances = [inst['boxes'] + [inst['class_name'], inst['score']] for inst in instances if inst['score'] >= self.det_threshold] + instances = [ + inst["boxes"] + [inst["class_name"], inst["score"]] + for inst in instances + if inst["score"] >= self.det_threshold + ] return instances def bboxes_hflip(self, bboxes: List[Tuple], image_size: Tuple, flip: bool): image_height, image_width = image_size if flip: - bboxes = [tuple(A.bbox_hflip(bbox[:4], rows=image_height, cols=image_width)) + tuple(bbox[4:]) - for bbox in bboxes] + bboxes = [ + tuple(A.bbox_hflip(bbox[:4], rows=image_height, cols=image_width)) + tuple(bbox[4:]) for bbox in bboxes + ] return bboxes @@ -686,9 +786,13 @@ def bboxes_crop_and_resize(self, bboxes: List[Tuple], crop_coords: Tuple, orig_s orig_height, orig_width = orig_size top, left, h, w = crop_coords xmin, ymin, xmax, ymax = left, top, left + w, top + h - bboxes = [tuple(A.bbox_crop(bbox[:4], x_min=xmin, y_min=ymin, x_max=xmax, y_max=ymax, rows=orig_height, - cols=orig_width)) + tuple(bbox[4:]) - for bbox in bboxes] + bboxes = [ + tuple( + A.bbox_crop(bbox[:4], x_min=xmin, y_min=ymin, x_max=xmax, y_max=ymax, rows=orig_height, cols=orig_width) + ) + + tuple(bbox[4:]) + for bbox in bboxes + ] bboxes = A.core.bbox_utils.filter_bboxes(bboxes, rows=h, cols=w, min_visibility=self.min_visibility) # No need to resize, bounding boxes in albumentations format are scale invariant @@ -696,12 +800,12 @@ def bboxes_crop_and_resize(self, bboxes: List[Tuple], crop_coords: Tuple, orig_s def order_and_filter_bboxes(self, bboxes): if self.det_max_instances is not None and len(bboxes) > self.det_max_instances: - bboxes = self.order_bboxes_by_score(bboxes)[:self.det_max_instances] + bboxes = self.order_bboxes_by_score(bboxes)[: self.det_max_instances] return self.bbox_order(bboxes) def convert_bboxes_to_string(self, bboxes: List[Tuple]): - """Convert bounding boxes to a string. + """Convert bounding boxes to a string. xmin, ymin, xmax, ymax are mapped to v0, v1, v2, v3 special tokens. Args: @@ -724,22 +828,30 @@ def convert_bboxes_to_string(self, bboxes: List[Tuple]): for (xmin, ymin, xmax, ymax, cls, score) in bboxes ] # Convert each bounding box to a string - bboxes = [' '.join(b) for b in bboxes] + bboxes = [" ".join(b) for b in bboxes] # Convert the list to a str - return ' '.join(bboxes) + return " ".join(bboxes) def load(self, path): - with open(path, 'r') as f: + with open(path, "r") as f: sample = json.load(f) return sample def preprocess(self, sample): - instances = sample['instances'] + instances = sample["instances"] return self.convert_detection_instance(instances) - def image_augment(self, bboxes: List[Tuple], crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple, - rand_aug_idx=None, resample_mode: str = None): + def image_augment( + self, + bboxes: List[Tuple], + crop_coords: Tuple, + flip: bool, + orig_size: Tuple, + target_size: Tuple, + rand_aug_idx=None, + resample_mode: str = None, + ): bboxes = self.bboxes_crop_and_resize(bboxes, crop_coords, orig_size) bboxes = self.bboxes_hflip(bboxes, target_size, flip) bboxes = self.order_and_filter_bboxes(bboxes) @@ -760,21 +872,29 @@ def __init__(self, aligned_captions=True, no_aug=False): def load(self, path): # Caption can either be stored as .txt or .json.gz (in which case it's a list of dicts) - if path.endswith('.txt'): + if path.endswith(".txt"): sample = Path(path).read_text() - elif path.endswith('.json'): - with open(path, 'r') as f: + elif path.endswith(".json"): + with open(path, "r") as f: sample = json.load(f) - elif path.endswith('.json.gz'): - with gzip.open(path, 'rb') as f: + elif path.endswith(".json.gz"): + with gzip.open(path, "rb") as f: sample = json.load(f) return sample def preprocess(self, sample): return sample - def image_augment(self, val, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple, - rand_aug_idx: Optional[int], resample_mode: str = None): + def image_augment( + self, + val, + crop_coords: Tuple, + flip: bool, + orig_size: Tuple, + target_size: Tuple, + rand_aug_idx: Optional[int], + resample_mode: str = None, + ): if isinstance(val, list) or isinstance(val, tuple): if self.aligned_captions: @@ -800,9 +920,9 @@ def __init__(self, aligned_captions=True, no_aug=False): self.no_aug = no_aug def load(self, path): - if path.endswith('.npz'): + if path.endswith(".npz"): sample = np.load(path) - sample = {'emb': sample['emb'], 'mask_valid': sample['mask_valid']} + sample = {"emb": sample["emb"], "mask_valid": sample["mask_valid"]} else: raise ValueError(f"Invalid file format for caption embedding: {path}") return sample @@ -810,11 +930,19 @@ def load(self, path): def preprocess(self, sample): return sample - def image_augment(self, val, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple, - rand_aug_idx: Optional[int], resample_mode: str = None): - - emb = val['emb'] - mask_valid = val['mask_valid'].astype(bool) + def image_augment( + self, + val, + crop_coords: Tuple, + flip: bool, + orig_size: Tuple, + target_size: Tuple, + rand_aug_idx: Optional[int], + resample_mode: str = None, + ): + + emb = val["emb"] + mask_valid = val["mask_valid"].astype(bool) num_sequences = emb.shape[0] if num_sequences > 1: @@ -832,25 +960,27 @@ def image_augment(self, val, crop_coords: Tuple, flip: bool, orig_size: Tuple, t else: emb, mask_valid = emb[0], mask_valid[0] - emb = emb[mask_valid] # Keep only valid embeddings + emb = emb[mask_valid] # Keep only valid embeddings return emb def postprocess(self, sample): return torch.tensor(sample) - + class MetadataTransform(AbstractTransform): - def __init__(self, - special_vmin: int = 0, - special_vmax: int = 999, - shuffle: bool = True, - random_trunc: bool = False, - return_chunks: bool = True, - return_raw: bool = False, - image_dim_bin_size: int = 32,): - """Metadata transform that takes in a metadata dictionary and converts + def __init__( + self, + special_vmin: int = 0, + special_vmax: int = 999, + shuffle: bool = True, + random_trunc: bool = False, + return_chunks: bool = True, + return_raw: bool = False, + image_dim_bin_size: int = 32, + ): + """Metadata transform that takes in a metadata dictionary and converts it into a string, or list of strings (for chunked span masking). Uses special tokens v1 to denote metadata types, and v0 for their values. @@ -874,59 +1004,64 @@ def __init__(self, # Explicit map to make sure that additional entries do not change existing IDs # TODO: Make this work with other text tokenizers self.metadata_id_map = { - 'original_width': 'v1=0', - 'original_height': 'v1=1', - 'caption_n_chars': 'v1=2', - 'caption_n_words': 'v1=3', - 'caption_n_sentences': 'v1=4', - 'n_humans': 'v1=5', - 'n_sam_instances': 'v1=6', - 'n_coco_instances': 'v1=7', - 'coco_instance_diversity': 'v1=8', - 'colorfulness': 'v1=9', - 'brightness': 'v1=10', - 'contrast': 'v1=11', - 'saturation': 'v1=12', - 'entropy': 'v1=13', - 'walkability': 'v1=14', - 'objectness': 'v1=15', - 'semantic_diversity': 'v1=16', - 'geometric_complexity': 'v1=17', - 'occlusion_score': 'v1=18', - 'watermark_score': 'v1=19', - 'aesthetic_score': 'v1=20', + "original_width": "v1=0", + "original_height": "v1=1", + "caption_n_chars": "v1=2", + "caption_n_words": "v1=3", + "caption_n_sentences": "v1=4", + "n_humans": "v1=5", + "n_sam_instances": "v1=6", + "n_coco_instances": "v1=7", + "coco_instance_diversity": "v1=8", + "colorfulness": "v1=9", + "brightness": "v1=10", + "contrast": "v1=11", + "saturation": "v1=12", + "entropy": "v1=13", + "walkability": "v1=14", + "objectness": "v1=15", + "semantic_diversity": "v1=16", + "geometric_complexity": "v1=17", + "occlusion_score": "v1=18", + "watermark_score": "v1=19", + "aesthetic_score": "v1=20", } self.id_metadata_map = {v: k for k, v in self.metadata_id_map.items()} # Image-dimension modalities are binned into 32 bins - self.image_dim_modalities = ['original_height', 'original_width'] + self.image_dim_modalities = ["original_height", "original_width"] # Integer modalities that don't undergo any scaling (except for truncation) self.metadata_int_modalities = [ - 'caption_n_chars', 'caption_n_words', 'caption_n_sentences', - 'n_humans', 'n_sam_instances', 'n_coco_instances', - 'coco_instance_diversity', 'semantic_diversity', + "caption_n_chars", + "caption_n_words", + "caption_n_sentences", + "n_humans", + "n_sam_instances", + "n_coco_instances", + "coco_instance_diversity", + "semantic_diversity", ] # Bin boundaries for manually defined metadata modalities. # Lowest and highest bin boundaries are implicitly set to -inf and +inf self.metadata_manual_bins = { - 'watermark_score': [0.5], - 'aesthetic_score': [4.5, 5.5], + "watermark_score": [0.5], + "aesthetic_score": [4.5, 5.5], } # All other float or integer modalities that are binned into a defined number of bins # Dictionary entries are (vmin, vmax, num_bins) self.metadata_min_max_bins = { - 'colorfulness': (0, 150, 50), - 'brightness': (0, 255, 50), - 'contrast': (0, 127, 50), - 'saturation': (0, 255, 50), - 'entropy': (0, 10, 50), - 'walkability': (0, 1, 50), - 'objectness': (0, 1, 50), - 'geometric_complexity': (0, 0.75, 50), - 'occlusion_score': (0, 0.25, 50), + "colorfulness": (0, 150, 50), + "brightness": (0, 255, 50), + "contrast": (0, 127, 50), + "saturation": (0, 255, 50), + "entropy": (0, 10, 50), + "walkability": (0, 1, 50), + "objectness": (0, 1, 50), + "geometric_complexity": (0, 0.75, 50), + "occlusion_score": (0, 0.25, 50), } def image_dim_to_string(self, metadata, key, bin_size=32): @@ -941,9 +1076,9 @@ def int_metadata_to_string(self, metadata, key): def float_metadata_to_string(self, metadata, key, vmin, vmax, bins): value = max(vmin, min(metadata[key], vmax)) value = (value - vmin) / (vmax - vmin) - value = int(value * (bins-1)) + value = int(value * (bins - 1)) return f"{self.metadata_id_map[key]} v0={value}" - + def manual_bin_metadata_to_string(self, metadata, key): value = metadata[key] bin_idx = 0 @@ -952,7 +1087,7 @@ def manual_bin_metadata_to_string(self, metadata, key): break bin_idx += 1 return f"{self.metadata_id_map[key]} v0={bin_idx}" - + def metadata_to_string(self, metadata, keys: List[str] = None): keys = list(metadata.keys()) if keys is None else keys @@ -961,10 +1096,10 @@ def metadata_to_string(self, metadata, keys: List[str] = None): random.shuffle(keys) if self.random_trunc: # Randomly truncate - keys = keys[:random.randint(1,len(keys))] + keys = keys[: random.randint(1, len(keys))] metadata_strings = [] - + for key in keys: if key in self.image_dim_modalities: # Image dimension modalities @@ -985,18 +1120,26 @@ def metadata_to_string(self, metadata, keys: List[str] = None): if self.return_chunks: return metadata_strings else: - return ' '.join(metadata_strings) + return " ".join(metadata_strings) def load(self, path): - with open(path, 'r') as f: + with open(path, "r") as f: sample = json.load(f) return sample def preprocess(self, sample): return sample - def image_augment(self, val, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple, - rand_aug_idx=None, resample_mode: str = None): + def image_augment( + self, + val, + crop_coords: Tuple, + flip: bool, + orig_size: Tuple, + target_size: Tuple, + rand_aug_idx=None, + resample_mode: str = None, + ): return val def postprocess(self, metadata): @@ -1004,7 +1147,7 @@ def postprocess(self, metadata): return metadata metadata_str = self.metadata_to_string(metadata) return metadata_str - + class HumanPoseTransform(AbstractTransform): @@ -1015,35 +1158,50 @@ def __init__(self, coord_bins=1000, only_pose=False, return_raw=False): def convert_humanpose_instance(self, instances, only_pose=False): """Convert instances dict to list of lists where each list takes the form: - [human, xmin xmax ymin ymax global val1 val2 ... val10 pose val1 val2 ... val 207 shape val1 val2 ... val10 camera val1 val2 val3 val4] + [human, xmin xmax ymin ymax global val1 val2 ... val10 pose val1 val2 ... val 207 shape val1 val2 ... val10 camera val1 val2 val3 val4] Like for bounding boxes, xmin, ymin, xmax, and ymax map to v0, v1, v2, and v3 respectively. """ - if only_pose: # used for tokenizer training for pose + if only_pose: # used for tokenizer training for pose if len(instances) == 0: return torch.zeros(207) else: - return torch.from_numpy(np.array(instances['pred_smpl_params']['body_pose'][0]).flatten()).float() - if len(instances) == 0: #empty, i.e. there are no humans - return 'none' - + return torch.from_numpy(np.array(instances["pred_smpl_params"]["body_pose"][0]).flatten()).float() + if len(instances) == 0: # empty, i.e. there are no humans + return "none" + for k in instances: - if k!='pred_smpl_params': + if k != "pred_smpl_params": instances[k] = torch.from_numpy(np.array(instances[k])) - smpl_params = (instances['pred_smpl_params']) + smpl_params = instances["pred_smpl_params"] for k in smpl_params: smpl_params[k] = torch.from_numpy(np.array(smpl_params[k])) - total_num_instances = len(instances['bbox_xyxy']) + total_num_instances = len(instances["bbox_xyxy"]) instances_converted = [] for ii in range(total_num_instances): - instances_converted.append(['human'] + (np.array(instances['bbox_xyxy'][ii]).flatten().tolist()) + ['global'] + (np.array(instances['pred_smpl_params']['global_orient'][ii]).flatten().tolist()) + ['pose'] + (instances['pose_tokenized'][ii].flatten().tolist()) + ['shape'] + (instances['pred_smpl_params']['betas'][ii].flatten().tolist()) + ['camera'] + (instances['pred_cam'][ii].flatten().tolist())) + instances_converted.append( + ["human"] + + (np.array(instances["bbox_xyxy"][ii]).flatten().tolist()) + + ["global"] + + (np.array(instances["pred_smpl_params"]["global_orient"][ii]).flatten().tolist()) + + ["pose"] + + (instances["pose_tokenized"][ii].flatten().tolist()) + + ["shape"] + + (instances["pred_smpl_params"]["betas"][ii].flatten().tolist()) + + ["camera"] + + (instances["pred_cam"][ii].flatten().tolist()) + ) return instances_converted - def humanposes_crop_and_resize(self, humanposes: List[Tuple], crop_coords: Tuple, orig_size: Tuple,): - """Crop and resize human poses (and their bounding boxes) - """ + def humanposes_crop_and_resize( + self, + humanposes: List[Tuple], + crop_coords: Tuple, + orig_size: Tuple, + ): + """Crop and resize human poses (and their bounding boxes)""" orig_height, orig_width = orig_size top, left, h, w = crop_coords @@ -1055,23 +1213,24 @@ def humanposes_crop_and_resize(self, humanposes: List[Tuple], crop_coords: Tuple bbox_curr[1::2] = bbox_curr[1::2] / orig_height xmin, ymin, xmax, ymax = left, top, left + w, top + h - bbox_curr = A.bbox_crop(bbox_curr, x_min=xmin, y_min=ymin, x_max=xmax, y_max=ymax, rows=orig_height, - cols=orig_width) + bbox_curr = A.bbox_crop( + bbox_curr, x_min=xmin, y_min=ymin, x_max=xmax, y_max=ymax, rows=orig_height, cols=orig_width + ) bbox_curr = np.array(bbox_curr) - if np.all(bbox_curr[1::2]<0) or np.all(bbox_curr[0::2]<0): #bbox is out of range, remove it + if np.all(bbox_curr[1::2] < 0) or np.all(bbox_curr[0::2] < 0): # bbox is out of range, remove it continue - if np.all(bbox_curr[1::2]>1.0) or np.all(bbox_curr[0::2]>1.0): #bbox is out of range, remove it + if np.all(bbox_curr[1::2] > 1.0) or np.all(bbox_curr[0::2] > 1.0): # bbox is out of range, remove it continue - bbox_curr = np.clip(bbox_curr, a_min=0, a_max=1.) + bbox_curr = np.clip(bbox_curr, a_min=0, a_max=1.0) instance[1:5] = bbox_curr humanposes_converted_resized.append(instance) # now return all instances, or none if there is no instance - if len(humanposes_converted_resized)>0: + if len(humanposes_converted_resized) > 0: pass - else: #no valid masks remains - return 'none' + else: # no valid masks remains + return "none" humanpose_returned = humanposes_converted_resized @@ -1079,14 +1238,14 @@ def humanposes_crop_and_resize(self, humanposes: List[Tuple], crop_coords: Tuple def convert_humanposes_to_string(self, all_humanposes: List[Tuple]): """Convert humanposes to a string - range of global orientation: [-1, 1] - range of object pose: [-1, 1] - range of shape (betas): [-3, 3] - range of camera: [-1, 19] + range of global orientation: [-1, 1] + range of object pose: [-1, 1] + range of shape (betas): [-3, 3] + range of camera: [-1, 19] """ bins = self.coord_bins - instance_final_all = '' + instance_final_all = "" for humanposes in all_humanposes: human = humanposes[0] @@ -1094,76 +1253,96 @@ def convert_humanposes_to_string(self, all_humanposes: List[Tuple]): glob = humanposes[5] global_orient = np.array(humanposes[6:15]) pose = humanposes[15] - pose_params = np.array(humanposes[16:24]) - shape = humanposes[24] - shape_params = np.array(humanposes[25:35]) - camera = humanposes[35] - camera_params = np.clip(np.array(humanposes[36:]), a_min=-1., a_max=19.) + pose_params = np.array(humanposes[16:24]) + shape = humanposes[24] + shape_params = np.array(humanposes[25:35]) + camera = humanposes[35] + camera_params = np.clip(np.array(humanposes[36:]), a_min=-1.0, a_max=19.0) bboxes_new = [ - f"v0={round(bboxes[0] * (bins - 1))}", - f"v1={round(bboxes[1] * (bins - 1))}", - f"v2={round(bboxes[2] * (bins - 1))}", - f"v3={round(bboxes[3] * (bins - 1))}"] + f"v0={round(bboxes[0] * (bins - 1))}", + f"v1={round(bboxes[1] * (bins - 1))}", + f"v2={round(bboxes[2] * (bins - 1))}", + f"v3={round(bboxes[3] * (bins - 1))}", + ] - global_orient = 499.5*global_orient + global_orient = 499.5 * global_orient global_orient_new = [] for ii in range(len(global_orient)): - global_orient_curr = f"v0={round(global_orient[ii]+499.5)}" + global_orient_curr = f"v0={round(global_orient[ii]+499.5)}" global_orient_new.append(global_orient_curr) pose_params_new = [] for ii in range(len(pose_params)): - if pose_params[ii]<512: - pose_params_curr = f"v0={round(pose_params[ii])}" - else: - pose_params_curr = f"v1={round(pose_params[ii] - 512)}" + if pose_params[ii] < 512: + pose_params_curr = f"v0={round(pose_params[ii])}" + else: + pose_params_curr = f"v1={round(pose_params[ii] - 512)}" pose_params_new.append(pose_params_curr) - shape_params = 166.5*shape_params + shape_params = 166.5 * shape_params shape_params_new = [] for ii in range(len(shape_params)): - shape_params_curr = f"v0={round(shape_params[ii]+499.5)}" + shape_params_curr = f"v0={round(shape_params[ii]+499.5)}" shape_params_new.append(shape_params_curr) - camera_params = 49.95*camera_params + camera_params = 49.95 * camera_params camera_params_new = [] for ii in range(len(camera_params)): - camera_params_curr = f"v0={round(camera_params[ii]+49.95)}" + camera_params_curr = f"v0={round(camera_params[ii]+49.95)}" camera_params_new.append(camera_params_curr) - - #randomly shuffle everything except bbox part of the sequence - all_strings = [[pose]+pose_params_new, [glob] + global_orient_new, [camera] + camera_params_new, [shape] + shape_params_new ] + + # randomly shuffle everything except bbox part of the sequence + all_strings = [ + [pose] + pose_params_new, + [glob] + global_orient_new, + [camera] + camera_params_new, + [shape] + shape_params_new, + ] rand_perm = torch.randperm(4) - instance_final = [human] + bboxes_new + all_strings[rand_perm[0]] + all_strings[rand_perm[1]] + all_strings[rand_perm[2]] + all_strings[rand_perm[3]] - - - instance_final = ', '.join(instance_final) + instance_final = ( + [human] + + bboxes_new + + all_strings[rand_perm[0]] + + all_strings[rand_perm[1]] + + all_strings[rand_perm[2]] + + all_strings[rand_perm[3]] + ) + + instance_final = ", ".join(instance_final) instance_final = instance_final.replace(",", "") - instance_final_all = instance_final_all + instance_final + ' ' + instance_final_all = instance_final_all + instance_final + " " - return instance_final_all + return instance_final_all def load(self, path): - with open(path, 'r') as f: + with open(path, "r") as f: sample = json.load(f) return sample def preprocess(self, sample): - instances = sample + instances = sample instances = self.convert_humanpose_instance(instances, only_pose=self.only_pose) return instances - def image_augment(self, humanposes: List[Tuple], crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple, - rand_aug_idx=None, resample_mode: str = None): - if humanposes=='none' or self.only_pose: + def image_augment( + self, + humanposes: List[Tuple], + crop_coords: Tuple, + flip: bool, + orig_size: Tuple, + target_size: Tuple, + rand_aug_idx=None, + resample_mode: str = None, + ): + if humanposes == "none" or self.only_pose: return humanposes humanposes = self.humanposes_crop_and_resize(humanposes, crop_coords, orig_size) return humanposes def postprocess(self, humanposes): - if humanposes=='none' or self.only_pose: + if humanposes == "none" or self.only_pose: return humanposes if not self.return_raw else [] if self.return_raw: return humanposes @@ -1178,9 +1357,8 @@ def __init__(self, coord_bins=1000, return_raw=False): self.return_raw = return_raw def convert_palette_instance(self, instances): - """Convert colors to v0= v0= ... - """ - length = random.randint(1,7) + """Convert colors to v0= v0= ...""" + length = random.randint(1, 7) instances_converted = np.array(instances[0][str(length)]).flatten().tolist() return instances_converted @@ -1189,36 +1367,43 @@ def palette_hflip(self, palettes: List[Tuple], image_size: Tuple, flip: bool): return palettes def convert_palettes_to_string(self, all_palettes: List[Tuple]): - """Convert palettes to a string - """ + """Convert palettes to a string""" colors = [] len_palettes = len(all_palettes) - colors.append(f"v1={round(len_palettes/3)}") # start with the length of the color palette to avoid confusion + colors.append(f"v1={round(len_palettes/3)}") # start with the length of the color palette to avoid confusion for ii in range(len(all_palettes)): color_new = f"v0={round(all_palettes[ii])}" colors.append(color_new) - + instance_final_all = colors - instance_final_all = ', '.join(instance_final_all) + instance_final_all = ", ".join(instance_final_all) instance_final_all = instance_final_all.replace(",", "") - return instance_final_all + return instance_final_all def load(self, path): - with open(path, 'r') as f: + with open(path, "r") as f: sample = json.load(f) return sample def preprocess(self, sample): if self.return_raw: return sample - instances = sample + instances = sample instances = self.convert_palette_instance(instances) return instances - def image_augment(self, palettes: List[Tuple], crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple, - rand_aug_idx=None, resample_mode: str = None): + def image_augment( + self, + palettes: List[Tuple], + crop_coords: Tuple, + flip: bool, + orig_size: Tuple, + target_size: Tuple, + rand_aug_idx=None, + resample_mode: str = None, + ): return palettes def postprocess(self, palettes): @@ -1226,37 +1411,40 @@ def postprocess(self, palettes): return palettes palettes = self.convert_palettes_to_string(palettes) return palettes - + class SAMInstanceTokTransform(AbstractTransform): - def __init__(self, image_size=224, points_per_side=7, point_order='random'): + def __init__(self, image_size=224, points_per_side=7, point_order="random"): self.H, self.W = to_2tuple(image_size) self.points_per_h, self.points_per_w = to_2tuple(points_per_side) - assert point_order in ['random', 'grid'] + assert point_order in ["random", "grid"] self.point_order = point_order def get_query_points(self): - if self.point_order == 'grid': + if self.point_order == "grid": # Create and cache grid query points - if not hasattr(self, 'grid_query_points'): - y, x = np.meshgrid(np.linspace(0, self.H, self.points_per_h + 2)[1:-1], np.linspace(0, self.W, self.points_per_w + 2)[1:-1]) + if not hasattr(self, "grid_query_points"): + y, x = np.meshgrid( + np.linspace(0, self.H, self.points_per_h + 2)[1:-1], + np.linspace(0, self.W, self.points_per_w + 2)[1:-1], + ) grid = np.stack((x, y), axis=2).astype(np.int32) self.grid_query_points = grid.reshape(-1, 2) return self.grid_query_points - elif self.point_order == 'random': + elif self.point_order == "random": # Randomly sample query points y = np.random.randint(0, self.H, self.points_per_h) x = np.random.randint(0, self.W, self.points_per_w) - return np.concatenate((x[:,None], y[:,None]), axis=1) + return np.concatenate((x[:, None], y[:, None]), axis=1) else: raise ValueError(f"Query point order mode {self.point_order} is not supported.") def get_target_tokens(self, sample, query_points): - instances_coords = [coords[0] for coords in sample['points']] - tokens = sample['token_ids'] - bboxes = sample['bbox'] - + instances_coords = [coords[0] for coords in sample["points"]] + tokens = sample["token_ids"] + bboxes = sample["bbox"] + instance_tokens_per_qpoint = dict() for point in query_points: point = (int(point[0].item()), int(point[1].item())) @@ -1267,7 +1455,7 @@ def get_target_tokens(self, sample, query_points): # If the query point is inside the instance, add its corresponding token if distance >= 0: instance_tokens_per_qpoint[point].append((tok, bbox)) - + return instance_tokens_per_qpoint def convert_target_tokens_to_string(self, target_tokens): @@ -1276,37 +1464,39 @@ def convert_target_tokens_to_string(self, target_tokens): # Randomly shuffle query points order (mainly for grid order) random.shuffle(query_points) for point in query_points: - + # Add query point coordinates to the string - result_text.append('point') - result_text.append(f'v0={point[1]}') - result_text.append(f'v1={point[0]}') - + result_text.append("point") + result_text.append(f"v0={point[1]}") + result_text.append(f"v1={point[0]}") + # Randomly shuffle the order of instance tokens per query point random.shuffle(target_tokens[point]) if len(target_tokens[point]) == 0: # If no instances tokens are found, add 'none' to the string - result_text.append('none') + result_text.append("none") else: for tok, bbox in target_tokens[point]: - result_text.append(f'polygon') - + result_text.append(f"polygon") + # Add bounding box coordinates to the string ymin, xmin, ymax, xmax = bbox.astype(np.int32) - result_text.extend([ - f'v0={xmin}', - f'v1={ymin}', - f'v2={xmax}', - f'v3={ymax}', - ]) - + result_text.extend( + [ + f"v0={xmin}", + f"v1={ymin}", + f"v2={xmax}", + f"v3={ymax}", + ] + ) + # Add instance tokens ids to the string for idx in tok.tolist(): if idx < 512: - result_text.append(f'v0={idx}') + result_text.append(f"v0={idx}") else: - result_text.append(f'v1={idx - 512}') - + result_text.append(f"v1={idx - 512}") + return " ".join(result_text) def load(self, path): @@ -1315,13 +1505,23 @@ def load(self, path): def preprocess(self, sample): for s in sample: - s['token_ids'] = s['token_ids'].astype(np.int32) + s["token_ids"] = s["token_ids"].astype(np.int32) return sample - def image_augment(self, v, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple, - rand_aug_idx: Optional[int], resample_mode: str = None): + def image_augment( + self, + v, + crop_coords: Tuple, + flip: bool, + orig_size: Tuple, + target_size: Tuple, + rand_aug_idx: Optional[int], + resample_mode: str = None, + ): if rand_aug_idx is None: - raise ValueError("Crop settings / augmentation index are missing but a pre-tokenized modality is being used") + raise ValueError( + "Crop settings / augmentation index are missing but a pre-tokenized modality is being used" + ) v = v[rand_aug_idx] return v @@ -1341,8 +1541,16 @@ def load(self, path): def preprocess(self, sample): raise NotImplementedError("CropSettingsTransform does not support preprocessing") - def image_augment(self, val, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple, - rand_aug_idx: Optional[int], resample_mode: str = None): + def image_augment( + self, + val, + crop_coords: Tuple, + flip: bool, + orig_size: Tuple, + target_size: Tuple, + rand_aug_idx: Optional[int], + resample_mode: str = None, + ): raise NotImplementedError("CropSettingsTransform is not meant to be used for image augmentation") def postprocess(self, sample): @@ -1357,8 +1565,16 @@ def load(self, path): def preprocess(self, sample): return sample - def image_augment(self, val, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple, - rand_aug_idx: Optional[int], resample_mode: str = None): + def image_augment( + self, + val, + crop_coords: Tuple, + flip: bool, + orig_size: Tuple, + target_size: Tuple, + rand_aug_idx: Optional[int], + resample_mode: str = None, + ): return val def postprocess(self, sample): @@ -1368,20 +1584,28 @@ def postprocess(self, sample): class JSONTransform(AbstractTransform): def load(self, path): - if path.endswith('.json'): - with open(path, 'r') as f: + if path.endswith(".json"): + with open(path, "r") as f: sample = json.load(f) - elif path.endswith('.json.gz'): - with gzip.open(path, 'rb') as f: + elif path.endswith(".json.gz"): + with gzip.open(path, "rb") as f: sample = json.load(f) return sample def preprocess(self, sample): return sample - def image_augment(self, val, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple, - rand_aug_idx: Optional[int], resample_mode: str = None): + def image_augment( + self, + val, + crop_coords: Tuple, + flip: bool, + orig_size: Tuple, + target_size: Tuple, + rand_aug_idx: Optional[int], + resample_mode: str = None, + ): return val def postprocess(self, sample): - return sample \ No newline at end of file + return sample