diff --git a/.gitignore b/.gitignore index 7d776a671..884f1c41f 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,5 @@ experiment-* .mypy_cache/* not_tracked_dir/ .vscode +*.pt +checkpoint/log.txt diff --git a/engine.py b/engine.py index ac5ea6ff4..239c68d58 100644 --- a/engine.py +++ b/engine.py @@ -23,26 +23,27 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) header = 'Epoch: [{}]'.format(epoch) - print_freq = 10 + print_freq = 100 + scaler = torch.cuda.amp.GradScaler() for samples, targets in metric_logger.log_every(data_loader, print_freq, header): samples = samples.to(device) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] - - outputs = model(samples) - loss_dict = criterion(outputs, targets) - weight_dict = criterion.weight_dict - losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) - - # reduce losses over all GPUs for logging purposes - loss_dict_reduced = utils.reduce_dict(loss_dict) - loss_dict_reduced_unscaled = {f'{k}_unscaled': v - for k, v in loss_dict_reduced.items()} - loss_dict_reduced_scaled = {k: v * weight_dict[k] - for k, v in loss_dict_reduced.items() if k in weight_dict} - losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) - - loss_value = losses_reduced_scaled.item() + with torch.cuda.amp.autocast(): + outputs = model(samples) + loss_dict = criterion(outputs, targets) + weight_dict = criterion.weight_dict + losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) + + # reduce losses over all GPUs for logging purposes + loss_dict_reduced = utils.reduce_dict(loss_dict) + loss_dict_reduced_unscaled = {f'{k}_unscaled': v + for k, v in loss_dict_reduced.items()} + loss_dict_reduced_scaled = {k: v * weight_dict[k] + for k, v in loss_dict_reduced.items() if k in weight_dict} + losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) + + loss_value = losses_reduced_scaled.item() if not math.isfinite(loss_value): print("Loss is {}, stopping training".format(loss_value)) @@ -50,10 +51,12 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, sys.exit(1) optimizer.zero_grad() - losses.backward() + scaler.scale(losses).backward() + if max_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) - optimizer.step() + scaler.step(optimizer) + scaler.update() metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled) metric_logger.update(class_error=loss_dict_reduced['class_error']) diff --git a/main.py b/main.py index e5f9eff80..8c7e9f961 100644 --- a/main.py +++ b/main.py @@ -5,6 +5,7 @@ import random import time from pathlib import Path +import os import numpy as np import torch @@ -84,7 +85,7 @@ def get_args_parser(): parser.add_argument('--coco_panoptic_path', type=str) parser.add_argument('--remove_difficult', action='store_true') - parser.add_argument('--output_dir', default='', + parser.add_argument('--output_dir', default='checkpoint', help='path where to save, empty for no saving') parser.add_argument('--device', default='cuda', help='device to use for training / testing') @@ -93,7 +94,7 @@ def get_args_parser(): parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch') parser.add_argument('--eval', action='store_true') - parser.add_argument('--num_workers', default=2, type=int) + parser.add_argument('--num_workers', default=8, type=int) # distributed training parameters parser.add_argument('--world_size', default=1, type=int, @@ -139,24 +140,32 @@ def main(args): weight_decay=args.weight_decay) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop) - dataset_train = build_dataset(image_set='train', args=args) - dataset_val = build_dataset(image_set='val', args=args) + if not os.path.exists('datasets/dataloaders.pt'): + dataset_train = build_dataset(image_set='train', args=args) + dataset_val = build_dataset(image_set='val', args=args) - if args.distributed: - sampler_train = DistributedSampler(dataset_train) - sampler_val = DistributedSampler(dataset_val, shuffle=False) - else: - sampler_train = torch.utils.data.RandomSampler(dataset_train) - sampler_val = torch.utils.data.SequentialSampler(dataset_val) - - batch_sampler_train = torch.utils.data.BatchSampler( - sampler_train, args.batch_size, drop_last=True) + if args.distributed: + sampler_train = DistributedSampler(dataset_train) + sampler_val = DistributedSampler(dataset_val, shuffle=False) + else: + sampler_train = torch.utils.data.RandomSampler(dataset_train) + sampler_val = torch.utils.data.SequentialSampler(dataset_val) + data_to_save = {'train_dataset':dataset_train,'train_sampler':sampler_train,'val_dataset':dataset_val,"val_sampler":sampler_val} + torch.save(data_to_save,'datasets/dataloaders.pt') + else: + loaded = torch.load('datasets/dataloaders.pt') + dataset_train = loaded['train_dataset'] + dataset_val = loaded['val_dataset'] + sampler_train = loaded['train_sampler'] + sampler_val = loaded['val_sampler'] + + batch_sampler_train = torch.utils.data.BatchSampler(sampler_train, args.batch_size, drop_last=True) data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train, - collate_fn=utils.collate_fn, num_workers=args.num_workers) + collate_fn=utils.collate_fn, num_workers=args.num_workers) data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val, - drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers) - + drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers) + if args.dataset_file == "coco_panoptic": # We also evaluate AP during panoptic training, on original coco DS coco_val = datasets.coco.build("val", args) diff --git a/models/position_encoding.py b/models/position_encoding.py index 73ae39edf..3321ee293 100644 --- a/models/position_encoding.py +++ b/models/position_encoding.py @@ -16,7 +16,7 @@ class PositionEmbeddingSine(nn.Module): """ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): super().__init__() - self.num_pos_feats = num_pos_feats + self.num_pos_feats = num_pos_feats # this comes from self.temperature = temperature self.normalize = normalize if scale is not None and normalize is False: @@ -38,7 +38,7 @@ def forward(self, tensor_list: NestedTensor): x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) - dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) # this is the base frequency sin(pos/(10000^(2*i/ 0.5*arg.hidden_dim))). This is a constant, it doesn't change pos_x = x_embed[:, :, :, None] / dim_t pos_y = y_embed[:, :, :, None] / dim_t diff --git a/references/End-to-End-Object-Detection-with-Transformers.pdf b/references/End-to-End-Object-Detection-with-Transformers.pdf new file mode 100644 index 000000000..5448496a2 Binary files /dev/null and b/references/End-to-End-Object-Detection-with-Transformers.pdf differ