Skip to content

Commit

Permalink
Major re-format, added qualitative examples, possibility to run witho…
Browse files Browse the repository at this point in the history
…ut labels
  • Loading branch information
gmberton committed Mar 10, 2024
1 parent fff30bc commit 68f127f
Show file tree
Hide file tree
Showing 28 changed files with 94 additions and 76 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Predictions can be easily visualized through the `num_preds_to_save` parameter.

```
python3 main.py --method=cosplace --backbone=ResNet18 --descriptors_dimension=512 \
--num_preds_to_save=3 --exp_name=cosplace_on_stlucia \
--num_preds_to_save=3 --log_dir=cosplace_on_stlucia \
--database_folder=../VPR-datasets-downloader/datasets/st_lucia/images/test/database \
--queries_folder=../VPR-datasets-downloader/datasets/st_lucia/images/test/queries
```
Expand Down
Binary file added assets/database/db1.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/database/db10.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/database/db11.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/database/db12.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/database/db13.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/database/db14.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/database/db15.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/database/db16.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/database/db17.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/database/db2.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/database/db3.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/database/db4.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/database/db5.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/database/db6.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/database/db7.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/database/db8.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/database/db9.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/queries/q1.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/queries/q2.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/queries/q3.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/queries/q4.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/queries/q5.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
13 changes: 6 additions & 7 deletions commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import sys
import logging
import traceback
from pathlib import Path


def setup_logging(output_folder: str, stdout: str = "debug", info_filename: str = "info.log",
def setup_logging(log_dir: Path, stdout: str = "debug", info_filename: str = "info.log",
debug_filename: str = "debug.log"):
"""After calling this function, you can easily log messages from anywhere
in your code without passing any object to your functions.
Expand All @@ -15,7 +16,7 @@ def setup_logging(output_folder: str, stdout: str = "debug", info_filename: str
Parameters
----------
output_folder : str, the folder where to save the logging files.
log_dir : str, the folder where to save the logging files.
stdout : str, can be "debug" or "info".
If stdout == "debug", print in stdout any time logging.debug(msg)
(or logging.info(msg)) is called.
Expand All @@ -26,9 +27,7 @@ def setup_logging(output_folder: str, stdout: str = "debug", info_filename: str
logging.debug(msg) or logging.info(msg).
"""
if os.path.exists(output_folder):
raise FileExistsError(f"{output_folder} already exists!")
os.makedirs(output_folder)
log_dir.mkdir(parents=True)
# logging.Logger.manager.loggerDict.keys() to check which loggers are in use
logging.getLogger('matplotlib.font_manager').disabled = True
logging.getLogger('shapely').disabled = True
Expand All @@ -39,13 +38,13 @@ def setup_logging(output_folder: str, stdout: str = "debug", info_filename: str
logging.getLogger('PIL').setLevel(logging.INFO) # turn off logging tag for some images

if info_filename is not None:
info_file_handler = logging.FileHandler(f'{output_folder}/{info_filename}')
info_file_handler = logging.FileHandler(log_dir / info_filename)
info_file_handler.setLevel(logging.INFO)
info_file_handler.setFormatter(base_formatter)
logger.addHandler(info_file_handler)

if debug_filename is not None:
debug_file_handler = logging.FileHandler(f'{output_folder}/{debug_filename}')
debug_file_handler = logging.FileHandler(log_dir / debug_filename)
debug_file_handler.setLevel(logging.DEBUG)
debug_file_handler.setFormatter(base_formatter)
logger.addHandler(debug_file_handler)
Expand Down
44 changes: 23 additions & 21 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import numpy as np
from tqdm import tqdm
from pathlib import Path
from datetime import datetime
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Subset
Expand All @@ -17,24 +18,24 @@

args = parser.parse_arguments()
start_time = datetime.now()
output_folder = f"logs/{args.exp_name}/{start_time.strftime('%Y-%m-%d_%H-%M-%S')}"
commons.setup_logging(output_folder, stdout="info")
log_dir = Path("logs") / args.log_dir / start_time.strftime('%Y-%m-%d_%H-%M-%S')
commons.setup_logging(log_dir, stdout="info")
logging.info(" ".join(sys.argv))
logging.info(f"Arguments: {args}")
logging.info(f"Testing with {args.method} with a {args.backbone} backbone and descriptors dimension {args.descriptors_dimension}")
logging.info(f"The outputs are being saved in {output_folder}")
logging.info(f"The outputs are being saved in {log_dir}")

model = vpr_models.get_model(args.method, args.backbone, args.descriptors_dimension)
model = model.eval().to(args.device)

test_ds = TestDataset(args.database_folder, args.queries_folder,
positive_dist_threshold=args.positive_dist_threshold,
image_size=args.image_size)
image_size=args.image_size, use_labels=args.use_labels)
logging.info(f"Testing on {test_ds}")

with torch.inference_mode():
logging.debug("Extracting database descriptors for evaluation/testing")
database_subset_ds = Subset(test_ds, list(range(test_ds.database_num)))
database_subset_ds = Subset(test_ds, list(range(test_ds.num_database)))
database_dataloader = DataLoader(dataset=database_subset_ds, num_workers=args.num_workers,
batch_size=args.batch_size)
all_descriptors = np.empty((len(test_ds), args.descriptors_dimension), dtype="float32")
Expand All @@ -45,16 +46,16 @@

logging.debug("Extracting queries descriptors for evaluation/testing using batch size 1")
queries_subset_ds = Subset(test_ds,
list(range(test_ds.database_num, test_ds.database_num + test_ds.queries_num)))
list(range(test_ds.num_database, test_ds.num_database + test_ds.num_queries)))
queries_dataloader = DataLoader(dataset=queries_subset_ds, num_workers=args.num_workers,
batch_size=1)
for images, indices in tqdm(queries_dataloader, ncols=100):
descriptors = model(images.to(args.device))
descriptors = descriptors.cpu().numpy()
all_descriptors[indices.numpy(), :] = descriptors

