diff --git a/examples/RELTR.py b/examples/RELTR.py new file mode 100644 index 00000000..1a91ccdc --- /dev/null +++ b/examples/RELTR.py @@ -0,0 +1,20 @@ +from openks.models import OpenKSModel + +# 列出已加载模型 +OpenKSModel.list_modules() + +# 算法模型选择配置 +args = { + "MODEL.DEVICE": 'cpu' + } +platform = 'PyTorch' +executor = 'RELTRExtract' +model = 'pytorch-RELTRExtractor' +print("根据配置,使用 {} 框架,{} 执行器训练 {} 模型。".format(platform, executor, model)) +print("-----------------------------------------------") +# 模型训练 +executor = OpenKSModel.get_module(platform, executor) + +text_ner = executor(args=args) +text_ner.run(mode="train") +print("-----------------------------------------------") \ No newline at end of file diff --git a/openks/models/pytorch/RELTR.py b/openks/models/pytorch/RELTR.py new file mode 100644 index 00000000..f52e5245 --- /dev/null +++ b/openks/models/pytorch/RELTR.py @@ -0,0 +1,392 @@ +import logging +import argparse +import torch +import torch.nn as nn +from sklearn.model_selection import train_test_split +from ..model import VisualConstructionModel +from .visual_entity_modules import clip +from .visual_entity_modules.datasets import loaddata +from .visual_entity_modules.newbert_model import TransformerBiaffine as Model +from PIL import Image +import argparse +from pathlib import Path +import util.misc as utils +import numpy as np +import random +import time +import datetime +from .visual_entity_modules.models import build_model +from .visual_entity_modules.datasets1 import build_dataset, get_coco_api_from_dataset +from torch.utils.data import DataLoader, DistributedSampler +import matplotlib.pyplot as plt +#from .visual_entity_modules.engine import evaluate, train_one_epoch +from .visual_entity_modules.engine import evaluate, train_one_epoch +import json +import torchvision.transforms as T +@VisualConstructionModel.register("RELTRExtract", "PyTorch") +class RELTRTorch(VisualConstructionModel): + def __init__(self, name: str, dataset=None, args=None): + # super().__init__(name=name, dataset=dataset, args=args) + parser = argparse.ArgumentParser('Set transformer detector', add_help=False) + parser.add_argument('--lr', default=1e-4, type=float) + parser.add_argument('--lr_backbone', default=1e-5, type=float) + parser.add_argument('--batch_size', default=2, type=int) + parser.add_argument('--weight_decay', default=1e-4, type=float) + parser.add_argument('--epochs', default=150, type=int) + parser.add_argument('--lr_drop', default=100, type=int) + parser.add_argument('--clip_max_norm', default=0.1, type=float, + help='gradient clipping max norm') + + # Model parameters + parser.add_argument('--frozen_weights', type=str, default=None, + help="Path to the pretrained model. If set, only the mask head will be trained") + # * Backbone + parser.add_argument('--backbone', default='resnet50', type=str, + help="Name of the convolutional backbone to use") + parser.add_argument('--dilation', action='store_true', + help="If true, we replace stride with dilation in the last convolutional block (DC5)") + parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), + help="Type of positional embedding to use on top of the image features") + + # * Transformer + parser.add_argument('--enc_layers', default=6, type=int, + help="Number of encoding layers in the transformer") + parser.add_argument('--dec_layers', default=6, type=int, + help="Number of decoding layers in the transformer") + parser.add_argument('--dim_feedforward', default=2048, type=int, + help="Intermediate size of the feedforward layers in the transformer blocks") + parser.add_argument('--hidden_dim', default=256, type=int, + help="Size of the embeddings (dimension of the transformer)") + parser.add_argument('--dropout', default=0.1, type=float, + help="Dropout applied in the transformer") + parser.add_argument('--nheads', default=8, type=int, + help="Number of attention heads inside the transformer's attentions") + parser.add_argument('--num_entities', default=100, type=int, + help="Number of query slots") + parser.add_argument('--num_triplets', default=200, type=int, + help="Number of query slots") + parser.add_argument('--pre_norm', action='store_true') + + # Loss + parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false', + help="Disables auxiliary decoding losses (loss at each layer)") + # * Matcher + parser.add_argument('--set_cost_class', default=1, type=float, + help="Class coefficient in the matching cost") + parser.add_argument('--set_cost_bbox', default=5, type=float, + help="L1 box coefficient in the matching cost") + parser.add_argument('--set_cost_giou', default=2, type=float, + help="giou box coefficient in the matching cost") + parser.add_argument('--set_iou_threshold', default=0.7, type=float, + help="giou box coefficient in the matching cost") + + # * Loss coefficients + parser.add_argument('--bbox_loss_coef', default=5, type=float) + parser.add_argument('--giou_loss_coef', default=2, type=float) + parser.add_argument('--rel_loss_coef', default=1, type=float) + parser.add_argument('--eos_coef', default=0.1, type=float, + help="Relative classification weight of the no-object class") + + # dataset parameters + parser.add_argument('--dataset', default='vg') + parser.add_argument('--ann_path', default='./data/vg/', type=str) + parser.add_argument('--img_folder', default='/home/cong/Dokumente/tmp/data/visualgenome/images/', type=str) + + parser.add_argument('--output_dir', default='', + help='path where to save, empty for no saving') + parser.add_argument('--device', default='cuda', + help='device to use for training / testing') + parser.add_argument('--seed', default=42, type=int) + parser.add_argument('--resume', default='', help='resume from checkpoint') + parser.add_argument('--start_epoch', default=0, type=int, metavar='N', + help='start epoch') + parser.add_argument('--eval', action='store_true') + parser.add_argument('--num_workers', default=2, type=int) + + # distributed training parameters + parser.add_argument('--world_size', default=1, type=int, + help='number of distributed processes') + parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') + + parser.add_argument('--return_interm_layers', action='store_true', + help="Return the fpn if there is the tag") + + self.parser = parser + + def parse_args(self,args): + return args + + # def data_reader(self, *args): + + # return super().data_reader(*args) + + def evaluate(self, *args): + transform = T.Compose([ + T.Resize(800), + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + + # for output bounding box post-processing + def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), + (x_c + 0.5 * w), (y_c + 0.5 * h)] + return torch.stack(b, dim=1) + + def rescale_bboxes(out_bbox, size): + img_w, img_h = size + b = box_cxcywh_to_xyxy(out_bbox) + b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32) + return b + + # VG classes + CLASSES = ['N/A', 'airplane', 'animal', 'arm', 'bag', 'banana', 'basket', 'beach', 'bear', 'bed', 'bench', + 'bike', + 'bird', 'board', 'boat', 'book', 'boot', 'bottle', 'bowl', 'box', 'boy', 'branch', 'building', + 'bus', 'cabinet', 'cap', 'car', 'cat', 'chair', 'child', 'clock', 'coat', 'counter', 'cow', 'cup', + 'curtain', 'desk', 'dog', 'door', 'drawer', 'ear', 'elephant', 'engine', 'eye', 'face', 'fence', + 'finger', 'flag', 'flower', 'food', 'fork', 'fruit', 'giraffe', 'girl', 'glass', 'glove', 'guy', + 'hair', 'hand', 'handle', 'hat', 'head', 'helmet', 'hill', 'horse', 'house', 'jacket', 'jean', + 'kid', 'kite', 'lady', 'lamp', 'laptop', 'leaf', 'leg', 'letter', 'light', 'logo', 'man', 'men', + 'motorcycle', 'mountain', 'mouth', 'neck', 'nose', 'number', 'orange', 'pant', 'paper', 'paw', + 'people', 'person', 'phone', 'pillow', 'pizza', 'plane', 'plant', 'plate', 'player', 'pole', 'post', + 'pot', 'racket', 'railing', 'rock', 'roof', 'room', 'screen', 'seat', 'sheep', 'shelf', 'shirt', + 'shoe', 'short', 'sidewalk', 'sign', 'sink', 'skateboard', 'ski', 'skier', 'sneaker', 'snow', + 'sock', 'stand', 'street', 'surfboard', 'table', 'tail', 'tie', 'tile', 'tire', 'toilet', 'towel', + 'tower', 'track', 'train', 'tree', 'truck', 'trunk', 'umbrella', 'vase', 'vegetable', 'vehicle', + 'wave', 'wheel', 'window', 'windshield', 'wing', 'wire', 'woman', 'zebra'] + + REL_CLASSES = ['__background__', 'above', 'across', 'against', 'along', 'and', 'at', 'attached to', 'behind', + 'belonging to', 'between', 'carrying', 'covered in', 'covering', 'eating', 'flying in', 'for', + 'from', 'growing on', 'hanging from', 'has', 'holding', 'in', 'in front of', 'laying on', + 'looking at', 'lying on', 'made of', 'mounted on', 'near', 'of', 'on', 'on back of', 'over', + 'painted on', 'parked on', 'part of', 'playing', 'riding', 'says', 'sitting on', 'standing on', + 'to', 'under', 'using', 'walking in', 'walking on', 'watching', 'wearing', 'wears', 'with'] + + model, _, _ = build_model(args) + ckpt = torch.load(args.resume) + model.load_state_dict(ckpt['model']) + model.eval() + + img_path = args.img_path + im = Image.open(img_path) + + # mean-std normalize the input image (batch-size: 1) + img = transform(im).unsqueeze(0) + + # propagate through the model + outputs = model(img) + + # keep only predictions with 0.+ confidence + probas = outputs['rel_logits'].softmax(-1)[0, :, :-1] + probas_sub = outputs['sub_logits'].softmax(-1)[0, :, :-1] + probas_obj = outputs['obj_logits'].softmax(-1)[0, :, :-1] + keep = torch.logical_and(probas.max(-1).values > 0.3, torch.logical_and(probas_sub.max(-1).values > 0.3, + probas_obj.max(-1).values > 0.3)) + + # convert boxes from [0; 1] to image scales + sub_bboxes_scaled = rescale_bboxes(outputs['sub_boxes'][0, keep], im.size) + obj_bboxes_scaled = rescale_bboxes(outputs['obj_boxes'][0, keep], im.size) + + topk = 10 + keep_queries = torch.nonzero(keep, as_tuple=True)[0] + indices = torch.argsort( + -probas[keep_queries].max(-1)[0] * probas_sub[keep_queries].max(-1)[0] * probas_obj[keep_queries].max(-1)[ + 0])[:topk] + keep_queries = keep_queries[indices] + + # use lists to store the outputs via up-values + conv_features, dec_attn_weights_sub, dec_attn_weights_obj = [], [], [] + + hooks = [ + model.backbone[-2].register_forward_hook( + lambda self, input, output: conv_features.append(output) + ), + model.transformer.decoder.layers[-1].cross_attn_sub.register_forward_hook( + lambda self, input, output: dec_attn_weights_sub.append(output[1]) + ), + model.transformer.decoder.layers[-1].cross_attn_obj.register_forward_hook( + lambda self, input, output: dec_attn_weights_obj.append(output[1]) + ) + ] + with torch.no_grad(): + # propagate through the model + outputs = model(img) + + for hook in hooks: + hook.remove() + + # don't need the list anymore + conv_features = conv_features[0] + dec_attn_weights_sub = dec_attn_weights_sub[0] + dec_attn_weights_obj = dec_attn_weights_obj[0] + + # get the feature map shape + h, w = conv_features['0'].tensors.shape[-2:] + im_w, im_h = im.size + + fig, axs = plt.subplots(ncols=len(indices), nrows=3, figsize=(22, 7)) + for idx, ax_i, (sxmin, symin, sxmax, symax), (oxmin, oymin, oxmax, oymax) in \ + zip(keep_queries, axs.T, sub_bboxes_scaled[indices], obj_bboxes_scaled[indices]): + ax = ax_i[0] + ax.imshow(dec_attn_weights_sub[0, idx].view(h, w)) + ax.axis('off') + ax.set_title(f'query id: {idx.item()}') + ax = ax_i[1] + ax.imshow(dec_attn_weights_obj[0, idx].view(h, w)) + ax.axis('off') + ax = ax_i[2] + ax.imshow(im) + ax.add_patch(plt.Rectangle((sxmin, symin), sxmax - sxmin, symax - symin, + fill=False, color='blue', linewidth=2.5)) + ax.add_patch(plt.Rectangle((oxmin, oymin), oxmax - oxmin, oymax - oymin, + fill=False, color='orange', linewidth=2.5)) + + ax.axis('off') + ax.set_title( + CLASSES[probas_sub[idx].argmax()] + ' ' + REL_CLASSES[probas[idx].argmax()] + ' ' + CLASSES[ + probas_obj[idx].argmax()], fontsize=10) + + fig.tight_layout() + plt.show() + + def train(self, *args): + parser = argparse.ArgumentParser('RelTR training and evaluation script', parents=[self.parser]) + args = parser.parse_args() + if args.output_dir: + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + utils.init_distributed_mode(args) + print("git:\n {}\n".format(utils.get_sha())) + if args.frozen_weights is not None: + assert args.masks, "Frozen training is meant for segmentation only" + print(args) + + device = torch.device(args.device) + + # fix the seed for reproducibility + seed = args.seed + utils.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + model, criterion, postprocessors = build_model(args) + model.to(device) + + model_without_ddp = model + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + model_without_ddp = model.module + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print('number of params:', n_parameters) + + param_dicts = [ + {"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad]}, + { + "params": [p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad], + "lr": args.lr_backbone, + }, + ] + optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, + weight_decay=args.weight_decay) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop) + + dataset_train = build_dataset(image_set='train', args=args) + dataset_val = build_dataset(image_set='val', args=args) + + if args.distributed: + sampler_train = DistributedSampler(dataset_train) + sampler_val = DistributedSampler(dataset_val, shuffle=False) + else: + sampler_train = torch.utils.data.RandomSampler(dataset_train) + sampler_val = torch.utils.data.SequentialSampler(dataset_val) + + batch_sampler_train = torch.utils.data.BatchSampler( + sampler_train, args.batch_size, drop_last=True) + + data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train, + collate_fn=utils.collate_fn, num_workers=args.num_workers) + data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val, + drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers) + + base_ds = get_coco_api_from_dataset(dataset_val) + + if args.frozen_weights is not None: + checkpoint = torch.load(args.frozen_weights, map_location='cpu') + model_without_ddp.detr.load_state_dict(checkpoint['model']) + + output_dir = Path(args.output_dir) + if args.resume: + checkpoint = torch.load(args.resume, map_location='cpu') + model_without_ddp.load_state_dict(checkpoint['model'], strict=True) + # del checkpoint['optimizer'] + if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: + optimizer.load_state_dict(checkpoint['optimizer']) + lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) + args.start_epoch = checkpoint['epoch'] + 1 + + if args.eval: + print('It is the {}th checkpoint'.format(checkpoint['epoch'])) + test_stats, coco_evaluator = evaluate(model, criterion, postprocessors, data_loader_val, base_ds, device, + args) + if args.output_dir: + utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth") + return + + print("Start training") + start_time = time.time() + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + sampler_train.set_epoch(epoch) + train_stats = train_one_epoch(model, criterion, data_loader_train, optimizer, device, epoch, + args.clip_max_norm) + lr_scheduler.step() + if args.output_dir: + checkpoint_paths = [output_dir / 'checkpoint.pth'] # anti-crash + # extra checkpoint before LR drop and every 100 epochs + if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 5 == 0: + checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth') + for checkpoint_path in checkpoint_paths: + utils.save_on_master({ + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'lr_scheduler': lr_scheduler.state_dict(), + 'epoch': epoch, + 'args': args, + }, checkpoint_path) + + test_stats, coco_evaluator = evaluate(model, criterion, postprocessors, data_loader_val, base_ds, device, + args) + + log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, + **{f'test_{k}': v for k, v in test_stats.items()}, + 'epoch': epoch, + 'n_parameters': n_parameters} + + if args.output_dir and utils.is_main_process(): + with (output_dir / "log.txt").open("a") as f: + f.write(json.dumps(log_stats) + "\n") + + # for evaluation logs + if coco_evaluator is not None: + (output_dir / 'eval').mkdir(exist_ok=True) + if "bbox" in coco_evaluator.coco_eval: + filenames = ['latest.pth'] + if epoch % 50 == 0: + filenames.append(f'{epoch:03}.pth') + for name in filenames: + torch.save(coco_evaluator.coco_eval["bbox"].eval, + output_dir / "eval" / name) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + + def run(self, mode="train"): + if mode == "train": + self.train() + elif mode == "eval": + self.evaluate() + elif mode == "single": + raise ValueError("UnImplemented mode!") \ No newline at end of file diff --git a/openks/models/pytorch/visual_entity_modules/datasets1/__init__.py b/openks/models/pytorch/visual_entity_modules/datasets1/__init__.py new file mode 100644 index 00000000..e54414af --- /dev/null +++ b/openks/models/pytorch/visual_entity_modules/datasets1/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import torch.utils.data +import torchvision + +from .coco import build as build_coco + + +def get_coco_api_from_dataset(dataset): + for _ in range(10): + # if isinstance(dataset, torchvision.datasets.CocoDetection): + # break + if isinstance(dataset, torch.utils.data.Subset): + dataset = dataset.dataset + if isinstance(dataset, torchvision.datasets.CocoDetection): + return dataset.coco + + +def build_dataset(image_set, args): + if args.dataset == 'vg' or args.dataset_file == 'oi': + return build_coco(image_set, args) + raise ValueError(f'dataset {args.dataset} not supported') diff --git a/openks/models/pytorch/visual_entity_modules/datasets1/coco.py b/openks/models/pytorch/visual_entity_modules/datasets1/coco.py new file mode 100644 index 00000000..cb71764e --- /dev/null +++ b/openks/models/pytorch/visual_entity_modules/datasets1/coco.py @@ -0,0 +1,182 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# Copyright (c) Institute of Information Processing, Leibniz University Hannover. + +""" +dataset (COCO-like) which returns image_id for evaluation. + +Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py +""" +from pathlib import Path +import json +import torch +import torch.utils.data +import torchvision +from pycocotools import mask as coco_mask + +import datasets.transforms as T + +class CocoDetection(torchvision.datasets.CocoDetection): + def __init__(self, img_folder, ann_file, transforms, return_masks): + super(CocoDetection, self).__init__(img_folder, ann_file) + self._transforms = transforms + self.prepare = ConvertCocoPolysToMask(return_masks) + + #TODO load relationship + with open('/'.join(ann_file.split('/')[:-1])+'/rel.json', 'r') as f: + all_rels = json.load(f) + if 'train' in ann_file: + self.rel_annotations = all_rels['train'] + elif 'val' in ann_file: + self.rel_annotations = all_rels['val'] + else: + self.rel_annotations = all_rels['test'] + + self.rel_categories = all_rels['rel_categories'] + + def __getitem__(self, idx): + img, target = super(CocoDetection, self).__getitem__(idx) + image_id = self.ids[idx] + rel_target = self.rel_annotations[str(image_id)] + + target = {'image_id': image_id, 'annotations': target, 'rel_annotations': rel_target} + + img, target = self.prepare(img, target) + if self._transforms is not None: + img, target = self._transforms(img, target) + return img, target + + +def convert_coco_poly_to_mask(segmentations, height, width): + masks = [] + for polygons in segmentations: + rles = coco_mask.frPyObjects(polygons, height, width) + mask = coco_mask.decode(rles) + if len(mask.shape) < 3: + mask = mask[..., None] + mask = torch.as_tensor(mask, dtype=torch.uint8) + mask = mask.any(dim=2) + masks.append(mask) + if masks: + masks = torch.stack(masks, dim=0) + else: + masks = torch.zeros((0, height, width), dtype=torch.uint8) + return masks + + +class ConvertCocoPolysToMask(object): + def __init__(self, return_masks=False): + self.return_masks = return_masks + + def __call__(self, image, target): + w, h = image.size + + image_id = target["image_id"] + image_id = torch.tensor([image_id]) + + anno = target["annotations"] + + anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0] + + boxes = [obj["bbox"] for obj in anno] + # guard against no boxes via resizing + boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) + boxes[:, 2:] += boxes[:, :2] + boxes[:, 0::2].clamp_(min=0, max=w) + boxes[:, 1::2].clamp_(min=0, max=h) + + classes = [obj["category_id"] for obj in anno] + classes = torch.tensor(classes, dtype=torch.int64) + + if self.return_masks: + segmentations = [obj["segmentation"] for obj in anno] + masks = convert_coco_poly_to_mask(segmentations, h, w) + + keypoints = None + if anno and "keypoints" in anno[0]: + keypoints = [obj["keypoints"] for obj in anno] + keypoints = torch.as_tensor(keypoints, dtype=torch.float32) + num_keypoints = keypoints.shape[0] + if num_keypoints: + keypoints = keypoints.view(num_keypoints, -1, 3) + + keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) + boxes = boxes[keep] + classes = classes[keep] + if self.return_masks: + masks = masks[keep] + if keypoints is not None: + keypoints = keypoints[keep] + + # TODO add relation gt in the target + rel_annotations = target['rel_annotations'] + + target = {} + target["boxes"] = boxes + target["labels"] = classes + if self.return_masks: + target["masks"] = masks + target["image_id"] = image_id + if keypoints is not None: + target["keypoints"] = keypoints + + # for conversion to coco api + area = torch.tensor([obj["area"] for obj in anno]) + iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno]) + target["area"] = area[keep] + target["iscrowd"] = iscrowd[keep] + + target["orig_size"] = torch.as_tensor([int(h), int(w)]) + target["size"] = torch.as_tensor([int(h), int(w)]) + # TODO add relation gt in the target + target['rel_annotations'] = torch.tensor(rel_annotations) + + return image, target + + +def make_coco_transforms(image_set): + + normalize = T.Compose([ + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + + scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] + + if image_set == 'train': + return T.Compose([ + T.RandomHorizontalFlip(), + T.RandomSelect( + T.RandomResize(scales, max_size=1333), + T.Compose([ + T.RandomResize([400, 500, 600]), + #T.RandomSizeCrop(384, 600), # TODO: cropping causes that some boxes are dropped then no tensor in the relation part! What should we do? + T.RandomResize(scales, max_size=1333), + ]) + ), + normalize]) + + if image_set == 'val': + return T.Compose([ + T.RandomResize([800], max_size=1333), + normalize, + ]) + + raise ValueError(f'unknown {image_set}') + + +def build(image_set, args): + + ann_path = args.ann_path + img_folder = args.img_folder + + #TODO: adapt vg as coco + if image_set == 'train': + ann_file = ann_path + 'train.json' + elif image_set == 'val': + if args.eval: + ann_file = ann_path + 'test.json' + else: + ann_file = ann_path + 'val.json' + + dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set), return_masks=False) + return dataset diff --git a/openks/models/pytorch/visual_entity_modules/datasets1/coco_eval.py b/openks/models/pytorch/visual_entity_modules/datasets1/coco_eval.py new file mode 100644 index 00000000..3078429f --- /dev/null +++ b/openks/models/pytorch/visual_entity_modules/datasets1/coco_eval.py @@ -0,0 +1,257 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +COCO evaluator that works in distributed mode. + +Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py +The difference is that there is less copy-pasting from pycocotools +in the end of the file, as python3 can suppress prints with contextlib +""" +import os +import contextlib +import copy +import numpy as np +import torch + +from pycocotools.cocoeval import COCOeval +from pycocotools.coco import COCO +import pycocotools.mask as mask_util + +from util.misc import all_gather + + +class CocoEvaluator(object): + def __init__(self, coco_gt, iou_types): + assert isinstance(iou_types, (list, tuple)) + #coco_gt = copy.deepcopy(coco_gt) + self.coco_gt = coco_gt + + self.iou_types = iou_types + self.coco_eval = {} + for iou_type in iou_types: + self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type) + + self.img_ids = [] + self.eval_imgs = {k: [] for k in iou_types} + + def update(self, predictions): + img_ids = list(np.unique(list(predictions.keys()))) + self.img_ids.extend(img_ids) + + for iou_type in self.iou_types: + results = self.prepare(predictions, iou_type) + + # suppress pycocotools prints + with open(os.devnull, 'w') as devnull: + with contextlib.redirect_stdout(devnull): + coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO() + coco_eval = self.coco_eval[iou_type] + + coco_eval.cocoDt = coco_dt + coco_eval.params.imgIds = list(img_ids) + img_ids, eval_imgs = evaluate(coco_eval) + + self.eval_imgs[iou_type].append(eval_imgs) + + def synchronize_between_processes(self): + for iou_type in self.iou_types: + self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2) + create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type]) + + def accumulate(self): + for coco_eval in self.coco_eval.values(): + coco_eval.accumulate() + + def summarize(self): + for iou_type, coco_eval in self.coco_eval.items(): + print("IoU metric: {}".format(iou_type)) + coco_eval.summarize() + + def prepare(self, predictions, iou_type): + if iou_type == "bbox": + return self.prepare_for_coco_detection(predictions) + elif iou_type == "segm": + return self.prepare_for_coco_segmentation(predictions) + elif iou_type == "keypoints": + return self.prepare_for_coco_keypoint(predictions) + else: + raise ValueError("Unknown iou type {}".format(iou_type)) + + def prepare_for_coco_detection(self, predictions): + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + boxes = prediction["boxes"] + boxes = convert_to_xywh(boxes).tolist() + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + "bbox": box, + "score": scores[k], + } + for k, box in enumerate(boxes) + ] + ) + return coco_results + + def prepare_for_coco_segmentation(self, predictions): + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + scores = prediction["scores"] + labels = prediction["labels"] + masks = prediction["masks"] + + masks = masks > 0.5 + + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + + rles = [ + mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] + for mask in masks + ] + for rle in rles: + rle["counts"] = rle["counts"].decode("utf-8") + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + "segmentation": rle, + "score": scores[k], + } + for k, rle in enumerate(rles) + ] + ) + return coco_results + + def prepare_for_coco_keypoint(self, predictions): + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + boxes = prediction["boxes"] + boxes = convert_to_xywh(boxes).tolist() + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + keypoints = prediction["keypoints"] + keypoints = keypoints.flatten(start_dim=1).tolist() + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + 'keypoints': keypoint, + "score": scores[k], + } + for k, keypoint in enumerate(keypoints) + ] + ) + return coco_results + + +def convert_to_xywh(boxes): + xmin, ymin, xmax, ymax = boxes.unbind(1) + return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1) + + +def merge(img_ids, eval_imgs): + all_img_ids = all_gather(img_ids) + all_eval_imgs = all_gather(eval_imgs) + + merged_img_ids = [] + for p in all_img_ids: + merged_img_ids.extend(p) + + merged_eval_imgs = [] + for p in all_eval_imgs: + merged_eval_imgs.append(p) + + merged_img_ids = np.array(merged_img_ids) + merged_eval_imgs = np.concatenate(merged_eval_imgs, 2) + + # keep only unique (and in sorted order) images + merged_img_ids, idx = np.unique(merged_img_ids, return_index=True) + merged_eval_imgs = merged_eval_imgs[..., idx] + + return merged_img_ids, merged_eval_imgs + + +def create_common_coco_eval(coco_eval, img_ids, eval_imgs): + img_ids, eval_imgs = merge(img_ids, eval_imgs) + img_ids = list(img_ids) + eval_imgs = list(eval_imgs.flatten()) + + coco_eval.evalImgs = eval_imgs + coco_eval.params.imgIds = img_ids + coco_eval._paramsEval = copy.deepcopy(coco_eval.params) + + +################################################################# +# From pycocotools, just removed the prints and fixed +# a Python3 bug about unicode not defined +################################################################# + + +def evaluate(self): + ''' + Run per image evaluation on given images and store results (a list of dict) in self.evalImgs + :return: None + ''' + # tic = time.time() + # print('Running per image evaluation...') + p = self.params + # add backward compatibility if useSegm is specified in params + if p.useSegm is not None: + p.iouType = 'segm' if p.useSegm == 1 else 'bbox' + print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType)) + # print('Evaluate annotation type *{}*'.format(p.iouType)) + p.imgIds = list(np.unique(p.imgIds)) + if p.useCats: + p.catIds = list(np.unique(p.catIds)) + p.maxDets = sorted(p.maxDets) + self.params = p + + self._prepare() + # loop through images, area range, max detection number + catIds = p.catIds if p.useCats else [-1] + + if p.iouType == 'segm' or p.iouType == 'bbox': + computeIoU = self.computeIoU + elif p.iouType == 'keypoints': + computeIoU = self.computeOks + self.ious = { + (imgId, catId): computeIoU(imgId, catId) + for imgId in p.imgIds + for catId in catIds} + + evaluateImg = self.evaluateImg + maxDet = p.maxDets[-1] + evalImgs = [ + evaluateImg(imgId, catId, areaRng, maxDet) + for catId in catIds + for areaRng in p.areaRng + for imgId in p.imgIds + ] + # this is NOT in the pycocotools code, but could be done outside + evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds)) + self._paramsEval = copy.deepcopy(self.params) + # toc = time.time() + # print('DONE (t={:0.2f}s).'.format(toc-tic)) + return p.imgIds, evalImgs + +################################################################# +# end of straight copy from pycocotools, just removing the prints +################################################################# diff --git a/openks/models/pytorch/visual_entity_modules/datasets1/transforms.py b/openks/models/pytorch/visual_entity_modules/datasets1/transforms.py new file mode 100644 index 00000000..06358578 --- /dev/null +++ b/openks/models/pytorch/visual_entity_modules/datasets1/transforms.py @@ -0,0 +1,276 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Transforms and data augmentation for both image + bbox. +""" +import random + +import PIL +import torch +import torchvision.transforms as T +import torchvision.transforms.functional as F + +from util.box_ops import box_xyxy_to_cxcywh +from util.misc import interpolate + + +def crop(image, target, region): + cropped_image = F.crop(image, *region) + + target = target.copy() + i, j, h, w = region + + # should we do something wrt the original size? + target["size"] = torch.tensor([h, w]) + + fields = ["labels", "area", "iscrowd"] + + if "boxes" in target: + boxes = target["boxes"] + max_size = torch.as_tensor([w, h], dtype=torch.float32) + cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) + cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) + cropped_boxes = cropped_boxes.clamp(min=0) + area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) + target["boxes"] = cropped_boxes.reshape(-1, 4) + target["area"] = area + fields.append("boxes") + + if "masks" in target: + # FIXME should we update the area here if there are no boxes? + target['masks'] = target['masks'][:, i:i + h, j:j + w] + fields.append("masks") + + # remove elements for which the boxes or masks that have zero area + if "boxes" in target or "masks" in target: + # favor boxes selection when defining which elements to keep + # this is compatible with previous implementation + if "boxes" in target: + cropped_boxes = target['boxes'].reshape(-1, 2, 2) + keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) + else: + keep = target['masks'].flatten(1).any(1) + + for field in fields: + target[field] = target[field][keep] + + return cropped_image, target + + +def hflip(image, target): + flipped_image = F.hflip(image) + + w, h = image.size + + target = target.copy() + if "boxes" in target: + boxes = target["boxes"] + boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0]) + target["boxes"] = boxes + + if "masks" in target: + target['masks'] = target['masks'].flip(-1) + + return flipped_image, target + + +def resize(image, target, size, max_size=None): + # size can be min_size (scalar) or (w, h) tuple + + def get_size_with_aspect_ratio(image_size, size, max_size=None): + w, h = image_size + if max_size is not None: + min_original_size = float(min((w, h))) + max_original_size = float(max((w, h))) + if max_original_size / min_original_size * size > max_size: + size = int(round(max_size * min_original_size / max_original_size)) + + if (w <= h and w == size) or (h <= w and h == size): + return (h, w) + + if w < h: + ow = size + oh = int(size * h / w) + else: + oh = size + ow = int(size * w / h) + + return (oh, ow) + + def get_size(image_size, size, max_size=None): + if isinstance(size, (list, tuple)): + return size[::-1] + else: + return get_size_with_aspect_ratio(image_size, size, max_size) + + size = get_size(image.size, size, max_size) + rescaled_image = F.resize(image, size) + + if target is None: + return rescaled_image, None + + ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)) + ratio_width, ratio_height = ratios + + target = target.copy() + if "boxes" in target: + boxes = target["boxes"] + scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height]) + target["boxes"] = scaled_boxes + + if "area" in target: + area = target["area"] + scaled_area = area * (ratio_width * ratio_height) + target["area"] = scaled_area + + h, w = size + target["size"] = torch.tensor([h, w]) + + if "masks" in target: + target['masks'] = interpolate( + target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5 + + return rescaled_image, target + + +def pad(image, target, padding): + # assumes that we only pad on the bottom right corners + padded_image = F.pad(image, (0, 0, padding[0], padding[1])) + if target is None: + return padded_image, None + target = target.copy() + # should we do something wrt the original size? + target["size"] = torch.tensor(padded_image.size[::-1]) + if "masks" in target: + target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1])) + return padded_image, target + + +class RandomCrop(object): + def __init__(self, size): + self.size = size + + def __call__(self, img, target): + region = T.RandomCrop.get_params(img, self.size) + return crop(img, target, region) + + +class RandomSizeCrop(object): + def __init__(self, min_size: int, max_size: int): + self.min_size = min_size + self.max_size = max_size + + def __call__(self, img: PIL.Image.Image, target: dict): + w = random.randint(self.min_size, min(img.width, self.max_size)) + h = random.randint(self.min_size, min(img.height, self.max_size)) + region = T.RandomCrop.get_params(img, [h, w]) + return crop(img, target, region) + + +class CenterCrop(object): + def __init__(self, size): + self.size = size + + def __call__(self, img, target): + image_width, image_height = img.size + crop_height, crop_width = self.size + crop_top = int(round((image_height - crop_height) / 2.)) + crop_left = int(round((image_width - crop_width) / 2.)) + return crop(img, target, (crop_top, crop_left, crop_height, crop_width)) + + +class RandomHorizontalFlip(object): + def __init__(self, p=0.5): + self.p = p + + def __call__(self, img, target): + if random.random() < self.p: + return hflip(img, target) + return img, target + + +class RandomResize(object): + def __init__(self, sizes, max_size=None): + assert isinstance(sizes, (list, tuple)) + self.sizes = sizes + self.max_size = max_size + + def __call__(self, img, target=None): + size = random.choice(self.sizes) + return resize(img, target, size, self.max_size) + + +class RandomPad(object): + def __init__(self, max_pad): + self.max_pad = max_pad + + def __call__(self, img, target): + pad_x = random.randint(0, self.max_pad) + pad_y = random.randint(0, self.max_pad) + return pad(img, target, (pad_x, pad_y)) + + +class RandomSelect(object): + """ + Randomly selects between transforms1 and transforms2, + with probability p for transforms1 and (1 - p) for transforms2 + """ + def __init__(self, transforms1, transforms2, p=0.5): + self.transforms1 = transforms1 + self.transforms2 = transforms2 + self.p = p + + def __call__(self, img, target): + if random.random() < self.p: + return self.transforms1(img, target) + return self.transforms2(img, target) + + +class ToTensor(object): + def __call__(self, img, target): + return F.to_tensor(img), target + + +class RandomErasing(object): + + def __init__(self, *args, **kwargs): + self.eraser = T.RandomErasing(*args, **kwargs) + + def __call__(self, img, target): + return self.eraser(img), target + + +class Normalize(object): + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, image, target=None): + image = F.normalize(image, mean=self.mean, std=self.std) + if target is None: + return image, None + target = target.copy() + h, w = image.shape[-2:] + if "boxes" in target: + boxes = target["boxes"] + boxes = box_xyxy_to_cxcywh(boxes) + boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) + target["boxes"] = boxes + return image, target + + +class Compose(object): + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, image, target): + for t in self.transforms: + image, target = t(image, target) + return image, target + + def __repr__(self): + format_string = self.__class__.__name__ + "(" + for t in self.transforms: + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\n)" + return format_string diff --git a/openks/models/pytorch/visual_entity_modules/engine.py b/openks/models/pytorch/visual_entity_modules/engine.py new file mode 100644 index 00000000..0c13a60c --- /dev/null +++ b/openks/models/pytorch/visual_entity_modules/engine.py @@ -0,0 +1,189 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# Copyright (c) Institute of Information Processing, Leibniz University Hannover. + +""" +Train and eval functions used in main.py +""" +import math +import sys +from typing import Iterable +import numpy as np + +import torch + +from datasets1.coco_eval import CocoEvaluator +import util.misc as utils +from util.box_ops import rescale_bboxes +from lib.evaluation.sg_eval import BasicSceneGraphEvaluator, calculate_mR_from_evaluator_list + +def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, + data_loader: Iterable, optimizer: torch.optim.Optimizer, + device: torch.device, epoch: int, max_norm: float = 0): + model.train() + criterion.train() + metric_logger = utils.MetricLogger(delimiter=" ") + metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) + metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) + metric_logger.add_meter('sub_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) + metric_logger.add_meter('obj_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) + metric_logger.add_meter('rel_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) + + header = 'Epoch: [{}]'.format(epoch) + print_freq = 500 + + for samples, targets in metric_logger.log_every(data_loader, print_freq, header): + samples = samples.to(device) + targets = [{k: v.to(device) for k, v in t.items()} for t in targets] + + outputs = model(samples) + loss_dict = criterion(outputs, targets) + weight_dict = criterion.weight_dict + losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) + + # reduce losses over all GPUs for logging purposes + loss_dict_reduced = utils.reduce_dict(loss_dict) + loss_dict_reduced_unscaled = {f'{k}_unscaled': v + for k, v in loss_dict_reduced.items()} + loss_dict_reduced_scaled = {k: v * weight_dict[k] + for k, v in loss_dict_reduced.items() if k in weight_dict} + losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) + + loss_value = losses_reduced_scaled.item() + + if not math.isfinite(loss_value): + print("Loss is {}, stopping training".format(loss_value)) + print(loss_dict_reduced) + sys.exit(1) + + optimizer.zero_grad() + losses.backward() + if max_norm > 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) + optimizer.step() + + metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled) + metric_logger.update(class_error=loss_dict_reduced['class_error']) + metric_logger.update(sub_error=loss_dict_reduced['sub_error']) + metric_logger.update(obj_error=loss_dict_reduced['obj_error']) + metric_logger.update(rel_error=loss_dict_reduced['rel_error']) + metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + +@torch.no_grad() +def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, args): + model.eval() + criterion.eval() + + metric_logger = utils.MetricLogger(delimiter=" ") + metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) + metric_logger.add_meter('sub_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) + metric_logger.add_meter('obj_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) + metric_logger.add_meter('rel_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) + header = 'Test:' + + # initilize evaluator + evaluator = BasicSceneGraphEvaluator.all_modes(multiple_preds=False) + if args.eval: + evaluator_list = [] + for index, name in enumerate(data_loader.dataset.rel_categories): + if index == 0: + continue + evaluator_list.append((index, name, BasicSceneGraphEvaluator.all_modes())) + else: + evaluator_list = None + + iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessors.keys()) + coco_evaluator = CocoEvaluator(base_ds, iou_types) + + for samples, targets in metric_logger.log_every(data_loader, 100, header): + + samples = samples.to(device) + targets = [{k: v.to(device) for k, v in t.items()} for t in targets] + + outputs = model(samples) + loss_dict = criterion(outputs, targets) + weight_dict = criterion.weight_dict + evaluate_rel_batch(outputs, targets, evaluator, evaluator_list) + + # reduce losses over all GPUs for logging purposes + loss_dict_reduced = utils.reduce_dict(loss_dict) + loss_dict_reduced_scaled = {k: v * weight_dict[k] + for k, v in loss_dict_reduced.items() if k in weight_dict} + loss_dict_reduced_unscaled = {f'{k}_unscaled': v + for k, v in loss_dict_reduced.items()} + metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()), + **loss_dict_reduced_scaled, + **loss_dict_reduced_unscaled) + metric_logger.update(class_error=loss_dict_reduced['class_error']) + metric_logger.update(sub_error=loss_dict_reduced['sub_error']) + metric_logger.update(obj_error=loss_dict_reduced['obj_error']) + metric_logger.update(rel_error=loss_dict_reduced['rel_error']) + + orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) + results = postprocessors['bbox'](outputs, orig_target_sizes) + + res = {target['image_id'].item(): output for target, output in zip(targets, results)} + if coco_evaluator is not None: + coco_evaluator.update(res) + + evaluator['sgdet'].print_stats() + if args.eval: + calculate_mR_from_evaluator_list(evaluator_list, 'sgdet') + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + if coco_evaluator is not None: + coco_evaluator.synchronize_between_processes() + + # accumulate predictions from all images + if coco_evaluator is not None: + coco_evaluator.accumulate() + coco_evaluator.summarize() + + stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} + if coco_evaluator is not None: + if 'bbox' in postprocessors.keys(): + stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist() + + return stats, coco_evaluator + +def evaluate_rel_batch(outputs, targets, evaluator, evaluator_list): + for batch, target in enumerate(targets): + target_bboxes_scaled = rescale_bboxes(target['boxes'].cpu(), torch.flip(target['orig_size'],dims=[0]).cpu()).clone().numpy() # recovered boxes with original size + + gt_entry = {'gt_classes': target['labels'].cpu().clone().numpy(), + 'gt_relations': target['rel_annotations'].cpu().clone().numpy(), + 'gt_boxes': target_bboxes_scaled} + + sub_bboxes_scaled = rescale_bboxes(outputs['sub_boxes'][batch].cpu(), torch.flip(target['orig_size'],dims=[0]).cpu()).clone().numpy() + obj_bboxes_scaled = rescale_bboxes(outputs['obj_boxes'][batch].cpu(), torch.flip(target['orig_size'],dims=[0]).cpu()).clone().numpy() + + pred_sub_scores, pred_sub_classes = torch.max(outputs['sub_logits'][batch].softmax(-1)[:, :-1], dim=1) + pred_obj_scores, pred_obj_classes = torch.max(outputs['obj_logits'][batch].softmax(-1)[:, :-1], dim=1) + rel_scores = outputs['rel_logits'][batch][:,1:-1].softmax(-1) + # + pred_entry = {'sub_boxes': sub_bboxes_scaled, + 'sub_classes': pred_sub_classes.cpu().clone().numpy(), + 'sub_scores': pred_sub_scores.cpu().clone().numpy(), + 'obj_boxes': obj_bboxes_scaled, + 'obj_classes': pred_obj_classes.cpu().clone().numpy(), + 'obj_scores': pred_obj_scores.cpu().clone().numpy(), + 'rel_scores': rel_scores.cpu().clone().numpy()} + + evaluator['sgdet'].evaluate_scene_graph_entry(gt_entry, pred_entry) + + if evaluator_list is not None: + for pred_id, _, evaluator_rel in evaluator_list: + gt_entry_rel = gt_entry.copy() + mask = np.in1d(gt_entry_rel['gt_relations'][:, -1], pred_id) + gt_entry_rel['gt_relations'] = gt_entry_rel['gt_relations'][mask, :] + if gt_entry_rel['gt_relations'].shape[0] == 0: + continue + evaluator_rel['sgdet'].evaluate_scene_graph_entry(gt_entry_rel, pred_entry) + diff --git a/openks/models/pytorch/visual_entity_modules/models/__init__.py b/openks/models/pytorch/visual_entity_modules/models/__init__.py new file mode 100644 index 00000000..86343576 --- /dev/null +++ b/openks/models/pytorch/visual_entity_modules/models/__init__.py @@ -0,0 +1,5 @@ +from .reltr import build + + +def build_model(args): + return build(args) diff --git a/openks/models/pytorch/visual_entity_modules/models/backbone.py b/openks/models/pytorch/visual_entity_modules/models/backbone.py new file mode 100644 index 00000000..99d0e60e --- /dev/null +++ b/openks/models/pytorch/visual_entity_modules/models/backbone.py @@ -0,0 +1,118 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Backbone modules. +""" + +import torch +import torch.nn.functional as F +import torchvision +from torch import nn +from torchvision.models._utils import IntermediateLayerGetter +from typing import Dict, List + +from util.misc import NestedTensor, is_main_process + +from .position_encoding import build_position_encoding + + +class FrozenBatchNorm2d(torch.nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, + without which any other models than torchvision.models.resnet[18,34,50,101] + produce nans. + """ + + def __init__(self, n): + super(FrozenBatchNorm2d, self).__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + num_batches_tracked_key = prefix + 'num_batches_tracked' + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super(FrozenBatchNorm2d, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + + def forward(self, x): + # move reshapes to the beginning + # to make it fuser-friendly + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + eps = 1e-5 + scale = w * (rv + eps).rsqrt() + bias = b - rm * scale + return x * scale + bias + + +class BackboneBase(nn.Module): + + def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool): + super().__init__() + for name, parameter in backbone.named_parameters(): + if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: + parameter.requires_grad_(False) + if return_interm_layers: + return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} + else: + return_layers = {'layer4': "0"} + self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) + self.num_channels = num_channels + + def forward(self, tensor_list: NestedTensor): + xs = self.body(tensor_list.tensors) + out: Dict[str, NestedTensor] = {} + for name, x in xs.items(): + m = tensor_list.mask + assert m is not None + mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] + out[name] = NestedTensor(x, mask) + return out + + +class Backbone(BackboneBase): + """ResNet backbone with frozen BatchNorm.""" + def __init__(self, name: str, + train_backbone: bool, + return_interm_layers: bool, + dilation: bool): + backbone = getattr(torchvision.models, name)( + replace_stride_with_dilation=[False, False, dilation], + pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) + num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 + super().__init__(backbone, train_backbone, num_channels, return_interm_layers) + + +class Joiner(nn.Sequential): + def __init__(self, backbone, position_embedding): + super().__init__(backbone, position_embedding) + + def forward(self, tensor_list: NestedTensor): + xs = self[0](tensor_list) + out: List[NestedTensor] = [] + pos = [] + for name, x in xs.items(): + out.append(x) + # position encoding + pos.append(self[1](x).to(x.tensors.dtype)) + + return out, pos + + +def build_backbone(args): + position_embedding = build_position_encoding(args) + train_backbone = args.lr_backbone > 0 + return_interm_layers = args.return_interm_layers + backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation) + model = Joiner(backbone, position_embedding) + model.num_channels = backbone.num_channels + return model diff --git a/openks/models/pytorch/visual_entity_modules/models/matcher.py b/openks/models/pytorch/visual_entity_modules/models/matcher.py new file mode 100644 index 00000000..06cc412e --- /dev/null +++ b/openks/models/pytorch/visual_entity_modules/models/matcher.py @@ -0,0 +1,156 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# Copyright (c) Institute of Information Processing, Leibniz University Hannover. +""" +Modules to compute the matching cost between the predicted triplet and ground truth triplet. +""" +import torch +from scipy.optimize import linear_sum_assignment +from torch import nn + +from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou, box_iou + + +class HungarianMatcher(nn.Module): + """This class computes an assignment between the targets and the predictions of the network""" + + def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1, iou_threshold: float = 0.7): + """Creates the matcher + + Params: + cost_class: This is the relative weight of the classification error in the matching cost + cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost + cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost + """ + super().__init__() + self.cost_class = cost_class + self.cost_bbox = cost_bbox + self.cost_giou = cost_giou + self.iou_threshold = iou_threshold + assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0" + + @torch.no_grad() + def forward(self, outputs, targets): + """ Performs the matching + + Params: + outputs: This is a dict that contains at least these entries: + "pred_logits": Tensor of dim [batch_size, num_entities, num_entity_classes] with the entity classification logits + "pred_boxes": Tensor of dim [batch_size, num_entities, 4] with the predicted box coordinates + "sub_logits": Tensor of dim [batch_size, num_triplets, num_entity_classes] with the subject classification logits + "sub_boxes": Tensor of dim [batch_size, num_triplets, 4] with the predicted subject box coordinates + "obj_logits": Tensor of dim [batch_size, num_triplets, num_entity_classes] with the object classification logits + "obj_boxes": Tensor of dim [batch_size, num_triplets, 4] with the predicted object box coordinates + "rel_logits": Tensor of dim [batch_size, num_triplets, num_predicate_classes] with the predicate classification logits + + targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: + "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth + objects in the target) containing the class labels + "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates + "image_id": Image index + "orig_size": Tensor of dim [2] with the height and width + "size": Tensor of dim [2] with the height and width after transformation + "rel_annotations": Tensor of dim [num_gt_triplet, 3] with the subject index/object index/predicate class + Returns: + A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected entity predictions (in order) + - index_j is the indices of the corresponding selected entity targets (in order) + A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected triplet predictions (in order) + - index_j is the indices of the corresponding selected triplet targets (in order) + Subject loss weight (Type: bool) to determine if back propagation should be conducted + Object loss weight (Type: bool) to determine if back propagation should be conducted + """ + bs, num_queries = outputs["pred_logits"].shape[:2] + num_queries_rel = outputs["rel_logits"].shape[1] + alpha = 0.25 + gamma = 2.0 + + # We flatten to compute the cost matrices in a batch + out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid() + out_bbox = outputs["pred_boxes"].flatten(0, 1) + + # Also concat the target labels and boxes + tgt_ids = torch.cat([v["labels"] for v in targets]) + tgt_bbox = torch.cat([v["boxes"] for v in targets]) + + # Compute the entity classification cost. We borrow the cost function from Deformable DETR (https://arxiv.org/abs/2010.04159) + neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log()) + pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log()) + cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids] + + # Compute the L1 cost between entity boxes + cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) + + # Compute the giou cost betwen entity boxes + cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) + + # Final entity cost matrix + C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou + C = C.view(bs, num_queries, -1).cpu() + + sizes = [len(v["boxes"]) for v in targets] + indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] + + # Concat the subject/object/predicate labels and subject/object boxes + sub_tgt_bbox = torch.cat([v['boxes'][v['rel_annotations'][:, 0]] for v in targets]) + sub_tgt_ids = torch.cat([v['labels'][v['rel_annotations'][:, 0]] for v in targets]) + obj_tgt_bbox = torch.cat([v['boxes'][v['rel_annotations'][:, 1]] for v in targets]) + obj_tgt_ids = torch.cat([v['labels'][v['rel_annotations'][:, 1]] for v in targets]) + rel_tgt_ids = torch.cat([v["rel_annotations"][:, 2] for v in targets]) + + sub_prob = outputs["sub_logits"].flatten(0, 1).sigmoid() + sub_bbox = outputs["sub_boxes"].flatten(0, 1) + obj_prob = outputs["obj_logits"].flatten(0, 1).sigmoid() + obj_bbox = outputs["obj_boxes"].flatten(0, 1) + rel_prob = outputs["rel_logits"].flatten(0, 1).sigmoid() + + # Compute the subject matching cost based on class and box. + neg_cost_class_sub = (1 - alpha) * (sub_prob ** gamma) * (-(1 - sub_prob + 1e-8).log()) + pos_cost_class_sub = alpha * ((1 - sub_prob) ** gamma) * (-(sub_prob + 1e-8).log()) + cost_sub_class = pos_cost_class_sub[:, sub_tgt_ids] - neg_cost_class_sub[:, sub_tgt_ids] + cost_sub_bbox = torch.cdist(sub_bbox, sub_tgt_bbox, p=1) + cost_sub_giou = -generalized_box_iou(box_cxcywh_to_xyxy(sub_bbox), box_cxcywh_to_xyxy(sub_tgt_bbox)) + + # Compute the object matching cost based on class and box. + neg_cost_class_obj = (1 - alpha) * (obj_prob ** gamma) * (-(1 - obj_prob + 1e-8).log()) + pos_cost_class_obj = alpha * ((1 - obj_prob) ** gamma) * (-(obj_prob + 1e-8).log()) + cost_obj_class = pos_cost_class_obj[:, obj_tgt_ids] - neg_cost_class_obj[:, obj_tgt_ids] + cost_obj_bbox = torch.cdist(obj_bbox, obj_tgt_bbox, p=1) + cost_obj_giou = -generalized_box_iou(box_cxcywh_to_xyxy(obj_bbox), box_cxcywh_to_xyxy(obj_tgt_bbox)) + + # Compute the object matching cost only based on class. + neg_cost_class_rel = (1 - alpha) * (rel_prob ** gamma) * (-(1 - rel_prob + 1e-8).log()) + pos_cost_class_rel = alpha * ((1 - rel_prob) ** gamma) * (-(rel_prob + 1e-8).log()) + cost_rel_class = pos_cost_class_rel[:, rel_tgt_ids] - neg_cost_class_rel[:, rel_tgt_ids] + + # Final triplet cost matrix + C_rel = self.cost_bbox * cost_sub_bbox + self.cost_bbox * cost_obj_bbox + \ + self.cost_class * cost_sub_class + self.cost_class * cost_obj_class + 0.5 * cost_rel_class + \ + self.cost_giou * cost_sub_giou + self.cost_giou * cost_obj_giou + C_rel = C_rel.view(bs, num_queries_rel, -1).cpu() + + sizes1 = [len(v["rel_annotations"]) for v in targets] + indices1 = [linear_sum_assignment(c[i]) for i, c in enumerate(C_rel.split(sizes1, -1))] + + # assignment strategy to avoid assigning to some good predictions + sub_weight = torch.ones((bs, num_queries_rel)).to(out_prob.device) + good_sub_detection = torch.logical_and((outputs["sub_logits"].flatten(0, 1)[:, :-1].argmax(-1)[:, None] == tgt_ids), + (box_iou(box_cxcywh_to_xyxy(sub_bbox), box_cxcywh_to_xyxy(tgt_bbox))[0] >= self.iou_threshold)) + for i, c in enumerate(good_sub_detection.split(sizes, -1)): + sub_weight[i, c.sum(-1)[i*num_queries_rel:(i+1)*num_queries_rel].to(torch.bool)] = 0 + sub_weight[i, indices1[i][0]] = 1 + + obj_weight = torch.ones((bs, num_queries_rel)).to(out_prob.device) + good_obj_detection = torch.logical_and((outputs["obj_logits"].flatten(0, 1)[:, :-1].argmax(-1)[:, None] == tgt_ids), + (box_iou(box_cxcywh_to_xyxy(obj_bbox), box_cxcywh_to_xyxy(tgt_bbox))[0] >= self.iou_threshold)) + for i, c in enumerate(good_obj_detection.split(sizes, -1)): + obj_weight[i, c.sum(-1)[i*num_queries_rel:(i+1)*num_queries_rel].to(torch.bool)] = 0 + obj_weight[i, indices1[i][0]] = 1 + + return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices],\ + [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices1],\ + sub_weight, obj_weight + + +def build_matcher(args): + return HungarianMatcher(cost_class=args.set_cost_class, cost_bbox=args.set_cost_bbox, cost_giou=args.set_cost_giou, iou_threshold=args.set_iou_threshold) diff --git a/openks/models/pytorch/visual_entity_modules/models/position_encoding.py b/openks/models/pytorch/visual_entity_modules/models/position_encoding.py new file mode 100644 index 00000000..73ae39ed --- /dev/null +++ b/openks/models/pytorch/visual_entity_modules/models/position_encoding.py @@ -0,0 +1,89 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Various positional encodings for the transformer. +""" +import math +import torch +from torch import nn + +from util.misc import NestedTensor + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + mask = tensor_list.mask + assert mask is not None + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class PositionEmbeddingLearned(nn.Module): + """ + Absolute pos embedding, learned. + """ + def __init__(self, num_pos_feats=256): + super().__init__() + self.row_embed = nn.Embedding(50, num_pos_feats) + self.col_embed = nn.Embedding(50, num_pos_feats) + self.reset_parameters() + + def reset_parameters(self): + nn.init.uniform_(self.row_embed.weight) + nn.init.uniform_(self.col_embed.weight) + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + h, w = x.shape[-2:] + i = torch.arange(w, device=x.device) + j = torch.arange(h, device=x.device) + x_emb = self.col_embed(i) + y_emb = self.row_embed(j) + pos = torch.cat([ + x_emb.unsqueeze(0).repeat(h, 1, 1), + y_emb.unsqueeze(1).repeat(1, w, 1), + ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) + return pos + + +def build_position_encoding(args): + N_steps = args.hidden_dim // 2 + if args.position_embedding in ('v2', 'sine'): + # TODO find a better way of exposing other arguments + position_embedding = PositionEmbeddingSine(N_steps, normalize=True) + elif args.position_embedding in ('v3', 'learned'): + position_embedding = PositionEmbeddingLearned(N_steps) + else: + raise ValueError(f"not supported {args.position_embedding}") + + return position_embedding diff --git a/openks/models/pytorch/visual_entity_modules/models/reltr.py b/openks/models/pytorch/visual_entity_modules/models/reltr.py new file mode 100644 index 00000000..db93d276 --- /dev/null +++ b/openks/models/pytorch/visual_entity_modules/models/reltr.py @@ -0,0 +1,415 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# Copyright (c) Institute of Information Processing, Leibniz University Hannover. + +import torch +import torch.nn.functional as F +from torch import nn +from util import box_ops +from util.misc import (NestedTensor, nested_tensor_from_tensor_list, + accuracy, get_world_size, interpolate, + is_dist_avail_and_initialized) +from .backbone import build_backbone +from .matcher import build_matcher +from .transformer import build_transformer + +class RelTR(nn.Module): + """ RelTR: Relation Transformer for Scene Graph Generation """ + def __init__(self, backbone, transformer, num_classes, num_rel_classes, num_entities, num_triplets, aux_loss=False, matcher=None): + """ Initializes the model. + Parameters: + backbone: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + num_classes: number of entity classes + num_entities: number of entity queries + num_triplets: number of coupled subject/object queries + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + """ + super().__init__() + self.num_entities = num_entities + self.transformer = transformer + hidden_dim = transformer.d_model + self.hidden_dim = hidden_dim + + self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1) + self.backbone = backbone + self.aux_loss = aux_loss + + self.entity_embed = nn.Embedding(num_entities, hidden_dim*2) + self.triplet_embed = nn.Embedding(num_triplets, hidden_dim*3) + self.so_embed = nn.Embedding(2, hidden_dim) # subject and object encoding + + # entity prediction + self.entity_class_embed = nn.Linear(hidden_dim, num_classes + 1) + self.entity_bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) + + # mask head + self.so_mask_conv = nn.Sequential(torch.nn.Upsample(size=(28, 28)), + nn.Conv2d(2, 64, kernel_size=3, stride=2, padding=3, bias=True), + nn.ReLU(inplace=True), + nn.BatchNorm2d(64), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1, bias=True), + nn.ReLU(inplace=True), + nn.BatchNorm2d(32)) + self.so_mask_fc = nn.Sequential(nn.Linear(2048, 512), + nn.ReLU(inplace=True), + nn.Linear(512, 128)) + + # predicate classification + self.rel_class_embed = MLP(hidden_dim*2+128, hidden_dim, num_rel_classes + 1, 2) + + # subject/object label classfication and box regression + self.sub_class_embed = nn.Linear(hidden_dim, num_classes + 1) + self.sub_bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) + self.obj_class_embed = nn.Linear(hidden_dim, num_classes + 1) + self.obj_bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) + + + def forward(self, samples: NestedTensor): + """ The forward expects a NestedTensor, which consists of: + - samples.tensor: batched images, of shape [batch_size x 3 x H x W] + - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels + + It returns a dict with the following elements: + - "pred_logits": the entity classification logits (including no-object) for all entity queries. + Shape= [batch_size x num_queries x (num_classes + 1)] + - "pred_boxes": the normalized entity boxes coordinates for all entity queries, represented as + (center_x, center_y, height, width). These values are normalized in [0, 1], + relative to the size of each individual image (disregarding possible padding). + See PostProcess for information on how to retrieve the unnormalized bounding box. + - "sub_logits": the subject classification logits + - "obj_logits": the object classification logits + - "sub_boxes": the normalized subject boxes coordinates + - "obj_boxes": the normalized object boxes coordinates + - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of + dictionnaries containing the two above keys for each decoder layer. + """ + + if isinstance(samples, (list, torch.Tensor)): + samples = nested_tensor_from_tensor_list(samples) + features, pos = self.backbone(samples) + + src, mask = features[-1].decompose() + assert mask is not None + hs, hs_t, so_masks, _ = self.transformer(self.input_proj(src), mask, self.entity_embed.weight, + self.triplet_embed.weight, pos[-1], self.so_embed.weight) + so_masks = so_masks.detach() + so_masks = self.so_mask_conv(so_masks.view(-1, 2, src.shape[-2],src.shape[-1])).view(hs_t.shape[0], hs_t.shape[1], hs_t.shape[2],-1) + so_masks = self.so_mask_fc(so_masks) + + hs_sub, hs_obj = torch.split(hs_t, self.hidden_dim, dim=-1) + + outputs_class = self.entity_class_embed(hs) + outputs_coord = self.entity_bbox_embed(hs).sigmoid() + + outputs_class_sub = self.sub_class_embed(hs_sub) + outputs_coord_sub = self.sub_bbox_embed(hs_sub).sigmoid() + + outputs_class_obj = self.obj_class_embed(hs_obj) + outputs_coord_obj = self.obj_bbox_embed(hs_obj).sigmoid() + + outputs_class_rel = self.rel_class_embed(torch.cat((hs_sub, hs_obj, so_masks), dim=-1)) + + out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1], + 'sub_logits': outputs_class_sub[-1], 'sub_boxes': outputs_coord_sub[-1], + 'obj_logits': outputs_class_obj[-1], 'obj_boxes': outputs_coord_obj[-1], + 'rel_logits': outputs_class_rel[-1]} + if self.aux_loss: + out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord, outputs_class_sub, outputs_coord_sub, + outputs_class_obj, outputs_coord_obj, outputs_class_rel) + return out + + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_coord, outputs_class_sub, outputs_coord_sub, + outputs_class_obj, outputs_coord_obj, outputs_class_rel): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + return [{'pred_logits': a, 'pred_boxes': b, 'sub_logits': c, 'sub_boxes': d, 'obj_logits': e, 'obj_boxes': f, + 'rel_logits': g} + for a, b, c, d, e, f, g in zip(outputs_class[:-1], outputs_coord[:-1], outputs_class_sub[:-1], + outputs_coord_sub[:-1], outputs_class_obj[:-1], outputs_coord_obj[:-1], + outputs_class_rel[:-1])] + + +class SetCriterion(nn.Module): + """ This class computes the loss for RelTR. + The process happens in two steps: + 1) we compute hungarian assignment between ground truth boxes and the outputs of the model + 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) + """ + def __init__(self, num_classes, num_rel_classes, matcher, weight_dict, eos_coef, losses): + """ Create the criterion. + Parameters: + num_classes: number of object categories, omitting the special no-object category + matcher: module able to compute a matching between targets and proposals + weight_dict: dict containing as key the names of the losses and as values their relative weight. + eos_coef: relative classification weight applied to the no-object category + losses: list of all the losses to be applied. See get_loss for list of available losses. + """ + super().__init__() + self.num_classes = num_classes + self.matcher = matcher + self.weight_dict = weight_dict + self.eos_coef = eos_coef + self.losses = losses + empty_weight = torch.ones(self.num_classes + 1) + empty_weight[-1] = self.eos_coef + self.register_buffer('empty_weight', empty_weight) + + empty_weight_rel = torch.ones(num_rel_classes+1) + empty_weight_rel[-1] = self.eos_coef + self.register_buffer('empty_weight_rel', empty_weight_rel) + + def loss_labels(self, outputs, targets, indices, num_boxes, log=True): + """Entity/subject/object Classification loss + """ + assert 'pred_logits' in outputs + + pred_logits = outputs['pred_logits'] + + idx = self._get_src_permutation_idx(indices[0]) + target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices[0])]) + target_classes = torch.full(pred_logits.shape[:2], self.num_classes, dtype=torch.int64, device=pred_logits.device) + target_classes[idx] = target_classes_o + + sub_logits = outputs['sub_logits'] + obj_logits = outputs['obj_logits'] + + rel_idx = self._get_src_permutation_idx(indices[1]) + target_rels_classes_o = torch.cat([t["labels"][t["rel_annotations"][J, 0]] for t, (_, J) in zip(targets, indices[1])]) + target_relo_classes_o = torch.cat([t["labels"][t["rel_annotations"][J, 1]] for t, (_, J) in zip(targets, indices[1])]) + + target_sub_classes = torch.full(sub_logits.shape[:2], self.num_classes, dtype=torch.int64, device=sub_logits.device) + target_obj_classes = torch.full(obj_logits.shape[:2], self.num_classes, dtype=torch.int64, device=obj_logits.device) + + target_sub_classes[rel_idx] = target_rels_classes_o + target_obj_classes[rel_idx] = target_relo_classes_o + + target_classes = torch.cat((target_classes, target_sub_classes, target_obj_classes), dim=1) + src_logits = torch.cat((pred_logits, sub_logits, obj_logits), dim=1) + + loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight, reduction='none') + + loss_weight = torch.cat((torch.ones(pred_logits.shape[:2]).to(pred_logits.device), indices[2]*0.5, indices[3]*0.5), dim=-1) + losses = {'loss_ce': (loss_ce * loss_weight).sum()/self.empty_weight[target_classes].sum()} + + if log: + # TODO this should probably be a separate loss, not hacked in this one here + losses['class_error'] = 100 - accuracy(pred_logits[idx], target_classes_o)[0] + losses['sub_error'] = 100 - accuracy(sub_logits[rel_idx], target_rels_classes_o)[0] + losses['obj_error'] = 100 - accuracy(obj_logits[rel_idx], target_relo_classes_o)[0] + return losses + + @torch.no_grad() + def loss_cardinality(self, outputs, targets, indices, num_boxes): + """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes + This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients + """ + pred_logits = outputs['rel_logits'] + device = pred_logits.device + tgt_lengths = torch.as_tensor([len(v["rel_annotations"]) for v in targets], device=device) + # Count the number of predictions that are NOT "no-object" (which is the last class) + card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1) + card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) + losses = {'cardinality_error': card_err} + return losses + + def loss_boxes(self, outputs, targets, indices, num_boxes): + """Compute the losses related to the entity/subject/object bounding boxes, the L1 regression loss and the GIoU loss + targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] + The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. + """ + assert 'pred_boxes' in outputs + idx = self._get_src_permutation_idx(indices[0]) + pred_boxes = outputs['pred_boxes'][idx] + target_entry_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices[0])], dim=0) + + rel_idx = self._get_src_permutation_idx(indices[1]) + target_rels_boxes = torch.cat([t['boxes'][t["rel_annotations"][i, 0]] for t, (_, i) in zip(targets, indices[1])], dim=0) + target_relo_boxes = torch.cat([t['boxes'][t["rel_annotations"][i, 1]] for t, (_, i) in zip(targets, indices[1])], dim=0) + rels_boxes = outputs['sub_boxes'][rel_idx] + relo_boxes = outputs['obj_boxes'][rel_idx] + + src_boxes = torch.cat((pred_boxes, rels_boxes, relo_boxes), dim=0) + target_boxes = torch.cat((target_entry_boxes, target_rels_boxes, target_relo_boxes), dim=0) + loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') + + losses = {} + losses['loss_bbox'] = loss_bbox.sum() / num_boxes + + loss_giou = 1 - torch.diag(box_ops.generalized_box_iou( + box_ops.box_cxcywh_to_xyxy(src_boxes), + box_ops.box_cxcywh_to_xyxy(target_boxes))) + losses['loss_giou'] = loss_giou.sum() / num_boxes + return losses + + def loss_relations(self, outputs, targets, indices, num_boxes, log=True): + """Compute the predicate classification loss + """ + assert 'rel_logits' in outputs + + src_logits = outputs['rel_logits'] + idx = self._get_src_permutation_idx(indices[1]) + target_classes_o = torch.cat([t["rel_annotations"][J,2] for t, (_, J) in zip(targets, indices[1])]) + target_classes = torch.full(src_logits.shape[:2], 51, dtype=torch.int64, device=src_logits.device) + target_classes[idx] = target_classes_o + + loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight_rel) + + losses = {'loss_rel': loss_ce} + if log: + losses['rel_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0] + return losses + + def _get_src_permutation_idx(self, indices): + # permute predictions following indices + batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + src_idx = torch.cat([src for (src, _) in indices]) + return batch_idx, src_idx + + def _get_tgt_permutation_idx(self, indices): + # permute targets following indices + batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + tgt_idx = torch.cat([tgt for (_, tgt) in indices]) + return batch_idx, tgt_idx + + def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): + loss_map = { + 'labels': self.loss_labels, + 'cardinality': self.loss_cardinality, + 'boxes': self.loss_boxes, + 'relations': self.loss_relations + } + assert loss in loss_map, f'do you really want to compute {loss} loss?' + return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) + + def forward(self, outputs, targets): + """ This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc + """ + outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'} + + # Retrieve the matching between the outputs of the last layer and the targets + indices = self.matcher(outputs_without_aux, targets) + self.indices = indices + + # Compute the average number of target boxes accross all nodes, for normalization purposes + num_boxes = sum(len(t["labels"])+len(t["rel_annotations"]) for t in targets) + num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) + if is_dist_avail_and_initialized(): + torch.distributed.all_reduce(num_boxes) + num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() + + # Compute all the requested losses + losses = {} + for loss in self.losses: + losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes)) + + # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. + if 'aux_outputs' in outputs: + for i, aux_outputs in enumerate(outputs['aux_outputs']): + indices = self.matcher(aux_outputs, targets) + for loss in self.losses: + kwargs = {} + if loss == 'labels' or loss == 'relations': + # Logging is enabled only for the last layer + kwargs = {'log': False} + l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs) + l_dict = {k + f'_{i}': v for k, v in l_dict.items()} + losses.update(l_dict) + + return losses + + +class PostProcess(nn.Module): + """ This module converts the model's output into the format expected by the coco api""" + @torch.no_grad() + def forward(self, outputs, target_sizes): + """ Perform the computation + Parameters: + outputs: raw outputs of the model + target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch + For evaluation, this must be the original image size (before any data augmentation) + For visualization, this should be the image size after data augment, but before padding + """ + + out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes'] + + assert len(out_logits) == len(target_sizes) + assert target_sizes.shape[1] == 2 + + prob = F.softmax(out_logits, -1) + scores, labels = prob[..., :-1].max(-1) + + # convert to [x0, y0, x1, y1] format + boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) + # and from relative [0, 1] to absolute [0, height] coordinates + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) + boxes = boxes * scale_fct[:, None, :] + + results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)] + + return results + + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +def build(args): + + num_classes = 151 if args.dataset != 'oi' else None #TODO: openimage v6 + num_rel_classes = 51 if args.dataset != 'oi' else None #TODO: openimage v6 + + device = torch.device(args.device) + + backbone = build_backbone(args) + + transformer = build_transformer(args) + matcher = build_matcher(args) + model = RelTR( + backbone, + transformer, + num_classes=num_classes, + num_rel_classes = num_rel_classes, + num_entities=args.num_entities, + num_triplets=args.num_triplets, + aux_loss=args.aux_loss, + matcher=matcher) + + weight_dict = {'loss_ce': 1, 'loss_bbox': args.bbox_loss_coef} + weight_dict['loss_giou'] = args.giou_loss_coef + weight_dict['loss_rel'] = args.rel_loss_coef + + # TODO this is a hack + if args.aux_loss: + aux_weight_dict = {} + for i in range(args.dec_layers - 1): + aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()}) + weight_dict.update(aux_weight_dict) + + losses = ['labels', 'boxes', 'cardinality', "relations"] + + criterion = SetCriterion(num_classes, num_rel_classes, matcher=matcher, weight_dict=weight_dict, + eos_coef=args.eos_coef, losses=losses) + criterion.to(device) + postprocessors = {'bbox': PostProcess()} + + return model, criterion, postprocessors + diff --git a/openks/models/pytorch/visual_entity_modules/models/transformer.py b/openks/models/pytorch/visual_entity_modules/models/transformer.py new file mode 100644 index 00000000..365bb7d1 --- /dev/null +++ b/openks/models/pytorch/visual_entity_modules/models/transformer.py @@ -0,0 +1,359 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# Copyright (c) Institute of Information Processing, Leibniz University Hannover. +""" +RelTR Transformer class. +""" +import copy +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + + +class Transformer(nn.Module): + + def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, + num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False, + return_intermediate_dec=False): + super().__init__() + + encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, + dropout, activation, normalize_before) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, + dropout, activation, ) + + self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, return_intermediate=return_intermediate_dec) + + self._reset_parameters() + self.d_model = d_model + self.nhead = nhead + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src, mask, entity_embed, triplet_embed, pos_embed, so_embed): + # flatten NxCxHxW to HWxNxC + bs, c, h, w = src.shape + src = src.flatten(2).permute(2, 0, 1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1) + + entity_embed, entity = torch.split(entity_embed, c, dim=1) + triplet_embed, triplet = torch.split(triplet_embed, [c, 2 * c], dim=1) + + entity_embed = entity_embed.unsqueeze(1).repeat(1, bs, 1) + triplet_embed = triplet_embed.unsqueeze(1).repeat(1, bs, 1) + entity = entity.unsqueeze(1).repeat(1, bs, 1) + triplet = triplet.unsqueeze(1).repeat(1, bs, 1) + mask = mask.flatten(1) + + memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) + hs, hs_t, sub_maps, obj_maps = self.decoder(entity, triplet, memory, memory_key_padding_mask=mask, + pos=pos_embed, entity_pos=entity_embed, + triplet_pos=triplet_embed, so_pos=so_embed) + + so_masks = torch.cat((sub_maps.reshape(sub_maps.shape[0], bs, sub_maps.shape[2], 1, h, w), + obj_maps.reshape(obj_maps.shape[0], bs, obj_maps.shape[2], 1, h, w)), dim=3) + + return hs.transpose(1, 2), hs_t.transpose(1, 2), so_masks, memory.permute(1, 2, 0).view(bs, c, h, w) + + +class TransformerEncoder(nn.Module): + + def __init__(self, encoder_layer, num_layers, norm=None): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward(self, src, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + output = src + + for layer in self.layers: + output = layer(output, src_mask=mask, + src_key_padding_mask=src_key_padding_mask, pos=pos) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerEncoderLayer(nn.Module): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(src, pos) + src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + def forward_pre(self, src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(src2) + return src + + def forward(self, src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos) + return self.forward_post(src, src_mask, src_key_padding_mask, pos) + + +class TransformerDecoder(nn.Module): + + def __init__(self, decoder_layer, num_layers, return_intermediate=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.return_intermediate = return_intermediate + + def forward(self, entity, triplet, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, entity_pos: Optional[Tensor] = None, + triplet_pos: Optional[Tensor] = None, so_pos: Optional[Tensor] = None): + output_entity = entity + output_triplet = triplet + intermediate_entity = [] + intermediate_triplet = [] + intermediate_submaps = [] + intermediate_objmaps = [] + + for layer in self.layers: + output_entity, output_triplet, sub_maps, obj_maps = layer(output_entity, output_triplet, entity_pos, triplet_pos, so_pos, + memory, tgt_mask=tgt_mask, memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, pos=pos) + + + if self.return_intermediate: + intermediate_entity.append(output_entity) + intermediate_triplet.append(output_triplet) + intermediate_submaps.append(sub_maps) + intermediate_objmaps.append(obj_maps) + + if self.return_intermediate: + return torch.stack(intermediate_entity), torch.stack(intermediate_triplet), \ + torch.stack(intermediate_submaps), torch.stack(intermediate_objmaps) + + +class TransformerDecoderLayer(nn.Module): + """triplet decoder layer""" + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"): + super().__init__() + self.activation = _get_activation_fn(activation) + + # entity part + self.self_attn_entity = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.dropout2_entity = nn.Dropout(dropout) + self.norm2_entity = nn.LayerNorm(d_model) + + self.cross_attn_entity = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.dropout1_entity = nn.Dropout(dropout) + self.norm1_entity = nn.LayerNorm(d_model) + + # triplet part + self.self_attn_so = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.dropout2_so = nn.Dropout(dropout) + self.norm2_so = nn.LayerNorm(d_model) + + self.cross_attn_sub = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.dropout1_sub = nn.Dropout(dropout) + self.norm1_sub = nn.LayerNorm(d_model) + self.cross_sub_entity = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.dropout2_sub = nn.Dropout(dropout) + self.norm2_sub = nn.LayerNorm(d_model) + + self.cross_attn_obj = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.dropout1_obj = nn.Dropout(dropout) + self.norm1_obj = nn.LayerNorm(d_model) + self.cross_obj_entity = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.dropout2_obj = nn.Dropout(dropout) + self.norm2_obj = nn.LayerNorm(d_model) + + # ffn + self.linear1_entity = nn.Linear(d_model, dim_feedforward) + self.dropout3_entity = nn.Dropout(dropout) + self.linear2_entity = nn.Linear(dim_feedforward, d_model) + self.dropout4_entity = nn.Dropout(dropout) + self.norm3_entity = nn.LayerNorm(d_model) + + self.linear1_sub = nn.Linear(d_model, dim_feedforward) + self.dropout3_sub = nn.Dropout(dropout) + self.linear2_sub = nn.Linear(dim_feedforward, d_model) + self.dropout4_sub = nn.Dropout(dropout) + self.norm3_sub = nn.LayerNorm(d_model) + + self.linear1_obj = nn.Linear(d_model, dim_feedforward) + self.dropout3_obj = nn.Dropout(dropout) + self.linear2_obj = nn.Linear(dim_feedforward, d_model) + self.dropout4_obj = nn.Dropout(dropout) + self.norm3_obj = nn.LayerNorm(d_model) + + def forward_ffn_entity(self, tgt): + tgt2 = self.linear2_entity(self.dropout3_entity(self.activation(self.linear1_entity(tgt)))) + tgt = tgt + self.dropout4_entity(tgt2) + tgt = self.norm3_entity(tgt) + return tgt + def forward_ffn_sub(self, tgt): + tgt2 = self.linear2_sub(self.dropout3_sub(self.activation(self.linear1_sub(tgt)))) + tgt = tgt + self.dropout4_sub(tgt2) + tgt = self.norm3_sub(tgt) + return tgt + def forward_ffn_obj(self, tgt): + tgt2 = self.linear2_obj(self.dropout3_obj(self.activation(self.linear1_obj(tgt)))) + tgt = tgt + self.dropout4_obj(tgt2) + tgt = self.norm3_obj(tgt) + return tgt + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward(self, tgt_entity, tgt_triplet, entity_pos, triplet_pos, so_pos, + memory, tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + + # entity layer + q_entity = k_entity = self.with_pos_embed(tgt_entity, entity_pos) + tgt2_entity = self.self_attn_entity(q_entity, k_entity, value=tgt_entity, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt_entity = tgt_entity + self.dropout2_entity(tgt2_entity) + tgt_entity = self.norm2_entity(tgt_entity) + + tgt2_entity = self.cross_attn_entity(query=self.with_pos_embed(tgt_entity, entity_pos), + key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt_entity = tgt_entity + self.dropout1_entity(tgt2_entity) + tgt_entity = self.norm1_entity(tgt_entity) + tgt_entity = self.forward_ffn_entity(tgt_entity) + + # triplet layer + # coupled self attention + t_num = triplet_pos.shape[0] + h_dim = triplet_pos.shape[2] + tgt_sub, tgt_obj = torch.split(tgt_triplet, h_dim, dim=-1) + q_sub = k_sub = self.with_pos_embed(self.with_pos_embed(tgt_sub, triplet_pos), so_pos[0]) + q_obj = k_obj = self.with_pos_embed(self.with_pos_embed(tgt_obj, triplet_pos), so_pos[1]) + q_so = torch.cat((q_sub, q_obj), dim=0) + k_so = torch.cat((k_sub, k_obj), dim=0) + tgt_so = torch.cat((tgt_sub, tgt_obj), dim=0) + + tgt2_so = self.self_attn_so(q_so, k_so, tgt_so)[0] + tgt_so = tgt_so + self.dropout2_so(tgt2_so) + tgt_so = self.norm2_so(tgt_so) + tgt_sub, tgt_obj = torch.split(tgt_so, t_num, dim=0) + + # subject branch - decoupled visual attention + tgt2_sub, sub_maps = self.cross_attn_sub(query=self.with_pos_embed(tgt_sub, triplet_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask) + tgt_sub = tgt_sub + self.dropout1_sub(tgt2_sub) + tgt_sub = self.norm1_sub(tgt_sub) + + # subject branch - decoupled entity attention + tgt2_sub = self.cross_sub_entity(query=self.with_pos_embed(tgt_sub, triplet_pos), + key=tgt_entity, value=tgt_entity)[0] + tgt_sub = tgt_sub + self.dropout2_sub(tgt2_sub) + tgt_sub = self.norm2_sub(tgt_sub) + tgt_sub = self.forward_ffn_sub(tgt_sub) + + # object branch - decoupled visual attention + tgt2_obj, obj_maps = self.cross_attn_obj(query=self.with_pos_embed(tgt_obj, triplet_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask) + tgt_obj = tgt_obj + self.dropout1_obj(tgt2_obj) + tgt_obj = self.norm1_obj(tgt_obj) + + # object branch - decoupled entity attention + tgt2_obj = self.cross_obj_entity(query=self.with_pos_embed(tgt_obj, triplet_pos), + key=tgt_entity, value=tgt_entity)[0] + tgt_obj = tgt_obj + self.dropout2_obj(tgt2_obj) + tgt_obj = self.norm2_obj(tgt_obj) + tgt_obj = self.forward_ffn_obj(tgt_obj) + + tgt_triplet = torch.cat((tgt_sub, tgt_obj), dim=-1) + return tgt_entity, tgt_triplet, sub_maps, obj_maps + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def build_transformer(args): + return Transformer( + d_model=args.hidden_dim, + dropout=args.dropout, + nhead=args.nheads, + dim_feedforward=args.dim_feedforward, + num_encoder_layers=args.enc_layers, + num_decoder_layers=args.dec_layers, + normalize_before=args.pre_norm, + return_intermediate_dec=True, + ) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") diff --git a/openks/models/pytorch/visual_entity_modules/util/__init__.py b/openks/models/pytorch/visual_entity_modules/util/__init__.py new file mode 100644 index 00000000..168f9979 --- /dev/null +++ b/openks/models/pytorch/visual_entity_modules/util/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved diff --git a/openks/models/pytorch/visual_entity_modules/util/box_ops.py b/openks/models/pytorch/visual_entity_modules/util/box_ops.py new file mode 100644 index 00000000..f747e2ad --- /dev/null +++ b/openks/models/pytorch/visual_entity_modules/util/box_ops.py @@ -0,0 +1,109 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Utilities for bounding box manipulation and GIoU. +""" +import torch +from torchvision.ops.boxes import box_area + +def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), + (x_c + 0.5 * w), (y_c + 0.5 * h)] + return torch.stack(b, dim=-1) + + +def box_xyxy_to_cxcywh(x): + x0, y0, x1, y1 = x.unbind(-1) + b = [(x0 + x1) / 2, (y0 + y1) / 2, + (x1 - x0), (y1 - y0)] + return torch.stack(b, dim=-1) + + +# modified from torchvision to also return the union +def box_iou(boxes1, boxes2): + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + wh = (rb - lt).clamp(min=0) # [N,M,2] + inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + iou = inter / union + return iou, union + + +def generalized_box_iou(boxes1, boxes2): + """ + Generalized IoU from https://giou.stanford.edu/ + + The boxes should be in [x0, y0, x1, y1] format + + Returns a [N, M] pairwise matrix, where N = len(boxes1) + and M = len(boxes2) + """ + # degenerate boxes gives inf / nan results + # so do an early check + assert (boxes1[:, 2:] >= boxes1[:, :2]).all(), boxes1 + assert (boxes2[:, 2:] >= boxes2[:, :2]).all(), boxes2 + iou, union = box_iou(boxes1, boxes2) + + lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) + rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) + + wh = (rb - lt).clamp(min=0) # [N,M,2] + area = wh[:, :, 0] * wh[:, :, 1] + + return iou - (area - union) / area + + +def masks_to_boxes(masks): + """Compute the bounding boxes around the provided masks + + The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. + + Returns a [N, 4] tensors, with the boxes in xyxy format + """ + if masks.numel() == 0: + return torch.zeros((0, 4), device=masks.device) + + h, w = masks.shape[-2:] + + y = torch.arange(0, h, dtype=torch.float) + x = torch.arange(0, w, dtype=torch.float) + y, x = torch.meshgrid(y, x) + + x_mask = (masks * x.unsqueeze(0)) + x_max = x_mask.flatten(1).max(-1)[0] + x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] + + y_mask = (masks * y.unsqueeze(0)) + y_max = y_mask.flatten(1).max(-1)[0] + y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] + + return torch.stack([x_min, y_min, x_max, y_max], 1) + +def get_union_box(boxes1, boxes2): + """ + :param boxes1: subject cxcywh + :param boxes2: object cxcywh + :return: union box cxcywh + """ + boxes1 = box_cxcywh_to_xyxy(boxes1) + boxes2 = box_cxcywh_to_xyxy(boxes2) + + lt = torch.min(boxes1[:, :2], boxes2[:, :2]) + rb = torch.max(boxes1[:, 2:], boxes2[:, 2:]) + + union_boxes = torch.cat((lt,rb),dim=1) + + return box_xyxy_to_cxcywh(union_boxes) + +def rescale_bboxes(out_bbox, size): + img_w, img_h = size + b = box_cxcywh_to_xyxy(out_bbox) + b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32) + return b diff --git a/openks/models/pytorch/visual_entity_modules/util/misc.py b/openks/models/pytorch/visual_entity_modules/util/misc.py new file mode 100644 index 00000000..bc5035b3 --- /dev/null +++ b/openks/models/pytorch/visual_entity_modules/util/misc.py @@ -0,0 +1,467 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" +import os +import subprocess +import time +from collections import defaultdict, deque +import datetime +import pickle +from typing import Optional, List + +import torch +import torch.distributed as dist +from torch import Tensor + +# needed due to empty tensor bug in pytorch and torchvision 0.5 +import torchvision +if float(torchvision.__version__.split('.')[1]) < 7: + from torchvision.ops import _new_empty_tensor + from torchvision.ops.misc import _output_size + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to("cuda") + + # obtain Tensor size of each rank + local_size = torch.tensor([tensor.numel()], device="cuda") + size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) + if local_size != max_size: + padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") + tensor = torch.cat((tensor, padding), dim=0) + dist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that all processes + have the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.all_reduce(values) + if average: + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + if torch.cuda.is_available(): + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}', + 'max mem: {memory:.0f}' + ]) + else: + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ]) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() + sha = 'N/A' + diff = "clean" + branch = 'N/A' + try: + sha = _run(['git', 'rev-parse', 'HEAD']) + subprocess.check_output(['git', 'diff'], cwd=cwd) + diff = _run(['git', 'diff-index', 'HEAD']) + diff = "has uncommited changes" if diff else "clean" + branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +def collate_fn(batch): + batch = list(zip(*batch)) + batch[0] = nested_tensor_from_tensor_list(batch[0]) + return tuple(batch) + + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +class NestedTensor(object): + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): + # TODO make this more general + if tensor_list[0].ndim == 3: + if torchvision._is_tracing(): + # nested_tensor_from_tensor_list() does not export well to ONNX + # call _onnx_nested_tensor_from_tensor_list() instead + return _onnx_nested_tensor_from_tensor_list(tensor_list) + + # TODO make it support different-sized images + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], :img.shape[2]] = False + else: + raise ValueError('not supported') + return NestedTensor(tensor, mask) + + +# _onnx_nested_tensor_from_tensor_list() is an implementation of +# nested_tensor_from_tensor_list() that is supported by ONNX tracing. +@torch.jit.unused +def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: + max_size = [] + for i in range(tensor_list[0].dim()): + max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64) + max_size.append(max_size_i) + max_size = tuple(max_size) + + # work around for + # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + # m[: img.shape[1], :img.shape[2]] = False + # which is not yet supported in onnx + padded_imgs = [] + padded_masks = [] + for img in tensor_list: + padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] + padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) + padded_imgs.append(padded_img) + + m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) + padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) + padded_masks.append(padded_mask.to(torch.bool)) + + tensor = torch.stack(padded_imgs) + mask = torch.stack(padded_masks) + + return NestedTensor(tensor, mask=mask) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + else: + print('Not using distributed mode') + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}'.format( + args.rank, args.dist_url), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +@torch.no_grad() +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + if target.numel() == 0: + return [torch.zeros([], device=output.device)] + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): + # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor + """ + Equivalent to nn.functional.interpolate, but with support for empty batch sizes. + This will eventually be supported natively by PyTorch, and this + class can go away. + """ + if float(torchvision.__version__.split('.')[1]) < 7: + if input.numel() > 0: + return torch.nn.functional.interpolate( + input, size, scale_factor, mode, align_corners + ) + + output_shape = _output_size(2, input, size, scale_factor) + output_shape = list(input.shape[:-2]) + list(output_shape) + return _new_empty_tensor(input, output_shape) + else: + return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)