diff --git a/.gitignore b/.gitignore index 6199eb12..d7ae58f9 100755 --- a/.gitignore +++ b/.gitignore @@ -165,4 +165,6 @@ results hparams.yaml -data/pretrained_models \ No newline at end of file +data/pretrained_models + +*.tar \ No newline at end of file diff --git a/only_for_me/narval/finetune.sh b/only_for_me/narval/finetune.sh index 84506204..665cfba2 100644 --- a/only_for_me/narval/finetune.sh +++ b/only_for_me/narval/finetune.sh @@ -6,6 +6,7 @@ #SBATCH --cpus-per-task=12 #SBATCH --gres=gpu:a100:2 +# https://github.com/webdataset/webdataset-lightning/blob/main/simple_cluster.py #### SBATCH --mem=32G #### SBATCH --nodes=1 #### SBATCH --time=0:20:0 diff --git a/only_for_me/narval/gz_decals_webdataset.py b/only_for_me/narval/gz_decals_webdataset.py new file mode 100644 index 00000000..4abce9eb --- /dev/null +++ b/only_for_me/narval/gz_decals_webdataset.py @@ -0,0 +1,120 @@ +import logging +import os +import shutil +import sys +import cv2 +import json +from itertools import islice +import glob + +import tqdm +import numpy as np +import pandas as pd +from PIL import Image # necessary to avoid PIL.Image error assumption in web_datasets + +from galaxy_datasets.shared import label_metadata +from galaxy_datasets import gz_decals_5 +from galaxy_datasets.transforms import default_transforms +from galaxy_datasets.pytorch import galaxy_dataset + +import webdataset as wds + +def galaxy_to_wds(galaxy: pd.Series, label_cols): + + im = cv2.imread(galaxy['file_loc']) + # cv2 loads BGR for 'history', fix + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + labels = json.dumps(galaxy[label_cols].to_dict()) + id_str = str(galaxy['id_str']) + # print(id_str) + return { + "__key__": id_str, + "image.jpg": im, + "labels.json": labels + } + +def df_to_wds(df: pd.DataFrame, label_cols, save_loc, n_shards): + df['id_str'] = df['id_str'].str.replace('.', '_') + + shard_dfs = np.array_split(df, n_shards) + print('shards: ', len(shard_dfs)) + print('shard size: ', len(shard_dfs[0])) + for shard_n, shard_df in tqdm.tqdm(enumerate(shard_dfs), total=len(shard_dfs)): + shard_save_loc = save_loc.replace('.tar', f'_{shard_n}_{len(shard_df)}.tar') + print(shard_save_loc) + sink = wds.TarWriter(shard_save_loc) + for index, galaxy in shard_df.iterrows(): + sink.write(galaxy_to_wds(galaxy, label_cols)) + sink.close() + +def check_wds(wds_loc): + + dataset = wds.WebDataset(wds_loc) \ + .decode("rgb") + + for sample in islice(dataset, 0, 3): + print(sample['__key__']) + print(sample['image.jpg'].shape) # .decode(jpg) converts to decoded to 0-1 RGB float, was 0-255 + print(type(sample['labels.json'])) # automatically decoded + +def identity(x): + # no lambda to be pickleable + return x + + +def load_wds(wds_loc): + + augmentation_transform = default_transforms() # A.Compose object + def do_transform(img): + return np.transpose(augmentation_transform(image=np.array(img))["image"], axes=[2, 0, 1]).astype(np.float32) + + dataset = wds.WebDataset(wds_loc) \ + .decode("rgb") \ + .to_tuple('image.jpg', 'labels.json') \ + .map_tuple(do_transform, identity) + + for sample in islice(dataset, 0, 3): + print(sample[0].shape) + print(sample[1]) + + +def main(): + + train_catalog, _ = gz_decals_5(root='/home/walml/repos/zoobot/only_for_me/narval/temp', download=False, train=True) + + train_catalog = train_catalog[:512*64] + label_cols = label_metadata.decals_dr5_ortho_label_cols + + save_loc = "gz_decals_5_train.tar" + + # df_to_wds(train_catalog, label_cols, save_loc, n_shards=8) + + # check_wds(save_loc) + + # load_wds(save_loc) + + import zoobot.pytorch.training.webdatamodule as webdatamodule + + wdm = webdatamodule.WebDataModule( + train_urls=glob.glob(save_loc.replace('.tar', '_*.tar')), + val_urls=[], + # train_size=len(train_catalog), + # val_size=0, + label_cols=label_cols, + num_workers=1 + ) + wdm.setup('fit') + + for sample in islice(wdm.train_dataloader(), 0, 3): + images, labels = sample + print(images.shape) + # print(len(labels)) # list of dicts + print(labels) + exit() + + + +if __name__ == '__main__': + + main() + diff --git a/only_for_me/narval/train.py b/only_for_me/narval/train.py new file mode 100644 index 00000000..eb616450 --- /dev/null +++ b/only_for_me/narval/train.py @@ -0,0 +1,130 @@ +import logging +import os +import argparse +import glob + +from pytorch_lightning.loggers import WandbLogger +import wandb + +from zoobot.pytorch.training import train_with_pytorch_lightning +from zoobot.shared import benchmark_datasets, schemas + + +if __name__ == '__main__': + + """ + Used to create the PyTorch pretrained weights checkpoints + See .sh file of the same name for args used. + + See zoobot/pytorch/examples/minimal_examples.py for a friendlier example + """ + parser = argparse.ArgumentParser() + parser.add_argument('--save-dir', dest='save_dir', type=str) + # parser.add_argument('--data-dir', dest='data_dir', type=str) + # parser.add_argument('--dataset', dest='dataset', type=str, help='dataset to use, either "gz_decals_dr5" or "gz_evo"') + parser.add_argument('--architecture', dest='architecture_name', default='efficientnet', type=str) + parser.add_argument('--resize-after-crop', dest='resize_after_crop', + type=int, default=224) + parser.add_argument('--color', default=False, action='store_true') + parser.add_argument('--batch-size', dest='batch_size', + default=256, type=int) + parser.add_argument('--gpus', dest='gpus', default=1, type=int) + parser.add_argument('--nodes', dest='nodes', default=1, type=int) + parser.add_argument('--mixed-precision', dest='mixed_precision', + default=False, action='store_true') + parser.add_argument('--debug', dest='debug', + default=False, action='store_true') + parser.add_argument('--wandb', dest='wandb', + default=False, action='store_true') + parser.add_argument('--seed', dest='random_state', default=42, type=int) + args = parser.parse_args() + + """ + debug + python only_for_me/narval/train.py --save-dir only_for_me/narval/debug_models --batch-size 32 --color + """ + + logging.basicConfig(level=logging.INFO) + + random_state = args.random_state + + # if args.nodes > 1: + # # at Manchester, our slurm cluster sets TASKS not NTASKS, which then confuses lightning + # if 'SLURM_NTASKS_PER_NODE' not in os.environ.keys(): + # os.environ['SLURM_NTASKS_PER_NODE'] = os.environ['SLURM_TASKS_PER_NODE'] + # # log the rest to help debug + # logging.info([(x, y) for (x, y) in os.environ.items() if 'SLURM' in x]) + + if args.debug: + download = False + else: + # download = True # for first use + download = False # for speed afterwards + + if os.path.isdir('/home/walml/repos/zoobot'): + search_str = '/home/walml/repos/zoobot/gz_decals_5_train_*.tar' + + else: + search_str = '/home/walml/projects/def-bovy/walml/data/webdatasets/gz_decals_5/gz_decals_5_train_*.tar' + + all_urls = glob.glob(search_str) + assert len(all_urls) > 0, search_str + train_urls, val_urls = all_urls[:6], all_urls[6:] + schema = schemas.decals_dr5_ortho_schema + + + # if args.dataset == 'gz_decals_dr5': + # schema, (train_catalog, val_catalog, test_catalog) = benchmark_datasets.get_gz_decals_dr5_benchmark_dataset(args.data_dir, random_state, download=download) + # elif args.dataset == 'gz_evo': + # schema, (train_catalog, val_catalog, test_catalog) = benchmark_datasets.get_gz_evo_benchmark_dataset(args.data_dir, random_state, download=download) + # else: + # raise ValueError(f'Dataset {args.dataset} not recognised: should be "gz_decals_dr5" or "gz_evo"') + + + # logging.info('First val galaxy: {}'.format(val_catalog.iloc[0]['id_str'])) + + + # debug mode + if args.debug: + logging.warning( + 'Using debug mode: cutting urls down to 2') + train_urls = train_urls[:2] + val_urls = val_urls[:2] + epochs = 2 + else: + epochs = 1000 + + if args.wandb: + wandb_logger = WandbLogger( + project='narval', + name=os.path.basename(args.save_dir), + log_model=False + ) + else: + wandb_logger = None + + train_with_pytorch_lightning.train_default_zoobot_from_scratch( + save_dir=args.save_dir, + schema=schema, + train_urls = train_urls, + val_urls = val_urls, + test_urls = None, + architecture_name=args.architecture_name, + batch_size=args.batch_size, + epochs=epochs, # rely on early stopping + patience=10, + # augmentation parameters + color=args.color, + resize_after_crop=args.resize_after_crop, + # hardware parameters + gpus=args.gpus, + nodes=args.nodes, + mixed_precision=args.mixed_precision, + wandb_logger=wandb_logger, + prefetch_factor=4, + num_workers=11, # system has 24 cpu, 12 cpu per gpu, leave a little wiggle room + random_state=random_state, + learning_rate=1e-3, + ) + + wandb.finish() \ No newline at end of file diff --git a/zoobot/pytorch/training/train_with_pytorch_lightning.py b/zoobot/pytorch/training/train_with_pytorch_lightning.py index 5690e0e1..9fdaf83c 100644 --- a/zoobot/pytorch/training/train_with_pytorch_lightning.py +++ b/zoobot/pytorch/training/train_with_pytorch_lightning.py @@ -11,6 +11,7 @@ from galaxy_datasets.pytorch.galaxy_datamodule import GalaxyDataModule from zoobot.pytorch.estimators import define_model +from zoobot.pytorch.training import webdatamodule def train_default_zoobot_from_scratch( @@ -22,6 +23,9 @@ def train_default_zoobot_from_scratch( train_catalog=None, val_catalog=None, test_catalog=None, + train_urls=None, + val_urls=None, + test_urls=None, # training time parameters epochs=1000, patience=8, @@ -167,22 +171,6 @@ def train_default_zoobot_from_scratch( Suggest reducing num_workers.""" ) - - if catalog is not None: - assert train_catalog is None - assert val_catalog is None - assert test_catalog is None - catalogs_to_use = { - 'catalog': catalog - } - else: - assert catalog is None - catalogs_to_use = { - 'train_catalog': train_catalog, - 'val_catalog': val_catalog, - 'test_catalog': test_catalog # may be None - } - if wandb_logger is not None: wandb_logger.log_hyperparams({ 'random_state': random_state, @@ -201,20 +189,50 @@ def train_default_zoobot_from_scratch( 'framework': 'pytorch' }) - datamodule = GalaxyDataModule( - label_cols=schema.label_cols, - # can take either a catalog (and split it), or a pre-split catalog - **catalogs_to_use, - # augmentations parameters - greyscale=not color, - crop_scale_bounds=crop_scale_bounds, - crop_ratio_bounds=crop_ratio_bounds, - resize_after_crop=resize_after_crop, - # hardware parameters - batch_size=batch_size, # on 2xA100s, 256 with DDP, 512 with distributed (i.e. split batch) - num_workers=num_workers, - prefetch_factor=prefetch_factor - ) + # work out what dataset the user has passed + single_catalog = catalog is not None + split_catalogs = train_catalog is not None + webdatasets = train_urls is not None + + if single_catalog or split_catalogs: + # this branch will use GalaxyDataModule to load catalogs + assert not webdatasets + if single_catalog: + assert not split_catalogs + data_to_use = { + 'catalog': catalog + } + else: + data_to_use = { + 'train_catalog': train_catalog, + 'val_catalog': val_catalog, + 'test_catalog': test_catalog # may be None + } + datamodule = GalaxyDataModule( + label_cols=schema.label_cols, + # can take either a catalog (and split it), or a pre-split catalog + **data_to_use, + # augmentations parameters + greyscale=not color, + crop_scale_bounds=crop_scale_bounds, + crop_ratio_bounds=crop_ratio_bounds, + resize_after_crop=resize_after_crop, + # hardware parameters + batch_size=batch_size, # on 2xA100s, 256 with DDP, 512 with distributed (i.e. split batch) + num_workers=num_workers, + prefetch_factor=prefetch_factor + ) + else: + # this branch will use WebDataModule to load premade webdatasets + datamodule = webdatamodule.WebDataModule( + train_urls=train_urls, + val_urls=val_urls, + batch_size=batch_size, + num_workers=num_workers, + label_cols=schema.label_cols + # TODO pass through the rest + ) + datamodule.setup(stage='fit') # these args are automatically logged diff --git a/zoobot/pytorch/training/webdatamodule.py b/zoobot/pytorch/training/webdatamodule.py new file mode 100644 index 00000000..bf1fb91b --- /dev/null +++ b/zoobot/pytorch/training/webdatamodule.py @@ -0,0 +1,130 @@ +import os + +import torch.utils.data +import numpy as np +import pytorch_lightning as pl + +import webdataset as wds + +from galaxy_datasets.transforms import default_transforms + +# https://github.com/webdataset/webdataset-lightning/blob/main/train.py +class WebDataModule(pl.LightningDataModule): + def __init__(self, train_urls, val_urls, train_size=None, val_size=None, label_cols=None, batch_size=64, num_workers=4): + super().__init__() + self.train_urls = train_urls + self.val_urls = val_urls + + if train_size is None: + # assume the size of each shard is encoded in the filename as ..._{size}.tar + train_size = sum([int(url.rstrip('.tar').split('_')[-1]) for url in train_urls]) + if val_size is None: + val_size = sum([int(url.rstrip('.tar').split('_')[-1]) for url in val_urls]) + + self.train_size = train_size + self.val_size = val_size + + self.label_cols = label_cols + + self.batch_size = batch_size + self.num_workers = num_workers + + print("train_urls = ", self.train_urls) + print("val_urls = ", self.val_urls) + print("train_size = ", self.train_size) + print("val_size = ", self.val_size) + print("batch_size", self.batch_size, "num_workers", self.num_workers) + + def make_image_transform(self, mode="train"): + # if mode == "train": + # elif mode == "val": + + augmentation_transform = default_transforms() # A.Compose object + def do_transform(img): + return np.transpose(augmentation_transform(image=np.array(img))["image"], axes=[2, 0, 1]).astype(np.float32) + return do_transform + + def make_label_transform(self): + if self.label_cols is not None: + def label_transform(label_dict): + return torch.from_numpy(np.array([label_dict.get(col, 0) for col in self.label_cols])) + return label_transform + else: + return identity # do nothing + + def make_loader(self, urls, mode="train"): + if mode == "train": + dataset_size = self.train_size + shuffle = 5000 + elif mode == "val": + dataset_size = self.val_size + shuffle = 0 + + transform_image = self.make_image_transform(mode=mode) + + transform_label = self.make_label_transform() + + dataset = ( + # https://webdataset.github.io/webdataset/multinode/ + # WDS 'knows' which worker it is running on and selects a subset of urls accordingly + wds.WebDataset(urls) + .shuffle(shuffle) + .decode("rgb") + .to_tuple('image.jpg', 'labels.json') + .map_tuple(transform_image, transform_label) + # torch collate stacks dicts nicely while webdataset only lists them + # so use the torch collate instead + .batched(self.batch_size, torch.utils.data.default_collate, partial=False) + ) + + # from itertools import islice + # for batch in islice(dataset, 0, 3): + # images, labels = batch + # # print(len(sample)) + # print(images.shape) + # print(len(labels)) # list of dicts + # # exit() + + loader = wds.WebLoader( + dataset, + batch_size=None, # already batched + shuffle=False, + num_workers=self.num_workers, + ) + + # print('sampling') + # for sample in islice(loader, 0, 3): + # images, labels = sample + # print(images.shape) + # print(len(labels)) # list of dicts + # exit() + + loader.length = dataset_size // self.batch_size + + # temp hack instead + assert dataset_size % self.batch_size == 0 + # if mode == "train": + # ensure same number of batches in all clients + # loader = loader.ddp_equalize(dataset_size // self.batch_size) + # print("# loader length", len(loader)) + + return loader + + def train_dataloader(self): + return self.make_loader(self.train_urls, mode="train") + + def val_dataloader(self): + return self.make_loader(self.val_urls, mode="val") + + # @staticmethod + # def add_loader_specific_args(parser): + # parser.add_argument("-b", "--batch-size", type=int, default=128) + # parser.add_argument("--workers", type=int, default=6) + # parser.add_argument("--bucket", default="./shards") + # parser.add_argument("--shards", default="imagenet-train-{000000..001281}.tar") + # parser.add_argument("--valshards", default="imagenet-val-{000000..000006}.tar") + # return parser + + +def identity(x): + return x \ No newline at end of file