Skip to content

Commit

Permalink
Merge pull request #59 from BerndDoser/data
Browse files Browse the repository at this point in the history
WIP: Refactoring
  • Loading branch information
BerndDoser authored Dec 14, 2023
2 parents b5c0b2e + c9cf449 commit 91636a2
Show file tree
Hide file tree
Showing 107 changed files with 3,980 additions and 3,265 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
path: ~/.local
key: poetry-1.7.1-0

- name: Install peotry
- name: Install poetry
uses: snok/install-poetry@v1
with:
version: 1.7.1
Expand Down
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"python.testing.pytestEnabled": true,
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.formatOnSave": false,
"editor.formatOnSave": true,
},
"python.analysis.autoImportCompletions": true
}
13 changes: 9 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
[![Build Status](https://github.com/HITS-AIN/Spherinator/actions/workflows/python-package.yml/badge.svg?branch=main)](https://github.com/HITS-AIN/Spherinator/actions/workflows/python-package.yml?branch=main)
![versions](https://img.shields.io/badge/python-3.9%20%7C%203.10%20%7C%203.11-blue)

# Spherinator
# Spherinator & HiPSter

Provides simple autoencoders to project images to the surface of a sphere inluding a tool to creat HiPS representations for browsing.
The `Spherinator` uses [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/) to implement a convolutional neural network (CNN) based variational autoencoder (VAE) with a spherical latent space.
The `HiPSter` creates the HiPS tilings and the catalog which can be visualized interactively on the surface of a sphere with [Aladin Lite](https://github.com/cds-astro/aladin-lite).

![HiPSter example](efigi.png "Example of autoencoded HiPS tiling for efigi data of nearby galaxies in SDSS")
<p align="center">
<img src="docs/P404_f2.png" width="400" height="400">
</p>


## Git clone with submodules
Expand Down Expand Up @@ -57,11 +60,13 @@ Examples:
The following command generates a HiPS representation and a catalog showing the real images located on the latent space using the trained model.

```bash
./hipster.py all --checkpoint <checkpoint-file>.ckpt
./hipster.py --checkpoint <checkpoint-file>.ckpt
```

Call `./hipster.py --help` for more information.

For visualization, a Python HTTP server can be started by executing `python3 -m http.server 8082` within the HiPSter output file.


## Profiling

Expand Down
19 changes: 6 additions & 13 deletions callbacks/log_reconstruction_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torchvision.transforms.functional as functional
from lightning.pytorch.callbacks import Callback
from matplotlib import figure
from torchvision import transforms
import numpy as np

matplotlib.use("Agg")

Expand All @@ -24,9 +24,8 @@ def on_train_epoch_end(self, trainer, pl_module):
return

# Generate some random samples from the validation set
samples = next(iter(trainer.train_dataloader))["image"]
samples = samples[: self.num_samples]
samples = samples.to(pl_module.device)
data = next(iter(trainer.train_dataloader))
samples = data[: self.num_samples].to(pl_module.device)

# Generate reconstructions of the samples using the model
with torch.no_grad():
Expand All @@ -48,7 +47,7 @@ def on_train_epoch_end(self, trainer, pl_module):
rotate, [pl_module.crop_size, pl_module.crop_size]
)
scaled = functional.resize(
crop, [pl_module.input_size, pl_module.input_size], antialias=False
crop, [pl_module.input_size, pl_module.input_size], antialias=True
)

if pl_module.__class__.__name__ == "RotationalAutoencoder":
Expand All @@ -62,20 +61,14 @@ def on_train_epoch_end(self, trainer, pl_module):
best_scaled[best_recon_idx] = scaled[best_recon_idx]
best_recon[best_recon_idx] = recon[best_recon_idx]

normalize = transforms.Lambda(
lambda x: (x - torch.min(x)) / (torch.max(x) - torch.min(x))
)
normalize(best_scaled)
normalize(best_recon)

# Plot the original samples and their reconstructions side by side
fig = figure.Figure(figsize=(6, 2 * self.num_samples))
ax = fig.subplots(self.num_samples, 2)
for i in range(self.num_samples):
ax[i, 0].imshow(best_scaled[i].cpu().detach().numpy().T)
ax[i, 0].imshow(np.clip(best_scaled[i].cpu().detach().numpy().T, 0, 1))
ax[i, 0].set_title("Original")
ax[i, 0].axis("off")
ax[i, 1].imshow(best_recon[i].cpu().detach().numpy().T)
ax[i, 1].imshow(np.clip(best_recon[i].cpu().detach().numpy().T, 0, 1))
ax[i, 1].set_title("Reconstruction")
ax[i, 1].axis("off")
fig.tight_layout()
Expand Down
14 changes: 8 additions & 6 deletions data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,19 @@
"""

from .illustris_sdss_dataset import IllustrisSdssDataset
from .illustris_sdss_dataset_with_metadata import IllustrisSdssDatasetWithMetadata
from .illustris_sdss_data_module import IllustrisSdssDataModule
from .galaxy_zoo_dataset import GalaxyZooDataset
from .galaxy_zoo_data_module import GalaxyZooDataModule
from .shapes_dataset import ShapesDataset
from .shapes_data_module import ShapesDataModule

__all__ = [
'IllustrisSdssDataset',
'IllustrisSdssDataModule',
'GalaxyZooDataset',
'GalaxyZooDataModule',
'ShapesDataset',
'ShapesDataModule'
"IllustrisSdssDataset",
"IllustrisSdssDatasetWithMetadata",
"IllustrisSdssDataModule",
"GalaxyZooDataset",
"GalaxyZooDataModule",
"ShapesDataset",
"ShapesDataModule",
]
154 changes: 114 additions & 40 deletions data/galaxy_zoo_data_module.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,126 @@
import lightning.pytorch as pl
from pathlib import Path

import torchvision.transforms.v2 as transforms
from torch.utils.data import DataLoader
from torchvision import transforms

import data.galaxy_zoo_dataset as galaxy_zoo_dataset
import data.preprocessing as preprocessing
from data.galaxy_zoo_dataset import GalaxyZooDataset
from models.spherinator_module import SpherinatorModule

from .spherinator_data_module import SpherinatorDataModule


class GalaxyZooDataModule(pl.LightningDataModule):
class GalaxyZooDataModule(SpherinatorDataModule):
"""Defines access to the Galaxy Zoo data as a data module."""

def __init__(self, data_dir: str = "./", batch_size: int = 32, extension: str = "jpg", shuffle: bool = True, num_workers: int = 16):
def __init__(
self,
data_directory: str = "./",
batch_size: int = 32,
extension: str = "jpg",
shuffle: bool = True,
num_workers: int = 16,
):
"""Initialize GalaxyZooDataModule
Args:
data_directory (str): The directories to scan for data files.
batch_size (int, optional): The batch size for training. Defaults to 32.
extension (str, optional): The kind of files to search for. Defaults to "jpg".
shuffle (bool, optional): Wether or not to shuffle whe reading. Defaults to True.
num_workers (int, optional): How many worker to use for loading. Defaults to 16.
"""
super().__init__()
self.data_dir = data_dir
self.train_transform = transforms.Compose([
preprocessing.DielemanTransformation(
rotation_range=[0,360],
translation_range=[4./424,4./424],
scaling_range=[1/1.1,1.1],
flip=0.5),
preprocessing.CropAndScale((424,424), (424,424))
])
self.val_transform = transforms.Compose([
preprocessing.CropAndScale((424,424), (424,424))
])

self.data_directory = data_directory
self.batch_size = batch_size
self.extension = extension
self.shuffle = shuffle
self.num_workers = num_workers

self.transform_train = transforms.Compose(
[
preprocessing.DielemanTransformation(
rotation_range=[0, 360],
translation_range=[4.0 / 424, 4.0 / 424],
scaling_range=[1 / 1.1, 1.1],
flip=0.5,
),
transforms.CenterCrop((363, 363)),
transforms.Resize((424, 424), antialias=True),
]
)
self.transform_processing = transforms.CenterCrop((363, 363))
self.transform_images = self.transform_train
self.transform_thumbnail_images = transforms.Compose(
[
self.transform_processing,
transforms.Resize((100, 100), antialias=True),
]
)

def setup(self, stage: str):
if stage == "fit":
self.data_train = galaxy_zoo_dataset.GalaxyZooDataset(data_directory=self.data_dir,
extension=self.extension,
transform=self.train_transform)

self.dataloader_train = DataLoader(self.data_train,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers)
elif stage =="val":
self.data_val = galaxy_zoo_dataset.GalaxyZooDataset(data_directory=self.data_dir,
extension=self.extension,
transform=self.val_transform)
self.dataloader_val = DataLoader(self.data_train,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers)

def train_dataloader(self):
return self.dataloader_train

def val_dataloader(self):
return self.dataloader_val
"""Sets up the data set and data loaders.
Args:
stage (str): Defines for which stage the data is needed.
"""
if not stage in ["fit", "processing", "images", "thumbnail_images"]:
raise ValueError(f"Stage {stage} not supported.")

if stage == "fit" and self.data_train is None:
self.data_train = GalaxyZooDataset(
data_directory=self.data_directory,
extension=self.extension,
transform=self.transform_train,
)
self.dataloader_train = DataLoader(
self.data_train,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
)
elif stage == "processing" and self.data_processing is None:
self.data_processing = GalaxyZooDataset(
data_directory=self.data_directory,
extension=self.extension,
transform=self.transform_processing,
)
self.dataloader_processing = DataLoader(
self.data_train,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
)
elif stage == "images" and self.data_images is None:
self.data_images = GalaxyZooDataset(
data_directory=self.data_directory,
extension=self.extension,
transform=self.transform_images,
)
self.dataloader_images = DataLoader(
self.data_train,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
)
elif stage == "thumbnail_images" and self.data_thumbnail_images is None:
self.data_thumbnail_images = GalaxyZooDataset(
data_directory=self.data_directory,
extension=self.extension,
transform=self.transform_thumbnail_images,
)
self.dataloader_thumbnail_images = DataLoader(
self.data_train,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
)

def write_catalog(
self, model: SpherinatorModule, catalog_file: Path, hipster_url: str, title: str
):
"""Writes a catalog to disk."""
self.setup("processing")
with open(catalog_file, "w", encoding="utf-8") as output:
output.write("#filename,RMSD,rotation,x,y,z\n")
Loading

0 comments on commit 91636a2

Please sign in to comment.