Skip to content

Commit

Permalink
Merge pull request #31 from fraunhoferportugal/dev
Browse files Browse the repository at this point in the history
Image extractors and version updates
  • Loading branch information
ivo-facoco authored Dec 10, 2024
2 parents 64c9344 + dd96e8e commit a13045d
Show file tree
Hide file tree
Showing 10 changed files with 167 additions and 31 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ repos:
description: Automatically upgrade syntax for newer versions of the language
args: [--py36-plus]
- repo: https://github.com/jendrikseipp/vulture
rev: v2.13
rev: v2.14
hooks:
- id: vulture
name: vulture - finds unused code in Python programs
Expand Down Expand Up @@ -244,13 +244,13 @@ repos:
language: system
files: requirements/*.txt
- repo: https://github.com/PyCQA/bandit
rev: 1.7.10
rev: 1.8.0
hooks:
- id: bandit
args: ["-c", "pyproject.toml"]
additional_dependencies: [".[toml]"]
- repo: https://github.com/PyCQA/bandit
rev: 1.7.10
rev: 1.8.0
hooks:
- id: bandit
name: bandit - find common security issues in Python code.
Expand Down
19 changes: 17 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,21 @@
All notable changes to this project will be documented in this file.
This format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.1.6] - 2024-12-10
Minor patch release with new image features extraction method and documentation updates.

### Added
- `extract_features_from_dataloader` method to the image `BaseExtractor` class

### Changed
- Allowed for `numpy >= 2.0.0` in the requirements
- Allowed `pydantic>2.9.0` in the requirements
- Moved internal method in the image `BaseExtractor` to underscore method

### Fixed
- Simplified the `StandardTransform` in images to work directly with tensors


## [0.1.5] - 2024-11-29
Introduce new time-series metrics and documentation updates.

Expand All @@ -20,12 +35,12 @@ Introduce new time-series metrics and documentation updates.

### Fixed
- PyPI security issues due to direct external `pydom` dependency


## [0.1.4] - 2024-11-21
Taxonomy rework and documentation updates.

### Added
### Added
- readthedocs slug in the README file
- References to tabular metrics

Expand Down
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.1.5
0.1.6
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# https://github.com/microsoft/vscode-python/blob/master/CHANGELOG.md#enhancements-1
[tool.poetry]
name = "pymdma"
version = "0.1.5"
version = "0.1.6"
description = "Multimodal Data Metrics for Auditing real and synthetic data"
authors = ["Fraunhofer AICOS <[email protected]>"]
maintainers = [
Expand Down Expand Up @@ -36,10 +36,10 @@ exclude = ["tests/*", "notebooks/*", "docs/*", "src/**/main.py",
python = ">=3.9, <3.13"
loguru = {version = ">=0.7.2, <0.8.0"}
matplotlib = {version = ">=3.4.3, <4.0.0"}
numpy = {version = ">=1.22.0, <2.0.0"}
numpy = {version = ">=1.22.0, <2.5.0"}
piq = {version = ">=0.8.0, <1.0.0"}
pot = {version = ">=0.9.4, <0.10.0"}
pydantic = {version = ">=2.8.2, <2.9.0"}
pydantic = {version = ">=2.8.2, <3.0.0"}
python-dotenv = {version = ">=1.0.0, <2.0.0"}
torch = {version = ">=2.1.0, <2.5.0"}
gudhi = {version = ">=3.9.0, <=4.0.0"}
Expand Down Expand Up @@ -98,7 +98,7 @@ optional = true
black = "^23.11.0"
flake8 = "^6.1.0"
isort = "^5.12.0"
mypy = "^1.7.1"
mypy = "^1.13.0"
pre-commit = "^3.5.0"
pytest = "^7.4.3"
pytest-cov = "^4.1.0"
Expand Down
2 changes: 1 addition & 1 deletion src/pymdma/common/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,5 @@ def __init__(self, name: str) -> None:
self.name = name

