-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #59 from BerndDoser/data
WIP: Refactoring
- Loading branch information
Showing
107 changed files
with
3,980 additions
and
3,265 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,52 +1,126 @@ | ||
import lightning.pytorch as pl | ||
from pathlib import Path | ||
|
||
import torchvision.transforms.v2 as transforms | ||
from torch.utils.data import DataLoader | ||
from torchvision import transforms | ||
|
||
import data.galaxy_zoo_dataset as galaxy_zoo_dataset | ||
import data.preprocessing as preprocessing | ||
from data.galaxy_zoo_dataset import GalaxyZooDataset | ||
from models.spherinator_module import SpherinatorModule | ||
|
||
from .spherinator_data_module import SpherinatorDataModule | ||
|
||
|
||
class GalaxyZooDataModule(pl.LightningDataModule): | ||
class GalaxyZooDataModule(SpherinatorDataModule): | ||
"""Defines access to the Galaxy Zoo data as a data module.""" | ||
|
||
def __init__(self, data_dir: str = "./", batch_size: int = 32, extension: str = "jpg", shuffle: bool = True, num_workers: int = 16): | ||
def __init__( | ||
self, | ||
data_directory: str = "./", | ||
batch_size: int = 32, | ||
extension: str = "jpg", | ||
shuffle: bool = True, | ||
num_workers: int = 16, | ||
): | ||
"""Initialize GalaxyZooDataModule | ||
Args: | ||
data_directory (str): The directories to scan for data files. | ||
batch_size (int, optional): The batch size for training. Defaults to 32. | ||
extension (str, optional): The kind of files to search for. Defaults to "jpg". | ||
shuffle (bool, optional): Wether or not to shuffle whe reading. Defaults to True. | ||
num_workers (int, optional): How many worker to use for loading. Defaults to 16. | ||
""" | ||
super().__init__() | ||
self.data_dir = data_dir | ||
self.train_transform = transforms.Compose([ | ||
preprocessing.DielemanTransformation( | ||
rotation_range=[0,360], | ||
translation_range=[4./424,4./424], | ||
scaling_range=[1/1.1,1.1], | ||
flip=0.5), | ||
preprocessing.CropAndScale((424,424), (424,424)) | ||
]) | ||
self.val_transform = transforms.Compose([ | ||
preprocessing.CropAndScale((424,424), (424,424)) | ||
]) | ||
|
||
self.data_directory = data_directory | ||
self.batch_size = batch_size | ||
self.extension = extension | ||
self.shuffle = shuffle | ||
self.num_workers = num_workers | ||
|
||
self.transform_train = transforms.Compose( | ||
[ | ||
preprocessing.DielemanTransformation( | ||
rotation_range=[0, 360], | ||
translation_range=[4.0 / 424, 4.0 / 424], | ||
scaling_range=[1 / 1.1, 1.1], | ||
flip=0.5, | ||
), | ||
transforms.CenterCrop((363, 363)), | ||
transforms.Resize((424, 424), antialias=True), | ||
] | ||
) | ||
self.transform_processing = transforms.CenterCrop((363, 363)) | ||
self.transform_images = self.transform_train | ||
self.transform_thumbnail_images = transforms.Compose( | ||
[ | ||
self.transform_processing, | ||
transforms.Resize((100, 100), antialias=True), | ||
] | ||
) | ||
|
||
def setup(self, stage: str): | ||
if stage == "fit": | ||
self.data_train = galaxy_zoo_dataset.GalaxyZooDataset(data_directory=self.data_dir, | ||
extension=self.extension, | ||
transform=self.train_transform) | ||
|
||
self.dataloader_train = DataLoader(self.data_train, | ||
batch_size=self.batch_size, | ||
shuffle=self.shuffle, | ||
num_workers=self.num_workers) | ||
elif stage =="val": | ||
self.data_val = galaxy_zoo_dataset.GalaxyZooDataset(data_directory=self.data_dir, | ||
extension=self.extension, | ||
transform=self.val_transform) | ||
self.dataloader_val = DataLoader(self.data_train, | ||
batch_size=self.batch_size, | ||
shuffle=False, | ||
num_workers=self.num_workers) | ||
|
||
def train_dataloader(self): | ||
return self.dataloader_train | ||
|
||
def val_dataloader(self): | ||
return self.dataloader_val | ||
"""Sets up the data set and data loaders. | ||
Args: | ||
stage (str): Defines for which stage the data is needed. | ||
""" | ||
if not stage in ["fit", "processing", "images", "thumbnail_images"]: | ||
raise ValueError(f"Stage {stage} not supported.") | ||
|
||
if stage == "fit" and self.data_train is None: | ||
self.data_train = GalaxyZooDataset( | ||
data_directory=self.data_directory, | ||
extension=self.extension, | ||
transform=self.transform_train, | ||
) | ||
self.dataloader_train = DataLoader( | ||
self.data_train, | ||
batch_size=self.batch_size, | ||
shuffle=self.shuffle, | ||
num_workers=self.num_workers, | ||
) | ||
elif stage == "processing" and self.data_processing is None: | ||
self.data_processing = GalaxyZooDataset( | ||
data_directory=self.data_directory, | ||
extension=self.extension, | ||
transform=self.transform_processing, | ||
) | ||
self.dataloader_processing = DataLoader( | ||
self.data_train, | ||
batch_size=self.batch_size, | ||
shuffle=False, | ||
num_workers=self.num_workers, | ||
) | ||
elif stage == "images" and self.data_images is None: | ||
self.data_images = GalaxyZooDataset( | ||
data_directory=self.data_directory, | ||
extension=self.extension, | ||
transform=self.transform_images, | ||
) | ||
self.dataloader_images = DataLoader( | ||
self.data_train, | ||
batch_size=self.batch_size, | ||
shuffle=False, | ||
num_workers=self.num_workers, | ||
) | ||
elif stage == "thumbnail_images" and self.data_thumbnail_images is None: | ||
self.data_thumbnail_images = GalaxyZooDataset( | ||
data_directory=self.data_directory, | ||
extension=self.extension, | ||
transform=self.transform_thumbnail_images, | ||
) | ||
self.dataloader_thumbnail_images = DataLoader( | ||
self.data_train, | ||
batch_size=self.batch_size, | ||
shuffle=False, | ||
num_workers=self.num_workers, | ||
) | ||
|
||
def write_catalog( | ||
self, model: SpherinatorModule, catalog_file: Path, hipster_url: str, title: str | ||
): | ||
"""Writes a catalog to disk.""" | ||
self.setup("processing") | ||
with open(catalog_file, "w", encoding="utf-8") as output: | ||
output.write("#filename,RMSD,rotation,x,y,z\n") |
Oops, something went wrong.