queries_descriptors = all_descriptors[test_ds.database_num:]
database_descriptors = all_descriptors[:test_ds.database_num]
queries_descriptors = all_descriptors[test_ds.num_database:]
database_descriptors = all_descriptors[:test_ds.num_database]

# Use a kNN to find predictions
faiss_index = faiss.IndexFlatL2(args.descriptors_dimension)
Expand All @@ -65,23 +66,24 @@
_, predictions = faiss_index.search(queries_descriptors, max(args.recall_values))

# For each query, check if the predictions are correct
positives_per_query = test_ds.get_positives()
recalls = np.zeros(len(args.recall_values))
for query_index, preds in enumerate(predictions):
for i, n in enumerate(args.recall_values):
if np.any(np.in1d(preds[:n], positives_per_query[query_index])):
recalls[i:] += 1
break

# Divide by queries_num and multiply by 100, so the recalls are in percentages
recalls = recalls / test_ds.queries_num * 100
recalls_str = ", ".join([f"R@{val}: {rec:.1f}" for val, rec in zip(args.recall_values, recalls)])
logging.info(recalls_str)
if args.use_labels:
positives_per_query = test_ds.get_positives()
recalls = np.zeros(len(args.recall_values))
for query_index, preds in enumerate(predictions):
for i, n in enumerate(args.recall_values):
if np.any(np.in1d(preds[:n], positives_per_query[query_index])):
recalls[i:] += 1
break

# Divide by num_queries and multiply by 100, so the recalls are in percentages
recalls = recalls / test_ds.num_queries * 100
recalls_str = ", ".join([f"R@{val}: {rec:.1f}" for val, rec in zip(args.recall_values, recalls)])
logging.info(recalls_str)

# Save visualizations of predictions
if args.num_preds_to_save != 0:
logging.info("Saving final predictions")
# For each query save num_preds_to_save predictions
visualizations.save_preds(predictions[:, :args.num_preds_to_save], test_ds,
output_folder, args.save_only_wrong_preds)
log_dir, args.save_only_wrong_preds, args.use_labels)

9 changes: 7 additions & 2 deletions parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@ def parse_arguments():
help="_")
parser.add_argument("--batch_size", type=int, default=4,
help="set to 1 if database images may have different resolution")
parser.add_argument("--exp_name", type=str, default="default",
help="experiment name, output logs will be saved under logs/exp_name")
parser.add_argument("--log_dir", type=str, default="default",
help="experiment name, output logs will be saved under logs/log_dir")
parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"],
help="_")
parser.add_argument("--recall_values", type=int, nargs="+", default=[1, 5, 10, 20],
help="values for recall (e.g. recall@1, recall@5)")
parser.add_argument("--no_labels", action="store_true",
help="set to true if you have no labels and just want to "
"do standard image retrieval given two folders of queries and DB")
parser.add_argument("--num_preds_to_save", type=int, default=0,
help="set != 0 if you want to save predictions for each query")
parser.add_argument("--save_only_wrong_preds", action="store_true",
Expand All @@ -40,6 +43,8 @@ def parse_arguments():
"smallest edge of all images to this value, while keeping aspect ratio")
args = parser.parse_args()

args.use_labels = not args.no_labels