@abstractmethod
def extract_features_dataloader(self, dataloader):
def _extract_features_dataloader(self, dataloader):
pass
4 changes: 2 additions & 2 deletions src/pymdma/image/input_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,11 @@ def get_embeddings(
extractor = ExtractorFactory.model_from_name(model_name) if extractor is None else extractor

# extractor = model_instance if model_instance is not None else FeatureExtractor(model_name, device=self.device)
reference_feats, _labels, _reference_ids = extractor.extract_features_dataloader(
reference_feats, _labels, _reference_ids = extractor._extract_features_dataloader(
self.reference_loader,
device=self.device,
)
synthetic_feats, _labels, synthetic_ids = extractor.extract_features_dataloader(
synthetic_feats, _labels, synthetic_ids = extractor._extract_features_dataloader(
self.target_loader,
device=self.device,
)
Expand Down
100 changes: 84 additions & 16 deletions src/pymdma/image/models/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from loguru import logger
from PIL import Image
from torch.utils.data import DataLoader
from torchvision.transforms import transforms
from tqdm import tqdm

from pymdma.common.definitions import EmbedderInterface
Expand All @@ -26,13 +27,15 @@ def __init__(
self.interp = interpolation
self.preprocess_transform = preprocess_transform

self._to_tensor = transforms.PILToTensor()

def __call__(self, image: Image.Image) -> torch.Tensor:
image = self.preprocess_transform(image) if self.preprocess_transform is not None else image
image = image.resize(self.img_size, self.interp)
# bring image to the range [0, 1] and normalize to [-1, 1]
image = np.array(image).astype(np.float32) / 255.0
image = self._to_tensor(image).float() / 255.0
image = image * 2.0 - 1.0
return torch.from_numpy(image).permute(2, 0, 1).float()
return image


class BaseExtractor(torch.nn.Module, EmbedderInterface):
Expand All @@ -55,14 +58,20 @@ def extract_features_from_files(
device: str = "cpu",
preprocess_transform: Optional[Callable] = None,
) -> np.ndarray:
"""Extract features from a list of image files.
Args:
files (List[Path]): list of paths to image files
batch_size (int): batch size for feature extraction. Defaults to 50.
Returns:
np.ndarray: array of features
"""Extract features from a list of image files. Converts images to
tensors and normalizes them to the [-1, 1] range.
Parameters
----------
files : List[Path]
list of paths to image files
batch_size : int, optional
batch size for feature extraction, by default 50
Returns
-------
np.ndarray
array of features with shape (n_samples, n_features)
"""
if batch_size > len(files):
# print("Warning: batch size is bigger than the data size. " "Setting batch size to data size")
Expand Down Expand Up @@ -90,19 +99,78 @@ def extract_features_from_files(
return np.concatenate(act_array, axis=0)

@torch.no_grad()
def extract_features_dataloader(
def extract_features_from_dataloader(
self,
dataloader: DataLoader,
normalize: bool = False,
device: str = "cpu",
):
"""Extract features from a DataLoader.
Parameters
----------
dataloader : DataLoader
PyTorch DataLoader
normalize : bool, optional
Wether to normalize the images to 0.5 mean and 0.5 std across channels.
Assumes that the images are in the [0, 1] range, by default False
device : str, optional
device to use, by default "cpu"
Notes
-----
Has the following assumptions:
- Dataloader outputs a tuple with the first element being the image batch tensor.
- Image tensors are of shape (batch_size, channels, height, width)
- Tensors are in the [0, 1] range in which case you should set `normalize` to True.
- Tensors have been normalized to the [-1, 1] range (0.5 mean and 0.5 std across channels) or to another range (e.g. ImageNet normalization).
It is recommended to disable the `shuffle` option in the DataLoader for consistency.
Depending on the model, you might need to resize the images to the model's input size.
Returns
-------
np.ndarray
array of features with shape (n_samples, n_features)
"""
self.extractor = self.extractor.to(device, dtype=torch.float32)

# validation dry run
sample_batch = next(iter(dataloader))
assert isinstance(
sample_batch[0], torch.Tensor
), f"First element of the tuple must be a torch.Tensor. Got {type(sample_batch[0])}."

if sample_batch[0].shape[2:] != self.input_size:
logger.warning(
f"Model default input size {self.input_size} does not match the size of the images in the dataloader {sample_batch[0][2:]}. Might lead to execution errors."
)

act_array = []
for batch in tqdm(dataloader, total=len(dataloader)):
images = batch[0]
images = images.to(device, dtype=torch.float32)
if normalize:
images = images * 2.0 - 1.0
act_array.append(self(images).detach().cpu().numpy())
return np.concatenate(act_array, axis=0)

@torch.no_grad()
def _extract_features_dataloader(
self,
dataloader: DataLoader,
device: str = "cpu",
preprocess_transform: Optional[Callable] = None,
) -> Tuple[np.ndarray, np.ndarray]:
"""Use selected model to extract features from all images in
dataloader.
"""Internal method to extract features from a DataLoader.
Args:
dataloader (DataLoader): image dataloader
Parameters
----------
dataloader : DataLoader
image dataloader
Returns:
Tuple[np.ndarray, np.ndarray]: array of features and array of image labels
Tuple[np.ndarray, np.ndarray, np.ndarray]: extracted features, labels, and image ids
"""
logger.info("Extracting image features.")
act_array = []
Expand Down
4 changes: 2 additions & 2 deletions src/pymdma/image/models/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ def forward(self, x):


class DinoExtractor(BaseExtractor):
def __init__(self, model_name) -> None:
def __init__(self, model_name, input_size: tuple[int, int] = (224, 224)) -> None:
super().__init__(
input_size=(224, 224),
input_size=input_size,
interpolation=Image.Resampling.BICUBIC,
)

Expand Down
14 changes: 14 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pytest
from fastapi.testclient import TestClient
from PIL import Image
from torchvision.transforms import transforms

from pymdma.api.run_api import app
from pymdma.config import data_dir
Expand Down Expand Up @@ -65,6 +66,19 @@ def get_extractor(name):
return get_extractor


@pytest.fixture(scope="module")
def image_transforms():
def get_transforms(input_size: Tuple[int], interpolation: int = Image.BILINEAR):
return transforms.Compose(
[
transforms.Resize(input_size, interpolation=interpolation),
transforms.ToTensor(),
]
)

return get_transforms


# ###################################################################################################
# ################################## Time-Series Fixtures ###########################################
# ###################################################################################################
Expand Down
39 changes: 39 additions & 0 deletions tests/test_image_import.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import matplotlib.pyplot as plt
import numpy as np
import pytest
from torch.utils.data import DataLoader

from pymdma.image.data.simple_dataset import SimpleDataset
from pymdma.image.measures.input_val.annotation import coco as ann
from pymdma.image.measures.input_val.data import no_reference as no_ref_quality
from pymdma.image.measures.input_val.data import reference as ref_quality
Expand Down Expand Up @@ -111,3 +113,40 @@ def test_extractor_models(image_feature_extractor, synth_image_filenames, extrac
dataset_level, instance_level = result.value
assert dataset_level > 0.90, "Dataset level is below threshold"
assert all([inst == 1 for inst in instance_level]), "Same image instance should be precise"


@pytest.mark.parametrize(
"extractor_name, input_size",
[
("inception_fid", (299, 299)),
("vgg16", (224, 224)),
("dino_vits8", (224, 224)),
("vit_b_16", (224, 224)),
],
)
def test_extractor_methods(
image_feature_extractor, synth_image_filenames, image_transforms, extractor_name, input_size
):
extractor = image_feature_extractor(extractor_name)

dataset = SimpleDataset(synth_image_filenames, image_transforms(input_size, extractor.interpolation))
dataloader = DataLoader(dataset, batch_size=2, shuffle=False)

features_files = extractor.extract_features_from_files(synth_image_filenames)
features_dataloader = extractor.extract_features_from_dataloader(dataloader, normalize=True)

assert features_files.shape[0] == len(synth_image_filenames), "Feature length does not match input length"
assert features_dataloader.shape[0] == len(synth_image_filenames), "Feature length does not match input length"

assert (
features_dataloader.mean() == features_files.mean()
), "Feature extraction from files and dataloader should be the same"
assert (
features_dataloader.std() == features_files.std()
), "Feature extraction from files and dataloader should be the same"

prec = ImprovedPrecision()
result = prec.compute(features_files, features_dataloader)
dataset_level, instance_level = result.value
assert dataset_level > 0.98, "Dataset level is below threshold"
assert all([inst == 1 for inst in instance_level]), "Same image instance should be precise"

0 comments on commit a13045d

Please sign in to comment.