if args.method == "netvlad":
if args.backbone not in [None, "VGG16"]:
raise ValueError("When using NetVLAD the backbone must be None or VGG16")
Expand Down
56 changes: 29 additions & 27 deletions test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def read_images_paths(dataset_folder):


class TestDataset(data.Dataset):
def __init__(self, database_folder, queries_folder, positive_dist_threshold=25, image_size=None):
def __init__(self, database_folder, queries_folder, positive_dist_threshold=25,
image_size=None, use_labels=True):
"""Dataset with images from database and queries, used for validation and test.
Parameters
----------
Expand All @@ -62,51 +63,52 @@ def __init__(self, database_folder, queries_folder, positive_dist_threshold=25,
self.database_paths = read_images_paths(database_folder)
self.queries_paths = read_images_paths(queries_folder)

# Read UTM coordinates, which must be contained within the paths
# The format must be path/to/file/@utm_easting@utm_northing@[email protected]
try:
# This is just a sanity check
image_path = self.database_paths[0]
utm_east = float(image_path.split("@")[1])
utm_north = float(image_path.split("@")[2])
except:
raise ValueError("The path of images should be path/to/file/@utm_east@utm_north@[email protected] "
f"but it is {image_path}, which does not contain the UTM coordinates.")
self.images_paths = list(self.database_paths) + list(self.queries_paths)

self.database_utms = np.array([(path.split("@")[1], path.split("@")[2]) for path in self.database_paths]).astype(float)
self.queries_utms = np.array([(path.split("@")[1], path.split("@")[2]) for path in self.queries_paths]).astype(float)
self.num_database = len(self.database_paths)
self.num_queries = len(self.queries_paths)

# Find positives_per_query, which are within positive_dist_threshold (default 25 meters)
knn = NearestNeighbors(n_jobs=-1)
knn.fit(self.database_utms)
self.positives_per_query = knn.radius_neighbors(self.queries_utms,
radius=positive_dist_threshold,
return_distance=False)
if use_labels:
# Read UTM coordinates, which must be contained within the paths
# The format must be path/to/file/@utm_easting@utm_northing@[email protected]
try:
# This is just a sanity check
image_path = self.database_paths[0]
utm_east = float(image_path.split("@")[1])
utm_north = float(image_path.split("@")[2])
except:
raise ValueError("The path of images should be path/to/file/@utm_east@utm_north@[email protected] "
f"but it is {image_path}, which does not contain the UTM coordinates.")

self.database_utms = np.array([(path.split("@")[1], path.split("@")[2]) for path in self.database_paths]).astype(float)
self.queries_utms = np.array([(path.split("@")[1], path.split("@")[2]) for path in self.queries_paths]).astype(float)

# Find positives_per_query, which are within positive_dist_threshold (default 25 meters)
knn = NearestNeighbors(n_jobs=-1)
knn.fit(self.database_utms)
self.positives_per_query = knn.radius_neighbors(self.queries_utms,
radius=positive_dist_threshold,
return_distance=False)

self.images_paths = [p for p in self.database_paths]
self.images_paths += [p for p in self.queries_paths]

self.database_num = len(self.database_paths)
self.queries_num = len(self.queries_paths)
transformations = [
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
if image_size:
transformations.append(transforms.Resize(size=image_size, antialias=True))
self.base_transform = transforms.Compose(transformations)
self.transform = transforms.Compose(transformations)

def __getitem__(self, index):
image_path = self.images_paths[index]
pil_img = Image.open(image_path).convert("RGB")
normalized_img = self.base_transform(pil_img)
normalized_img = self.transform(pil_img)
return normalized_img, index

def __len__(self):
return len(self.images_paths)

def __repr__(self):
return f"< #queries: {self.queries_num}; #database: {self.database_num} >"
return f"< #queries: {self.num_queries}; #database: {self.num_database} >"

def get_positives(self):
return self.positives_per_query
Expand Down
46 changes: 28 additions & 18 deletions visualizations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@

import os
import cv2
import numpy as np
from tqdm import tqdm
Expand All @@ -11,19 +10,19 @@
H = 512
W = 512
TEXT_H = 175
FONTSIZE = 80
FONTSIZE = 50
SPACE = 50 # Space between two images


def write_labels_to_image(labels=["text1", "text2"]):
"""Creates an image with vertical text, spaced along rows."""
"""Creates an image with text"""
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", FONTSIZE)
img = Image.new('RGB', ((W * len(labels)) + 50 * (len(labels)-1), TEXT_H), (1, 1, 1))
d = ImageDraw.Draw(img)
for i, text in enumerate(labels):
_, _, w, h = d.textbbox((0,0), text, font=font)
d.text(((W+SPACE)*i + W//2 - w//2, 1), text, fill=(0, 0, 0), font=font)
return np.array(img)
return np.array(img)[:100] # Remove some empty space


def draw(img, c=(0, 255, 0), thickness=20):
Expand All @@ -34,12 +33,12 @@ def draw(img, c=(0, 255, 0), thickness=20):
return cv2.line(img, (p[3, 0], p[3, 1]), (p[0, 0], p[0, 1]), c, thickness=thickness*2)


def build_prediction_image(images_paths, preds_correct=None):
def build_prediction_image(images_paths, preds_correct):
"""Build a row of images, where the first is the query and the rest are predictions.
For each image, if is_correct then draw a green/red box.
"""
assert len(images_paths) == len(preds_correct)
labels = ["Query"] + [f"Pr{i} - {is_correct}" for i, is_correct in enumerate(preds_correct[1:])]
labels = ["Query"] + [f"Pred{i} - {is_correct}" for i, is_correct in enumerate(preds_correct[1:])]
num_images = len(images_paths)
images = [np.array(Image.open(path).convert("RGB")) for path in images_paths]
for img, correct in zip(images, preds_correct):
Expand All @@ -63,19 +62,20 @@ def build_prediction_image(images_paths, preds_correct=None):
return final_image


def save_file_with_paths(query_path, preds_paths, positives_paths, output_path):
def save_file_with_paths(query_path, preds_paths, positives_paths, output_path, use_labels=True):
file_content = []
file_content.append("Query path:")
file_content.append(query_path + "\n")
file_content.append("Predictions paths:")
file_content.append("\n".join(preds_paths) + "\n")
file_content.append("Positives paths:")
file_content.append("\n".join(positives_paths) + "\n")
if use_labels:
file_content.append("Positives paths:")
file_content.append("\n".join(positives_paths) + "\n")
with open(output_path, "w") as file:
_ = file.write("\n".join(file_content))


def save_preds(predictions, eval_ds, output_folder, save_only_wrong_preds=None):
def save_preds(predictions, eval_ds, log_dir, save_only_wrong_preds=None, use_labels=True):
"""For each query, save an image containing the query and its predictions,
and a file with the paths of the query, its predictions and its positives.
Expand All @@ -84,35 +84,45 @@ def save_preds(predictions, eval_ds, output_folder, save_only_wrong_preds=None):
predictions : np.array of shape [num_queries x num_preds_to_viz], with the preds
for each query
eval_ds : TestDataset
output_folder : str / Path with the path to save the predictions
log_dir : Path with the path to save the predictions
save_only_wrong_preds : bool, if True save only the wrongly predicted queries,
i.e. the ones where the first pred is uncorrect (further than 25 m)
"""
positives_per_query = eval_ds.get_positives()
os.makedirs(f"{output_folder}/preds", exist_ok=True)
for query_index, preds in enumerate(tqdm(predictions, ncols=80, desc=f"Saving preds in {output_folder}")):
if use_labels:
positives_per_query = eval_ds.get_positives()

viz_dir = (log_dir / "preds")
viz_dir.mkdir()
for query_index, preds in enumerate(tqdm(predictions, ncols=80, desc=f"Saving preds in {viz_dir}")):
query_path = eval_ds.queries_paths[query_index]
list_of_images_paths = [query_path]
# List of None (query), True (correct preds) or False (wrong preds)
preds_correct = [None]
for pred_index, pred in enumerate(preds):
pred_path = eval_ds.database_paths[pred]
list_of_images_paths.append(pred_path)
is_correct = pred in positives_per_query[query_index]
if use_labels:
is_correct = pred in positives_per_query[query_index]
else:
is_correct = None
preds_correct.append(is_correct)

if save_only_wrong_preds and preds_correct[1]:
continue

prediction_image = build_prediction_image(list_of_images_paths, preds_correct)
pred_image_path = f"{output_folder}/preds/{query_index:03d}.jpg"
pred_image_path = viz_dir / f"{query_index:03d}.jpg"
prediction_image.save(pred_image_path)

positives_paths = [eval_ds.database_paths[idx] for idx in positives_per_query[query_index]]
if use_labels:
positives_paths = [eval_ds.database_paths[idx] for idx in positives_per_query[query_index]]
else:
positives_paths = None
save_file_with_paths(
query_path=list_of_images_paths[0],
preds_paths=list_of_images_paths[1:],
positives_paths=positives_paths,
output_path=f"{output_folder}/preds/{query_index:03d}.txt"
output_path=viz_dir / f"{query_index:03d}.txt",
use_labels=use_labels
)

0 comments on commit 68f127f

Please sign in to comment.